khora_data/ecs/
query.rs

1// Copyright 2025 eraflo
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use khora_core::ecs::entity::EntityId;
16
17use crate::ecs::{
18    page::{AnyVec, ComponentPage},
19    Component, DomainBitset, QueryMode, QueryPlan, World,
20};
21use std::{any::TypeId, marker::PhantomData};
22
23// ------------------------- //
24// ---- WorldQuery Part ---- //
25// ------------------------- //
26
27/// A trait implemented by types that can be used to query data from the `World`.
28///
29/// This "sealed" trait provides the necessary information for the query engine to
30/// find the correct pages and safely access component data. It is implemented for
31/// component references (`&T`, `&mut T`), `EntityId`, filters like `Without<T>`,
32/// and tuples of other `WorldQuery` types.
33pub trait WorldQuery {
34    /// The type of item that the query iterator will yield (e.g., `(&'a Position, &'a mut Velocity)`).
35    type Item<'a>;
36
37    /// Returns the sorted list of `TypeId`s for the components to be INCLUDED in the query.
38    /// This signature is used to find `ComponentPage`s that contain all these components.
39    fn type_ids() -> Vec<TypeId>;
40
41    /// Returns the sorted list of `TypeId`s for components to be EXCLUDED from the query.
42    /// Used to filter out pages that contain these components.
43    fn without_type_ids() -> Vec<TypeId> {
44        Vec::new()
45    }
46
47    /// Fetches the query's item from a specific row in a `ComponentPage`.
48    ///
49    /// # Safety
50    ///
51    /// This function is unsafe because the caller (the `Query` iterator) must guarantee
52    /// several invariants:
53    /// 1. The page pointed to by `page_ptr` is valid and matches the query's signature.
54    /// 2. `row_index` is a valid index within the page's columns.
55    /// 3. Aliasing rules are not violated (e.g., no two `&mut T` to the same data).
56    unsafe fn fetch<'a>(page_ptr: *const ComponentPage, row_index: usize) -> Self::Item<'a>;
57
58    /// Fetches the query's item directly from the world for a specific entity.
59    ///
60    /// # Safety
61    /// Caller must ensure `world` is valid and no aliasing rules are violated.
62    unsafe fn fetch_from_world<'a>(
63        world: *const World,
64        entity_id: EntityId,
65    ) -> Option<Self::Item<'a>>;
66}
67
68// Implementation for a query of a single, immutable component reference.
69impl<T: Component> WorldQuery for &T {
70    type Item<'a> = &'a T;
71
72    fn type_ids() -> Vec<TypeId> {
73        vec![TypeId::of::<T>()]
74    }
75
76    /// Fetches a reference to the component `T` from the specified row.
77    ///
78    /// # Safety
79    /// The caller MUST guarantee that the page contains a column for component `T`
80    /// and that `row_index` is in bounds.
81    unsafe fn fetch<'a>(page_ptr: *const ComponentPage, row_index: usize) -> Self::Item<'a> {
82        // 1. Get a reference to the `ComponentPage`.
83        let page = &*page_ptr;
84
85        // 2. Get the type-erased column for the component `T`.
86        // We can unwrap because the caller guarantees the column exists.
87        let column: &dyn AnyVec = &**page.columns.get(&TypeId::of::<T>()).unwrap();
88
89        // 3. Downcast the column to its concrete `Vec<T>` type.
90        // First, cast to `&dyn Any`, then `downcast_ref`.
91        let vec: &Vec<T> = column.as_any().downcast_ref::<Vec<T>>().unwrap();
92
93        // 4. Get the component from the vector at the specified row.
94        // We use `get_unchecked` for performance, as the caller guarantees the
95        // index is in bounds. This avoids a bounds check.
96        vec.get_unchecked(row_index)
97    }
98
99    unsafe fn fetch_from_world<'a>(
100        world: *const World,
101        entity_id: EntityId,
102    ) -> Option<Self::Item<'a>> {
103        let world = &*world;
104
105        // Get the entity's metadata.
106        let metadata = world.entities.get(entity_id.index as usize)?.1.as_ref()?;
107
108        // Get the domain for the component type.
109        let domain = world.storage.registry.get_domain(TypeId::of::<T>())?;
110
111        // Get the location of the entity in the domain.
112        let location = metadata.locations.get(&domain)?;
113
114        // Get the page for the entity.
115        let page = &world.storage.pages[location.page_id as usize];
116        let column = page.columns.get(&TypeId::of::<T>())?;
117        let vec = column.as_any().downcast_ref::<Vec<T>>()?;
118        vec.get(location.row_index as usize)
119    }
120}
121
122// Implementation for a query of a single, mutable component reference.
123impl<T: Component> WorldQuery for &mut T {
124    type Item<'a> = &'a mut T;
125
126    fn type_ids() -> Vec<TypeId> {
127        vec![TypeId::of::<T>()]
128    }
129
130    /// Fetches a mutable reference to the component `T` from the specified row.
131    ///
132    /// # Safety
133    /// The caller MUST guarantee that:
134    /// 1. The page contains a column for component `T`.
135    /// 2. `row_index` is in bounds.
136    /// 3. No other mutable reference to this specific component exists at the same time.
137    ///    The query engine must enforce this.
138    unsafe fn fetch<'a>(page_ptr: *const ComponentPage, row_index: usize) -> Self::Item<'a> {
139        // UNSAFE: We cast the const pointer to a mutable one.
140        // This is safe ONLY if the query engine guarantees no other access.
141        let page = &mut *(page_ptr as *mut ComponentPage);
142        let column = page.columns.get_mut(&TypeId::of::<T>()).unwrap();
143        let vec = column.as_any_mut().downcast_mut::<Vec<T>>().unwrap();
144        vec.get_unchecked_mut(row_index)
145    }
146
147    unsafe fn fetch_from_world<'a>(
148        world: *const World,
149        entity_id: EntityId,
150    ) -> Option<Self::Item<'a>> {
151        let world_mut = &mut *(world as *mut World);
152        let metadata = world_mut
153            .entities
154            .get_mut(entity_id.index as usize)?
155            .1
156            .as_mut()?;
157        let _domain = world_mut.storage.registry.get_domain(TypeId::of::<T>())?;
158        let location = metadata.locations.get(&_domain)?;
159
160        let page = &mut world_mut.storage.pages[location.page_id as usize];
161        let column = page.columns.get_mut(&TypeId::of::<T>())?;
162        let vec = column.as_any_mut().downcast_mut::<Vec<T>>()?;
163        vec.get_mut(location.row_index as usize)
164    }
165}
166
167// Implementation for tuples of WorldQuery types.
168// We use a macro to avoid "infinity" of manual implementations while maintaining
169// the same rigorous safety standards as the single-component cases.
170macro_rules! impl_query_tuple {
171    ($($Q:ident),*) => {
172        impl<$($Q: WorldQuery),*> WorldQuery for ($($Q,)*) {
173            type Item<'a> = ($($Q::Item<'a>,)*);
174
175            fn type_ids() -> Vec<TypeId> {
176                let mut ids = Vec::new();
177                $(ids.extend($Q::type_ids());)*
178                ids.sort();
179                ids.dedup(); // Ensure unique TypeIds for canonical signature
180                ids
181            }
182
183            fn without_type_ids() -> Vec<TypeId> {
184                let mut ids = Vec::new();
185                $(ids.extend($Q::without_type_ids());)*
186                ids.sort();
187                ids.dedup(); // Ensure unique TypeIds for canonical signature
188                ids
189            }
190
191            unsafe fn fetch<'a>(page_ptr: *const ComponentPage, row_index: usize) -> Self::Item<'a> {
192                ($($Q::fetch(page_ptr, row_index),)*)
193            }
194
195            unsafe fn fetch_from_world<'a>(
196                world: *const World,
197                entity_id: EntityId,
198            ) -> Option<Self::Item<'a>> {
199                Some(($($Q::fetch_from_world(world, entity_id)?,)*))
200            }
201        }
202    };
203}
204
205impl_query_tuple!(Q1);
206impl_query_tuple!(Q1, Q2);
207impl_query_tuple!(Q1, Q2, Q3);
208impl_query_tuple!(Q1, Q2, Q3, Q4);
209impl_query_tuple!(Q1, Q2, Q3, Q4, Q5);
210impl_query_tuple!(Q1, Q2, Q3, Q4, Q5, Q6);
211impl_query_tuple!(Q1, Q2, Q3, Q4, Q5, Q6, Q7);
212impl_query_tuple!(Q1, Q2, Q3, Q4, Q5, Q6, Q7, Q8);
213impl_query_tuple!(Q1, Q2, Q3, Q4, Q5, Q6, Q7, Q8, Q9);
214impl_query_tuple!(Q1, Q2, Q3, Q4, Q5, Q6, Q7, Q8, Q9, Q10);
215impl_query_tuple!(Q1, Q2, Q3, Q4, Q5, Q6, Q7, Q8, Q9, Q10, Q11);
216impl_query_tuple!(Q1, Q2, Q3, Q4, Q5, Q6, Q7, Q8, Q9, Q10, Q11, Q12);
217
218// To fetch an entity's ID, we need to access the page's own entity list.
219// We also need to query for the entity ID itself.
220impl WorldQuery for EntityId {
221    type Item<'a> = EntityId;
222
223    // EntityId is not a component, it doesn't have a TypeId in the page signature.
224    fn type_ids() -> Vec<TypeId> {
225        Vec::new()
226    }
227
228    // It doesn't have a `without` filter either.
229    // The default implementation is sufficient.
230
231    unsafe fn fetch<'a>(page_ptr: *const ComponentPage, row_index: usize) -> Self::Item<'a> {
232        let page = &*page_ptr;
233        // The entity ID is fetched from the page's own list of entities.
234        *page.entities.get_unchecked(row_index)
235    }
236
237    unsafe fn fetch_from_world<'a>(
238        _world: *const World,
239        entity_id: EntityId,
240    ) -> Option<Self::Item<'a>> {
241        // We already have the entity ID, just return it.
242        Some(entity_id)
243    }
244}
245
246// -------------------- //
247// ---- Query Part ---- //
248// -------------------- //
249
250/// An iterator that yields the results of a `WorldQuery`.
251///
252/// This struct is created by the [`World::query()`] method. It holds a reference
253/// to the world and iterates through all `ComponentPage`s that match the query's
254/// signature, fetching the requested data for each entity in those pages.
255pub struct Query<'a, Q: WorldQuery> {
256    /// A raw pointer to the world. Using a pointer avoids lifetime variance issues
257    /// and gives us the flexibility to provide either mutable or immutable access.
258    world_ptr: *const World,
259
260    /// The list of page indices that match this query (used in Native mode).
261    matching_page_indices: Vec<u32>,
262
263    /// The execution plan for this query.
264    plan: QueryPlan,
265
266    /// The index of the current page we are iterating through.
267    current_page_index: usize,
268
269    /// The index of the next row to fetch within the current page.
270    current_row_index: usize,
271
272    /// A marker to associate this iterator with the lifetime `'a` and the query type `Q`.
273    /// It tells Rust that our iterator behaves as if it's borrowing `&'a Q` from the world.
274    _phantom: PhantomData<(&'a (), Q)>,
275
276    /// Pre-computed bitset intersection for fast-failing transversal lookups.
277    combined_bitset: Option<DomainBitset>,
278}
279
280impl<'a, Q: WorldQuery> Query<'a, Q> {
281    /// (Internal) Creates a new `Query` iterator.
282    ///
283    /// This is intended to be called only by `World::query()`.
284    /// It takes the world, the plan (strategy), and the pre-calculated list
285    /// of matching pages as arguments.
286    pub(crate) fn new(world: &'a World, plan: QueryPlan, matching_page_indices: Vec<u32>) -> Self {
287        let combined_bitset = world.compute_query_bitset(&plan);
288        Self {
289            world_ptr: world as *const _,
290            matching_page_indices,
291            plan,
292            current_page_index: 0,
293            current_row_index: 0,
294            _phantom: PhantomData,
295            combined_bitset,
296        }
297    }
298}
299
300impl<'a, Q: WorldQuery> Iterator for Query<'a, Q> {
301    /// The type of item yielded by this iterator, as defined by the `WorldQuery` trait.
302    type Item = Q::Item<'a>;
303
304    /// Advances the iterator and returns the next item.
305    fn next(&mut self) -> Option<Self::Item> {
306        // Delegate to the appropriate execution path based on the pre-calculated plan.
307        match self.plan.mode {
308            QueryMode::Native => self.next_native(),
309            QueryMode::Transversal => self.next_transversal(),
310        }
311    }
312}
313
314impl<'a, Q: WorldQuery> Query<'a, Q> {
315    /// (Internal) Performs a "Native" iteration, fetching data from a single domain.
316    /// This is the most efficient execution path.
317    fn next_native(&mut self) -> Option<Q::Item<'a>> {
318        loop {
319            // 1. Check if there are any pages left to iterate through.
320            if self.current_page_index >= self.matching_page_indices.len() {
321                return None; // No more pages, iteration is finished.
322            }
323
324            // Unsafe block because we are dereferencing a raw pointer.
325            // This is safe because the `Query` is only created from a valid `World` reference.
326            let world = unsafe { &*self.world_ptr };
327
328            // 2. Get the current page.
329            let page_id = self.matching_page_indices[self.current_page_index];
330            let page = &world.storage.pages[page_id as usize];
331
332            // 3. Check if there are rows left in the current page.
333            if self.current_row_index < page.row_count() {
334                let item = unsafe {
335                    // Safe because the page signature matches the query requirements.
336                    Q::fetch(page as *const _, self.current_row_index)
337                };
338
339                // Advance the row index for the next call.
340                self.current_row_index += 1;
341                return Some(item);
342            } else {
343                self.current_page_index += 1;
344                self.current_row_index = 0; // Reset the row index for the new page.
345                                            // The `loop` will then re-evaluate with the new page index.
346            }
347        }
348    }
349
350    /// (Internal) Performs a "Transversal" iteration, joining data across domains.
351    /// This uses the driver domain to find entities and then pulls peer data from other domains.
352    fn next_transversal(&mut self) -> Option<Q::Item<'a>> {
353        loop {
354            if self.current_page_index >= self.matching_page_indices.len() {
355                return None;
356            }
357
358            let world = unsafe { &*self.world_ptr };
359            let page_id = self.matching_page_indices[self.current_page_index];
360            let page = &world.storage.pages[page_id as usize];
361
362            if self.current_row_index < page.row_count() {
363                // In transversal mode, we use the EntityId from the driver page to look up
364                // counterpart components in the peer domains.
365                let entity_id = page.entities[self.current_row_index];
366                self.current_row_index += 1;
367
368                // Optimization: Skip metadata lookup if the entity is not in the combined bitset.
369                if let Some(bitset) = &self.combined_bitset {
370                    if !bitset.is_set(entity_id.index) {
371                        continue;
372                    }
373                }
374
375                // Attempt to fetch the Full item (driver + peers) from the world.
376                if let Some(item) = unsafe { Q::fetch_from_world(world as *const _, entity_id) } {
377                    return Some(item);
378                }
379                // If the join failed for this entity, we continue to the next one.
380            } else {
381                self.current_page_index += 1;
382                self.current_row_index = 0;
383            }
384        }
385    }
386}
387
388/// A `WorldQuery` filter that matches entities that do NOT have component `T`.
389///
390/// This is used as a marker in a query tuple to exclude entities. For example,
391/// `Query<(&Position, Without<Velocity>)>` will iterate over all entities
392/// that have a `Position` but no `Velocity`.
393pub struct Without<T: Component>(PhantomData<T>);
394
395// `Without<T>` itself doesn't fetch any data, so its `WorldQuery` implementation
396// is mostly empty. It acts as a signal to the query engine.
397impl<T: Component> WorldQuery for Without<T> {
398    /// This query item is a zero-sized unit type, as it fetches no data.
399    type Item<'a> = ();
400
401    /// A `Without` filter does not add any component types to the query's main
402    /// signature for page matching. The filtering is handled separately.
403    fn type_ids() -> Vec<TypeId> {
404        Vec::new() // Returns an empty Vec.
405    }
406
407    /// This is the key part of the filter: it returns the TypeId of the component to filter out.
408    fn without_type_ids() -> Vec<TypeId> {
409        // This is the key change: it returns the TypeId of the component to filter out.
410        vec![TypeId::of::<T>()]
411    }
412
413    /// Fetches nothing. Returns a unit type `()`.
414    unsafe fn fetch<'a>(_page_ptr: *const ComponentPage, _row_index: usize) -> Self::Item<'a> {
415        // This function will be called but its result is ignored.
416    }
417
418    unsafe fn fetch_from_world<'a>(
419        world: *const World,
420        entity_id: EntityId,
421    ) -> Option<Self::Item<'a>> {
422        let world = unsafe { &*world };
423        let metadata = world.entities.get(entity_id.index as usize)?.1.as_ref()?;
424
425        // Check ALL pages associated with this entity.
426        // If ANY page contains the forbidden component, filtering failed.
427        for location in metadata.locations.values() {
428            let page = &world.storage.pages[location.page_id as usize];
429            if page.type_ids.binary_search(&TypeId::of::<T>()).is_ok() {
430                return None;
431            }
432        }
433        Some(())
434    }
435}
436
437// ------------------------- //
438// ---- QueryMut Part ---- //
439// ------------------------- //
440
441/// An iterator that yields the results of a mutable `WorldQuery`.
442///
443/// This struct is created by the [`World::query_mut()`] method.
444pub struct QueryMut<'a, Q: WorldQuery> {
445    world_ptr: *mut World,
446    matching_page_indices: Vec<u32>,
447    plan: QueryPlan,
448    current_page_index: usize,
449    current_row_index: usize,
450    _phantom: PhantomData<(&'a (), Q)>,
451    /// Pre-computed bitset intersection for fast-failing transversal lookups.
452    combined_bitset: Option<DomainBitset>,
453}
454
455impl<'a, Q: WorldQuery> QueryMut<'a, Q> {
456    /// Creates a new `QueryMut` iterator.
457    ///
458    /// This is intended to be called only by `World::query_mut()`.
459    pub(crate) fn new(
460        world: &'a mut World,
461        plan: QueryPlan,
462        matching_page_indices: Vec<u32>,
463    ) -> Self {
464        let combined_bitset = world.compute_query_bitset(&plan);
465        Self {
466            world_ptr: world as *mut _,
467            matching_page_indices,
468            plan,
469            current_page_index: 0,
470            current_row_index: 0,
471            _phantom: PhantomData,
472            combined_bitset,
473        }
474    }
475}
476
477impl<'a, Q: WorldQuery> Iterator for QueryMut<'a, Q> {
478    type Item = Q::Item<'a>;
479
480    fn next(&mut self) -> Option<Self::Item> {
481        match self.plan.mode {
482            QueryMode::Native => self.next_native(),
483            QueryMode::Transversal => self.next_transversal(),
484        }
485    }
486}
487
488impl<'a, Q: WorldQuery> QueryMut<'a, Q> {
489    fn next_native(&mut self) -> Option<Q::Item<'a>> {
490        loop {
491            if self.current_page_index >= self.matching_page_indices.len() {
492                return None;
493            }
494
495            // SAFETY: `world_ptr` is guaranteed to be valid for the lifetime 'a.
496            let world = unsafe { &mut *self.world_ptr };
497            let page_id = self.matching_page_indices[self.current_page_index] as usize;
498
499            // We get a mutable reference to the page, which is a safe operation
500            // because `world` is a mutable reference.
501            let page = &mut world.storage.pages[page_id];
502
503            if self.current_row_index < page.row_count() {
504                let item = unsafe { Q::fetch(page as *mut _ as *const _, self.current_row_index) };
505
506                self.current_row_index += 1;
507                return Some(item);
508            } else {
509                self.current_page_index += 1;
510                self.current_row_index = 0;
511            }
512        }
513    }
514
515    fn next_transversal(&mut self) -> Option<Q::Item<'a>> {
516        loop {
517            if self.current_page_index >= self.matching_page_indices.len() {
518                return None;
519            }
520
521            // Get the current page.
522            let world = unsafe { &mut *self.world_ptr };
523            let page_id = self.matching_page_indices[self.current_page_index] as usize;
524            let page = &mut world.storage.pages[page_id];
525
526            // Check if there are rows left in the current page.
527            if self.current_row_index < page.row_count() {
528                let entity_id = page.entities[self.current_row_index];
529                self.current_row_index += 1;
530
531                // Skip entities that are not in the combined bitset.
532                if let Some(combined) = &self.combined_bitset {
533                    if !combined.is_set(entity_id.index) {
534                        continue;
535                    }
536                }
537
538                // Attempt to fetch the Full item (driver + peers) from the world.
539                if let Some(item) = unsafe { Q::fetch_from_world(world as *const _, entity_id) } {
540                    return Some(item);
541                }
542            } else {
543                self.current_page_index += 1;
544                self.current_row_index = 0;
545            }
546        }
547    }
548}