1use std::fmt::Write;
34
35use flux_arc_interner::List;
36use flux_macros::{TypeFoldable, TypeVisitable};
37use itertools::Itertools;
38use rustc_ast::Mutability;
39use rustc_span::Symbol;
40use rustc_type_ir::{BoundVar, INNERMOST};
41
42use super::{
43 BaseTy, Binder, BoundVariableKind, Expr, FnSig, GenericArg, GenericArgsExt, PolyFnSig,
44 SubsetTy, Ty, TyCtor, TyKind, TyOrBase,
45 fold::{TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitable},
46};
47use crate::rty::{BoundReftKind, ExprKind, HoleKind};
48
49pub struct Hoister<D> {
56 pub delegate: D,
57 in_boxes: bool,
58 in_downcast: bool,
59 in_mut_refs: bool,
60 in_shr_refs: bool,
61 in_strg_refs: bool,
62 in_tuples: bool,
63 existentials: bool,
64 slices: bool,
65}
66
67pub trait HoisterDelegate {
68 fn hoist_exists(&mut self, ty_ctor: &TyCtor) -> Ty;
69 fn hoist_constr(&mut self, pred: Expr);
70}
71
72impl<D> Hoister<D> {
73 pub fn with_delegate(delegate: D) -> Self {
74 Hoister {
75 delegate,
76 in_tuples: false,
77 in_shr_refs: false,
78 in_mut_refs: false,
79 in_strg_refs: false,
80 in_boxes: false,
81 in_downcast: false,
82 existentials: true,
83 slices: false,
84 }
85 }
86
87 pub fn hoist_inside_shr_refs(mut self, shr_refs: bool) -> Self {
88 self.in_shr_refs = shr_refs;
89 self
90 }
91
92 pub fn hoist_inside_mut_refs(mut self, mut_refs: bool) -> Self {
93 self.in_mut_refs = mut_refs;
94 self
95 }
96
97 pub fn hoist_inside_strg_refs(mut self, strg_refs: bool) -> Self {
98 self.in_strg_refs = strg_refs;
99 self
100 }
101
102 pub fn hoist_inside_tuples(mut self, tuples: bool) -> Self {
103 self.in_tuples = tuples;
104 self
105 }
106
107 pub fn hoist_inside_boxes(mut self, boxes: bool) -> Self {
108 self.in_boxes = boxes;
109 self
110 }
111
112 pub fn hoist_inside_downcast(mut self, downcast: bool) -> Self {
113 self.in_downcast = downcast;
114 self
115 }
116
117 pub fn hoist_existentials(mut self, exists: bool) -> Self {
118 self.existentials = exists;
119 self
120 }
121
122 pub fn hoist_slices(mut self, slices: bool) -> Self {
123 self.slices = slices;
124 self
125 }
126
127 pub fn transparent(self) -> Self {
128 self.hoist_inside_boxes(true)
129 .hoist_inside_downcast(true)
130 .hoist_inside_mut_refs(false)
131 .hoist_inside_shr_refs(true)
132 .hoist_inside_strg_refs(true)
133 .hoist_inside_tuples(true)
134 .hoist_slices(true)
135 }
136
137 pub fn shallow(self) -> Self {
138 self.hoist_inside_boxes(false)
139 .hoist_inside_downcast(false)
140 .hoist_inside_mut_refs(false)
141 .hoist_inside_shr_refs(false)
142 .hoist_inside_strg_refs(false)
143 .hoist_inside_tuples(false)
144 }
145}
146
147impl<D: HoisterDelegate> Hoister<D> {
148 pub fn hoist(&mut self, ty: &Ty) -> Ty {
149 ty.fold_with(self)
150 }
151}
152
153fn is_indexed_slice(ty: &Ty) -> bool {
162 let Some(bty) = ty.as_bty_skipping_existentials() else { return false };
163 match bty {
164 BaseTy::Slice(_) | BaseTy::Array(..) => true,
165 BaseTy::Ref(_, ty, _) => is_indexed_slice(ty),
166 _ => false,
167 }
168}
169
170impl<D: HoisterDelegate> TypeFolder for Hoister<D> {
171 fn fold_ty(&mut self, ty: &Ty) -> Ty {
172 match ty.kind() {
173 TyKind::Indexed(bty, idx) => Ty::indexed(bty.fold_with(self), idx.clone()),
174 TyKind::Exists(ty_ctor) if self.existentials => {
175 match &ty_ctor.vars()[..] {
181 [BoundVariableKind::Refine(sort, ..)] => {
182 if sort.is_unit() {
183 ty_ctor.replace_bound_reft(&Expr::unit())
184 } else if let Some(def_id) = sort.is_unit_adt() {
185 ty_ctor.replace_bound_reft(&Expr::unit_struct(def_id))
186 } else {
187 self.delegate.hoist_exists(ty_ctor)
188 }
189 }
190 _ => self.delegate.hoist_exists(ty_ctor),
191 }
192 .fold_with(self)
193 }
194 TyKind::Constr(pred, ty) => {
195 self.delegate.hoist_constr(pred.clone());
196 ty.fold_with(self)
197 }
198 TyKind::StrgRef(..) if self.in_strg_refs => ty.super_fold_with(self),
199 TyKind::Downcast(..) if self.in_downcast => ty.super_fold_with(self),
200 _ => ty.clone(),
201 }
202 }
203
204 fn fold_bty(&mut self, bty: &BaseTy) -> BaseTy {
205 match bty {
206 BaseTy::Adt(adt_def, args) if adt_def.is_box() && self.in_boxes => {
207 let (boxed, alloc) = args.box_args();
208 let args = List::from_arr([
209 GenericArg::Ty(boxed.fold_with(self)),
210 GenericArg::Ty(alloc.clone()),
211 ]);
212 BaseTy::Adt(adt_def.clone(), args)
213 }
214 BaseTy::Ref(re, ty, mutability) if is_indexed_slice(ty) && self.slices => {
215 BaseTy::Ref(*re, ty.fold_with(self), *mutability)
216 }
217 BaseTy::Ref(re, ty, Mutability::Not) if self.in_shr_refs => {
218 BaseTy::Ref(*re, ty.fold_with(self), Mutability::Not)
219 }
220 BaseTy::Ref(re, ty, Mutability::Mut) if self.in_mut_refs => {
221 BaseTy::Ref(*re, ty.fold_with(self), Mutability::Mut)
222 }
223 BaseTy::Tuple(tys) if self.in_tuples => BaseTy::Tuple(tys.fold_with(self)),
224 _ => bty.clone(),
225 }
226 }
227}
228
229#[derive(Default)]
230pub struct LocalHoister {
231 vars: Vec<BoundVariableKind>,
232 preds: Vec<Expr>,
233 pub name: Option<Symbol>,
234}
235
236impl LocalHoister {
237 pub fn new(vars: Vec<BoundVariableKind>) -> Self {
238 LocalHoister { vars, preds: vec![], name: None }
239 }
240
241 pub fn bind<T>(self, f: impl FnOnce(List<BoundVariableKind>, Vec<Expr>) -> T) -> Binder<T> {
242 let vars = List::from_vec(self.vars);
243 Binder::bind_with_vars(f(vars.clone(), self.preds), vars)
244 }
245}
246
247impl HoisterDelegate for &mut LocalHoister {
248 fn hoist_exists(&mut self, ty_ctor: &TyCtor) -> Ty {
249 ty_ctor.replace_bound_refts_with(|sort, mode, kind| {
250 let idx = self.vars.len();
251 let kind = if let Some(name) = self.name { BoundReftKind::Named(name) } else { kind };
252 self.vars
253 .push(BoundVariableKind::Refine(sort.clone(), mode, kind));
254 Expr::bvar(INNERMOST, BoundVar::from_usize(idx), kind)
255 })
256 }
257
258 fn hoist_constr(&mut self, pred: Expr) {
259 self.preds.push(pred);
260 }
261}
262
263impl PolyFnSig {
264 pub fn hoist_input_binders(&self) -> Self {
274 let original_vars = self.vars().to_vec();
275 let fn_sig = self.skip_binder_ref();
276
277 let mut delegate =
278 LocalHoister { vars: original_vars, preds: fn_sig.requires().to_vec(), name: None };
279 let mut hoister = Hoister::with_delegate(&mut delegate).transparent();
280
281 let inputs = fn_sig
282 .inputs()
283 .iter()
284 .map(|ty| hoister.hoist(ty))
285 .collect_vec();
286
287 delegate.bind(|_vars, mut preds| {
288 let mut keep_hole = true;
289 preds.retain(|pred| {
290 if let ExprKind::Hole(HoleKind::Pred) = pred.kind() {
291 std::mem::replace(&mut keep_hole, false)
292 } else {
293 true
294 }
295 });
296
297 FnSig::new(
298 fn_sig.safety,
299 fn_sig.abi,
300 preds.into(),
301 inputs.into(),
302 fn_sig.output().clone(),
303 fn_sig.no_panic.clone(),
304 fn_sig.lifted,
305 )
306 })
307 }
308}
309
310impl Ty {
311 pub fn shallow_canonicalize(&self) -> CanonicalTy {
314 let mut delegate = LocalHoister::default();
315 let ty = self.shift_in_escaping(1);
316 let ty = Hoister::with_delegate(&mut delegate).hoist(&ty);
317 let constr_ty = delegate.bind(|_, preds| {
318 let pred = Expr::and_from_iter(preds);
319 CanonicalConstrTy { ty, pred }
320 });
321 if constr_ty.vars().is_empty() {
322 CanonicalTy::Constr(constr_ty.skip_binder().shift_out_escaping(1))
323 } else {
324 CanonicalTy::Exists(constr_ty)
325 }
326 }
327}
328
329#[derive(TypeVisitable, TypeFoldable)]
330pub struct CanonicalConstrTy {
331 ty: Ty,
336 pred: Expr,
337}
338
339impl CanonicalConstrTy {
340 pub fn ty(&self) -> Ty {
341 self.ty.clone()
342 }
343
344 pub fn pred(&self) -> Expr {
345 self.pred.clone()
346 }
347
348 pub fn to_ty(&self) -> Ty {
349 Ty::constr(self.pred(), self.ty())
350 }
351}
352
353#[derive(TypeVisitable)]
361pub enum CanonicalTy {
362 Constr(CanonicalConstrTy),
364 Exists(Binder<CanonicalConstrTy>),
366}
367
368impl CanonicalTy {
369 pub fn to_ty(&self) -> Ty {
370 match self {
371 CanonicalTy::Constr(constr_ty) => constr_ty.to_ty(),
372 CanonicalTy::Exists(poly_constr_ty) => {
373 Ty::exists(poly_constr_ty.as_ref().map(CanonicalConstrTy::to_ty))
374 }
375 }
376 }
377
378 pub fn as_ty_or_base(&self) -> TyOrBase {
379 match self {
380 CanonicalTy::Constr(constr_ty) => {
381 if let TyKind::Indexed(bty, idx) = constr_ty.ty.kind() {
382 let pred = if idx.is_unit() {
389 constr_ty.pred.clone()
390 } else {
391 Expr::and(&constr_ty.pred, Expr::eq(Expr::nu(), idx.shift_in_escaping(1)))
392 };
393 let sort = bty.sort();
394 let constr = SubsetTy::new(bty.shift_in_escaping(1), Expr::nu(), pred);
395 TyOrBase::Base(Binder::bind_with_sort(constr, sort))
396 } else {
397 TyOrBase::Ty(self.to_ty())
398 }
399 }
400 CanonicalTy::Exists(poly_constr_ty) => {
401 let constr = poly_constr_ty.as_ref().skip_binder();
402 if let TyKind::Indexed(bty, idx) = constr.ty.kind()
403 && idx.is_nu()
404 {
405 let ctor = poly_constr_ty
406 .as_ref()
407 .map(|constr| SubsetTy::new(bty.clone(), idx, &constr.pred));
408 TyOrBase::Base(ctor)
409 } else {
410 TyOrBase::Ty(self.to_ty())
411 }
412 }
413 }
414 }
415}
416
417mod pretty {
418 use super::*;
419 use crate::pretty::*;
420
421 impl Pretty for CanonicalConstrTy {
422 fn fmt(&self, cx: &PrettyCx, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
423 if self.pred().is_trivially_true() {
424 w!(cx, f, "{:?}", &self.ty)
425 } else {
426 w!(cx, f, "{{ {:?} | {:?} }}", &self.ty, &self.pred)
427 }
428 }
429 }
430
431 impl Pretty for CanonicalTy {
432 fn fmt(&self, cx: &PrettyCx, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
433 match self {
434 CanonicalTy::Constr(constr) => w!(cx, f, "{:?}", constr),
435 CanonicalTy::Exists(poly_constr) => {
436 let redundant_bvars = poly_constr.skip_binder_ref().redundant_bvars();
437 cx.with_bound_vars_removable(
438 poly_constr.vars(),
439 redundant_bvars,
440 None,
441 |f_body| {
442 let constr = poly_constr.skip_binder_ref();
443 if constr.pred().is_trivially_true() {
444 w!(cx, f_body, "{:?}", &constr.ty)
445 } else {
446 w!(cx, f_body, "{:?} | {:?}", &constr.ty, &constr.pred)
447 }
448 },
449 |(), bound_var_layer, body| {
450 let vars = poly_constr
451 .vars()
452 .into_iter()
453 .enumerate()
454 .filter_map(|(idx, var)| {
455 let not_removed = !bound_var_layer
456 .successfully_removed_vars
457 .contains(&BoundVar::from_usize(idx));
458 let refine_var = matches!(var, BoundVariableKind::Refine(..));
459 if not_removed && refine_var { Some(var.clone()) } else { None }
460 })
461 .collect_vec();
462 if vars.is_empty() {
463 write!(f, "{}", body)
464 } else {
465 let left = "{";
466 let right = format!(". {} }}", body);
467 cx.fmt_bound_vars(false, left, &vars, &right, f)
468 }
469 },
470 )
471 }
472 }
473 }
474 }
475
476 impl_debug_with_default_cx!(CanonicalTy, CanonicalConstrTy);
477}