LIGHT_CULLING_WGSL

Constant LIGHT_CULLING_WGSL 

Source
pub const LIGHT_CULLING_WGSL: &str = "// Light Culling Compute Shader for Forward+ Rendering\n// \n// This shader performs tile-based light culling by:\n// 1. Computing the frustum for each screen tile\n// 2. Testing each light against the tile frustum\n// 3. Building a per-tile light index list\n//\n// Workgroup size: 16x16 threads = 256 threads per tile\n// Each workgroup processes one screen tile\n\n// --- Uniforms and Buffers ---\n\nstruct LightCullingUniforms {\n    view_projection: mat4x4<f32>,\n    inverse_projection: mat4x4<f32>,\n    screen_dimensions: vec2<f32>,\n    tile_count: vec2<u32>,\n    num_lights: u32,\n    tile_size: u32,\n    _padding: vec2<f32>,\n};\n\n@group(0) @binding(0)\nvar<uniform> uniforms: LightCullingUniforms;\n\n// GPU-friendly light structure (64 bytes)\nstruct GpuLight {\n    position: vec3<f32>,\n    range: f32,\n    color: vec3<f32>,\n    intensity: f32,\n    direction: vec3<f32>,\n    light_type: u32,      // 0 = directional, 1 = point, 2 = spot\n    inner_cone_cos: f32,\n    outer_cone_cos: f32,\n    _padding: vec2<f32>,\n};\n\n@group(0) @binding(1)\nvar<storage, read> lights: array<GpuLight>;\n\n// Output: Per-tile light indices\n// Format: For tile at (x, y), offset = (y * tiles_x + x) * max_lights_per_tile\n@group(0) @binding(2)\nvar<storage, read_write> light_index_list: array<u32>;\n\n// Output: Per-tile (offset, count) pairs\n// Format: light_grid[tile_index * 2] = offset, light_grid[tile_index * 2 + 1] = count\n@group(0) @binding(3)\nvar<storage, read_write> light_grid: array<atomic<u32>>;\n\n// --- Workgroup Shared Memory ---\n\nconst TILE_SIZE: u32 = 16u;\nconst MAX_LIGHTS_PER_TILE: u32 = 128u;\n\nvar<workgroup> shared_light_count: atomic<u32>;\nvar<workgroup> shared_light_indices: array<u32, MAX_LIGHTS_PER_TILE>;\nvar<workgroup> tile_frustum_planes: array<vec4<f32>, 4>;  // Left, Right, Top, Bottom\nvar<workgroup> tile_min_z: atomic<u32>;\nvar<workgroup> tile_max_z: atomic<u32>;\n\n// --- Helper Functions ---\n\n// Convert clip-space NDC to view-space position\nfn ndc_to_view(ndc: vec3<f32>) -> vec3<f32> {\n    let clip = vec4<f32>(ndc, 1.0);\n    let view = uniforms.inverse_projection * clip;\n    return view.xyz / view.w;\n}\n\n// Create a frustum plane from 3 points (counter-clockwise order)\nfn create_plane(p0: vec3<f32>, p1: vec3<f32>, p2: vec3<f32>) -> vec4<f32> {\n    let v0 = p1 - p0;\n    let v1 = p2 - p0;\n    let normal = normalize(cross(v0, v1));\n    let d = -dot(normal, p0);\n    return vec4<f32>(normal, d);\n}\n\n// Test if a sphere intersects a plane (returns positive if in front)\nfn sphere_plane_distance(center: vec3<f32>, plane: vec4<f32>) -> f32 {\n    return dot(center, plane.xyz) + plane.w;\n}\n\n// Test if a light sphere intersects the tile frustum\nfn light_intersects_tile(light_pos_view: vec3<f32>, light_range: f32) -> bool {\n    // Check against all 4 frustum planes\n    for (var i = 0u; i < 4u; i++) {\n        let distance = sphere_plane_distance(light_pos_view, tile_frustum_planes[i]);\n        if (distance < -light_range) {\n            return false;  // Completely outside this plane\n        }\n    }\n    \n    // Check against depth range (z is negative in view space)\n    let min_z_f = bitcast<f32>(atomicLoad(&tile_min_z));\n    let max_z_f = bitcast<f32>(atomicLoad(&tile_max_z));\n    \n    // Light is too far (z more negative than max depth)\n    if (light_pos_view.z + light_range < -max_z_f) {\n        return false;\n    }\n    // Light is too close (z less negative than min depth)\n    if (light_pos_view.z - light_range > -min_z_f) {\n        return false;\n    }\n    \n    return true;\n}\n\n// --- Main Compute Shader ---\n\n@compute @workgroup_size(16, 16, 1)\nfn cs_main(\n    @builtin(global_invocation_id) global_id: vec3<u32>,\n    @builtin(local_invocation_id) local_id: vec3<u32>,\n    @builtin(local_invocation_index) local_index: u32,\n    @builtin(workgroup_id) workgroup_id: vec3<u32>,\n) {\n    let tile_x = workgroup_id.x;\n    let tile_y = workgroup_id.y;\n    let tile_index = tile_y * uniforms.tile_count.x + tile_x;\n    \n    // Initialize shared memory (first thread only)\n    if (local_index == 0u) {\n        atomicStore(&shared_light_count, 0u);\n        atomicStore(&tile_min_z, bitcast<u32>(1.0f32));  // Near = 1.0\n        atomicStore(&tile_max_z, bitcast<u32>(0.0f32));  // Far = 0.0\n    }\n    \n    // Construct tile frustum planes (first 4 threads)\n    if (local_index < 4u) {\n        // Calculate tile corners in NDC (-1 to 1)\n        let tile_min_ndc = vec2<f32>(\n            f32(tile_x * TILE_SIZE) / uniforms.screen_dimensions.x * 2.0 - 1.0,\n            f32(tile_y * TILE_SIZE) / uniforms.screen_dimensions.y * 2.0 - 1.0\n        );\n        let tile_max_ndc = vec2<f32>(\n            f32((tile_x + 1u) * TILE_SIZE) / uniforms.screen_dimensions.x * 2.0 - 1.0,\n            f32((tile_y + 1u) * TILE_SIZE) / uniforms.screen_dimensions.y * 2.0 - 1.0\n        );\n        \n        // Convert corners to view space (at near plane z = 1.0 in NDC)\n        let tl = ndc_to_view(vec3<f32>(tile_min_ndc.x, tile_max_ndc.y, 1.0));\n        let tr = ndc_to_view(vec3<f32>(tile_max_ndc.x, tile_max_ndc.y, 1.0));\n        let bl = ndc_to_view(vec3<f32>(tile_min_ndc.x, tile_min_ndc.y, 1.0));\n        let br = ndc_to_view(vec3<f32>(tile_max_ndc.x, tile_min_ndc.y, 1.0));\n        let origin = vec3<f32>(0.0, 0.0, 0.0);\n        \n        // Create frustum planes (normals pointing inward)\n        switch (local_index) {\n            case 0u: { tile_frustum_planes[0] = create_plane(origin, tl, bl); }  // Left\n            case 1u: { tile_frustum_planes[1] = create_plane(origin, br, tr); }  // Right\n            case 2u: { tile_frustum_planes[2] = create_plane(origin, tr, tl); }  // Top\n            case 3u: { tile_frustum_planes[3] = create_plane(origin, bl, br); }  // Bottom\n            default: {}\n        }\n    }\n    \n    workgroupBarrier();\n    \n    // --- Light Culling ---\n    // Each thread processes multiple lights in a strided pattern\n    let lights_per_thread = (uniforms.num_lights + 255u) / 256u;\n    \n    for (var i = 0u; i < lights_per_thread; i++) {\n        let light_index = local_index + i * 256u;\n        \n        if (light_index >= uniforms.num_lights) {\n            break;\n        }\n        \n        let light = lights[light_index];\n        \n        // Transform light position to view space\n        let light_world_pos = vec4<f32>(light.position, 1.0);\n        let light_clip_pos = uniforms.view_projection * light_world_pos;\n        let light_view_pos = ndc_to_view(light_clip_pos.xyz / light_clip_pos.w);\n        \n        // Directional lights always affect all tiles\n        var affects_tile = light.light_type == 0u;\n        \n        // Point and spot lights need frustum test\n        if (light.light_type != 0u) {\n            affects_tile = light_intersects_tile(light_view_pos, light.range);\n        }\n        \n        if (affects_tile) {\n            let slot = atomicAdd(&shared_light_count, 1u);\n            if (slot < MAX_LIGHTS_PER_TILE) {\n                shared_light_indices[slot] = light_index;\n            }\n        }\n    }\n    \n    workgroupBarrier();\n    \n    // --- Write Results ---\n    // First thread writes the light grid entry\n    if (local_index == 0u) {\n        let count = min(atomicLoad(&shared_light_count), MAX_LIGHTS_PER_TILE);\n        let offset = tile_index * MAX_LIGHTS_PER_TILE;\n        \n        // Store offset and count in light_grid\n        atomicStore(&light_grid[tile_index * 2u], offset);\n        atomicStore(&light_grid[tile_index * 2u + 1u], count);\n    }\n    \n    workgroupBarrier();\n    \n    // All threads write their portion of light indices\n    let count = min(atomicLoad(&shared_light_count), MAX_LIGHTS_PER_TILE);\n    let base_offset = tile_index * MAX_LIGHTS_PER_TILE;\n    \n    for (var i = local_index; i < count; i += 256u) {\n        light_index_list[base_offset + i] = shared_light_indices[i];\n    }\n}\n";
Expand description

Light culling compute shader for Forward+ rendering.

Performs tile-based light culling by:

  • Computing frustum planes for each 16x16 pixel tile
  • Testing each light against the tile frustum
  • Building a per-tile light index list

Outputs:

  • light_index_list: Indices of lights affecting each tile
  • light_grid: (offset, count) pairs per tile