1use std::cmp::Ordering;
2
3use flux_common::{bug, tracked_span_bug};
4use rustc_type_ir::DebruijnIndex;
5
6use self::fold::FallibleTypeFolder;
7use super::fold::{TypeFolder, TypeSuperFoldable};
8use crate::rty::*;
9
10pub(super) struct BoundVarReplacer<D> {
12 current_index: DebruijnIndex,
13 delegate: D,
14}
15
16pub trait BoundVarReplacerDelegate {
17 fn replace_expr(&mut self, var: BoundReft) -> Expr;
18 fn replace_region(&mut self, br: BoundRegion) -> Region;
19}
20
21pub(crate) struct FnMutDelegate<F1, F2> {
22 pub exprs: F1,
23 pub regions: F2,
24}
25
26impl<F1, F2> FnMutDelegate<F1, F2>
27where
28 F1: FnMut(BoundReft) -> Expr,
29 F2: FnMut(BoundRegion) -> Region,
30{
31 pub(crate) fn new(exprs: F1, regions: F2) -> Self {
32 Self { exprs, regions }
33 }
34}
35
36impl<F1, F2> BoundVarReplacerDelegate for FnMutDelegate<F1, F2>
37where
38 F1: FnMut(BoundReft) -> Expr,
39 F2: FnMut(BoundRegion) -> Region,
40{
41 fn replace_expr(&mut self, var: BoundReft) -> Expr {
42 (self.exprs)(var)
43 }
44
45 fn replace_region(&mut self, br: BoundRegion) -> Region {
46 (self.regions)(br)
47 }
48}
49
50impl<D> BoundVarReplacer<D> {
51 pub(super) fn new(delegate: D) -> BoundVarReplacer<D> {
52 BoundVarReplacer { delegate, current_index: INNERMOST }
53 }
54}
55
56impl<D> TypeFolder for BoundVarReplacer<D>
57where
58 D: BoundVarReplacerDelegate,
59{
60 fn enter_binder(&mut self, _: &BoundVariableKinds) {
61 self.current_index.shift_in(1);
62 }
63
64 fn exit_binder(&mut self) {
65 self.current_index.shift_out(1);
66 }
67
68 fn fold_expr(&mut self, e: &Expr) -> Expr {
69 if let ExprKind::Var(Var::Bound(debruijn, breft)) = e.kind() {
70 match debruijn.cmp(&self.current_index) {
71 Ordering::Less => Expr::bvar(*debruijn, breft.var, breft.kind),
72 Ordering::Equal => {
73 self.delegate
74 .replace_expr(*breft)
75 .shift_in_escaping(self.current_index.as_u32())
76 }
77 Ordering::Greater => Expr::bvar(debruijn.shifted_out(1), breft.var, breft.kind),
78 }
79 } else {
80 e.super_fold_with(self)
81 }
82 }
83
84 fn fold_region(&mut self, re: &Region) -> Region {
85 if let ReBound(debruijn, br) = *re {
86 match debruijn.cmp(&self.current_index) {
87 Ordering::Less => *re,
88 Ordering::Equal => {
89 let region = self.delegate.replace_region(br);
90 if let ReBound(debruijn1, br) = region {
91 tracked_span_assert_eq!(debruijn1, INNERMOST);
96 Region::ReBound(debruijn, br)
97 } else {
98 region
99 }
100 }
101 Ordering::Greater => ReBound(debruijn.shifted_out(1), br),
102 }
103 } else {
104 *re
105 }
106 }
107}
108
109pub struct GenericsSubstFolder<'a, D> {
114 current_index: DebruijnIndex,
115 delegate: D,
116 refinement_args: &'a [Expr],
117}
118
119pub trait GenericsSubstDelegate {
120 type Error = !;
121
122 fn sort_for_param(&mut self, param_ty: ParamTy) -> Result<Sort, Self::Error>;
123 fn ty_for_param(&mut self, param_ty: ParamTy) -> Result<Ty, Self::Error>;
124 fn ctor_for_param(&mut self, param_ty: ParamTy) -> Result<SubsetTyCtor, Self::Error>;
125 fn region_for_param(&mut self, ebr: EarlyParamRegion) -> Region;
126 fn expr_for_param_const(&self, param_const: ParamConst) -> Expr;
127 fn const_for_param(&mut self, param: &Const) -> Const;
128}
129
130pub(crate) struct GenericArgsDelegate<'a, 'tcx>(
132 pub(crate) &'a [GenericArg],
133 pub(crate) TyCtxt<'tcx>,
134);
135
136impl GenericsSubstDelegate for GenericArgsDelegate<'_, '_> {
137 fn sort_for_param(&mut self, param_ty: ParamTy) -> Result<Sort, !> {
138 match self.0.get(param_ty.index as usize) {
139 Some(GenericArg::Base(ctor)) => Ok(ctor.sort()),
140 Some(arg) => {
141 tracked_span_bug!("expected base type for generic parameter, found `{arg:?}`")
142 }
143 None => tracked_span_bug!("type parameter out of range {param_ty:?}"),
144 }
145 }
146
147 fn ty_for_param(&mut self, param_ty: ParamTy) -> Result<Ty, !> {
148 match self.0.get(param_ty.index as usize) {
149 Some(GenericArg::Ty(ty)) => Ok(ty.clone()),
150 Some(arg) => tracked_span_bug!("expected type for generic parameter, found `{arg:?}`"),
151 None => tracked_span_bug!("type parameter out of range {param_ty:?}"),
152 }
153 }
154
155 fn ctor_for_param(&mut self, param_ty: ParamTy) -> Result<SubsetTyCtor, !> {
156 match self.0.get(param_ty.index as usize) {
157 Some(GenericArg::Base(ctor)) => Ok(ctor.clone()),
158 Some(arg) => {
159 tracked_span_bug!("expected base type for generic parameter, found `{arg:?}`")
160 }
161 None => tracked_span_bug!("type parameter out of range"),
162 }
163 }
164
165 fn region_for_param(&mut self, ebr: EarlyParamRegion) -> Region {
166 match self.0.get(ebr.index as usize) {
167 Some(GenericArg::Lifetime(re)) => *re,
168 Some(arg) => bug!("expected region for generic parameter, found `{arg:?}`"),
169 None => bug!("region parameter out of range"),
170 }
171 }
172
173 fn const_for_param(&mut self, param: &Const) -> Const {
174 if let ConstKind::Param(param_const) = ¶m.kind {
175 match self.0.get(param_const.index as usize) {
176 Some(GenericArg::Const(cst)) => cst.clone(),
177 Some(arg) => bug!("expected const for generic parameter, found `{arg:?}`"),
178 None => bug!("generic parameter out of range"),
179 }
180 } else {
181 param.clone()
182 }
183 }
184
185 fn expr_for_param_const(&self, param_const: ParamConst) -> Expr {
186 match self.0.get(param_const.index as usize) {
187 Some(GenericArg::Const(cst)) => Expr::from_const(self.1, cst),
188 Some(arg) => bug!("expected const for generic parameter, found `{arg:?}`"),
189 None => bug!("generic parameter out of range"),
190 }
191 }
192}
193
194pub(crate) struct GenericsSubstForSort<F, E>
205where
206 F: FnMut(ParamTy) -> Result<Sort, E>,
207{
208 pub(crate) sort_for_param: F,
210}
211
212impl<F, E> GenericsSubstDelegate for GenericsSubstForSort<F, E>
213where
214 F: FnMut(ParamTy) -> Result<Sort, E>,
215{
216 type Error = E;
217
218 fn sort_for_param(&mut self, param_ty: ParamTy) -> Result<Sort, E> {
219 (self.sort_for_param)(param_ty)
220 }
221
222 fn ty_for_param(&mut self, param_ty: ParamTy) -> Result<Ty, E> {
223 bug!("unexpected type param {param_ty:?}");
224 }
225
226 fn ctor_for_param(&mut self, param_ty: ParamTy) -> Result<SubsetTyCtor, E> {
227 bug!("unexpected base type param {param_ty:?}");
228 }
229
230 fn region_for_param(&mut self, ebr: EarlyParamRegion) -> Region {
231 bug!("unexpected region param {ebr:?}");
232 }
233
234 fn const_for_param(&mut self, param: &Const) -> Const {
235 bug!("unexpected const param {param:?}");
236 }
237
238 fn expr_for_param_const(&self, param_const: ParamConst) -> Expr {
239 bug!("unexpected param_const {param_const:?}");
240 }
241}
242
243impl<'a, D> GenericsSubstFolder<'a, D> {
244 pub fn new(delegate: D, refine: &'a [Expr]) -> Self {
245 Self { current_index: INNERMOST, delegate, refinement_args: refine }
246 }
247}
248
249impl<D: GenericsSubstDelegate> FallibleTypeFolder for GenericsSubstFolder<'_, D> {
250 type Error = D::Error;
251
252 fn try_enter_binder(&mut self, _: &BoundVariableKinds) {
253 self.current_index.shift_in(1);
254 }
255
256 fn try_exit_binder(&mut self) {
257 self.current_index.shift_out(1);
258 }
259
260 fn try_fold_sort(&mut self, sort: &Sort) -> Result<Sort, D::Error> {
261 if let Sort::Param(param_ty) = sort {
262 self.delegate.sort_for_param(*param_ty)
263 } else {
264 sort.try_super_fold_with(self)
265 }
266 }
267
268 fn try_fold_ty(&mut self, ty: &Ty) -> Result<Ty, D::Error> {
269 match ty.kind() {
270 TyKind::Param(param_ty) => self.delegate.ty_for_param(*param_ty),
271 TyKind::Indexed(BaseTy::Param(param_ty), idx) => {
272 let idx = idx.try_fold_with(self)?;
273 Ok(self
274 .delegate
275 .ctor_for_param(*param_ty)?
276 .replace_bound_reft(&idx)
277 .to_ty())
278 }
279 _ => ty.try_super_fold_with(self),
280 }
281 }
282
283 fn try_fold_subset_ty(&mut self, sty: &SubsetTy) -> Result<SubsetTy, D::Error> {
284 if let BaseTy::Param(param_ty) = &sty.bty {
285 let idx = sty.idx.try_fold_with(self)?;
286 let pred = sty.pred.try_fold_with(self)?;
287 Ok(self
288 .delegate
289 .ctor_for_param(*param_ty)?
290 .replace_bound_reft(&idx)
291 .strengthen(pred))
292 } else {
293 sty.try_super_fold_with(self)
294 }
295 }
296
297 fn try_fold_region(&mut self, re: &Region) -> Result<Region, D::Error> {
298 if let ReEarlyParam(ebr) = *re { Ok(self.delegate.region_for_param(ebr)) } else { Ok(*re) }
299 }
300
301 fn try_fold_expr(&mut self, expr: &Expr) -> Result<Expr, D::Error> {
302 match expr.kind() {
303 ExprKind::Var(Var::EarlyParam(var)) => Ok(self.expr_for_param(var.index)),
304 ExprKind::Var(Var::ConstGeneric(param_const)) => {
305 Ok(self.delegate.expr_for_param_const(*param_const))
306 }
307 _ => expr.try_super_fold_with(self),
308 }
309 }
310
311 fn try_fold_const(&mut self, c: &Const) -> Result<Const, D::Error> {
312 Ok(self.delegate.const_for_param(c))
313 }
314}
315
316impl<D> GenericsSubstFolder<'_, D> {
317 fn expr_for_param(&self, idx: u32) -> Expr {
318 self.refinement_args[idx as usize].shift_in_escaping(self.current_index.as_u32())
319 }
320}
321
322pub(crate) struct SortSubst<D> {
323 delegate: D,
324}
325
326impl<D> SortSubst<D> {
327 pub(crate) fn new(delegate: D) -> Self {
328 Self { delegate }
329 }
330}
331
332impl<D: SortSubstDelegate> TypeFolder for SortSubst<D> {
333 fn fold_sort(&mut self, sort: &Sort) -> Sort {
334 match sort {
335 Sort::Var(var) => self.delegate.sort_for_param(*var),
336 Sort::BitVec(BvSize::Param(var)) => Sort::BitVec(self.delegate.bv_size_for_param(*var)),
337 _ => sort.super_fold_with(self),
338 }
339 }
340}
341
342trait SortSubstDelegate {
343 fn sort_for_param(&self, var: ParamSort) -> Sort;
344 fn bv_size_for_param(&self, var: ParamSort) -> BvSize;
345}
346
347impl SortSubstDelegate for &[SortArg] {
348 fn sort_for_param(&self, var: ParamSort) -> Sort {
349 match &self[var.index()] {
350 SortArg::Sort(sort) => sort.clone(),
351 SortArg::BvSize(_) => tracked_span_bug!("unexpected bv size for sort param"),
352 }
353 }
354
355 fn bv_size_for_param(&self, var: ParamSort) -> BvSize {
356 match self[var.index()] {
357 SortArg::BvSize(size) => size,
358 SortArg::Sort(_) => tracked_span_bug!("unexpected sort for bv size param"),
359 }
360 }
361}
362
363impl SortSubstDelegate for &[Sort] {
364 fn sort_for_param(&self, var: ParamSort) -> Sort {
365 self[var.index()].clone()
366 }
367
368 fn bv_size_for_param(&self, _var: ParamSort) -> BvSize {
369 tracked_span_bug!("unexpected bv size parameter")
370 }
371}