1use std::fmt::Write;
34
35use flux_arc_interner::List;
36use flux_macros::{TypeFoldable, TypeVisitable};
37use itertools::Itertools;
38use rustc_ast::Mutability;
39use rustc_type_ir::{BoundVar, INNERMOST};
40
41use super::{
42 BaseTy, Binder, BoundVariableKind, Expr, FnSig, GenericArg, GenericArgsExt, PolyFnSig,
43 SubsetTy, Ty, TyCtor, TyKind, TyOrBase,
44 fold::{TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitable},
45};
46use crate::rty::{ExprKind, HoleKind};
47
48pub struct Hoister<D> {
55 delegate: D,
56 in_boxes: bool,
57 in_downcast: bool,
58 in_mut_refs: bool,
59 in_shr_refs: bool,
60 in_strg_refs: bool,
61 in_tuples: bool,
62 existentials: bool,
63 slices: bool,
64}
65
66pub trait HoisterDelegate {
67 fn hoist_exists(&mut self, ty_ctor: &TyCtor) -> Ty;
68 fn hoist_constr(&mut self, pred: Expr);
69}
70
71impl<D> Hoister<D> {
72 pub fn with_delegate(delegate: D) -> Self {
73 Hoister {
74 delegate,
75 in_tuples: false,
76 in_shr_refs: false,
77 in_mut_refs: false,
78 in_strg_refs: false,
79 in_boxes: false,
80 in_downcast: false,
81 existentials: true,
82 slices: false,
83 }
84 }
85
86 pub fn hoist_inside_shr_refs(mut self, shr_refs: bool) -> Self {
87 self.in_shr_refs = shr_refs;
88 self
89 }
90
91 pub fn hoist_inside_mut_refs(mut self, mut_refs: bool) -> Self {
92 self.in_mut_refs = mut_refs;
93 self
94 }
95
96 pub fn hoist_inside_strg_refs(mut self, strg_refs: bool) -> Self {
97 self.in_strg_refs = strg_refs;
98 self
99 }
100
101 pub fn hoist_inside_tuples(mut self, tuples: bool) -> Self {
102 self.in_tuples = tuples;
103 self
104 }
105
106 pub fn hoist_inside_boxes(mut self, boxes: bool) -> Self {
107 self.in_boxes = boxes;
108 self
109 }
110
111 pub fn hoist_inside_downcast(mut self, downcast: bool) -> Self {
112 self.in_downcast = downcast;
113 self
114 }
115
116 pub fn hoist_existentials(mut self, exists: bool) -> Self {
117 self.existentials = exists;
118 self
119 }
120
121 pub fn hoist_slices(mut self, slices: bool) -> Self {
122 self.slices = slices;
123 self
124 }
125
126 pub fn transparent(self) -> Self {
127 self.hoist_inside_boxes(true)
128 .hoist_inside_downcast(true)
129 .hoist_inside_mut_refs(false)
130 .hoist_inside_shr_refs(true)
131 .hoist_inside_strg_refs(true)
132 .hoist_inside_tuples(true)
133 .hoist_slices(true)
134 }
135
136 pub fn shallow(self) -> Self {
137 self.hoist_inside_boxes(false)
138 .hoist_inside_downcast(false)
139 .hoist_inside_mut_refs(false)
140 .hoist_inside_shr_refs(false)
141 .hoist_inside_strg_refs(false)
142 .hoist_inside_tuples(false)
143 }
144}
145
146impl<D: HoisterDelegate> Hoister<D> {
147 pub fn hoist(&mut self, ty: &Ty) -> Ty {
148 ty.fold_with(self)
149 }
150}
151
152fn is_indexed_slice(ty: &Ty) -> bool {
156 let Some(bty) = ty.as_bty_skipping_existentials() else {
157 return false;
158 };
159 match bty {
160 BaseTy::Slice(_) => true,
161 BaseTy::Ref(_, ty, _) => is_indexed_slice(ty),
162 _ => false,
163 }
164}
165
166impl<D: HoisterDelegate> TypeFolder for Hoister<D> {
167 fn fold_ty(&mut self, ty: &Ty) -> Ty {
168 match ty.kind() {
169 TyKind::Indexed(bty, idx) => Ty::indexed(bty.fold_with(self), idx.clone()),
170 TyKind::Exists(ty_ctor) if self.existentials => {
171 match &ty_ctor.vars()[..] {
177 [BoundVariableKind::Refine(sort, ..)] => {
178 if sort.is_unit() {
179 ty_ctor.replace_bound_reft(&Expr::unit())
180 } else if let Some(def_id) = sort.is_unit_adt() {
181 ty_ctor.replace_bound_reft(&Expr::unit_struct(def_id))
182 } else {
183 self.delegate.hoist_exists(ty_ctor)
184 }
185 }
186 _ => self.delegate.hoist_exists(ty_ctor),
187 }
188 .fold_with(self)
189 }
190 TyKind::Constr(pred, ty) => {
191 self.delegate.hoist_constr(pred.clone());
192 ty.fold_with(self)
193 }
194 TyKind::StrgRef(..) if self.in_strg_refs => ty.super_fold_with(self),
195 TyKind::Downcast(..) if self.in_downcast => ty.super_fold_with(self),
196 _ => ty.clone(),
197 }
198 }
199
200 fn fold_bty(&mut self, bty: &BaseTy) -> BaseTy {
201 match bty {
202 BaseTy::Adt(adt_def, args) if adt_def.is_box() && self.in_boxes => {
203 let (boxed, alloc) = args.box_args();
204 let args = List::from_arr([
205 GenericArg::Ty(boxed.fold_with(self)),
206 GenericArg::Ty(alloc.clone()),
207 ]);
208 BaseTy::Adt(adt_def.clone(), args)
209 }
210 BaseTy::Ref(re, ty, mutability) if is_indexed_slice(ty) && self.slices => {
211 BaseTy::Ref(*re, ty.fold_with(self), *mutability)
212 }
213 BaseTy::Ref(re, ty, Mutability::Not) if self.in_shr_refs => {
214 BaseTy::Ref(*re, ty.fold_with(self), Mutability::Not)
215 }
216 BaseTy::Ref(re, ty, Mutability::Mut) if self.in_mut_refs => {
217 BaseTy::Ref(*re, ty.fold_with(self), Mutability::Mut)
218 }
219 BaseTy::Tuple(tys) if self.in_tuples => BaseTy::Tuple(tys.fold_with(self)),
220 _ => bty.clone(),
221 }
222 }
223}
224
225#[derive(Default)]
226pub struct LocalHoister {
227 vars: Vec<BoundVariableKind>,
228 preds: Vec<Expr>,
229}
230
231impl LocalHoister {
232 pub fn new(vars: Vec<BoundVariableKind>) -> Self {
233 LocalHoister { vars, preds: vec![] }
234 }
235
236 pub fn bind<T>(self, f: impl FnOnce(List<BoundVariableKind>, Vec<Expr>) -> T) -> Binder<T> {
237 let vars = List::from_vec(self.vars);
238 Binder::bind_with_vars(f(vars.clone(), self.preds), vars)
239 }
240}
241
242impl HoisterDelegate for &mut LocalHoister {
243 fn hoist_exists(&mut self, ty_ctor: &TyCtor) -> Ty {
244 ty_ctor.replace_bound_refts_with(|sort, mode, kind| {
245 let idx = self.vars.len();
246 self.vars
247 .push(BoundVariableKind::Refine(sort.clone(), mode, kind));
248 Expr::bvar(INNERMOST, BoundVar::from_usize(idx), kind)
249 })
250 }
251
252 fn hoist_constr(&mut self, pred: Expr) {
253 self.preds.push(pred);
254 }
255}
256
257impl PolyFnSig {
258 pub fn hoist_input_binders(&self) -> Self {
268 let original_vars = self.vars().to_vec();
269 let fn_sig = self.skip_binder_ref();
270 let mut delegate = LocalHoister { vars: original_vars, preds: fn_sig.requires().to_vec() };
271 let mut hoister = Hoister::with_delegate(&mut delegate).transparent();
272
273 let inputs = fn_sig
274 .inputs()
275 .iter()
276 .map(|ty| hoister.hoist(ty))
277 .collect_vec();
278
279 delegate.bind(|_vars, mut preds| {
280 let mut keep_hole = true;
281 preds.retain(|pred| {
282 if let ExprKind::Hole(HoleKind::Pred) = pred.kind() {
283 std::mem::replace(&mut keep_hole, false)
284 } else {
285 true
286 }
287 });
288
289 FnSig::new(
290 fn_sig.safety,
291 fn_sig.abi,
292 preds.into(),
293 inputs.into(),
294 fn_sig.output().clone(),
295 )
296 })
297 }
298}
299
300impl Ty {
301 pub fn shallow_canonicalize(&self) -> CanonicalTy {
304 let mut delegate = LocalHoister::default();
305 let ty = self.shift_in_escaping(1);
306 let ty = Hoister::with_delegate(&mut delegate).hoist(&ty);
307 let constr_ty = delegate.bind(|_, preds| {
308 let pred = Expr::and_from_iter(preds);
309 CanonicalConstrTy { ty, pred }
310 });
311 if constr_ty.vars().is_empty() {
312 CanonicalTy::Constr(constr_ty.skip_binder().shift_out_escaping(1))
313 } else {
314 CanonicalTy::Exists(constr_ty)
315 }
316 }
317}
318
319#[derive(TypeVisitable, TypeFoldable)]
320pub struct CanonicalConstrTy {
321 ty: Ty,
326 pred: Expr,
327}
328
329impl CanonicalConstrTy {
330 pub fn ty(&self) -> Ty {
331 self.ty.clone()
332 }
333
334 pub fn pred(&self) -> Expr {
335 self.pred.clone()
336 }
337
338 pub fn to_ty(&self) -> Ty {
339 Ty::constr(self.pred(), self.ty())
340 }
341}
342
343#[derive(TypeVisitable)]
351pub enum CanonicalTy {
352 Constr(CanonicalConstrTy),
354 Exists(Binder<CanonicalConstrTy>),
356}
357
358impl CanonicalTy {
359 pub fn to_ty(&self) -> Ty {
360 match self {
361 CanonicalTy::Constr(constr_ty) => constr_ty.to_ty(),
362 CanonicalTy::Exists(poly_constr_ty) => {
363 Ty::exists(poly_constr_ty.as_ref().map(CanonicalConstrTy::to_ty))
364 }
365 }
366 }
367
368 pub fn as_ty_or_base(&self) -> TyOrBase {
369 match self {
370 CanonicalTy::Constr(constr_ty) => {
371 if let TyKind::Indexed(bty, idx) = constr_ty.ty.kind() {
372 let pred = if idx.is_unit() {
379 constr_ty.pred.clone()
380 } else {
381 Expr::and(&constr_ty.pred, Expr::eq(Expr::nu(), idx.shift_in_escaping(1)))
382 };
383 let sort = bty.sort();
384 let constr = SubsetTy::new(bty.shift_in_escaping(1), Expr::nu(), pred);
385 TyOrBase::Base(Binder::bind_with_sort(constr, sort))
386 } else {
387 TyOrBase::Ty(self.to_ty())
388 }
389 }
390 CanonicalTy::Exists(poly_constr_ty) => {
391 let constr = poly_constr_ty.as_ref().skip_binder();
392 if let TyKind::Indexed(bty, idx) = constr.ty.kind()
393 && idx.is_nu()
394 {
395 let ctor = poly_constr_ty
396 .as_ref()
397 .map(|constr| SubsetTy::new(bty.clone(), idx, &constr.pred));
398 TyOrBase::Base(ctor)
399 } else {
400 TyOrBase::Ty(self.to_ty())
401 }
402 }
403 }
404 }
405}
406
407mod pretty {
408 use super::*;
409 use crate::pretty::*;
410
411 impl Pretty for CanonicalConstrTy {
412 fn fmt(&self, cx: &PrettyCx, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413 if self.pred().is_trivially_true() {
414 w!(cx, f, "{:?}", &self.ty)
415 } else {
416 w!(cx, f, "{{ {:?} | {:?} }}", &self.ty, &self.pred)
417 }
418 }
419 }
420
421 impl Pretty for CanonicalTy {
422 fn fmt(&self, cx: &PrettyCx, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
423 match self {
424 CanonicalTy::Constr(constr) => w!(cx, f, "{:?}", constr),
425 CanonicalTy::Exists(poly_constr) => {
426 let redundant_bvars = poly_constr.skip_binder_ref().redundant_bvars();
427 cx.with_bound_vars_removable(
428 poly_constr.vars(),
429 redundant_bvars,
430 None,
431 |f_body| {
432 let constr = poly_constr.skip_binder_ref();
433 if constr.pred().is_trivially_true() {
434 w!(cx, f_body, "{:?}", &constr.ty)
435 } else {
436 w!(cx, f_body, "{:?} | {:?}", &constr.ty, &constr.pred)
437 }
438 },
439 |(), bound_var_layer, body| {
440 let vars = poly_constr
441 .vars()
442 .into_iter()
443 .enumerate()
444 .filter_map(|(idx, var)| {
445 let not_removed = !bound_var_layer
446 .successfully_removed_vars
447 .contains(&BoundVar::from_usize(idx));
448 let refine_var = matches!(var, BoundVariableKind::Refine(..));
449 if not_removed && refine_var { Some(var.clone()) } else { None }
450 })
451 .collect_vec();
452 if vars.is_empty() {
453 write!(f, "{}", body)
454 } else {
455 let left = "{";
456 let right = format!(". {} }}", body);
457 cx.fmt_bound_vars(false, left, &vars, &right, f)
458 }
459 },
460 )
461 }
462 }
463 }
464 }
465
466 impl_debug_with_default_cx!(CanonicalTy, CanonicalConstrTy);
467}