1use std::cmp::Ordering;
2
3use flux_common::{bug, tracked_span_bug};
4use rustc_type_ir::DebruijnIndex;
5
6use self::fold::FallibleTypeFolder;
7use super::fold::{TypeFolder, TypeSuperFoldable};
8use crate::rty::*;
9
10pub(super) struct BoundVarReplacer<D> {
12 current_index: DebruijnIndex,
13 delegate: D,
14}
15
16pub trait BoundVarReplacerDelegate {
17 fn replace_expr(&mut self, var: BoundReft) -> Expr;
18 fn replace_region(&mut self, br: BoundRegion) -> Region;
19}
20
21pub(crate) struct FnMutDelegate<F1, F2> {
22 pub exprs: F1,
23 pub regions: F2,
24}
25
26impl<F1, F2> FnMutDelegate<F1, F2>
27where
28 F1: FnMut(BoundReft) -> Expr,
29 F2: FnMut(BoundRegion) -> Region,
30{
31 pub(crate) fn new(exprs: F1, regions: F2) -> Self {
32 Self { exprs, regions }
33 }
34}
35
36impl<F1, F2> BoundVarReplacerDelegate for FnMutDelegate<F1, F2>
37where
38 F1: FnMut(BoundReft) -> Expr,
39 F2: FnMut(BoundRegion) -> Region,
40{
41 fn replace_expr(&mut self, var: BoundReft) -> Expr {
42 (self.exprs)(var)
43 }
44
45 fn replace_region(&mut self, br: BoundRegion) -> Region {
46 (self.regions)(br)
47 }
48}
49
50impl<D> BoundVarReplacer<D> {
51 pub(super) fn new(delegate: D) -> BoundVarReplacer<D> {
52 BoundVarReplacer { delegate, current_index: INNERMOST }
53 }
54}
55
56impl<D> TypeFolder for BoundVarReplacer<D>
57where
58 D: BoundVarReplacerDelegate,
59{
60 fn fold_binder<T>(&mut self, t: &Binder<T>) -> Binder<T>
61 where
62 T: TypeFoldable,
63 {
64 self.current_index.shift_in(1);
65 let r = t.super_fold_with(self);
66 self.current_index.shift_out(1);
67 r
68 }
69
70 fn fold_expr(&mut self, e: &Expr) -> Expr {
71 if let ExprKind::Var(Var::Bound(debruijn, breft)) = e.kind() {
72 match debruijn.cmp(&self.current_index) {
73 Ordering::Less => Expr::bvar(*debruijn, breft.var, breft.kind),
74 Ordering::Equal => {
75 self.delegate
76 .replace_expr(*breft)
77 .shift_in_escaping(self.current_index.as_u32())
78 }
79 Ordering::Greater => Expr::bvar(debruijn.shifted_out(1), breft.var, breft.kind),
80 }
81 } else {
82 e.super_fold_with(self)
83 }
84 }
85
86 fn fold_region(&mut self, re: &Region) -> Region {
87 if let ReBound(debruijn, br) = *re {
88 match debruijn.cmp(&self.current_index) {
89 Ordering::Less => *re,
90 Ordering::Equal => {
91 let region = self.delegate.replace_region(br);
92 if let ReBound(debruijn1, br) = region {
93 tracked_span_assert_eq!(debruijn1, INNERMOST);
98 Region::ReBound(debruijn, br)
99 } else {
100 region
101 }
102 }
103 Ordering::Greater => ReBound(debruijn.shifted_out(1), br),
104 }
105 } else {
106 *re
107 }
108 }
109}
110
111pub struct GenericsSubstFolder<'a, D> {
116 current_index: DebruijnIndex,
117 delegate: D,
118 refinement_args: &'a [Expr],
119}
120
121pub trait GenericsSubstDelegate {
122 type Error = !;
123
124 fn sort_for_param(&mut self, param_ty: ParamTy) -> Result<Sort, Self::Error>;
125 fn ty_for_param(&mut self, param_ty: ParamTy) -> Result<Ty, Self::Error>;
126 fn ctor_for_param(&mut self, param_ty: ParamTy) -> Result<SubsetTyCtor, Self::Error>;
127 fn region_for_param(&mut self, ebr: EarlyParamRegion) -> Region;
128 fn expr_for_param_const(&self, param_const: ParamConst) -> Expr;
129 fn const_for_param(&mut self, param: &Const) -> Const;
130}
131
132pub(crate) struct GenericArgsDelegate<'a, 'tcx>(
134 pub(crate) &'a [GenericArg],
135 pub(crate) TyCtxt<'tcx>,
136);
137
138impl GenericsSubstDelegate for GenericArgsDelegate<'_, '_> {
139 fn sort_for_param(&mut self, param_ty: ParamTy) -> Result<Sort, !> {
140 match self.0.get(param_ty.index as usize) {
141 Some(GenericArg::Base(ctor)) => Ok(ctor.sort()),
142 Some(arg) => {
143 tracked_span_bug!("expected base type for generic parameter, found `{arg:?}`")
144 }
145 None => tracked_span_bug!("type parameter out of range {param_ty:?}"),
146 }
147 }
148
149 fn ty_for_param(&mut self, param_ty: ParamTy) -> Result<Ty, !> {
150 match self.0.get(param_ty.index as usize) {
151 Some(GenericArg::Ty(ty)) => Ok(ty.clone()),
152 Some(arg) => tracked_span_bug!("expected type for generic parameter, found `{arg:?}`"),
153 None => tracked_span_bug!("type parameter out of range {param_ty:?}"),
154 }
155 }
156
157 fn ctor_for_param(&mut self, param_ty: ParamTy) -> Result<SubsetTyCtor, !> {
158 match self.0.get(param_ty.index as usize) {
159 Some(GenericArg::Base(ctor)) => Ok(ctor.clone()),
160 Some(arg) => {
161 tracked_span_bug!("expected base type for generic parameter, found `{arg:?}`")
162 }
163 None => tracked_span_bug!("type parameter out of range"),
164 }
165 }
166
167 fn region_for_param(&mut self, ebr: EarlyParamRegion) -> Region {
168 match self.0.get(ebr.index as usize) {
169 Some(GenericArg::Lifetime(re)) => *re,
170 Some(arg) => bug!("expected region for generic parameter, found `{arg:?}`"),
171 None => bug!("region parameter out of range"),
172 }
173 }
174
175 fn const_for_param(&mut self, param: &Const) -> Const {
176 if let ConstKind::Param(param_const) = ¶m.kind {
177 match self.0.get(param_const.index as usize) {
178 Some(GenericArg::Const(cst)) => cst.clone(),
179 Some(arg) => bug!("expected const for generic parameter, found `{arg:?}`"),
180 None => bug!("generic parameter out of range"),
181 }
182 } else {
183 param.clone()
184 }
185 }
186
187 fn expr_for_param_const(&self, param_const: ParamConst) -> Expr {
188 match self.0.get(param_const.index as usize) {
189 Some(GenericArg::Const(cst)) => Expr::from_const(self.1, cst),
190 Some(arg) => bug!("expected const for generic parameter, found `{arg:?}`"),
191 None => bug!("generic parameter out of range"),
192 }
193 }
194}
195
196pub(crate) struct GenericsSubstForSort<F, E>
207where
208 F: FnMut(ParamTy) -> Result<Sort, E>,
209{
210 pub(crate) sort_for_param: F,
212}
213
214impl<F, E> GenericsSubstDelegate for GenericsSubstForSort<F, E>
215where
216 F: FnMut(ParamTy) -> Result<Sort, E>,
217{
218 type Error = E;
219
220 fn sort_for_param(&mut self, param_ty: ParamTy) -> Result<Sort, E> {
221 (self.sort_for_param)(param_ty)
222 }
223
224 fn ty_for_param(&mut self, param_ty: ParamTy) -> Result<Ty, E> {
225 bug!("unexpected type param {param_ty:?}");
226 }
227
228 fn ctor_for_param(&mut self, param_ty: ParamTy) -> Result<SubsetTyCtor, E> {
229 bug!("unexpected base type param {param_ty:?}");
230 }
231
232 fn region_for_param(&mut self, ebr: EarlyParamRegion) -> Region {
233 bug!("unexpected region param {ebr:?}");
234 }
235
236 fn const_for_param(&mut self, param: &Const) -> Const {
237 bug!("unexpected const param {param:?}");
238 }
239
240 fn expr_for_param_const(&self, param_const: ParamConst) -> Expr {
241 bug!("unexpected param_const {param_const:?}");
242 }
243}
244
245impl<'a, D> GenericsSubstFolder<'a, D> {
246 pub fn new(delegate: D, refine: &'a [Expr]) -> Self {
247 Self { current_index: INNERMOST, delegate, refinement_args: refine }
248 }
249}
250
251impl<D: GenericsSubstDelegate> FallibleTypeFolder for GenericsSubstFolder<'_, D> {
252 type Error = D::Error;
253
254 fn try_fold_binder<T: TypeFoldable>(&mut self, t: &Binder<T>) -> Result<Binder<T>, D::Error> {
255 self.current_index.shift_in(1);
256 let r = t.try_super_fold_with(self)?;
257 self.current_index.shift_out(1);
258 Ok(r)
259 }
260
261 fn try_fold_sort(&mut self, sort: &Sort) -> Result<Sort, D::Error> {
262 if let Sort::Param(param_ty) = sort {
263 self.delegate.sort_for_param(*param_ty)
264 } else {
265 sort.try_super_fold_with(self)
266 }
267 }
268
269 fn try_fold_ty(&mut self, ty: &Ty) -> Result<Ty, D::Error> {
270 match ty.kind() {
271 TyKind::Param(param_ty) => self.delegate.ty_for_param(*param_ty),
272 TyKind::Indexed(BaseTy::Param(param_ty), idx) => {
273 let idx = idx.try_fold_with(self)?;
274 Ok(self
275 .delegate
276 .ctor_for_param(*param_ty)?
277 .replace_bound_reft(&idx)
278 .to_ty())
279 }
280 _ => ty.try_super_fold_with(self),
281 }
282 }
283
284 fn try_fold_subset_ty(&mut self, sty: &SubsetTy) -> Result<SubsetTy, D::Error> {
285 if let BaseTy::Param(param_ty) = &sty.bty {
286 let idx = sty.idx.try_fold_with(self)?;
287 let pred = sty.pred.try_fold_with(self)?;
288 Ok(self
289 .delegate
290 .ctor_for_param(*param_ty)?
291 .replace_bound_reft(&idx)
292 .strengthen(pred))
293 } else {
294 sty.try_super_fold_with(self)
295 }
296 }
297
298 fn try_fold_region(&mut self, re: &Region) -> Result<Region, D::Error> {
299 if let ReEarlyParam(ebr) = *re { Ok(self.delegate.region_for_param(ebr)) } else { Ok(*re) }
300 }
301
302 fn try_fold_expr(&mut self, expr: &Expr) -> Result<Expr, D::Error> {
303 match expr.kind() {
304 ExprKind::Var(Var::EarlyParam(var)) => Ok(self.expr_for_param(var.index)),
305 ExprKind::Var(Var::ConstGeneric(param_const)) => {
306 Ok(self.delegate.expr_for_param_const(*param_const))
307 }
308 _ => expr.try_super_fold_with(self),
309 }
310 }
311
312 fn try_fold_const(&mut self, c: &Const) -> Result<Const, D::Error> {
313 Ok(self.delegate.const_for_param(c))
314 }
315}
316
317impl<D> GenericsSubstFolder<'_, D> {
318 fn expr_for_param(&self, idx: u32) -> Expr {
319 self.refinement_args[idx as usize].shift_in_escaping(self.current_index.as_u32())
320 }
321}
322
323pub(crate) struct SortSubst<D> {
324 delegate: D,
325}
326
327impl<D> SortSubst<D> {
328 pub(crate) fn new(delegate: D) -> Self {
329 Self { delegate }
330 }
331}
332
333impl<D: SortSubstDelegate> TypeFolder for SortSubst<D> {
334 fn fold_sort(&mut self, sort: &Sort) -> Sort {
335 match sort {
336 Sort::Var(var) => self.delegate.sort_for_param(*var),
337 Sort::BitVec(BvSize::Param(var)) => Sort::BitVec(self.delegate.bv_size_for_param(*var)),
338 _ => sort.super_fold_with(self),
339 }
340 }
341}
342
343trait SortSubstDelegate {
344 fn sort_for_param(&self, var: ParamSort) -> Sort;
345 fn bv_size_for_param(&self, var: ParamSort) -> BvSize;
346}
347
348impl SortSubstDelegate for &[SortArg] {
349 fn sort_for_param(&self, var: ParamSort) -> Sort {
350 match &self[var.index()] {
351 SortArg::Sort(sort) => sort.clone(),
352 SortArg::BvSize(_) => tracked_span_bug!("unexpected bv size for sort param"),
353 }
354 }
355
356 fn bv_size_for_param(&self, var: ParamSort) -> BvSize {
357 match self[var.index()] {
358 SortArg::BvSize(size) => size,
359 SortArg::Sort(_) => tracked_span_bug!("unexpected sort for bv size param"),
360 }
361 }
362}
363
364impl SortSubstDelegate for &[Sort] {
365 fn sort_for_param(&self, var: ParamSort) -> Sort {
366 self[var.index()].clone()
367 }
368
369 fn bv_size_for_param(&self, _var: ParamSort) -> BvSize {
370 tracked_span_bug!("unexpected bv size parameter")
371 }
372}