khora_data/ecs/
query.rs

1// Copyright 2025 eraflo
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use khora_core::ecs::entity::EntityId;
16
17use crate::ecs::{
18    page::{AnyVec, ComponentPage},
19    Component, World,
20};
21use std::{any::TypeId, marker::PhantomData};
22
23// ------------------------- //
24// ---- WorldQuery Part ---- //
25// ------------------------- //
26
27/// A trait implemented by types that can be used to query data from the `World`.
28///
29/// This "sealed" trait provides the necessary information for the query engine to
30/// find the correct pages and safely access component data. It is implemented for
31/// component references (`&T`, `&mut T`), `EntityId`, filters like `Without<T>`,
32/// and tuples of other `WorldQuery` types.
33pub trait WorldQuery {
34    /// The type of item that the query iterator will yield (e.g., `(&'a Position, &'a mut Velocity)`).
35    type Item<'a>;
36
37    /// Returns the sorted list of `TypeId`s for the components to be INCLUDED in the query.
38    /// This signature is used to find `ComponentPage`s that contain all these components.
39    fn type_ids() -> Vec<TypeId>;
40
41    /// Returns the sorted list of `TypeId`s for components to be EXCLUDED from the query.
42    /// Used to filter out pages that contain these components.
43    fn without_type_ids() -> Vec<TypeId> {
44        Vec::new()
45    }
46
47    /// Fetches the query's item from a specific row in a `ComponentPage`.
48    ///
49    /// # Safety
50    ///
51    /// This function is unsafe because the caller (the `Query` iterator) must guarantee
52    /// several invariants:
53    /// 1. The page pointed to by `page_ptr` is valid and matches the query's signature.
54    /// 2. `row_index` is a valid index within the page's columns.
55    /// 3. Aliasing rules are not violated (e.g., no two `&mut T` to the same data).
56    unsafe fn fetch<'a>(page_ptr: *const ComponentPage, row_index: usize) -> Self::Item<'a>;
57}
58
59// Implementation for a query of a single, immutable component reference.
60impl<T: Component> WorldQuery for &T {
61    type Item<'a> = &'a T;
62
63    fn type_ids() -> Vec<TypeId> {
64        vec![TypeId::of::<T>()]
65    }
66
67    /// Fetches a reference to the component `T` from the specified row.
68    ///
69    /// # Safety
70    /// The caller MUST guarantee that the page contains a column for component `T`
71    /// and that `row_index` is in bounds.
72    unsafe fn fetch<'a>(page_ptr: *const ComponentPage, row_index: usize) -> Self::Item<'a> {
73        // 1. Get a reference to the `ComponentPage`.
74        let page = &*page_ptr;
75
76        // 2. Get the type-erased column for the component `T`.
77        // We can unwrap because the caller guarantees the column exists.
78        let column: &dyn AnyVec = &**page.columns.get(&TypeId::of::<T>()).unwrap();
79
80        // 3. Downcast the column to its concrete `Vec<T>` type.
81        // First, cast to `&dyn Any`, then `downcast_ref`.
82        let vec: &Vec<T> = column.as_any().downcast_ref::<Vec<T>>().unwrap();
83
84        // 4. Get the component from the vector at the specified row.
85        // We use `get_unchecked` for performance, as the caller guarantees the
86        // index is in bounds. This avoids a bounds check.
87        vec.get_unchecked(row_index)
88    }
89}
90
91// Implementation for a query of a single, mutable component reference.
92impl<T: Component> WorldQuery for &mut T {
93    type Item<'a> = &'a mut T;
94
95    fn type_ids() -> Vec<TypeId> {
96        vec![TypeId::of::<T>()]
97    }
98
99    /// Fetches a mutable reference to the component `T` from the specified row.
100    ///
101    /// # Safety
102    /// The caller MUST guarantee that:
103    /// 1. The page contains a column for component `T`.
104    /// 2. `row_index` is in bounds.
105    /// 3. No other mutable reference to this specific component exists at the same time.
106    ///    The query engine must enforce this.
107    unsafe fn fetch<'a>(page_ptr: *const ComponentPage, row_index: usize) -> Self::Item<'a> {
108        // UNSAFE: We cast the const pointer to a mutable one.
109        // This is safe ONLY if the query engine guarantees no other access.
110        let page = &mut *(page_ptr as *mut ComponentPage);
111        let column = page.columns.get_mut(&TypeId::of::<T>()).unwrap();
112        let vec = column.as_any_mut().downcast_mut::<Vec<T>>().unwrap();
113        vec.get_unchecked_mut(row_index)
114    }
115}
116
117// Implementation for a 1-item tuple query.
118impl<Q1: WorldQuery> WorldQuery for (Q1,) {
119    type Item<'a> = (Q1::Item<'a>,);
120
121    fn type_ids() -> Vec<TypeId> {
122        Q1::type_ids()
123    }
124
125    fn without_type_ids() -> Vec<TypeId> {
126        Q1::without_type_ids()
127    }
128
129    unsafe fn fetch<'a>(page_ptr: *const ComponentPage, row_index: usize) -> Self::Item<'a> {
130        (Q1::fetch(page_ptr, row_index),)
131    }
132}
133
134// Implementation for a query of a 2-item tuple.
135// This is now generic over any two types that implement `WorldQuery`.
136impl<Q1: WorldQuery, Q2: WorldQuery> WorldQuery for (Q1, Q2) {
137    type Item<'a> = (Q1::Item<'a>, Q2::Item<'a>);
138
139    fn type_ids() -> Vec<TypeId> {
140        let mut ids = Q1::type_ids();
141        ids.extend(Q2::type_ids());
142        // Sorting is crucial to maintain a canonical signature.
143        ids.sort();
144        // We also remove duplicates, in case of queries like `(&T, &T)`.
145        ids.dedup();
146        ids
147    }
148
149    fn without_type_ids() -> Vec<TypeId> {
150        let mut ids = Q1::without_type_ids();
151        ids.extend(Q2::without_type_ids());
152        ids.sort();
153        ids.dedup();
154        ids
155    }
156
157    unsafe fn fetch<'a>(page_ptr: *const ComponentPage, row_index: usize) -> Self::Item<'a> {
158        // Fetch data for each part of the tuple individually.
159        let item1 = Q1::fetch(page_ptr, row_index);
160        let item2 = Q2::fetch(page_ptr, row_index);
161        (item1, item2)
162    }
163}
164
165// Implementation for a query of a 3-item tuple.
166impl<Q1: WorldQuery, Q2: WorldQuery, Q3: WorldQuery> WorldQuery for (Q1, Q2, Q3) {
167    type Item<'a> = (Q1::Item<'a>, Q2::Item<'a>, Q3::Item<'a>);
168
169    fn type_ids() -> Vec<TypeId> {
170        let mut ids = Q1::type_ids();
171        ids.extend(Q2::type_ids());
172        ids.extend(Q3::type_ids());
173        ids.sort();
174        ids.dedup();
175        ids
176    }
177
178    fn without_type_ids() -> Vec<TypeId> {
179        let mut ids = Q1::without_type_ids();
180        ids.extend(Q2::without_type_ids());
181        ids.extend(Q3::without_type_ids());
182        ids.sort();
183        ids.dedup();
184        ids
185    }
186
187    unsafe fn fetch<'a>(page_ptr: *const ComponentPage, row_index: usize) -> Self::Item<'a> {
188        // Fetch data for each part of the tuple individually.
189        let item1 = Q1::fetch(page_ptr, row_index);
190        let item2 = Q2::fetch(page_ptr, row_index);
191        let item3 = Q3::fetch(page_ptr, row_index);
192        (item1, item2, item3)
193    }
194}
195
196// Implementation for a query of a 4-item tuple.
197impl<Q1: WorldQuery, Q2: WorldQuery, Q3: WorldQuery, Q4: WorldQuery> WorldQuery
198    for (Q1, Q2, Q3, Q4)
199{
200    type Item<'a> = (Q1::Item<'a>, Q2::Item<'a>, Q3::Item<'a>, Q4::Item<'a>);
201
202    fn type_ids() -> Vec<TypeId> {
203        let mut ids = Q1::type_ids();
204        ids.extend(Q2::type_ids());
205        ids.extend(Q3::type_ids());
206        ids.extend(Q4::type_ids());
207        ids.sort();
208        ids.dedup();
209        ids
210    }
211
212    fn without_type_ids() -> Vec<TypeId> {
213        let mut ids = Q1::without_type_ids();
214        ids.extend(Q2::without_type_ids());
215        ids.extend(Q3::without_type_ids());
216        ids.extend(Q4::without_type_ids());
217        ids.sort();
218        ids.dedup();
219        ids
220    }
221
222    unsafe fn fetch<'a>(page_ptr: *const ComponentPage, row_index: usize) -> Self::Item<'a> {
223        (
224            Q1::fetch(page_ptr, row_index),
225            Q2::fetch(page_ptr, row_index),
226            Q3::fetch(page_ptr, row_index),
227            Q4::fetch(page_ptr, row_index),
228        )
229    }
230}
231
232// To fetch an entity's ID, we need to access the page's own entity list.
233// We also need to query for the entity ID itself.
234impl WorldQuery for EntityId {
235    type Item<'a> = EntityId;
236
237    // EntityId is not a component, it doesn't have a TypeId in the page signature.
238    fn type_ids() -> Vec<TypeId> {
239        Vec::new()
240    }
241
242    // It doesn't have a `without` filter either.
243    // The default implementation is sufficient.
244
245    unsafe fn fetch<'a>(page_ptr: *const ComponentPage, row_index: usize) -> Self::Item<'a> {
246        let page = &*page_ptr;
247        // The entity ID is fetched from the page's own list of entities.
248        *page.entities.get_unchecked(row_index)
249    }
250}
251
252// -------------------- //
253// ---- Query Part ---- //
254// -------------------- //
255
256/// An iterator that yields the results of a `WorldQuery`.
257///
258/// This struct is created by the [`World::query()`] method. It holds a reference
259/// to the world and iterates through all `ComponentPage`s that match the query's
260/// signature, fetching the requested data for each entity in those pages.
261pub struct Query<'a, Q: WorldQuery> {
262    /// A raw pointer to the world. Using a pointer avoids lifetime variance issues
263    /// and gives us the flexibility to provide either mutable or immutable access.
264    world_ptr: *const World,
265
266    /// The list of page indices that match this query.
267    matching_page_indices: Vec<u32>,
268
269    /// The index of the current page we are iterating through.
270    current_page_index: usize,
271
272    /// The index of the next row to fetch within the current page.
273    current_row_index: usize,
274
275    /// A marker to associate this iterator with the lifetime `'a` and the query type `Q`.
276    /// It tells Rust that our iterator behaves as if it's borrowing `&'a Q` from the world.
277    _phantom: PhantomData<(&'a (), Q)>,
278}
279
280impl<'a, Q: WorldQuery> Query<'a, Q> {
281    /// (Internal) Creates a new `Query` iterator.
282    ///
283    /// This is intended to be called only by `World::query()`.
284    /// It takes the world and the list of matching pages as arguments.
285    pub(crate) fn new(world: &'a World, matching_page_indices: Vec<u32>) -> Self {
286        Self {
287            world_ptr: world as *const _,
288            matching_page_indices,
289            current_page_index: 0,
290            current_row_index: 0,
291            _phantom: PhantomData,
292        }
293    }
294}
295
296impl<'a, Q: WorldQuery> Iterator for Query<'a, Q> {
297    /// The type of item yielded by this iterator, as defined by the `WorldQuery` trait.
298    type Item = Q::Item<'a>;
299
300    /// Advances the iterator and returns the next item.
301    fn next(&mut self) -> Option<Self::Item> {
302        loop {
303            // 1. Check if there are any pages left to iterate through.
304            if self.current_page_index >= self.matching_page_indices.len() {
305                return None; // No more pages, iteration is finished.
306            }
307
308            // Unsafe block because we are dereferencing a raw pointer.
309            // This is safe because the `Query` is only created from a valid `World` reference.
310            let world = unsafe { &*self.world_ptr };
311
312            // 2. Get the current page.
313            let page_id = self.matching_page_indices[self.current_page_index];
314            let page = &world.pages[page_id as usize];
315
316            // 3. Check if there are rows left in the current page.
317            if self.current_row_index < page.row_count() {
318                // There are rows left, so we can fetch the data.
319                let item = unsafe {
320                    // This is safe because we've already matched the page signature
321                    // and we are iterating within the page's bounds.
322                    Q::fetch(page as *const _, self.current_row_index)
323                };
324
325                // Advance the row index for the next call.
326                self.current_row_index += 1;
327                return Some(item);
328            } else {
329                // 4. No more rows in this page. Move to the next page.
330                self.current_page_index += 1;
331                self.current_row_index = 0; // Reset the row index for the new page.
332                                            // The `loop` will then re-evaluate with the new page index.
333            }
334        }
335    }
336}
337
338/// A `WorldQuery` filter that matches entities that do NOT have component `T`.
339///
340/// This is used as a marker in a query tuple to exclude entities. For example,
341/// `Query<(&Position, Without<Velocity>)>` will iterate over all entities
342/// that have a `Position` but no `Velocity`.
343pub struct Without<T: Component>(PhantomData<T>);
344
345// `Without<T>` itself doesn't fetch any data, so its `WorldQuery` implementation
346// is mostly empty. It acts as a signal to the query engine.
347impl<T: Component> WorldQuery for Without<T> {
348    /// This query item is a zero-sized unit type, as it fetches no data.
349    type Item<'a> = ();
350
351    /// A `Without` filter does not add any component types to the query's main
352    /// signature for page matching. The filtering is handled separately.
353    fn type_ids() -> Vec<TypeId> {
354        Vec::new() // Returns an empty Vec.
355    }
356
357    /// This is the key part of the filter: it returns the TypeId of the component to filter out.
358    fn without_type_ids() -> Vec<TypeId> {
359        // This is the key change: it returns the TypeId of the component to filter out.
360        vec![TypeId::of::<T>()]
361    }
362
363    /// Fetches nothing. Returns a unit type `()`.
364    unsafe fn fetch<'a>(_page_ptr: *const ComponentPage, _row_index: usize) -> Self::Item<'a> {
365        // This function will be called but its result is ignored.
366    }
367}
368
369// ------------------------- //
370// ---- QueryMut Part ---- //
371// ------------------------- //
372
373/// An iterator that yields the results of a mutable `WorldQuery`.
374///
375/// This struct is created by the [`World::query_mut()`] method.
376pub struct QueryMut<'a, Q: WorldQuery> {
377    world_ptr: *mut World,
378    matching_page_indices: Vec<u32>,
379    current_page_index: usize,
380    current_row_index: usize,
381    _phantom: PhantomData<(&'a (), Q)>,
382}
383
384impl<'a, Q: WorldQuery> QueryMut<'a, Q> {
385    pub(crate) fn new(world: &'a mut World, matching_page_indices: Vec<u32>) -> Self {
386        Self {
387            world_ptr: world as *mut _,
388            matching_page_indices,
389            current_page_index: 0,
390            current_row_index: 0,
391            _phantom: PhantomData,
392        }
393    }
394}
395
396impl<'a, Q: WorldQuery> Iterator for QueryMut<'a, Q> {
397    type Item = Q::Item<'a>;
398
399    fn next(&mut self) -> Option<Self::Item> {
400        loop {
401            if self.current_page_index >= self.matching_page_indices.len() {
402                return None;
403            }
404
405            // SAFETY: `world_ptr` is guaranteed to be valid for the lifetime 'a.
406            let world = unsafe { &mut *self.world_ptr };
407            let page_id = self.matching_page_indices[self.current_page_index] as usize;
408
409            // We get a mutable reference to the page, which is a safe operation
410            // because `world` is a mutable reference.
411            let page = &mut world.pages[page_id];
412
413            if self.current_row_index < page.row_count() {
414                let item = unsafe {
415                    // We can now safely pass a pointer that originates from a mutable reference.
416                    // The `fetch` implementation for `&mut T` will cast this pointer, but the operation
417                    // is now sound because the original source was mutable. This eliminates the UB.
418                    Q::fetch(page as *mut _ as *const _, self.current_row_index)
419                };
420
421                self.current_row_index += 1;
422                return Some(item);
423            } else {
424                // No more rows, advance to the next page.
425                self.current_page_index += 1;
426                self.current_row_index = 0;
427            }
428        }
429    }
430}