flux_middle/rty/
canonicalize.rs

1//! A canonical type is a type where all [existentials] and [constraint predicates] are *hoisted* to
2//! the top level. For example, the canonical version of `∃a. {∃b. i32[a + b] | b > 0}` is
3//! `∃a,b. {i32[a + b] | b > 0}`.
4//!
5//! Type constructors introduce scopes that can limit the hoisting. For instance, it is generally
6//! not permitted to hoist an existential out of a generic argument. For example, in `Vec<∃v. i32[v]>`
7//! the existential inside the `Vec` cannot be hoisted out.
8//!
9//! However, some type constructors are more "lenient" with respect to hoisting. Consider the tuple
10//! `(∃a. i32[a], ∃b. i32[b])`. Hoisting the existentials results in `∃a,b. (i32[a], i32[b])` which
11//! is an equivalent type (in the sense that subtyping holds both ways). The same applies to shared
12//! references: `&∃a. i32[a]` is equivalent to `∃a. &i32[a]`. We refer to this class of type
13//! constructors as *transparent*. Hoisting existential out of transparent type constructors is useful
14//! as it allows the logical information to be extracted from the type.
15//!
16//! And important case is mutable references. In some situations, it is sound to hoist out of mutable
17//! references. For example, if we have a variable in the environment of type `&mut ∃v. T[v]`, it is
18//! sound to treat it as `&mut T[a]` for a freshly generated `a` (assuming the lifetime of the
19//! reference is alive). However, this may result in a type that is *too specific* because the index
20//! `a` cannot be updated anymore.
21//!
22//! By default, we do *shallow* hoisting, i.e., we stop at the first type constructor. This is enough
23//! for cases where we need to inspect a type structurally one level. The amount of hoisting can be
24//! controlled by configuring the [`Hoister`] struct.
25//!
26//! It's also important to note that canonizalization doesn't imply any form of semantic equality
27//! and it is just a best effort to facilitate syntactic manipulation. For example, the types
28//! `∃a,b. (i32[a], i32[b])` and `∃a,b. (i32[b], i32[a])` are semantically equal but hoisting won't
29//! account for it.
30//!
31//! [existentials]: TyKind::Exists
32//! [constraint predicates]: TyKind::Constr
33use flux_arc_interner::List;
34use flux_macros::{TypeFoldable, TypeVisitable};
35use itertools::Itertools;
36use rustc_ast::Mutability;
37use rustc_span::Symbol;
38use rustc_type_ir::{BoundVar, INNERMOST};
39
40use super::{
41    BaseTy, Binder, BoundVariableKind, Expr, FnSig, GenericArg, PolyFnSig, SubsetTy, Ty, TyCtor,
42    TyKind, TyOrBase,
43    fold::{TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitable},
44};
45use crate::rty::{BoundReftKind, ExprKind, GenericArgsExt, HoleKind};
46
47/// The [`Hoister`] struct is responsible for hoisting existentials and predicates out of a type.
48/// It can be configured to stop hoisting at specific type constructors.
49///
50/// The struct is generic on a delegate `D` because we use it to do *local* hoisting, keeping
51/// variables bound with a [`Binder`], and for *freeing* variables into the refinement context.
52// Should we use a builder for this?
53pub struct Hoister<D> {
54    pub delegate: D,
55    in_boxes: bool,
56    in_downcast: bool,
57    in_mut_refs: bool,
58    in_shr_refs: bool,
59    in_strg_refs: bool,
60    in_tuples: bool,
61    existentials: bool,
62    slices: bool,
63}
64
65pub trait HoisterDelegate {
66    fn hoist_exists(&mut self, ty_ctor: &TyCtor) -> Ty;
67    fn hoist_constr(&mut self, pred: Expr);
68}
69
70impl<D> Hoister<D> {
71    pub fn with_delegate(delegate: D) -> Self {
72        Hoister {
73            delegate,
74            in_tuples: false,
75            in_shr_refs: false,
76            in_mut_refs: false,
77            in_strg_refs: false,
78            in_boxes: false,
79            in_downcast: false,
80            existentials: true,
81            slices: false,
82        }
83    }
84
85    pub fn hoist_inside_shr_refs(mut self, shr_refs: bool) -> Self {
86        self.in_shr_refs = shr_refs;
87        self
88    }
89
90    pub fn hoist_inside_mut_refs(mut self, mut_refs: bool) -> Self {
91        self.in_mut_refs = mut_refs;
92        self
93    }
94
95    pub fn hoist_inside_strg_refs(mut self, strg_refs: bool) -> Self {
96        self.in_strg_refs = strg_refs;
97        self
98    }
99
100    pub fn hoist_inside_tuples(mut self, tuples: bool) -> Self {
101        self.in_tuples = tuples;
102        self
103    }
104
105    pub fn hoist_inside_boxes(mut self, boxes: bool) -> Self {
106        self.in_boxes = boxes;
107        self
108    }
109
110    pub fn hoist_inside_downcast(mut self, downcast: bool) -> Self {
111        self.in_downcast = downcast;
112        self
113    }
114
115    pub fn hoist_existentials(mut self, exists: bool) -> Self {
116        self.existentials = exists;
117        self
118    }
119
120    pub fn hoist_slices(mut self, slices: bool) -> Self {
121        self.slices = slices;
122        self
123    }
124
125    pub fn transparent(self) -> Self {
126        self.hoist_inside_boxes(true)
127            .hoist_inside_downcast(true)
128            .hoist_inside_mut_refs(false)
129            .hoist_inside_shr_refs(true)
130            .hoist_inside_strg_refs(true)
131            .hoist_inside_tuples(true)
132            .hoist_slices(true)
133    }
134
135    pub fn shallow(self) -> Self {
136        self.hoist_inside_boxes(false)
137            .hoist_inside_downcast(false)
138            .hoist_inside_mut_refs(false)
139            .hoist_inside_shr_refs(false)
140            .hoist_inside_strg_refs(false)
141            .hoist_inside_tuples(false)
142    }
143}
144
145impl<D: HoisterDelegate> Hoister<D> {
146    pub fn hoist(&mut self, ty: &Ty) -> Ty {
147        ty.fold_with(self)
148    }
149}
150
151/// Is `ty` of the form `&m (&m ... (&m T))` where `T` is an existentially indexed slice or array?
152/// We need to do a "transitive" check to deal with cases like `&mut &mut [i32]`
153/// which arise from closures like that in `tests/tests/pos/surface/closure03.rs`.
154// TODO(nilehmann) we could generalize this and hoist anything inside a mutable reference that
155// doesn't lose generality. This includes:
156// - The length of a slice (which can't change)
157// - Parameters with unit sorts
158// - Constraints that are not inside an existential (e.g., the `p` in  &mut {i32 | p})
159fn is_indexed_slice(ty: &Ty) -> bool {
160    let Some(bty) = ty.as_bty_skipping_existentials() else { return false };
161    match bty {
162        BaseTy::Slice(_) | BaseTy::Array(..) => true,
163        BaseTy::Ref(_, ty, _) => is_indexed_slice(ty),
164        _ => false,
165    }
166}
167
168impl<D: HoisterDelegate> TypeFolder for Hoister<D> {
169    fn fold_ty(&mut self, ty: &Ty) -> Ty {
170        match ty.kind() {
171            TyKind::Indexed(bty, idx) => Ty::indexed(bty.fold_with(self), idx.clone()),
172            TyKind::Exists(ty_ctor) if self.existentials => {
173                // Avoid hoisting useless parameters for unit sorts. This is important for
174                // canonicalization because we assume mutable references won't be under a
175                // binder after we canonicalize them.
176                // FIXME(nilehmann) this same logic is repeated in a couple of places, e.g.,
177                // TyCtor::to_ty
178                match &ty_ctor.vars()[..] {
179                    [BoundVariableKind::Refine(sort, ..)] => {
180                        if sort.is_unit() {
181                            ty_ctor.replace_bound_reft(&Expr::unit())
182                        } else if let Some(def_id) = sort.is_unit_adt() {
183                            ty_ctor.replace_bound_reft(&Expr::unit_struct(def_id))
184                        } else {
185                            self.delegate.hoist_exists(ty_ctor)
186                        }
187                    }
188                    _ => self.delegate.hoist_exists(ty_ctor),
189                }
190                .fold_with(self)
191            }
192            TyKind::Constr(pred, ty) => {
193                self.delegate.hoist_constr(pred.clone());
194                ty.fold_with(self)
195            }
196            TyKind::StrgRef(..) if self.in_strg_refs => ty.super_fold_with(self),
197            TyKind::Downcast(..) if self.in_downcast => ty.super_fold_with(self),
198            _ => ty.clone(),
199        }
200    }
201
202    fn fold_bty(&mut self, bty: &BaseTy) -> BaseTy {
203        match bty {
204            BaseTy::Adt(adt_def, args) if adt_def.is_box() && self.in_boxes => {
205                let (boxed, alloc) = args.box_args();
206                let args = List::from_arr([GenericArg::Ty(boxed.fold_with(self)), alloc.clone()]);
207                BaseTy::Adt(adt_def.clone(), args)
208            }
209            BaseTy::Ref(re, ty, mutability) if is_indexed_slice(ty) && self.slices => {
210                BaseTy::Ref(*re, ty.fold_with(self), *mutability)
211            }
212            BaseTy::Ref(re, ty, Mutability::Not) if self.in_shr_refs => {
213                BaseTy::Ref(*re, ty.fold_with(self), Mutability::Not)
214            }
215            BaseTy::Ref(re, ty, Mutability::Mut) if self.in_mut_refs => {
216                BaseTy::Ref(*re, ty.fold_with(self), Mutability::Mut)
217            }
218            BaseTy::Tuple(tys) if self.in_tuples => BaseTy::Tuple(tys.fold_with(self)),
219            _ => bty.clone(),
220        }
221    }
222}
223
224#[derive(Default)]
225pub struct LocalHoister {
226    vars: Vec<BoundVariableKind>,
227    preds: Vec<Expr>,
228    pub name: Option<Symbol>,
229}
230
231impl LocalHoister {
232    pub fn new(vars: Vec<BoundVariableKind>) -> Self {
233        LocalHoister { vars, preds: vec![], name: None }
234    }
235
236    pub fn bind<T>(self, f: impl FnOnce(List<BoundVariableKind>, Vec<Expr>) -> T) -> Binder<T> {
237        let vars = List::from_vec(self.vars);
238        Binder::bind_with_vars(f(vars.clone(), self.preds), vars)
239    }
240}
241
242impl HoisterDelegate for &mut LocalHoister {
243    fn hoist_exists(&mut self, ty_ctor: &TyCtor) -> Ty {
244        ty_ctor.replace_bound_refts_with(|sort, mode, kind| {
245            let idx = self.vars.len();
246            let kind = if let Some(name) = self.name { BoundReftKind::Named(name) } else { kind };
247            self.vars
248                .push(BoundVariableKind::Refine(sort.clone(), mode, kind));
249            Expr::bvar(INNERMOST, BoundVar::from_usize(idx), kind)
250        })
251    }
252
253    fn hoist_constr(&mut self, pred: Expr) {
254        self.preds.push(pred);
255    }
256}
257
258impl PolyFnSig {
259    /// Convert a function signature with existentials to one where they are all
260    /// bound at the top level. Performs a transparent (i.e. not shallow)
261    /// canonicalization.
262    /// The uses the `LocalHoister` machinery to convert a function template _without_
263    /// binders, e.g. `fn ({v.i32 | *}) -> {v.i32|*})`
264    /// into one _with_ input binders, e.g. `forall <a:int>. fn ({i32[a]|*}) -> {v.i32|*}`
265    /// after which the hole-filling machinery can be used to fill in the holes.
266    /// This lets us get "dependent signatures" for closures, where the output
267    /// can refer to the input. e.g. see `tests/pos/surface/closure09.rs`
268    pub fn hoist_input_binders(&self) -> Self {
269        let original_vars = self.vars().to_vec();
270        let fn_sig = self.skip_binder_ref();
271
272        let mut delegate =
273            LocalHoister { vars: original_vars, preds: fn_sig.requires().to_vec(), name: None };
274        let mut hoister = Hoister::with_delegate(&mut delegate).transparent();
275
276        let inputs = fn_sig
277            .inputs()
278            .iter()
279            .map(|ty| hoister.hoist(ty))
280            .collect_vec();
281
282        delegate.bind(|_vars, mut preds| {
283            let mut keep_hole = true;
284            preds.retain(|pred| {
285                if let ExprKind::Hole(HoleKind::Pred) = pred.kind() {
286                    std::mem::replace(&mut keep_hole, false)
287                } else {
288                    true
289                }
290            });
291
292            FnSig::new(
293                fn_sig.safety,
294                fn_sig.abi,
295                preds.into(),
296                inputs.into(),
297                fn_sig.output().clone(),
298                fn_sig.no_panic.clone(),
299                fn_sig.lifted,
300            )
301        })
302    }
303}
304
305impl Ty {
306    /// Hoist existentials and predicates inside the type stopping when encountering the first
307    /// type constructor.
308    pub fn shallow_canonicalize(&self) -> CanonicalTy {
309        let mut delegate = LocalHoister::default();
310        let ty = self.shift_in_escaping(1);
311        let ty = Hoister::with_delegate(&mut delegate).hoist(&ty);
312        let constr_ty = delegate.bind(|_, preds| {
313            let pred = Expr::and_from_iter(preds);
314            CanonicalConstrTy { ty, pred }
315        });
316        if constr_ty.vars().is_empty() {
317            CanonicalTy::Constr(constr_ty.skip_binder().shift_out_escaping(1))
318        } else {
319            CanonicalTy::Exists(constr_ty)
320        }
321    }
322}
323
324#[derive(TypeVisitable, TypeFoldable)]
325pub struct CanonicalConstrTy {
326    /// Guaranteed to not have any (shallow) [existential] or [constraint] types
327    ///
328    /// [existential]: TyKind::Exists
329    /// [constraint]: TyKind::Constr
330    ty: Ty,
331    pred: Expr,
332}
333
334impl CanonicalConstrTy {
335    pub fn ty(&self) -> Ty {
336        self.ty.clone()
337    }
338
339    pub fn pred(&self) -> Expr {
340        self.pred.clone()
341    }
342
343    pub fn to_ty(&self) -> Ty {
344        Ty::constr(self.pred(), self.ty())
345    }
346}
347
348/// A (shallowly) canonicalized type. This can be either of the form `{T | p}` or `∃v0,…,vn. {T | p}`,
349/// where `T` doesnt have any (shallow) [existential] or [constraint] types.
350///
351/// When canonicalizing a type without a [constraint] type, `p` will be [`Expr::tt()`].
352///
353/// [existential]: TyKind::Exists
354/// [constraint]: TyKind::Constr
355#[derive(TypeVisitable)]
356pub enum CanonicalTy {
357    /// A type of the form `{T | p}`
358    Constr(CanonicalConstrTy),
359    /// A type of the form `∃v0,…,vn. {T | p}`
360    Exists(Binder<CanonicalConstrTy>),
361}
362
363impl CanonicalTy {
364    pub fn to_ty(&self) -> Ty {
365        match self {
366            CanonicalTy::Constr(constr_ty) => constr_ty.to_ty(),
367            CanonicalTy::Exists(poly_constr_ty) => {
368                Ty::exists(poly_constr_ty.as_ref().map(CanonicalConstrTy::to_ty))
369            }
370        }
371    }
372
373    pub fn as_ty_or_base(&self) -> TyOrBase {
374        match self {
375            CanonicalTy::Constr(constr_ty) => {
376                if let TyKind::Indexed(bty, idx) = constr_ty.ty.kind() {
377                    // given {b[e] | p} return λv. {b[v] | p ∧ v == e}
378
379                    // HACK(nilehmann) avoid adding trivial `v == ()` equalities, if we don't do it,
380                    // some debug assertions fail. The assertions expect types to be unrefined so they
381                    // only check for syntactical equality. We should change those cases to handle
382                    // refined types and/or ensure some canonical representation for unrefined types.
383                    let pred = if idx.is_unit() {
384                        constr_ty.pred.shift_in_escaping(1)
385                    } else {
386                        Expr::and(
387                            constr_ty.pred.shift_in_escaping(1),
388                            Expr::eq(Expr::nu(), idx.shift_in_escaping(1)),
389                        )
390                    };
391                    let sort = bty.sort();
392                    let constr = SubsetTy::new(bty.shift_in_escaping(1), Expr::nu(), pred);
393                    TyOrBase::Base(Binder::bind_with_sort(constr, sort))
394                } else {
395                    TyOrBase::Ty(self.to_ty())
396                }
397            }
398            CanonicalTy::Exists(poly_constr_ty) => {
399                let constr = poly_constr_ty.as_ref().skip_binder();
400                if let TyKind::Indexed(bty, idx) = constr.ty.kind()
401                    && idx.is_nu()
402                {
403                    let ctor = poly_constr_ty
404                        .as_ref()
405                        .map(|constr| SubsetTy::new(bty.clone(), idx, &constr.pred));
406                    TyOrBase::Base(ctor)
407                } else {
408                    TyOrBase::Ty(self.to_ty())
409                }
410            }
411        }
412    }
413}
414
415mod pretty {
416    use super::*;
417    use crate::pretty::*;
418
419    impl Pretty for CanonicalConstrTy {
420        fn fmt(&self, cx: &PrettyCx, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
421            if self.pred().is_trivially_true() {
422                w!(cx, f, "{:?}", &self.ty)
423            } else {
424                w!(cx, f, "{{ {:?} | {:?} }}", &self.ty, &self.pred)
425            }
426        }
427    }
428
429    impl Pretty for CanonicalTy {
430        fn fmt(&self, cx: &PrettyCx, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
431            match self {
432                CanonicalTy::Constr(constr) => w!(cx, f, "{:?}", constr),
433                CanonicalTy::Exists(poly_constr) => {
434                    let redundant_bvars = poly_constr.skip_binder_ref().redundant_bvars();
435                    cx.with_bound_vars_removable(poly_constr.vars(), redundant_bvars, None, || {
436                        let constr = poly_constr.skip_binder_ref();
437                        let ty_fmt = format_cx!(cx, "{:?}", &constr.ty);
438                        let pred_fmt = if !constr.pred().is_trivially_true() {
439                            Some(format_cx!(cx, "{:?}", constr.pred()))
440                        } else {
441                            None
442                        };
443
444                        let vars = cx
445                            .bvar_env
446                            .peek_layer()
447                            .unwrap()
448                            .filter_vars(poly_constr.vars());
449
450                        if vars.is_empty() {
451                            if let Some(pred_fmt) = pred_fmt {
452                                write!(f, "{{ {ty_fmt} | {pred_fmt} }}")
453                            } else {
454                                write!(f, "{ty_fmt}")
455                            }
456                        } else {
457                            cx.fmt_bound_vars(false, "{", &vars, ". ", f)?;
458                            if let Some(pred_fmt) = pred_fmt {
459                                write!(f, "{ty_fmt} | {pred_fmt} }}")
460                            } else {
461                                write!(f, "{ty_fmt} }}")
462                            }
463                        }
464                    })
465                }
466            }
467        }
468    }
469
470    impl_debug_with_default_cx!(CanonicalTy, CanonicalConstrTy);
471}