pub const LIGHT_CULLING_WGSL: &str = "// Light Culling Compute Shader for Forward+ Rendering\n// \n// This shader performs tile-based light culling by:\n// 1. Computing the frustum for each screen tile\n// 2. Testing each light against the tile frustum\n// 3. Building a per-tile light index list\n//\n// Workgroup size: 16x16 threads = 256 threads per tile\n// Each workgroup processes one screen tile\n\n// --- Uniforms and Buffers ---\n\nstruct LightCullingUniforms {\n view_projection: mat4x4<f32>,\n inverse_projection: mat4x4<f32>,\n screen_dimensions: vec2<f32>,\n tile_count: vec2<u32>,\n num_lights: u32,\n tile_size: u32,\n _padding: vec2<f32>,\n};\n\n@group(0) @binding(0)\nvar<uniform> uniforms: LightCullingUniforms;\n\n// GPU-friendly light structure (64 bytes)\nstruct GpuLight {\n position: vec3<f32>,\n range: f32,\n color: vec3<f32>,\n intensity: f32,\n direction: vec3<f32>,\n light_type: u32, // 0 = directional, 1 = point, 2 = spot\n inner_cone_cos: f32,\n outer_cone_cos: f32,\n _padding: vec2<f32>,\n};\n\n@group(0) @binding(1)\nvar<storage, read> lights: array<GpuLight>;\n\n// Output: Per-tile light indices\n// Format: For tile at (x, y), offset = (y * tiles_x + x) * max_lights_per_tile\n@group(0) @binding(2)\nvar<storage, read_write> light_index_list: array<u32>;\n\n// Output: Per-tile (offset, count) pairs\n// Format: light_grid[tile_index * 2] = offset, light_grid[tile_index * 2 + 1] = count\n@group(0) @binding(3)\nvar<storage, read_write> light_grid: array<atomic<u32>>;\n\n// --- Workgroup Shared Memory ---\n\nconst TILE_SIZE: u32 = 16u;\nconst MAX_LIGHTS_PER_TILE: u32 = 128u;\n\nvar<workgroup> shared_light_count: atomic<u32>;\nvar<workgroup> shared_light_indices: array<u32, MAX_LIGHTS_PER_TILE>;\nvar<workgroup> tile_frustum_planes: array<vec4<f32>, 4>; // Left, Right, Top, Bottom\nvar<workgroup> tile_min_z: atomic<u32>;\nvar<workgroup> tile_max_z: atomic<u32>;\n\n// --- Helper Functions ---\n\n// Convert clip-space NDC to view-space position\nfn ndc_to_view(ndc: vec3<f32>) -> vec3<f32> {\n let clip = vec4<f32>(ndc, 1.0);\n let view = uniforms.inverse_projection * clip;\n return view.xyz / view.w;\n}\n\n// Create a frustum plane from 3 points (counter-clockwise order)\nfn create_plane(p0: vec3<f32>, p1: vec3<f32>, p2: vec3<f32>) -> vec4<f32> {\n let v0 = p1 - p0;\n let v1 = p2 - p0;\n let normal = normalize(cross(v0, v1));\n let d = -dot(normal, p0);\n return vec4<f32>(normal, d);\n}\n\n// Test if a sphere intersects a plane (returns positive if in front)\nfn sphere_plane_distance(center: vec3<f32>, plane: vec4<f32>) -> f32 {\n return dot(center, plane.xyz) + plane.w;\n}\n\n// Test if a light sphere intersects the tile frustum\nfn light_intersects_tile(light_pos_view: vec3<f32>, light_range: f32) -> bool {\n // Check against all 4 frustum planes\n for (var i = 0u; i < 4u; i++) {\n let distance = sphere_plane_distance(light_pos_view, tile_frustum_planes[i]);\n if (distance < -light_range) {\n return false; // Completely outside this plane\n }\n }\n \n // Check against depth range (z is negative in view space)\n let min_z_f = bitcast<f32>(atomicLoad(&tile_min_z));\n let max_z_f = bitcast<f32>(atomicLoad(&tile_max_z));\n \n // Light is too far (z more negative than max depth)\n if (light_pos_view.z + light_range < -max_z_f) {\n return false;\n }\n // Light is too close (z less negative than min depth)\n if (light_pos_view.z - light_range > -min_z_f) {\n return false;\n }\n \n return true;\n}\n\n// --- Main Compute Shader ---\n\n@compute @workgroup_size(16, 16, 1)\nfn cs_main(\n @builtin(global_invocation_id) global_id: vec3<u32>,\n @builtin(local_invocation_id) local_id: vec3<u32>,\n @builtin(local_invocation_index) local_index: u32,\n @builtin(workgroup_id) workgroup_id: vec3<u32>,\n) {\n let tile_x = workgroup_id.x;\n let tile_y = workgroup_id.y;\n let tile_index = tile_y * uniforms.tile_count.x + tile_x;\n \n // Initialize shared memory (first thread only)\n if (local_index == 0u) {\n atomicStore(&shared_light_count, 0u);\n atomicStore(&tile_min_z, bitcast<u32>(1.0f32)); // Near = 1.0\n atomicStore(&tile_max_z, bitcast<u32>(0.0f32)); // Far = 0.0\n }\n \n // Construct tile frustum planes (first 4 threads)\n if (local_index < 4u) {\n // Calculate tile corners in NDC (-1 to 1)\n let tile_min_ndc = vec2<f32>(\n f32(tile_x * TILE_SIZE) / uniforms.screen_dimensions.x * 2.0 - 1.0,\n f32(tile_y * TILE_SIZE) / uniforms.screen_dimensions.y * 2.0 - 1.0\n );\n let tile_max_ndc = vec2<f32>(\n f32((tile_x + 1u) * TILE_SIZE) / uniforms.screen_dimensions.x * 2.0 - 1.0,\n f32((tile_y + 1u) * TILE_SIZE) / uniforms.screen_dimensions.y * 2.0 - 1.0\n );\n \n // Convert corners to view space (at near plane z = 1.0 in NDC)\n let tl = ndc_to_view(vec3<f32>(tile_min_ndc.x, tile_max_ndc.y, 1.0));\n let tr = ndc_to_view(vec3<f32>(tile_max_ndc.x, tile_max_ndc.y, 1.0));\n let bl = ndc_to_view(vec3<f32>(tile_min_ndc.x, tile_min_ndc.y, 1.0));\n let br = ndc_to_view(vec3<f32>(tile_max_ndc.x, tile_min_ndc.y, 1.0));\n let origin = vec3<f32>(0.0, 0.0, 0.0);\n \n // Create frustum planes (normals pointing inward)\n switch (local_index) {\n case 0u: { tile_frustum_planes[0] = create_plane(origin, tl, bl); } // Left\n case 1u: { tile_frustum_planes[1] = create_plane(origin, br, tr); } // Right\n case 2u: { tile_frustum_planes[2] = create_plane(origin, tr, tl); } // Top\n case 3u: { tile_frustum_planes[3] = create_plane(origin, bl, br); } // Bottom\n default: {}\n }\n }\n \n workgroupBarrier();\n \n // --- Light Culling ---\n // Each thread processes multiple lights in a strided pattern\n let lights_per_thread = (uniforms.num_lights + 255u) / 256u;\n \n for (var i = 0u; i < lights_per_thread; i++) {\n let light_index = local_index + i * 256u;\n \n if (light_index >= uniforms.num_lights) {\n break;\n }\n \n let light = lights[light_index];\n \n // Transform light position to view space\n let light_world_pos = vec4<f32>(light.position, 1.0);\n let light_clip_pos = uniforms.view_projection * light_world_pos;\n let light_view_pos = ndc_to_view(light_clip_pos.xyz / light_clip_pos.w);\n \n // Directional lights always affect all tiles\n var affects_tile = light.light_type == 0u;\n \n // Point and spot lights need frustum test\n if (light.light_type != 0u) {\n affects_tile = light_intersects_tile(light_view_pos, light.range);\n }\n \n if (affects_tile) {\n let slot = atomicAdd(&shared_light_count, 1u);\n if (slot < MAX_LIGHTS_PER_TILE) {\n shared_light_indices[slot] = light_index;\n }\n }\n }\n \n workgroupBarrier();\n \n // --- Write Results ---\n // First thread writes the light grid entry\n if (local_index == 0u) {\n let count = min(atomicLoad(&shared_light_count), MAX_LIGHTS_PER_TILE);\n let offset = tile_index * MAX_LIGHTS_PER_TILE;\n \n // Store offset and count in light_grid\n atomicStore(&light_grid[tile_index * 2u], offset);\n atomicStore(&light_grid[tile_index * 2u + 1u], count);\n }\n \n workgroupBarrier();\n \n // All threads write their portion of light indices\n let count = min(atomicLoad(&shared_light_count), MAX_LIGHTS_PER_TILE);\n let base_offset = tile_index * MAX_LIGHTS_PER_TILE;\n \n for (var i = local_index; i < count; i += 256u) {\n light_index_list[base_offset + i] = shared_light_indices[i];\n }\n}\n";Expand description
Light culling compute shader for Forward+ rendering.
Performs tile-based light culling by:
- Computing frustum planes for each 16x16 pixel tile
- Testing each light against the tile frustum
- Building a per-tile light index list
Outputs:
light_index_list: Indices of lights affecting each tilelight_grid: (offset, count) pairs per tile