flux_middle/rty/
subst.rs

1use std::cmp::Ordering;
2
3use flux_common::{bug, tracked_span_bug};
4use rustc_type_ir::DebruijnIndex;
5
6use self::fold::FallibleTypeFolder;
7use super::fold::{TypeFolder, TypeSuperFoldable};
8use crate::rty::*;
9
10/// Substitution for late bound variables
11pub(super) struct BoundVarReplacer<D> {
12    current_index: DebruijnIndex,
13    delegate: D,
14}
15
16pub trait BoundVarReplacerDelegate {
17    fn replace_expr(&mut self, var: BoundReft) -> Expr;
18    fn replace_region(&mut self, br: BoundRegion) -> Region;
19}
20
21pub(crate) struct FnMutDelegate<F1, F2> {
22    pub exprs: F1,
23    pub regions: F2,
24}
25
26impl<F1, F2> FnMutDelegate<F1, F2>
27where
28    F1: FnMut(BoundReft) -> Expr,
29    F2: FnMut(BoundRegion) -> Region,
30{
31    pub(crate) fn new(exprs: F1, regions: F2) -> Self {
32        Self { exprs, regions }
33    }
34}
35
36impl<F1, F2> BoundVarReplacerDelegate for FnMutDelegate<F1, F2>
37where
38    F1: FnMut(BoundReft) -> Expr,
39    F2: FnMut(BoundRegion) -> Region,
40{
41    fn replace_expr(&mut self, var: BoundReft) -> Expr {
42        (self.exprs)(var)
43    }
44
45    fn replace_region(&mut self, br: BoundRegion) -> Region {
46        (self.regions)(br)
47    }
48}
49
50impl<D> BoundVarReplacer<D> {
51    pub(super) fn new(delegate: D) -> BoundVarReplacer<D> {
52        BoundVarReplacer { delegate, current_index: INNERMOST }
53    }
54}
55
56impl<D> TypeFolder for BoundVarReplacer<D>
57where
58    D: BoundVarReplacerDelegate,
59{
60    fn fold_binder<T>(&mut self, t: &Binder<T>) -> Binder<T>
61    where
62        T: TypeFoldable,
63    {
64        self.current_index.shift_in(1);
65        let r = t.super_fold_with(self);
66        self.current_index.shift_out(1);
67        r
68    }
69
70    fn fold_expr(&mut self, e: &Expr) -> Expr {
71        if let ExprKind::Var(Var::Bound(debruijn, breft)) = e.kind() {
72            match debruijn.cmp(&self.current_index) {
73                Ordering::Less => Expr::bvar(*debruijn, breft.var, breft.kind),
74                Ordering::Equal => {
75                    self.delegate
76                        .replace_expr(*breft)
77                        .shift_in_escaping(self.current_index.as_u32())
78                }
79                Ordering::Greater => Expr::bvar(debruijn.shifted_out(1), breft.var, breft.kind),
80            }
81        } else {
82            e.super_fold_with(self)
83        }
84    }
85
86    fn fold_region(&mut self, re: &Region) -> Region {
87        if let ReBound(debruijn, br) = *re {
88            match debruijn.cmp(&self.current_index) {
89                Ordering::Less => *re,
90                Ordering::Equal => {
91                    let region = self.delegate.replace_region(br);
92                    if let ReBound(debruijn1, br) = region {
93                        // If the callback returns a late-bound region,
94                        // that region should always use the INNERMOST
95                        // debruijn index. Then we adjust it to the
96                        // correct depth.
97                        tracked_span_assert_eq!(debruijn1, INNERMOST);
98                        Region::ReBound(debruijn, br)
99                    } else {
100                        region
101                    }
102                }
103                Ordering::Greater => ReBound(debruijn.shifted_out(1), br),
104            }
105        } else {
106            *re
107        }
108    }
109}
110
111/// Substitution for generics, i.e., early bound types, lifetimes, const generics, and refinements.
112/// Note that a substitution for refinement parameters (a list of expressions) must always be
113/// specified, while the behavior of other generics parameters (types, lifetimes and consts) can be
114/// configured with [`GenericsSubstDelegate`].
115pub struct GenericsSubstFolder<'a, D> {
116    current_index: DebruijnIndex,
117    delegate: D,
118    refinement_args: &'a [Expr],
119}
120
121pub trait GenericsSubstDelegate {
122    type Error = !;
123
124    fn sort_for_param(&mut self, param_ty: ParamTy) -> Result<Sort, Self::Error>;
125    fn ty_for_param(&mut self, param_ty: ParamTy) -> Result<Ty, Self::Error>;
126    fn ctor_for_param(&mut self, param_ty: ParamTy) -> Result<SubsetTyCtor, Self::Error>;
127    fn region_for_param(&mut self, ebr: EarlyParamRegion) -> Region;
128    fn expr_for_param_const(&self, param_const: ParamConst) -> Expr;
129    fn const_for_param(&mut self, param: &Const) -> Const;
130}
131
132/// A substitution with an explicit list of generic arguments.
133pub(crate) struct GenericArgsDelegate<'a, 'tcx>(
134    pub(crate) &'a [GenericArg],
135    pub(crate) TyCtxt<'tcx>,
136);
137
138impl GenericsSubstDelegate for GenericArgsDelegate<'_, '_> {
139    fn sort_for_param(&mut self, param_ty: ParamTy) -> Result<Sort, !> {
140        match self.0.get(param_ty.index as usize) {
141            Some(GenericArg::Base(ctor)) => Ok(ctor.sort()),
142            Some(arg) => {
143                tracked_span_bug!("expected base type for generic parameter, found `{arg:?}`")
144            }
145            None => tracked_span_bug!("type parameter out of range {param_ty:?}"),
146        }
147    }
148
149    fn ty_for_param(&mut self, param_ty: ParamTy) -> Result<Ty, !> {
150        match self.0.get(param_ty.index as usize) {
151            Some(GenericArg::Ty(ty)) => Ok(ty.clone()),
152            Some(arg) => tracked_span_bug!("expected type for generic parameter, found `{arg:?}`"),
153            None => tracked_span_bug!("type parameter out of range {param_ty:?}"),
154        }
155    }
156
157    fn ctor_for_param(&mut self, param_ty: ParamTy) -> Result<SubsetTyCtor, !> {
158        match self.0.get(param_ty.index as usize) {
159            Some(GenericArg::Base(ctor)) => Ok(ctor.clone()),
160            Some(arg) => {
161                tracked_span_bug!("expected base type for generic parameter, found `{arg:?}`")
162            }
163            None => tracked_span_bug!("type parameter out of range"),
164        }
165    }
166
167    fn region_for_param(&mut self, ebr: EarlyParamRegion) -> Region {
168        match self.0.get(ebr.index as usize) {
169            Some(GenericArg::Lifetime(re)) => *re,
170            Some(arg) => bug!("expected region for generic parameter, found `{arg:?}`"),
171            None => bug!("region parameter out of range"),
172        }
173    }
174
175    fn const_for_param(&mut self, param: &Const) -> Const {
176        if let ConstKind::Param(param_const) = &param.kind {
177            match self.0.get(param_const.index as usize) {
178                Some(GenericArg::Const(cst)) => cst.clone(),
179                Some(arg) => bug!("expected const for generic parameter, found `{arg:?}`"),
180                None => bug!("generic parameter out of range"),
181            }
182        } else {
183            param.clone()
184        }
185    }
186
187    fn expr_for_param_const(&self, param_const: ParamConst) -> Expr {
188        match self.0.get(param_const.index as usize) {
189            Some(GenericArg::Const(cst)) => Expr::from_const(self.1, cst),
190            Some(arg) => bug!("expected const for generic parameter, found `{arg:?}`"),
191            None => bug!("generic parameter out of range"),
192        }
193    }
194}
195
196/// A substitution meant to be used only for sorts. It'll panic if used on a type. This is used to
197/// break cycles during wf checking. During wf-checking we use [`rty::Sort`], but we can't yet
198/// generate (in general) an [`rty::GenericArg`] because conversion from [`fhir`] into [`rty`]
199/// requires the results of wf checking. Perhaps, we could also solve this problem by doing
200/// wf-checking with a different "IR" for sorts that sits in between [`fhir`] and [`rty`].
201///
202/// [`rty::Sort`]: crate::rty::Sort
203/// [`rty::GenericArg`]: crate::rty::GenericArg
204/// [`fhir`]: crate::fhir
205/// [`rty`]: crate::rty
206pub(crate) struct GenericsSubstForSort<F, E>
207where
208    F: FnMut(ParamTy) -> Result<Sort, E>,
209{
210    /// Implementation of [`GenericsSubstDelegate::sort_for_param`]
211    pub(crate) sort_for_param: F,
212}
213
214impl<F, E> GenericsSubstDelegate for GenericsSubstForSort<F, E>
215where
216    F: FnMut(ParamTy) -> Result<Sort, E>,
217{
218    type Error = E;
219
220    fn sort_for_param(&mut self, param_ty: ParamTy) -> Result<Sort, E> {
221        (self.sort_for_param)(param_ty)
222    }
223
224    fn ty_for_param(&mut self, param_ty: ParamTy) -> Result<Ty, E> {
225        bug!("unexpected type param {param_ty:?}");
226    }
227
228    fn ctor_for_param(&mut self, param_ty: ParamTy) -> Result<SubsetTyCtor, E> {
229        bug!("unexpected base type param {param_ty:?}");
230    }
231
232    fn region_for_param(&mut self, ebr: EarlyParamRegion) -> Region {
233        bug!("unexpected region param {ebr:?}");
234    }
235
236    fn const_for_param(&mut self, param: &Const) -> Const {
237        bug!("unexpected const param {param:?}");
238    }
239
240    fn expr_for_param_const(&self, param_const: ParamConst) -> Expr {
241        bug!("unexpected param_const {param_const:?}");
242    }
243}
244
245impl<'a, D> GenericsSubstFolder<'a, D> {
246    pub fn new(delegate: D, refine: &'a [Expr]) -> Self {
247        Self { current_index: INNERMOST, delegate, refinement_args: refine }
248    }
249}
250
251impl<D: GenericsSubstDelegate> FallibleTypeFolder for GenericsSubstFolder<'_, D> {
252    type Error = D::Error;
253
254    fn try_fold_binder<T: TypeFoldable>(&mut self, t: &Binder<T>) -> Result<Binder<T>, D::Error> {
255        self.current_index.shift_in(1);
256        let r = t.try_super_fold_with(self)?;
257        self.current_index.shift_out(1);
258        Ok(r)
259    }
260
261    fn try_fold_sort(&mut self, sort: &Sort) -> Result<Sort, D::Error> {
262        if let Sort::Param(param_ty) = sort {
263            self.delegate.sort_for_param(*param_ty)
264        } else {
265            sort.try_super_fold_with(self)
266        }
267    }
268
269    fn try_fold_ty(&mut self, ty: &Ty) -> Result<Ty, D::Error> {
270        match ty.kind() {
271            TyKind::Param(param_ty) => self.delegate.ty_for_param(*param_ty),
272            TyKind::Indexed(BaseTy::Param(param_ty), idx) => {
273                let idx = idx.try_fold_with(self)?;
274                Ok(self
275                    .delegate
276                    .ctor_for_param(*param_ty)?
277                    .replace_bound_reft(&idx)
278                    .to_ty())
279            }
280            _ => ty.try_super_fold_with(self),
281        }
282    }
283
284    fn try_fold_subset_ty(&mut self, sty: &SubsetTy) -> Result<SubsetTy, D::Error> {
285        if let BaseTy::Param(param_ty) = &sty.bty {
286            let idx = sty.idx.try_fold_with(self)?;
287            let pred = sty.pred.try_fold_with(self)?;
288            Ok(self
289                .delegate
290                .ctor_for_param(*param_ty)?
291                .replace_bound_reft(&idx)
292                .strengthen(pred))
293        } else {
294            sty.try_super_fold_with(self)
295        }
296    }
297
298    fn try_fold_region(&mut self, re: &Region) -> Result<Region, D::Error> {
299        if let ReEarlyParam(ebr) = *re { Ok(self.delegate.region_for_param(ebr)) } else { Ok(*re) }
300    }
301
302    fn try_fold_expr(&mut self, expr: &Expr) -> Result<Expr, D::Error> {
303        match expr.kind() {
304            ExprKind::Var(Var::EarlyParam(var)) => Ok(self.expr_for_param(var.index)),
305            ExprKind::Var(Var::ConstGeneric(param_const)) => {
306                Ok(self.delegate.expr_for_param_const(*param_const))
307            }
308            _ => expr.try_super_fold_with(self),
309        }
310    }
311
312    fn try_fold_const(&mut self, c: &Const) -> Result<Const, D::Error> {
313        Ok(self.delegate.const_for_param(c))
314    }
315}
316
317impl<D> GenericsSubstFolder<'_, D> {
318    fn expr_for_param(&self, idx: u32) -> Expr {
319        self.refinement_args[idx as usize].shift_in_escaping(self.current_index.as_u32())
320    }
321}
322
323pub(crate) struct SortSubst<D> {
324    delegate: D,
325}
326
327impl<D> SortSubst<D> {
328    pub(crate) fn new(delegate: D) -> Self {
329        Self { delegate }
330    }
331}
332
333impl<D: SortSubstDelegate> TypeFolder for SortSubst<D> {
334    fn fold_sort(&mut self, sort: &Sort) -> Sort {
335        match sort {
336            Sort::Var(var) => self.delegate.sort_for_param(*var),
337            Sort::BitVec(BvSize::Param(var)) => Sort::BitVec(self.delegate.bv_size_for_param(*var)),
338            _ => sort.super_fold_with(self),
339        }
340    }
341}
342
343trait SortSubstDelegate {
344    fn sort_for_param(&self, var: ParamSort) -> Sort;
345    fn bv_size_for_param(&self, var: ParamSort) -> BvSize;
346}
347
348impl SortSubstDelegate for &[SortArg] {
349    fn sort_for_param(&self, var: ParamSort) -> Sort {
350        match &self[var.index()] {
351            SortArg::Sort(sort) => sort.clone(),
352            SortArg::BvSize(_) => tracked_span_bug!("unexpected bv size for sort param"),
353        }
354    }
355
356    fn bv_size_for_param(&self, var: ParamSort) -> BvSize {
357        match self[var.index()] {
358            SortArg::BvSize(size) => size,
359            SortArg::Sort(_) => tracked_span_bug!("unexpected sort for bv size param"),
360        }
361    }
362}
363
364impl SortSubstDelegate for &[Sort] {
365    fn sort_for_param(&self, var: ParamSort) -> Sort {
366        self[var.index()].clone()
367    }
368
369    fn bv_size_for_param(&self, _var: ParamSort) -> BvSize {
370        tracked_span_bug!("unexpected bv size parameter")
371    }
372}