import { useTexture } from "@react-three/drei";
import { useFrame } from "@react-three/fiber";
import { forwardRef, useRef } from "react";
import { MathUtils, MeshStandardMaterial, ShaderMaterial } from "three";
import CustomShaderMaterial from "three-custom-shader-material";

interface MaterialProps {
  texture: string;
  shiny?: boolean;
}

const MirrorMaterial = forwardRef<MeshStandardMaterial, MaterialProps>(
  ({ texture, shiny }, ref) => {
    const map = useTexture(texture);
    const materialRef = useRef<ShaderMaterial>(null);

    const combinedRef = (instance: MeshStandardMaterial | null) => {
      if (typeof ref === "function") {
        ref(instance);
      } else if (ref) {
        ref.current = instance; // Forward ref to the parent
      }
      // @ts-expect-error - We know it's a MeshStandardMaterial
      materialRef.current = instance; // Internal ref for local use
    };

    // Animate alpha multiplier over time directly using the ref
    useFrame((_, delta) => {
      if (materialRef.current) {
        materialRef.current.uniforms.uTime.value += delta;
        materialRef.current.uniforms.uAlphaMultiplier.value = MathUtils.clamp(
          materialRef.current.uniforms.uAlphaMultiplier.value + delta,
          0,
          1,
        );
      }
    });

    return (
      <CustomShaderMaterial
        baseMaterial={MeshStandardMaterial}
        ref={combinedRef}
        map={map}
        metalness={0.9}
        roughness={0.1}
        color="white"
        uniforms={{
          uTime: { value: 0 },
          uPower: { value: shiny ? 1 : 0 },
          uAlphaMultiplier: {
            value: materialRef.current?.uniforms.uAlphaMultiplier.value ?? 0,
          },
        }}
        silent
        vertexShader={`
          #define M_PI 3.1415926535897932384626433832795
          uniform float uAlphaMultiplier;
          uniform float uTime;
          varying vec3 vPosition;
          void main() {
              //csm_Position = position + normalize(position) * 0.2 * (1.0 - (sin(uTime * M_PI / 2.0) + 1.0) / 2.0);
              vPosition = position;
          }
        `}
        fragmentShader={`
          uniform float uAlphaMultiplier;
          uniform float uTime;
          uniform float uPower;
          varying vec3 vPosition;
          void main() {
              float angle = radians(45.0);
              float rotatedY = vPosition.y * cos(angle) - vPosition.x * sin(angle);

              // Create stripes using the rotated y coordinate
              float stripes = mod((rotatedY - uTime * 0.04) * 25.0, 1.0);
              stripes = pow(stripes, 3.0);

              float stripePower = (stripes * uPower);

              // #E63223
              vec3 color1 = vec3(0.9, 0.2, 0.1);

              vec3 color3 = vec3(1.0, 1.0, 1.0);

              // #D2B4FF
              vec3 color2 = vec3(0.82, 0.71, 1.0);

              float t = (sin(uTime) + 1.0) / 2.0; // Value between 0 and 1
              vec3 mixColor = mix(
                mix(color1, color3, smoothstep(0.0, 0.5, t)), // First mix (color1 -> white)
                mix(color3, color2, smoothstep(0.5, 1.0, t)), // Second mix (white -> color2)
                step(0.5, t));

              csm_Emissive = mixColor * stripePower;
              csm_DiffuseColor = mix(vec4(1.0), csm_DiffuseColor, uAlphaMultiplier);
          }
        `}
      />
    );
  },
);

MirrorMaterial.displayName = "MirrorMaterial";

export default MirrorMaterial;
