flux_middle/rty/
canonicalize.rs

1//! A canonical type is a type where all [existentials] and [constraint predicates] are *hoisted* to
2//! the top level. For example, the canonical version of `∃a. {∃b. i32[a + b] | b > 0}` is
3//! `∃a,b. {i32[a + b] | b > 0}`.
4//!
5//! Type constructors introduce scopes that can limit the hoisting. For instance, it is generally
6//! not permitted to hoist an existential out of a generic argument. For example, in `Vec<∃v. i32[v]>`
7//! the existential inside the `Vec` cannot be hoisted out.
8//!
9//! However, some type constructors are more "lenient" with respect to hoisting. Consider the tuple
10//! `(∃a. i32[a], ∃b. i32[b])`. Hoisting the existentials results in `∃a,b. (i32[a], i32[b])` which
11//! is an equivalent type (in the sense that subtyping holds both ways). The same applies to shared
12//! references: `&∃a. i32[a]` is equivalent to `∃a. &i32[a]`. We refer to this class of type
13//! constructors as *transparent*. Hoisting existential out of transparent type constructors is useful
14//! as it allows the logical information to be extracted from the type.
15//!
16//! And important case is mutable references. In some situations, it is sound to hoist out of mutable
17//! references. For example, if we have a variable in the environment of type `&mut ∃v. T[v]`, it is
18//! sound to treat it as `&mut T[a]` for a freshly generated `a` (assuming the lifetime of the
19//! reference is alive). However, this may result in a type that is *too specific* because the index
20//! `a` cannot be updated anymore.
21//!
22//! By default, we do *shallow* hoisting, i.e., we stop at the first type constructor. This is enough
23//! for cases where we need to inspect a type structurally one level. The amount of hoisting can be
24//! controlled by configuring the [`Hoister`] struct.
25//!
26//! It's also important to note that canonizalization doesn't imply any form of semantic equality
27//! and it is just a best effort to facilitate syntactic manipulation. For example, the types
28//! `∃a,b. (i32[a], i32[b])` and `∃a,b. (i32[b], i32[a])` are semantically equal but hoisting won't
29//! account for it.
30//!
31//! [existentials]: TyKind::Exists
32//! [constraint predicates]: TyKind::Constr
33use std::fmt::Write;
34
35use flux_arc_interner::List;
36use flux_macros::{TypeFoldable, TypeVisitable};
37use itertools::Itertools;
38use rustc_ast::Mutability;
39use rustc_type_ir::{BoundVar, INNERMOST};
40
41use super::{
42    BaseTy, Binder, BoundVariableKind, Expr, FnSig, GenericArg, GenericArgsExt, PolyFnSig,
43    SubsetTy, Ty, TyCtor, TyKind, TyOrBase,
44    fold::{TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitable},
45};
46use crate::rty::{ExprKind, HoleKind};
47
48/// The [`Hoister`] struct is responsible for hoisting existentials and predicates out of a type.
49/// It can be configured to stop hoisting at specific type constructors.
50///
51/// The struct is generic on a delegate `D` because we use it to do *local* hoisting, keeping
52/// variables bound with a [`Binder`], and for *freeing* variables into the refinement context.
53// Should we use a builder for this?
54pub struct Hoister<D> {
55    delegate: D,
56    in_boxes: bool,
57    in_downcast: bool,
58    in_mut_refs: bool,
59    in_shr_refs: bool,
60    in_strg_refs: bool,
61    in_tuples: bool,
62    existentials: bool,
63    slices: bool,
64}
65
66pub trait HoisterDelegate {
67    fn hoist_exists(&mut self, ty_ctor: &TyCtor) -> Ty;
68    fn hoist_constr(&mut self, pred: Expr);
69}
70
71impl<D> Hoister<D> {
72    pub fn with_delegate(delegate: D) -> Self {
73        Hoister {
74            delegate,
75            in_tuples: false,
76            in_shr_refs: false,
77            in_mut_refs: false,
78            in_strg_refs: false,
79            in_boxes: false,
80            in_downcast: false,
81            existentials: true,
82            slices: false,
83        }
84    }
85
86    pub fn hoist_inside_shr_refs(mut self, shr_refs: bool) -> Self {
87        self.in_shr_refs = shr_refs;
88        self
89    }
90
91    pub fn hoist_inside_mut_refs(mut self, mut_refs: bool) -> Self {
92        self.in_mut_refs = mut_refs;
93        self
94    }
95
96    pub fn hoist_inside_strg_refs(mut self, strg_refs: bool) -> Self {
97        self.in_strg_refs = strg_refs;
98        self
99    }
100
101    pub fn hoist_inside_tuples(mut self, tuples: bool) -> Self {
102        self.in_tuples = tuples;
103        self
104    }
105
106    pub fn hoist_inside_boxes(mut self, boxes: bool) -> Self {
107        self.in_boxes = boxes;
108        self
109    }
110
111    pub fn hoist_inside_downcast(mut self, downcast: bool) -> Self {
112        self.in_downcast = downcast;
113        self
114    }
115
116    pub fn hoist_existentials(mut self, exists: bool) -> Self {
117        self.existentials = exists;
118        self
119    }
120
121    pub fn hoist_slices(mut self, slices: bool) -> Self {
122        self.slices = slices;
123        self
124    }
125
126    pub fn transparent(self) -> Self {
127        self.hoist_inside_boxes(true)
128            .hoist_inside_downcast(true)
129            .hoist_inside_mut_refs(false)
130            .hoist_inside_shr_refs(true)
131            .hoist_inside_strg_refs(true)
132            .hoist_inside_tuples(true)
133            .hoist_slices(true)
134    }
135
136    pub fn shallow(self) -> Self {
137        self.hoist_inside_boxes(false)
138            .hoist_inside_downcast(false)
139            .hoist_inside_mut_refs(false)
140            .hoist_inside_shr_refs(false)
141            .hoist_inside_strg_refs(false)
142            .hoist_inside_tuples(false)
143    }
144}
145
146impl<D: HoisterDelegate> Hoister<D> {
147    pub fn hoist(&mut self, ty: &Ty) -> Ty {
148        ty.fold_with(self)
149    }
150}
151
152/// Is `ty` of the form `&m (&m ... (&m T))` where `T` is an exi-indexed slice?
153/// We need to do a "transitive" check to deal with cases like `&mut &mut [i32]`
154/// which arise from closures like that in `tests/tests/pos/surface/closure03.rs`.
155fn is_indexed_slice(ty: &Ty) -> bool {
156    let Some(bty) = ty.as_bty_skipping_existentials() else {
157        return false;
158    };
159    match bty {
160        BaseTy::Slice(_) => true,
161        BaseTy::Ref(_, ty, _) => is_indexed_slice(ty),
162        _ => false,
163    }
164}
165
166impl<D: HoisterDelegate> TypeFolder for Hoister<D> {
167    fn fold_ty(&mut self, ty: &Ty) -> Ty {
168        match ty.kind() {
169            TyKind::Indexed(bty, idx) => Ty::indexed(bty.fold_with(self), idx.clone()),
170            TyKind::Exists(ty_ctor) if self.existentials => {
171                // Avoid hoisting useless parameters for unit sorts. This is important for
172                // canonicalization because we assume mutable references won't be under a
173                // binder after we canonicalize them.
174                // FIXME(nilehmann) this same logic is repeated in a couple of places, e.g.,
175                // TyCtor::to_ty
176                match &ty_ctor.vars()[..] {
177                    [BoundVariableKind::Refine(sort, ..)] => {
178                        if sort.is_unit() {
179                            ty_ctor.replace_bound_reft(&Expr::unit())
180                        } else if let Some(def_id) = sort.is_unit_adt() {
181                            ty_ctor.replace_bound_reft(&Expr::unit_struct(def_id))
182                        } else {
183                            self.delegate.hoist_exists(ty_ctor)
184                        }
185                    }
186                    _ => self.delegate.hoist_exists(ty_ctor),
187                }
188                .fold_with(self)
189            }
190            TyKind::Constr(pred, ty) => {
191                self.delegate.hoist_constr(pred.clone());
192                ty.fold_with(self)
193            }
194            TyKind::StrgRef(..) if self.in_strg_refs => ty.super_fold_with(self),
195            TyKind::Downcast(..) if self.in_downcast => ty.super_fold_with(self),
196            _ => ty.clone(),
197        }
198    }
199
200    fn fold_bty(&mut self, bty: &BaseTy) -> BaseTy {
201        match bty {
202            BaseTy::Adt(adt_def, args) if adt_def.is_box() && self.in_boxes => {
203                let (boxed, alloc) = args.box_args();
204                let args = List::from_arr([
205                    GenericArg::Ty(boxed.fold_with(self)),
206                    GenericArg::Ty(alloc.clone()),
207                ]);
208                BaseTy::Adt(adt_def.clone(), args)
209            }
210            BaseTy::Ref(re, ty, mutability) if is_indexed_slice(ty) && self.slices => {
211                BaseTy::Ref(*re, ty.fold_with(self), *mutability)
212            }
213            BaseTy::Ref(re, ty, Mutability::Not) if self.in_shr_refs => {
214                BaseTy::Ref(*re, ty.fold_with(self), Mutability::Not)
215            }
216            BaseTy::Ref(re, ty, Mutability::Mut) if self.in_mut_refs => {
217                BaseTy::Ref(*re, ty.fold_with(self), Mutability::Mut)
218            }
219            BaseTy::Tuple(tys) if self.in_tuples => BaseTy::Tuple(tys.fold_with(self)),
220            _ => bty.clone(),
221        }
222    }
223}
224
225#[derive(Default)]
226pub struct LocalHoister {
227    vars: Vec<BoundVariableKind>,
228    preds: Vec<Expr>,
229}
230
231impl LocalHoister {
232    pub fn new(vars: Vec<BoundVariableKind>) -> Self {
233        LocalHoister { vars, preds: vec![] }
234    }
235
236    pub fn bind<T>(self, f: impl FnOnce(List<BoundVariableKind>, Vec<Expr>) -> T) -> Binder<T> {
237        let vars = List::from_vec(self.vars);
238        Binder::bind_with_vars(f(vars.clone(), self.preds), vars)
239    }
240}
241
242impl HoisterDelegate for &mut LocalHoister {
243    fn hoist_exists(&mut self, ty_ctor: &TyCtor) -> Ty {
244        ty_ctor.replace_bound_refts_with(|sort, mode, kind| {
245            let idx = self.vars.len();
246            self.vars
247                .push(BoundVariableKind::Refine(sort.clone(), mode, kind));
248            Expr::bvar(INNERMOST, BoundVar::from_usize(idx), kind)
249        })
250    }
251
252    fn hoist_constr(&mut self, pred: Expr) {
253        self.preds.push(pred);
254    }
255}
256
257impl PolyFnSig {
258    /// Convert a function signature with existentials to one where they are all
259    /// bound at the top level. Performs a transparent (i.e. not shallow)
260    /// canonicalization.
261    /// The uses the `LocalHoister` machinery to convert a function template _without_
262    /// binders, e.g. `fn ({v.i32 | *}) -> {v.i32|*})`
263    /// into one _with_ input binders, e.g. `forall <a:int>. fn ({i32[a]|*}) -> {v.i32|*}`
264    /// after which the hole-filling machinery can be used to fill in the holes.
265    /// This lets us get "dependent signatures" for closures, where the output
266    /// can refer to the input. e.g. see `tests/pos/surface/closure09.rs`
267    pub fn hoist_input_binders(&self) -> Self {
268        let original_vars = self.vars().to_vec();
269        let fn_sig = self.skip_binder_ref();
270        let mut delegate = LocalHoister { vars: original_vars, preds: fn_sig.requires().to_vec() };
271        let mut hoister = Hoister::with_delegate(&mut delegate).transparent();
272
273        let inputs = fn_sig
274            .inputs()
275            .iter()
276            .map(|ty| hoister.hoist(ty))
277            .collect_vec();
278
279        delegate.bind(|_vars, mut preds| {
280            let mut keep_hole = true;
281            preds.retain(|pred| {
282                if let ExprKind::Hole(HoleKind::Pred) = pred.kind() {
283                    std::mem::replace(&mut keep_hole, false)
284                } else {
285                    true
286                }
287            });
288
289            FnSig::new(
290                fn_sig.safety,
291                fn_sig.abi,
292                preds.into(),
293                inputs.into(),
294                fn_sig.output().clone(),
295            )
296        })
297    }
298}
299
300impl Ty {
301    /// Hoist existentials and predicates inside the type stopping when encountering the first
302    /// type constructor.
303    pub fn shallow_canonicalize(&self) -> CanonicalTy {
304        let mut delegate = LocalHoister::default();
305        let ty = self.shift_in_escaping(1);
306        let ty = Hoister::with_delegate(&mut delegate).hoist(&ty);
307        let constr_ty = delegate.bind(|_, preds| {
308            let pred = Expr::and_from_iter(preds);
309            CanonicalConstrTy { ty, pred }
310        });
311        if constr_ty.vars().is_empty() {
312            CanonicalTy::Constr(constr_ty.skip_binder().shift_out_escaping(1))
313        } else {
314            CanonicalTy::Exists(constr_ty)
315        }
316    }
317}
318
319#[derive(TypeVisitable, TypeFoldable)]
320pub struct CanonicalConstrTy {
321    /// Guaranteed to not have any (shallow) [existential] or [constraint] types
322    ///
323    /// [existential]: TyKind::Exists
324    /// [constraint]: TyKind::Constr
325    ty: Ty,
326    pred: Expr,
327}
328
329impl CanonicalConstrTy {
330    pub fn ty(&self) -> Ty {
331        self.ty.clone()
332    }
333
334    pub fn pred(&self) -> Expr {
335        self.pred.clone()
336    }
337
338    pub fn to_ty(&self) -> Ty {
339        Ty::constr(self.pred(), self.ty())
340    }
341}
342
343/// A (shallowly) canonicalized type. This can be either of the form `{T | p}` or `∃v0,…,vn. {T | p}`,
344/// where `T` doesnt have any (shallow) [existential] or [constraint] types.
345///
346/// When canonicalizing a type without a [constraint] type, `p` will be [`Expr::tt()`].
347///
348/// [existential]: TyKind::Exists
349/// [constraint]: TyKind::Constr
350#[derive(TypeVisitable)]
351pub enum CanonicalTy {
352    /// A type of the form `{T | p}`
353    Constr(CanonicalConstrTy),
354    /// A type of the form `∃v0,…,vn. {T | p}`
355    Exists(Binder<CanonicalConstrTy>),
356}
357
358impl CanonicalTy {
359    pub fn to_ty(&self) -> Ty {
360        match self {
361            CanonicalTy::Constr(constr_ty) => constr_ty.to_ty(),
362            CanonicalTy::Exists(poly_constr_ty) => {
363                Ty::exists(poly_constr_ty.as_ref().map(CanonicalConstrTy::to_ty))
364            }
365        }
366    }
367
368    pub fn as_ty_or_base(&self) -> TyOrBase {
369        match self {
370            CanonicalTy::Constr(constr_ty) => {
371                if let TyKind::Indexed(bty, idx) = constr_ty.ty.kind() {
372                    // given {b[e] | p} return λv. {b[v] | p ∧ v == e}
373
374                    // HACK(nilehmann) avoid adding trivial `v == ()` equalities, if we don't do it,
375                    // some debug assertions fail. The assertions expect types to be unrefined so they
376                    // only check for syntactical equality. We should change those cases to handle
377                    // refined types and/or ensure some canonical representation for unrefined types.
378                    let pred = if idx.is_unit() {
379                        constr_ty.pred.clone()
380                    } else {
381                        Expr::and(&constr_ty.pred, Expr::eq(Expr::nu(), idx.shift_in_escaping(1)))
382                    };
383                    let sort = bty.sort();
384                    let constr = SubsetTy::new(bty.shift_in_escaping(1), Expr::nu(), pred);
385                    TyOrBase::Base(Binder::bind_with_sort(constr, sort))
386                } else {
387                    TyOrBase::Ty(self.to_ty())
388                }
389            }
390            CanonicalTy::Exists(poly_constr_ty) => {
391                let constr = poly_constr_ty.as_ref().skip_binder();
392                if let TyKind::Indexed(bty, idx) = constr.ty.kind()
393                    && idx.is_nu()
394                {
395                    let ctor = poly_constr_ty
396                        .as_ref()
397                        .map(|constr| SubsetTy::new(bty.clone(), idx, &constr.pred));
398                    TyOrBase::Base(ctor)
399                } else {
400                    TyOrBase::Ty(self.to_ty())
401                }
402            }
403        }
404    }
405}
406
407mod pretty {
408    use super::*;
409    use crate::pretty::*;
410
411    impl Pretty for CanonicalConstrTy {
412        fn fmt(&self, cx: &PrettyCx, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413            if self.pred().is_trivially_true() {
414                w!(cx, f, "{:?}", &self.ty)
415            } else {
416                w!(cx, f, "{{ {:?} | {:?} }}", &self.ty, &self.pred)
417            }
418        }
419    }
420
421    impl Pretty for CanonicalTy {
422        fn fmt(&self, cx: &PrettyCx, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
423            match self {
424                CanonicalTy::Constr(constr) => w!(cx, f, "{:?}", constr),
425                CanonicalTy::Exists(poly_constr) => {
426                    let redundant_bvars = poly_constr.skip_binder_ref().redundant_bvars();
427                    cx.with_bound_vars_removable(
428                        poly_constr.vars(),
429                        redundant_bvars,
430                        None,
431                        |f_body| {
432                            let constr = poly_constr.skip_binder_ref();
433                            if constr.pred().is_trivially_true() {
434                                w!(cx, f_body, "{:?}", &constr.ty)
435                            } else {
436                                w!(cx, f_body, "{:?} | {:?}", &constr.ty, &constr.pred)
437                            }
438                        },
439                        |(), bound_var_layer, body| {
440                            let vars = poly_constr
441                                .vars()
442                                .into_iter()
443                                .enumerate()
444                                .filter_map(|(idx, var)| {
445                                    let not_removed = !bound_var_layer
446                                        .successfully_removed_vars
447                                        .contains(&BoundVar::from_usize(idx));
448                                    let refine_var = matches!(var, BoundVariableKind::Refine(..));
449                                    if not_removed && refine_var { Some(var.clone()) } else { None }
450                                })
451                                .collect_vec();
452                            if vars.is_empty() {
453                                write!(f, "{}", body)
454                            } else {
455                                let left = "{";
456                                let right = format!(". {} }}", body);
457                                cx.fmt_bound_vars(false, left, &vars, &right, f)
458                            }
459                        },
460                    )
461                }
462            }
463        }
464    }
465
466    impl_debug_with_default_cx!(CanonicalTy, CanonicalConstrTy);
467}