import type { Profile } from './ParticleExport';
import { CurveDefinition } from './curve-tools';
import { createStreamParticleMaterial } from './StreamParticleMaterial';
import { sampleCircleProfile, sampleRectProfile } from './sample-utils';
import { THREE } from './v3d';

const SUB_DIVISIONS = 200;

const vec = new THREE.Vector3();
const vec2 = new THREE.Vector3();
const vec3 = new THREE.Vector3();
const vec4 = new THREE.Vector3();

interface LinearSegments {
  points: THREE.Vector3[];
  normals: THREE.Vector3[];
  tangents: THREE.Vector3[];
  binormals: THREE.Vector3[];
  metaData: THREE.Vector3[];
  subLength: number;
  totalLength: number;
}

export interface ParticleSystemConfiguration {
  numParticles: number;
  velocity?: number;
  material?: THREE.Material;
  color?: string | number;
  fadeStart?: number;
  fadeEnd?: number;
  noSlip?: number;
}

export interface StreamParticleSystemProps extends ParticleSystemConfiguration {
  profile: Profile;
  curveDefinition: CurveDefinition;
}
export default class StreamParticleSystem {
  object3d: THREE.Object3D;

  velocity: number;

  private particleCount: number;

  private geometry: THREE.BufferGeometry;

  private material: THREE.Material;

  private totalLength!: number;

  private initialPositions!: number[];

  private speedMods!: number[];

  private posAttribute!: THREE.Float32BufferAttribute;

  private fadeStart: number;

  private fadeEnd: number;

  private noSlip: number;

  private alphaAttribute!: THREE.Float32BufferAttribute;

  private linearSegments!: LinearSegments;

  constructor({
    numParticles = 1,
    velocity = 1,
    fadeStart = 0,
    fadeEnd = 0,
    noSlip = 1,
    profile,
    material,
    curveDefinition: curveDef,
  }: StreamParticleSystemProps) {
    this.particleCount = numParticles;
    this.geometry = new THREE.BufferGeometry();

    const textureLoader = new THREE.TextureLoader();
    this.material =
      material ??
      createStreamParticleMaterial({
        color: new THREE.Color(0xcc608a),
        size: 0.05,
        map: textureLoader.load('./particle2.png'),
      });
    this.velocity = velocity;
    this.fadeStart = fadeStart;
    this.fadeEnd = fadeEnd;
    this.noSlip = noSlip;
    this.initializeSpline(curveDef);
    this.initializeGeometry(profile);

    this.object3d = new THREE.Points(this.geometry, this.material);
    this.object3d.scale.set(1, 1, 1);

    /* Debug Curve
      const nurbsGeometry = new THREE.BufferGeometry();
      nurbsGeometry.setFromPoints( curveDef.positionCurve.getPoints( 200 ) );
      const nurbsMaterial = new THREE.LineBasicMaterial( { color: 0x333333 } );
      const nurbsLine = new THREE.Line( nurbsGeometry, nurbsMaterial );
      this.object3d.children.push(nurbsLine)
    */
  }

  initializeGeometry(profile: Profile) {
    const vertices = [];
    const alphas = [];
    const speedMods = [];

    let sampleFunc: () => { x: number; y: number; z: number; slip: number };
    if (profile.rect) {
      const { rect } = profile;
      sampleFunc = () => sampleRectProfile(rect, this.totalLength);
    } else if (profile.circle) {
      const { circle } = profile;
      sampleFunc = () => sampleCircleProfile(circle, this.totalLength);
    } else throw new Error('Invalid profile');

    for (let i = 0; i < this.particleCount; i++) {
      const { x, y, z, slip } = sampleFunc();
      vertices.push(x, y, z);
      alphas.push(1.0);

      const speedVariation = 0.5 * THREE.MathUtils.seededRandom() + 0.75;
      const slipVariation = THREE.MathUtils.lerp(1, this.noSlip, slip);

      speedMods.push(speedVariation * slipVariation);
    }

    this.initialPositions = vertices;
    this.speedMods = speedMods;
    this.posAttribute = new THREE.Float32BufferAttribute(vertices, 3).setUsage(
      THREE.DynamicDrawUsage
    );
    this.alphaAttribute = new THREE.Float32BufferAttribute(alphas, 1).setUsage(
      THREE.DynamicDrawUsage
    );
    this.geometry.setAttribute('position', this.posAttribute);
    this.geometry.setAttribute('alpha', this.alphaAttribute);
  }

  initializeSpline(curveDefinition: CurveDefinition) {
    const curve = curveDefinition.positionCurve;
    const { dataCurve } = curveDefinition;

    curve.updateArcLengths();

    const points = curve.getSpacedPoints(SUB_DIVISIONS);
    const frames = curve.computeFrenetFrames(SUB_DIVISIONS);
    const lengths = curve.getLengths(SUB_DIVISIONS);
    const totalLength = lengths[lengths.length - 1];

    const metaSegments = points.map((p, i) => {
      // const u = lengths[i] / totalLength;
      // const t = curve.getUtoTmapping(u, u * totalLength);
      const t = i / SUB_DIVISIONS;
      const metaData = dataCurve.getPoint(t);
      return metaData;
    });

    this.totalLength = totalLength;
    this.linearSegments = {
      points,
      normals: frames.normals,
      tangents: frames.tangents,
      binormals: frames.binormals,
      metaData: metaSegments,
      subLength: totalLength / SUB_DIVISIONS,
      totalLength,
    };
  }

  update(_delta: number, time: number) {
    // TODO: quite ugly visibility hack
    const parentGroupVisible = this.object3d.parent?.parent?.visible;
    this.object3d.visible = !!parentGroupVisible;

    if (!parentGroupVisible) return;

    const seed = 2;
    THREE.MathUtils.seededRandom(seed);
    const segments = this.linearSegments;
    const { totalLength } = this;

    for (let i = 0; i < this.particleCount; i++) {
      const x = this.initialPositions[3 * i];
      const y = this.initialPositions[3 * i + 1];
      const z = this.initialPositions[3 * i + 2];
      const speedMod = this.speedMods[i];

      const splinePos =
        (z + time * speedMod * this.velocity) % this.linearSegments.totalLength;
      const indexFloat = splinePos / this.linearSegments.subLength;
      const indexT = indexFloat - Math.trunc(indexFloat);

      const index0 = Math.floor(indexFloat);
      const index1 = Math.ceil(indexFloat);

      const p0 = segments.points[index0];
      const p1 = segments.points[index1];

      const n0 = segments.normals[index0];
      const n1 = segments.normals[index1];

      const b0 = segments.binormals[index0];
      const b1 = segments.binormals[index1];

      const p = vec.lerpVectors(p0, p1, indexT);
      const n = vec2.lerpVectors(n0, n1, indexT);
      const b = vec3.lerpVectors(b0, b1, indexT);

      const meta = vec4.lerpVectors(
        segments.metaData[index0],
        segments.metaData[index1],
        indexT
      );

      p.addScaledVector(n, x * meta.x).addScaledVector(b, y * meta.x);

      // Update 1 - Positions
      this.posAttribute.setXYZ(i, p.x, p.y, p.z);

      // Update 2 - Alphas
      let alpha = THREE.MathUtils.smoothstep(splinePos, 0, this.fadeStart);
      alpha *=
        1 -
        THREE.MathUtils.smoothstep(
          splinePos,
          totalLength - this.fadeEnd,
          totalLength
        );

      this.alphaAttribute.setX(i, alpha);
    }
    this.posAttribute.needsUpdate = true;
    this.alphaAttribute.needsUpdate = true;
  }
}
