import * as d3 from 'd3';
import React, { useMemo,useState } from "react";
import { ScaleLinear } from "d3";

interface AxisLeftProps {
  yScale: ScaleLinear<number, number>;
  pixelsPerTick: number;
  width: number;
};

const TICK_LENGTH = 10;

export const AxisLeft = ({ yScale, pixelsPerTick, width }: AxisLeftProps) => {
  const range = yScale.range();

  const ticks = useMemo(() => {
    const height = range[0] - range[1];
    const numberOfTicksTarget = Math.floor(height / pixelsPerTick);

    return yScale.ticks(numberOfTicksTarget).map((value) => ({
      value,
      yOffset: yScale(value),
    }));
  }, [yScale]);

  return (
    <>
      {ticks.map(({ value, yOffset }) => (
        <g
          key={value}
          transform={`translate(0, ${yOffset})`}
          shapeRendering={"crispEdges"}
        >
          <line
            x1={-TICK_LENGTH}
            x2={width + TICK_LENGTH}
            stroke="#D2D7D3"
            strokeWidth={0.5}
          />
          <text
            key={value}
            style={{
              fontSize: "10px",
              textAnchor: "middle",
              transform: "translateX(-20px)",
              fill: "#D2D7D3",
            }}
          >
            {value}
          </text>
        </g>
      ))}
    </>
  );
};

interface AxisBottomProps {
  xScale: ScaleLinear<number, number>;
  pixelsPerTick: number;
  height: number;
};

export const AxisBottom = ({
  xScale,
  pixelsPerTick,
  height,
}: AxisBottomProps) => {
  const range = xScale.range();

  const ticks = useMemo(() => {
    const width = range[1] - range[0];
    const numberOfTicksTarget = Math.floor(width / pixelsPerTick);

    return xScale.ticks(numberOfTicksTarget).map((value) => ({
      value,
      xOffset: xScale(value),
    }));
  }, [xScale]);

  return (
    <>
      {ticks.map(({ value, xOffset }) => (
        <g
          key={value}
          transform={`translate(${xOffset}, 0)`}
          shapeRendering={"crispEdges"}
        >
          <line
            y1={TICK_LENGTH}
            y2={-height - TICK_LENGTH}
            stroke="#D2D7D3"
            strokeWidth={0.5}
          />
          <text
            key={value}
            style={{
              fontSize: "10px",
              textAnchor: "middle",
              transform: "translateY(20px)",
              fill: "#D2D7D3",
            }}
          >
            {value}
          </text>
        </g>
      ))}
    </>
  );
};

interface ScatterChartProps {
  width: number;
  height: number;
  train_data: { x: number; y: number }[];
  test_data: { x: number; y: number }[];
  train_labels: number[];
  test_labels: number[];
  margin?: { top?: number, right?: number, bottom?: number, left?: number };
};

const colors = [
  "rgb(0, 0, 255)",
  "rgb(255, 0, 0)",
  "rgb(0, 255, 0)",
  "rgb(120, 63, 193)",
  "rgb(255, 0, 182)",
  "rgb(177, 204, 113)",
  "rgb(255, 211, 0)",
  "rgb(0, 159, 255)",
  "rgb(154, 77, 66)"
];

export const ScatterChart = ({ width, height, train_data,test_data, train_labels,test_labels, margin }: ScatterChartProps) => {
  const boundsWidth = width - (margin?.right ?? 30) - (margin?.left ?? 30);
  const boundsHeight = height - (margin?.top ?? 30) - (margin?.bottom ?? 30);

  const data = [...train_data,...test_data];
  const labels = [...train_labels,...test_labels];

  const minX = data.reduce((min, p) => p.x < min ? p.x : min, data[0].x);
  const maxX = data.reduce((max, p) => p.x > max ? p.x : max, data[0].x);

  const minY = data.reduce((min, p) => p.y < min ? p.y : min, data[0].y);
  const maxY = data.reduce((max, p) => p.y > max ? p.y : max, data[0].y);

  const yScale = d3.scaleLinear().domain([minY, maxY]).range([boundsHeight, 0]);
  const xScale = d3.scaleLinear().domain([minX, maxX]).range([0, boundsWidth]);


  const [isHover, setIsHover] = useState<boolean>(false);


  const allShapes_hover = data.map((d, i) => {
    return (
      <circle
        key={i}
        r={i >= train_labels.length ? 8 : 4}
        cx={xScale(d.y)}
        cy={yScale(d.x)}
        opacity={i >= train_labels.length ? 0.8 : 0.2}
        fill={colors[labels[i]]}
      />
    );
  });
  const allShapes = data.map((d, i) => {
    return (
      <circle
        key={i}
        r={i >= train_labels.length ? 8 : 4}
        cx={xScale(d.y)}
        cy={yScale(d.x)}
        opacity={i >= train_labels.length ? 0.3 : 0.2}
        fill={colors[labels[i]]}
      />
    );
  });


  return (
      <svg width={width} height={height} onMouseOver={()=>setIsHover(true)} onMouseOut={()=>setIsHover(false)}>
        <g
          width={boundsWidth}
          height={boundsHeight}
          transform={`translate(${[margin?.left ?? 30, margin?.top ?? 30].join(',')})`}
        >
          <AxisLeft yScale={yScale} pixelsPerTick={40} width={boundsWidth} />
          <g transform={`translate(0, ${boundsHeight})`}>
            <AxisBottom
              xScale={xScale}
              pixelsPerTick={40}
              height={boundsHeight}
            />
          </g>
          {isHover?allShapes_hover:allShapes}
        </g>
      </svg>
  );
};
