1use std::{
8 any::{Any, TypeId},
9 cell::RefCell,
10 marker::PhantomData,
11 sync::Arc,
12};
13
14use im::HashMap as ImHashMap;
15use parking_lot::RwLock;
16use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
17
18use crate::{
19 execution_context::{with_execution_context, with_execution_context_mut},
20 runtime::{
21 RuntimePhase, compute_context_slot_key, current_phase,
22 current_replay_boundary_instance_key_from_scope, ensure_build_phase,
23 record_replay_boundary_invalidation_for_instance_key,
24 },
25};
26
27#[derive(Clone, Copy, Debug, Eq, PartialEq)]
28pub(crate) struct ContextSnapshotEntry {
29 slot: u32,
30 generation: u64,
31 key: SlotKey,
32}
33
34pub(crate) type ContextMap = ImHashMap<TypeId, ContextSnapshotEntry>;
35
36#[derive(Hash, Eq, PartialEq, Clone, Copy, Debug)]
37struct SlotKey {
38 instance_logic_id: u64,
39 slot_hash: u64,
40 type_id: TypeId,
41}
42
43struct SlotEntry {
44 key: SlotKey,
45 generation: u64,
46 value: Option<Arc<dyn Any + Send + Sync>>,
47 last_alive_epoch: u64,
48}
49
50#[derive(Default)]
51struct SlotTable {
52 entries: Vec<SlotEntry>,
53 free_list: Vec<u32>,
54 key_to_slot: HashMap<SlotKey, u32>,
55 epoch: u64,
56}
57
58impl SlotTable {
59 fn begin_epoch(&mut self) {
60 self.epoch = self.epoch.wrapping_add(1);
61 }
62}
63
64fn with_slot_table<R>(f: impl FnOnce(&SlotTable) -> R) -> R {
65 CONTEXT_GLOBALS.with(|globals| f(&globals.slot_table.borrow()))
66}
67
68fn with_slot_table_mut<R>(f: impl FnOnce(&mut SlotTable) -> R) -> R {
69 CONTEXT_GLOBALS.with(|globals| f(&mut globals.slot_table.borrow_mut()))
70}
71
72#[derive(Default)]
73struct ContextSnapshotTracker {
74 previous_by_instance_key: HashMap<u64, ContextMap>,
75 current_by_instance_key: HashMap<u64, ContextMap>,
76}
77
78#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
79struct ContextReadDependencyKey {
80 slot: u32,
81 generation: u64,
82}
83
84#[derive(Default)]
85struct ContextReadDependencyTracker {
86 readers_by_context: HashMap<ContextReadDependencyKey, HashSet<u64>>,
87 contexts_by_reader: HashMap<u64, HashSet<ContextReadDependencyKey>>,
88}
89
90fn with_context_snapshot_tracker<R>(f: impl FnOnce(&ContextSnapshotTracker) -> R) -> R {
91 CONTEXT_GLOBALS.with(|globals| f(&globals.snapshot_tracker.borrow()))
92}
93
94fn with_context_snapshot_tracker_mut<R>(f: impl FnOnce(&mut ContextSnapshotTracker) -> R) -> R {
95 CONTEXT_GLOBALS.with(|globals| f(&mut globals.snapshot_tracker.borrow_mut()))
96}
97
98fn with_context_read_dependency_tracker<R>(
99 f: impl FnOnce(&ContextReadDependencyTracker) -> R,
100) -> R {
101 CONTEXT_GLOBALS.with(|globals| f(&globals.read_dependency_tracker.borrow()))
102}
103
104fn with_context_read_dependency_tracker_mut<R>(
105 f: impl FnOnce(&mut ContextReadDependencyTracker) -> R,
106) -> R {
107 CONTEXT_GLOBALS.with(|globals| f(&mut globals.read_dependency_tracker.borrow_mut()))
108}
109
110struct ContextGlobals {
111 slot_table: RefCell<SlotTable>,
112 snapshot_tracker: RefCell<ContextSnapshotTracker>,
113 read_dependency_tracker: RefCell<ContextReadDependencyTracker>,
114}
115
116impl ContextGlobals {
117 fn new() -> Self {
118 Self {
119 slot_table: RefCell::new(SlotTable::default()),
120 snapshot_tracker: RefCell::new(ContextSnapshotTracker::default()),
121 read_dependency_tracker: RefCell::new(ContextReadDependencyTracker::default()),
122 }
123 }
124}
125
126thread_local! {
127 static CONTEXT_GLOBALS: ContextGlobals = ContextGlobals::new();
128}
129
130fn current_context_map() -> ContextMap {
131 with_execution_context(|context| {
132 context
133 .context_stack
134 .last()
135 .cloned()
136 .unwrap_or_else(ContextMap::new)
137 })
138}
139
140fn resolve_snapshot_entry(entry: ContextSnapshotEntry) -> Option<ContextSnapshotEntry> {
141 with_slot_table(|table| {
142 let live_entry = table.entries.get(entry.slot as usize);
143 if let Some(live_entry) = live_entry
144 && live_entry.value.is_some()
145 && live_entry.key == entry.key
146 {
147 return Some(ContextSnapshotEntry {
148 slot: entry.slot,
149 generation: live_entry.generation,
150 key: entry.key,
151 });
152 }
153
154 let slot = table.key_to_slot.get(&entry.key).copied()?;
155 let live_entry = table.entries.get(slot as usize)?;
156 live_entry.value.as_ref()?;
157 Some(ContextSnapshotEntry {
158 slot,
159 generation: live_entry.generation,
160 key: entry.key,
161 })
162 })
163}
164
165fn normalize_context_snapshot(snapshot: &ContextMap) -> ContextMap {
166 snapshot
167 .iter()
168 .fold(ContextMap::new(), |acc, (type_id, entry)| {
169 if let Some(resolved) = resolve_snapshot_entry(*entry) {
170 acc.update(*type_id, resolved)
171 } else {
172 acc
173 }
174 })
175}
176
177pub(crate) fn begin_frame_component_context_tracking() {
178 with_context_snapshot_tracker_mut(|tracker| tracker.current_by_instance_key.clear());
179}
180
181pub(crate) fn finalize_frame_component_context_tracking() {
182 with_context_snapshot_tracker_mut(|tracker| {
183 tracker.previous_by_instance_key = std::mem::take(&mut tracker.current_by_instance_key);
184 });
185}
186
187pub(crate) fn finalize_frame_component_context_tracking_partial() {
188 with_context_snapshot_tracker_mut(|tracker| {
189 let current = std::mem::take(&mut tracker.current_by_instance_key);
190 tracker.previous_by_instance_key.extend(current);
191 });
192}
193
194pub(crate) fn reset_component_context_tracking() {
195 with_context_snapshot_tracker_mut(|tracker| {
196 *tracker = ContextSnapshotTracker::default();
197 });
198}
199
200pub(crate) fn previous_component_context_snapshots() -> HashMap<u64, ContextMap> {
201 with_context_snapshot_tracker(|tracker| tracker.previous_by_instance_key.clone())
202}
203
204pub(crate) fn context_from_previous_snapshot_for_instance<T>(
205 instance_key: u64,
206) -> Option<Context<T>>
207where
208 T: Send + Sync + 'static,
209{
210 with_context_snapshot_tracker(|tracker| {
211 let map = tracker.previous_by_instance_key.get(&instance_key)?;
212 map.get(&TypeId::of::<T>())
213 .copied()
214 .and_then(resolve_snapshot_entry)
215 .map(|entry| Context::new(entry.slot, entry.generation))
216 })
217}
218
219pub(crate) fn remove_previous_component_context_snapshots(instance_keys: &HashSet<u64>) {
220 if instance_keys.is_empty() {
221 return;
222 }
223 with_context_snapshot_tracker_mut(|tracker| {
224 tracker
225 .previous_by_instance_key
226 .retain(|instance_key, _| !instance_keys.contains(instance_key));
227 tracker
228 .current_by_instance_key
229 .retain(|instance_key, _| !instance_keys.contains(instance_key));
230 });
231}
232
233pub(crate) fn remove_context_read_dependencies(instance_keys: &HashSet<u64>) {
234 if instance_keys.is_empty() {
235 return;
236 }
237 with_context_read_dependency_tracker_mut(|tracker| {
238 for instance_key in instance_keys {
239 let Some(context_keys) = tracker.contexts_by_reader.remove(instance_key) else {
240 continue;
241 };
242 for context_key in context_keys {
243 let mut remove_entry = false;
244 if let Some(readers) = tracker.readers_by_context.get_mut(&context_key) {
245 readers.remove(instance_key);
246 remove_entry = readers.is_empty();
247 }
248 if remove_entry {
249 tracker.readers_by_context.remove(&context_key);
250 }
251 }
252 }
253 });
254}
255
256pub(crate) fn reset_context_read_dependencies() {
257 with_context_read_dependency_tracker_mut(|tracker| {
258 *tracker = ContextReadDependencyTracker::default();
259 });
260}
261
262pub(crate) fn record_current_context_snapshot_for(instance_key: u64) {
263 with_context_snapshot_tracker_mut(|tracker| {
264 tracker
265 .current_by_instance_key
266 .insert(instance_key, current_context_map());
267 });
268}
269
270pub(crate) fn with_context_snapshot<R>(snapshot: &ContextMap, f: impl FnOnce() -> R) -> R {
271 struct ContextSnapshotGuard {
272 previous_stack: Option<Vec<ContextMap>>,
273 }
274
275 impl Drop for ContextSnapshotGuard {
276 fn drop(&mut self) {
277 if let Some(previous_stack) = self.previous_stack.take() {
278 with_execution_context_mut(|context| {
279 context.context_stack = previous_stack;
280 });
281 }
282 }
283 }
284
285 let normalized_snapshot = normalize_context_snapshot(snapshot);
286 let previous_stack = with_execution_context_mut(|context| {
287 std::mem::replace(&mut context.context_stack, vec![normalized_snapshot])
288 });
289 let _guard = ContextSnapshotGuard {
290 previous_stack: Some(previous_stack),
291 };
292
293 f()
294}
295
296fn track_context_read_dependency(slot: u32, generation: u64) {
297 if !matches!(current_phase(), Some(RuntimePhase::Build)) {
298 return;
299 }
300 let Some(reader_instance_key) = current_replay_boundary_instance_key_from_scope() else {
301 return;
302 };
303
304 let key = ContextReadDependencyKey { slot, generation };
305 with_context_read_dependency_tracker_mut(|tracker| {
306 tracker
307 .readers_by_context
308 .entry(key)
309 .or_default()
310 .insert(reader_instance_key);
311 tracker
312 .contexts_by_reader
313 .entry(reader_instance_key)
314 .or_default()
315 .insert(key);
316 });
317}
318
319fn context_read_subscribers(slot: u32, generation: u64) -> Vec<u64> {
320 let key = ContextReadDependencyKey { slot, generation };
321 with_context_read_dependency_tracker(|tracker| {
322 tracker
323 .readers_by_context
324 .get(&key)
325 .map(|readers| readers.iter().copied().collect())
326 .unwrap_or_default()
327 })
328}
329
330pub(crate) fn begin_recompose_context_slot_epoch() {
331 with_slot_table_mut(SlotTable::begin_epoch);
333 with_execution_context_mut(|context| {
334 context.context_stack.clear();
335 context.context_stack.push(ContextMap::new());
336 });
337}
338
339pub(crate) fn recycle_recomposed_context_slots_for_instance_logic_ids(
340 instance_logic_ids: &HashSet<u64>,
341) {
342 if instance_logic_ids.is_empty() {
343 return;
344 }
345
346 with_slot_table_mut(|table| {
348 let epoch = table.epoch;
349 let mut freed: Vec<(u32, SlotKey)> = Vec::new();
350 for (slot, entry) in table.entries.iter_mut().enumerate() {
351 if !instance_logic_ids.contains(&entry.key.instance_logic_id) {
352 continue;
353 }
354 if entry.value.is_none() {
355 continue;
356 }
357 if entry.last_alive_epoch == epoch {
358 continue;
359 }
360
361 freed.push((slot as u32, entry.key));
362 entry.value = None;
363 entry.generation = entry.generation.wrapping_add(1);
364 entry.last_alive_epoch = 0;
365 }
366
367 for (slot, key) in freed {
368 table.key_to_slot.remove(&key);
369 table.free_list.push(slot);
370 }
371 });
372}
373
374pub(crate) fn live_context_slot_instance_logic_ids() -> HashSet<u64> {
375 with_slot_table(|table| {
376 table
377 .entries
378 .iter()
379 .filter(|entry| entry.value.is_some())
380 .map(|entry| entry.key.instance_logic_id)
381 .collect()
382 })
383}
384
385pub(crate) fn drop_context_slots_for_instance_logic_ids(instance_logic_ids: &HashSet<u64>) {
386 if instance_logic_ids.is_empty() {
387 return;
388 }
389
390 with_slot_table_mut(|table| {
391 let mut freed: Vec<(u32, SlotKey)> = Vec::new();
392 for (slot, entry) in table.entries.iter_mut().enumerate() {
393 if entry.value.is_none() {
394 continue;
395 }
396 if !instance_logic_ids.contains(&entry.key.instance_logic_id) {
397 continue;
398 }
399 freed.push((slot as u32, entry.key));
400 entry.value = None;
401 entry.generation = entry.generation.wrapping_add(1);
402 entry.last_alive_epoch = 0;
403 }
404
405 for (slot, key) in freed {
406 table.key_to_slot.remove(&key);
407 table.free_list.push(slot);
408 }
409 });
410}
411
412pub struct Context<T> {
447 slot: u32,
448 generation: u64,
449 _marker: PhantomData<T>,
450}
451
452impl<T> Copy for Context<T> {}
453
454impl<T> Clone for Context<T> {
455 fn clone(&self) -> Self {
456 *self
457 }
458}
459
460impl<T> Context<T> {
461 fn new(slot: u32, generation: u64) -> Self {
462 Self {
463 slot,
464 generation,
465 _marker: PhantomData,
466 }
467 }
468}
469
470impl<T> Context<T>
471where
472 T: Send + Sync + 'static,
473{
474 fn load_entry(&self) -> Arc<dyn Any + Send + Sync> {
475 with_slot_table(|table| {
476 let entry = table
477 .entries
478 .get(self.slot as usize)
479 .unwrap_or_else(|| panic!("Context points to freed slot: {}", self.slot));
480
481 if entry.generation != self.generation {
482 panic!(
483 "Context is stale (slot {}, generation {}, current generation {})",
484 self.slot, self.generation, entry.generation
485 );
486 }
487
488 if entry.key.type_id != TypeId::of::<T>() {
489 panic!(
490 "Context type mismatch for slot {}: expected {}, stored {:?}",
491 self.slot,
492 std::any::type_name::<T>(),
493 entry.key.type_id
494 );
495 }
496
497 entry
498 .value
499 .as_ref()
500 .unwrap_or_else(|| panic!("Context slot {} has been recycled", self.slot))
501 .clone()
502 })
503 }
504
505 fn load_lock(&self) -> Arc<RwLock<T>> {
506 self.load_entry()
507 .downcast::<RwLock<T>>()
508 .unwrap_or_else(|_| panic!("Context slot {} downcast failed", self.slot))
509 }
510
511 pub fn with<R>(&self, f: impl FnOnce(&T) -> R) -> R {
513 track_context_read_dependency(self.slot, self.generation);
514 let lock = self.load_lock();
515 let guard = lock.read();
516 f(&guard)
517 }
518
519 pub fn with_mut<R>(&self, f: impl FnOnce(&mut T) -> R) -> R {
521 let lock = self.load_lock();
522 let result = {
523 let mut guard = lock.write();
524 f(&mut guard)
525 };
526 let subscribers = context_read_subscribers(self.slot, self.generation);
527 for instance_key in subscribers {
528 record_replay_boundary_invalidation_for_instance_key(instance_key);
529 }
530 result
531 }
532
533 pub fn get(&self) -> T
535 where
536 T: Clone,
537 {
538 self.with(Clone::clone)
539 }
540
541 pub fn set(&self, value: T) {
543 self.with_mut(|v| *v = value);
544 }
545
546 pub fn replace(&self, value: T) -> T {
548 self.with_mut(|v| std::mem::replace(v, value))
549 }
550}
551
552fn push_context_layer(type_id: TypeId, slot: u32, generation: u64, key: SlotKey) {
553 with_execution_context_mut(|context| {
554 let parent = context
555 .context_stack
556 .last()
557 .cloned()
558 .unwrap_or_else(ContextMap::new);
559 let next = parent.update(
560 type_id,
561 ContextSnapshotEntry {
562 slot,
563 generation,
564 key,
565 },
566 );
567 context.context_stack.push(next);
568 });
569}
570
571fn pop_context_layer() {
572 with_execution_context_mut(|context| {
573 let popped = context.context_stack.pop();
574 debug_assert!(popped.is_some(), "Context stack underflow");
575 if context.context_stack.is_empty() {
576 context.context_stack.push(ContextMap::new());
577 }
578 });
579}
580
581pub fn provide_context<T, I, F, R>(init: I, f: F) -> R
617where
618 T: Send + Sync + 'static,
619 I: FnOnce() -> T,
620 F: FnOnce() -> R,
621{
622 ensure_build_phase();
623
624 let (instance_logic_id, slot_hash) = compute_context_slot_key();
625 let type_id = TypeId::of::<T>();
626 let slot_key = SlotKey {
627 instance_logic_id,
628 slot_hash,
629 type_id,
630 };
631
632 let (slot, generation) = {
633 with_slot_table_mut(|table| {
634 let epoch = table.epoch;
635
636 if let Some(slot) = table.key_to_slot.get(&slot_key).copied() {
637 let entry = table
638 .entries
639 .get_mut(slot as usize)
640 .expect("context slot entry should exist");
641 entry.last_alive_epoch = epoch;
642
643 if entry.value.is_none() {
644 entry.value = Some(Arc::new(RwLock::new(init())));
645 entry.generation = entry.generation.wrapping_add(1);
646 }
647
648 let generation = entry.generation;
649 (slot, generation)
650 } else if let Some(slot) = table.free_list.pop() {
651 let entry = table
652 .entries
653 .get_mut(slot as usize)
654 .expect("context slot entry should exist");
655
656 entry.key = slot_key;
657 entry.value = Some(Arc::new(RwLock::new(init())));
658 entry.last_alive_epoch = epoch;
659
660 let generation = entry.generation;
661 table.key_to_slot.insert(slot_key, slot);
662 (slot, generation)
663 } else {
664 let generation = 0;
665 let slot = table.entries.len() as u32;
666 table.entries.push(SlotEntry {
667 key: slot_key,
668 generation,
669 value: Some(Arc::new(RwLock::new(init()))),
670 last_alive_epoch: epoch,
671 });
672 table.key_to_slot.insert(slot_key, slot);
673 (slot, generation)
674 }
675 })
676 };
677
678 push_context_layer(type_id, slot, generation, slot_key);
679
680 struct ContextScopeGuard;
681 impl Drop for ContextScopeGuard {
682 fn drop(&mut self) {
683 pop_context_layer();
684 }
685 }
686
687 let guard = ContextScopeGuard;
688 let result = f();
689 drop(guard);
690 result
691}
692
693pub fn use_context<T>() -> Option<Context<T>>
712where
713 T: Send + Sync + 'static,
714{
715 ensure_build_phase();
716
717 with_execution_context(|context| {
718 let map = context
719 .context_stack
720 .last()
721 .expect("Context stack must always contain at least one layer");
722 map.get(&TypeId::of::<T>())
723 .copied()
724 .map(|entry| Context::new(entry.slot, entry.generation))
725 })
726}
727#[cfg(test)]
730mod tests {
731 use std::{any::TypeId, sync::Arc};
732
733 use parking_lot::RwLock;
734
735 use super::{
736 ContextMap, ContextSnapshotEntry, SlotEntry, SlotKey, SlotTable, with_context_snapshot,
737 with_slot_table_mut,
738 };
739 use crate::execution_context::{
740 reset_execution_context, with_execution_context, with_execution_context_mut,
741 };
742 use crate::runtime::{RuntimePhase, push_phase};
743
744 fn reset_test_state() {
745 with_slot_table_mut(|table| *table = SlotTable::default());
746 reset_execution_context();
747 with_execution_context_mut(|context| {
748 context.context_stack = vec![ContextMap::new()];
749 });
750 }
751
752 #[test]
753 fn with_context_snapshot_restores_stack_after_panic() {
754 reset_test_state();
755
756 let base_layer = ContextMap::new();
757 let parent_layer = base_layer.update(
758 TypeId::of::<u8>(),
759 ContextSnapshotEntry {
760 slot: 1,
761 generation: 2,
762 key: SlotKey {
763 instance_logic_id: 7,
764 slot_hash: 11,
765 type_id: TypeId::of::<u8>(),
766 },
767 },
768 );
769 let original_stack = vec![base_layer, parent_layer];
770 with_execution_context_mut(|context| {
771 context.context_stack = original_stack.clone();
772 });
773
774 let snapshot = ContextMap::new().update(
775 TypeId::of::<u32>(),
776 ContextSnapshotEntry {
777 slot: 3,
778 generation: 4,
779 key: SlotKey {
780 instance_logic_id: 17,
781 slot_hash: 19,
782 type_id: TypeId::of::<u32>(),
783 },
784 },
785 );
786 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
787 with_context_snapshot(&snapshot, || {
788 with_execution_context(|context| {
789 assert_eq!(context.context_stack.len(), 1);
790 assert!(context.context_stack.last().is_some());
791 });
792 panic!("expected panic in context snapshot test");
793 });
794 }));
795 assert!(result.is_err());
796
797 with_execution_context(|context| {
798 assert_eq!(context.context_stack, original_stack);
799 });
800 }
801
802 #[test]
803 fn with_context_snapshot_remaps_stale_generation() {
804 reset_test_state();
805
806 let key = SlotKey {
807 instance_logic_id: 31,
808 slot_hash: 37,
809 type_id: TypeId::of::<u8>(),
810 };
811 {
812 with_slot_table_mut(|table| {
813 *table = SlotTable::default();
814 table.entries.push(SlotEntry {
815 key,
816 generation: 2,
817 value: Some(Arc::new(RwLock::new(42_u8))),
818 last_alive_epoch: 1,
819 });
820 table.key_to_slot.insert(key, 0);
821 });
822 }
823
824 let snapshot = ContextMap::new().update(
825 TypeId::of::<u8>(),
826 ContextSnapshotEntry {
827 slot: 0,
828 generation: 1,
829 key,
830 },
831 );
832
833 let _phase_guard = push_phase(RuntimePhase::Build);
834 with_context_snapshot(&snapshot, || {
835 let context = super::use_context::<u8>().expect("context should be remapped");
836 assert_eq!(context.get(), 42);
837 });
838 }
839
840 #[test]
841 fn with_context_snapshot_remaps_stale_slot_by_key() {
842 reset_test_state();
843
844 let key = SlotKey {
845 instance_logic_id: 41,
846 slot_hash: 43,
847 type_id: TypeId::of::<u16>(),
848 };
849 let old_key = SlotKey {
850 instance_logic_id: 47,
851 slot_hash: 53,
852 type_id: TypeId::of::<u16>(),
853 };
854 {
855 with_slot_table_mut(|table| {
856 *table = SlotTable::default();
857 table.entries.push(SlotEntry {
858 key: old_key,
859 generation: 9,
860 value: None,
861 last_alive_epoch: 0,
862 });
863 table.entries.push(SlotEntry {
864 key,
865 generation: 4,
866 value: Some(Arc::new(RwLock::new(7_u16))),
867 last_alive_epoch: 2,
868 });
869 table.key_to_slot.insert(key, 1);
870 });
871 }
872
873 let snapshot = ContextMap::new().update(
874 TypeId::of::<u16>(),
875 ContextSnapshotEntry {
876 slot: 0,
877 generation: 3,
878 key,
879 },
880 );
881
882 let _phase_guard = push_phase(RuntimePhase::Build);
883 with_context_snapshot(&snapshot, || {
884 let context = super::use_context::<u16>().expect("context should be remapped by key");
885 assert_eq!(context.get(), 7);
886 });
887 }
888}