import React, { useContext, useEffect, useRef, useState } from 'react';
import * as d3 from 'd3';
import { GraphLabel } from './GraphLabel';
import { TrainContext } from '../../../../../../context/TrainContext';

export const colorArr = ['#45d188',  '#b92931', '#d7cf53', '#75bfca',];
const MARGIN = { top: 30, right: 30, bottom: 50, left: 50 };

export type DataPoint = { x: number; y: number };
export type DataCombined = { 
  label: string;
  data:  DataPoint[];
};

type TrainResultProps = {
  width: number;
  height: number;
  originalData: number[] | number[][] | DataCombined[];
};

// Connected Scatter Plot
export const TrainResultGraph = ({
  width,
  height,
  originalData,
}: TrainResultProps) => {
  let data: any;
  data = originalData;
  const {stepIndex} = useContext(TrainContext);

  const [hovered, setHovered] = useState('');
  // size
  const [widthPx, setWidthPx] = useState(0);
  const [heightPx, setHeightPx] = useState(0);
  
  useEffect(() => {
    const resultDiv = document.querySelector('.metric-list');
    let vh = window.innerHeight * 0.01;
    let vw = resultDiv.clientWidth * 0.01;
    setHeightPx(height * vh);
    setWidthPx(width * vw);
  }, [])
  
  // bounds = area inside the graph axis = calculated by substracting the margins
  const axesRef = useRef(null);
  const boundsWidth = widthPx - MARGIN.right - MARGIN.left;
  const boundsHeight = heightPx - MARGIN.top - MARGIN.bottom;

  if(!Array.isArray(data)){
    data = [];
  }

  // 바뀐부분!!
  const total: any[] = data.map(datum => datum.data).flat();

  // Y axis
  const [min, max] = d3.extent(total, (d) => d.y);
  const yScale = d3
    .scaleLinear()
    .domain([min - (max * 0.15), max + (max * 0.15) || 0])
    .range([boundsHeight, 0]);

  // X axis
  const [xMin, xMax] = d3.extent(total, (d) => d.x);
  const xScale = d3
    .scaleLinear()
    .domain([xMin - (xMax * 0.05), xMax + (xMax * 0.15) || 0])
    .range([0, boundsWidth]);

  // Render the X and Y axis using d3.js, not react
  useEffect(() => {

    const domain = xScale.domain();
    const tickValues = d3.range(Math.ceil(domain[0]), Math.floor(domain[1]) + 1);    

    const svgElement = d3.select(axesRef.current);
    svgElement.selectAll('*').remove();
    const xAxisGenerator = d3.axisBottom(xScale)
                              .tickFormat(d3.format("d"))
                              .tickValues(tickValues); // 정수 단위로 표시
                              
    const xAxis = svgElement
      .append('g')
      .attr('transform', 'translate(0,' + boundsHeight + ')')
      .call(xAxisGenerator);

    xAxis.append('text')
      .attr('x', boundsWidth / 2)
      .attr('y', 40)
      .attr('fill', 'white') 
      .style('text-anchor', 'middle') 
      .text('epoch'); 

    const yAxisGenerator = d3.axisLeft(yScale);
    svgElement.append('g').call(yAxisGenerator);
  }, [xScale, yScale, boundsHeight]);

  // Build All lines
  // Build the line
  const allLines = data.map((datum, index) => {
    const lineBuilder = d3
      .line<DataPoint>()
      .x((d) => xScale(d.x))
      .y((d) => yScale(d.y));

    // 바뀐 부분!
    const linePath = lineBuilder(datum.data);
    if (!linePath) {
      return null;
    }

    // Build the circles
    const allCircles = datum.data.map((item, i) => { // 바뀐 부분!
      // 각각 1번째와 2번째 점에서 위치와 label 이름을 저장
      return (
        <circle
          key={item.x + item.y}
          cx={xScale(item.x)}
          cy={yScale(item.y)}
          r={ 
            i === stepIndex - 1 ? 10 : 
            i === datum.data.length - 1 ? 6 : 
            4}
          fill={
            colorArr[index]
          }
          className={i === datum.data.length - 1 - index ? datum.label : ''}
        />
      );
    });

    return (
        <g 
          width={boundsWidth}
          height={boundsHeight}
          transform={`translate(${[MARGIN.left, MARGIN.top].join(',')})`}
          onMouseEnter={() => {
            setHovered(datum.label);
          }}
          onMouseLeave={() => {
            setHovered('');
          }}
        >
          <path
            key={datum.label}
            d={linePath}
            opacity={0.3}
            stroke={colorArr[index]}
            fill="none"
            strokeWidth={6}
          />
          {allCircles}
        </g>
    )
  })
  return (
    <div 
      className='train-result'
      style={{position: 'relative'}}
    >
      <svg width={widthPx} height={heightPx}>
        {allLines}
        <g
          width={boundsWidth}
          height={boundsHeight}
          ref={axesRef}
          transform={`translate(${[MARGIN.left, MARGIN.top].join(',')})`}
        />
      </svg>
      <div
      style={{
        width: boundsWidth,
        height: boundsHeight,
        position: "absolute",
        top: 0,
        left: 0,
        pointerEvents: "none",
        marginLeft: MARGIN.left,
        marginTop: MARGIN.top,
      }}>
        <GraphLabel hovered={hovered} interactionData={data.map(v=>v.label)} />
      </div>
    </div>
  );
};
