khora_lanes/audio_lane/mixing/
spatial_mixing_lane.rs

1// Copyright 2025 eraflo
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! The core audio processing lane, responsible for mixing and spatializing sound sources.
16
17use khora_core::audio::device::StreamInfo;
18use khora_data::ecs::{AudioListener, AudioSource, GlobalTransform, PlaybackState, World};
19
20/// A lane that performs spatialized audio mixing.
21#[derive(Default)]
22pub struct SpatialMixingLane;
23
24impl SpatialMixingLane {
25    /// Creates a new `SpatialMixingLane`.
26    pub fn new() -> Self {
27        Self
28    }
29}
30
31impl khora_core::lane::Lane for SpatialMixingLane {
32    fn strategy_name(&self) -> &'static str {
33        "SpatialMixing"
34    }
35
36    fn lane_kind(&self) -> khora_core::lane::LaneKind {
37        khora_core::lane::LaneKind::Audio
38    }
39
40    fn execute(
41        &self,
42        ctx: &mut khora_core::lane::LaneContext,
43    ) -> Result<(), khora_core::lane::LaneError> {
44        use khora_core::lane::{AudioOutputSlot, AudioStreamInfo, LaneError, Slot};
45
46        let stream_info = ctx
47            .get::<AudioStreamInfo>()
48            .ok_or(LaneError::missing("AudioStreamInfo"))?
49            .0;
50        let output_slot = ctx
51            .get::<AudioOutputSlot>()
52            .ok_or(LaneError::missing("AudioOutputSlot"))?;
53        let output_buffer = output_slot.get();
54        let world = ctx
55            .get::<Slot<World>>()
56            .ok_or(LaneError::missing("Slot<World>"))?
57            .get();
58
59        self.mix(world, output_buffer, &stream_info);
60        Ok(())
61    }
62
63    fn as_any(&self) -> &dyn std::any::Any {
64        self
65    }
66
67    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
68        self
69    }
70}
71
72impl SpatialMixingLane {
73    /// Mixes all active `AudioSource`s into a single output buffer, applying 3D spatialization.
74    pub fn mix(&self, world: &mut World, output_buffer: &mut [f32], stream_info: &StreamInfo) {
75        output_buffer.fill(0.0);
76
77        // --- Step 1: Find the listener (if any) ---
78        let listener_transform = world
79            .query::<(&AudioListener, &GlobalTransform)>()
80            .next()
81            .map(|(_, t)| t.0);
82
83        // --- Step 2 & 3: Process and mix all active sources ---
84        let samples_to_write = output_buffer.len() / stream_info.channels as usize;
85
86        for (source, source_transform) in world.query_mut::<(&mut AudioSource, &GlobalTransform)>()
87        {
88            if source.autoplay && source.state.is_none() {
89                source.state = Some(PlaybackState { cursor: 0.0 });
90            }
91
92            let sound_data = &source.handle;
93            let num_frames = sound_data.samples.len() / sound_data.channels as usize;
94
95            // Stop immediately if the sound is empty.
96            if num_frames == 0 {
97                source.state = None;
98                continue;
99            }
100
101            let resample_ratio = sound_data.sample_rate as f32 / stream_info.sample_rate as f32;
102            let (mut volume, mut pan) = (source.volume, 0.5);
103
104            if let Some(listener_mat) = listener_transform {
105                let source_pos = source_transform.0.translation();
106                let listener_pos = listener_mat.translation();
107                let listener_right = listener_mat.right();
108                let to_source = source_pos - listener_pos;
109                let distance = to_source.length();
110
111                volume *= 1.0 / (1.0 + distance * distance);
112                if distance > 0.001 {
113                    pan = (to_source.normalize().dot(listener_right) + 1.0) * 0.5;
114                }
115            }
116
117            let vol_l = volume * (1.0 - pan).sqrt();
118            let vol_r = volume * pan.sqrt();
119
120            for i in 0..samples_to_write {
121                // Get a mutable reference to the cursor for this iteration.
122                // If the state becomes None mid-loop, we stop processing this source.
123                let cursor = if let Some(state) = source.state.as_mut() {
124                    &mut state.cursor
125                } else {
126                    break;
127                };
128
129                // --- Robust End-of-Sound and Loop Handling ---
130                if *cursor >= num_frames as f32 {
131                    if source.looping {
132                        *cursor %= num_frames as f32;
133                    } else {
134                        source.state = None;
135                        break; // Stop processing samples for this source
136                    }
137                }
138
139                let cursor_floor = cursor.floor() as usize;
140                let cursor_fract = cursor.fract();
141
142                // For looping sounds, the next sample might wrap around to the beginning.
143                let next_frame_idx = (cursor_floor + 1) % num_frames;
144
145                let s1_idx = cursor_floor * sound_data.channels as usize;
146                let s2_idx = next_frame_idx * sound_data.channels as usize;
147
148                // This check prevents panics if sound data is malformed, though unlikely.
149                if s1_idx >= sound_data.samples.len() || s2_idx >= sound_data.samples.len() {
150                    source.state = None;
151                    break;
152                }
153
154                let s1 = sound_data.samples[s1_idx];
155                let s2 = sound_data.samples[s2_idx];
156                let sample = s1 + (s2 - s1) * cursor_fract;
157
158                // Mix into output buffer
159                let out_idx = i * stream_info.channels as usize;
160                if stream_info.channels == 2 {
161                    output_buffer[out_idx] += sample * vol_l;
162                    output_buffer[out_idx + 1] += sample * vol_r;
163                } else {
164                    output_buffer[out_idx] += sample * volume;
165                }
166
167                // Advance cursor
168                *cursor += resample_ratio;
169            }
170        }
171
172        // --- Step 4: Limiter ---
173        for sample in output_buffer.iter_mut() {
174            *sample = sample.clamp(-1.0, 1.0);
175        }
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use khora_core::{
183        asset::AssetHandle,
184        math::{affine_transform::AffineTransform, vector::Vec3},
185    };
186    use khora_data::assets::SoundData;
187
188    // Helper to create a simple SoundData for tests.
189    fn create_test_sound(len: usize, sample_rate: u32) -> AssetHandle<SoundData> {
190        let samples = (0..len).map(|i| (i as f32).sin()).collect();
191        AssetHandle::new(SoundData {
192            samples,
193            channels: 1, // Mono for simplicity
194            sample_rate,
195        })
196    }
197
198    // Helper for floating-point comparison
199    fn approx_eq(a: f32, b: f32) -> bool {
200        (a - b).abs() < 1e-5
201    }
202
203    #[test]
204    fn test_panning_right() {
205        let mut world = World::new();
206        let stream_info = StreamInfo {
207            channels: 2,
208            sample_rate: 44100,
209        };
210        let lane = SpatialMixingLane::new();
211        let mut buffer = vec![0.0; 128];
212
213        // Listener at the origin
214        world.spawn((AudioListener, GlobalTransform(AffineTransform::IDENTITY)));
215
216        // Source sound to the right of the listener
217        let sound = create_test_sound(1024, 44100);
218        world.spawn((
219            AudioSource {
220                handle: sound,
221                autoplay: true,
222                looping: false,
223                volume: 1.0,
224                state: None,
225            },
226            GlobalTransform(AffineTransform::from_translation(Vec3::new(10.0, 0.0, 0.0))),
227        ));
228
229        lane.mix(&mut world, &mut buffer, &stream_info);
230
231        // The sum of squares is a good measure of the signal's energy
232        let energy_left = buffer.iter().step_by(2).map(|&s| s * s).sum::<f32>();
233        let energy_right = buffer
234            .iter()
235            .skip(1)
236            .step_by(2)
237            .map(|&s| s * s)
238            .sum::<f32>();
239
240        assert!(
241            energy_right > energy_left * 100.0,
242            "The energy should be much stronger in the right channel"
243        );
244
245        // With a sound perfectly to the right, the left channel MUST be silent.
246        assert!(
247            approx_eq(energy_left, 0.0),
248            "The left channel should be silent for a sound perfectly to the right"
249        );
250    }
251
252    #[test]
253    fn test_distance_attenuation() {
254        let stream_info = StreamInfo {
255            channels: 2,
256            sample_rate: 44100,
257        };
258        let lane = SpatialMixingLane::new();
259
260        // --- Case 1: Near source ---
261        let mut world_near = World::new();
262        world_near.spawn((AudioListener, GlobalTransform(AffineTransform::IDENTITY)));
263        let sound = create_test_sound(1024, 44100);
264        world_near.spawn((
265            AudioSource {
266                handle: sound.clone(),
267                autoplay: true,
268                looping: true,
269                volume: 1.0,
270                state: None,
271            },
272            GlobalTransform(AffineTransform::from_translation(Vec3::new(1.0, 0.0, 0.0))),
273        ));
274
275        let mut buffer_near = vec![0.0; 128];
276        lane.mix(&mut world_near, &mut buffer_near, &stream_info);
277        let peak_near = buffer_near.iter().map(|s| s.abs()).fold(0.0, f32::max);
278
279        // --- Case 2: Far source ---
280        let mut world_far = World::new();
281        world_far.spawn((AudioListener, GlobalTransform(AffineTransform::IDENTITY)));
282        world_far.spawn((
283            AudioSource {
284                handle: sound,
285                autoplay: true,
286                looping: true,
287                volume: 1.0,
288                state: None,
289            },
290            GlobalTransform(AffineTransform::from_translation(Vec3::new(
291                100.0, 0.0, 0.0,
292            ))),
293        ));
294
295        let mut buffer_far = vec![0.0; 128];
296        lane.mix(&mut world_far, &mut buffer_far, &stream_info);
297        let peak_far = buffer_far.iter().map(|s| s.abs()).fold(0.0, f32::max);
298
299        assert!(peak_near > 0.01, "The near sound should be audible");
300        assert!(
301            peak_far < peak_near * 0.1,
302            "The far sound should be significantly quieter"
303        );
304    }
305
306    #[test]
307    fn test_sound_finishes_and_stops() {
308        let mut world = World::new();
309        let stream_info = StreamInfo {
310            channels: 2,
311            sample_rate: 10,
312        };
313        let lane = SpatialMixingLane::new();
314
315        let sound = create_test_sound(5, 10); // Sound of 5 samples at 10Hz (0.5s)
316        let entity = world.spawn((
317            AudioSource {
318                handle: sound,
319                autoplay: true,
320                looping: false, // Does not loop
321                volume: 1.0,
322                state: None,
323            },
324            GlobalTransform::default(),
325        ));
326
327        // First mix, the sound plays
328        let mut buffer = vec![0.0; 20]; // 20 samples = 2 seconds
329        lane.mix(&mut world, &mut buffer, &stream_info);
330
331        // After the mix, the sound should have finished
332        let source = world.get_mut::<AudioSource>(entity).unwrap();
333        assert!(
334            source.state.is_none(),
335            "The playback state should be `None` after the sound finishes"
336        );
337
338        // Check that the sound stops in the buffer
339        let first_part_energy = buffer[0..10].iter().map(|&s| s * s).sum::<f32>();
340        let second_part_energy = buffer[10..20].iter().map(|&s| s * s).sum::<f32>();
341        assert!(first_part_energy > 0.0);
342        assert!(
343            approx_eq(second_part_energy, 0.0),
344            "The second half of the buffer should be silent"
345        );
346    }
347
348    #[test]
349    fn test_sound_loops() {
350        let mut world = World::new();
351        let stream_info = StreamInfo {
352            channels: 1,
353            sample_rate: 10,
354        };
355        let lane = SpatialMixingLane::new();
356
357        let sound = create_test_sound(5, 10); // Sound of 5 samples
358        let entity = world.spawn((
359            AudioSource {
360                handle: sound,
361                autoplay: true,
362                looping: true, // Loops
363                volume: 1.0,
364                state: None,
365            },
366            GlobalTransform::default(),
367        ));
368
369        let mut buffer = vec![0.0; 12]; // Buffer of 12 samples
370        lane.mix(&mut world, &mut buffer, &stream_info);
371
372        // After the mix, the sound should still be playing
373        let source = world.get::<AudioSource>(entity).unwrap();
374        let cursor = source.state.as_ref().unwrap().cursor;
375        assert!(
376            cursor > 0.0 && cursor < 5.0,
377            "The cursor should have looped and returned to the beginning, cursor is {}",
378            cursor
379        );
380
381        // The sound should be present at the start and end of the buffer
382        assert!(!approx_eq(buffer[1], 0.0)); // Beginning
383        assert!(!approx_eq(buffer[6], 0.0)); // Middle (after loop)
384        assert!(!approx_eq(buffer[11], 0.0)); // End
385    }
386}