1use super::RenderWorld;
48use crate::render_lane::{RenderLane, ShaderComplexity};
49
50use khora_core::{
51 asset::{AssetUUID, Material},
52 renderer::{
53 api::{
54 buffer::BufferId,
55 command::{
56 ComputePassDescriptor, LoadOp, Operations, RenderPassColorAttachment,
57 RenderPassDescriptor, StoreOp,
58 },
59 PrimitiveTopology,
60 },
61 traits::CommandEncoder,
62 BindGroupId, ComputePipelineId, ForwardPlusTileConfig, GpuMesh, RenderContext,
63 RenderPipelineId,
64 },
65};
66use khora_data::assets::Assets;
67use std::sync::RwLock;
68
69const TRIANGLE_COST: f32 = 0.001;
73
74const DRAW_CALL_COST: f32 = 0.1;
76
77const COMPUTE_PASS_OVERHEAD: f32 = 0.5;
79
80const PER_TILE_COST: f32 = 0.0001;
82
83const LIGHT_TILE_TEST_COST: f32 = 0.00001;
85
86#[derive(Debug, Clone, Default)]
93pub struct ForwardPlusGpuResources {
94 pub light_buffer: Option<BufferId>,
96 pub light_index_buffer: Option<BufferId>,
98 pub light_grid_buffer: Option<BufferId>,
100 pub culling_uniforms_buffer: Option<BufferId>,
102 pub culling_bind_group: Option<BindGroupId>,
104 pub forward_bind_group: Option<BindGroupId>,
106 pub culling_pipeline: Option<ComputePipelineId>,
108 pub render_pipeline: Option<RenderPipelineId>,
110}
111
112impl ForwardPlusGpuResources {
113 pub fn is_initialized(&self) -> bool {
115 self.light_buffer.is_some()
116 && self.light_index_buffer.is_some()
117 && self.light_grid_buffer.is_some()
118 && self.culling_uniforms_buffer.is_some()
119 && self.culling_bind_group.is_some()
120 && self.culling_pipeline.is_some()
121 }
122}
123
124#[derive(Debug, Clone)]
137pub struct ForwardPlusLane {
138 pub tile_config: ForwardPlusTileConfig,
140
141 pub shader_complexity: ShaderComplexity,
143
144 screen_size: (u32, u32),
146
147 pub gpu_resources: ForwardPlusGpuResources,
149}
150
151impl Default for ForwardPlusLane {
152 fn default() -> Self {
153 Self {
154 tile_config: ForwardPlusTileConfig::default(),
155 shader_complexity: ShaderComplexity::SimpleLit,
156 screen_size: (1920, 1080),
157 gpu_resources: ForwardPlusGpuResources::default(),
158 }
159 }
160}
161
162impl ForwardPlusLane {
163 pub fn new() -> Self {
170 Self::default()
171 }
172
173 pub fn with_config(config: ForwardPlusTileConfig) -> Self {
179 Self {
180 tile_config: config,
181 ..Default::default()
182 }
183 }
184
185 pub fn with_complexity(complexity: ShaderComplexity) -> Self {
187 Self {
188 shader_complexity: complexity,
189 ..Default::default()
190 }
191 }
192
193 pub fn set_screen_size(&mut self, width: u32, height: u32) {
198 self.screen_size = (width, height);
199 }
200
201 pub fn tile_count(&self) -> (u32, u32) {
203 self.tile_config
204 .tile_dimensions(self.screen_size.0, self.screen_size.1)
205 }
206
207 pub fn total_tiles(&self) -> u32 {
209 let (tiles_x, tiles_y) = self.tile_count();
210 tiles_x * tiles_y
211 }
212
213 pub fn effective_light_count(&self, render_world: &RenderWorld) -> usize {
218 render_world.directional_light_count()
219 + render_world.point_light_count()
220 + render_world.spot_light_count()
221 }
222
223 fn compute_pass_cost(&self, render_world: &RenderWorld) -> f32 {
225 let total_tiles = self.total_tiles() as f32;
226 let light_count = self.effective_light_count(render_world) as f32;
227
228 COMPUTE_PASS_OVERHEAD
230 + (total_tiles * PER_TILE_COST)
231 + (total_tiles * light_count * LIGHT_TILE_TEST_COST)
232 }
233
234 fn fragment_light_factor(&self, render_world: &RenderWorld) -> f32 {
240 let total_lights = self.effective_light_count(render_world) as f32;
241
242 if total_lights == 0.0 {
243 return 1.0;
244 }
245
246 let effective_lights = total_lights
249 .sqrt()
250 .min(self.tile_config.max_lights_per_tile as f32);
251
252 1.0 + (effective_lights * 0.02)
253 }
254}
255
256impl RenderLane for ForwardPlusLane {
257 fn strategy_name(&self) -> &'static str {
258 "ForwardPlus"
259 }
260
261 fn get_pipeline_for_material(
262 &self,
263 material_uuid: Option<AssetUUID>,
264 materials: &Assets<Box<dyn Material>>,
265 ) -> RenderPipelineId {
266 if let Some(uuid) = material_uuid {
268 if materials.get(&uuid).is_none() {
269 let _ = uuid;
270 }
271 }
272
273 RenderPipelineId(2)
276 }
277
278 fn render(
279 &self,
280 render_world: &RenderWorld,
281 encoder: &mut dyn CommandEncoder,
282 render_ctx: &RenderContext,
283 gpu_meshes: &RwLock<Assets<GpuMesh>>,
284 materials: &RwLock<Assets<Box<dyn Material>>>,
285 ) {
286 let gpu_mesh_assets = gpu_meshes.read().unwrap();
288 let material_assets = materials.read().unwrap();
289
290 let pipelines: Vec<RenderPipelineId> = render_world
292 .meshes
293 .iter()
294 .map(|mesh| self.get_pipeline_for_material(mesh.material_uuid, &material_assets))
295 .collect();
296
297 let (tiles_x, tiles_y) = self.tile_count();
299
300 if self.gpu_resources.is_initialized() {
302 let compute_pass_desc = ComputePassDescriptor {
303 label: Some("Forward+ Light Culling"),
304 timestamp_writes: None,
305 };
306
307 {
308 let mut compute_pass = encoder.begin_compute_pass(&compute_pass_desc);
309
310 if let Some(ref pipeline) = self.gpu_resources.culling_pipeline {
312 compute_pass.set_pipeline(pipeline);
313 }
314
315 if let Some(ref bind_group) = self.gpu_resources.culling_bind_group {
317 compute_pass.set_bind_group(0, bind_group);
318 }
319
320 compute_pass.dispatch_workgroups(tiles_x, tiles_y, 1);
322 }
323 }
325
326 let color_attachment = RenderPassColorAttachment {
329 view: render_ctx.color_target,
330 resolve_target: None,
331 ops: Operations {
332 load: LoadOp::Clear(render_ctx.clear_color),
333 store: StoreOp::Store,
334 },
335 };
336
337 let render_pass_desc = RenderPassDescriptor {
338 label: Some("Forward+ Render Pass"),
339 color_attachments: &[color_attachment],
340 };
341
342 let mut render_pass = encoder.begin_render_pass(&render_pass_desc);
343
344 let mut current_pipeline: Option<RenderPipelineId> = None;
346
347 if let Some(ref forward_bg) = self.gpu_resources.forward_bind_group {
350 render_pass.set_bind_group(3, forward_bg);
351 }
352
353 for (i, extracted_mesh) in render_world.meshes.iter().enumerate() {
355 if let Some(gpu_mesh) = gpu_mesh_assets.get(&extracted_mesh.gpu_mesh_uuid) {
356 let pipeline = &pipelines[i];
357
358 if current_pipeline != Some(*pipeline) {
360 render_pass.set_pipeline(pipeline);
361 current_pipeline = Some(*pipeline);
362 }
363
364 render_pass.set_vertex_buffer(0, &gpu_mesh.vertex_buffer, 0);
366 render_pass.set_index_buffer(&gpu_mesh.index_buffer, 0, gpu_mesh.index_format);
367
368 render_pass.draw_indexed(0..gpu_mesh.index_count, 0, 0..1);
370 }
371 }
372 }
373
374 fn estimate_cost(
375 &self,
376 render_world: &RenderWorld,
377 gpu_meshes: &RwLock<Assets<GpuMesh>>,
378 ) -> f32 {
379 let gpu_mesh_assets = gpu_meshes.read().unwrap();
380
381 let mut total_triangles = 0u32;
382 let mut draw_call_count = 0u32;
383
384 for extracted_mesh in &render_world.meshes {
385 if let Some(gpu_mesh) = gpu_mesh_assets.get(&extracted_mesh.gpu_mesh_uuid) {
386 let triangle_count = match gpu_mesh.primitive_topology {
387 PrimitiveTopology::TriangleList => gpu_mesh.index_count / 3,
388 PrimitiveTopology::TriangleStrip => {
389 if gpu_mesh.index_count >= 3 {
390 gpu_mesh.index_count - 2
391 } else {
392 0
393 }
394 }
395 PrimitiveTopology::LineList
396 | PrimitiveTopology::LineStrip
397 | PrimitiveTopology::PointList => 0,
398 };
399
400 total_triangles += triangle_count;
401 draw_call_count += 1;
402 }
403 }
404
405 let geometry_cost =
407 (total_triangles as f32 * TRIANGLE_COST) + (draw_call_count as f32 * DRAW_CALL_COST);
408
409 let shader_multiplier = self.shader_complexity.cost_multiplier();
411
412 let compute_cost = self.compute_pass_cost(render_world);
414
415 let light_factor = self.fragment_light_factor(render_world);
417
418 compute_cost + (geometry_cost * shader_multiplier * light_factor)
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426 use khora_core::renderer::TileSize;
427
428 #[test]
429 fn test_forward_plus_lane_creation() {
430 let lane = ForwardPlusLane::new();
431 assert_eq!(lane.tile_config.tile_size, TileSize::X16);
432 assert_eq!(lane.tile_config.max_lights_per_tile, 128);
433 assert_eq!(lane.shader_complexity, ShaderComplexity::SimpleLit);
434 }
435
436 #[test]
437 fn test_forward_plus_lane_with_config() {
438 let config = ForwardPlusTileConfig {
439 tile_size: TileSize::X32,
440 max_lights_per_tile: 256,
441 use_depth_prepass: true,
442 };
443 let lane = ForwardPlusLane::with_config(config);
444
445 assert_eq!(lane.tile_config.tile_size, TileSize::X32);
446 assert_eq!(lane.tile_config.max_lights_per_tile, 256);
447 assert!(lane.tile_config.use_depth_prepass);
448 }
449
450 #[test]
451 fn test_tile_count_calculation() {
452 let mut lane = ForwardPlusLane::new();
453 lane.set_screen_size(1920, 1080);
454
455 let (tiles_x, tiles_y) = lane.tile_count();
456 assert_eq!(tiles_x, 120); assert_eq!(tiles_y, 68); }
459
460 #[test]
461 fn test_strategy_name() {
462 let lane = ForwardPlusLane::new();
463 assert_eq!(lane.strategy_name(), "ForwardPlus");
464 }
465
466 #[test]
467 fn test_pipeline_id() {
468 let lane = ForwardPlusLane::new();
469 let materials = Assets::<Box<dyn Material>>::new();
470 let pipeline = lane.get_pipeline_for_material(None, &materials);
471 assert_eq!(pipeline, RenderPipelineId(2));
472 }
473}