// - updatePhysics: updates debris particles
// - renderCS: raymarches the scene into a UAV output (swapchain backbuffer)

#define MAX_BALLS 12
#define MAX_BRICKS 1024

struct GameState {
    float paddleX;
    float time;
    float resW;
    float resH;
    float lightInt;
    float debugMode;
    float paddleVel;
    float spawnTrigger;
    float spawnX;
    float spawnY;
    float maxRaySteps;
    float debrisCount;
    float brickCount;
    float ballBrightness;
    float autoplay;
    float brickWidth;
    float brickHeight;
    float padding1;
    float padding2;
    float padding3;
    float4 ballPos[MAX_BALLS];
    float4 ballCol[MAX_BALLS];
    float4 padTail[3]; // align to 512 bytes
};

struct Particle {
    float4 pos;
    float4 vel;
    float4 color;
};

struct Brick {
    float4 pos;
};

cbuffer GameStateCB : register(b0) { GameState state; }

StructuredBuffer<Brick> bricks : register(t0);
RWTexture2D<float4> outputTex : register(u0);
RWStructuredBuffer<Particle> debris : register(u1);

float hash12(float2 p) {
    float3 p3 = frac(float3(p.x, p.y, p.x) * 0.1031);
    p3 += dot(p3, p3.yzx + 33.33);
    return frac((p3.x + p3.y) * p3.z);
}

float sdBox(float3 p, float3 b) {
    float3 q = abs(p) - b;
    return length(max(q, 0.0)) + min(max(q.x, max(q.y, q.z)), 0.0);
}

float2 rot2(float2 p, float a) {
    float c = cos(a);
    float s = sin(a);
    return float2(p.x * c - p.y * s, p.x * s + p.y * c);
}

float stepf(float a, float x) { return x >= a ? 1.0 : 0.0; }

float mapOpaque(in float3 p, bool includePaddle, bool isBallShadow, out float closestMat, out float matAuxY, out float matAuxZ) {
    float res = 1000.0;
    closestMat = 0.0;
    matAuxY = 0.0;
    matAuxZ = 0.0;

    float dFloor = p.y + 15.0;
    if (dFloor < res) {
        res = dFloor;
        closestMat = 0.0;
        matAuxY = p.x;
        matAuxZ = p.z;
    }

    float bW = state.brickWidth;
    float bH = state.brickHeight;
    uint bCount = (uint)state.brickCount;
    [loop]
    for (uint i = 0; i < bCount; ++i) {
        Brick b = bricks[i];
        if (b.pos.w > 0.5) {
            float dx = p.x - b.pos.x;
            float dy = p.y - b.pos.y;
            float dBrick = sdBox(float3(dx, dy, p.z), float3(bW, bH, 0.6)) - 0.05;
            if (dBrick < res) {
                res = dBrick;
                closestMat = 3.0;
                matAuxY = (float)i;
                matAuxZ = b.pos.z;
            }
        }
    }

    float dWallL = sdBox(p - float3(-21.0, 0.0, 0.0), float3(1.0, 20.0, 5.0));
    float dWallR = sdBox(p - float3(21.0, 0.0, 0.0), float3(1.0, 20.0, 5.0));
    float dWalls = min(dWallL, dWallR);
    if (dWalls < res) {
        res = dWalls;
        closestMat = 4.0;
    }

    uint loopLimit = (uint)state.debrisCount;
    if (loopLimit > 0) {
        [loop]
        for (uint k = 0; k < loopLimit; ++k) {
            Particle part = debris[k];
            if (part.pos.w > 0.5) {
                if (isBallShadow) {
                    bool safe = false;
                    [unroll]
                    for (uint b = 0; b < MAX_BALLS; ++b) {
                        if (state.ballPos[b].z > 0.0) {
                            float dx = part.pos.x - state.ballPos[b].x;
                            float dy = part.pos.y - state.ballPos[b].y;
                            if ((dx * dx + dy * dy) < 0.52) { safe = true; }
                        }
                    }
                    if (safe) { continue; }
                }
                float3 boxP = p - part.pos.xyz;
                float angle = part.color.a;
                float2 r1 = rot2(boxP.xz, angle);
                boxP.x = r1.x;
                boxP.z = r1.y;
                float2 r2 = rot2(boxP.xy, angle);
                boxP.x = r2.x;
                boxP.y = r2.y;
                float dDeb = sdBox(boxP, float3(0.25, 0.25, 0.25)) - 0.02;
                if (dDeb < res) {
                    res = dDeb;
                    closestMat = 5.0;
                    matAuxY = (float)k;
                    matAuxZ = 0.0;
                }
            }
        }
    }

    if (includePaddle) {
        float3 paddlePos = float3(state.paddleX, -12.0, 0.0);
        float dPaddle = sdBox(p - paddlePos, float3(3.0, 0.6, 2.5)) - 0.1;
        if (dPaddle < res) {
            res = dPaddle;
            closestMat = 1.0;
        }
    }

    return res;
}

float4 mapScene(in float3 p) {
    float closestMat, auxY, auxZ;
    float res = mapOpaque(p, true, false, closestMat, auxY, auxZ);
    [unroll]
    for (uint i = 0; i < MAX_BALLS; ++i) {
        if (state.ballPos[i].z > 0.0) {
            float dBall = length(p - state.ballPos[i].xyz) - 0.7;
            if (dBall < res) {
                res = dBall;
                closestMat = 2.0;
                auxY = (float)i;
                auxZ = 0.0;
            }
        }
    }
    return float4(res, closestMat, auxY, auxZ);
}

float3 calcNormal(float3 p) {
    float2 e = float2(0.001, 0.0);
    float d = mapScene(p).x;
    return normalize(float3(
        mapScene(p + float3(e.x, e.y, e.y)).x - d,
        mapScene(p + float3(e.y, e.x, e.y)).x - d,
        mapScene(p + float3(e.y, e.y, e.x)).x - d));
}

bool hitPaddle(float3 ro, float3 rd, float maxDist) {
    float3 boxPos = float3(state.paddleX, -12.0, 0.0);
    float3 boxSize = float3(3.1, 0.7, 2.6);
    float3 m = 1.0 / rd;
    float3 n = m * (ro - boxPos);
    float3 k = abs(m) * boxSize;
    float3 t1 = -n - k;
    float3 t2 = -n + k;
    float tN = max(max(t1.x, t1.y), t1.z);
    float tF = min(min(t2.x, t2.y), t2.z);
    return tN < tF && tF > 0.0 && tN < maxDist;
}

float2 softShadowSteps(float3 ro, float3 rd, float k, float maxDist, float noise, bool includePaddle, bool isBallShadow) {
    float res = 1.0;
    float t = 0.05 + (noise * 0.02);
    float steps = 0.0;
    [loop]
    for (uint i = 0; i < 160; ++i) {
        steps += 1.0;
        float closestMat, auxY, auxZ;
        float h = mapOpaque(ro + rd * t, includePaddle, isBallShadow, closestMat, auxY, auxZ);
        res = min(res, k * h / t);
        t += clamp(h, 0.02, 0.5);
        if (res < 0.001 || t > maxDist) { break; }
    }
    return float2(clamp(res, 0.0, 1.0), steps);
}

[numthreads(64, 1, 1)]
void updatePhysics(uint3 id : SV_DispatchThreadID) {
    uint idx = id.x;
    if (idx >= 8192) return;
    if (idx >= (uint)state.debrisCount) return;

    if (state.spawnTrigger > 0.5) {
        uint rawVal = (uint)state.spawnTrigger;
        int cType = (int)(rawVal / 100);
        uint blockStart = (rawVal % 100 - 1u) * 8u;

        if (idx >= blockStart && idx < blockStart + 8u) {
            Particle p = debris[idx];
            float rnd = hash12(float2((float)idx, state.time));
            float rnd2 = hash12(float2(state.time, (float)idx));
            p.pos = float4(state.spawnX + (rnd - 0.5) * 2.0, state.spawnY + (rnd2 - 0.5) * 1.0, 0.0, 1.0);
            p.vel = float4((rnd - 0.5) * 0.4, (rnd2) * 0.3 + 0.1, (hash12(float2(rnd, rnd2)) - 0.5) * 0.4, (rnd - 0.5) * 0.8);
            p.color.a = rnd * 6.28;
            if (cType == 0) { p.color = float4(1.0, 0.1, 0.1, 1.0); }
            else if (cType == 1) { p.color = float4(1.0, 0.5, 0.0, 1.0); }
            else if (cType == 2) { p.color = float4(0.1, 1.0, 0.1, 1.0); }
            else if (cType == 3) { p.color = float4(0.0, 0.5, 1.0, 1.0); }
            else if (cType == 4) { p.color = float4(0.7, 0.0, 1.0, 1.0); }
            else if (cType == 5) { p.color = float4(0.9, 0.9, 0.9, 1.0); }
            else { p.color = float4(0.4, 0.4, 0.4, 1.0); }
            debris[idx] = p;
            return;
        }
    }

    Particle p = debris[idx];
    if (p.pos.w < 0.5) return;
    p.vel.y -= 0.015;
    p.pos.x += p.vel.x; p.pos.y += p.vel.y; p.pos.z += p.vel.z;
    p.color.a += p.vel.w;
    p.vel *= 0.98;
    if (p.pos.x > 19.5) { p.pos.x = 19.5; p.vel.x *= -0.6; }
    if (p.pos.x < -19.5) { p.pos.x = -19.5; p.vel.x *= -0.6; }
    if (p.pos.y < -14.5) {
        p.pos.y = -14.5; p.vel.y = -p.vel.y * 0.4;
        p.vel.x *= 0.8; p.vel.z *= 0.8; p.vel.w *= 0.8;
        if (abs(p.vel.y) < 0.05) { p.vel.y = 0.0; }
    }
    float padTop = -11.4;
    float padL = state.paddleX - 3.2;
    float padR = state.paddleX + 3.2;
    if (p.pos.y < padTop + 0.5 && p.pos.y > padTop - 1.0) {
        if (p.pos.x > padL && p.pos.x < padR && abs(p.pos.z) < 2.5) {
            if (p.vel.y < 0.0) {
                p.pos.y = padTop + 0.5; p.vel.y = -p.vel.y * 0.3;
                p.vel.x += state.paddleVel * 0.3; p.vel.w += (p.vel.x * 0.2);
            }
        }
    }
    debris[idx] = p;
}

[numthreads(8, 8, 1)]
void renderCS(uint3 id : SV_DispatchThreadID) {
    uint2 dims;
    outputTex.GetDimensions(dims.x, dims.y);
    if (id.x >= dims.x || id.y >= dims.y) return;

    float2 uv = (float2(id.xy) - float2(dims) * 0.5) / (float)dims.y;
    float noiseVal = hash12(float2((float)id.x, (float)id.y) + float2(state.time, state.time));
    float3 ro = float3(0.0, 0.0, 48.0);
    float3 rd = normalize(float3(uv.x, -uv.y, -1.2));

    float t = 0.0;
    bool hit = false;
    float4 mData = float4(0.0, 0.0, 0.0, 0.0);
    float totalSteps = 0.0;
    int limit = (int)state.maxRaySteps;
    [loop]
    for (int i = 0; i < limit; ++i) {
        totalSteps += 1.0;
        float4 res = mapScene(ro + rd * t);
        if (res.x < 0.001) { hit = true; mData = res; break; }
        t += res.x;
        if (t > 100.0) { break; }
    }

    float3 col = float3(0.01, 0.01, 0.02);
    if (hit) {
        float3 p = ro + rd * t;
        float3 n = calcNormal(p);
        int mat = (int)mData.y;
        bool shadowCasterActive = false;
        [unroll]
        for (uint i = 0; i < MAX_BALLS; ++i) {
            if (state.ballPos[i].z > 0.0 && state.ballPos[i].y > -10.8) { shadowCasterActive = true; }
        }

        float3 lightPos = float3(10.0, 30.0, 30.0);
        float3 l = normalize(lightPos - p);
        float diff = max(dot(n, l), 0.0);
        float3 shadowOrigin = p + n * 0.05;
        float2 shaRes1 = softShadowSteps(shadowOrigin, l, 8.0, 50.0, noiseVal, shadowCasterActive, false);
        totalSteps += shaRes1.y;
        float3 globalLight = float3(1.0, 1.0, 1.0) * diff * shaRes1.x * state.lightInt;

        float3 activeBallLight = float3(0.0, 0.0, 0.0);
        if (mat != 2) {
            [unroll]
            for (uint i = 0; i < MAX_BALLS; ++i) {
                if (state.ballPos[i].z > 0.0) {
                    float3 ballP = state.ballPos[i].xyz;
                    float3 vecToBall = ballP - p;
                    float distToBall = length(vecToBall);
                    float3 lBall = normalize(vecToBall);
                    float att = 1.0 / (1.0 + distToBall * distToBall * 0.2);
                    float diffBall = max(dot(n, lBall), 0.0);
                    float2 shaRes2 = softShadowSteps(shadowOrigin, lBall, 8.0, distToBall - 0.5, noiseVal, shadowCasterActive, true);
                    totalSteps += shaRes2.y;
                    float3 lightColor = state.ballCol[i].rgb * state.ballBrightness;
                    if (hitPaddle(shadowOrigin, lBall, distToBall - 0.5)) {
                        lightColor = float3(0.2, 0.8, 2.0) * 0.8 * state.ballBrightness;
                    }
                    activeBallLight += lightColor * diffBall * att * shaRes2.x * 2.0;
                }
            }
        }

        float3 albedo = float3(0.5, 0.5, 0.5);
        float3 em = float3(0.0, 0.0, 0.0);
        if (mat == 0) {
            float thick = 0.02 + (t * 0.0015);
            float gridTrig = 1.0 - thick;
            float grid = stepf(gridTrig, frac(p.x * 0.5)) + stepf(gridTrig, frac(p.z * 0.5));
            albedo = float3(0.05, 0.05, 0.05) + float3(0.2, 0.2, 0.2) * grid * float3(0.0, 1.0, 1.0);
        } else if (mat == 1) {
            albedo = float3(0.1, 0.6, 0.8);
            if (!shadowCasterActive) {
                em = float3(0.1, 0.7, 1.0) * 3.0;
            } else {
                em = albedo * 0.5;
                float fresnel = pow(1.0 - max(0.0, dot(n, -rd)), 3.0);
                em += float3(0.5, 0.8, 1.0) * fresnel * 1.0;
            }
        } else if (mat == 2) {
            int bIdx = (int)mData.z;
            albedo = float3(0.0, 0.0, 0.0);
            em = state.ballCol[bIdx].rgb * 2.0 * state.ballBrightness;
        } else if (mat == 3) {
            int bType = (int)mData.w;
            if (bType == 0) { albedo = float3(1.0, 0.1, 0.1); }
            else if (bType == 1) { albedo = float3(1.0, 0.5, 0.0); }
            else if (bType == 2) { albedo = float3(0.1, 1.0, 0.1); }
            else if (bType == 3) { albedo = float3(0.0, 0.5, 1.0); }
            else if (bType == 4) { albedo = float3(0.7, 0.0, 1.0); }
            else if (bType == 5) { albedo = float3(0.9, 0.9, 0.9); }
            else { albedo = float3(0.4, 0.4, 0.4); }
            em = albedo * 0.3;
        } else if (mat == 4) {
            albedo = float3(0.1, 0.1, 0.1);
            float repeatY = frac(p.y * 0.5);
            float thickness = 0.05 + (t * 0.002);
            float strip = stepf(1.0 - thickness, repeatY);
            em = float3(0.0, 0.5, 1.0) * strip * 2.0;
        } else if (mat == 5) {
            uint chunkIdx = (uint)mData.z;
            Particle part = debris[chunkIdx];
            albedo = part.color.rgb;
            em = albedo * 0.2;
        }

        float3 ambient = (float3(0.05, 0.05, 0.05) + state.lightInt * 0.05) * albedo;
        col = ambient + (albedo * globalLight) + activeBallLight + em;
    }

    [unroll]
    for (uint i = 0; i < MAX_BALLS; ++i) {
        if (state.ballPos[i].z > 0.0) {
            float3 ballP = state.ballPos[i].xyz;
            float3 rayToBall = cross(rd, ballP - ro);
            float distToRay = length(rayToBall);
            float3 glow = state.ballCol[i].rgb * (0.3 / (0.05 + distToRay * distToRay)) * state.ballBrightness;
            if (!hit || t > length(ballP - ro)) { col += glow; }
        }
    }

    float a = 2.51; float b = 0.03; float c = 2.43; float d = 0.59; float e = 0.14;
    col = clamp((col * (a * col + b)) / (col * (c * col + d) + e), 0.0, 1.0);
    col = pow(col, 1.0 / 2.2);
    float dist = length(uv);
    col *= 1.0 - dist * 0.3;
    outputTex[id.xy] = float4(col, 1.0);
}
