import type { CurvePoint } from "./types";

export function interpolateSpline(points: CurvePoint[]): Uint8Array {
  const lut = new Uint8Array(256);

  if (points.length < 2) {
    for (let i = 0; i < 256; i++) lut[i] = i;
    return lut;
  }

  const sorted = [...points].sort((a, b) => a.x - b.x);
  const n = sorted.length;

  if (n === 2) {
    const [p0, p1] = sorted;
    for (let i = 0; i < 256; i++) {
      if (i <= p0.x) {
        lut[i] = clamp(p0.y);
      } else if (i >= p1.x) {
        lut[i] = clamp(p1.y);
      } else {
        const t = (i - p0.x) / (p1.x - p0.x);
        lut[i] = clamp(p0.y + t * (p1.y - p0.y));
      }
    }
    return lut;
  }

  // Monotone cubic Hermite interpolation (Fritsch-Carlson)
  const xs = sorted.map((p) => p.x);
  const ys = sorted.map((p) => p.y);
  const dx: number[] = [];
  const dy: number[] = [];
  const m: number[] = [];
  const ms: number[] = [];

  for (let i = 0; i < n - 1; i++) {
    dx[i] = xs[i + 1] - xs[i];
    dy[i] = ys[i + 1] - ys[i];
    ms[i] = dx[i] === 0 ? 0 : dy[i] / dx[i];
  }

  m[0] = ms[0];
  for (let i = 1; i < n - 1; i++) {
    if (ms[i - 1] * ms[i] <= 0) {
      m[i] = 0;
    } else {
      m[i] = (ms[i - 1] + ms[i]) / 2;
    }
  }
  m[n - 1] = ms[n - 2];

  // Fritsch-Carlson monotonicity
  for (let i = 0; i < n - 1; i++) {
    if (ms[i] === 0) {
      m[i] = 0;
      m[i + 1] = 0;
    } else {
      const alpha = m[i] / ms[i];
      const beta = m[i + 1] / ms[i];
      const tau = alpha * alpha + beta * beta;
      if (tau > 9) {
        const s = 3 / Math.sqrt(tau);
        m[i] = s * alpha * ms[i];
        m[i + 1] = s * beta * ms[i];
      }
    }
  }

  for (let x = 0; x < 256; x++) {
    if (x <= xs[0]) {
      lut[x] = clamp(ys[0]);
      continue;
    }
    if (x >= xs[n - 1]) {
      lut[x] = clamp(ys[n - 1]);
      continue;
    }

    let seg = 0;
    for (let i = 0; i < n - 1; i++) {
      if (x >= xs[i] && x < xs[i + 1]) {
        seg = i;
        break;
      }
    }

    const h = dx[seg];
    const t = (x - xs[seg]) / h;
    const t2 = t * t;
    const t3 = t2 * t;

    const h00 = 2 * t3 - 3 * t2 + 1;
    const h10 = t3 - 2 * t2 + t;
    const h01 = -2 * t3 + 3 * t2;
    const h11 = t3 - t2;

    const val = h00 * ys[seg] + h10 * h * m[seg] + h01 * ys[seg + 1] + h11 * h * m[seg + 1];
    lut[x] = clamp(val);
  }

  return lut;
}

function clamp(v: number): number {
  return Math.max(0, Math.min(255, Math.round(v)));
}

export function isIdentityCurve(points: CurvePoint[]): boolean {
  if (points.length !== 2) return false;
  return points[0].x === 0 && points[0].y === 0 && points[1].x === 255 && points[1].y === 255;
}
