1use std::{
8 any::{Any, TypeId},
9 cell::RefCell,
10 marker::PhantomData,
11 sync::Arc,
12};
13
14use im::HashMap as ImHashMap;
15use parking_lot::RwLock;
16use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
17
18use crate::{
19 execution_context::{with_execution_context, with_execution_context_mut},
20 runtime::{
21 RuntimePhase, compute_context_slot_key, current_component_instance_key_from_scope,
22 current_phase, ensure_build_phase, record_component_invalidation_for_instance_key,
23 },
24};
25
26#[derive(Clone, Copy, Debug, Eq, PartialEq)]
27pub(crate) struct ContextSnapshotEntry {
28 slot: u32,
29 generation: u64,
30 key: SlotKey,
31}
32
33pub(crate) type ContextMap = ImHashMap<TypeId, ContextSnapshotEntry>;
34
35#[derive(Hash, Eq, PartialEq, Clone, Copy, Debug)]
36struct SlotKey {
37 instance_logic_id: u64,
38 slot_hash: u64,
39 type_id: TypeId,
40}
41
42struct SlotEntry {
43 key: SlotKey,
44 generation: u64,
45 value: Option<Arc<dyn Any + Send + Sync>>,
46 last_alive_epoch: u64,
47}
48
49#[derive(Default)]
50struct SlotTable {
51 entries: Vec<SlotEntry>,
52 free_list: Vec<u32>,
53 key_to_slot: HashMap<SlotKey, u32>,
54 epoch: u64,
55}
56
57impl SlotTable {
58 fn begin_epoch(&mut self) {
59 self.epoch = self.epoch.wrapping_add(1);
60 }
61}
62
63fn with_slot_table<R>(f: impl FnOnce(&SlotTable) -> R) -> R {
64 CONTEXT_GLOBALS.with(|globals| f(&globals.slot_table.borrow()))
65}
66
67fn with_slot_table_mut<R>(f: impl FnOnce(&mut SlotTable) -> R) -> R {
68 CONTEXT_GLOBALS.with(|globals| f(&mut globals.slot_table.borrow_mut()))
69}
70
71#[derive(Default)]
72struct ContextSnapshotTracker {
73 previous_by_instance_key: HashMap<u64, ContextMap>,
74 current_by_instance_key: HashMap<u64, ContextMap>,
75}
76
77#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
78struct ContextReadDependencyKey {
79 slot: u32,
80 generation: u64,
81}
82
83#[derive(Default)]
84struct ContextReadDependencyTracker {
85 readers_by_context: HashMap<ContextReadDependencyKey, HashSet<u64>>,
86 contexts_by_reader: HashMap<u64, HashSet<ContextReadDependencyKey>>,
87}
88
89fn with_context_snapshot_tracker<R>(f: impl FnOnce(&ContextSnapshotTracker) -> R) -> R {
90 CONTEXT_GLOBALS.with(|globals| f(&globals.snapshot_tracker.borrow()))
91}
92
93fn with_context_snapshot_tracker_mut<R>(f: impl FnOnce(&mut ContextSnapshotTracker) -> R) -> R {
94 CONTEXT_GLOBALS.with(|globals| f(&mut globals.snapshot_tracker.borrow_mut()))
95}
96
97fn with_context_read_dependency_tracker<R>(
98 f: impl FnOnce(&ContextReadDependencyTracker) -> R,
99) -> R {
100 CONTEXT_GLOBALS.with(|globals| f(&globals.read_dependency_tracker.borrow()))
101}
102
103fn with_context_read_dependency_tracker_mut<R>(
104 f: impl FnOnce(&mut ContextReadDependencyTracker) -> R,
105) -> R {
106 CONTEXT_GLOBALS.with(|globals| f(&mut globals.read_dependency_tracker.borrow_mut()))
107}
108
109struct ContextGlobals {
110 slot_table: RefCell<SlotTable>,
111 snapshot_tracker: RefCell<ContextSnapshotTracker>,
112 read_dependency_tracker: RefCell<ContextReadDependencyTracker>,
113}
114
115impl ContextGlobals {
116 fn new() -> Self {
117 Self {
118 slot_table: RefCell::new(SlotTable::default()),
119 snapshot_tracker: RefCell::new(ContextSnapshotTracker::default()),
120 read_dependency_tracker: RefCell::new(ContextReadDependencyTracker::default()),
121 }
122 }
123}
124
125thread_local! {
126 static CONTEXT_GLOBALS: ContextGlobals = ContextGlobals::new();
127}
128
129fn current_context_map() -> ContextMap {
130 with_execution_context(|context| {
131 context
132 .context_stack
133 .last()
134 .cloned()
135 .unwrap_or_else(ContextMap::new)
136 })
137}
138
139fn resolve_snapshot_entry(entry: ContextSnapshotEntry) -> Option<ContextSnapshotEntry> {
140 with_slot_table(|table| {
141 let live_entry = table.entries.get(entry.slot as usize);
142 if let Some(live_entry) = live_entry
143 && live_entry.value.is_some()
144 && live_entry.key == entry.key
145 {
146 return Some(ContextSnapshotEntry {
147 slot: entry.slot,
148 generation: live_entry.generation,
149 key: entry.key,
150 });
151 }
152
153 let slot = table.key_to_slot.get(&entry.key).copied()?;
154 let live_entry = table.entries.get(slot as usize)?;
155 live_entry.value.as_ref()?;
156 Some(ContextSnapshotEntry {
157 slot,
158 generation: live_entry.generation,
159 key: entry.key,
160 })
161 })
162}
163
164fn normalize_context_snapshot(snapshot: &ContextMap) -> ContextMap {
165 snapshot
166 .iter()
167 .fold(ContextMap::new(), |acc, (type_id, entry)| {
168 if let Some(resolved) = resolve_snapshot_entry(*entry) {
169 acc.update(*type_id, resolved)
170 } else {
171 acc
172 }
173 })
174}
175
176pub(crate) fn begin_frame_component_context_tracking() {
177 with_context_snapshot_tracker_mut(|tracker| tracker.current_by_instance_key.clear());
178}
179
180pub(crate) fn finalize_frame_component_context_tracking() {
181 with_context_snapshot_tracker_mut(|tracker| {
182 tracker.previous_by_instance_key = std::mem::take(&mut tracker.current_by_instance_key);
183 });
184}
185
186pub(crate) fn finalize_frame_component_context_tracking_partial() {
187 with_context_snapshot_tracker_mut(|tracker| {
188 let current = std::mem::take(&mut tracker.current_by_instance_key);
189 tracker.previous_by_instance_key.extend(current);
190 });
191}
192
193pub(crate) fn reset_component_context_tracking() {
194 with_context_snapshot_tracker_mut(|tracker| {
195 *tracker = ContextSnapshotTracker::default();
196 });
197}
198
199pub(crate) fn previous_component_context_snapshots() -> HashMap<u64, ContextMap> {
200 with_context_snapshot_tracker(|tracker| tracker.previous_by_instance_key.clone())
201}
202
203pub(crate) fn context_from_previous_snapshot_for_instance<T>(
204 instance_key: u64,
205) -> Option<Context<T>>
206where
207 T: Send + Sync + 'static,
208{
209 with_context_snapshot_tracker(|tracker| {
210 let map = tracker.previous_by_instance_key.get(&instance_key)?;
211 map.get(&TypeId::of::<T>())
212 .copied()
213 .and_then(resolve_snapshot_entry)
214 .map(|entry| Context::new(entry.slot, entry.generation))
215 })
216}
217
218pub(crate) fn remove_previous_component_context_snapshots(instance_keys: &HashSet<u64>) {
219 if instance_keys.is_empty() {
220 return;
221 }
222 with_context_snapshot_tracker_mut(|tracker| {
223 tracker
224 .previous_by_instance_key
225 .retain(|instance_key, _| !instance_keys.contains(instance_key));
226 tracker
227 .current_by_instance_key
228 .retain(|instance_key, _| !instance_keys.contains(instance_key));
229 });
230}
231
232pub(crate) fn remove_context_read_dependencies(instance_keys: &HashSet<u64>) {
233 if instance_keys.is_empty() {
234 return;
235 }
236 with_context_read_dependency_tracker_mut(|tracker| {
237 for instance_key in instance_keys {
238 let Some(context_keys) = tracker.contexts_by_reader.remove(instance_key) else {
239 continue;
240 };
241 for context_key in context_keys {
242 let mut remove_entry = false;
243 if let Some(readers) = tracker.readers_by_context.get_mut(&context_key) {
244 readers.remove(instance_key);
245 remove_entry = readers.is_empty();
246 }
247 if remove_entry {
248 tracker.readers_by_context.remove(&context_key);
249 }
250 }
251 }
252 });
253}
254
255pub(crate) fn reset_context_read_dependencies() {
256 with_context_read_dependency_tracker_mut(|tracker| {
257 *tracker = ContextReadDependencyTracker::default();
258 });
259}
260
261pub(crate) fn record_current_context_snapshot_for(instance_key: u64) {
262 with_context_snapshot_tracker_mut(|tracker| {
263 tracker
264 .current_by_instance_key
265 .insert(instance_key, current_context_map());
266 });
267}
268
269pub(crate) fn with_context_snapshot<R>(snapshot: &ContextMap, f: impl FnOnce() -> R) -> R {
270 struct ContextSnapshotGuard {
271 previous_stack: Option<Vec<ContextMap>>,
272 }
273
274 impl Drop for ContextSnapshotGuard {
275 fn drop(&mut self) {
276 if let Some(previous_stack) = self.previous_stack.take() {
277 with_execution_context_mut(|context| {
278 context.context_stack = previous_stack;
279 });
280 }
281 }
282 }
283
284 let normalized_snapshot = normalize_context_snapshot(snapshot);
285 let previous_stack = with_execution_context_mut(|context| {
286 std::mem::replace(&mut context.context_stack, vec![normalized_snapshot])
287 });
288 let _guard = ContextSnapshotGuard {
289 previous_stack: Some(previous_stack),
290 };
291
292 f()
293}
294
295fn track_context_read_dependency(slot: u32, generation: u64) {
296 if !matches!(current_phase(), Some(RuntimePhase::Build)) {
297 return;
298 }
299 let Some(reader_instance_key) = current_component_instance_key_from_scope() else {
300 return;
301 };
302
303 let key = ContextReadDependencyKey { slot, generation };
304 with_context_read_dependency_tracker_mut(|tracker| {
305 tracker
306 .readers_by_context
307 .entry(key)
308 .or_default()
309 .insert(reader_instance_key);
310 tracker
311 .contexts_by_reader
312 .entry(reader_instance_key)
313 .or_default()
314 .insert(key);
315 });
316}
317
318fn context_read_subscribers(slot: u32, generation: u64) -> Vec<u64> {
319 let key = ContextReadDependencyKey { slot, generation };
320 with_context_read_dependency_tracker(|tracker| {
321 tracker
322 .readers_by_context
323 .get(&key)
324 .map(|readers| readers.iter().copied().collect())
325 .unwrap_or_default()
326 })
327}
328
329pub(crate) fn begin_recompose_context_slot_epoch() {
330 with_slot_table_mut(SlotTable::begin_epoch);
332 with_execution_context_mut(|context| {
333 context.context_stack.clear();
334 context.context_stack.push(ContextMap::new());
335 });
336}
337
338pub(crate) fn recycle_recomposed_context_slots_for_instance_logic_ids(
339 instance_logic_ids: &HashSet<u64>,
340) {
341 if instance_logic_ids.is_empty() {
342 return;
343 }
344
345 with_slot_table_mut(|table| {
347 let epoch = table.epoch;
348 let mut freed: Vec<(u32, SlotKey)> = Vec::new();
349 for (slot, entry) in table.entries.iter_mut().enumerate() {
350 if !instance_logic_ids.contains(&entry.key.instance_logic_id) {
351 continue;
352 }
353 if entry.value.is_none() {
354 continue;
355 }
356 if entry.last_alive_epoch == epoch {
357 continue;
358 }
359
360 freed.push((slot as u32, entry.key));
361 entry.value = None;
362 entry.generation = entry.generation.wrapping_add(1);
363 entry.last_alive_epoch = 0;
364 }
365
366 for (slot, key) in freed {
367 table.key_to_slot.remove(&key);
368 table.free_list.push(slot);
369 }
370 });
371}
372
373pub(crate) fn live_context_slot_instance_logic_ids() -> HashSet<u64> {
374 with_slot_table(|table| {
375 table
376 .entries
377 .iter()
378 .filter(|entry| entry.value.is_some())
379 .map(|entry| entry.key.instance_logic_id)
380 .collect()
381 })
382}
383
384pub(crate) fn drop_context_slots_for_instance_logic_ids(instance_logic_ids: &HashSet<u64>) {
385 if instance_logic_ids.is_empty() {
386 return;
387 }
388
389 with_slot_table_mut(|table| {
390 let mut freed: Vec<(u32, SlotKey)> = Vec::new();
391 for (slot, entry) in table.entries.iter_mut().enumerate() {
392 if entry.value.is_none() {
393 continue;
394 }
395 if !instance_logic_ids.contains(&entry.key.instance_logic_id) {
396 continue;
397 }
398 freed.push((slot as u32, entry.key));
399 entry.value = None;
400 entry.generation = entry.generation.wrapping_add(1);
401 entry.last_alive_epoch = 0;
402 }
403
404 for (slot, key) in freed {
405 table.key_to_slot.remove(&key);
406 table.free_list.push(slot);
407 }
408 });
409}
410
411pub struct Context<T> {
446 slot: u32,
447 generation: u64,
448 _marker: PhantomData<T>,
449}
450
451impl<T> Copy for Context<T> {}
452
453impl<T> Clone for Context<T> {
454 fn clone(&self) -> Self {
455 *self
456 }
457}
458
459impl<T> Context<T> {
460 fn new(slot: u32, generation: u64) -> Self {
461 Self {
462 slot,
463 generation,
464 _marker: PhantomData,
465 }
466 }
467}
468
469impl<T> Context<T>
470where
471 T: Send + Sync + 'static,
472{
473 fn load_entry(&self) -> Arc<dyn Any + Send + Sync> {
474 with_slot_table(|table| {
475 let entry = table
476 .entries
477 .get(self.slot as usize)
478 .unwrap_or_else(|| panic!("Context points to freed slot: {}", self.slot));
479
480 if entry.generation != self.generation {
481 panic!(
482 "Context is stale (slot {}, generation {}, current generation {})",
483 self.slot, self.generation, entry.generation
484 );
485 }
486
487 if entry.key.type_id != TypeId::of::<T>() {
488 panic!(
489 "Context type mismatch for slot {}: expected {}, stored {:?}",
490 self.slot,
491 std::any::type_name::<T>(),
492 entry.key.type_id
493 );
494 }
495
496 entry
497 .value
498 .as_ref()
499 .unwrap_or_else(|| panic!("Context slot {} has been recycled", self.slot))
500 .clone()
501 })
502 }
503
504 fn load_lock(&self) -> Arc<RwLock<T>> {
505 self.load_entry()
506 .downcast::<RwLock<T>>()
507 .unwrap_or_else(|_| panic!("Context slot {} downcast failed", self.slot))
508 }
509
510 pub fn with<R>(&self, f: impl FnOnce(&T) -> R) -> R {
512 track_context_read_dependency(self.slot, self.generation);
513 let lock = self.load_lock();
514 let guard = lock.read();
515 f(&guard)
516 }
517
518 pub fn with_mut<R>(&self, f: impl FnOnce(&mut T) -> R) -> R {
520 let lock = self.load_lock();
521 let result = {
522 let mut guard = lock.write();
523 f(&mut guard)
524 };
525 let subscribers = context_read_subscribers(self.slot, self.generation);
526 for instance_key in subscribers {
527 record_component_invalidation_for_instance_key(instance_key);
528 }
529 result
530 }
531
532 pub fn get(&self) -> T
534 where
535 T: Clone,
536 {
537 self.with(Clone::clone)
538 }
539
540 pub fn set(&self, value: T) {
542 self.with_mut(|v| *v = value);
543 }
544
545 pub fn replace(&self, value: T) -> T {
547 self.with_mut(|v| std::mem::replace(v, value))
548 }
549}
550
551fn push_context_layer(type_id: TypeId, slot: u32, generation: u64, key: SlotKey) {
552 with_execution_context_mut(|context| {
553 let parent = context
554 .context_stack
555 .last()
556 .cloned()
557 .unwrap_or_else(ContextMap::new);
558 let next = parent.update(
559 type_id,
560 ContextSnapshotEntry {
561 slot,
562 generation,
563 key,
564 },
565 );
566 context.context_stack.push(next);
567 });
568}
569
570fn pop_context_layer() {
571 with_execution_context_mut(|context| {
572 let popped = context.context_stack.pop();
573 debug_assert!(popped.is_some(), "Context stack underflow");
574 if context.context_stack.is_empty() {
575 context.context_stack.push(ContextMap::new());
576 }
577 });
578}
579
580pub fn provide_context<T, I, F, R>(init: I, f: F) -> R
616where
617 T: Send + Sync + 'static,
618 I: FnOnce() -> T,
619 F: FnOnce() -> R,
620{
621 ensure_build_phase();
622
623 let (instance_logic_id, slot_hash) = compute_context_slot_key();
624 let type_id = TypeId::of::<T>();
625 let slot_key = SlotKey {
626 instance_logic_id,
627 slot_hash,
628 type_id,
629 };
630
631 let (slot, generation) = {
632 with_slot_table_mut(|table| {
633 let epoch = table.epoch;
634
635 if let Some(slot) = table.key_to_slot.get(&slot_key).copied() {
636 let entry = table
637 .entries
638 .get_mut(slot as usize)
639 .expect("context slot entry should exist");
640 entry.last_alive_epoch = epoch;
641
642 if entry.value.is_none() {
643 entry.value = Some(Arc::new(RwLock::new(init())));
644 entry.generation = entry.generation.wrapping_add(1);
645 }
646
647 let generation = entry.generation;
648 (slot, generation)
649 } else if let Some(slot) = table.free_list.pop() {
650 let entry = table
651 .entries
652 .get_mut(slot as usize)
653 .expect("context slot entry should exist");
654
655 entry.key = slot_key;
656 entry.value = Some(Arc::new(RwLock::new(init())));
657 entry.last_alive_epoch = epoch;
658
659 let generation = entry.generation;
660 table.key_to_slot.insert(slot_key, slot);
661 (slot, generation)
662 } else {
663 let generation = 0;
664 let slot = table.entries.len() as u32;
665 table.entries.push(SlotEntry {
666 key: slot_key,
667 generation,
668 value: Some(Arc::new(RwLock::new(init()))),
669 last_alive_epoch: epoch,
670 });
671 table.key_to_slot.insert(slot_key, slot);
672 (slot, generation)
673 }
674 })
675 };
676
677 push_context_layer(type_id, slot, generation, slot_key);
678
679 struct ContextScopeGuard;
680 impl Drop for ContextScopeGuard {
681 fn drop(&mut self) {
682 pop_context_layer();
683 }
684 }
685
686 let guard = ContextScopeGuard;
687 let result = f();
688 drop(guard);
689 result
690}
691
692pub fn use_context<T>() -> Option<Context<T>>
711where
712 T: Send + Sync + 'static,
713{
714 ensure_build_phase();
715
716 with_execution_context(|context| {
717 let map = context
718 .context_stack
719 .last()
720 .expect("Context stack must always contain at least one layer");
721 map.get(&TypeId::of::<T>())
722 .copied()
723 .map(|entry| Context::new(entry.slot, entry.generation))
724 })
725}
726#[cfg(test)]
729mod tests {
730 use std::{any::TypeId, sync::Arc};
731
732 use parking_lot::RwLock;
733
734 use super::{
735 ContextMap, ContextSnapshotEntry, SlotEntry, SlotKey, SlotTable, with_context_snapshot,
736 with_slot_table_mut,
737 };
738 use crate::execution_context::{
739 reset_execution_context, with_execution_context, with_execution_context_mut,
740 };
741 use crate::runtime::{RuntimePhase, push_phase};
742
743 fn reset_test_state() {
744 with_slot_table_mut(|table| *table = SlotTable::default());
745 reset_execution_context();
746 with_execution_context_mut(|context| {
747 context.context_stack = vec![ContextMap::new()];
748 });
749 }
750
751 #[test]
752 fn with_context_snapshot_restores_stack_after_panic() {
753 reset_test_state();
754
755 let base_layer = ContextMap::new();
756 let parent_layer = base_layer.update(
757 TypeId::of::<u8>(),
758 ContextSnapshotEntry {
759 slot: 1,
760 generation: 2,
761 key: SlotKey {
762 instance_logic_id: 7,
763 slot_hash: 11,
764 type_id: TypeId::of::<u8>(),
765 },
766 },
767 );
768 let original_stack = vec![base_layer, parent_layer];
769 with_execution_context_mut(|context| {
770 context.context_stack = original_stack.clone();
771 });
772
773 let snapshot = ContextMap::new().update(
774 TypeId::of::<u32>(),
775 ContextSnapshotEntry {
776 slot: 3,
777 generation: 4,
778 key: SlotKey {
779 instance_logic_id: 17,
780 slot_hash: 19,
781 type_id: TypeId::of::<u32>(),
782 },
783 },
784 );
785 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
786 with_context_snapshot(&snapshot, || {
787 with_execution_context(|context| {
788 assert_eq!(context.context_stack.len(), 1);
789 assert!(context.context_stack.last().is_some());
790 });
791 panic!("expected panic in context snapshot test");
792 });
793 }));
794 assert!(result.is_err());
795
796 with_execution_context(|context| {
797 assert_eq!(context.context_stack, original_stack);
798 });
799 }
800
801 #[test]
802 fn with_context_snapshot_remaps_stale_generation() {
803 reset_test_state();
804
805 let key = SlotKey {
806 instance_logic_id: 31,
807 slot_hash: 37,
808 type_id: TypeId::of::<u8>(),
809 };
810 {
811 with_slot_table_mut(|table| {
812 *table = SlotTable::default();
813 table.entries.push(SlotEntry {
814 key,
815 generation: 2,
816 value: Some(Arc::new(RwLock::new(42_u8))),
817 last_alive_epoch: 1,
818 });
819 table.key_to_slot.insert(key, 0);
820 });
821 }
822
823 let snapshot = ContextMap::new().update(
824 TypeId::of::<u8>(),
825 ContextSnapshotEntry {
826 slot: 0,
827 generation: 1,
828 key,
829 },
830 );
831
832 let _phase_guard = push_phase(RuntimePhase::Build);
833 with_context_snapshot(&snapshot, || {
834 let context = super::use_context::<u8>().expect("context should be remapped");
835 assert_eq!(context.get(), 42);
836 });
837 }
838
839 #[test]
840 fn with_context_snapshot_remaps_stale_slot_by_key() {
841 reset_test_state();
842
843 let key = SlotKey {
844 instance_logic_id: 41,
845 slot_hash: 43,
846 type_id: TypeId::of::<u16>(),
847 };
848 let old_key = SlotKey {
849 instance_logic_id: 47,
850 slot_hash: 53,
851 type_id: TypeId::of::<u16>(),
852 };
853 {
854 with_slot_table_mut(|table| {
855 *table = SlotTable::default();
856 table.entries.push(SlotEntry {
857 key: old_key,
858 generation: 9,
859 value: None,
860 last_alive_epoch: 0,
861 });
862 table.entries.push(SlotEntry {
863 key,
864 generation: 4,
865 value: Some(Arc::new(RwLock::new(7_u16))),
866 last_alive_epoch: 2,
867 });
868 table.key_to_slot.insert(key, 1);
869 });
870 }
871
872 let snapshot = ContextMap::new().update(
873 TypeId::of::<u16>(),
874 ContextSnapshotEntry {
875 slot: 0,
876 generation: 3,
877 key,
878 },
879 );
880
881 let _phase_guard = push_phase(RuntimePhase::Build);
882 with_context_snapshot(&snapshot, || {
883 let context = super::use_context::<u16>().expect("context should be remapped by key");
884 assert_eq!(context.get(), 7);
885 });
886 }
887}