tessera_ui/renderer/
reorder.rs

1use std::{
2    any::TypeId,
3    collections::{BinaryHeap, HashMap},
4};
5
6use petgraph::{
7    graph::{DiGraph, NodeIndex},
8    visit::IntoNodeIdentifiers,
9};
10
11use crate::{
12    px::{Px, PxPosition, PxRect, PxSize},
13    renderer::command::{BarrierRequirement, Command},
14};
15
16/// Instruction category for sorting.
17/// The order of the variants is important as it defines the priority.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
19pub(crate) enum InstructionCategory {
20    /// Low priority, can be batched together.
21    ContinuationDraw,
22    /// Medium priority, requires a barrier.
23    BarrierDraw,
24    /// High priority, must be executed before barrier draws that depend on it.
25    Compute,
26    /// A state-changing command that acts as a reordering fence.
27    StateChange,
28}
29
30/// A wrapper for a command with additional information for sorting.
31pub(crate) struct InstructionInfo {
32    pub(crate) original_index: usize,
33    pub(crate) command: Command,
34    pub(crate) type_id: TypeId,
35    pub(crate) size: PxSize,
36    pub(crate) position: PxPosition,
37    pub(crate) category: InstructionCategory,
38    pub(crate) rect: PxRect,
39}
40
41impl InstructionInfo {
42    /// Creates a new `InstructionInfo` from a command and its context.
43    ///
44    /// It calculates the instruction category and the bounding rectangle.
45    pub(crate) fn new(
46        (command, type_id, size, position): (Command, TypeId, PxSize, PxPosition),
47        original_index: usize,
48    ) -> Self {
49        let (category, rect) = match &command {
50            Command::Compute(command) => {
51                // Compute commands should have proper scoping based on their barrier requirement
52                // instead of always using global scope
53                let barrier_req = command.barrier();
54                let rect = match barrier_req {
55                    BarrierRequirement::Global => PxRect {
56                        x: Px(0),
57                        y: Px(0),
58                        width: Px(i32::MAX),
59                        height: Px(i32::MAX),
60                    },
61                    BarrierRequirement::PaddedLocal {
62                        top,
63                        right,
64                        bottom,
65                        left,
66                    } => {
67                        let padded_x = (position.x - left).max(Px(0));
68                        let padded_y = (position.y - top).max(Px(0));
69                        let padded_width = size.width + left + right;
70                        let padded_height = size.height + top + bottom;
71                        PxRect {
72                            x: padded_x,
73                            y: padded_y,
74                            width: padded_width,
75                            height: padded_height,
76                        }
77                    }
78                    BarrierRequirement::Absolute(rect) => rect,
79                };
80                (InstructionCategory::Compute, rect)
81            }
82            Command::Draw(draw_command) => {
83                let barrier = draw_command.barrier();
84                let category = if barrier.is_some() {
85                    InstructionCategory::BarrierDraw
86                } else {
87                    InstructionCategory::ContinuationDraw
88                };
89
90                let rect = match barrier {
91                    Some(BarrierRequirement::Global) => PxRect {
92                        x: Px(0),
93                        y: Px(0),
94                        width: Px(i32::MAX),
95                        height: Px(i32::MAX),
96                    },
97                    Some(BarrierRequirement::PaddedLocal {
98                        top,
99                        right,
100                        bottom,
101                        left,
102                    }) => {
103                        let padded_x = (position.x - left).max(Px(0));
104                        let padded_y = (position.y - top).max(Px(0));
105                        let padded_width = size.width + left + right;
106                        let padded_height = size.height + top + bottom;
107                        PxRect {
108                            x: padded_x,
109                            y: padded_y,
110                            width: padded_width,
111                            height: padded_height,
112                        }
113                    }
114                    Some(BarrierRequirement::Absolute(rect)) => rect,
115                    None => PxRect {
116                        x: position.x,
117                        y: position.y,
118                        width: size.width,
119                        height: size.height,
120                    },
121                };
122                (category, rect)
123            }
124            Command::ClipPush(rect) => (InstructionCategory::StateChange, *rect),
125            Command::ClipPop => (
126                InstructionCategory::StateChange,
127                PxRect {
128                    x: position.x,
129                    y: position.y,
130                    width: Px::ZERO,
131                    height: Px::ZERO,
132                },
133            ),
134        };
135
136        Self {
137            original_index,
138            command,
139            type_id,
140            size,
141            position,
142            category,
143            rect,
144        }
145    }
146}
147
148/// A node in the priority queue for topological sorting.
149#[derive(Debug, Clone, Copy, PartialEq, Eq)]
150struct PriorityNode {
151    category: InstructionCategory,
152    type_id: TypeId,
153    original_index: usize,
154    node_index: NodeIndex,
155    batch_potential: usize,
156}
157
158impl Ord for PriorityNode {
159    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
160        // This is the core heuristic for optimal batching:
161        // 1. Higher category is always higher priority.
162        // 2. For the same category, nodes with smaller batch potential are prioritized.
163        //    This helps to get "lonely" nodes out of the way, clearing the path for
164        //    larger batches to be processed contiguously.
165        // 3. The original index is used as a final tie-breaker for stability.
166        self.category
167            .cmp(&other.category)
168            .then_with(|| self.batch_potential.cmp(&other.batch_potential).reverse())
169            .then_with(|| self.original_index.cmp(&other.original_index).reverse())
170    }
171}
172
173impl PartialOrd for PriorityNode {
174    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
175        Some(self.cmp(other))
176    }
177}
178
179pub(crate) fn reorder_instructions(
180    commands: impl IntoIterator<Item = (Command, TypeId, PxSize, PxPosition)>,
181) -> Vec<(Command, TypeId, PxSize, PxPosition)> {
182    let instructions: Vec<InstructionInfo> = commands
183        .into_iter()
184        .enumerate()
185        .map(|(i, cmd)| InstructionInfo::new(cmd, i))
186        .collect();
187
188    if instructions.is_empty() {
189        return vec![];
190    }
191
192    let mut potentials = HashMap::new();
193    for info in &instructions {
194        *potentials.entry((info.category, info.type_id)).or_insert(0) += 1;
195    }
196
197    let graph = build_dependency_graph(&instructions);
198
199    let sorted_node_indices = priority_topological_sort(&graph, &instructions, &potentials);
200
201    let mut sorted_instructions = Vec::with_capacity(instructions.len());
202    let mut original_infos: Vec<_> = instructions.into_iter().map(Some).collect();
203
204    for node_index in sorted_node_indices {
205        let original_index = node_index.index();
206        if let Some(info) = original_infos[original_index].take() {
207            sorted_instructions.push((info.command, info.type_id, info.size, info.position));
208        }
209    }
210
211    sorted_instructions
212}
213
214fn priority_topological_sort(
215    graph: &DiGraph<(), ()>,
216    instructions: &[InstructionInfo],
217    potentials: &HashMap<(InstructionCategory, TypeId), usize>,
218) -> Vec<NodeIndex> {
219    let mut in_degree = vec![0; graph.node_count()];
220    for edge in graph.raw_edges() {
221        in_degree[edge.target().index()] += 1;
222    }
223
224    let mut ready_queue = BinaryHeap::new();
225    for node_index in graph.node_identifiers() {
226        if in_degree[node_index.index()] == 0 {
227            let info = &instructions[node_index.index()];
228            ready_queue.push(PriorityNode {
229                category: info.category,
230                type_id: info.type_id,
231                original_index: info.original_index,
232                node_index,
233                batch_potential: potentials[&(info.category, info.type_id)],
234            });
235        }
236    }
237
238    let mut sorted_list = Vec::with_capacity(instructions.len());
239    while let Some(priority_node) = ready_queue.pop() {
240        let u = priority_node.node_index;
241        sorted_list.push(u);
242
243        for v in graph.neighbors(u) {
244            in_degree[v.index()] -= 1;
245            if in_degree[v.index()] == 0 {
246                let info = &instructions[v.index()];
247                ready_queue.push(PriorityNode {
248                    category: info.category,
249                    type_id: info.type_id,
250                    original_index: info.original_index,
251                    node_index: v,
252                    batch_potential: potentials[&(info.category, info.type_id)],
253                });
254            }
255        }
256    }
257
258    if sorted_list.len() != instructions.len() {
259        // This indicates a cycle in the graph, which should not happen
260        // in a well-formed UI command stream.
261        // Fallback to original order.
262        return (0..instructions.len()).map(NodeIndex::new).collect();
263    }
264
265    sorted_list
266}
267
268fn build_dependency_graph(instructions: &[InstructionInfo]) -> DiGraph<(), ()> {
269    let mut graph = DiGraph::new();
270    let node_indices: Vec<NodeIndex> = instructions.iter().map(|_| graph.add_node(())).collect();
271
272    for i in 0..instructions.len() {
273        for j in 0..instructions.len() {
274            if i == j {
275                continue;
276            }
277
278            let inst_i = &instructions[i];
279            let inst_j = &instructions[j];
280
281            // Rule 0: State changes act as fences.
282            // If one of two commands is a state change, their relative order must be preserved.
283            if inst_i.original_index < inst_j.original_index
284                && (inst_i.category == InstructionCategory::StateChange
285                    || inst_j.category == InstructionCategory::StateChange)
286            {
287                graph.add_edge(node_indices[i], node_indices[j], ());
288            }
289
290            // Rule 1: Explicit dependency (Compute -> BarrierDraw)
291            // If inst_j is a BarrierDraw and inst_i is a Compute that appeared
292            // earlier, then j depends on i.
293            if inst_i.category == InstructionCategory::Compute
294                && inst_j.category == InstructionCategory::BarrierDraw
295                && inst_i.original_index < inst_j.original_index
296            {
297                graph.add_edge(node_indices[i], node_indices[j], ());
298            }
299
300            // Rule 2: Implicit dependency (Overlapping Draws)
301            // If both are draw commands and their original order matters (j came after i)
302            // and their rectangles are not orthogonal (i.e., they might overlap),
303            // then j depends on i to maintain painter's algorithm.
304            if (inst_i.category == InstructionCategory::BarrierDraw
305                || inst_i.category == InstructionCategory::ContinuationDraw)
306                && (inst_j.category == InstructionCategory::BarrierDraw
307                    || inst_j.category == InstructionCategory::ContinuationDraw)
308                && inst_i.original_index < inst_j.original_index
309                && !inst_i.rect.is_orthogonal(&inst_j.rect)
310            {
311                graph.add_edge(node_indices[i], node_indices[j], ());
312            }
313
314            // Rule 3: Implicit dependency (Draw -> Compute)
315            // If inst_j is a Compute command and inst_i is a Draw command that
316            // appeared earlier, and their areas are not orthogonal, then j depends on i.
317            if (inst_i.category == InstructionCategory::BarrierDraw
318                || inst_i.category == InstructionCategory::ContinuationDraw)
319                && inst_j.category == InstructionCategory::Compute
320                && inst_i.original_index < inst_j.original_index
321                && !inst_i.rect.is_orthogonal(&inst_j.rect)
322            {
323                graph.add_edge(node_indices[i], node_indices[j], ());
324            }
325        }
326    }
327
328    graph
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334    use crate::{
335        px::{Px, PxPosition, PxRect, PxSize},
336        renderer::{
337            BarrierRequirement, command::Command, compute::ComputeCommand, drawer::DrawCommand,
338        },
339    };
340    use std::any::TypeId;
341    use std::fmt::Debug;
342
343    // --- Mock Commands ---
344    // Mocks to simulate different command types for testing reordering logic.
345
346    #[derive(Debug, PartialEq, Clone)]
347    struct MockDrawCommand {
348        barrier_req: Option<BarrierRequirement>,
349    }
350
351    impl DrawCommand for MockDrawCommand {
352        fn barrier(&self) -> Option<BarrierRequirement> {
353            self.barrier_req
354        }
355    }
356
357    #[derive(Debug, PartialEq, Clone)]
358    struct MockDrawCommand2 {
359        barrier_req: Option<BarrierRequirement>,
360    }
361
362    impl DrawCommand for MockDrawCommand2 {
363        fn barrier(&self) -> Option<BarrierRequirement> {
364            self.barrier_req
365        }
366    }
367
368    #[derive(Debug, PartialEq, Clone)]
369    struct MockComputeCommand {
370        barrier_req: BarrierRequirement,
371    }
372
373    impl ComputeCommand for MockComputeCommand {
374        fn barrier(&self) -> BarrierRequirement {
375            self.barrier_req
376        }
377    }
378
379    #[derive(Debug, PartialEq, Clone)]
380    struct MockComputeCommand2 {
381        barrier_req: BarrierRequirement,
382    }
383
384    impl ComputeCommand for MockComputeCommand2 {
385        fn barrier(&self) -> BarrierRequirement {
386            self.barrier_req
387        }
388    }
389
390    // --- Helper Functions ---
391
392    fn create_cmd(
393        pos: PxPosition,
394        barrier_req: Option<BarrierRequirement>,
395        is_compute: bool,
396    ) -> (Command, TypeId, PxSize, PxPosition) {
397        let size = PxSize::new(Px(10), Px(10));
398        if is_compute {
399            let cmd = MockComputeCommand {
400                barrier_req: barrier_req.unwrap_or(BarrierRequirement::Global),
401            };
402            (
403                Command::Compute(Box::new(cmd)),
404                TypeId::of::<MockComputeCommand>(),
405                size,
406                pos,
407            )
408        } else {
409            let cmd = MockDrawCommand { barrier_req };
410            (
411                Command::Draw(Box::new(cmd)),
412                TypeId::of::<MockDrawCommand>(),
413                size,
414                pos,
415            )
416        }
417    }
418
419    fn create_cmd2(
420        pos: PxPosition,
421        barrier_req: Option<BarrierRequirement>,
422        is_compute: bool,
423    ) -> (Command, TypeId, PxSize, PxPosition) {
424        let size = PxSize::new(Px(10), Px(10));
425        if is_compute {
426            let cmd = MockComputeCommand2 {
427                barrier_req: barrier_req.unwrap_or(BarrierRequirement::Global),
428            };
429            (
430                Command::Compute(Box::new(cmd)),
431                TypeId::of::<MockComputeCommand2>(),
432                size,
433                pos,
434            )
435        } else {
436            let cmd = MockDrawCommand2 { barrier_req };
437            (
438                Command::Draw(Box::new(cmd)),
439                TypeId::of::<MockDrawCommand2>(),
440                size,
441                pos,
442            )
443        }
444    }
445
446    fn get_positions(commands: &[(Command, TypeId, PxSize, PxPosition)]) -> Vec<PxPosition> {
447        commands.iter().map(|(_, _, _, pos)| *pos).collect()
448    }
449
450    // --- Test Cases ---
451
452    #[test]
453    fn test_empty_instructions() {
454        let commands = vec![];
455        let reordered = reorder_instructions(commands);
456        assert!(reordered.is_empty());
457    }
458
459    #[test]
460    fn test_no_dependencies_preserves_order() {
461        let commands = vec![
462            create_cmd(PxPosition::new(Px(0), Px(0)), None, false), // 0
463            create_cmd(PxPosition::new(Px(20), Px(0)), None, false), // 1
464        ];
465        let original_positions = get_positions(&commands);
466        let reordered = reorder_instructions(commands);
467        let reordered_positions = get_positions(&reordered);
468        assert_eq!(reordered_positions, original_positions);
469    }
470
471    #[test]
472    fn test_compute_before_barrier_preserves_order() {
473        let commands = vec![
474            create_cmd(
475                PxPosition::new(Px(0), Px(0)),
476                Some(BarrierRequirement::Global),
477                true,
478            ), // 0: Compute
479            create_cmd(
480                PxPosition::new(Px(20), Px(20)),
481                Some(BarrierRequirement::Global),
482                false,
483            ), // 1: BarrierDraw
484        ];
485        let original_positions = get_positions(&commands);
486        let reordered = reorder_instructions(commands);
487        let reordered_positions = get_positions(&reordered);
488        assert_eq!(reordered_positions, original_positions);
489    }
490
491    #[test]
492    fn test_opt() {
493        // Test case 1: No dependencies, test batching
494        let commands = vec![
495            create_cmd(PxPosition::new(Px(0), Px(0)), None, false), // 0 (T1)
496            create_cmd2(PxPosition::new(Px(10), Px(10)), None, false), // 1 (T2)
497            create_cmd(PxPosition::new(Px(20), Px(20)), None, false), // 2 (T1)
498        ];
499        let reordered = reorder_instructions(commands);
500        let reordered_positions = get_positions(&reordered);
501
502        // Potentials: T1 -> 2, T2 -> 1.
503        // T2 has lower potential, so it's prioritized.
504        // Expected order: [1, 0, 2]
505        let expected_positions = vec![
506            PxPosition::new(Px(10), Px(10)), // 1
507            PxPosition::new(Px(0), Px(0)),   // 0
508            PxPosition::new(Px(20), Px(20)), // 2
509        ];
510        assert_eq!(reordered_positions, expected_positions);
511
512        // Test case 2: With dependencies, test batching
513        let commands = vec![
514            create_cmd(PxPosition::new(Px(0), Px(0)), None, false), // 0 (T1)
515            create_cmd2(PxPosition::new(Px(10), Px(10)), None, false), // 1 (T2)
516            create_cmd(PxPosition::new(Px(5), Px(5)), None, false), // 2 (T1)
517        ];
518        let reordered = reorder_instructions(commands);
519        let reordered_positions = get_positions(&reordered);
520
521        // Potentials: T1 -> 2, T2 -> 1.
522        // Dependencies: 2 > 0, 2 > 1.
523        // Initial ready queue: [0, 1].
524        // Node 1 has lower potential (1 vs 2), so it's prioritized.
525        // Expected order: [1, 0, 2]
526        let expected_positions = vec![
527            PxPosition::new(Px(10), Px(10)), // Cmd 1
528            PxPosition::new(Px(0), Px(0)),   // Cmd 0
529            PxPosition::new(Px(5), Px(5)),   // Cmd 2
530        ];
531        assert_eq!(expected_positions, reordered_positions);
532    }
533
534    #[test]
535    fn test_overlapping_draw_preserves_order() {
536        let commands = vec![
537            create_cmd(PxPosition::new(Px(0), Px(0)), None, false), // 0
538            create_cmd(PxPosition::new(Px(5), Px(5)), None, false), // 1 (overlaps with 0)
539        ];
540        let original_positions = get_positions(&commands);
541        let reordered = reorder_instructions(commands);
542        let reordered_positions = get_positions(&reordered);
543        assert_eq!(reordered_positions, original_positions);
544    }
545
546    #[test]
547    fn test_draw_before_overlapping_compute_preserves_order() {
548        let commands = vec![
549            create_cmd(
550                PxPosition::new(Px(0), Px(0)),
551                Some(BarrierRequirement::Global),
552                false,
553            ), // 0: BarrierDraw
554            create_cmd(
555                PxPosition::new(Px(20), Px(20)),
556                Some(BarrierRequirement::Global),
557                true,
558            ), // 1: Compute
559        ];
560        let original_positions = get_positions(&commands);
561        let reordered = reorder_instructions(commands);
562        let reordered_positions = get_positions(&reordered);
563        assert_eq!(reordered_positions, original_positions);
564    }
565
566    #[test]
567    fn test_reorder_based_on_priority_with_no_overlap() {
568        let commands = vec![
569            create_cmd(
570                PxPosition::new(Px(0), Px(0)),
571                Some(BarrierRequirement::Absolute(PxRect::new(
572                    Px(0),
573                    Px(0),
574                    Px(10),
575                    Px(10),
576                ))), // rect A
577                false, // BarrierDraw
578            ), // 0
579            create_cmd(
580                PxPosition::new(Px(100), Px(100)),
581                Some(BarrierRequirement::Absolute(PxRect::new(
582                    Px(100),
583                    Px(100),
584                    Px(10),
585                    Px(10),
586                ))), // rect B
587                true, // Compute
588            ), // 1
589            create_cmd(PxPosition::new(Px(200), Px(200)), None, false), // 2: ContinuationDraw
590        ];
591        let original_positions = get_positions(&commands);
592        // No dependencies as all rects are orthogonal.
593        // Priority: Compute (1) > BarrierDraw (0) > ContinuationDraw (2)
594        let reordered = reorder_instructions(commands);
595        let reordered_positions = get_positions(&reordered);
596
597        let expected_positions = vec![
598            original_positions[1], // Compute
599            original_positions[0], // BarrierDraw
600            original_positions[2], // ContinuationDraw
601        ];
602        assert_eq!(reordered_positions, expected_positions);
603    }
604
605    #[test]
606    fn test_complex_reordering_with_dependencies() {
607        let commands = vec![
608            // 0: Compute. Must run first.
609            create_cmd(
610                PxPosition::new(Px(0), Px(0)),
611                Some(BarrierRequirement::Global),
612                true,
613            ),
614            // 1: BarrierDraw. Depends on 0. Orthogonal to 4.
615            create_cmd(
616                PxPosition::new(Px(50), Px(50)),
617                Some(BarrierRequirement::Absolute(PxRect::new(
618                    Px(50),
619                    Px(50),
620                    Px(10),
621                    Px(10),
622                ))),
623                false,
624            ),
625            // 2: ContinuationDraw. Overlaps with 3.
626            create_cmd(PxPosition::new(Px(200), Px(200)), None, false),
627            // 3: ContinuationDraw.
628            create_cmd(PxPosition::new(Px(205), Px(205)), None, false),
629            // 4: BarrierDraw. Depends on 0. Orthogonal to 1.
630            create_cmd(
631                PxPosition::new(Px(80), Px(80)),
632                Some(BarrierRequirement::Absolute(PxRect::new(
633                    Px(80),
634                    Px(80),
635                    Px(10),
636                    Px(10),
637                ))),
638                false,
639            ),
640        ];
641        let original_positions = get_positions(&commands);
642
643        // Dependencies:
644        // 0 -> 1 (Compute -> Barrier)
645        // 0 -> 4 (Compute -> Barrier)
646        // 2 -> 3 (Overlapping Draw)
647        // Potentials: Compute:1, BarrierDraw:2, ContinuationDraw:2
648        // All categories have different potentials, so batching heuristic won't apply across categories.
649        // Ready queue starts with [0(C), 2(CD)] -> Prio sort -> [0, 2]
650        // 1. Pop 0. Result: [0]. Add 1, 4 to queue. Queue: [1(BD), 4(BD), 2(CD)]. Prio sort: [1,4,2]
651        // 2. Pop 1. Result: [0, 1].
652        // 3. Pop 4. Result: [0, 1, 4].
653        // 4. Pop 2. Result: [0, 1, 4, 2]. Add 3 to queue. Queue: [3]
654        // 5. Pop 3. Result: [0, 1, 4, 2, 3].
655        let reordered = reorder_instructions(commands);
656        let reordered_positions = get_positions(&reordered);
657        let expected_positions = vec![
658            original_positions[0],
659            original_positions[1],
660            original_positions[4],
661            original_positions[2],
662            original_positions[3],
663        ];
664        assert_eq!(reordered_positions, expected_positions);
665    }
666}