flux_middle/rty/
subst.rs

1use std::cmp::Ordering;
2
3use flux_common::{bug, tracked_span_bug};
4use rustc_type_ir::DebruijnIndex;
5
6use self::fold::FallibleTypeFolder;
7use super::fold::{TypeFolder, TypeSuperFoldable};
8use crate::rty::*;
9
10/// Substitution for late bound variables
11pub(super) struct BoundVarReplacer<D> {
12    current_index: DebruijnIndex,
13    delegate: D,
14}
15
16pub trait BoundVarReplacerDelegate {
17    fn replace_expr(&mut self, var: BoundReft) -> Expr;
18    fn replace_region(&mut self, br: BoundRegion) -> Region;
19}
20
21pub(crate) struct FnMutDelegate<F1, F2> {
22    pub exprs: F1,
23    pub regions: F2,
24}
25
26impl<F1, F2> FnMutDelegate<F1, F2>
27where
28    F1: FnMut(BoundReft) -> Expr,
29    F2: FnMut(BoundRegion) -> Region,
30{
31    pub(crate) fn new(exprs: F1, regions: F2) -> Self {
32        Self { exprs, regions }
33    }
34}
35
36impl<F1, F2> BoundVarReplacerDelegate for FnMutDelegate<F1, F2>
37where
38    F1: FnMut(BoundReft) -> Expr,
39    F2: FnMut(BoundRegion) -> Region,
40{
41    fn replace_expr(&mut self, var: BoundReft) -> Expr {
42        (self.exprs)(var)
43    }
44
45    fn replace_region(&mut self, br: BoundRegion) -> Region {
46        (self.regions)(br)
47    }
48}
49
50impl<D> BoundVarReplacer<D> {
51    pub(super) fn new(delegate: D) -> BoundVarReplacer<D> {
52        BoundVarReplacer { delegate, current_index: INNERMOST }
53    }
54}
55
56impl<D> TypeFolder for BoundVarReplacer<D>
57where
58    D: BoundVarReplacerDelegate,
59{
60    fn enter_binder(&mut self, _: &BoundVariableKinds) {
61        self.current_index.shift_in(1);
62    }
63
64    fn exit_binder(&mut self) {
65        self.current_index.shift_out(1);
66    }
67
68    fn fold_expr(&mut self, e: &Expr) -> Expr {
69        if let ExprKind::Var(Var::Bound(debruijn, breft)) = e.kind() {
70            match debruijn.cmp(&self.current_index) {
71                Ordering::Less => Expr::bvar(*debruijn, breft.var, breft.kind),
72                Ordering::Equal => {
73                    self.delegate
74                        .replace_expr(*breft)
75                        .shift_in_escaping(self.current_index.as_u32())
76                }
77                Ordering::Greater => Expr::bvar(debruijn.shifted_out(1), breft.var, breft.kind),
78            }
79        } else {
80            e.super_fold_with(self)
81        }
82    }
83
84    fn fold_region(&mut self, re: &Region) -> Region {
85        if let ReBound(debruijn, br) = *re {
86            match debruijn.cmp(&self.current_index) {
87                Ordering::Less => *re,
88                Ordering::Equal => {
89                    let region = self.delegate.replace_region(br);
90                    if let ReBound(debruijn1, br) = region {
91                        // If the callback returns a late-bound region,
92                        // that region should always use the INNERMOST
93                        // debruijn index. Then we adjust it to the
94                        // correct depth.
95                        tracked_span_assert_eq!(debruijn1, INNERMOST);
96                        Region::ReBound(debruijn, br)
97                    } else {
98                        region
99                    }
100                }
101                Ordering::Greater => ReBound(debruijn.shifted_out(1), br),
102            }
103        } else {
104            *re
105        }
106    }
107}
108
109/// Substitution for generics, i.e., early bound types, lifetimes, const generics, and refinements.
110/// Note that a substitution for refinement parameters (a list of expressions) must always be
111/// specified, while the behavior of other generics parameters (types, lifetimes and consts) can be
112/// configured with [`GenericsSubstDelegate`].
113pub struct GenericsSubstFolder<'a, D> {
114    current_index: DebruijnIndex,
115    delegate: D,
116    refinement_args: &'a [Expr],
117}
118
119pub trait GenericsSubstDelegate {
120    type Error = !;
121
122    fn sort_for_param(&mut self, param_ty: ParamTy) -> Result<Sort, Self::Error>;
123    fn ty_for_param(&mut self, param_ty: ParamTy) -> Result<Ty, Self::Error>;
124    fn ctor_for_param(&mut self, param_ty: ParamTy) -> Result<SubsetTyCtor, Self::Error>;
125    fn region_for_param(&mut self, ebr: EarlyParamRegion) -> Region;
126    fn expr_for_param_const(&self, param_const: ParamConst) -> Expr;
127    fn const_for_param(&mut self, param: &Const) -> Const;
128}
129
130/// A substitution with an explicit list of generic arguments.
131pub(crate) struct GenericArgsDelegate<'a, 'tcx>(
132    pub(crate) &'a [GenericArg],
133    pub(crate) TyCtxt<'tcx>,
134);
135
136impl GenericsSubstDelegate for GenericArgsDelegate<'_, '_> {
137    fn sort_for_param(&mut self, param_ty: ParamTy) -> Result<Sort, !> {
138        match self.0.get(param_ty.index as usize) {
139            Some(GenericArg::Base(ctor)) => Ok(ctor.sort()),
140            Some(arg) => {
141                tracked_span_bug!("expected base type for generic parameter, found `{arg:?}`")
142            }
143            None => tracked_span_bug!("type parameter out of range {param_ty:?}"),
144        }
145    }
146
147    fn ty_for_param(&mut self, param_ty: ParamTy) -> Result<Ty, !> {
148        match self.0.get(param_ty.index as usize) {
149            Some(GenericArg::Ty(ty)) => Ok(ty.clone()),
150            Some(arg) => tracked_span_bug!("expected type for generic parameter, found `{arg:?}`"),
151            None => tracked_span_bug!("type parameter out of range {param_ty:?}"),
152        }
153    }
154
155    fn ctor_for_param(&mut self, param_ty: ParamTy) -> Result<SubsetTyCtor, !> {
156        match self.0.get(param_ty.index as usize) {
157            Some(GenericArg::Base(ctor)) => Ok(ctor.clone()),
158            Some(arg) => {
159                tracked_span_bug!("expected base type for generic parameter, found `{arg:?}`")
160            }
161            None => tracked_span_bug!("type parameter out of range"),
162        }
163    }
164
165    fn region_for_param(&mut self, ebr: EarlyParamRegion) -> Region {
166        match self.0.get(ebr.index as usize) {
167            Some(GenericArg::Lifetime(re)) => *re,
168            Some(arg) => bug!("expected region for generic parameter, found `{arg:?}`"),
169            None => bug!("region parameter out of range"),
170        }
171    }
172
173    fn const_for_param(&mut self, param: &Const) -> Const {
174        if let ConstKind::Param(param_const) = &param.kind {
175            match self.0.get(param_const.index as usize) {
176                Some(GenericArg::Const(cst)) => cst.clone(),
177                Some(arg) => bug!("expected const for generic parameter, found `{arg:?}`"),
178                None => bug!("generic parameter out of range"),
179            }
180        } else {
181            param.clone()
182        }
183    }
184
185    fn expr_for_param_const(&self, param_const: ParamConst) -> Expr {
186        match self.0.get(param_const.index as usize) {
187            Some(GenericArg::Const(cst)) => Expr::from_const(self.1, cst),
188            Some(arg) => bug!("expected const for generic parameter, found `{arg:?}`"),
189            None => bug!("generic parameter out of range"),
190        }
191    }
192}
193
194/// A substitution meant to be used only for sorts. It'll panic if used on a type. This is used to
195/// break cycles during wf checking. During wf-checking we use [`rty::Sort`], but we can't yet
196/// generate (in general) an [`rty::GenericArg`] because conversion from [`fhir`] into [`rty`]
197/// requires the results of wf checking. Perhaps, we could also solve this problem by doing
198/// wf-checking with a different "IR" for sorts that sits in between [`fhir`] and [`rty`].
199///
200/// [`rty::Sort`]: crate::rty::Sort
201/// [`rty::GenericArg`]: crate::rty::GenericArg
202/// [`fhir`]: crate::fhir
203/// [`rty`]: crate::rty
204pub(crate) struct GenericsSubstForSort<F, E>
205where
206    F: FnMut(ParamTy) -> Result<Sort, E>,
207{
208    /// Implementation of [`GenericsSubstDelegate::sort_for_param`]
209    pub(crate) sort_for_param: F,
210}
211
212impl<F, E> GenericsSubstDelegate for GenericsSubstForSort<F, E>
213where
214    F: FnMut(ParamTy) -> Result<Sort, E>,
215{
216    type Error = E;
217
218    fn sort_for_param(&mut self, param_ty: ParamTy) -> Result<Sort, E> {
219        (self.sort_for_param)(param_ty)
220    }
221
222    fn ty_for_param(&mut self, param_ty: ParamTy) -> Result<Ty, E> {
223        bug!("unexpected type param {param_ty:?}");
224    }
225
226    fn ctor_for_param(&mut self, param_ty: ParamTy) -> Result<SubsetTyCtor, E> {
227        bug!("unexpected base type param {param_ty:?}");
228    }
229
230    fn region_for_param(&mut self, ebr: EarlyParamRegion) -> Region {
231        bug!("unexpected region param {ebr:?}");
232    }
233
234    fn const_for_param(&mut self, param: &Const) -> Const {
235        bug!("unexpected const param {param:?}");
236    }
237
238    fn expr_for_param_const(&self, param_const: ParamConst) -> Expr {
239        bug!("unexpected param_const {param_const:?}");
240    }
241}
242
243impl<'a, D> GenericsSubstFolder<'a, D> {
244    pub fn new(delegate: D, refine: &'a [Expr]) -> Self {
245        Self { current_index: INNERMOST, delegate, refinement_args: refine }
246    }
247}
248
249impl<D: GenericsSubstDelegate> FallibleTypeFolder for GenericsSubstFolder<'_, D> {
250    type Error = D::Error;
251
252    fn try_enter_binder(&mut self, _: &BoundVariableKinds) {
253        self.current_index.shift_in(1);
254    }
255
256    fn try_exit_binder(&mut self) {
257        self.current_index.shift_out(1);
258    }
259
260    fn try_fold_sort(&mut self, sort: &Sort) -> Result<Sort, D::Error> {
261        if let Sort::Param(param_ty) = sort {
262            self.delegate.sort_for_param(*param_ty)
263        } else {
264            sort.try_super_fold_with(self)
265        }
266    }
267
268    fn try_fold_ty(&mut self, ty: &Ty) -> Result<Ty, D::Error> {
269        match ty.kind() {
270            TyKind::Param(param_ty) => self.delegate.ty_for_param(*param_ty),
271            TyKind::Indexed(BaseTy::Param(param_ty), idx) => {
272                let idx = idx.try_fold_with(self)?;
273                Ok(self
274                    .delegate
275                    .ctor_for_param(*param_ty)?
276                    .replace_bound_reft(&idx)
277                    .to_ty())
278            }
279            _ => ty.try_super_fold_with(self),
280        }
281    }
282
283    fn try_fold_subset_ty(&mut self, sty: &SubsetTy) -> Result<SubsetTy, D::Error> {
284        if let BaseTy::Param(param_ty) = &sty.bty {
285            let idx = sty.idx.try_fold_with(self)?;
286            let pred = sty.pred.try_fold_with(self)?;
287            Ok(self
288                .delegate
289                .ctor_for_param(*param_ty)?
290                .replace_bound_reft(&idx)
291                .strengthen(pred))
292        } else {
293            sty.try_super_fold_with(self)
294        }
295    }
296
297    fn try_fold_region(&mut self, re: &Region) -> Result<Region, D::Error> {
298        if let ReEarlyParam(ebr) = *re { Ok(self.delegate.region_for_param(ebr)) } else { Ok(*re) }
299    }
300
301    fn try_fold_expr(&mut self, expr: &Expr) -> Result<Expr, D::Error> {
302        match expr.kind() {
303            ExprKind::Var(Var::EarlyParam(var)) => Ok(self.expr_for_param(var.index)),
304            ExprKind::Var(Var::ConstGeneric(param_const)) => {
305                Ok(self.delegate.expr_for_param_const(*param_const))
306            }
307            _ => expr.try_super_fold_with(self),
308        }
309    }
310
311    fn try_fold_const(&mut self, c: &Const) -> Result<Const, D::Error> {
312        Ok(self.delegate.const_for_param(c))
313    }
314}
315
316impl<D> GenericsSubstFolder<'_, D> {
317    fn expr_for_param(&self, idx: u32) -> Expr {
318        self.refinement_args[idx as usize].shift_in_escaping(self.current_index.as_u32())
319    }
320}
321
322pub(crate) struct SortSubst<D> {
323    delegate: D,
324}
325
326impl<D> SortSubst<D> {
327    pub(crate) fn new(delegate: D) -> Self {
328        Self { delegate }
329    }
330}
331
332impl<D: SortSubstDelegate> TypeFolder for SortSubst<D> {
333    fn fold_sort(&mut self, sort: &Sort) -> Sort {
334        match sort {
335            Sort::Var(var) => self.delegate.sort_for_param(*var),
336            Sort::BitVec(BvSize::Param(var)) => Sort::BitVec(self.delegate.bv_size_for_param(*var)),
337            _ => sort.super_fold_with(self),
338        }
339    }
340}
341
342trait SortSubstDelegate {
343    fn sort_for_param(&self, var: ParamSort) -> Sort;
344    fn bv_size_for_param(&self, var: ParamSort) -> BvSize;
345}
346
347impl SortSubstDelegate for &[SortArg] {
348    fn sort_for_param(&self, var: ParamSort) -> Sort {
349        match &self[var.index()] {
350            SortArg::Sort(sort) => sort.clone(),
351            SortArg::BvSize(_) => tracked_span_bug!("unexpected bv size for sort param"),
352        }
353    }
354
355    fn bv_size_for_param(&self, var: ParamSort) -> BvSize {
356        match self[var.index()] {
357            SortArg::BvSize(size) => size,
358            SortArg::Sort(_) => tracked_span_bug!("unexpected sort for bv size param"),
359        }
360    }
361}
362
363impl SortSubstDelegate for &[Sort] {
364    fn sort_for_param(&self, var: ParamSort) -> Sort {
365        self[var.index()].clone()
366    }
367
368    fn bv_size_for_param(&self, _var: ParamSort) -> BvSize {
369        tracked_span_bug!("unexpected bv size parameter")
370    }
371}