1use std::{
2 any::TypeId,
3 collections::{BinaryHeap, HashMap},
4};
5
6use petgraph::{
7 graph::{DiGraph, NodeIndex},
8 visit::IntoNodeIdentifiers,
9};
10
11use crate::{
12 px::{Px, PxPosition, PxRect, PxSize},
13 renderer::command::{BarrierRequirement, Command},
14};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
19pub(crate) enum InstructionCategory {
20 ContinuationDraw,
22 BarrierDraw,
24 Compute,
26 StateChange,
28}
29
30pub(crate) struct InstructionInfo {
32 pub(crate) original_index: usize,
33 pub(crate) command: Command,
34 pub(crate) type_id: TypeId,
35 pub(crate) size: PxSize,
36 pub(crate) position: PxPosition,
37 pub(crate) category: InstructionCategory,
38 pub(crate) rect: PxRect,
39}
40
41impl InstructionInfo {
42 pub(crate) fn new(
46 (command, type_id, size, position): (Command, TypeId, PxSize, PxPosition),
47 original_index: usize,
48 ) -> Self {
49 let (category, rect) = match &command {
50 Command::Compute(command) => {
51 let barrier_req = command.barrier();
54 let rect = match barrier_req {
55 BarrierRequirement::Global => PxRect {
56 x: Px(0),
57 y: Px(0),
58 width: Px(i32::MAX),
59 height: Px(i32::MAX),
60 },
61 BarrierRequirement::PaddedLocal {
62 top,
63 right,
64 bottom,
65 left,
66 } => {
67 let padded_x = (position.x - left).max(Px(0));
68 let padded_y = (position.y - top).max(Px(0));
69 let padded_width = size.width + left + right;
70 let padded_height = size.height + top + bottom;
71 PxRect {
72 x: padded_x,
73 y: padded_y,
74 width: padded_width,
75 height: padded_height,
76 }
77 }
78 BarrierRequirement::Absolute(rect) => rect,
79 };
80 (InstructionCategory::Compute, rect)
81 }
82 Command::Draw(draw_command) => {
83 let barrier = draw_command.barrier();
84 let category = if barrier.is_some() {
85 InstructionCategory::BarrierDraw
86 } else {
87 InstructionCategory::ContinuationDraw
88 };
89
90 let rect = match barrier {
91 Some(BarrierRequirement::Global) => PxRect {
92 x: Px(0),
93 y: Px(0),
94 width: Px(i32::MAX),
95 height: Px(i32::MAX),
96 },
97 Some(BarrierRequirement::PaddedLocal {
98 top,
99 right,
100 bottom,
101 left,
102 }) => {
103 let padded_x = (position.x - left).max(Px(0));
104 let padded_y = (position.y - top).max(Px(0));
105 let padded_width = size.width + left + right;
106 let padded_height = size.height + top + bottom;
107 PxRect {
108 x: padded_x,
109 y: padded_y,
110 width: padded_width,
111 height: padded_height,
112 }
113 }
114 Some(BarrierRequirement::Absolute(rect)) => rect,
115 None => PxRect {
116 x: position.x,
117 y: position.y,
118 width: size.width,
119 height: size.height,
120 },
121 };
122 (category, rect)
123 }
124 Command::ClipPush(rect) => (InstructionCategory::StateChange, *rect),
125 Command::ClipPop => (
126 InstructionCategory::StateChange,
127 PxRect {
128 x: position.x,
129 y: position.y,
130 width: Px::ZERO,
131 height: Px::ZERO,
132 },
133 ),
134 };
135
136 Self {
137 original_index,
138 command,
139 type_id,
140 size,
141 position,
142 category,
143 rect,
144 }
145 }
146}
147
148#[derive(Debug, Clone, Copy, PartialEq, Eq)]
150struct PriorityNode {
151 category: InstructionCategory,
152 type_id: TypeId,
153 original_index: usize,
154 node_index: NodeIndex,
155 batch_potential: usize,
156}
157
158impl Ord for PriorityNode {
159 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
160 self.category
167 .cmp(&other.category)
168 .then_with(|| self.batch_potential.cmp(&other.batch_potential).reverse())
169 .then_with(|| self.original_index.cmp(&other.original_index).reverse())
170 }
171}
172
173impl PartialOrd for PriorityNode {
174 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
175 Some(self.cmp(other))
176 }
177}
178
179pub(crate) fn reorder_instructions(
180 commands: impl IntoIterator<Item = (Command, TypeId, PxSize, PxPosition)>,
181) -> Vec<(Command, TypeId, PxSize, PxPosition)> {
182 let instructions: Vec<InstructionInfo> = commands
183 .into_iter()
184 .enumerate()
185 .map(|(i, cmd)| InstructionInfo::new(cmd, i))
186 .collect();
187
188 if instructions.is_empty() {
189 return vec![];
190 }
191
192 let mut potentials = HashMap::new();
193 for info in &instructions {
194 *potentials.entry((info.category, info.type_id)).or_insert(0) += 1;
195 }
196
197 let graph = build_dependency_graph(&instructions);
198
199 let sorted_node_indices = priority_topological_sort(&graph, &instructions, &potentials);
200
201 let mut sorted_instructions = Vec::with_capacity(instructions.len());
202 let mut original_infos: Vec<_> = instructions.into_iter().map(Some).collect();
203
204 for node_index in sorted_node_indices {
205 let original_index = node_index.index();
206 if let Some(info) = original_infos[original_index].take() {
207 sorted_instructions.push((info.command, info.type_id, info.size, info.position));
208 }
209 }
210
211 sorted_instructions
212}
213
214fn priority_topological_sort(
215 graph: &DiGraph<(), ()>,
216 instructions: &[InstructionInfo],
217 potentials: &HashMap<(InstructionCategory, TypeId), usize>,
218) -> Vec<NodeIndex> {
219 let mut in_degree = vec![0; graph.node_count()];
220 for edge in graph.raw_edges() {
221 in_degree[edge.target().index()] += 1;
222 }
223
224 let mut ready_queue = BinaryHeap::new();
225 for node_index in graph.node_identifiers() {
226 if in_degree[node_index.index()] == 0 {
227 let info = &instructions[node_index.index()];
228 ready_queue.push(PriorityNode {
229 category: info.category,
230 type_id: info.type_id,
231 original_index: info.original_index,
232 node_index,
233 batch_potential: potentials[&(info.category, info.type_id)],
234 });
235 }
236 }
237
238 let mut sorted_list = Vec::with_capacity(instructions.len());
239 while let Some(priority_node) = ready_queue.pop() {
240 let u = priority_node.node_index;
241 sorted_list.push(u);
242
243 for v in graph.neighbors(u) {
244 in_degree[v.index()] -= 1;
245 if in_degree[v.index()] == 0 {
246 let info = &instructions[v.index()];
247 ready_queue.push(PriorityNode {
248 category: info.category,
249 type_id: info.type_id,
250 original_index: info.original_index,
251 node_index: v,
252 batch_potential: potentials[&(info.category, info.type_id)],
253 });
254 }
255 }
256 }
257
258 if sorted_list.len() != instructions.len() {
259 return (0..instructions.len()).map(NodeIndex::new).collect();
263 }
264
265 sorted_list
266}
267
268fn build_dependency_graph(instructions: &[InstructionInfo]) -> DiGraph<(), ()> {
269 let mut graph = DiGraph::new();
270 let node_indices: Vec<NodeIndex> = instructions.iter().map(|_| graph.add_node(())).collect();
271
272 for i in 0..instructions.len() {
273 for j in 0..instructions.len() {
274 if i == j {
275 continue;
276 }
277
278 let inst_i = &instructions[i];
279 let inst_j = &instructions[j];
280
281 if inst_i.original_index < inst_j.original_index
284 && (inst_i.category == InstructionCategory::StateChange
285 || inst_j.category == InstructionCategory::StateChange)
286 {
287 graph.add_edge(node_indices[i], node_indices[j], ());
288 }
289
290 if inst_i.category == InstructionCategory::Compute
294 && inst_j.category == InstructionCategory::BarrierDraw
295 && inst_i.original_index < inst_j.original_index
296 {
297 graph.add_edge(node_indices[i], node_indices[j], ());
298 }
299
300 if (inst_i.category == InstructionCategory::BarrierDraw
305 || inst_i.category == InstructionCategory::ContinuationDraw)
306 && (inst_j.category == InstructionCategory::BarrierDraw
307 || inst_j.category == InstructionCategory::ContinuationDraw)
308 && inst_i.original_index < inst_j.original_index
309 && !inst_i.rect.is_orthogonal(&inst_j.rect)
310 {
311 graph.add_edge(node_indices[i], node_indices[j], ());
312 }
313
314 if (inst_i.category == InstructionCategory::BarrierDraw
318 || inst_i.category == InstructionCategory::ContinuationDraw)
319 && inst_j.category == InstructionCategory::Compute
320 && inst_i.original_index < inst_j.original_index
321 && !inst_i.rect.is_orthogonal(&inst_j.rect)
322 {
323 graph.add_edge(node_indices[i], node_indices[j], ());
324 }
325 }
326 }
327
328 graph
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334 use crate::{
335 px::{Px, PxPosition, PxRect, PxSize},
336 renderer::{
337 BarrierRequirement, command::Command, compute::ComputeCommand, drawer::DrawCommand,
338 },
339 };
340 use std::any::TypeId;
341 use std::fmt::Debug;
342
343 #[derive(Debug, PartialEq, Clone)]
347 struct MockDrawCommand {
348 barrier_req: Option<BarrierRequirement>,
349 }
350
351 impl DrawCommand for MockDrawCommand {
352 fn barrier(&self) -> Option<BarrierRequirement> {
353 self.barrier_req
354 }
355 }
356
357 #[derive(Debug, PartialEq, Clone)]
358 struct MockDrawCommand2 {
359 barrier_req: Option<BarrierRequirement>,
360 }
361
362 impl DrawCommand for MockDrawCommand2 {
363 fn barrier(&self) -> Option<BarrierRequirement> {
364 self.barrier_req
365 }
366 }
367
368 #[derive(Debug, PartialEq, Clone)]
369 struct MockComputeCommand {
370 barrier_req: BarrierRequirement,
371 }
372
373 impl ComputeCommand for MockComputeCommand {
374 fn barrier(&self) -> BarrierRequirement {
375 self.barrier_req
376 }
377 }
378
379 #[derive(Debug, PartialEq, Clone)]
380 struct MockComputeCommand2 {
381 barrier_req: BarrierRequirement,
382 }
383
384 impl ComputeCommand for MockComputeCommand2 {
385 fn barrier(&self) -> BarrierRequirement {
386 self.barrier_req
387 }
388 }
389
390 fn create_cmd(
393 pos: PxPosition,
394 barrier_req: Option<BarrierRequirement>,
395 is_compute: bool,
396 ) -> (Command, TypeId, PxSize, PxPosition) {
397 let size = PxSize::new(Px(10), Px(10));
398 if is_compute {
399 let cmd = MockComputeCommand {
400 barrier_req: barrier_req.unwrap_or(BarrierRequirement::Global),
401 };
402 (
403 Command::Compute(Box::new(cmd)),
404 TypeId::of::<MockComputeCommand>(),
405 size,
406 pos,
407 )
408 } else {
409 let cmd = MockDrawCommand { barrier_req };
410 (
411 Command::Draw(Box::new(cmd)),
412 TypeId::of::<MockDrawCommand>(),
413 size,
414 pos,
415 )
416 }
417 }
418
419 fn create_cmd2(
420 pos: PxPosition,
421 barrier_req: Option<BarrierRequirement>,
422 is_compute: bool,
423 ) -> (Command, TypeId, PxSize, PxPosition) {
424 let size = PxSize::new(Px(10), Px(10));
425 if is_compute {
426 let cmd = MockComputeCommand2 {
427 barrier_req: barrier_req.unwrap_or(BarrierRequirement::Global),
428 };
429 (
430 Command::Compute(Box::new(cmd)),
431 TypeId::of::<MockComputeCommand2>(),
432 size,
433 pos,
434 )
435 } else {
436 let cmd = MockDrawCommand2 { barrier_req };
437 (
438 Command::Draw(Box::new(cmd)),
439 TypeId::of::<MockDrawCommand2>(),
440 size,
441 pos,
442 )
443 }
444 }
445
446 fn get_positions(commands: &[(Command, TypeId, PxSize, PxPosition)]) -> Vec<PxPosition> {
447 commands.iter().map(|(_, _, _, pos)| *pos).collect()
448 }
449
450 #[test]
453 fn test_empty_instructions() {
454 let commands = vec![];
455 let reordered = reorder_instructions(commands);
456 assert!(reordered.is_empty());
457 }
458
459 #[test]
460 fn test_no_dependencies_preserves_order() {
461 let commands = vec![
462 create_cmd(PxPosition::new(Px(0), Px(0)), None, false), create_cmd(PxPosition::new(Px(20), Px(0)), None, false), ];
465 let original_positions = get_positions(&commands);
466 let reordered = reorder_instructions(commands);
467 let reordered_positions = get_positions(&reordered);
468 assert_eq!(reordered_positions, original_positions);
469 }
470
471 #[test]
472 fn test_compute_before_barrier_preserves_order() {
473 let commands = vec![
474 create_cmd(
475 PxPosition::new(Px(0), Px(0)),
476 Some(BarrierRequirement::Global),
477 true,
478 ), create_cmd(
480 PxPosition::new(Px(20), Px(20)),
481 Some(BarrierRequirement::Global),
482 false,
483 ), ];
485 let original_positions = get_positions(&commands);
486 let reordered = reorder_instructions(commands);
487 let reordered_positions = get_positions(&reordered);
488 assert_eq!(reordered_positions, original_positions);
489 }
490
491 #[test]
492 fn test_opt() {
493 let commands = vec![
495 create_cmd(PxPosition::new(Px(0), Px(0)), None, false), create_cmd2(PxPosition::new(Px(10), Px(10)), None, false), create_cmd(PxPosition::new(Px(20), Px(20)), None, false), ];
499 let reordered = reorder_instructions(commands);
500 let reordered_positions = get_positions(&reordered);
501
502 let expected_positions = vec![
506 PxPosition::new(Px(10), Px(10)), PxPosition::new(Px(0), Px(0)), PxPosition::new(Px(20), Px(20)), ];
510 assert_eq!(reordered_positions, expected_positions);
511
512 let commands = vec![
514 create_cmd(PxPosition::new(Px(0), Px(0)), None, false), create_cmd2(PxPosition::new(Px(10), Px(10)), None, false), create_cmd(PxPosition::new(Px(5), Px(5)), None, false), ];
518 let reordered = reorder_instructions(commands);
519 let reordered_positions = get_positions(&reordered);
520
521 let expected_positions = vec![
527 PxPosition::new(Px(10), Px(10)), PxPosition::new(Px(0), Px(0)), PxPosition::new(Px(5), Px(5)), ];
531 assert_eq!(expected_positions, reordered_positions);
532 }
533
534 #[test]
535 fn test_overlapping_draw_preserves_order() {
536 let commands = vec![
537 create_cmd(PxPosition::new(Px(0), Px(0)), None, false), create_cmd(PxPosition::new(Px(5), Px(5)), None, false), ];
540 let original_positions = get_positions(&commands);
541 let reordered = reorder_instructions(commands);
542 let reordered_positions = get_positions(&reordered);
543 assert_eq!(reordered_positions, original_positions);
544 }
545
546 #[test]
547 fn test_draw_before_overlapping_compute_preserves_order() {
548 let commands = vec![
549 create_cmd(
550 PxPosition::new(Px(0), Px(0)),
551 Some(BarrierRequirement::Global),
552 false,
553 ), create_cmd(
555 PxPosition::new(Px(20), Px(20)),
556 Some(BarrierRequirement::Global),
557 true,
558 ), ];
560 let original_positions = get_positions(&commands);
561 let reordered = reorder_instructions(commands);
562 let reordered_positions = get_positions(&reordered);
563 assert_eq!(reordered_positions, original_positions);
564 }
565
566 #[test]
567 fn test_reorder_based_on_priority_with_no_overlap() {
568 let commands = vec![
569 create_cmd(
570 PxPosition::new(Px(0), Px(0)),
571 Some(BarrierRequirement::Absolute(PxRect::new(
572 Px(0),
573 Px(0),
574 Px(10),
575 Px(10),
576 ))), false, ), create_cmd(
580 PxPosition::new(Px(100), Px(100)),
581 Some(BarrierRequirement::Absolute(PxRect::new(
582 Px(100),
583 Px(100),
584 Px(10),
585 Px(10),
586 ))), true, ), create_cmd(PxPosition::new(Px(200), Px(200)), None, false), ];
591 let original_positions = get_positions(&commands);
592 let reordered = reorder_instructions(commands);
595 let reordered_positions = get_positions(&reordered);
596
597 let expected_positions = vec![
598 original_positions[1], original_positions[0], original_positions[2], ];
602 assert_eq!(reordered_positions, expected_positions);
603 }
604
605 #[test]
606 fn test_complex_reordering_with_dependencies() {
607 let commands = vec![
608 create_cmd(
610 PxPosition::new(Px(0), Px(0)),
611 Some(BarrierRequirement::Global),
612 true,
613 ),
614 create_cmd(
616 PxPosition::new(Px(50), Px(50)),
617 Some(BarrierRequirement::Absolute(PxRect::new(
618 Px(50),
619 Px(50),
620 Px(10),
621 Px(10),
622 ))),
623 false,
624 ),
625 create_cmd(PxPosition::new(Px(200), Px(200)), None, false),
627 create_cmd(PxPosition::new(Px(205), Px(205)), None, false),
629 create_cmd(
631 PxPosition::new(Px(80), Px(80)),
632 Some(BarrierRequirement::Absolute(PxRect::new(
633 Px(80),
634 Px(80),
635 Px(10),
636 Px(10),
637 ))),
638 false,
639 ),
640 ];
641 let original_positions = get_positions(&commands);
642
643 let reordered = reorder_instructions(commands);
656 let reordered_positions = get_positions(&reordered);
657 let expected_positions = vec![
658 original_positions[0],
659 original_positions[1],
660 original_positions[4],
661 original_positions[2],
662 original_positions[3],
663 ];
664 assert_eq!(reordered_positions, expected_positions);
665 }
666}