import React, { useEffect, useRef, useState, useContext } from "react";
import {
  Stage,
  Layer,
  Image,
  Rect,
  Line,
  Circle as KonvaCircle,
  Group,
} from "react-konva";
import { Circle } from "konva/lib/shapes/Circle";
import useImage from "use-image";
import { KonvaEventObject } from "konva/lib/Node";
import Konva from "konva";
import { Matrix, solve } from "ml-matrix";
import { Vector2d } from "konva/lib/types";

import { db, getSubImage } from "../../db";
import useWizardStore, { ImageLookup, ImageRegistration, ImageRegistrationTransform } from "../stores/useWizardStore";
import { MainContext } from "../contexts/MainContext";
import { WebTools, HttpError } from "../contexts/ImageSyncService";
import { unstable_batchedUpdates } from "react-dom";

// Interfaces
interface ImageTransformProps {
  fixed: ImageLookup;
  unfixed: ImageLookup;
  onTransform?: (transform: ImageRegistrationTransform) => void;
}

interface ZoomedImageProps {
  image: ImageBitmap;
  scale: number;
  idx: number;
  arr: number[][];
  setArr: React.Dispatch<React.SetStateAction<number[][]>>;
  circles: { x: number; y: number }[];
  setCircles: React.Dispatch<React.SetStateAction<{ x: number; y: number }[]>>;
}

interface Transform {
  trans_x: number;
  trans_y: number;
  scale: number;
  rotation: number;
}

// Constants
const IMAGE_SIZE = 400;
const ZOOMED_IMAGE_SIZE = 100;

const IndexToColor = (i: number) => {
  // red, green, blue, yellow
  const colors = ["#ff0000", "#00ff00", "#0000ff", "#ffff00"];

  return colors[i % colors.length];
};

// Helper Components
const ZoomedImage: React.FC<ZoomedImageProps> = ({
  image,
  scale,
  idx,
  arr,
  setArr,
  circles,
  setCircles,
}) => {
  if (!image || !arr[idx]) {
    return (
      <div style={{ width: ZOOMED_IMAGE_SIZE, height: ZOOMED_IMAGE_SIZE }} />
    );
  }

  const editPoint = (e: KonvaEventObject<MouseEvent>) => {
    const stage = e.target.getStage();
    // const stageId = stage?.attrs.id;
    const location = stage?.getPointerPosition();
    if (location) {
      const scale: Vector2d = stage!.getLayers()[0].getChildren()[0].scale()!;
      arr[idx] = [
        (arr[idx][0] * scale.x + location.x - ZOOMED_IMAGE_SIZE / 2) / scale.x,
        (arr[idx][1] * scale.y + location.y - ZOOMED_IMAGE_SIZE / 2) / scale.y,
      ];
      setArr([...arr]);

      if (circles.length > 0) {
        const newCircles = [...circles];
        const circle = newCircles[idx];
        newCircles[idx] = {
          x:
            (circle.x * scale.x + location.x - ZOOMED_IMAGE_SIZE / 2) / scale.x,
          y:
            (circle.y * scale.y + location.y - ZOOMED_IMAGE_SIZE / 2) / scale.y,
        };
        setCircles(newCircles);
      } else {
        setCircles([
          ...circles,
          { x: location.x / scale.x, y: location.y / scale.y },
        ]);
      }
    }
  };

  return (
    <div className="relative">
      {/* Delete Icon */}
      <div className="absolute top-0 right-0 z-10">
        <button
          className="bg-red-500 text-white h-6 w-6 rounded-full flex items-center justify-center"
          onClick={() => {
            setArr(arr.filter((_, i) => i !== idx));
            setCircles(circles.filter((_, i) => i !== idx));
          }}
        >
          X
        </button>
      </div>
      <Stage width={ZOOMED_IMAGE_SIZE} height={ZOOMED_IMAGE_SIZE}>
        <Layer onMouseUp={editPoint}>
          <Image
            scaleX={(IMAGE_SIZE / image.width) * scale}
            scaleY={(IMAGE_SIZE / image.height) * scale}
            x={
              -arr[idx][0] * ((IMAGE_SIZE / image.width) * scale) +
              ZOOMED_IMAGE_SIZE / 2
            }
            y={
              -arr[idx][1] * ((IMAGE_SIZE / image.width) * scale) +
              ZOOMED_IMAGE_SIZE / 2
            }
            image={image}
          />
          <Line
            points={[
              ZOOMED_IMAGE_SIZE / 4,
              ZOOMED_IMAGE_SIZE / 2,
              (ZOOMED_IMAGE_SIZE / 4) * 3,
              ZOOMED_IMAGE_SIZE / 2,
            ]}
            stroke={IndexToColor(idx) + 66}
            strokeWidth={1}
          />
          <Line
            points={[
              ZOOMED_IMAGE_SIZE / 2,
              ZOOMED_IMAGE_SIZE / 4,
              ZOOMED_IMAGE_SIZE / 2,
              (ZOOMED_IMAGE_SIZE / 4) * 3,
            ]}
            stroke={IndexToColor(idx) + 66}
            strokeWidth={1}
          />
          <KonvaCircle
            x={ZOOMED_IMAGE_SIZE / 2}
            y={ZOOMED_IMAGE_SIZE / 2}
            radius={2}
            fill={IndexToColor(idx)}
            stroke="black"
            strokeWidth={0}
          />
          <KonvaCircle
            x={ZOOMED_IMAGE_SIZE / 2}
            y={ZOOMED_IMAGE_SIZE / 2}
            radius={ZOOMED_IMAGE_SIZE / 4}
            fill="transparent"
            stroke={IndexToColor(idx)}
            strokeWidth={2}
          />
        </Layer>
      </Stage>
    </div>
  );
};

// Main Component
export const ImageTransform: React.FC<ImageTransformProps> = ({
  fixed,
  unfixed,
  onTransform,
}) => {
  const { u, p, patient_id } = useContext(MainContext)!;
  const setFixedLookup = useWizardStore(state => state.setFixedLookup);
  const [fixedImage, setFixedImage] = useState<ImageBitmap | null>(null);
  const [unfixedImage, setUnfixedImage] = useState<ImageBitmap | null>(null);
  const [opacity, setOpacity] = useState<number>(0.5);
  const [transform, setTransform] = useState<ImageRegistrationTransform>({
    translation: { x: 0, y: 0 },
    scale: 1,
    rotation: 0,
  });
  const [ptsSrc, setPtsSrc] = useState<number[][]>([]);
  const [ptsDst, setPtsDst] = useState<number[][]>([]);

  const fixedStageRef = useRef<Konva.Stage>(null);
  const unfixedStageRef = useRef<Konva.Stage>(null);

  // mount
  useEffect(() => {
    initAsync();
  }, []);

  useEffect(() => {
    onTransform?.(transform);
  }, [transform]);

  async function initAsync() {
    let fixedBitmap: ImageBitmap;
    let unfixedBitmap: ImageBitmap;
    try {
      const all = await Promise.all([
        getSubImage(fixed.filepath, fixed.subImageKey),
        getSubImage(unfixed.filepath, unfixed.subImageKey),
      ]);
      fixedBitmap = all[0];
      unfixedBitmap = all[1];
    } catch (error) {
      console.error("Error loading images:", error);
      return;
    }
    const reg = useWizardStore.getState().registrations;
    unstable_batchedUpdates(() => {
      setFixedImage(fixedBitmap);
      setUnfixedImage(unfixedBitmap);
      if (reg && reg[unfixed.type][unfixed.eye]) {
        setTransform(reg[unfixed.type][unfixed.eye]);
      }
      setFixedLookup(fixed.eye, fixed);
    });
  }

  const clearPoints = () => {
    fixedStageRef.current?.getLayers()[1].destroyChildren();
    unfixedStageRef.current?.getLayers()[1].destroyChildren();
    setFixedImageCircles([]);
    setUnfixedImageCircles([]);
    setPtsSrc([]);
    setPtsDst([]);
    setTransform({
      translation: { x: 0, y: 0 },
      scale: 1,
      rotation: 0,
    });
  };

  useEffect(() => {
    clearPoints();
  }, [fixed, unfixed]);

  const [fixedImageCircles, setFixedImageCircles] = useState<
    { x: number; y: number }[]
  >([]);
  const [unfixedImageCircles, setUnfixedImageCircles] = useState<
    { x: number; y: number }[]
  >([]);

  const addPoint = (
    e: KonvaEventObject<MouseEvent>,
    arr: number[][],
    setArr: React.Dispatch<React.SetStateAction<number[][]>>
  ) => {
    const stage = e.target.getStage();
    const stageId = stage?.attrs.id;
    const location = stage?.getPointerPosition();
    if (location && arr.length < 3) {
      const scale: Vector2d = stage!.getLayers()[0].getChildren()[0].scale()!;
      // const point = new Circle({
      //   radius: 5,
      //   fill: "red",
      //   stroke: "black",
      //   strokeWidth: 1,
      //   x: location.x,
      //   y: location.y,
      // });
      // stage!.getLayers()[1].add(point);

      if (stageId === "fixed") {
        setFixedImageCircles([
          ...fixedImageCircles,
          { x: location.x / scale.x, y: location.y / scale.y },
        ]);
      }

      if (stageId === "unfixed") {
        setUnfixedImageCircles([
          ...unfixedImageCircles,
          { x: location.x / scale.x, y: location.y / scale.y },
        ]);
      }

      setArr([...arr, [location.x / scale.x, location.y / scale.y]]);
    }
  };

  const solveTransform = (
    pts_src: number[][],
    pts_dst: number[][]
  ): ImageRegistrationTransform => {
    if (pts_src.length !== pts_dst.length) {
      throw new Error("Point sets must have the same length");
    }

    // we should make sure the order of the points is the same
    let srcOrdered = [...pts_src].sort((a, b) => a[0] - b[0]);
    let dstOrdered = [...pts_dst].sort((a, b) => a[0] - b[0]);

    const pts_src_mat = new Matrix(srcOrdered);
    const pts_dst_mat = new Matrix(dstOrdered);

    const sum_src = pts_src_mat.sum("column");
    const sum_src_x = sum_src[0];
    const sum_src_y = sum_src[1];
    const sum_sq = pts_src_mat.clone().pow(2).sum();
    const N = pts_src.length;

    const A = new Matrix([
      [sum_sq, 0, sum_src_x, sum_src_y],
      [0, sum_sq, -sum_src_y, sum_src_x],
      [sum_src_x, -sum_src_y, N, 0],
      [sum_src_y, sum_src_x, 0, N],
    ]);

    const b = Matrix.columnVector([
      pts_src_mat.getColumnVector(0).dot(pts_dst_mat.getColumnVector(0)) +
        pts_src_mat.getColumnVector(1).dot(pts_dst_mat.getColumnVector(1)),
      pts_src_mat.getColumnVector(0).dot(pts_dst_mat.getColumnVector(1)) -
        pts_src_mat.getColumnVector(1).dot(pts_dst_mat.getColumnVector(0)),
      pts_dst_mat.sum("column")[0],
      pts_dst_mat.sum("column")[1],
    ]);

    const x = solve(A, b).to1DArray();

    return {
      translation: {
        x: Number(x[2].toFixed(5)),
        y: Number(x[3].toFixed(5)),
      },
      scale: Number(Math.sqrt(x[0] ** 2 + x[1] ** 2).toFixed(5)),
      rotation: Number((Math.atan2(x[1], x[0]) * (180 / Math.PI)).toFixed(5)),
    };
  };

  useEffect(() => {
    if (ptsSrc.length >= 3 && ptsSrc.length === ptsDst.length) {
      try {
        const newTransform = solveTransform(ptsSrc, ptsDst);
        setTransform(newTransform);
      } catch (error) {
        console.error("Error calculating transform:", error);
      }
    }
  }, [ptsSrc, ptsDst]);

  const blendTimeout = useRef<NodeJS.Timeout | null>(null);
  const [blendAuto, setBlendAuto] = useState(false);
  const [blendDirection, setBlendDirection] = useState(1);

  useEffect(() => {
    if (blendAuto) {
      let direction = blendDirection;
      blendTimeout.current = setInterval(() => {
        setOpacity((prev) => {
          if (prev >= 1) {
            direction = -1;
          } else if (prev <= 0) {
            direction = 1;
          }

          return prev + 0.05 * direction;
        });

        setBlendDirection(direction);
      }, 20);
    }

    return () => {
      if (blendTimeout.current) {
        clearInterval(blendTimeout.current);
      }
    };
  }, [blendAuto]);

  if (!fixedImage || !unfixedImage) {
    return <div>Loading images...</div>;
  }

  function renderTextFields() {
    return <div className="grid grid-cols-3 w-[400px] gap-2">
      <span>Translation:</span>
      <div className="relative w-[90px]">
        <input
          type="number"
          step={0.1}
          value={transform.translation.x}
          onChange={(e) =>
            setTransform({ ...transform, translation: { ...transform.translation, x: e.target.valueAsNumber } })
          }
          className="relative w-[90px] rounded-md p-1 text-xs"
        />
        <div className="absolute -left-3 top-0 bottom-0 opacity-40 font-bold">
          x
        </div>
      </div>
      <div className="relative w-[90px]">
        <input
          type="number"
          step={0.1}
          value={transform.translation.y}
          onChange={(e) =>
            setTransform({ ...transform, translation: { ...transform.translation, y: e.target.valueAsNumber } })
          }
          className="relative w-[90px] rounded-md p-1 text-xs"
        />
        <div className="absolute -left-3 top-0 bottom-0 opacity-40 font-bold">
          y
        </div>
      </div>
      <span>Rotation:</span>
      <div className="relative w-[90px]">
        <input
          type="number"
          step={0.1}
          value={transform.rotation}
          onChange={(e) =>
            setTransform({ ...transform, rotation: e.target.valueAsNumber })
          }
          className="relative w-[90px] rounded-md p-1 text-xs"
        />
        <div className="absolute -right-3 top-0 bottom-0 opacity-40 font-bold">
          °
        </div>
      </div>
      <div></div>
      <span>Scale:</span>
      <input
        type="number"
        step={0.01}
        value={transform.scale}
        onChange={(e) =>
          setTransform({ ...transform, scale: e.target.valueAsNumber })
        }
        className="relative w-[90px] rounded-md px-1 text-xs"
      />
    </div>
  }

  // const textFields = renderTextFields();
  const textFields = undefined;

  return (
    <div className="flex flex-row p-4 justify-between bg-gray-200 rounded-md min-w-[1416px]">
      <div className="grid grid-cols-2 gap-4">
        <div className="flex flex-row">
          <div className="w-[100px]">
            {ptsDst.map((point, i) => (
              <ZoomedImage
                key={i}
                idx={i}
                image={fixedImage}
                scale={5}
                arr={ptsDst}
                setArr={setPtsDst}
                circles={fixedImageCircles}
                setCircles={setFixedImageCircles}
              />
            ))}
          </div>
          <div>
            <Stage
              ref={fixedStageRef}
              id="fixed"
              width={IMAGE_SIZE}
              height={IMAGE_SIZE}
              onMouseUp={(e) => addPoint(e, ptsDst, setPtsDst)}
            >
              <Layer>
                <Image
                  image={fixedImage}
                  scaleX={IMAGE_SIZE / fixedImage.width}
                  scaleY={IMAGE_SIZE / fixedImage.height}
                />
              </Layer>
              <Layer />
              <Layer>
                {fixedImageCircles.map((point, i) => (
                  <Group key={i}>
                    <KonvaCircle
                      key={i + "a"}
                      x={point.x * (IMAGE_SIZE / fixedImage.width)}
                      y={point.y * (IMAGE_SIZE / fixedImage.height)}
                      radius={10}
                      stroke={IndexToColor(i) + 99}
                      strokeWidth={2}
                    />
                    <KonvaCircle
                      key={i + "b"}
                      x={point.x * (IMAGE_SIZE / fixedImage.width)}
                      y={point.y * (IMAGE_SIZE / fixedImage.height)}
                      radius={1}
                      fill={IndexToColor(i)}
                      strokeWidth={0}
                    />
                    <Line
                      key={i + "c"}
                      points={[
                        point.x * (IMAGE_SIZE / fixedImage.width) - 10,
                        point.y * (IMAGE_SIZE / fixedImage.height),
                        point.x * (IMAGE_SIZE / fixedImage.width) + 10,
                        point.y * (IMAGE_SIZE / fixedImage.height),
                      ]}
                      stroke={IndexToColor(i) + 66}
                      strokeWidth={1}
                    />
                    <Line
                      key={i + "d"}
                      points={[
                        point.x * (IMAGE_SIZE / fixedImage.width),
                        point.y * (IMAGE_SIZE / fixedImage.height) - 10,
                        point.x * (IMAGE_SIZE / fixedImage.width),
                        point.y * (IMAGE_SIZE / fixedImage.height) + 10,
                      ]}
                      stroke={IndexToColor(i) + 66}
                      strokeWidth={1}
                    />
                  </Group>
                ))}
              </Layer>
            </Stage>
            <div className="text-sm">Infrared (reference)</div>
          </div>
        </div>
        <div>
          <div className="flex">
            <Stage
              ref={unfixedStageRef}
              id="unfixed"
              width={IMAGE_SIZE}
              height={IMAGE_SIZE}
              onMouseUp={(e) => addPoint(e, ptsSrc, setPtsSrc)}
            >
              <Layer>
                <Image
                  image={unfixedImage}
                  scaleX={IMAGE_SIZE / unfixedImage.width}
                  scaleY={IMAGE_SIZE / unfixedImage.height}
                />
              </Layer>
              <Layer />
              <Layer>
                {unfixedImageCircles.map((point, i) => (
                  <Group key={i}>
                    <KonvaCircle
                      key={i + "a"}
                      x={point.x * (IMAGE_SIZE / unfixedImage.width)}
                      y={point.y * (IMAGE_SIZE / unfixedImage.height)}
                      radius={10}
                      stroke={IndexToColor(i) + 99}
                      strokeWidth={2}
                    />
                    <KonvaCircle
                      key={i + "b"}
                      x={point.x * (IMAGE_SIZE / unfixedImage.width)}
                      y={point.y * (IMAGE_SIZE / unfixedImage.height)}
                      radius={1}
                      fill={IndexToColor(i)}
                      strokeWidth={0}
                    />
                    <Line
                      key={i + "c"}
                      points={[
                        point.x * (IMAGE_SIZE / unfixedImage.width) - 10,
                        point.y * (IMAGE_SIZE / unfixedImage.height),
                        point.x * (IMAGE_SIZE / unfixedImage.width) + 10,
                        point.y * (IMAGE_SIZE / unfixedImage.height),
                      ]}
                      stroke={IndexToColor(i) + 66}
                      strokeWidth={1}
                    />
                    <Line
                      key={i + "d"}
                      points={[
                        point.x * (IMAGE_SIZE / unfixedImage.width),
                        point.y * (IMAGE_SIZE / unfixedImage.height) - 10,
                        point.x * (IMAGE_SIZE / unfixedImage.width),
                        point.y * (IMAGE_SIZE / unfixedImage.height) + 10,
                      ]}
                      stroke={IndexToColor(i) + 66}
                      strokeWidth={1}
                    />
                  </Group>
                ))}
              </Layer>
            </Stage>
            <div className="w-[100px]">
              {ptsSrc.map((point, i) => (
                <ZoomedImage
                  key={i}
                  idx={i}
                  image={unfixedImage}
                  scale={5}
                  arr={ptsSrc}
                  setArr={setPtsSrc}
                  circles={unfixedImageCircles}
                  setCircles={setUnfixedImageCircles}
                />
              ))}
            </div>
          </div>
          <div className="text-sm">{unfixed.type == 'FAF' ? 'FAF Field 2' : 'FAF Field 1M'}</div>
        </div>
        <div className="flex justify-end items-start">
          <button
            className="font-semibold bg-blue-500 text-white px-2 py-1 rounded w-[400px]"
            onClick={clearPoints}
          >
            Clear Points
          </button>
        </div>
        {textFields}
      </div>
      <div className="flex place-items-center flex-col w-[400px]">
        <Stage
          width={IMAGE_SIZE}
          height={IMAGE_SIZE}
          scaleX={IMAGE_SIZE / fixedImage.width}
          scaleY={IMAGE_SIZE / fixedImage.height}
        >
          <Layer>
            <Image image={fixedImage} />
            <Image
              image={unfixedImage}
              opacity={opacity}
              x={transform.translation.x}
              y={transform.translation.y}
              scaleX={transform.scale}
              scaleY={transform.scale}
              rotation={transform.rotation}
            />
          </Layer>
        </Stage>
        <div className="flex flex-col items-center space-y-1 [&>*]:space-x-1 mt-4">
          <button
            className={`${
              blendAuto ? "bg-red-500" : "bg-blue-500"
            } text-white p-2 rounded-md`}
            onClick={() => setBlendAuto((prev) => !prev)}
          >
            {blendAuto ? "Stop" : "Start"} Auto Blend
          </button>
          <div className="flex items-center">
            <span>Blend:</span>
            <input
              type="range"
              min={0}
              max={1}
              step={0.01}
              value={opacity}
              onChange={(e) => setOpacity(e.target.valueAsNumber)}
            />
          </div>
        </div>
      </div>
    </div>
  );
};
