1use flux_arc_interner::List;
34use flux_macros::{TypeFoldable, TypeVisitable};
35use itertools::Itertools;
36use rustc_ast::Mutability;
37use rustc_type_ir::{BoundVar, INNERMOST};
38
39use super::{
40 BaseTy, Binder, BoundVariableKind, Expr, FnSig, GenericArg, GenericArgsExt, PolyFnSig,
41 SubsetTy, Ty, TyCtor, TyKind, TyOrBase,
42 fold::{TypeFoldable, TypeFolder, TypeSuperFoldable},
43};
44use crate::rty::{ExprKind, HoleKind};
45
46pub struct Hoister<D> {
53 delegate: D,
54 in_boxes: bool,
55 in_downcast: bool,
56 in_mut_refs: bool,
57 in_shr_refs: bool,
58 in_strg_refs: bool,
59 in_tuples: bool,
60 existentials: bool,
61}
62
63pub trait HoisterDelegate {
64 fn hoist_exists(&mut self, ty_ctor: &TyCtor) -> Ty;
65 fn hoist_constr(&mut self, pred: Expr);
66}
67
68impl<D> Hoister<D> {
69 pub fn with_delegate(delegate: D) -> Self {
70 Hoister {
71 delegate,
72 in_tuples: false,
73 in_shr_refs: false,
74 in_mut_refs: false,
75 in_strg_refs: false,
76 in_boxes: false,
77 in_downcast: false,
78 existentials: true,
79 }
80 }
81
82 pub fn hoist_inside_shr_refs(mut self, shr_refs: bool) -> Self {
83 self.in_shr_refs = shr_refs;
84 self
85 }
86
87 pub fn hoist_inside_mut_refs(mut self, mut_refs: bool) -> Self {
88 self.in_mut_refs = mut_refs;
89 self
90 }
91
92 pub fn hoist_inside_strg_refs(mut self, strg_refs: bool) -> Self {
93 self.in_strg_refs = strg_refs;
94 self
95 }
96
97 pub fn hoist_inside_tuples(mut self, tuples: bool) -> Self {
98 self.in_tuples = tuples;
99 self
100 }
101
102 pub fn hoist_inside_boxes(mut self, boxes: bool) -> Self {
103 self.in_boxes = boxes;
104 self
105 }
106
107 pub fn hoist_inside_downcast(mut self, downcast: bool) -> Self {
108 self.in_downcast = downcast;
109 self
110 }
111
112 pub fn hoist_existentials(mut self, exists: bool) -> Self {
113 self.existentials = exists;
114 self
115 }
116
117 pub fn transparent(self) -> Self {
118 self.hoist_inside_boxes(true)
119 .hoist_inside_downcast(true)
120 .hoist_inside_mut_refs(false)
121 .hoist_inside_shr_refs(true)
122 .hoist_inside_strg_refs(true)
123 .hoist_inside_tuples(true)
124 }
125
126 pub fn shallow(self) -> Self {
127 self.hoist_inside_boxes(false)
128 .hoist_inside_downcast(false)
129 .hoist_inside_mut_refs(false)
130 .hoist_inside_shr_refs(false)
131 .hoist_inside_strg_refs(false)
132 .hoist_inside_tuples(false)
133 }
134}
135
136impl<D: HoisterDelegate> Hoister<D> {
137 pub fn hoist(&mut self, ty: &Ty) -> Ty {
138 ty.fold_with(self)
139 }
140}
141
142impl<D: HoisterDelegate> TypeFolder for Hoister<D> {
143 fn fold_ty(&mut self, ty: &Ty) -> Ty {
144 match ty.kind() {
145 TyKind::Indexed(bty, idx) => Ty::indexed(bty.fold_with(self), idx.clone()),
146 TyKind::Exists(ty_ctor) if self.existentials => {
147 match &ty_ctor.vars()[..] {
153 [BoundVariableKind::Refine(sort, ..)] => {
154 if sort.is_unit() {
155 ty_ctor.replace_bound_reft(&Expr::unit())
156 } else if let Some(def_id) = sort.is_unit_adt() {
157 ty_ctor.replace_bound_reft(&Expr::unit_struct(def_id))
158 } else {
159 self.delegate.hoist_exists(ty_ctor)
160 }
161 }
162 _ => self.delegate.hoist_exists(ty_ctor),
163 }
164 .fold_with(self)
165 }
166 TyKind::Constr(pred, ty) => {
167 self.delegate.hoist_constr(pred.clone());
168 ty.fold_with(self)
169 }
170 TyKind::StrgRef(..) if self.in_strg_refs => ty.super_fold_with(self),
171 TyKind::Downcast(..) if self.in_downcast => ty.super_fold_with(self),
172 _ => ty.clone(),
173 }
174 }
175
176 fn fold_bty(&mut self, bty: &BaseTy) -> BaseTy {
177 match bty {
178 BaseTy::Adt(adt_def, args) if adt_def.is_box() && self.in_boxes => {
179 let (boxed, alloc) = args.box_args();
180 let args = List::from_arr([
181 GenericArg::Ty(boxed.fold_with(self)),
182 GenericArg::Ty(alloc.clone()),
183 ]);
184 BaseTy::Adt(adt_def.clone(), args)
185 }
186 BaseTy::Ref(re, ty, Mutability::Not) if self.in_shr_refs => {
187 BaseTy::Ref(*re, ty.fold_with(self), Mutability::Not)
188 }
189 BaseTy::Ref(re, ty, Mutability::Mut) if self.in_mut_refs => {
190 BaseTy::Ref(*re, ty.fold_with(self), Mutability::Mut)
191 }
192 BaseTy::Tuple(tys) if self.in_tuples => BaseTy::Tuple(tys.fold_with(self)),
193 _ => bty.clone(),
194 }
195 }
196}
197
198#[derive(Default)]
199pub struct LocalHoister {
200 vars: Vec<BoundVariableKind>,
201 preds: Vec<Expr>,
202}
203
204impl LocalHoister {
205 pub fn new(vars: Vec<BoundVariableKind>) -> Self {
206 LocalHoister { vars, preds: vec![] }
207 }
208
209 pub fn bind<T>(self, f: impl FnOnce(List<BoundVariableKind>, Vec<Expr>) -> T) -> Binder<T> {
210 let vars = List::from_vec(self.vars);
211 Binder::bind_with_vars(f(vars.clone(), self.preds), vars)
212 }
213}
214
215impl HoisterDelegate for &mut LocalHoister {
216 fn hoist_exists(&mut self, ty_ctor: &TyCtor) -> Ty {
217 ty_ctor.replace_bound_refts_with(|sort, mode, kind| {
218 let idx = self.vars.len();
219 self.vars
220 .push(BoundVariableKind::Refine(sort.clone(), mode, kind));
221 Expr::bvar(INNERMOST, BoundVar::from_usize(idx), kind)
222 })
223 }
224
225 fn hoist_constr(&mut self, pred: Expr) {
226 self.preds.push(pred);
227 }
228}
229
230impl PolyFnSig {
231 pub fn hoist_input_binders(&self) -> Self {
241 let original_vars = self.vars().to_vec();
242 let fn_sig = self.skip_binder_ref();
243 let mut delegate = LocalHoister { vars: original_vars, preds: fn_sig.requires().to_vec() };
244 let mut hoister = Hoister::with_delegate(&mut delegate).transparent();
245
246 let inputs = fn_sig
247 .inputs()
248 .iter()
249 .map(|ty| hoister.hoist(ty))
250 .collect_vec();
251
252 delegate.bind(|_vars, mut preds| {
253 let mut keep_hole = true;
254 preds.retain(|pred| {
255 if let ExprKind::Hole(HoleKind::Pred) = pred.kind() {
256 std::mem::replace(&mut keep_hole, false)
257 } else {
258 true
259 }
260 });
261
262 FnSig::new(
263 fn_sig.safety,
264 fn_sig.abi,
265 preds.into(),
266 inputs.into(),
267 fn_sig.output().clone(),
268 )
269 })
270 }
271}
272
273impl Ty {
274 pub fn shallow_canonicalize(&self) -> CanonicalTy {
277 let mut delegate = LocalHoister::default();
278 let ty = self.shift_in_escaping(1);
279 let ty = Hoister::with_delegate(&mut delegate).hoist(&ty);
280 let constr_ty = delegate.bind(|_, preds| {
281 let pred = Expr::and_from_iter(preds);
282 CanonicalConstrTy { ty, pred }
283 });
284 if constr_ty.vars().is_empty() {
285 CanonicalTy::Constr(constr_ty.skip_binder().shift_out_escaping(1))
286 } else {
287 CanonicalTy::Exists(constr_ty)
288 }
289 }
290}
291
292#[derive(TypeVisitable, TypeFoldable)]
293pub struct CanonicalConstrTy {
294 ty: Ty,
299 pred: Expr,
300}
301
302impl CanonicalConstrTy {
303 pub fn ty(&self) -> Ty {
304 self.ty.clone()
305 }
306
307 pub fn pred(&self) -> Expr {
308 self.pred.clone()
309 }
310
311 pub fn to_ty(&self) -> Ty {
312 Ty::constr(self.pred(), self.ty())
313 }
314}
315
316pub enum CanonicalTy {
324 Constr(CanonicalConstrTy),
326 Exists(Binder<CanonicalConstrTy>),
328}
329
330impl CanonicalTy {
331 pub fn to_ty(&self) -> Ty {
332 match self {
333 CanonicalTy::Constr(constr_ty) => constr_ty.to_ty(),
334 CanonicalTy::Exists(poly_constr_ty) => {
335 Ty::exists(poly_constr_ty.as_ref().map(CanonicalConstrTy::to_ty))
336 }
337 }
338 }
339
340 pub fn as_ty_or_base(&self) -> TyOrBase {
341 match self {
342 CanonicalTy::Constr(constr_ty) => {
343 if let TyKind::Indexed(bty, idx) = constr_ty.ty.kind() {
344 let pred = if idx.is_unit() {
351 constr_ty.pred.clone()
352 } else {
353 Expr::and(&constr_ty.pred, Expr::eq(Expr::nu(), idx.shift_in_escaping(1)))
354 };
355 let sort = bty.sort();
356 let constr = SubsetTy::new(bty.shift_in_escaping(1), Expr::nu(), pred);
357 TyOrBase::Base(Binder::bind_with_sort(constr, sort))
358 } else {
359 TyOrBase::Ty(self.to_ty())
360 }
361 }
362 CanonicalTy::Exists(poly_constr_ty) => {
363 let constr = poly_constr_ty.as_ref().skip_binder();
364 if let TyKind::Indexed(bty, idx) = constr.ty.kind()
365 && idx.is_nu()
366 {
367 let ctor = poly_constr_ty
368 .as_ref()
369 .map(|constr| SubsetTy::new(bty.clone(), idx, &constr.pred));
370 TyOrBase::Base(ctor)
371 } else {
372 TyOrBase::Ty(self.to_ty())
373 }
374 }
375 }
376 }
377}
378
379mod pretty {
380 use super::*;
381 use crate::pretty::*;
382
383 impl Pretty for CanonicalConstrTy {
384 fn fmt(&self, cx: &PrettyCx, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
385 w!(cx, f, "{{ {:?} | {:?} }}", &self.ty, &self.pred)
386 }
387 }
388
389 impl Pretty for CanonicalTy {
390 fn fmt(&self, cx: &PrettyCx, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
391 match self {
392 CanonicalTy::Constr(constr) => w!(cx, f, "{:?}", constr),
393 CanonicalTy::Exists(poly_constr) => {
394 cx.with_bound_vars(poly_constr.vars(), || {
395 cx.fmt_bound_vars(false, "∃", poly_constr.vars(), ". ", f)?;
396 w!(cx, f, "{:?}", poly_constr.as_ref().skip_binder())
397 })
398 }
399 }
400 }
401 }
402
403 impl_debug_with_default_cx!(CanonicalTy, CanonicalConstrTy);
404}