import { Tensor } from "onnxruntime-web";
import {
  OnnxLabel,
  clickType,
  modeDataProps,
  modelInputProps,
} from "./Interfaces";
import { SAMBBox, SAMData } from "../context/onnix-state";
import { traceOnnxMaskToPolygons } from "./maskUtils";
import { Point } from "utilities/math/point";

const modelResults = (mask: Tensor, label: OnnxLabel) => {
  const polygons = traceOnnxMaskToPolygons(
    mask.data,
    mask.dims[3],
    mask.dims[2]
  );
  const samPolygons = polygons.map((points) => {
    return {
      labelId: parseInt(label.value),
      color: label.color,
      data: points.map((p) => {
        return {
          x: p.x,
          y: p.y,
        };
      }),
    };
  });
  const samBBoxes: SAMBBox[] = samPolygons.map((poly) => {
    return {
      ...poly,
      data: poly2Bbox(poly.data),
    };
  });

  return {
    bboxes: samBBoxes.length > 0 ? [samBBoxes[0]] : [],
    polygons: samPolygons.length > 0 ? [samPolygons[0]] : [],
  };
};

const poly2Bbox = (points: Point[]) => {
  const minX = Math.min(...points.map((p) => p.x));
  const minY = Math.min(...points.map((p) => p.y));

  const maxX = Math.max(...points.map((p) => p.x));
  const maxY = Math.max(...points.map((p) => p.y));

  return {
    x: minX,
    y: minY,
    width: maxX - minX,
    height: maxY - minY,
  };
};

const modelClicks = (toolData: SAMData) => {
  const clicks: Array<modelInputProps> = [
    {
      x: toolData.bbox.x,
      y: toolData.bbox.y,
      width: null,
      height: null,
      clickType: clickType.UPPER_LEFT,
    },
    {
      x: toolData.bbox.x + toolData.bbox.width,
      y: toolData.bbox.y + toolData.bbox.height,
      width: null,
      height: null,
      clickType: clickType.BOTTOM_RIGHT,
    },
  ];
  for (const point of toolData.includePoints) {
    const includePoint: modelInputProps = {
      x: point.x,
      y: point.y,
      width: null,
      height: null,
      clickType: clickType.POSITIVE,
    };
    clicks.push(includePoint);
  }

  for (const point of toolData.excludePoints) {
    const excludePoint: modelInputProps = {
      x: point.x,
      y: point.y,
      width: null,
      height: null,
      clickType: clickType.NEGATIVE,
    };
    clicks.push(excludePoint);
  }
  return clicks;
};

const modelData = ({ clicks, tensor, modelScale }: modeDataProps) => {
  const lowResTensor = tensor;
  let pointCoords;
  let pointLabels;
  let pointCoordsTensor;
  let pointLabelsTensor;
  if (clicks) {
    let n = clicks.length;
    pointCoords = new Float32Array(2 * (n + 1));
    pointLabels = new Float32Array(n + 1);

    for (let i = 0; i < n; i++) {
      pointCoords[2 * i] = clicks[i].x;
      pointCoords[2 * i + 1] = clicks[i].y;
      pointLabels[i] = clicks[i].clickType;
    }

    pointCoordsTensor = new Tensor("float32", pointCoords, [1, n + 1, 2]);
    pointLabelsTensor = new Tensor("float32", pointLabels, [1, n + 1]);
  }
  const imageSizeTensor = new Tensor("float32", [
    modelScale.maskHeight,
    modelScale.maskWidth,
  ]);
  if (pointCoordsTensor === undefined || pointLabelsTensor === undefined)
    return;

  const maskInput = new Tensor(
    "float32",
    new Float32Array(256 * 256),
    [1, 1, 256, 256]
  );
  const hasMaskInput = new Tensor("float32", [0]);

  return {
    image_embeddings: lowResTensor,
    point_coords: pointCoordsTensor,
    point_labels: pointLabelsTensor,
    orig_im_size: imageSizeTensor,
    mask_input: maskInput,
    has_mask_input: hasMaskInput,
  };
};

export { modelData, modelClicks, modelResults };
