flux_middle/rty/
normalize.rs

1use std::ops::ControlFlow;
2
3use itertools::Itertools;
4use rustc_data_structures::{fx::FxIndexSet, unord::UnordMap};
5use rustc_hir::def_id::{CrateNum, DefIndex, LOCAL_CRATE};
6use rustc_macros::{TyDecodable, TyEncodable};
7use toposort_scc::IndexGraph;
8
9use super::{ESpan, fold::TypeSuperFoldable};
10use crate::{
11    def_id::{FluxDefId, FluxId, FluxLocalDefId},
12    fhir::SpecFuncKind,
13    global_env::GlobalEnv,
14    rty::{
15        Binder, Expr, ExprKind,
16        fold::{TypeFoldable, TypeFolder, TypeSuperVisitable, TypeVisitable, TypeVisitor},
17    },
18};
19
20#[derive(TyEncodable, TyDecodable)]
21pub struct NormalizedDefns {
22    krate: CrateNum,
23    defns: UnordMap<FluxId<DefIndex>, NormalizeInfo>,
24}
25
26impl Default for NormalizedDefns {
27    fn default() -> Self {
28        Self { krate: LOCAL_CRATE, defns: UnordMap::default() }
29    }
30}
31
32/// This type represents what we know about a flux-def *after*
33/// normalization, i.e. after "inlining" all or some transitively
34/// called flux-defs.
35/// - When `FLUX_SMT_DEFINE_FUN=1` is set we inline
36///   all *polymorphic* flux-defs, since they cannot
37///   be represented  as `define-fun` in SMTLIB but leave
38///   all *monomorphic* flux-defs un-inlined.
39/// - When the above flag is not set, we replace *every* flux-def
40///   with its (transitively) inlined body
41#[derive(Clone, TyEncodable, TyDecodable)]
42pub struct NormalizeInfo {
43    /// the actual definition, with the `Binder` representing the parameters
44    pub body: Binder<Expr>,
45    /// whether or not this function is inlined (i.e. NOT represented as `define-fun`)
46    pub inline: bool,
47    /// the rank of this defn in the topological sort of all the flux-defs, needed so
48    /// we can specify the `define-fun` in the correct order, without any "forward"
49    /// dependencies which the SMT solver cannot handle
50    pub rank: usize,
51    /// whether or not this function is uninterpreted by default
52    pub hide: bool,
53}
54
55pub(super) struct Normalizer<'a, 'genv, 'tcx> {
56    genv: GlobalEnv<'genv, 'tcx>,
57    defs: Option<&'a UnordMap<FluxLocalDefId, NormalizeInfo>>,
58}
59
60impl NormalizedDefns {
61    pub fn new(
62        genv: GlobalEnv,
63        defns: &[(FluxLocalDefId, Binder<Expr>, bool)],
64    ) -> Result<Self, Vec<FluxLocalDefId>> {
65        // 1. Topologically sort the Defns
66        let ds = toposort(defns)?;
67
68        // 2. Expand each defn in the sorted order
69        let mut normalized = UnordMap::default();
70        let mut ids = vec![];
71        for (rank, i) in ds.iter().enumerate() {
72            let (id, body, hide) = &defns[*i];
73            let body = body.fold_with(&mut Normalizer::new(genv, Some(&normalized)));
74
75            let inline = genv.should_inline_fun(id.to_def_id());
76            let info = NormalizeInfo { body: body.clone(), inline, rank, hide: *hide };
77            ids.push(*id);
78            normalized.insert(*id, info);
79        }
80        Ok(Self {
81            krate: LOCAL_CRATE,
82            defns: normalized
83                .into_items()
84                .map(|(id, body)| (id.local_def_index(), body))
85                .collect(),
86        })
87    }
88
89    pub fn func_info(&self, did: FluxDefId) -> NormalizeInfo {
90        debug_assert_eq!(self.krate, did.krate());
91        self.defns.get(&did.index()).unwrap().clone()
92    }
93}
94
95/// Returns
96/// * either Ok(d1...dn) which are topologically sorted such that
97///   forall i < j, di does not depend on i.e. "call" dj
98/// * or Err(d1...dn) where d1 'calls' d2 'calls' ... 'calls' dn 'calls' d1
99fn toposort<T>(
100    defns: &[(FluxLocalDefId, Binder<Expr>, T)],
101) -> Result<Vec<usize>, Vec<FluxLocalDefId>> {
102    // 1. Make a Symbol to Index map
103    let s2i: UnordMap<FluxLocalDefId, usize> = defns
104        .iter()
105        .enumerate()
106        .map(|(i, defn)| (defn.0, i))
107        .collect();
108
109    // 2. Make the dependency graph
110    let mut adj_list = Vec::with_capacity(defns.len());
111    for defn in defns {
112        let deps = local_deps(&defn.1);
113        let ddeps = deps
114            .iter()
115            .filter_map(|s| s2i.get(s).copied())
116            .collect_vec();
117        adj_list.push(ddeps);
118    }
119    let mut g = IndexGraph::from_adjacency_list(&adj_list);
120    g.transpose();
121
122    // 3. Topologically sort the graph
123    match g.toposort_or_scc() {
124        Ok(is) => Ok(is),
125        Err(mut scc) => {
126            let cycle = scc.pop().unwrap();
127            Err(cycle.iter().map(|i| defns[*i].0).collect())
128        }
129    }
130}
131
132pub fn local_deps(body: &Binder<Expr>) -> FxIndexSet<FluxLocalDefId> {
133    struct DepsVisitor(FxIndexSet<FluxLocalDefId>);
134    impl TypeVisitor for DepsVisitor {
135        #[allow(clippy::disallowed_methods, reason = "refinement functions cannot be extern specs")]
136        fn visit_expr(&mut self, expr: &Expr) -> ControlFlow<!> {
137            if let ExprKind::App(func, _) = expr.kind()
138                && let ExprKind::GlobalFunc(SpecFuncKind::Def(did)) = func.kind()
139                && let Some(did) = did.as_local()
140            {
141                self.0.insert(did);
142            }
143            expr.super_visit_with(self)
144        }
145    }
146    let mut visitor = DepsVisitor(Default::default());
147    body.visit_with(&mut visitor);
148    visitor.0
149}
150
151impl<'a, 'genv, 'tcx> Normalizer<'a, 'genv, 'tcx> {
152    pub(super) fn new(
153        genv: GlobalEnv<'genv, 'tcx>,
154        defs: Option<&'a UnordMap<FluxLocalDefId, NormalizeInfo>>,
155    ) -> Self {
156        Self { genv, defs }
157    }
158
159    #[allow(clippy::disallowed_methods, reason = "refinement functions cannot be extern specs")]
160    fn func_defn(&self, did: FluxDefId) -> Binder<Expr> {
161        if let Some(defs) = self.defs
162            && let Some(local_id) = did.as_local()
163        {
164            defs.get(&local_id).unwrap().body.clone()
165        } else {
166            self.genv.normalized_info(did).body
167        }
168    }
169
170    #[allow(clippy::disallowed_methods, reason = "refinement functions cannot be extern specs")]
171    fn inline(&self, did: &FluxDefId) -> bool {
172        let info = if let Some(defs) = self.defs
173            && let Some(local_id) = did.as_local()
174            && let Some(info) = defs.get(&local_id)
175        {
176            info
177        } else {
178            &self.genv.normalized_info(*did)
179        };
180        info.inline && !info.hide
181    }
182
183    fn at_base(expr: Expr, espan: Option<ESpan>) -> Expr {
184        match espan {
185            Some(espan) => BaseSpanner::new(espan).fold_expr(&expr),
186            None => expr,
187        }
188    }
189
190    fn app(&mut self, func: &Expr, args: &[Expr], espan: Option<ESpan>) -> Expr {
191        match func.kind() {
192            ExprKind::GlobalFunc(SpecFuncKind::Def(did)) if self.inline(did) => {
193                let res = self.func_defn(*did).replace_bound_refts(args);
194                Self::at_base(res, espan)
195            }
196            ExprKind::Abs(lam) => {
197                let res = lam.apply(args);
198                Self::at_base(res, espan)
199            }
200            _ => Expr::app(func.clone(), args.into()).at_opt(espan),
201        }
202    }
203}
204
205impl TypeFolder for Normalizer<'_, '_, '_> {
206    fn fold_expr(&mut self, expr: &Expr) -> Expr {
207        let expr = expr.super_fold_with(self);
208        let span = expr.span();
209        match expr.kind() {
210            ExprKind::App(func, args) => self.app(func, args, span),
211            ExprKind::FieldProj(e, proj) => e.proj_and_reduce(*proj),
212            _ => expr,
213        }
214    }
215}
216
217struct BaseSpanner {
218    espan: ESpan,
219}
220
221impl BaseSpanner {
222    fn new(espan: ESpan) -> Self {
223        Self { espan }
224    }
225}
226
227impl TypeFolder for BaseSpanner {
228    fn fold_expr(&mut self, expr: &Expr) -> Expr {
229        expr.super_fold_with(self).at_base(self.espan)
230    }
231}