khora_lanes/render_lane/
forward_plus_lane.rs

1// Copyright 2025 eraflo
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Forward+ (Tiled Forward) rendering lane implementation.
16//!
17//! This module implements a Forward+ rendering strategy that uses a compute shader
18//! to perform per-tile light culling before the main render pass. This approach
19//! significantly reduces the number of lights processed per fragment, making it
20//! ideal for scenes with many lights (>20).
21//!
22//! # Architecture
23//!
24//! The Forward+ pipeline works in two stages:
25//!
26//! 1. **Light Culling (Compute Pass)**: The screen is divided into tiles (16x16 pixels).
27//!    For each tile, the compute shader determines which lights intersect the tile's
28//!    frustum and builds a list of affecting light indices.
29//!
30//! 2. **Rendering (Render Pass)**: Each fragment looks up its tile's light list and
31//!    only evaluates lighting for those specific lights, rather than all lights in the scene.
32//!
33//! # Performance Characteristics
34//!
35//! - **O(tiles × lights)** for light culling (compute pass)
36//! - **O(fragments × lights_per_tile)** for shading (render pass)
37//! - **Suitable for**: Scenes with many lights (>20)
38//! - **Break-even point**: ~20 lights (vs standard forward rendering)
39//!
40//! # SAA Compliance (Symbiotic Adaptive Architecture)
41//!
42//! This lane integrates with GORNA through:
43//! - `estimate_cost()`: Provides accurate cost estimation including compute overhead
44//! - Configurable tile size and max lights per tile
45//! - Runtime-adjustable configuration via `ForwardPlusTileConfig`
46
47use super::RenderWorld;
48use crate::render_lane::{RenderLane, ShaderComplexity};
49
50use khora_core::{
51    asset::{AssetUUID, Material},
52    renderer::{
53        api::{
54            buffer::BufferId,
55            command::{
56                ComputePassDescriptor, LoadOp, Operations, RenderPassColorAttachment,
57                RenderPassDescriptor, StoreOp,
58            },
59            PrimitiveTopology,
60        },
61        traits::CommandEncoder,
62        BindGroupId, ComputePipelineId, ForwardPlusTileConfig, GpuMesh, RenderContext,
63        RenderPipelineId,
64    },
65};
66use khora_data::assets::Assets;
67use std::sync::RwLock;
68
69// --- Cost Estimation Constants ---
70
71/// Base cost per triangle rendered.
72const TRIANGLE_COST: f32 = 0.001;
73
74/// Cost per draw call issued.
75const DRAW_CALL_COST: f32 = 0.1;
76
77/// Fixed overhead for the compute pass (tile frustum + dispatch).
78const COMPUTE_PASS_OVERHEAD: f32 = 0.5;
79
80/// Cost factor per tile in the light culling pass.
81const PER_TILE_COST: f32 = 0.0001;
82
83/// Cost factor per light-tile intersection test.
84const LIGHT_TILE_TEST_COST: f32 = 0.00001;
85
86// --- ForwardPlusLane ---
87
88/// GPU resource handles for the Forward+ compute pass.
89///
90/// These are created during lane initialization and used each frame
91/// for light culling and rendering.
92#[derive(Debug, Clone, Default)]
93pub struct ForwardPlusGpuResources {
94    /// Light buffer containing all GpuLight instances.
95    pub light_buffer: Option<BufferId>,
96    /// Buffer containing per-tile light index lists.
97    pub light_index_buffer: Option<BufferId>,
98    /// Buffer containing (offset, count) pairs per tile.
99    pub light_grid_buffer: Option<BufferId>,
100    /// Uniform buffer for culling parameters.
101    pub culling_uniforms_buffer: Option<BufferId>,
102    /// Bind group for the culling compute shader.
103    pub culling_bind_group: Option<BindGroupId>,
104    /// Bind group for the forward pass (light data).
105    pub forward_bind_group: Option<BindGroupId>,
106    /// Compute pipeline for light culling.
107    pub culling_pipeline: Option<ComputePipelineId>,
108    /// Render pipeline for the Forward+ pass.
109    pub render_pipeline: Option<RenderPipelineId>,
110}
111
112impl ForwardPlusGpuResources {
113    /// Returns true if all required resources are initialized.
114    pub fn is_initialized(&self) -> bool {
115        self.light_buffer.is_some()
116            && self.light_index_buffer.is_some()
117            && self.light_grid_buffer.is_some()
118            && self.culling_uniforms_buffer.is_some()
119            && self.culling_bind_group.is_some()
120            && self.culling_pipeline.is_some()
121    }
122}
123
124/// A rendering lane that implements Forward+ (Tiled Forward) rendering.
125///
126/// Forward+ divides the screen into tiles and uses a compute shader to determine
127/// which lights affect each tile before the main render pass. This significantly
128/// reduces per-fragment lighting cost for scenes with many lights.
129///
130/// # Configuration
131///
132/// The lane is configured via `ForwardPlusTileConfig`, which controls:
133/// - **Tile size**: 16x16 or 32x32 pixels (trade-off between culling granularity and overhead)
134/// - **Max lights per tile**: Memory budget for per-tile light lists
135/// - **Depth pre-pass**: Optional optimization for depth-bounded light culling
136#[derive(Debug, Clone)]
137pub struct ForwardPlusLane {
138    /// Tile configuration for light culling.
139    pub tile_config: ForwardPlusTileConfig,
140
141    /// Shader complexity for cost estimation.
142    pub shader_complexity: ShaderComplexity,
143
144    /// Current screen dimensions (for tile count calculation).
145    screen_size: (u32, u32),
146
147    /// GPU resources for compute and render passes.
148    pub gpu_resources: ForwardPlusGpuResources,
149}
150
151impl Default for ForwardPlusLane {
152    fn default() -> Self {
153        Self {
154            tile_config: ForwardPlusTileConfig::default(),
155            shader_complexity: ShaderComplexity::SimpleLit,
156            screen_size: (1920, 1080),
157            gpu_resources: ForwardPlusGpuResources::default(),
158        }
159    }
160}
161
162impl ForwardPlusLane {
163    /// Creates a new `ForwardPlusLane` with default settings.
164    ///
165    /// Default configuration:
166    /// - Tile size: 16x16 pixels
167    /// - Max lights per tile: 128
168    /// - No depth pre-pass
169    pub fn new() -> Self {
170        Self::default()
171    }
172
173    /// Creates a new `ForwardPlusLane` with the specified configuration.
174    ///
175    /// # Arguments
176    ///
177    /// * `config` - The tile configuration for light culling
178    pub fn with_config(config: ForwardPlusTileConfig) -> Self {
179        Self {
180            tile_config: config,
181            ..Default::default()
182        }
183    }
184
185    /// Creates a new `ForwardPlusLane` with the specified shader complexity.
186    pub fn with_complexity(complexity: ShaderComplexity) -> Self {
187        Self {
188            shader_complexity: complexity,
189            ..Default::default()
190        }
191    }
192
193    /// Updates the screen size used for tile calculations.
194    ///
195    /// This should be called when the window is resized to recalculate
196    /// tile counts and buffer sizes.
197    pub fn set_screen_size(&mut self, width: u32, height: u32) {
198        self.screen_size = (width, height);
199    }
200
201    /// Calculates the number of tiles in each dimension.
202    pub fn tile_count(&self) -> (u32, u32) {
203        self.tile_config
204            .tile_dimensions(self.screen_size.0, self.screen_size.1)
205    }
206
207    /// Calculates the total number of tiles on screen.
208    pub fn total_tiles(&self) -> u32 {
209        let (tiles_x, tiles_y) = self.tile_count();
210        tiles_x * tiles_y
211    }
212
213    /// Returns the effective number of lights in the scene.
214    ///
215    /// This counts all light types (directional, point, spot) that will be
216    /// processed by the light culling pass.
217    pub fn effective_light_count(&self, render_world: &RenderWorld) -> usize {
218        render_world.directional_light_count()
219            + render_world.point_light_count()
220            + render_world.spot_light_count()
221    }
222
223    /// Estimates the cost of the compute pass (light culling).
224    fn compute_pass_cost(&self, render_world: &RenderWorld) -> f32 {
225        let total_tiles = self.total_tiles() as f32;
226        let light_count = self.effective_light_count(render_world) as f32;
227
228        // Compute pass cost = overhead + per-tile cost + light-tile tests
229        COMPUTE_PASS_OVERHEAD
230            + (total_tiles * PER_TILE_COST)
231            + (total_tiles * light_count * LIGHT_TILE_TEST_COST)
232    }
233
234    /// Calculates the per-fragment light cost factor.
235    ///
236    /// For Forward+, this uses sqrt(total_lights) instead of linear scaling
237    /// because lights are culled per-tile, so each fragment only processes
238    /// a subset of lights.
239    fn fragment_light_factor(&self, render_world: &RenderWorld) -> f32 {
240        let total_lights = self.effective_light_count(render_world) as f32;
241
242        if total_lights == 0.0 {
243            return 1.0;
244        }
245
246        // Sublinear scaling: sqrt(lights) because of tile culling
247        // Clamped to max_lights_per_tile
248        let effective_lights = total_lights
249            .sqrt()
250            .min(self.tile_config.max_lights_per_tile as f32);
251
252        1.0 + (effective_lights * 0.02)
253    }
254}
255
256impl RenderLane for ForwardPlusLane {
257    fn strategy_name(&self) -> &'static str {
258        "ForwardPlus"
259    }
260
261    fn get_pipeline_for_material(
262        &self,
263        material_uuid: Option<AssetUUID>,
264        materials: &Assets<Box<dyn Material>>,
265    ) -> RenderPipelineId {
266        // Verify material exists (same logic as LitForwardLane)
267        if let Some(uuid) = material_uuid {
268            if materials.get(&uuid).is_none() {
269                let _ = uuid;
270            }
271        }
272
273        // Forward+ uses pipeline ID 2 to differentiate from standard forward (ID 1)
274        // This will be the pipeline created with the forward_plus.wgsl shader
275        RenderPipelineId(2)
276    }
277
278    fn render(
279        &self,
280        render_world: &RenderWorld,
281        encoder: &mut dyn CommandEncoder,
282        render_ctx: &RenderContext,
283        gpu_meshes: &RwLock<Assets<GpuMesh>>,
284        materials: &RwLock<Assets<Box<dyn Material>>>,
285    ) {
286        // Acquire read locks
287        let gpu_mesh_assets = gpu_meshes.read().unwrap();
288        let material_assets = materials.read().unwrap();
289
290        // Pre-compute pipelines
291        let pipelines: Vec<RenderPipelineId> = render_world
292            .meshes
293            .iter()
294            .map(|mesh| self.get_pipeline_for_material(mesh.material_uuid, &material_assets))
295            .collect();
296
297        // --- STAGE 1: Light Culling (Compute Pass) ---
298        let (tiles_x, tiles_y) = self.tile_count();
299
300        // Only dispatch compute pass if GPU resources are initialized
301        if self.gpu_resources.is_initialized() {
302            let compute_pass_desc = ComputePassDescriptor {
303                label: Some("Forward+ Light Culling"),
304                timestamp_writes: None,
305            };
306
307            {
308                let mut compute_pass = encoder.begin_compute_pass(&compute_pass_desc);
309
310                // Set compute pipeline
311                if let Some(ref pipeline) = self.gpu_resources.culling_pipeline {
312                    compute_pass.set_pipeline(pipeline);
313                }
314
315                // Bind resources: uniforms, lights, light_index_list, light_grid
316                if let Some(ref bind_group) = self.gpu_resources.culling_bind_group {
317                    compute_pass.set_bind_group(0, bind_group);
318                }
319
320                // Dispatch one workgroup per tile
321                compute_pass.dispatch_workgroups(tiles_x, tiles_y, 1);
322            }
323            // compute_pass dropped here, ending the pass
324        }
325
326        // --- STAGE 2: Rendering (Render Pass) ---
327
328        let color_attachment = RenderPassColorAttachment {
329            view: render_ctx.color_target,
330            resolve_target: None,
331            ops: Operations {
332                load: LoadOp::Clear(render_ctx.clear_color),
333                store: StoreOp::Store,
334            },
335        };
336
337        let render_pass_desc = RenderPassDescriptor {
338            label: Some("Forward+ Render Pass"),
339            color_attachments: &[color_attachment],
340        };
341
342        let mut render_pass = encoder.begin_render_pass(&render_pass_desc);
343
344        // Track current pipeline to avoid redundant state changes
345        let mut current_pipeline: Option<RenderPipelineId> = None;
346
347        // Bind Forward+ lighting data (lights, indices, grid, tile info)
348        // This bind group is created from the compute pass outputs
349        if let Some(ref forward_bg) = self.gpu_resources.forward_bind_group {
350            render_pass.set_bind_group(3, forward_bg);
351        }
352
353        // Iterate over all extracted meshes
354        for (i, extracted_mesh) in render_world.meshes.iter().enumerate() {
355            if let Some(gpu_mesh) = gpu_mesh_assets.get(&extracted_mesh.gpu_mesh_uuid) {
356                let pipeline = &pipelines[i];
357
358                // Only bind pipeline if changed
359                if current_pipeline != Some(*pipeline) {
360                    render_pass.set_pipeline(pipeline);
361                    current_pipeline = Some(*pipeline);
362                }
363
364                // Bind buffers
365                render_pass.set_vertex_buffer(0, &gpu_mesh.vertex_buffer, 0);
366                render_pass.set_index_buffer(&gpu_mesh.index_buffer, 0, gpu_mesh.index_format);
367
368                // Issue draw call
369                render_pass.draw_indexed(0..gpu_mesh.index_count, 0, 0..1);
370            }
371        }
372    }
373
374    fn estimate_cost(
375        &self,
376        render_world: &RenderWorld,
377        gpu_meshes: &RwLock<Assets<GpuMesh>>,
378    ) -> f32 {
379        let gpu_mesh_assets = gpu_meshes.read().unwrap();
380
381        let mut total_triangles = 0u32;
382        let mut draw_call_count = 0u32;
383
384        for extracted_mesh in &render_world.meshes {
385            if let Some(gpu_mesh) = gpu_mesh_assets.get(&extracted_mesh.gpu_mesh_uuid) {
386                let triangle_count = match gpu_mesh.primitive_topology {
387                    PrimitiveTopology::TriangleList => gpu_mesh.index_count / 3,
388                    PrimitiveTopology::TriangleStrip => {
389                        if gpu_mesh.index_count >= 3 {
390                            gpu_mesh.index_count - 2
391                        } else {
392                            0
393                        }
394                    }
395                    PrimitiveTopology::LineList
396                    | PrimitiveTopology::LineStrip
397                    | PrimitiveTopology::PointList => 0,
398                };
399
400                total_triangles += triangle_count;
401                draw_call_count += 1;
402            }
403        }
404
405        // Base geometry cost
406        let geometry_cost =
407            (total_triangles as f32 * TRIANGLE_COST) + (draw_call_count as f32 * DRAW_CALL_COST);
408
409        // Shader complexity multiplier
410        let shader_multiplier = self.shader_complexity.cost_multiplier();
411
412        // Compute pass overhead
413        let compute_cost = self.compute_pass_cost(render_world);
414
415        // Per-fragment light factor (sublinear for Forward+)
416        let light_factor = self.fragment_light_factor(render_world);
417
418        // Total cost
419        compute_cost + (geometry_cost * shader_multiplier * light_factor)
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426    use khora_core::renderer::TileSize;
427
428    #[test]
429    fn test_forward_plus_lane_creation() {
430        let lane = ForwardPlusLane::new();
431        assert_eq!(lane.tile_config.tile_size, TileSize::X16);
432        assert_eq!(lane.tile_config.max_lights_per_tile, 128);
433        assert_eq!(lane.shader_complexity, ShaderComplexity::SimpleLit);
434    }
435
436    #[test]
437    fn test_forward_plus_lane_with_config() {
438        let config = ForwardPlusTileConfig {
439            tile_size: TileSize::X32,
440            max_lights_per_tile: 256,
441            use_depth_prepass: true,
442        };
443        let lane = ForwardPlusLane::with_config(config);
444
445        assert_eq!(lane.tile_config.tile_size, TileSize::X32);
446        assert_eq!(lane.tile_config.max_lights_per_tile, 256);
447        assert!(lane.tile_config.use_depth_prepass);
448    }
449
450    #[test]
451    fn test_tile_count_calculation() {
452        let mut lane = ForwardPlusLane::new();
453        lane.set_screen_size(1920, 1080);
454
455        let (tiles_x, tiles_y) = lane.tile_count();
456        assert_eq!(tiles_x, 120); // 1920 / 16
457        assert_eq!(tiles_y, 68); // ceil(1080 / 16)
458    }
459
460    #[test]
461    fn test_strategy_name() {
462        let lane = ForwardPlusLane::new();
463        assert_eq!(lane.strategy_name(), "ForwardPlus");
464    }
465
466    #[test]
467    fn test_pipeline_id() {
468        let lane = ForwardPlusLane::new();
469        let materials = Assets::<Box<dyn Material>>::new();
470        let pipeline = lane.get_pipeline_for_material(None, &materials);
471        assert_eq!(pipeline, RenderPipelineId(2));
472    }
473}