1use std::collections::{HashMap, HashSet};
2
3use crate::Types;
4
5fn dfs_finish_order<'a, T: Types>(
6 node: &'a T::KVar,
7 graph: &'a HashMap<T::KVar, Vec<T::KVar>>,
8 visited: &mut HashSet<T::KVar>,
9 order: &mut Vec<T::KVar>,
10) {
11 if visited.contains(node) {
12 return;
13 }
14
15 visited.insert(node.clone());
16
17 if let Some(neighbors) = graph.get(node) {
18 for neighbor in neighbors {
19 dfs_finish_order::<T>(neighbor, graph, visited, order);
20 }
21 }
22
23 order.push(node.clone());
24}
25
26fn reverse_graph<T: Types>(
27 graph: &HashMap<T::KVar, Vec<T::KVar>>,
28) -> HashMap<T::KVar, Vec<T::KVar>> {
29 let mut reversed = HashMap::new();
30
31 for (node, neighbors) in graph {
32 for neighbor in neighbors {
33 reversed
34 .entry(neighbor.clone())
35 .or_insert_with(Vec::new)
36 .push(node.clone());
37 }
38 }
39
40 for node in graph.keys() {
41 reversed.entry(node.clone()).or_insert_with(Vec::new);
42 }
43
44 reversed
45}
46
47fn dfs_collect_scc<'a, T: Types>(
48 node: &'a T::KVar,
49 graph: &'a HashMap<T::KVar, Vec<T::KVar>>,
50 visited: &mut HashSet<T::KVar>,
51 scc: &mut Vec<T::KVar>,
52) {
53 if visited.contains(node) {
54 return;
55 }
56
57 visited.insert(node.clone());
58 scc.push(node.clone());
59
60 if let Some(neighbors) = graph.get(node) {
61 for neighbor in neighbors {
62 dfs_collect_scc::<T>(neighbor, graph, visited, scc);
63 }
64 }
65}
66
67fn find_sccs<T: Types>(graph: &HashMap<T::KVar, Vec<T::KVar>>) -> Vec<Vec<T::KVar>> {
68 let mut visited = HashSet::new();
69 let mut order = Vec::new();
70
71 for node in graph.keys() {
73 if !visited.contains(node) {
74 dfs_finish_order::<T>(node, graph, &mut visited, &mut order);
75 }
76 }
77
78 let reversed = reverse_graph::<T>(graph);
79 visited.clear();
80 let mut sccs = Vec::new();
81
82 while let Some(node) = order.pop() {
84 if !visited.contains(&node) {
85 let mut scc = Vec::new();
86 dfs_collect_scc::<T>(&node, &reversed, &mut visited, &mut scc);
87 sccs.push(scc);
88 }
89 }
90
91 sccs
92}
93
94pub fn topological_sort_sccs<T: Types>(
95 graph: &HashMap<T::KVar, Vec<T::KVar>>,
96) -> Vec<Vec<T::KVar>> {
97 let sccs = find_sccs::<T>(graph);
98
99 let mut node_to_scc = HashMap::new();
101 for (i, scc) in sccs.iter().enumerate() {
102 for node in scc {
103 node_to_scc.insert(node.clone(), i);
104 }
105 }
106
107 let mut condensed_graph: HashMap<usize, HashSet<usize>> = HashMap::new();
109 for (node, neighbors) in graph {
110 let &from = node_to_scc.get(node).unwrap();
111 for neighbor in neighbors {
112 let &to = node_to_scc.get(neighbor).unwrap();
113 if from != to {
114 condensed_graph.entry(from).or_default().insert(to);
115 }
116 }
117 }
118
119 fn dfs_topo(
121 node: usize,
122 graph: &HashMap<usize, HashSet<usize>>,
123 visited: &mut HashSet<usize>,
124 result: &mut Vec<usize>,
125 ) {
126 if visited.contains(&node) {
127 return;
128 }
129
130 visited.insert(node);
131
132 if let Some(neighbors) = graph.get(&node) {
133 for &neighbor in neighbors {
134 dfs_topo(neighbor, graph, visited, result);
135 }
136 }
137
138 result.push(node);
139 }
140
141 let mut visited = HashSet::new();
142 let mut result = Vec::new();
143
144 for i in 0..sccs.len() {
145 if !visited.contains(&i) {
146 dfs_topo(i, &condensed_graph, &mut visited, &mut result);
147 }
148 }
149
150 result.reverse(); result.into_iter().map(|i| sccs[i].clone()).collect()
152}