/*
 * File: onnix-provider.tsx
 * Project: app-aiscaler-web
 * File Created: Tuesday, 25th April 2023 1:45:51 pm
 * Author: v.anhphamd (v.anhphd@vinbrain.net)
 *
 * Copyright 2023 VinBrain JSC
 */

import { ReactNode, useCallback, useEffect, useMemo, useState } from "react";
import { OnnixContext } from "./onnix-context";
import { InferenceSession, Tensor } from "onnxruntime-web";
import { OnnxLabel, modelScaleProps } from "../helpers/Interfaces";
import { useImageEditorContext } from "pages/labeler/image-labeling/image-editor-context/image-editor.context";
import { useSAMTool } from "pages/labeler/image-labeling/hooks/use-sam-tool";
import { useAppDispatch, useAppSelector } from "hooks/use-redux";
import { selectImageLabelingLabelOptions } from "store/labeler/image-workspace/image-labeling/image-labeling.selectors";
import { useLocalStorageState, useMount } from "ahooks";
import { SAMMode } from "./onnix-state";
import { getSAMScale } from "../helpers/scaleHelper";
import { modelClicks, modelData, modelResults } from "../helpers/onnxModelAPI";
import { AnnotateType } from "constants/annotation.constant";
import { imageAnnotationCompletedAsync } from "store/labeler/image-workspace/image-annotations/thunks/image-annotation-completed.thunk";
import { v4 } from "uuid";
import { useImageLabelingContext } from "pages/labeler/image-labeling/context/image-labeling.context";
import { selectImageLabelingJobById } from "store/labeler/image-workspace/batch-labeling/batch-labeling.selectors";
import { AIAudioAssistanceApi } from "data-access/impl/ai-audio-assistance";

const MODEL_DIR = "/onnix/scaler-model.onnx";

interface ImageData {
  width: number;
  height: number;
  data: string;
  samScale: number;
}

interface Props {
  children?: ReactNode;
  jobId: number;
}

export function OnnixProvider({ children, jobId }: Props) {
  const dispatch = useAppDispatch();
  const { imageLoaded } = useImageEditorContext();
  const { selectLabelById } = useImageLabelingContext();
  const [loading, setLoading] = useState(true);
  const { data: toolData, setData: setToolData } = useSAMTool();
  const [imageData, setImageData] = useState<ImageData | undefined>(undefined);
  const [model, setModel] = useState<InferenceSession | null>(null); // ONNX model
  const [embeddingImage, setEmbeddingImage] = useState<Tensor | null>(null); // Image embedding tensor
  const [modelScale, setModelScale] = useState<modelScaleProps | null>(null);
  const { cornerstoneHandler } = useImageEditorContext();
  const [mode, setMode] = useLocalStorageState("SAM_MODE", SAMMode.DRAW);
  const labelOptions: OnnxLabel[] = useAppSelector(
    selectImageLabelingLabelOptions
  );
  const labelingJob = useAppSelector(selectImageLabelingJobById(jobId));

  const [selectedLabelId, setSelectedLabelId] = useLocalStorageState(
    "SAM_SELECTED_LABEL_ID",
    labelOptions.length > 0 ? labelOptions[0].value : ""
  );
  const [onnxError, setOnnxError] = useState<any>(undefined);

  const isInitialized = useMemo(() => {
    return !!model && !!embeddingImage;
  }, [model, embeddingImage]);

  const selectedLabel = useMemo<OnnxLabel | undefined>(() => {
    return labelOptions.find((option) => option.value === selectedLabelId);
  }, [labelOptions, selectedLabelId]);

  const initialize = useCallback(async () => {
    try {
      if (!imageData || isInitialized || !labelingJob) return;
      const { data } = imageData;
      const model = await InferenceSession.create(MODEL_DIR);
      let image: Tensor | undefined;
      try {
        const metadata = labelingJob.file?.additionalProperties?.metadata;
        const embbedings = metadata?.samEmbeddings ?? [];
        const embeddingUrl = embbedings.length > 0 ? embbedings[0] : undefined;
        const payload = { base64Image: data, embeddingUrl };
        image = await AIAudioAssistanceApi.extractEmbeddingImage(payload);
      } catch (error) {}
      if (image) {
        setModel(model);
        setEmbeddingImage(image);
      }
    } catch (error) {
      setOnnxError(error);
    }
  }, [isInitialized, imageData, labelingJob]);

  useEffect(() => {
    if (imageData) initialize();
  }, [imageData, initialize]);

  useEffect(() => {
    if (labelOptions.length > 0 && !selectedLabel) {
      setSelectedLabelId(labelOptions[0].value);
    }
  }, [selectedLabel, setSelectedLabelId, labelOptions]);

  useEffect(() => {
    if (imageLoaded) {
      try {
        const image = cornerstoneHandler.current?.getImageData();
        if (image) {
          const samScale = getSAMScale(image.width, image.height);
          setImageData({ ...image, samScale });
          setModelScale({
            width: image.width,
            height: image.height,
            onnxScale: samScale,
            maskHeight: image.height,
            maskWidth: image.width,
          });
        }
      } catch (error) {
        setOnnxError(error);
      }
    }
  }, [imageLoaded, cornerstoneHandler]);

  useEffect(() => {
    const run = async () => {
      if (
        !imageData ||
        !model ||
        !embeddingImage ||
        !modelScale ||
        !toolData ||
        !toolData.dirty ||
        !selectedLabel ||
        !cornerstoneHandler.current
      )
        return;
      try {
        setLoading(true);
        const clicks = modelClicks(toolData);
        const feeds = modelData({
          clicks: clicks.map((click) => {
            return {
              ...click,
              x: click.x * modelScale.onnxScale,
              y: click.y * modelScale.onnxScale,
            };
          }),
          tensor: embeddingImage,
          last_pred_mask: null,
          modelScale,
        });
        if (!feeds) return;
        const results = await model.run(feeds);
        const { bboxes, polygons } = modelResults(results.masks, selectedLabel);
        const newToolData = { ...toolData, bboxes, polygons, dirty: false };

        setToolData(newToolData);
        setLoading(false);
      } catch (error) {
        console.log("Error", error);
        setLoading(false);
      }
    };

    if (
      !imageData ||
      !model ||
      !embeddingImage ||
      !modelScale ||
      !toolData ||
      !toolData.dirty ||
      !selectedLabel ||
      !cornerstoneHandler.current
    )
      return;
    run();
  }, [
    modelScale,
    imageData,
    toolData,
    model,
    embeddingImage,
    cornerstoneHandler,
    selectedLabel,
    setToolData,
  ]);

  useEffect(() => {
    if (selectedLabel && toolData && cornerstoneHandler.current) {
      cornerstoneHandler.current.updateSAMAnnotations(toolData, selectedLabel);
    }
  }, [cornerstoneHandler, toolData, selectedLabel]);

  useEffect(() => {
    if (selectedLabel)
      setToolData((prev) => {
        if (!prev) return prev;
        return {
          ...prev,
          bboxes: (prev?.bboxes ?? []).map((box) => {
            return {
              ...box,
              labelId: parseInt(selectedLabel.value),
              color: selectedLabel.color,
            };
          }),
          polygons: (prev?.polygons ?? []).map((poly) => {
            return {
              ...poly,
              labelId: parseInt(selectedLabel.value),
              color: selectedLabel.color,
            };
          }),
        };
      });
  }, [selectedLabel, setToolData]);

  function selectLabel(newLabelId: string) {
    setSelectedLabelId(newLabelId);
  }

  function handleClose() {
    cornerstoneHandler.current?.removeSAMMeasurements([]);
    cornerstoneHandler.current?.updateImage();
    selectMode(SAMMode.DRAW);
  }

  function handleApply() {
    cornerstoneHandler.current?.removeSAMMeasurements([]);
    cornerstoneHandler.current?.updateImage();
    selectMode(SAMMode.DRAW);

    if (!toolData || !selectedLabel) return;
    const labelId = parseInt(selectedLabel.value);
    selectLabelById(labelId);

    if (selectedLabel.type === AnnotateType.BOUNDING_BOX && toolData.bboxes) {
      for (const bbox of toolData.bboxes) {
        const payload = {
          jobId,
          labelId: bbox.labelId,
          data: [
            {
              uuid: v4(),
              type: selectedLabel.type,
              points: [
                { x: bbox.data.x, y: bbox.data.y },
                {
                  x: bbox.data.x + bbox.data.width,
                  y: bbox.data.y + bbox.data.height,
                },
              ],
            },
          ],
        };
        dispatch(imageAnnotationCompletedAsync(payload));
      }
    }

    if (selectedLabel.type === AnnotateType.POLYGON && toolData.polygons) {
      for (const poly of toolData.polygons) {
        const payload = {
          jobId,
          labelId: poly.labelId,
          data: [
            {
              uuid: v4(),
              type: selectedLabel.type,
              points: poly.data,
            },
          ],
        };
        dispatch(imageAnnotationCompletedAsync(payload));
      }
    }
  }

  function handleReRun() {}

  function selectMode(newMode: SAMMode) {
    cornerstoneHandler.current?.setSAMMode(newMode);
    setMode(newMode);
  }

  useMount(() => selectMode(SAMMode.DRAW));

  const state = {
    loading,
    isInitialized,
    error: onnxError,
    data: toolData,
    selectedLabel,
    labelOptions,
    mode,
    selectLabel,
    selectMode,
    handleApply,
    handleClose,
    handleReRun,
  };

  return (
    <OnnixContext.Provider value={state}>{children}</OnnixContext.Provider>
  );
}
