1use flux_arc_interner::List;
34use flux_macros::{TypeFoldable, TypeVisitable};
35use itertools::Itertools;
36use rustc_ast::Mutability;
37use rustc_type_ir::{BoundVar, INNERMOST};
38
39use super::{
40 BaseTy, Binder, BoundVariableKind, Expr, FnSig, GenericArg, GenericArgsExt, PolyFnSig,
41 SubsetTy, Ty, TyCtor, TyKind, TyOrBase,
42 fold::{TypeFoldable, TypeFolder, TypeSuperFoldable},
43};
44use crate::rty::{ExprKind, HoleKind};
45
46pub struct Hoister<D> {
53 delegate: D,
54 in_boxes: bool,
55 in_downcast: bool,
56 in_mut_refs: bool,
57 in_shr_refs: bool,
58 in_strg_refs: bool,
59 in_tuples: bool,
60 existentials: bool,
61 slices: bool,
62}
63
64pub trait HoisterDelegate {
65 fn hoist_exists(&mut self, ty_ctor: &TyCtor) -> Ty;
66 fn hoist_constr(&mut self, pred: Expr);
67}
68
69impl<D> Hoister<D> {
70 pub fn with_delegate(delegate: D) -> Self {
71 Hoister {
72 delegate,
73 in_tuples: false,
74 in_shr_refs: false,
75 in_mut_refs: false,
76 in_strg_refs: false,
77 in_boxes: false,
78 in_downcast: false,
79 existentials: true,
80 slices: false,
81 }
82 }
83
84 pub fn hoist_inside_shr_refs(mut self, shr_refs: bool) -> Self {
85 self.in_shr_refs = shr_refs;
86 self
87 }
88
89 pub fn hoist_inside_mut_refs(mut self, mut_refs: bool) -> Self {
90 self.in_mut_refs = mut_refs;
91 self
92 }
93
94 pub fn hoist_inside_strg_refs(mut self, strg_refs: bool) -> Self {
95 self.in_strg_refs = strg_refs;
96 self
97 }
98
99 pub fn hoist_inside_tuples(mut self, tuples: bool) -> Self {
100 self.in_tuples = tuples;
101 self
102 }
103
104 pub fn hoist_inside_boxes(mut self, boxes: bool) -> Self {
105 self.in_boxes = boxes;
106 self
107 }
108
109 pub fn hoist_inside_downcast(mut self, downcast: bool) -> Self {
110 self.in_downcast = downcast;
111 self
112 }
113
114 pub fn hoist_existentials(mut self, exists: bool) -> Self {
115 self.existentials = exists;
116 self
117 }
118
119 pub fn hoist_slices(mut self, slices: bool) -> Self {
120 self.slices = slices;
121 self
122 }
123
124 pub fn transparent(self) -> Self {
125 self.hoist_inside_boxes(true)
126 .hoist_inside_downcast(true)
127 .hoist_inside_mut_refs(false)
128 .hoist_inside_shr_refs(true)
129 .hoist_inside_strg_refs(true)
130 .hoist_inside_tuples(true)
131 .hoist_slices(true)
132 }
133
134 pub fn shallow(self) -> Self {
135 self.hoist_inside_boxes(false)
136 .hoist_inside_downcast(false)
137 .hoist_inside_mut_refs(false)
138 .hoist_inside_shr_refs(false)
139 .hoist_inside_strg_refs(false)
140 .hoist_inside_tuples(false)
141 }
142}
143
144impl<D: HoisterDelegate> Hoister<D> {
145 pub fn hoist(&mut self, ty: &Ty) -> Ty {
146 ty.fold_with(self)
147 }
148}
149
150fn is_indexed_slice(ty: &Ty) -> bool {
154 let Some(bty) = ty.as_bty_skipping_existentials() else {
155 return false;
156 };
157 match bty {
158 BaseTy::Slice(_) => true,
159 BaseTy::Ref(_, ty, _) => is_indexed_slice(ty),
160 _ => false,
161 }
162}
163
164impl<D: HoisterDelegate> TypeFolder for Hoister<D> {
165 fn fold_ty(&mut self, ty: &Ty) -> Ty {
166 match ty.kind() {
167 TyKind::Indexed(bty, idx) => Ty::indexed(bty.fold_with(self), idx.clone()),
168 TyKind::Exists(ty_ctor) if self.existentials => {
169 match &ty_ctor.vars()[..] {
175 [BoundVariableKind::Refine(sort, ..)] => {
176 if sort.is_unit() {
177 ty_ctor.replace_bound_reft(&Expr::unit())
178 } else if let Some(def_id) = sort.is_unit_adt() {
179 ty_ctor.replace_bound_reft(&Expr::unit_struct(def_id))
180 } else {
181 self.delegate.hoist_exists(ty_ctor)
182 }
183 }
184 _ => self.delegate.hoist_exists(ty_ctor),
185 }
186 .fold_with(self)
187 }
188 TyKind::Constr(pred, ty) => {
189 self.delegate.hoist_constr(pred.clone());
190 ty.fold_with(self)
191 }
192 TyKind::StrgRef(..) if self.in_strg_refs => ty.super_fold_with(self),
193 TyKind::Downcast(..) if self.in_downcast => ty.super_fold_with(self),
194 _ => ty.clone(),
195 }
196 }
197
198 fn fold_bty(&mut self, bty: &BaseTy) -> BaseTy {
199 match bty {
200 BaseTy::Adt(adt_def, args) if adt_def.is_box() && self.in_boxes => {
201 let (boxed, alloc) = args.box_args();
202 let args = List::from_arr([
203 GenericArg::Ty(boxed.fold_with(self)),
204 GenericArg::Ty(alloc.clone()),
205 ]);
206 BaseTy::Adt(adt_def.clone(), args)
207 }
208 BaseTy::Ref(re, ty, mutability) if is_indexed_slice(ty) && self.slices => {
209 BaseTy::Ref(*re, ty.fold_with(self), *mutability)
210 }
211 BaseTy::Ref(re, ty, Mutability::Not) if self.in_shr_refs => {
212 BaseTy::Ref(*re, ty.fold_with(self), Mutability::Not)
213 }
214 BaseTy::Ref(re, ty, Mutability::Mut) if self.in_mut_refs => {
215 BaseTy::Ref(*re, ty.fold_with(self), Mutability::Mut)
216 }
217 BaseTy::Tuple(tys) if self.in_tuples => BaseTy::Tuple(tys.fold_with(self)),
218 _ => bty.clone(),
219 }
220 }
221}
222
223#[derive(Default)]
224pub struct LocalHoister {
225 vars: Vec<BoundVariableKind>,
226 preds: Vec<Expr>,
227}
228
229impl LocalHoister {
230 pub fn new(vars: Vec<BoundVariableKind>) -> Self {
231 LocalHoister { vars, preds: vec![] }
232 }
233
234 pub fn bind<T>(self, f: impl FnOnce(List<BoundVariableKind>, Vec<Expr>) -> T) -> Binder<T> {
235 let vars = List::from_vec(self.vars);
236 Binder::bind_with_vars(f(vars.clone(), self.preds), vars)
237 }
238}
239
240impl HoisterDelegate for &mut LocalHoister {
241 fn hoist_exists(&mut self, ty_ctor: &TyCtor) -> Ty {
242 ty_ctor.replace_bound_refts_with(|sort, mode, kind| {
243 let idx = self.vars.len();
244 self.vars
245 .push(BoundVariableKind::Refine(sort.clone(), mode, kind));
246 Expr::bvar(INNERMOST, BoundVar::from_usize(idx), kind)
247 })
248 }
249
250 fn hoist_constr(&mut self, pred: Expr) {
251 self.preds.push(pred);
252 }
253}
254
255impl PolyFnSig {
256 pub fn hoist_input_binders(&self) -> Self {
266 let original_vars = self.vars().to_vec();
267 let fn_sig = self.skip_binder_ref();
268 let mut delegate = LocalHoister { vars: original_vars, preds: fn_sig.requires().to_vec() };
269 let mut hoister = Hoister::with_delegate(&mut delegate).transparent();
270
271 let inputs = fn_sig
272 .inputs()
273 .iter()
274 .map(|ty| hoister.hoist(ty))
275 .collect_vec();
276
277 delegate.bind(|_vars, mut preds| {
278 let mut keep_hole = true;
279 preds.retain(|pred| {
280 if let ExprKind::Hole(HoleKind::Pred) = pred.kind() {
281 std::mem::replace(&mut keep_hole, false)
282 } else {
283 true
284 }
285 });
286
287 FnSig::new(
288 fn_sig.safety,
289 fn_sig.abi,
290 preds.into(),
291 inputs.into(),
292 fn_sig.output().clone(),
293 )
294 })
295 }
296}
297
298impl Ty {
299 pub fn shallow_canonicalize(&self) -> CanonicalTy {
302 let mut delegate = LocalHoister::default();
303 let ty = self.shift_in_escaping(1);
304 let ty = Hoister::with_delegate(&mut delegate).hoist(&ty);
305 let constr_ty = delegate.bind(|_, preds| {
306 let pred = Expr::and_from_iter(preds);
307 CanonicalConstrTy { ty, pred }
308 });
309 if constr_ty.vars().is_empty() {
310 CanonicalTy::Constr(constr_ty.skip_binder().shift_out_escaping(1))
311 } else {
312 CanonicalTy::Exists(constr_ty)
313 }
314 }
315}
316
317#[derive(TypeVisitable, TypeFoldable)]
318pub struct CanonicalConstrTy {
319 ty: Ty,
324 pred: Expr,
325}
326
327impl CanonicalConstrTy {
328 pub fn ty(&self) -> Ty {
329 self.ty.clone()
330 }
331
332 pub fn pred(&self) -> Expr {
333 self.pred.clone()
334 }
335
336 pub fn to_ty(&self) -> Ty {
337 Ty::constr(self.pred(), self.ty())
338 }
339}
340
341pub enum CanonicalTy {
349 Constr(CanonicalConstrTy),
351 Exists(Binder<CanonicalConstrTy>),
353}
354
355impl CanonicalTy {
356 pub fn to_ty(&self) -> Ty {
357 match self {
358 CanonicalTy::Constr(constr_ty) => constr_ty.to_ty(),
359 CanonicalTy::Exists(poly_constr_ty) => {
360 Ty::exists(poly_constr_ty.as_ref().map(CanonicalConstrTy::to_ty))
361 }
362 }
363 }
364
365 pub fn as_ty_or_base(&self) -> TyOrBase {
366 match self {
367 CanonicalTy::Constr(constr_ty) => {
368 if let TyKind::Indexed(bty, idx) = constr_ty.ty.kind() {
369 let pred = if idx.is_unit() {
376 constr_ty.pred.clone()
377 } else {
378 Expr::and(&constr_ty.pred, Expr::eq(Expr::nu(), idx.shift_in_escaping(1)))
379 };
380 let sort = bty.sort();
381 let constr = SubsetTy::new(bty.shift_in_escaping(1), Expr::nu(), pred);
382 TyOrBase::Base(Binder::bind_with_sort(constr, sort))
383 } else {
384 TyOrBase::Ty(self.to_ty())
385 }
386 }
387 CanonicalTy::Exists(poly_constr_ty) => {
388 let constr = poly_constr_ty.as_ref().skip_binder();
389 if let TyKind::Indexed(bty, idx) = constr.ty.kind()
390 && idx.is_nu()
391 {
392 let ctor = poly_constr_ty
393 .as_ref()
394 .map(|constr| SubsetTy::new(bty.clone(), idx, &constr.pred));
395 TyOrBase::Base(ctor)
396 } else {
397 TyOrBase::Ty(self.to_ty())
398 }
399 }
400 }
401 }
402}
403
404mod pretty {
405 use super::*;
406 use crate::pretty::*;
407
408 impl Pretty for CanonicalConstrTy {
409 fn fmt(&self, cx: &PrettyCx, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
410 w!(cx, f, "{{ {:?} | {:?} }}", &self.ty, &self.pred)
411 }
412 }
413
414 impl Pretty for CanonicalTy {
415 fn fmt(&self, cx: &PrettyCx, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416 match self {
417 CanonicalTy::Constr(constr) => w!(cx, f, "{:?}", constr),
418 CanonicalTy::Exists(poly_constr) => {
419 cx.with_bound_vars(poly_constr.vars(), || {
420 cx.fmt_bound_vars(false, "∃", poly_constr.vars(), ". ", f)?;
421 w!(cx, f, "{:?}", poly_constr.as_ref().skip_binder())
422 })
423 }
424 }
425 }
426 }
427
428 impl_debug_with_default_cx!(CanonicalTy, CanonicalConstrTy);
429}