import * as d3 from "d3";
import React, { useMemo,useState } from "react";
import { ScaleLinear } from "d3";

interface AxisHorizontalProps {
    xScale: ScaleLinear<number, number>;
    pixelsPerTick: number;
    index: number;
};

const TICK_LENGTH = 10;

export const AxisHorizontal = ({
    xScale,
    pixelsPerTick,
    index,
}: AxisHorizontalProps) => {
    const range = xScale.range();

    const ticks = useMemo(() => {
        const width = range[1] - range[0];
        const numberOfTicksTarget = Math.floor(width / pixelsPerTick);

        return xScale.ticks(numberOfTicksTarget).map((value) => ({
            value,
            xOffset: xScale(value),
        }));
    }, [xScale]);

    return (
        <>
            {(index === 0 || index === 31) && <text
                style={{
                    fontSize: "10px",
                    fill: "#D2D7D3",
                    textAnchor: "middle",
                    alignmentBaseline: "central",
                    transform: "translate(-20px, 0)",
                }}
            >
                {index + 1}
            </text>}
            <line
                x1={-TICK_LENGTH}
                x2={range[1] + TICK_LENGTH}
                y1={0}
                y2={0}
                stroke="#D2D7D3"
                strokeWidth={0.5}
            />
            {ticks.map(({ value, xOffset }) => (
                <g
                    key={value}
                    transform={`translate(${xOffset},0)`}
                    shapeRendering={"crispEdges"}
                >
                    {index > 0 && <line y1={500 / 31} y2={0} stroke={"#D2D7D3"} strokeWidth={0.5} />}
                    {index === 0 &&
                        <text
                            key={value}
                            style={{
                                fontSize: "10px",
                                fill: "#D2D7D3",
                                textAnchor: "middle",
                                alignmentBaseline: "central",
                                transform: "translate(0, 10px)",
                            }}
                        >
                            {value}
                        </text>
                    }
                </g>
            ))}
        </>
    );
};

interface ParallelCoordinateChartProps {
    width: number;
    height: number;
    train_data: number[][];
    test_data: number[][];
    train_labels: number[];
    test_labels: number[];
    margin?: { left?: number; right?: number; top?: number; bottom?: number };
};

type XScale = d3.ScaleLinear<number, number, never>;

const colors = [
    "rgb(0, 0, 255)",
    "rgb(255, 0, 0)",
    "rgb(0, 255, 0)",
    "rgb(120, 63, 193)",
    "rgb(255, 0, 182)",
    "rgb(177, 204, 113)",
    "rgb(255, 211, 0)",
    "rgb(0, 159, 255)",
    "rgb(154, 77, 66)"
];

export const ParallelCoordinateChart = ({ width, height, train_data, test_data, train_labels, test_labels, margin }: ParallelCoordinateChartProps) => {
    const boundsWidth = width - (margin?.left ?? 30) - (margin?.right ?? 30);
    const boundsHeight = height - (margin?.top ?? 30) - (margin?.bottom ?? 30);

    const data = [...train_data,...test_data];
    const labels = [...train_labels,...test_labels];
    const indexes = data[0].map((_, index) => index);

    const [isHover, setIsHover] = useState<boolean>(false);

    const minX = Math.min(...data.flat());
    const maxX = Math.max(...data.flat());

    let xScales: { [name: string]: XScale } = {};
    indexes.forEach((index) => {
        xScales[index] = d3
            .scaleLinear()
            .range([0, boundsWidth])
            .domain([minX, maxX]);
    });

    const yScale = d3
        .scalePoint<number>()
        .range([boundsHeight, 0])
        .domain(indexes)
        .padding(0);

    const colorScale = d3.scaleOrdinal<string>().domain(Array(9).fill(0).map((_, index) => index.toString())).range(colors);

    const lineGenerator = d3.line();

    const allLines = data.map((features, i) => {
        const allCoordinates = indexes.map((index) => {
            const xScale = xScales[index];
            const x = xScale(features[index]);
            const y = yScale(index) ?? 0;
            const coordinate: [number, number] = [x, y];
            return coordinate;
        });

        const d = lineGenerator(allCoordinates);

        if (!d) {
            return;
        }

        return <path key={i} d={d} strokeWidth={i >= train_labels.length ? 2 : undefined} stroke={colorScale(labels[i].toString())} fill="none" opacity={i >= train_labels.length ? 0.1: 0.05} />;
    });

    const allLines_hover = data.map((features, i) => {
        const allCoordinates = indexes.map((index) => {
            const xScale = xScales[index];
            const x = xScale(features[index]);
            const y = yScale(index) ?? 0;
            const coordinate: [number, number] = [x, y];
            return coordinate;
        });

        const d = lineGenerator(allCoordinates);

        if (!d) {
            return;
        }

        return <path key={i} d={d} strokeWidth={i >= train_labels.length ? 2 : undefined} stroke={colorScale(labels[i].toString())} fill="none" opacity={i >= train_labels.length ? 0.8: 0.05} />;
    });


    const allAxes = indexes.map((index, i) => {
    return (
            <g key={i} transform={"translate(0," + yScale(index) + ")"}>
                <AxisHorizontal xScale={xScales[index]} pixelsPerTick={40} index={i} />
            </g>
        );
    });

    return (
        <svg width={width} height={height} onMouseOver={()=>setIsHover(true)} onMouseOut={()=>setIsHover(false)}   >
            <g
                width={boundsWidth}
                height={boundsHeight}
                transform={`translate(${[margin?.left ?? 30, margin?.top ?? 30].join(",")})`}   
            >
                    {isHover?allLines_hover:allLines}
                {allAxes}
            </g>
        </svg>
    );
}