liquid_fixpoint/
graph.rs

1use std::collections::{HashMap, HashSet};
2
3use crate::Types;
4
5fn dfs_finish_order<'a, T: Types>(
6    node: &'a T::KVar,
7    graph: &'a HashMap<T::KVar, Vec<T::KVar>>,
8    visited: &mut HashSet<T::KVar>,
9    order: &mut Vec<T::KVar>,
10) {
11    if visited.contains(node) {
12        return;
13    }
14
15    visited.insert(node.clone());
16
17    if let Some(neighbors) = graph.get(node) {
18        for neighbor in neighbors {
19            dfs_finish_order::<T>(neighbor, graph, visited, order);
20        }
21    }
22
23    order.push(node.clone());
24}
25
26fn reverse_graph<T: Types>(
27    graph: &HashMap<T::KVar, Vec<T::KVar>>,
28) -> HashMap<T::KVar, Vec<T::KVar>> {
29    let mut reversed = HashMap::new();
30
31    for (node, neighbors) in graph {
32        for neighbor in neighbors {
33            reversed
34                .entry(neighbor.clone())
35                .or_insert_with(Vec::new)
36                .push(node.clone());
37        }
38    }
39
40    for node in graph.keys() {
41        reversed.entry(node.clone()).or_insert_with(Vec::new);
42    }
43
44    reversed
45}
46
47fn dfs_collect_scc<'a, T: Types>(
48    node: &'a T::KVar,
49    graph: &'a HashMap<T::KVar, Vec<T::KVar>>,
50    visited: &mut HashSet<T::KVar>,
51    scc: &mut Vec<T::KVar>,
52) {
53    if visited.contains(node) {
54        return;
55    }
56
57    visited.insert(node.clone());
58    scc.push(node.clone());
59
60    if let Some(neighbors) = graph.get(node) {
61        for neighbor in neighbors {
62            dfs_collect_scc::<T>(neighbor, graph, visited, scc);
63        }
64    }
65}
66
67fn find_sccs<T: Types>(graph: &HashMap<T::KVar, Vec<T::KVar>>) -> Vec<Vec<T::KVar>> {
68    let mut visited = HashSet::new();
69    let mut order = Vec::new();
70
71    // First pass: original graph
72    for node in graph.keys() {
73        if !visited.contains(node) {
74            dfs_finish_order::<T>(node, graph, &mut visited, &mut order);
75        }
76    }
77
78    let reversed = reverse_graph::<T>(graph);
79    visited.clear();
80    let mut sccs = Vec::new();
81
82    // Second pass: reversed graph in reverse finishing order
83    while let Some(node) = order.pop() {
84        if !visited.contains(&node) {
85            let mut scc = Vec::new();
86            dfs_collect_scc::<T>(&node, &reversed, &mut visited, &mut scc);
87            sccs.push(scc);
88        }
89    }
90
91    sccs
92}
93
94pub fn topological_sort_sccs<T: Types>(
95    graph: &HashMap<T::KVar, Vec<T::KVar>>,
96) -> Vec<Vec<T::KVar>> {
97    let sccs = find_sccs::<T>(graph);
98
99    // Map each node to its SCC index
100    let mut node_to_scc = HashMap::new();
101    for (i, scc) in sccs.iter().enumerate() {
102        for node in scc {
103            node_to_scc.insert(node.clone(), i);
104        }
105    }
106
107    // Build condensed graph (DAG of SCCs)
108    let mut condensed_graph: HashMap<usize, HashSet<usize>> = HashMap::new();
109    for (node, neighbors) in graph {
110        let &from = node_to_scc.get(node).unwrap();
111        for neighbor in neighbors {
112            let &to = node_to_scc.get(neighbor).unwrap();
113            if from != to {
114                condensed_graph.entry(from).or_default().insert(to);
115            }
116        }
117    }
118
119    // Perform topological sort on SCC graph using DFS
120    fn dfs_topo(
121        node: usize,
122        graph: &HashMap<usize, HashSet<usize>>,
123        visited: &mut HashSet<usize>,
124        result: &mut Vec<usize>,
125    ) {
126        if visited.contains(&node) {
127            return;
128        }
129
130        visited.insert(node);
131
132        if let Some(neighbors) = graph.get(&node) {
133            for &neighbor in neighbors {
134                dfs_topo(neighbor, graph, visited, result);
135            }
136        }
137
138        result.push(node);
139    }
140
141    let mut visited = HashSet::new();
142    let mut result = Vec::new();
143
144    for i in 0..sccs.len() {
145        if !visited.contains(&i) {
146            dfs_topo(i, &condensed_graph, &mut visited, &mut result);
147        }
148    }
149
150    result.reverse(); // topological order
151    result.into_iter().map(|i| sccs[i].clone()).collect()
152}