1use std::{
14 collections::HashSet,
15 sync::{
16 Arc,
17 atomic::{AtomicU64, Ordering},
18 },
19};
20
21use dashmap::DashMap;
22
23use crate::{
24 ShardState, ShardStateLifeCycle, ShardStateMap, init_or_get_shard_state_in_map,
25 recycle_shard_state_slot,
26};
27
28static NEXT_ROUTE_ID: AtomicU64 = AtomicU64::new(1);
29
30#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
32pub(crate) struct RouteId(u64);
33
34impl RouteId {
35 fn new() -> Self {
36 Self(NEXT_ROUTE_ID.fetch_add(1, Ordering::Relaxed))
37 }
38}
39
40#[derive(Clone, Debug, Eq, Hash, PartialEq)]
41struct RouteShardKey {
42 route_id: RouteId,
43 shard_id: String,
44}
45
46struct RouteEntry {
47 route_id: RouteId,
48 destination: Arc<dyn RouterDestination>,
49}
50
51pub struct RouterController {
53 route_stack: Vec<RouteEntry>,
54 version: u64,
55 scope_shards: ShardStateMap<String>,
56 route_shards: ShardStateMap<RouteShardKey>,
57}
58
59impl RouterController {
60 pub fn new() -> Self {
62 Self {
63 route_stack: Vec::new(),
64 version: 0,
65 scope_shards: DashMap::new(),
66 route_shards: DashMap::new(),
67 }
68 }
69
70 pub fn with_root(root_dest: impl RouterDestination + 'static) -> Self {
72 let mut router = Self::new();
73 router.push(root_dest);
74 router
75 }
76
77 pub fn with_root_shared(root_dest: Arc<dyn RouterDestination>) -> Self {
79 let mut router = Self::new();
80 router.push_shared(root_dest);
81 router
82 }
83
84 pub fn version(&self) -> u64 {
86 self.version
87 }
88
89 pub fn push<T: RouterDestination + 'static>(&mut self, destination: T) {
91 self.push_shared(Arc::new(destination));
92 }
93
94 pub fn push_shared(&mut self, destination: Arc<dyn RouterDestination>) {
96 self.route_stack.push(RouteEntry {
97 route_id: RouteId::new(),
98 destination,
99 });
100 self.bump_version();
101 }
102
103 pub fn pop(&mut self) -> Option<Arc<dyn RouterDestination>> {
107 let removed = self.route_stack.pop()?;
108 self.prune_route_shards(removed.route_id);
109 self.bump_version();
110 Some(removed.destination)
111 }
112
113 pub fn replace<T: RouterDestination + 'static>(
117 &mut self,
118 destination: T,
119 ) -> Option<Arc<dyn RouterDestination>> {
120 let previous = self.pop();
121 self.push(destination);
122 previous
123 }
124
125 pub fn replace_shared(
127 &mut self,
128 destination: Arc<dyn RouterDestination>,
129 ) -> Option<Arc<dyn RouterDestination>> {
130 let previous = self.pop();
131 self.push_shared(destination);
132 previous
133 }
134
135 pub fn is_empty(&self) -> bool {
137 self.route_stack.is_empty()
138 }
139
140 pub fn len(&self) -> usize {
142 self.route_stack.len()
143 }
144
145 pub fn last(&self) -> Option<&dyn RouterDestination> {
147 self.route_stack.last().map(|entry| &*entry.destination)
148 }
149
150 pub(crate) fn current_route_id(&self) -> Option<RouteId> {
152 self.route_stack.last().map(|entry| entry.route_id)
153 }
154
155 pub fn exec_current(&self) -> bool {
159 let Some(entry) = self.route_stack.last() else {
160 return false;
161 };
162 entry.destination.exec_component();
163 true
164 }
165
166 pub fn init_or_get<T, F, R>(&self, id: &str, f: F) -> R
168 where
169 T: Default + Send + Sync + 'static,
170 F: FnOnce(ShardState<T>) -> R,
171 {
172 self.init_or_get_with_lifecycle(id, ShardStateLifeCycle::Shard, f)
173 }
174
175 pub fn init_or_get_with_lifecycle<T, F, R>(
177 &self,
178 id: &str,
179 life_cycle: ShardStateLifeCycle,
180 f: F,
181 ) -> R
182 where
183 T: Default + Send + Sync + 'static,
184 F: FnOnce(ShardState<T>) -> R,
185 {
186 match life_cycle {
187 ShardStateLifeCycle::Scope => {
188 init_or_get_shard_state_in_map(&self.scope_shards, id.to_owned(), id, "scope", f)
189 }
190 ShardStateLifeCycle::Shard => {
191 let route_id = self.current_route_id().unwrap_or_else(|| {
192 panic!("route-scoped shard state requires a non-empty router stack")
193 });
194 init_or_get_shard_state_in_map(
195 &self.route_shards,
196 RouteShardKey {
197 route_id,
198 shard_id: id.to_owned(),
199 },
200 id,
201 "route",
202 f,
203 )
204 }
205 }
206 }
207
208 pub fn clear(&mut self) {
210 if self.route_stack.is_empty() {
211 return;
212 }
213 let removed_route_ids: HashSet<_> = self
214 .route_stack
215 .drain(..)
216 .map(|entry| entry.route_id)
217 .collect();
218 let keys: Vec<_> = self
219 .route_shards
220 .iter()
221 .filter(|entry| removed_route_ids.contains(&entry.key().route_id))
222 .map(|entry| entry.key().clone())
223 .collect();
224 for key in keys {
225 if let Some((_, slot)) = self.route_shards.remove(&key) {
226 recycle_shard_state_slot(slot);
227 }
228 }
229 self.bump_version();
230 }
231
232 pub fn reset(&mut self, root_dest: impl RouterDestination + 'static) {
234 self.clear();
235 self.push(root_dest);
236 }
237
238 pub fn reset_shared(&mut self, root_dest: Arc<dyn RouterDestination>) {
240 self.clear();
241 self.push_shared(root_dest);
242 }
243
244 fn bump_version(&mut self) {
245 self.version = self.version.wrapping_add(1);
246 }
247
248 fn prune_route_shards(&self, route_id: RouteId) {
249 let keys: Vec<_> = self
250 .route_shards
251 .iter()
252 .filter(|entry| entry.key().route_id == route_id)
253 .map(|entry| entry.key().clone())
254 .collect();
255 for key in keys {
256 if let Some((_, slot)) = self.route_shards.remove(&key) {
257 recycle_shard_state_slot(slot);
258 }
259 }
260 }
261}
262
263impl Default for RouterController {
264 fn default() -> Self {
265 Self::new()
266 }
267}
268
269impl Drop for RouterController {
270 fn drop(&mut self) {
271 let scope_slots: Vec<_> = self
272 .scope_shards
273 .iter()
274 .map(|entry| *entry.value())
275 .collect();
276 let route_slots: Vec<_> = self
277 .route_shards
278 .iter()
279 .map(|entry| *entry.value())
280 .collect();
281
282 self.scope_shards.clear();
283 self.route_shards.clear();
284
285 for slot in scope_slots.into_iter().chain(route_slots) {
286 recycle_shard_state_slot(slot);
287 }
288 }
289}
290
291pub trait RouterDestination: Send + Sync {
293 fn exec_component(&self);
295 fn shard_id(&self) -> &'static str;
297}
298
299#[cfg(test)]
300mod tests {
301 use std::{
302 panic::{AssertUnwindSafe, catch_unwind},
303 sync::atomic::{AtomicU64, AtomicUsize, Ordering},
304 };
305
306 use crate::ShardStateLifeCycle;
307
308 use super::{RouterController, RouterDestination};
309
310 static TEST_SHARD_ID_COUNTER: AtomicU64 = AtomicU64::new(1);
311
312 fn unique_shard_id(prefix: &str) -> &'static str {
313 let id = TEST_SHARD_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
314 Box::leak(format!("{prefix}::{id}").into_boxed_str())
315 }
316
317 #[derive(Default)]
318 struct CounterState {
319 value: AtomicUsize,
320 }
321
322 struct DummyDestination;
323
324 impl RouterDestination for DummyDestination {
325 fn exec_component(&self) {}
326
327 fn shard_id(&self) -> &'static str {
328 "dummy"
329 }
330 }
331
332 fn increment_state(
333 router: &RouterController,
334 shard_id: &str,
335 life_cycle: ShardStateLifeCycle,
336 ) -> usize {
337 router.init_or_get_with_lifecycle::<CounterState, _, _>(shard_id, life_cycle, |state| {
338 state.with(|value| value.value.fetch_add(1, Ordering::SeqCst) + 1)
339 })
340 }
341
342 #[test]
343 fn route_scoped_state_is_released_on_pop() {
344 let shard_id = unique_shard_id("route_scoped");
345 let mut router = RouterController::with_root(DummyDestination);
346
347 assert_eq!(
348 increment_state(&router, shard_id, ShardStateLifeCycle::Shard),
349 1
350 );
351 assert_eq!(
352 increment_state(&router, shard_id, ShardStateLifeCycle::Shard),
353 2
354 );
355
356 assert!(router.pop().is_some());
357 router.push(DummyDestination);
358 assert_eq!(
359 increment_state(&router, shard_id, ShardStateLifeCycle::Shard),
360 1
361 );
362 }
363
364 #[test]
365 fn scope_scoped_state_persists_inside_scope_but_resets_across_scopes() {
366 let shard_id = unique_shard_id("scope_scoped");
367 let mut router = RouterController::with_root(DummyDestination);
368
369 assert_eq!(
370 increment_state(&router, shard_id, ShardStateLifeCycle::Scope),
371 1
372 );
373
374 router.push(DummyDestination);
375 assert_eq!(
376 increment_state(&router, shard_id, ShardStateLifeCycle::Scope),
377 2
378 );
379
380 assert!(router.pop().is_some());
381 assert_eq!(
382 increment_state(&router, shard_id, ShardStateLifeCycle::Scope),
383 3
384 );
385
386 drop(router);
387
388 let router = RouterController::with_root(DummyDestination);
389 assert_eq!(
390 increment_state(&router, shard_id, ShardStateLifeCycle::Scope),
391 1
392 );
393 }
394
395 #[test]
396 fn route_scoped_state_requires_active_route() {
397 let shard_id = unique_shard_id("route_context_required");
398 let router = RouterController::new();
399 let result = catch_unwind(AssertUnwindSafe(|| {
400 let _ = increment_state(&router, shard_id, ShardStateLifeCycle::Shard);
401 }));
402 assert!(result.is_err());
403 }
404}