khora_lanes/physics_lane/
native_lanes.rs

1// Copyright 2025 eraflo
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use khora_core::ecs::entity::EntityId;
16use khora_core::physics::{
17    ContactManifold, DynamicTree, ImpulseSolver, NarrowPhase, VelocityState,
18};
19use khora_data::ecs::{Collider, CollisionPair, CollisionPairs, GlobalTransform, RigidBody, World};
20use std::collections::HashMap;
21use std::sync::RwLock;
22
23/// The Broadphase Lane manages spatial partitioning and potential collision pair generation.
24/// It maintains a persistent Dynamic AABB Tree to minimize update costs.
25pub struct NativeBroadphaseLane {
26    /// Spatial hierarchy for efficient overlap queries.
27    /// Wrapped in RwLock for thread-safe access if multiple lanes query it.
28    tree: RwLock<DynamicTree<EntityId>>,
29    /// Mapping from EntityId to tree node handle for efficient updates.
30    handles: RwLock<HashMap<EntityId, i32>>,
31}
32
33impl Default for NativeBroadphaseLane {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl NativeBroadphaseLane {
40    /// Creates a new `NativeBroadphaseLane`.
41    pub fn new() -> Self {
42        Self {
43            tree: RwLock::new(DynamicTree::new()),
44            handles: RwLock::new(HashMap::new()),
45        }
46    }
47
48    /// Executes the broad-phase step: updates the tree and generates collision pairs.
49    pub fn step(&self, world: &mut World) {
50        // 1. Sync ECS components to the Dynamic Tree
51        self.sync_tree(world);
52
53        // 2. Clear old pairs and generate new ones
54        self.generate_pairs(world);
55    }
56
57    fn sync_tree(&self, world: &mut World) {
58        let mut tree = self.tree.write().unwrap();
59        let mut handles = self.handles.write().unwrap();
60
61        // Track seen entities to remove dead ones
62        let mut current_entities = std::collections::HashSet::new();
63
64        // Query entities with Collider and GlobalTransform
65        let query = world.query::<(EntityId, &Collider, &GlobalTransform)>();
66        for (entity_id, collider, transform) in query {
67            let world_aabb = collider.shape.compute_aabb().transform(&transform.0 .0);
68            current_entities.insert(entity_id);
69
70            let displacement = world
71                .get::<RigidBody>(entity_id)
72                .map(|b| b.linear_velocity)
73                .unwrap_or(khora_core::math::Vec3::ZERO);
74
75            if let Some(&handle) = handles.get(&entity_id) {
76                // Update with displacement prediction if it's a dynamic body
77                tree.update(handle, world_aabb, displacement, false);
78            } else {
79                let handle = tree.insert(world_aabb, entity_id);
80                handles.insert(entity_id, handle);
81            }
82        }
83
84        // Cleanup removed entities
85        let mut to_remove = Vec::new();
86        for &entity_id in handles.keys() {
87            if !current_entities.contains(&entity_id) {
88                to_remove.push(entity_id);
89            }
90        }
91
92        for entity_id in to_remove {
93            if let Some(handle) = handles.remove(&entity_id) {
94                tree.remove(handle);
95            }
96        }
97    }
98
99    fn generate_pairs(&self, world: &mut World) {
100        let tree = self.tree.read().unwrap();
101        let mut all_pairs = Vec::new();
102
103        tree.query_pairs(|&a, &b| {
104            all_pairs.push(CollisionPair {
105                entity_a: a,
106                entity_b: b,
107            });
108        });
109
110        // Store pairs in a singleton CollisionPairs component
111        // We find the first entity with CollisionPairs or spawn one.
112        let mut found = false;
113        {
114            let mut query = world.query_mut::<&mut CollisionPairs>();
115            if let Some(pairs_comp) = query.next() {
116                pairs_comp.pairs = all_pairs.clone();
117                found = true;
118            }
119        }
120
121        if !found {
122            world.spawn(CollisionPairs { pairs: all_pairs });
123        }
124    }
125}
126
127impl khora_core::lane::Lane for NativeBroadphaseLane {
128    fn strategy_name(&self) -> &'static str {
129        "NativeBroadphase"
130    }
131
132    fn lane_kind(&self) -> khora_core::lane::LaneKind {
133        khora_core::lane::LaneKind::Physics
134    }
135
136    fn execute(
137        &self,
138        ctx: &mut khora_core::lane::LaneContext,
139    ) -> Result<(), khora_core::lane::LaneError> {
140        use khora_core::lane::{LaneError, Slot};
141
142        let world = ctx
143            .get::<Slot<World>>()
144            .ok_or(LaneError::missing("Slot<World>"))?
145            .get();
146
147        self.step(world);
148        Ok(())
149    }
150
151    fn as_any(&self) -> &dyn std::any::Any {
152        self
153    }
154
155    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
156        self
157    }
158}
159
160/// The Solver Lane resolves collisions and updates velocities/positions.
161/// It uses the Sequential Impulse method for stable constraint resolution.
162pub struct NativeSolverLane {
163    solver: SequentialImpulseSolver,
164}
165
166impl Default for NativeSolverLane {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172impl NativeSolverLane {
173    /// Creates a new `NativeSolverLane`.
174    pub fn new() -> Self {
175        Self {
176            solver: SequentialImpulseSolver::new(),
177        }
178    }
179
180    /// Executes the solver step: integrates velocities, resolves collisions, and integrates positions.
181    pub fn step(&self, world: &mut World, dt: f32) {
182        // 1. Integrate Forces (Gravity, etc.)
183        self.integrate_velocities(world, dt);
184
185        // 2. Resolve Constraints (Collisions)
186        self.solver.solve_collisions(world);
187
188        // 3. Integrate Positions
189        self.integrate_positions(world, dt);
190    }
191
192    fn integrate_velocities(&self, world: &mut World, dt: f32) {
193        let gravity = khora_core::math::Vec3::new(0.0, -9.81, 0.0);
194        let query = world.query_mut::<&mut RigidBody>();
195        for rb in query {
196            if rb.body_type == khora_core::physics::BodyType::Dynamic {
197                // v = v + a*dt
198                rb.linear_velocity = rb.linear_velocity + (gravity * dt);
199            }
200        }
201    }
202
203    fn integrate_positions(&self, world: &mut World, dt: f32) {
204        let query = world.query_mut::<(&mut khora_data::ecs::Transform, &RigidBody)>();
205        for (transform, rb) in query {
206            if rb.body_type == khora_core::physics::BodyType::Dynamic {
207                // Integrate translation
208                transform.translation = transform.translation + (rb.linear_velocity * dt);
209
210                // Integrate rotation: dq = [w * dt / 2, 1] * q
211                let angular_velocity = rb.angular_velocity;
212                let w_mag = angular_velocity.length();
213                if w_mag > 0.0001 {
214                    let axis = angular_velocity / w_mag;
215                    let angle = w_mag * dt;
216                    let delta_rot = khora_core::math::Quat::from_axis_angle(axis, angle);
217                    transform.rotation = delta_rot * transform.rotation;
218                    transform.rotation = transform.rotation.normalize();
219                }
220            }
221        }
222    }
223}
224
225impl khora_core::lane::Lane for NativeSolverLane {
226    fn strategy_name(&self) -> &'static str {
227        "NativeSolver"
228    }
229
230    fn lane_kind(&self) -> khora_core::lane::LaneKind {
231        khora_core::lane::LaneKind::Physics
232    }
233
234    fn execute(
235        &self,
236        ctx: &mut khora_core::lane::LaneContext,
237    ) -> Result<(), khora_core::lane::LaneError> {
238        use khora_core::lane::{LaneError, Slot};
239
240        let dt = ctx
241            .get::<khora_core::lane::PhysicsDeltaTime>()
242            .ok_or(LaneError::missing("PhysicsDeltaTime"))?
243            .0;
244        let world = ctx
245            .get::<Slot<World>>()
246            .ok_or(LaneError::missing("Slot<World>"))?
247            .get();
248
249        self.step(world, dt);
250        Ok(())
251    }
252
253    fn as_any(&self) -> &dyn std::any::Any {
254        self
255    }
256
257    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
258        self
259    }
260}
261
262/// Encapsulates the Impulse-based constraint resolution logic.
263/// This solver follows the Sequential Impulse method, which is a popular approach
264/// for solving constraints (like collisions) by applying a series of impulses
265/// until the system reaches a stable state.
266struct SequentialImpulseSolver {
267    narrow_phase: NarrowPhase,
268    impulse_solver: ImpulseSolver,
269}
270
271impl SequentialImpulseSolver {
272    fn new() -> Self {
273        Self {
274            narrow_phase: NarrowPhase::new(),
275            impulse_solver: ImpulseSolver::new(),
276        }
277    }
278
279    /// Iterates over all detected collision pairs and resolves them.
280    /// This is the entry point for the collision resolution phase of the lane.
281    fn solve_collisions(&self, world: &mut World) {
282        // 1. Gather all potential collision pairs from the broad-phase output.
283        let candidates = {
284            let mut pairs = Vec::new();
285            let query = world.query::<&CollisionPairs>();
286            for p in query {
287                pairs.extend_from_slice(&p.pairs);
288            }
289            pairs
290        };
291
292        // 2. Process each pair sequentially.
293        for pair in candidates {
294            self.solve_pair(world, pair.entity_a, pair.entity_b);
295        }
296    }
297
298    /// Performs narrow-phase detection and resolves a single pair of entities.
299    fn solve_pair(&self, world: &mut World, a_id: EntityId, b_id: EntityId) {
300        let mut manifold = None;
301
302        // Scope the immutable borrows to allow mutable access to the world later.
303        {
304            let a_coll = world.get::<Collider>(a_id);
305            let a_trans = world.get::<GlobalTransform>(a_id);
306            let b_coll = world.get::<Collider>(b_id);
307            let b_trans = world.get::<GlobalTransform>(b_id);
308
309            if let (Some(ac), Some(at), Some(bc), Some(bt)) = (a_coll, a_trans, b_coll, b_trans) {
310                // Perform intersection test using the NarrowPhase instance.
311                manifold = self.narrow_phase.detect(&ac.shape, &at.0, &bc.shape, &bt.0);
312            }
313        }
314
315        // If a collision was detected, apply the resolution impulse.
316        if let Some(m) = manifold {
317            // Retrieve both rigid bodies mutably using the safe multi-entity access API.
318            let [a_rb_opt, b_rb_opt] = world.get_many_mut::<RigidBody, 2>([a_id, b_id]);
319
320            match (a_rb_opt, b_rb_opt) {
321                (Some(rb_a), Some(rb_b)) => {
322                    self.apply_impulse(rb_a, rb_b, &m);
323                }
324                (Some(rb_a), None) => {
325                    // Resolve against a static object (which has No RigidBody component).
326                    self.apply_impulse_static(rb_a, &m);
327                }
328                (None, Some(rb_b)) => {
329                    // Mirror of the above for symmetry.
330                    self.apply_impulse_static(rb_b, &m.inverted());
331                }
332                _ => {}
333            }
334        }
335    }
336
337    /// Resolves two entities by delegating the math to the core ImpulseSolver.
338    fn apply_impulse(&self, a: &mut RigidBody, b: &mut RigidBody, m: &ContactManifold) {
339        let state_a = VelocityState {
340            linear_velocity: a.linear_velocity,
341            angular_velocity: a.angular_velocity,
342            mass: a.mass,
343            body_type: a.body_type,
344        };
345
346        let state_b = VelocityState {
347            linear_velocity: b.linear_velocity,
348            angular_velocity: b.angular_velocity,
349            mass: b.mass,
350            body_type: b.body_type,
351        };
352
353        let (new_a, new_b) = self.impulse_solver.resolve(state_a, state_b, m);
354
355        a.linear_velocity = new_a.linear_velocity;
356        b.linear_velocity = new_b.linear_velocity;
357    }
358
359    /// Resolves a dynamic body against a static one.
360    fn apply_impulse_static(&self, a: &mut RigidBody, m: &ContactManifold) {
361        let state_a = VelocityState {
362            linear_velocity: a.linear_velocity,
363            angular_velocity: a.angular_velocity,
364            mass: a.mass,
365            body_type: a.body_type,
366        };
367
368        // Represent static as a default state with Static body type.
369        let state_static = VelocityState {
370            linear_velocity: khora_core::math::Vec3::ZERO,
371            angular_velocity: khora_core::math::Vec3::ZERO,
372            mass: 0.0,
373            body_type: khora_core::physics::BodyType::Static,
374        };
375
376        let (new_a, _) = self.impulse_solver.resolve(state_a, state_static, m);
377        a.linear_velocity = new_a.linear_velocity;
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384    use khora_core::math::{AffineTransform, Vec3};
385
386    #[test]
387    fn test_sphere_sphere_collision() {
388        let sphere_a = Collider::new_sphere(1.0);
389        let sphere_b = Collider::new_sphere(1.0);
390        let trans_a = AffineTransform::from_translation(Vec3::new(0.0, 0.0, 0.0));
391        let trans_b = AffineTransform::from_translation(Vec3::new(1.5, 0.0, 0.0));
392
393        let narrow = NarrowPhase::new();
394        let manifold = narrow
395            .detect(&sphere_a.shape, &trans_a, &sphere_b.shape, &trans_b)
396            .unwrap();
397        assert!(manifold.depth > 0.0);
398        assert!((manifold.normal.x - 1.0).abs() < 0.001);
399        assert!((manifold.depth - 0.5).abs() < 0.001);
400    }
401}