flux_middle/rty/
fold.rs

1//! This modules follows the implementation of folding in rustc. For more information read the
2//! documentation in `rustc_type_ir::fold`.
3
4use std::ops::ControlFlow;
5
6use flux_arc_interner::{Internable, List};
7use flux_common::bug;
8use itertools::Itertools;
9use rustc_data_structures::fx::FxHashMap;
10use rustc_hash::FxHashSet;
11use rustc_type_ir::{BoundVar, DebruijnIndex, INNERMOST};
12
13use super::{
14    BaseTy, Binder, BoundVariableKinds, Const, EVid, EarlyReftParam, Ensures, Expr, ExprKind,
15    GenericArg, Name, OutlivesPredicate, PolyFuncSort, PtrKind, ReBound, ReErased, Region, Sort,
16    SubsetTy, Ty, TyKind, TyOrBase, normalize::Normalizer,
17};
18use crate::{
19    global_env::GlobalEnv,
20    rty::{BoundReft, Var, VariantSig, expr::HoleKind},
21};
22
23pub trait TypeVisitor: Sized {
24    type BreakTy = !;
25
26    fn visit_binder<T: TypeVisitable>(&mut self, t: &Binder<T>) -> ControlFlow<Self::BreakTy> {
27        t.super_visit_with(self)
28    }
29
30    fn visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::BreakTy> {
31        expr.super_visit_with(self)
32    }
33
34    fn visit_sort(&mut self, sort: &Sort) -> ControlFlow<Self::BreakTy> {
35        sort.super_visit_with(self)
36    }
37
38    fn visit_ty(&mut self, ty: &Ty) -> ControlFlow<Self::BreakTy> {
39        ty.super_visit_with(self)
40    }
41
42    fn visit_bty(&mut self, bty: &BaseTy) -> ControlFlow<Self::BreakTy> {
43        bty.super_visit_with(self)
44    }
45}
46
47pub trait FallibleTypeFolder: Sized {
48    type Error;
49
50    fn try_fold_binder<T: TypeFoldable>(
51        &mut self,
52        t: &Binder<T>,
53    ) -> Result<Binder<T>, Self::Error> {
54        t.try_super_fold_with(self)
55    }
56
57    fn try_fold_sort(&mut self, sort: &Sort) -> Result<Sort, Self::Error> {
58        sort.try_super_fold_with(self)
59    }
60
61    fn try_fold_ty(&mut self, ty: &Ty) -> Result<Ty, Self::Error> {
62        ty.try_super_fold_with(self)
63    }
64
65    fn try_fold_bty(&mut self, bty: &BaseTy) -> Result<BaseTy, Self::Error> {
66        bty.try_super_fold_with(self)
67    }
68
69    fn try_fold_subset_ty(&mut self, constr: &SubsetTy) -> Result<SubsetTy, Self::Error> {
70        constr.try_super_fold_with(self)
71    }
72
73    fn try_fold_region(&mut self, re: &Region) -> Result<Region, Self::Error> {
74        Ok(*re)
75    }
76
77    fn try_fold_const(&mut self, c: &Const) -> Result<Const, Self::Error> {
78        c.try_super_fold_with(self)
79    }
80
81    fn try_fold_expr(&mut self, expr: &Expr) -> Result<Expr, Self::Error> {
82        expr.try_super_fold_with(self)
83    }
84}
85
86pub trait TypeFolder: FallibleTypeFolder<Error = !> {
87    fn fold_binder<T: TypeFoldable>(&mut self, t: &Binder<T>) -> Binder<T> {
88        t.super_fold_with(self)
89    }
90
91    fn fold_sort(&mut self, sort: &Sort) -> Sort {
92        sort.super_fold_with(self)
93    }
94
95    fn fold_ty(&mut self, ty: &Ty) -> Ty {
96        ty.super_fold_with(self)
97    }
98
99    fn fold_bty(&mut self, bty: &BaseTy) -> BaseTy {
100        bty.super_fold_with(self)
101    }
102
103    fn fold_subset_ty(&mut self, constr: &SubsetTy) -> SubsetTy {
104        constr.super_fold_with(self)
105    }
106
107    fn fold_region(&mut self, re: &Region) -> Region {
108        *re
109    }
110
111    fn fold_const(&mut self, c: &Const) -> Const {
112        c.super_fold_with(self)
113    }
114
115    fn fold_expr(&mut self, expr: &Expr) -> Expr {
116        expr.super_fold_with(self)
117    }
118}
119
120impl<F> FallibleTypeFolder for F
121where
122    F: TypeFolder,
123{
124    type Error = !;
125
126    fn try_fold_binder<T: TypeFoldable>(
127        &mut self,
128        t: &Binder<T>,
129    ) -> Result<Binder<T>, Self::Error> {
130        Ok(self.fold_binder(t))
131    }
132
133    fn try_fold_sort(&mut self, sort: &Sort) -> Result<Sort, Self::Error> {
134        Ok(self.fold_sort(sort))
135    }
136
137    fn try_fold_ty(&mut self, ty: &Ty) -> Result<Ty, Self::Error> {
138        Ok(self.fold_ty(ty))
139    }
140
141    fn try_fold_bty(&mut self, bty: &BaseTy) -> Result<BaseTy, Self::Error> {
142        Ok(self.fold_bty(bty))
143    }
144
145    fn try_fold_subset_ty(&mut self, ty: &SubsetTy) -> Result<SubsetTy, Self::Error> {
146        Ok(self.fold_subset_ty(ty))
147    }
148
149    fn try_fold_region(&mut self, re: &Region) -> Result<Region, Self::Error> {
150        Ok(self.fold_region(re))
151    }
152
153    fn try_fold_const(&mut self, c: &Const) -> Result<Const, Self::Error> {
154        Ok(self.fold_const(c))
155    }
156
157    fn try_fold_expr(&mut self, expr: &Expr) -> Result<Expr, Self::Error> {
158        Ok(self.fold_expr(expr))
159    }
160}
161
162pub trait TypeVisitable: Sized {
163    fn visit_with<V: TypeVisitor>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy>;
164
165    fn has_escaping_bvars(&self) -> bool {
166        self.has_escaping_bvars_at_or_above(INNERMOST)
167    }
168
169    /// Returns `true` if `self` has any late-bound vars that are either
170    /// bound by `binder` or bound by some binder outside of `binder`.
171    /// If `binder` is `ty::INNERMOST`, this indicates whether
172    /// there are any late-bound vars that appear free.
173    fn has_escaping_bvars_at_or_above(&self, binder: DebruijnIndex) -> bool {
174        struct HasEscapingVars {
175            /// Anything bound by `outer_index` or "above" is escaping.
176            outer_index: DebruijnIndex,
177        }
178
179        impl TypeVisitor for HasEscapingVars {
180            type BreakTy = ();
181
182            fn visit_binder<T: TypeVisitable>(&mut self, t: &Binder<T>) -> ControlFlow<()> {
183                self.outer_index.shift_in(1);
184                t.super_visit_with(self)?;
185                self.outer_index.shift_out(1);
186                ControlFlow::Continue(())
187            }
188
189            // TODO(nilehmann) keep track of the outermost binder to optimize this, i.e.,
190            // what rustc calls outer_exclusive_binder.
191            fn visit_expr(&mut self, expr: &Expr) -> ControlFlow<()> {
192                if let ExprKind::Var(Var::Bound(debruijn, _)) = expr.kind() {
193                    if *debruijn >= self.outer_index {
194                        ControlFlow::Break(())
195                    } else {
196                        ControlFlow::Continue(())
197                    }
198                } else {
199                    expr.super_visit_with(self)
200                }
201            }
202        }
203        let mut visitor = HasEscapingVars { outer_index: binder };
204        self.visit_with(&mut visitor).is_break()
205    }
206
207    /// Returns the set of all free variables.
208    /// For example, `Vec<i32[n]>{v : v > m}` returns `{n, m}`.
209    fn fvars(&self) -> FxHashSet<Name> {
210        struct CollectFreeVars(FxHashSet<Name>);
211
212        impl TypeVisitor for CollectFreeVars {
213            fn visit_expr(&mut self, e: &Expr) -> ControlFlow<Self::BreakTy> {
214                if let ExprKind::Var(Var::Free(name)) = e.kind() {
215                    self.0.insert(*name);
216                }
217                e.super_visit_with(self)
218            }
219        }
220
221        let mut collector = CollectFreeVars(FxHashSet::default());
222        let _ = self.visit_with(&mut collector);
223        collector.0
224    }
225
226    fn early_params(&self) -> FxHashSet<EarlyReftParam> {
227        struct CollectEarlyParams(FxHashSet<EarlyReftParam>);
228
229        impl TypeVisitor for CollectEarlyParams {
230            fn visit_expr(&mut self, e: &Expr) -> ControlFlow<Self::BreakTy> {
231                if let ExprKind::Var(Var::EarlyParam(param)) = e.kind() {
232                    self.0.insert(*param);
233                }
234                e.super_visit_with(self)
235            }
236        }
237
238        let mut collector = CollectEarlyParams(FxHashSet::default());
239        let _ = self.visit_with(&mut collector);
240        collector.0
241    }
242
243    /// Gives the indices of the provided bvars which:
244    ///   1. Only occur a single time.
245    ///   2. In their occurrence, are either
246    ///      a. The direct argument in an index (e.g. `exists b0. usize[b0]`)
247    ///      b. The direct argument of a constructor in an index (e.g.
248    ///      `exists b0. Vec<usize>[{len: b0}]`)
249    ///
250    /// This is to be used for "re-sugaring" existentials into surface syntax
251    /// that doesn't use existentials.
252    ///
253    /// For 2b., we do need to be careful to ensure that if a constructor has
254    /// multiple arguments, they _all_ are redundant bvars, e.g. as in
255    ///
256    ///     exists b0, b1. RMat<f32>[{rows: b0, cols: b1}]
257    ///
258    /// which may be rewritten as `RMat<f32>`,
259    /// versus the (unlikely) edge case
260    ///
261    ///     exists b0. RMat<f32>[{rows: b0, cols: b0}]
262    ///
263    /// for which the existential is now necessary.
264    ///
265    /// NOTE: this only applies to refinement bvars.
266    fn redundant_bvars(&self) -> FxHashSet<BoundVar> {
267        struct RedundantBVarFinder {
268            current_index: DebruijnIndex,
269            total_bvar_occurrences: FxHashMap<BoundVar, usize>,
270            bvars_appearing_in_index: FxHashSet<BoundVar>,
271        }
272
273        impl TypeVisitor for RedundantBVarFinder {
274            // Here we count all times we see a bvar
275            fn visit_expr(&mut self, e: &Expr) -> ControlFlow<Self::BreakTy> {
276                if let ExprKind::Var(Var::Bound(debruijn, BoundReft { var, .. })) = e.kind()
277                    && debruijn == &self.current_index
278                {
279                    self.total_bvar_occurrences
280                        .entry(*var)
281                        .and_modify(|count| {
282                            *count += 1;
283                        })
284                        .or_insert(1);
285                }
286                e.super_visit_with(self)
287            }
288
289            fn visit_ty(&mut self, ty: &Ty) -> ControlFlow<Self::BreakTy> {
290                // Here we check for bvars specifically as the direct arguments
291                // to an index or as the direct arguments to a Ctor in an index.
292                if let TyKind::Indexed(_bty, expr) = ty.kind() {
293                    match expr.kind() {
294                        ExprKind::Var(Var::Bound(debruijn, BoundReft { var, .. })) => {
295                            if debruijn == &self.current_index {
296                                self.bvars_appearing_in_index.insert(*var);
297                            }
298                        }
299                        ExprKind::Ctor(_ctor, exprs) => {
300                            exprs.iter().for_each(|expr| {
301                                if let ExprKind::Var(Var::Bound(debruijn, BoundReft { var, .. })) =
302                                    expr.kind()
303                                    && debruijn == &self.current_index
304                                {
305                                    self.bvars_appearing_in_index.insert(*var);
306                                }
307                            });
308                        }
309                        _ => {}
310                    }
311                }
312                ty.super_visit_with(self)
313            }
314
315            fn visit_binder<T: TypeVisitable>(
316                &mut self,
317                t: &Binder<T>,
318            ) -> ControlFlow<Self::BreakTy> {
319                self.current_index.shift_in(1);
320                t.super_visit_with(self)?;
321                self.current_index.shift_out(1);
322                ControlFlow::Continue(())
323            }
324        }
325
326        let mut finder = RedundantBVarFinder {
327            current_index: INNERMOST,
328            total_bvar_occurrences: FxHashMap::default(),
329            bvars_appearing_in_index: FxHashSet::default(),
330        };
331        let _ = self.visit_with(&mut finder);
332
333        finder
334            .bvars_appearing_in_index
335            .into_iter()
336            .filter(|var_index| finder.total_bvar_occurrences.get(var_index) == Some(&1))
337            .collect()
338    }
339}
340
341pub trait TypeSuperVisitable: TypeVisitable {
342    fn super_visit_with<V: TypeVisitor>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy>;
343}
344
345pub trait TypeFoldable: TypeVisitable {
346    fn try_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error>;
347
348    fn fold_with<F: TypeFolder>(&self, folder: &mut F) -> Self {
349        self.try_fold_with(folder).into_ok()
350    }
351
352    /// Normalize expressions by applying beta reductions for tuples and lambda abstractions.
353    fn normalize(&self, genv: GlobalEnv) -> Self {
354        self.fold_with(&mut Normalizer::new(genv, None))
355    }
356
357    /// Replaces all [holes] with the result of calling a closure. The closure takes a list with
358    /// all the *layers* of [bound] variables at the point the hole was found. Each layer corresponds
359    /// to the list of bound variables at that level. The list is ordered from outermost to innermost
360    /// binder, i.e., the last element is the binder closest to the hole.
361    ///
362    /// [holes]: ExprKind::Hole
363    /// [bound]: Binder
364    fn replace_holes(&self, f: impl FnMut(&[BoundVariableKinds], HoleKind) -> Expr) -> Self {
365        struct ReplaceHoles<F>(F, Vec<BoundVariableKinds>);
366
367        impl<F> TypeFolder for ReplaceHoles<F>
368        where
369            F: FnMut(&[BoundVariableKinds], HoleKind) -> Expr,
370        {
371            fn fold_binder<T: TypeFoldable>(&mut self, t: &Binder<T>) -> Binder<T> {
372                self.1.push(t.vars().clone());
373                let t = t.super_fold_with(self);
374                self.1.pop();
375                t
376            }
377
378            fn fold_expr(&mut self, e: &Expr) -> Expr {
379                if let ExprKind::Hole(kind) = e.kind() {
380                    self.0(&self.1, kind.clone())
381                } else {
382                    e.super_fold_with(self)
383                }
384            }
385        }
386
387        self.fold_with(&mut ReplaceHoles(f, vec![]))
388    }
389
390    /// Remove all refinements and turn each underlying [`BaseTy`] into a [`TyKind::Exists`] with a
391    /// [`TyKind::Constr`] and a [`hole`]. For example, `Vec<{v. i32[v] | v > 0}>[n]` becomes
392    /// `{n. Vec<{v. i32[v] | *}>[n] | *}`.
393    ///
394    /// [`hole`]: ExprKind::Hole
395    fn with_holes(&self) -> Self {
396        struct WithHoles;
397
398        impl TypeFolder for WithHoles {
399            fn fold_ty(&mut self, ty: &Ty) -> Ty {
400                if let Some(bty) = ty.as_bty_skipping_existentials() {
401                    Ty::exists_with_constr(bty.fold_with(self), Expr::hole(HoleKind::Pred))
402                } else {
403                    ty.super_fold_with(self)
404                }
405            }
406
407            fn fold_subset_ty(&mut self, constr: &SubsetTy) -> SubsetTy {
408                SubsetTy::new(constr.bty.clone(), constr.idx.clone(), Expr::hole(HoleKind::Pred))
409            }
410        }
411
412        self.fold_with(&mut WithHoles)
413    }
414
415    fn replace_evars(&self, f: &mut impl FnMut(EVid) -> Option<Expr>) -> Result<Self, EVid> {
416        struct Folder<F>(F);
417        impl<F: FnMut(EVid) -> Option<Expr>> FallibleTypeFolder for Folder<F> {
418            type Error = EVid;
419
420            fn try_fold_expr(&mut self, expr: &Expr) -> Result<Expr, Self::Error> {
421                if let ExprKind::Var(Var::EVar(evid)) = expr.kind() {
422                    if let Some(sol) = (self.0)(*evid) { Ok(sol.clone()) } else { Err(*evid) }
423                } else {
424                    expr.try_super_fold_with(self)
425                }
426            }
427        }
428
429        self.try_fold_with(&mut Folder(f))
430    }
431
432    fn shift_in_escaping(&self, amount: u32) -> Self {
433        struct Shifter {
434            current_index: DebruijnIndex,
435            amount: u32,
436        }
437
438        impl TypeFolder for Shifter {
439            fn fold_binder<T>(&mut self, t: &Binder<T>) -> Binder<T>
440            where
441                T: TypeFoldable,
442            {
443                self.current_index.shift_in(1);
444                let r = t.super_fold_with(self);
445                self.current_index.shift_out(1);
446                r
447            }
448
449            fn fold_region(&mut self, re: &Region) -> Region {
450                if let ReBound(debruijn, br) = *re
451                    && debruijn >= self.current_index
452                {
453                    ReBound(debruijn.shifted_in(self.amount), br)
454                } else {
455                    *re
456                }
457            }
458
459            fn fold_expr(&mut self, expr: &Expr) -> Expr {
460                if let ExprKind::Var(Var::Bound(debruijn, breft)) = expr.kind()
461                    && *debruijn >= self.current_index
462                {
463                    Expr::bvar(debruijn.shifted_in(self.amount), breft.var, breft.kind)
464                } else {
465                    expr.super_fold_with(self)
466                }
467            }
468        }
469        self.fold_with(&mut Shifter { amount, current_index: INNERMOST })
470    }
471
472    fn shift_out_escaping(&self, amount: u32) -> Self {
473        struct Shifter {
474            amount: u32,
475            current_index: DebruijnIndex,
476        }
477
478        impl TypeFolder for Shifter {
479            fn fold_binder<T: TypeFoldable>(&mut self, t: &Binder<T>) -> Binder<T> {
480                self.current_index.shift_in(1);
481                let t = t.super_fold_with(self);
482                self.current_index.shift_out(1);
483                t
484            }
485
486            fn fold_region(&mut self, re: &Region) -> Region {
487                if let ReBound(debruijn, br) = *re
488                    && debruijn >= self.current_index
489                {
490                    ReBound(debruijn.shifted_out(self.amount), br)
491                } else {
492                    *re
493                }
494            }
495
496            fn fold_expr(&mut self, expr: &Expr) -> Expr {
497                if let ExprKind::Var(Var::Bound(debruijn, breft)) = expr.kind()
498                    && debruijn >= &self.current_index
499                {
500                    Expr::bvar(debruijn.shifted_out(self.amount), breft.var, breft.kind)
501                } else {
502                    expr.super_fold_with(self)
503                }
504            }
505        }
506        self.fold_with(&mut Shifter { amount, current_index: INNERMOST })
507    }
508
509    fn erase_regions(&self) -> Self {
510        struct RegionEraser;
511        impl TypeFolder for RegionEraser {
512            fn fold_region(&mut self, r: &Region) -> Region {
513                match *r {
514                    ReBound(..) => *r,
515                    _ => ReErased,
516                }
517            }
518        }
519
520        self.fold_with(&mut RegionEraser)
521    }
522}
523
524pub trait TypeSuperFoldable: TypeFoldable {
525    fn try_super_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error>;
526
527    fn super_fold_with<F: TypeFolder>(&self, folder: &mut F) -> Self {
528        self.try_super_fold_with(folder).into_ok()
529    }
530}
531
532impl<T: TypeVisitable> TypeVisitable for OutlivesPredicate<T> {
533    fn visit_with<V: TypeVisitor>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
534        self.0.visit_with(visitor)?;
535        self.1.visit_with(visitor)
536    }
537}
538
539impl<T: TypeFoldable> TypeFoldable for OutlivesPredicate<T> {
540    fn try_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error> {
541        Ok(OutlivesPredicate(self.0.try_fold_with(folder)?, self.1.try_fold_with(folder)?))
542    }
543}
544
545impl TypeVisitable for Sort {
546    fn visit_with<V: TypeVisitor>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
547        visitor.visit_sort(self)
548    }
549}
550
551impl TypeSuperVisitable for Sort {
552    fn super_visit_with<V: TypeVisitor>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
553        match self {
554            Sort::Tuple(sorts) => sorts.visit_with(visitor),
555            Sort::App(_, args) => args.visit_with(visitor),
556            Sort::Func(fsort) => fsort.visit_with(visitor),
557            Sort::Alias(_, alias_ty) => alias_ty.visit_with(visitor),
558            Sort::Int
559            | Sort::Bool
560            | Sort::Real
561            | Sort::Str
562            | Sort::Char
563            | Sort::BitVec(_)
564            | Sort::Loc
565            | Sort::Param(_)
566            | Sort::Var(_)
567            | Sort::Infer(_)
568            | Sort::Err => ControlFlow::Continue(()),
569        }
570    }
571}
572
573impl TypeFoldable for Sort {
574    fn try_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error> {
575        folder.try_fold_sort(self)
576    }
577}
578
579impl TypeSuperFoldable for Sort {
580    fn try_super_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error> {
581        let sort = match self {
582            Sort::Tuple(sorts) => Sort::tuple(sorts.try_fold_with(folder)?),
583            Sort::App(ctor, sorts) => Sort::app(ctor.clone(), sorts.try_fold_with(folder)?),
584            Sort::Func(fsort) => Sort::Func(fsort.try_fold_with(folder)?),
585            Sort::Alias(kind, alias_ty) => Sort::Alias(*kind, alias_ty.try_fold_with(folder)?),
586            Sort::Int
587            | Sort::Bool
588            | Sort::Real
589            | Sort::Loc
590            | Sort::Str
591            | Sort::Char
592            | Sort::BitVec(_)
593            | Sort::Param(_)
594            | Sort::Var(_)
595            | Sort::Infer(_)
596            | Sort::Err => self.clone(),
597        };
598        Ok(sort)
599    }
600}
601
602impl TypeVisitable for PolyFuncSort {
603    fn visit_with<V: TypeVisitor>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
604        self.fsort.visit_with(visitor)
605    }
606}
607
608impl TypeFoldable for PolyFuncSort {
609    fn try_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error> {
610        Ok(PolyFuncSort { params: self.params.clone(), fsort: self.fsort.try_fold_with(folder)? })
611    }
612}
613
614impl<T> TypeVisitable for Binder<T>
615where
616    T: TypeVisitable,
617{
618    fn visit_with<V: TypeVisitor>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
619        visitor.visit_binder(self)
620    }
621}
622
623impl<T> TypeSuperVisitable for Binder<T>
624where
625    T: TypeVisitable,
626{
627    fn super_visit_with<V: TypeVisitor>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
628        self.skip_binder_ref().visit_with(visitor)
629    }
630}
631
632impl<T> TypeFoldable for Binder<T>
633where
634    T: TypeFoldable,
635{
636    fn try_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error> {
637        folder.try_fold_binder(self)
638    }
639}
640
641impl<T> TypeSuperFoldable for Binder<T>
642where
643    T: TypeFoldable,
644{
645    fn try_super_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error> {
646        Ok(Binder::bind_with_vars(
647            self.skip_binder_ref().try_fold_with(folder)?,
648            self.vars().try_fold_with(folder)?,
649        ))
650    }
651}
652
653impl TypeVisitable for VariantSig {
654    fn visit_with<V: TypeVisitor>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
655        self.fields.visit_with(visitor)?;
656        self.idx.visit_with(visitor)
657    }
658}
659
660impl TypeFoldable for VariantSig {
661    fn try_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error> {
662        let args = self.args.try_fold_with(folder)?;
663        let fields = self.fields.try_fold_with(folder)?;
664        let idx = self.idx.try_fold_with(folder)?;
665        let requires = self.requires.try_fold_with(folder)?;
666        Ok(VariantSig::new(self.adt_def.clone(), args, fields, idx, requires))
667    }
668}
669
670impl<T: TypeVisitable> TypeVisitable for Vec<T> {
671    fn visit_with<V: TypeVisitor>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
672        self.iter().try_for_each(|t| t.visit_with(visitor))
673    }
674}
675
676impl TypeVisitable for Ensures {
677    fn visit_with<V: TypeVisitor>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
678        match self {
679            Ensures::Type(path, ty) => {
680                path.to_expr().visit_with(visitor)?;
681                ty.visit_with(visitor)
682            }
683            Ensures::Pred(e) => e.visit_with(visitor),
684        }
685    }
686}
687
688impl TypeFoldable for Ensures {
689    fn try_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error> {
690        let c = match self {
691            Ensures::Type(path, ty) => {
692                let path_expr = path.to_expr().try_fold_with(folder)?;
693                Ensures::Type(
694                    path_expr.to_path().unwrap_or_else(|| {
695                        bug!("invalid path `{path_expr:?}` produced when folding `{self:?}`",)
696                    }),
697                    ty.try_fold_with(folder)?,
698                )
699            }
700            Ensures::Pred(e) => Ensures::Pred(e.try_fold_with(folder)?),
701        };
702        Ok(c)
703    }
704}
705
706impl TypeVisitable for super::TyOrBase {
707    fn visit_with<V: TypeVisitor>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
708        match self {
709            Self::Ty(ty) => ty.visit_with(visitor),
710            Self::Base(bty) => bty.visit_with(visitor),
711        }
712    }
713}
714
715impl TypeVisitable for Ty {
716    fn visit_with<V: TypeVisitor>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
717        visitor.visit_ty(self)
718    }
719}
720
721impl TypeSuperVisitable for Ty {
722    fn super_visit_with<V: TypeVisitor>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
723        match self.kind() {
724            TyKind::Indexed(bty, idxs) => {
725                bty.visit_with(visitor)?;
726                idxs.visit_with(visitor)
727            }
728            TyKind::Exists(exists) => exists.visit_with(visitor),
729            TyKind::StrgRef(_, path, ty) => {
730                path.to_expr().visit_with(visitor)?;
731                ty.visit_with(visitor)
732            }
733            TyKind::Ptr(_, path) => path.to_expr().visit_with(visitor),
734            TyKind::Constr(pred, ty) => {
735                pred.visit_with(visitor)?;
736                ty.visit_with(visitor)
737            }
738            TyKind::Downcast(.., args, _, fields) => {
739                args.visit_with(visitor)?;
740                fields.visit_with(visitor)
741            }
742            TyKind::Blocked(ty) => ty.visit_with(visitor),
743            TyKind::Infer(_) | TyKind::Param(_) | TyKind::Discr(..) | TyKind::Uninit => {
744                ControlFlow::Continue(())
745            }
746        }
747    }
748}
749
750impl TypeFoldable for Ty {
751    fn try_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error> {
752        folder.try_fold_ty(self)
753    }
754}
755
756impl TypeSuperFoldable for Ty {
757    fn try_super_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Ty, F::Error> {
758        let ty = match self.kind() {
759            TyKind::Indexed(bty, idxs) => {
760                Ty::indexed(bty.try_fold_with(folder)?, idxs.try_fold_with(folder)?)
761            }
762            TyKind::Exists(exists) => TyKind::Exists(exists.try_fold_with(folder)?).intern(),
763            TyKind::StrgRef(re, path, ty) => {
764                Ty::strg_ref(
765                    re.try_fold_with(folder)?,
766                    path.to_expr()
767                        .try_fold_with(folder)?
768                        .to_path()
769                        .expect("type folding produced an invalid path"),
770                    ty.try_fold_with(folder)?,
771                )
772            }
773            TyKind::Ptr(pk, path) => {
774                let pk = match pk {
775                    PtrKind::Mut(re) => PtrKind::Mut(re.try_fold_with(folder)?),
776                    PtrKind::Box => PtrKind::Box,
777                };
778                Ty::ptr(
779                    pk,
780                    path.to_expr()
781                        .try_fold_with(folder)?
782                        .to_path()
783                        .expect("type folding produced an invalid path"),
784                )
785            }
786            TyKind::Constr(pred, ty) => {
787                Ty::constr(pred.try_fold_with(folder)?, ty.try_fold_with(folder)?)
788            }
789            TyKind::Downcast(adt, args, ty, variant, fields) => {
790                Ty::downcast(
791                    adt.clone(),
792                    args.clone(),
793                    ty.clone(),
794                    *variant,
795                    fields.try_fold_with(folder)?,
796                )
797            }
798            TyKind::Blocked(ty) => Ty::blocked(ty.try_fold_with(folder)?),
799            TyKind::Infer(_) | TyKind::Param(_) | TyKind::Uninit | TyKind::Discr(..) => {
800                self.clone()
801            }
802        };
803        Ok(ty)
804    }
805}
806
807impl TypeVisitable for Region {
808    fn visit_with<V: TypeVisitor>(&self, _visitor: &mut V) -> ControlFlow<V::BreakTy> {
809        ControlFlow::Continue(())
810    }
811}
812
813impl TypeFoldable for Region {
814    fn try_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error> {
815        folder.try_fold_region(self)
816    }
817}
818
819impl TypeSuperFoldable for Const {
820    fn try_super_fold_with<F: FallibleTypeFolder>(
821        &self,
822        _folder: &mut F,
823    ) -> Result<Self, F::Error> {
824        // FIXME(nilehmann) we are not folding the type in `ConstKind::Value` because it's a rustc::ty::Ty
825        Ok(self.clone())
826    }
827}
828
829impl TypeVisitable for Const {
830    fn visit_with<V: TypeVisitor>(&self, _visitor: &mut V) -> ControlFlow<V::BreakTy> {
831        ControlFlow::Continue(())
832    }
833}
834
835impl TypeFoldable for Const {
836    fn try_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error> {
837        folder.try_fold_const(self)
838    }
839}
840
841impl TypeVisitable for BaseTy {
842    fn visit_with<V: TypeVisitor>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
843        visitor.visit_bty(self)
844    }
845}
846
847impl TypeSuperVisitable for BaseTy {
848    fn super_visit_with<V: TypeVisitor>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
849        match self {
850            BaseTy::Adt(_, args) => args.visit_with(visitor),
851            BaseTy::FnDef(_, args) => args.visit_with(visitor),
852            BaseTy::Slice(ty) => ty.visit_with(visitor),
853            BaseTy::RawPtr(ty, _) => ty.visit_with(visitor),
854            BaseTy::RawPtrMetadata(ty) => ty.visit_with(visitor),
855            BaseTy::Ref(_, ty, _) => ty.visit_with(visitor),
856            BaseTy::FnPtr(poly_fn_sig) => poly_fn_sig.visit_with(visitor),
857            BaseTy::Tuple(tys) => tys.visit_with(visitor),
858            BaseTy::Alias(_, alias_ty) => alias_ty.visit_with(visitor),
859            BaseTy::Array(ty, _) => ty.visit_with(visitor),
860            BaseTy::Coroutine(_, resume_ty, upvars) => {
861                resume_ty.visit_with(visitor)?;
862                upvars.visit_with(visitor)
863            }
864            BaseTy::Dynamic(exi_preds, _) => exi_preds.visit_with(visitor),
865            BaseTy::Int(_)
866            | BaseTy::Uint(_)
867            | BaseTy::Bool
868            | BaseTy::Float(_)
869            | BaseTy::Str
870            | BaseTy::Char
871            | BaseTy::Closure(..)
872            | BaseTy::Never
873            | BaseTy::Infer(_)
874            | BaseTy::Foreign(_)
875            | BaseTy::Param(_) => ControlFlow::Continue(()),
876        }
877    }
878}
879
880impl TypeFoldable for BaseTy {
881    fn try_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error> {
882        folder.try_fold_bty(self)
883    }
884}
885
886impl TypeSuperFoldable for BaseTy {
887    fn try_super_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error> {
888        let bty = match self {
889            BaseTy::Adt(adt_def, args) => BaseTy::adt(adt_def.clone(), args.try_fold_with(folder)?),
890            BaseTy::FnDef(def_id, args) => BaseTy::fn_def(*def_id, args.try_fold_with(folder)?),
891            BaseTy::Foreign(def_id) => BaseTy::Foreign(*def_id),
892            BaseTy::Slice(ty) => BaseTy::Slice(ty.try_fold_with(folder)?),
893            BaseTy::RawPtr(ty, mu) => BaseTy::RawPtr(ty.try_fold_with(folder)?, *mu),
894            BaseTy::RawPtrMetadata(ty) => BaseTy::RawPtrMetadata(ty.try_fold_with(folder)?),
895            BaseTy::Ref(re, ty, mutbl) => {
896                BaseTy::Ref(re.try_fold_with(folder)?, ty.try_fold_with(folder)?, *mutbl)
897            }
898            BaseTy::FnPtr(decl) => BaseTy::FnPtr(decl.try_fold_with(folder)?),
899            BaseTy::Tuple(tys) => BaseTy::Tuple(tys.try_fold_with(folder)?),
900            BaseTy::Alias(kind, alias_ty) => BaseTy::Alias(*kind, alias_ty.try_fold_with(folder)?),
901            BaseTy::Array(ty, c) => {
902                BaseTy::Array(ty.try_fold_with(folder)?, c.try_fold_with(folder)?)
903            }
904            BaseTy::Closure(did, args, gen_args) => {
905                BaseTy::Closure(*did, args.try_fold_with(folder)?, gen_args.clone())
906            }
907            BaseTy::Coroutine(did, resume_ty, args) => {
908                BaseTy::Coroutine(
909                    *did,
910                    resume_ty.try_fold_with(folder)?,
911                    args.try_fold_with(folder)?,
912                )
913            }
914            BaseTy::Dynamic(preds, region) => {
915                BaseTy::Dynamic(preds.try_fold_with(folder)?, region.try_fold_with(folder)?)
916            }
917            BaseTy::Int(_)
918            | BaseTy::Param(_)
919            | BaseTy::Uint(_)
920            | BaseTy::Bool
921            | BaseTy::Float(_)
922            | BaseTy::Str
923            | BaseTy::Char
924            | BaseTy::Infer(_)
925            | BaseTy::Never => self.clone(),
926        };
927        Ok(bty)
928    }
929}
930
931impl TypeVisitable for SubsetTy {
932    fn visit_with<V: TypeVisitor>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
933        self.bty.visit_with(visitor)?;
934        self.idx.visit_with(visitor)?;
935        self.pred.visit_with(visitor)
936    }
937}
938impl TypeFoldable for TyOrBase {
939    fn try_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error> {
940        match self {
941            Self::Ty(ty) => Ok(Self::Ty(ty.try_fold_with(folder)?)),
942            Self::Base(bty) => Ok(Self::Base(bty.try_fold_with(folder)?)),
943        }
944    }
945}
946
947impl TypeFoldable for SubsetTy {
948    fn try_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error> {
949        folder.try_fold_subset_ty(self)
950    }
951}
952
953impl TypeSuperFoldable for SubsetTy {
954    fn try_super_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error> {
955        Ok(SubsetTy {
956            bty: self.bty.try_fold_with(folder)?,
957            idx: self.idx.try_fold_with(folder)?,
958            pred: self.pred.try_fold_with(folder)?,
959        })
960    }
961}
962
963impl TypeVisitable for GenericArg {
964    fn visit_with<V: TypeVisitor>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
965        match self {
966            GenericArg::Ty(ty) => ty.visit_with(visitor),
967            GenericArg::Base(ty) => ty.visit_with(visitor),
968            GenericArg::Lifetime(_) => ControlFlow::Continue(()),
969            GenericArg::Const(_) => ControlFlow::Continue(()),
970        }
971    }
972}
973
974impl TypeFoldable for GenericArg {
975    fn try_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error> {
976        let arg = match self {
977            GenericArg::Ty(ty) => GenericArg::Ty(ty.try_fold_with(folder)?),
978            GenericArg::Base(ctor) => GenericArg::Base(ctor.try_fold_with(folder)?),
979            GenericArg::Lifetime(re) => GenericArg::Lifetime(re.try_fold_with(folder)?),
980            GenericArg::Const(c) => GenericArg::Const(c.try_fold_with(folder)?),
981        };
982        Ok(arg)
983    }
984}
985
986impl TypeVisitable for Expr {
987    fn visit_with<V: TypeVisitor>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
988        visitor.visit_expr(self)
989    }
990}
991
992impl TypeSuperVisitable for Expr {
993    fn super_visit_with<V: TypeVisitor>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
994        match self.kind() {
995            ExprKind::Var(_) => ControlFlow::Continue(()),
996            ExprKind::BinaryOp(_, e1, e2) => {
997                e1.visit_with(visitor)?;
998                e2.visit_with(visitor)
999            }
1000            ExprKind::Tuple(flds) => flds.visit_with(visitor),
1001            ExprKind::Ctor(_, flds) => flds.visit_with(visitor),
1002            ExprKind::FieldProj(e, _) | ExprKind::PathProj(e, _) | ExprKind::UnaryOp(_, e) => {
1003                e.visit_with(visitor)
1004            }
1005            ExprKind::App(func, sorts, arg) => {
1006                func.visit_with(visitor)?;
1007                sorts.visit_with(visitor)?;
1008                arg.visit_with(visitor)
1009            }
1010            ExprKind::IfThenElse(p, e1, e2) => {
1011                p.visit_with(visitor)?;
1012                e1.visit_with(visitor)?;
1013                e2.visit_with(visitor)
1014            }
1015            ExprKind::KVar(kvar) => kvar.visit_with(visitor),
1016            ExprKind::Alias(alias, args) => {
1017                alias.visit_with(visitor)?;
1018                args.visit_with(visitor)
1019            }
1020            ExprKind::Abs(body) => body.visit_with(visitor),
1021            ExprKind::BoundedQuant(_, _, body) => body.visit_with(visitor),
1022            ExprKind::ForAll(expr) => expr.visit_with(visitor),
1023            ExprKind::Let(init, body) => {
1024                init.visit_with(visitor)?;
1025                body.visit_with(visitor)
1026            }
1027            ExprKind::Constant(_)
1028            | ExprKind::Hole(_)
1029            | ExprKind::Local(_)
1030            | ExprKind::GlobalFunc(..)
1031            | ExprKind::InternalFunc(..)
1032            | ExprKind::ConstDefId(_) => ControlFlow::Continue(()),
1033        }
1034    }
1035}
1036
1037impl TypeFoldable for Expr {
1038    fn try_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error> {
1039        folder.try_fold_expr(self)
1040    }
1041}
1042
1043impl TypeSuperFoldable for Expr {
1044    fn try_super_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error> {
1045        let span = self.span();
1046        let expr = match self.kind() {
1047            ExprKind::Var(var) => Expr::var(*var),
1048            ExprKind::Local(local) => Expr::local(*local),
1049            ExprKind::Constant(c) => Expr::constant(*c),
1050            ExprKind::ConstDefId(did) => Expr::const_def_id(*did),
1051            ExprKind::BinaryOp(op, e1, e2) => {
1052                Expr::binary_op(
1053                    op.try_fold_with(folder)?,
1054                    e1.try_fold_with(folder)?,
1055                    e2.try_fold_with(folder)?,
1056                )
1057            }
1058            ExprKind::UnaryOp(op, e) => Expr::unary_op(*op, e.try_fold_with(folder)?),
1059            ExprKind::FieldProj(e, proj) => Expr::field_proj(e.try_fold_with(folder)?, *proj),
1060            ExprKind::Tuple(flds) => Expr::tuple(flds.try_fold_with(folder)?),
1061            ExprKind::Ctor(ctor, flds) => Expr::ctor(*ctor, flds.try_fold_with(folder)?),
1062            ExprKind::PathProj(e, field) => Expr::path_proj(e.try_fold_with(folder)?, *field),
1063            ExprKind::App(func, sorts, arg) => {
1064                Expr::app(
1065                    func.try_fold_with(folder)?,
1066                    sorts.try_fold_with(folder)?,
1067                    arg.try_fold_with(folder)?,
1068                )
1069            }
1070            ExprKind::IfThenElse(p, e1, e2) => {
1071                Expr::ite(
1072                    p.try_fold_with(folder)?,
1073                    e1.try_fold_with(folder)?,
1074                    e2.try_fold_with(folder)?,
1075                )
1076            }
1077            ExprKind::Hole(kind) => Expr::hole(kind.try_fold_with(folder)?),
1078            ExprKind::KVar(kvar) => Expr::kvar(kvar.try_fold_with(folder)?),
1079            ExprKind::Abs(lam) => Expr::abs(lam.try_fold_with(folder)?),
1080            ExprKind::BoundedQuant(kind, rng, body) => {
1081                Expr::bounded_quant(*kind, *rng, body.try_fold_with(folder)?)
1082            }
1083            ExprKind::GlobalFunc(kind) => Expr::global_func(kind.clone()),
1084            ExprKind::InternalFunc(kind) => Expr::internal_func(kind.clone()),
1085            ExprKind::Alias(alias, args) => {
1086                Expr::alias(alias.try_fold_with(folder)?, args.try_fold_with(folder)?)
1087            }
1088            ExprKind::ForAll(expr) => Expr::forall(expr.try_fold_with(folder)?),
1089            ExprKind::Let(init, body) => {
1090                Expr::let_(init.try_fold_with(folder)?, body.try_fold_with(folder)?)
1091            }
1092        };
1093        Ok(expr.at_opt(span))
1094    }
1095}
1096
1097impl<T> TypeVisitable for List<T>
1098where
1099    T: TypeVisitable,
1100    [T]: Internable,
1101{
1102    fn visit_with<V: TypeVisitor>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
1103        self.iter().try_for_each(|t| t.visit_with(visitor))
1104    }
1105}
1106
1107impl<T> TypeFoldable for List<T>
1108where
1109    T: TypeFoldable,
1110    [T]: Internable,
1111{
1112    fn try_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error> {
1113        self.iter().map(|t| t.try_fold_with(folder)).try_collect()
1114    }
1115}
1116
1117/// Used for types that are `Copy` and which **do not care arena allocated data** (i.e., don't need
1118/// to be folded).
1119macro_rules! TrivialTypeTraversalImpls {
1120    ($($ty:ty,)+) => {
1121        $(
1122            impl $crate::rty::fold::TypeFoldable for $ty {
1123                fn try_fold_with<F: $crate::rty::fold::FallibleTypeFolder>(
1124                    &self,
1125                    _: &mut F,
1126                ) -> ::std::result::Result<Self, F::Error> {
1127                    Ok(*self)
1128                }
1129
1130                #[inline]
1131                fn fold_with<F: $crate::rty::fold::TypeFolder>(
1132                    &self,
1133                    _: &mut F,
1134                ) -> Self {
1135                    *self
1136                }
1137            }
1138
1139            impl $crate::rty::fold::TypeVisitable for $ty {
1140                #[inline]
1141                fn visit_with<V: $crate::rty::fold::TypeVisitor>(
1142                    &self,
1143                    _: &mut V)
1144                    -> ::core::ops::ControlFlow<V::BreakTy>
1145                {
1146                    ::core::ops::ControlFlow::Continue(())
1147                }
1148            }
1149        )+
1150    };
1151}
1152
1153// For things that don't carry any arena-allocated data (and are copy...), just add them to this list.
1154TrivialTypeTraversalImpls! {
1155    (),
1156    bool,
1157    usize,
1158    crate::fhir::InferMode,
1159    crate::rty::BoundReftKind,
1160    crate::rty::BvSize,
1161    crate::rty::KVid,
1162    crate::def_id::FluxDefId,
1163    crate::def_id::FluxLocalDefId,
1164    rustc_span::Symbol,
1165    rustc_hir::def_id::DefId,
1166    rustc_hir::Safety,
1167    rustc_abi::ExternAbi,
1168    rustc_type_ir::ClosureKind,
1169    flux_rustc_bridge::ty::BoundRegionKind,
1170}