1use flux_arc_interner::List;
34use flux_macros::{TypeFoldable, TypeVisitable};
35use itertools::Itertools;
36use rustc_ast::Mutability;
37use rustc_span::Symbol;
38use rustc_type_ir::{BoundVar, INNERMOST};
39
40use super::{
41 BaseTy, Binder, BoundVariableKind, Expr, FnSig, GenericArg, PolyFnSig, SubsetTy, Ty, TyCtor,
42 TyKind, TyOrBase,
43 fold::{TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitable},
44};
45use crate::rty::{BoundReftKind, ExprKind, GenericArgsExt, HoleKind};
46
47pub struct Hoister<D> {
54 pub delegate: D,
55 in_boxes: bool,
56 in_downcast: bool,
57 in_mut_refs: bool,
58 in_shr_refs: bool,
59 in_strg_refs: bool,
60 in_tuples: bool,
61 existentials: bool,
62 slices: bool,
63}
64
65pub trait HoisterDelegate {
66 fn hoist_exists(&mut self, ty_ctor: &TyCtor) -> Ty;
67 fn hoist_constr(&mut self, pred: Expr);
68}
69
70impl<D> Hoister<D> {
71 pub fn with_delegate(delegate: D) -> Self {
72 Hoister {
73 delegate,
74 in_tuples: false,
75 in_shr_refs: false,
76 in_mut_refs: false,
77 in_strg_refs: false,
78 in_boxes: false,
79 in_downcast: false,
80 existentials: true,
81 slices: false,
82 }
83 }
84
85 pub fn hoist_inside_shr_refs(mut self, shr_refs: bool) -> Self {
86 self.in_shr_refs = shr_refs;
87 self
88 }
89
90 pub fn hoist_inside_mut_refs(mut self, mut_refs: bool) -> Self {
91 self.in_mut_refs = mut_refs;
92 self
93 }
94
95 pub fn hoist_inside_strg_refs(mut self, strg_refs: bool) -> Self {
96 self.in_strg_refs = strg_refs;
97 self
98 }
99
100 pub fn hoist_inside_tuples(mut self, tuples: bool) -> Self {
101 self.in_tuples = tuples;
102 self
103 }
104
105 pub fn hoist_inside_boxes(mut self, boxes: bool) -> Self {
106 self.in_boxes = boxes;
107 self
108 }
109
110 pub fn hoist_inside_downcast(mut self, downcast: bool) -> Self {
111 self.in_downcast = downcast;
112 self
113 }
114
115 pub fn hoist_existentials(mut self, exists: bool) -> Self {
116 self.existentials = exists;
117 self
118 }
119
120 pub fn hoist_slices(mut self, slices: bool) -> Self {
121 self.slices = slices;
122 self
123 }
124
125 pub fn transparent(self) -> Self {
126 self.hoist_inside_boxes(true)
127 .hoist_inside_downcast(true)
128 .hoist_inside_mut_refs(false)
129 .hoist_inside_shr_refs(true)
130 .hoist_inside_strg_refs(true)
131 .hoist_inside_tuples(true)
132 .hoist_slices(true)
133 }
134
135 pub fn shallow(self) -> Self {
136 self.hoist_inside_boxes(false)
137 .hoist_inside_downcast(false)
138 .hoist_inside_mut_refs(false)
139 .hoist_inside_shr_refs(false)
140 .hoist_inside_strg_refs(false)
141 .hoist_inside_tuples(false)
142 }
143}
144
145impl<D: HoisterDelegate> Hoister<D> {
146 pub fn hoist(&mut self, ty: &Ty) -> Ty {
147 ty.fold_with(self)
148 }
149}
150
151fn is_indexed_slice(ty: &Ty) -> bool {
160 let Some(bty) = ty.as_bty_skipping_existentials() else { return false };
161 match bty {
162 BaseTy::Slice(_) | BaseTy::Array(..) => true,
163 BaseTy::Ref(_, ty, _) => is_indexed_slice(ty),
164 _ => false,
165 }
166}
167
168impl<D: HoisterDelegate> TypeFolder for Hoister<D> {
169 fn fold_ty(&mut self, ty: &Ty) -> Ty {
170 match ty.kind() {
171 TyKind::Indexed(bty, idx) => Ty::indexed(bty.fold_with(self), idx.clone()),
172 TyKind::Exists(ty_ctor) if self.existentials => {
173 match &ty_ctor.vars()[..] {
179 [BoundVariableKind::Refine(sort, ..)] => {
180 if sort.is_unit() {
181 ty_ctor.replace_bound_reft(&Expr::unit())
182 } else if let Some(def_id) = sort.is_unit_adt() {
183 ty_ctor.replace_bound_reft(&Expr::unit_struct(def_id))
184 } else {
185 self.delegate.hoist_exists(ty_ctor)
186 }
187 }
188 _ => self.delegate.hoist_exists(ty_ctor),
189 }
190 .fold_with(self)
191 }
192 TyKind::Constr(pred, ty) => {
193 self.delegate.hoist_constr(pred.clone());
194 ty.fold_with(self)
195 }
196 TyKind::StrgRef(..) if self.in_strg_refs => ty.super_fold_with(self),
197 TyKind::Downcast(..) if self.in_downcast => ty.super_fold_with(self),
198 _ => ty.clone(),
199 }
200 }
201
202 fn fold_bty(&mut self, bty: &BaseTy) -> BaseTy {
203 match bty {
204 BaseTy::Adt(adt_def, args) if adt_def.is_box() && self.in_boxes => {
205 let (boxed, alloc) = args.box_args();
206 let args = List::from_arr([GenericArg::Ty(boxed.fold_with(self)), alloc.clone()]);
207 BaseTy::Adt(adt_def.clone(), args)
208 }
209 BaseTy::Ref(re, ty, mutability) if is_indexed_slice(ty) && self.slices => {
210 BaseTy::Ref(*re, ty.fold_with(self), *mutability)
211 }
212 BaseTy::Ref(re, ty, Mutability::Not) if self.in_shr_refs => {
213 BaseTy::Ref(*re, ty.fold_with(self), Mutability::Not)
214 }
215 BaseTy::Ref(re, ty, Mutability::Mut) if self.in_mut_refs => {
216 BaseTy::Ref(*re, ty.fold_with(self), Mutability::Mut)
217 }
218 BaseTy::Tuple(tys) if self.in_tuples => BaseTy::Tuple(tys.fold_with(self)),
219 _ => bty.clone(),
220 }
221 }
222}
223
224#[derive(Default)]
225pub struct LocalHoister {
226 vars: Vec<BoundVariableKind>,
227 preds: Vec<Expr>,
228 pub name: Option<Symbol>,
229}
230
231impl LocalHoister {
232 pub fn new(vars: Vec<BoundVariableKind>) -> Self {
233 LocalHoister { vars, preds: vec![], name: None }
234 }
235
236 pub fn bind<T>(self, f: impl FnOnce(List<BoundVariableKind>, Vec<Expr>) -> T) -> Binder<T> {
237 let vars = List::from_vec(self.vars);
238 Binder::bind_with_vars(f(vars.clone(), self.preds), vars)
239 }
240}
241
242impl HoisterDelegate for &mut LocalHoister {
243 fn hoist_exists(&mut self, ty_ctor: &TyCtor) -> Ty {
244 ty_ctor.replace_bound_refts_with(|sort, mode, kind| {
245 let idx = self.vars.len();
246 let kind = if let Some(name) = self.name { BoundReftKind::Named(name) } else { kind };
247 self.vars
248 .push(BoundVariableKind::Refine(sort.clone(), mode, kind));
249 Expr::bvar(INNERMOST, BoundVar::from_usize(idx), kind)
250 })
251 }
252
253 fn hoist_constr(&mut self, pred: Expr) {
254 self.preds.push(pred);
255 }
256}
257
258impl PolyFnSig {
259 pub fn hoist_input_binders(&self) -> Self {
269 let original_vars = self.vars().to_vec();
270 let fn_sig = self.skip_binder_ref();
271
272 let mut delegate =
273 LocalHoister { vars: original_vars, preds: fn_sig.requires().to_vec(), name: None };
274 let mut hoister = Hoister::with_delegate(&mut delegate).transparent();
275
276 let inputs = fn_sig
277 .inputs()
278 .iter()
279 .map(|ty| hoister.hoist(ty))
280 .collect_vec();
281
282 delegate.bind(|_vars, mut preds| {
283 let mut keep_hole = true;
284 preds.retain(|pred| {
285 if let ExprKind::Hole(HoleKind::Pred) = pred.kind() {
286 std::mem::replace(&mut keep_hole, false)
287 } else {
288 true
289 }
290 });
291
292 FnSig::new(
293 fn_sig.safety,
294 fn_sig.abi,
295 preds.into(),
296 inputs.into(),
297 fn_sig.output().clone(),
298 fn_sig.no_panic.clone(),
299 fn_sig.lifted,
300 )
301 })
302 }
303}
304
305impl Ty {
306 pub fn shallow_canonicalize(&self) -> CanonicalTy {
309 let mut delegate = LocalHoister::default();
310 let ty = self.shift_in_escaping(1);
311 let ty = Hoister::with_delegate(&mut delegate).hoist(&ty);
312 let constr_ty = delegate.bind(|_, preds| {
313 let pred = Expr::and_from_iter(preds);
314 CanonicalConstrTy { ty, pred }
315 });
316 if constr_ty.vars().is_empty() {
317 CanonicalTy::Constr(constr_ty.skip_binder().shift_out_escaping(1))
318 } else {
319 CanonicalTy::Exists(constr_ty)
320 }
321 }
322}
323
324#[derive(TypeVisitable, TypeFoldable)]
325pub struct CanonicalConstrTy {
326 ty: Ty,
331 pred: Expr,
332}
333
334impl CanonicalConstrTy {
335 pub fn ty(&self) -> Ty {
336 self.ty.clone()
337 }
338
339 pub fn pred(&self) -> Expr {
340 self.pred.clone()
341 }
342
343 pub fn to_ty(&self) -> Ty {
344 Ty::constr(self.pred(), self.ty())
345 }
346}
347
348#[derive(TypeVisitable)]
356pub enum CanonicalTy {
357 Constr(CanonicalConstrTy),
359 Exists(Binder<CanonicalConstrTy>),
361}
362
363impl CanonicalTy {
364 pub fn to_ty(&self) -> Ty {
365 match self {
366 CanonicalTy::Constr(constr_ty) => constr_ty.to_ty(),
367 CanonicalTy::Exists(poly_constr_ty) => {
368 Ty::exists(poly_constr_ty.as_ref().map(CanonicalConstrTy::to_ty))
369 }
370 }
371 }
372
373 pub fn as_ty_or_base(&self) -> TyOrBase {
374 match self {
375 CanonicalTy::Constr(constr_ty) => {
376 if let TyKind::Indexed(bty, idx) = constr_ty.ty.kind() {
377 let pred = if idx.is_unit() {
384 constr_ty.pred.shift_in_escaping(1)
385 } else {
386 Expr::and(
387 constr_ty.pred.shift_in_escaping(1),
388 Expr::eq(Expr::nu(), idx.shift_in_escaping(1)),
389 )
390 };
391 let sort = bty.sort();
392 let constr = SubsetTy::new(bty.shift_in_escaping(1), Expr::nu(), pred);
393 TyOrBase::Base(Binder::bind_with_sort(constr, sort))
394 } else {
395 TyOrBase::Ty(self.to_ty())
396 }
397 }
398 CanonicalTy::Exists(poly_constr_ty) => {
399 let constr = poly_constr_ty.as_ref().skip_binder();
400 if let TyKind::Indexed(bty, idx) = constr.ty.kind()
401 && idx.is_nu()
402 {
403 let ctor = poly_constr_ty
404 .as_ref()
405 .map(|constr| SubsetTy::new(bty.clone(), idx, &constr.pred));
406 TyOrBase::Base(ctor)
407 } else {
408 TyOrBase::Ty(self.to_ty())
409 }
410 }
411 }
412 }
413}
414
415mod pretty {
416 use super::*;
417 use crate::pretty::*;
418
419 impl Pretty for CanonicalConstrTy {
420 fn fmt(&self, cx: &PrettyCx, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
421 if self.pred().is_trivially_true() {
422 w!(cx, f, "{:?}", &self.ty)
423 } else {
424 w!(cx, f, "{{ {:?} | {:?} }}", &self.ty, &self.pred)
425 }
426 }
427 }
428
429 impl Pretty for CanonicalTy {
430 fn fmt(&self, cx: &PrettyCx, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
431 match self {
432 CanonicalTy::Constr(constr) => w!(cx, f, "{:?}", constr),
433 CanonicalTy::Exists(poly_constr) => {
434 let redundant_bvars = poly_constr.skip_binder_ref().redundant_bvars();
435 cx.with_bound_vars_removable(poly_constr.vars(), redundant_bvars, None, || {
436 let constr = poly_constr.skip_binder_ref();
437 let ty_fmt = format_cx!(cx, "{:?}", &constr.ty);
438 let pred_fmt = if !constr.pred().is_trivially_true() {
439 Some(format_cx!(cx, "{:?}", constr.pred()))
440 } else {
441 None
442 };
443
444 let vars = cx
445 .bvar_env
446 .peek_layer()
447 .unwrap()
448 .filter_vars(poly_constr.vars());
449
450 if vars.is_empty() {
451 if let Some(pred_fmt) = pred_fmt {
452 write!(f, "{{ {ty_fmt} | {pred_fmt} }}")
453 } else {
454 write!(f, "{ty_fmt}")
455 }
456 } else {
457 cx.fmt_bound_vars(false, "{", &vars, ". ", f)?;
458 if let Some(pred_fmt) = pred_fmt {
459 write!(f, "{ty_fmt} | {pred_fmt} }}")
460 } else {
461 write!(f, "{ty_fmt} }}")
462 }
463 }
464 })
465 }
466 }
467 }
468 }
469
470 impl_debug_with_default_cx!(CanonicalTy, CanonicalConstrTy);
471}