1use std::ops::ControlFlow;
2
3use itertools::Itertools;
4use rustc_data_structures::{fx::FxIndexSet, unord::UnordMap};
5use rustc_hir::def_id::{CrateNum, DefIndex, LOCAL_CRATE};
6use rustc_macros::{TyDecodable, TyEncodable};
7use toposort_scc::IndexGraph;
8
9use super::{ESpan, fold::TypeSuperFoldable};
10use crate::{
11 def_id::{FluxDefId, FluxId, FluxLocalDefId},
12 fhir::SpecFuncKind,
13 global_env::GlobalEnv,
14 rty::{
15 Binder, Expr, ExprKind,
16 fold::{TypeFoldable, TypeFolder, TypeSuperVisitable, TypeVisitable, TypeVisitor},
17 },
18};
19
20#[derive(TyEncodable, TyDecodable)]
21pub struct NormalizedDefns {
22 krate: CrateNum,
23 defns: UnordMap<FluxId<DefIndex>, NormalizeInfo>,
24}
25
26impl Default for NormalizedDefns {
27 fn default() -> Self {
28 Self { krate: LOCAL_CRATE, defns: UnordMap::default() }
29 }
30}
31
32#[derive(Clone, TyEncodable, TyDecodable)]
42pub struct NormalizeInfo {
43 pub body: Binder<Expr>,
45 pub inline: bool,
47 pub rank: usize,
51 pub hide: bool,
53}
54
55pub(super) struct Normalizer<'a, 'genv, 'tcx> {
56 genv: GlobalEnv<'genv, 'tcx>,
57 defs: Option<&'a UnordMap<FluxLocalDefId, NormalizeInfo>>,
58}
59
60impl NormalizedDefns {
61 pub fn new(
62 genv: GlobalEnv,
63 defns: &[(FluxLocalDefId, Binder<Expr>, bool)],
64 ) -> Result<Self, Vec<FluxLocalDefId>> {
65 let ds = toposort(defns)?;
67
68 let mut normalized = UnordMap::default();
70 let mut ids = vec![];
71 for (rank, i) in ds.iter().enumerate() {
72 let (id, body, hide) = &defns[*i];
73 let body = body.fold_with(&mut Normalizer::new(genv, Some(&normalized)));
74
75 let inline = genv.should_inline_fun(id.to_def_id());
76 let info = NormalizeInfo { body: body.clone(), inline, rank, hide: *hide };
77 ids.push(*id);
78 normalized.insert(*id, info);
79 }
80 Ok(Self {
81 krate: LOCAL_CRATE,
82 defns: normalized
83 .into_items()
84 .map(|(id, body)| (id.local_def_index(), body))
85 .collect(),
86 })
87 }
88
89 pub fn func_info(&self, did: FluxDefId) -> NormalizeInfo {
90 debug_assert_eq!(self.krate, did.krate());
91 self.defns.get(&did.index()).unwrap().clone()
92 }
93}
94
95fn toposort<T>(
100 defns: &[(FluxLocalDefId, Binder<Expr>, T)],
101) -> Result<Vec<usize>, Vec<FluxLocalDefId>> {
102 let s2i: UnordMap<FluxLocalDefId, usize> = defns
104 .iter()
105 .enumerate()
106 .map(|(i, defn)| (defn.0, i))
107 .collect();
108
109 let mut adj_list = Vec::with_capacity(defns.len());
111 for defn in defns {
112 let deps = local_deps(&defn.1);
113 let ddeps = deps
114 .iter()
115 .filter_map(|s| s2i.get(s).copied())
116 .collect_vec();
117 adj_list.push(ddeps);
118 }
119 let mut g = IndexGraph::from_adjacency_list(&adj_list);
120 g.transpose();
121
122 match g.toposort_or_scc() {
124 Ok(is) => Ok(is),
125 Err(mut scc) => {
126 let cycle = scc.pop().unwrap();
127 Err(cycle.iter().map(|i| defns[*i].0).collect())
128 }
129 }
130}
131
132pub fn local_deps(body: &Binder<Expr>) -> FxIndexSet<FluxLocalDefId> {
133 struct DepsVisitor(FxIndexSet<FluxLocalDefId>);
134 impl TypeVisitor for DepsVisitor {
135 #[allow(clippy::disallowed_methods, reason = "refinement functions cannot be extern specs")]
136 fn visit_expr(&mut self, expr: &Expr) -> ControlFlow<!> {
137 if let ExprKind::App(func, _) = expr.kind()
138 && let ExprKind::GlobalFunc(SpecFuncKind::Def(did)) = func.kind()
139 && let Some(did) = did.as_local()
140 {
141 self.0.insert(did);
142 }
143 expr.super_visit_with(self)
144 }
145 }
146 let mut visitor = DepsVisitor(Default::default());
147 body.visit_with(&mut visitor);
148 visitor.0
149}
150
151impl<'a, 'genv, 'tcx> Normalizer<'a, 'genv, 'tcx> {
152 pub(super) fn new(
153 genv: GlobalEnv<'genv, 'tcx>,
154 defs: Option<&'a UnordMap<FluxLocalDefId, NormalizeInfo>>,
155 ) -> Self {
156 Self { genv, defs }
157 }
158
159 #[allow(clippy::disallowed_methods, reason = "refinement functions cannot be extern specs")]
160 fn func_defn(&self, did: FluxDefId) -> Binder<Expr> {
161 if let Some(defs) = self.defs
162 && let Some(local_id) = did.as_local()
163 {
164 defs.get(&local_id).unwrap().body.clone()
165 } else {
166 self.genv.normalized_info(did).body
167 }
168 }
169
170 #[allow(clippy::disallowed_methods, reason = "refinement functions cannot be extern specs")]
171 fn inline(&self, did: &FluxDefId) -> bool {
172 let info = if let Some(defs) = self.defs
173 && let Some(local_id) = did.as_local()
174 && let Some(info) = defs.get(&local_id)
175 {
176 info
177 } else {
178 &self.genv.normalized_info(*did)
179 };
180 info.inline && !info.hide
181 }
182
183 fn at_base(expr: Expr, espan: Option<ESpan>) -> Expr {
184 match espan {
185 Some(espan) => BaseSpanner::new(espan).fold_expr(&expr),
186 None => expr,
187 }
188 }
189
190 fn app(&mut self, func: &Expr, args: &[Expr], espan: Option<ESpan>) -> Expr {
191 match func.kind() {
192 ExprKind::GlobalFunc(SpecFuncKind::Def(did)) if self.inline(did) => {
193 let res = self.func_defn(*did).replace_bound_refts(args);
194 Self::at_base(res, espan)
195 }
196 ExprKind::Abs(lam) => {
197 let res = lam.apply(args);
198 Self::at_base(res, espan)
199 }
200 _ => Expr::app(func.clone(), args.into()).at_opt(espan),
201 }
202 }
203}
204
205impl TypeFolder for Normalizer<'_, '_, '_> {
206 fn fold_expr(&mut self, expr: &Expr) -> Expr {
207 let expr = expr.super_fold_with(self);
208 let span = expr.span();
209 match expr.kind() {
210 ExprKind::App(func, args) => self.app(func, args, span),
211 ExprKind::FieldProj(e, proj) => e.proj_and_reduce(*proj),
212 _ => expr,
213 }
214 }
215}
216
217struct BaseSpanner {
218 espan: ESpan,
219}
220
221impl BaseSpanner {
222 fn new(espan: ESpan) -> Self {
223 Self { espan }
224 }
225}
226
227impl TypeFolder for BaseSpanner {
228 fn fold_expr(&mut self, expr: &Expr) -> Expr {
229 expr.super_fold_with(self).at_base(self.espan)
230 }
231}