flux_refineck/
type_env.rs

1mod place_ty;
2
3use std::{iter, ops::ControlFlow};
4
5use flux_common::{
6    bug,
7    dbg::{SpanTrace, debug_assert_eq3},
8    tracked_span_bug, tracked_span_dbg_assert_eq,
9};
10use flux_infer::{
11    fixpoint_encoding::KVarEncoding,
12    infer::{ConstrReason, InferCtxt, InferCtxtAt, InferCtxtRoot, InferResult},
13    refine_tree::Scope,
14};
15use flux_macros::DebugAsJson;
16use flux_middle::{
17    PlaceExt as _,
18    global_env::GlobalEnv,
19    pretty::{PrettyCx, PrettyNested},
20    queries::QueryResult,
21    rty::{
22        BaseTy, Binder, BoundReftKind, Ctor, Ensures, Expr, ExprKind, FnSig, GenericArg, HoleKind,
23        INNERMOST, Lambda, List, Loc, Mutability, Path, PtrKind, Region, SortCtor, SubsetTy, Ty,
24        TyKind, VariantIdx,
25        canonicalize::{Hoister, LocalHoister},
26        fold::{FallibleTypeFolder, TypeFoldable, TypeVisitable, TypeVisitor},
27        region_matching::{rty_match_regions, ty_match_regions},
28    },
29};
30use flux_rustc_bridge::{
31    self,
32    mir::{BasicBlock, Body, Local, LocalDecl, LocalDecls, Place, PlaceElem},
33    ty,
34};
35use itertools::{Itertools, izip};
36use rustc_data_structures::unord::UnordMap;
37use rustc_index::{IndexSlice, IndexVec};
38use rustc_middle::{mir::RETURN_PLACE, ty::TyCtxt};
39use rustc_span::{Span, Symbol};
40use rustc_type_ir::BoundVar;
41use serde::Serialize;
42
43use self::place_ty::{LocKind, PlacesTree};
44use super::rty::Sort;
45
46#[derive(Clone, Default)]
47pub struct TypeEnv<'a> {
48    bindings: PlacesTree,
49    local_decls: &'a LocalDecls,
50}
51
52pub struct BasicBlockEnvShape {
53    scope: Scope,
54    bindings: PlacesTree,
55}
56
57pub struct BasicBlockEnv {
58    data: Binder<BasicBlockEnvData>,
59    scope: Scope,
60}
61
62#[derive(Debug)]
63struct BasicBlockEnvData {
64    constrs: List<Expr>,
65    bindings: PlacesTree,
66}
67
68impl<'a> TypeEnv<'a> {
69    pub fn new(infcx: &mut InferCtxt, body: &'a Body, fn_sig: &FnSig) -> TypeEnv<'a> {
70        let mut env = TypeEnv { bindings: PlacesTree::default(), local_decls: &body.local_decls };
71
72        for requires in fn_sig.requires() {
73            infcx.assume_pred(requires);
74        }
75
76        for (local, ty) in body.args_iter().zip(fn_sig.inputs()) {
77            let ty = infcx.unpack(ty);
78            infcx.assume_invariants(&ty);
79            env.alloc_with_ty(local, ty);
80        }
81
82        for local in body.vars_and_temps_iter() {
83            env.alloc(local);
84        }
85
86        env.alloc(RETURN_PLACE);
87        env
88    }
89
90    pub fn empty() -> TypeEnv<'a> {
91        TypeEnv { bindings: PlacesTree::default(), local_decls: IndexSlice::empty() }
92    }
93
94    fn alloc_with_ty(&mut self, local: Local, ty: Ty) {
95        let ty = ty_match_regions(&ty, &self.local_decls[local].ty);
96        self.bindings.insert(local.into(), LocKind::Local, ty);
97    }
98
99    fn alloc(&mut self, local: Local) {
100        self.bindings
101            .insert(local.into(), LocKind::Local, Ty::uninit());
102    }
103
104    pub(crate) fn into_infer(self, scope: Scope) -> BasicBlockEnvShape {
105        BasicBlockEnvShape::new(scope, self)
106    }
107
108    pub(crate) fn lookup_rust_ty(&self, genv: GlobalEnv, place: &Place) -> QueryResult<ty::Ty> {
109        Ok(place.ty(genv, self.local_decls)?.ty)
110    }
111
112    pub(crate) fn lookup_place(
113        &mut self,
114        infcx: &mut InferCtxtAt,
115        place: &Place,
116    ) -> InferResult<Ty> {
117        let span = infcx.span;
118        let result = self.bindings.lookup_unfolding(infcx, place, span)?;
119        Ok(result.ty)
120    }
121
122    pub(crate) fn get(&self, path: &Path) -> Ty {
123        self.bindings.get(path)
124    }
125
126    pub fn update_path(&mut self, path: &Path, new_ty: Ty, span: Span) {
127        self.bindings.lookup(path, span).update(new_ty);
128    }
129
130    /// When checking a borrow in the right hand side of an assignment `x = &'?n p`, we use the
131    /// annotated region `'?n` in the type of the result. This region will only be used temporarily
132    /// and then replaced by the region in the type of `x` after the assignment. See [`TypeEnv::assign`]
133    pub(crate) fn borrow(
134        &mut self,
135        infcx: &mut InferCtxtAt,
136        re: Region,
137        mutbl: Mutability,
138        place: &Place,
139    ) -> InferResult<Ty> {
140        let span = infcx.span;
141        let result = self.bindings.lookup_unfolding(infcx, place, span)?;
142        if result.is_strg && mutbl == Mutability::Mut {
143            Ok(Ty::ptr(PtrKind::Mut(re), result.path()))
144        } else {
145            // FIXME(nilehmann) we should block the place here. That would require a notion
146            // of shared vs mutable block types because sometimes blocked places from a shared
147            // reference never get unblocked and we should still allow reads through them.
148            Ok(Ty::mk_ref(re, result.ty, mutbl))
149        }
150    }
151
152    // FIXME(nilehmann) this is only used in a single place and we have it because [`TypeEnv`]
153    // doesn't expose a lookup without unfolding
154    pub(crate) fn ptr_to_ref_at_place(
155        &mut self,
156        infcx: &mut InferCtxtAt,
157        place: &Place,
158    ) -> InferResult {
159        let lookup = self.bindings.lookup(place, infcx.span);
160        let TyKind::Ptr(PtrKind::Mut(re), path) = lookup.ty.kind() else {
161            tracked_span_bug!("ptr_to_borrow called on non mutable pointer type")
162        };
163
164        let ref_ty =
165            self.ptr_to_ref(infcx, ConstrReason::Other, *re, path, PtrToRefBound::Infer)?;
166
167        self.bindings.lookup(place, infcx.span).update(ref_ty);
168
169        Ok(())
170    }
171
172    /// Convert a (strong) pointer to a mutable reference.
173    ///
174    /// This roughly implements the following inference rule:
175    /// ```text
176    ///                   t₁ <: t₂
177    /// -------------------------------------------------
178    /// Γ₁,ℓ:t1,Γ₂ ; ptr(mut, ℓ) => Γ₁,ℓ:†t₂,Γ₂ ; &mut t2
179    /// ```
180    /// That's it, we first get the current type `t₁` at location `ℓ` and check it is a subtype
181    /// of `t₂`. Then, we update the type of `ℓ` to `t₂` and block the place.
182    ///
183    /// The bound `t₂` can be either inferred ([`PtrToRefBound::Infer`]), explicitly provided
184    /// ([`PtrToRefBound::Ty`]), or made equal to `t₁` ([`PtrToRefBound::Identity`]).
185    ///
186    /// As an example, consider the environment `x: i32[a]` and the pointer `ptr(mut, x)`.
187    /// Converting the pointer to a mutable reference with an inferred bound produces the following
188    /// derivation (roughly):
189    ///
190    /// ```text
191    ///                    i32[a] <: i32{v: $k(v)}
192    /// ----------------------------------------------------------------
193    /// x: i32[a] ; ptr(mut, x) => x:†i32{v: $k(v)} ; &mut i32{v: $k(v)}
194    /// ```
195    pub(crate) fn ptr_to_ref(
196        &mut self,
197        infcx: &mut InferCtxtAt,
198        reason: ConstrReason,
199        re: Region,
200        path: &Path,
201        bound: PtrToRefBound,
202    ) -> InferResult<Ty> {
203        // ℓ: t1
204        let t1 = self.bindings.lookup(path, infcx.span).fold(infcx)?;
205
206        // t1 <: t2
207        let t2 = match bound {
208            PtrToRefBound::Ty(t2) => {
209                let t2 = rty_match_regions(&t2, &t1);
210                infcx.subtyping(&t1, &t2, reason)?;
211                t2
212            }
213            PtrToRefBound::Infer => {
214                let t2 = t1.with_holes().replace_holes(|sorts, kind| {
215                    debug_assert_eq!(kind, HoleKind::Pred);
216                    infcx.fresh_kvar(sorts, KVarEncoding::Conj)
217                });
218                infcx.subtyping(&t1, &t2, reason)?;
219                t2
220            }
221            PtrToRefBound::Identity => t1.clone(),
222        };
223
224        // ℓ: †t2
225        self.bindings
226            .lookup(path, infcx.span)
227            .block_with(t2.clone());
228
229        Ok(Ty::mk_ref(re, t2, Mutability::Mut))
230    }
231
232    pub(crate) fn fold_local_ptrs(&mut self, infcx: &mut InferCtxtAt) -> InferResult {
233        for (loc, bound, ty) in self.bindings.local_ptrs() {
234            infcx.subtyping(&ty, &bound, ConstrReason::FoldLocal)?;
235            self.bindings.remove_local(&loc);
236        }
237        Ok(())
238    }
239
240    /// Updates the type of `place` to `new_ty`. This may involve a *strong update* if we have
241    /// ownership of `place` or a *weak update* if it's behind a reference (which fires a subtyping
242    /// constraint)
243    ///
244    /// When strong updating, the process involves recovering the original regions (lifetimes) used
245    /// in the (unrefined) Rust type of `place` and then substituting these regions in `new_ty`. For
246    /// instance, if we are assigning a value of type `S<&'?10 i32{v: v > 0}>` to a variable `x`,
247    /// and the (unrefined) Rust type of `x` is `S<&'?5 i32>`, before the assignment, we identify a
248    /// substitution that maps the region `'?10` to `'?5`. After applying this substitution, the
249    /// type of the place `x` is updated accordingly. This ensures that the lifetimes in the
250    /// assigned type are consistent with those expected by the place's original type definition.
251    pub(crate) fn assign(
252        &mut self,
253        infcx: &mut InferCtxtAt,
254        place: &Place,
255        new_ty: Ty,
256    ) -> InferResult {
257        let rustc_ty = place.ty(infcx.genv, self.local_decls)?.ty;
258        let new_ty = ty_match_regions(&new_ty, &rustc_ty);
259        let span = infcx.span;
260        let result = self.bindings.lookup_unfolding(infcx, place, span)?;
261        if result.is_strg {
262            result.update(new_ty);
263        } else if !place.behind_raw_ptr(infcx.genv, self.local_decls)? {
264            infcx.subtyping(&new_ty, &result.ty, ConstrReason::Assign)?;
265        }
266        Ok(())
267    }
268
269    pub(crate) fn move_place(&mut self, infcx: &mut InferCtxtAt, place: &Place) -> InferResult<Ty> {
270        let span = infcx.span;
271        let result = self.bindings.lookup_unfolding(infcx, place, span)?;
272        if result.is_strg {
273            let uninit = Ty::uninit();
274            Ok(result.update(uninit))
275        } else {
276            // ignore the 'move' and trust rustc managed the move correctly
277            // https://github.com/flux-rs/flux/issues/725#issuecomment-2295065634
278            Ok(result.ty)
279        }
280    }
281
282    pub(crate) fn unpack(&mut self, infcx: &mut InferCtxt) {
283        self.bindings
284            .fmap_mut(|_loc, ty| infcx.hoister(true).hoist(ty));
285    }
286
287    pub(crate) fn unblock(&mut self, infcx: &mut InferCtxt, place: &Place) {
288        self.bindings.unblock(infcx, place);
289    }
290
291    pub(crate) fn check_goto(
292        self,
293        infcx: &mut InferCtxtAt,
294        bb_env: &BasicBlockEnv,
295        target: BasicBlock,
296    ) -> InferResult {
297        infcx.ensure_resolved_evars(|infcx| {
298            let bb_env = bb_env
299                .data
300                .replace_bound_refts_with(|sort, mode, _| infcx.fresh_infer_var(sort, mode));
301
302            // Check constraints
303            for constr in &bb_env.constrs {
304                infcx.check_pred(constr, ConstrReason::Goto(target));
305            }
306
307            // Check subtyping
308            let bb_env = bb_env.bindings.flatten();
309            for (path, _, ty2) in bb_env {
310                let ty1 = self.bindings.get(&path);
311                infcx.subtyping(&ty1.unblocked(), &ty2.unblocked(), ConstrReason::Goto(target))?;
312            }
313            Ok(())
314        })
315    }
316
317    pub(crate) fn fold(&mut self, infcx: &mut InferCtxtAt, place: &Place) -> InferResult {
318        let span = infcx.span;
319        self.bindings.lookup(place, span).fold(infcx)?;
320        Ok(())
321    }
322
323    pub(crate) fn unfold_local_ptr(
324        &mut self,
325        infcx: &mut InferCtxt,
326        bound: &Ty,
327    ) -> InferResult<Loc> {
328        let name = infcx.define_unknown_var(&Sort::Loc);
329        let loc = Loc::from(name);
330        let ty = infcx.unpack(bound);
331        self.bindings
332            .insert(loc, LocKind::LocalPtr(bound.clone()), ty);
333        Ok(loc)
334    }
335
336    /// ```text
337    /// -----------------------------------
338    /// Γ ; &strg <ℓ: t> => Γ,ℓ: t ; ptr(ℓ)
339    /// ```
340    pub(crate) fn unfold_strg_ref(
341        &mut self,
342        infcx: &mut InferCtxt,
343        path: &Path,
344        ty: &Ty,
345    ) -> InferResult<Loc> {
346        if let Some(loc) = path.to_loc() {
347            let ty = infcx.unpack(ty);
348            self.bindings.insert(loc, LocKind::Universal, ty);
349            Ok(loc)
350        } else {
351            bug!("unfold_strg_ref: unexpected path {path:?}")
352        }
353    }
354
355    pub(crate) fn unfold(
356        &mut self,
357        infcx: &mut InferCtxt,
358        place: &Place,
359        span: Span,
360    ) -> InferResult {
361        self.bindings.unfold(infcx, place, span)
362    }
363
364    pub(crate) fn downcast(
365        &mut self,
366        infcx: &mut InferCtxtAt,
367        place: &Place,
368        variant_idx: VariantIdx,
369    ) -> InferResult {
370        let mut down_place = place.clone();
371        let span = infcx.span;
372        down_place
373            .projection
374            .push(PlaceElem::Downcast(None, variant_idx));
375        self.bindings.unfold(infcx, &down_place, span)?;
376        Ok(())
377    }
378
379    pub fn fully_resolve_evars(&mut self, infcx: &InferCtxt) {
380        self.bindings
381            .fmap_mut(|_loc, ty| infcx.fully_resolve_evars(ty));
382    }
383
384    pub(crate) fn assume_ensures(
385        &mut self,
386        infcx: &mut InferCtxt,
387        ensures: &[Ensures],
388        span: Span,
389    ) {
390        for ensure in ensures {
391            match ensure {
392                Ensures::Type(path, updated_ty) => {
393                    let updated_ty = infcx.unpack(updated_ty);
394                    infcx.assume_invariants(&updated_ty);
395                    self.update_path(path, updated_ty, span);
396                }
397                Ensures::Pred(e) => infcx.assume_pred(e),
398            }
399        }
400    }
401
402    pub(crate) fn check_ensures(
403        &mut self,
404        at: &mut InferCtxtAt,
405        ensures: &[Ensures],
406        reason: ConstrReason,
407    ) -> InferResult {
408        for constraint in ensures {
409            match constraint {
410                Ensures::Type(path, ty) => {
411                    let actual_ty = self.get(path).unblocked(); // HACK
412                    at.subtyping(&actual_ty, ty, reason)?;
413                }
414                Ensures::Pred(e) => {
415                    at.check_pred(e, ConstrReason::Ret);
416                }
417            }
418        }
419        Ok(())
420    }
421}
422
423pub(crate) enum PtrToRefBound {
424    Ty(Ty),
425    Infer,
426    Identity,
427}
428
429impl flux_infer::infer::LocEnv for TypeEnv<'_> {
430    fn ptr_to_ref(
431        &mut self,
432        infcx: &mut InferCtxtAt,
433        reason: ConstrReason,
434        re: Region,
435        path: &Path,
436        bound: Ty,
437    ) -> InferResult<Ty> {
438        self.ptr_to_ref(infcx, reason, re, path, PtrToRefBound::Ty(bound))
439    }
440
441    fn get(&self, path: &Path) -> Ty {
442        self.get(path)
443    }
444
445    fn unfold_strg_ref(&mut self, infcx: &mut InferCtxt, path: &Path, ty: &Ty) -> InferResult<Loc> {
446        self.unfold_strg_ref(infcx, path, ty)
447    }
448}
449
450impl BasicBlockEnvShape {
451    pub fn enter<'a>(&self, local_decls: &'a LocalDecls) -> TypeEnv<'a> {
452        TypeEnv { bindings: self.bindings.clone(), local_decls }
453    }
454
455    fn new(scope: Scope, env: TypeEnv) -> BasicBlockEnvShape {
456        let mut bindings = env.bindings;
457        bindings.fmap_mut(|_loc, ty| BasicBlockEnvShape::pack_ty(&scope, ty));
458        BasicBlockEnvShape { scope, bindings }
459    }
460
461    fn pack_ty(scope: &Scope, ty: &Ty) -> Ty {
462        match ty.kind() {
463            TyKind::Indexed(bty, idxs) => {
464                let bty = BasicBlockEnvShape::pack_bty(scope, bty);
465                if scope.has_free_vars(idxs) {
466                    Ty::exists_with_constr(bty, Expr::hole(HoleKind::Pred))
467                } else {
468                    Ty::indexed(bty, idxs.clone())
469                }
470            }
471            TyKind::Downcast(adt, args, ty, variant, fields) => {
472                debug_assert!(!scope.has_free_vars(args));
473                debug_assert!(!scope.has_free_vars(ty));
474                let fields = fields.iter().map(|ty| Self::pack_ty(scope, ty)).collect();
475                Ty::downcast(adt.clone(), args.clone(), ty.clone(), *variant, fields)
476            }
477            TyKind::Blocked(ty) => Ty::blocked(BasicBlockEnvShape::pack_ty(scope, ty)),
478            // FIXME(nilehmann) [`TyKind::Exists`] could also contain free variables.
479            TyKind::Exists(_)
480            | TyKind::Discr(..)
481            | TyKind::Ptr(..)
482            | TyKind::Uninit
483            | TyKind::Param(_)
484            | TyKind::Constr(_, _) => ty.clone(),
485            TyKind::Infer(_) => bug!("unexpected hole whecn checking function body"),
486            TyKind::StrgRef(..) => bug!("unexpected strong reference when checking function body"),
487        }
488    }
489
490    fn pack_bty(scope: &Scope, bty: &BaseTy) -> BaseTy {
491        match bty {
492            BaseTy::Adt(adt_def, args) => {
493                let args = List::from_vec(
494                    args.iter()
495                        .map(|arg| Self::pack_generic_arg(scope, arg))
496                        .collect(),
497                );
498                BaseTy::adt(adt_def.clone(), args)
499            }
500            BaseTy::FnDef(def_id, args) => {
501                let args = List::from_vec(
502                    args.iter()
503                        .map(|arg| Self::pack_generic_arg(scope, arg))
504                        .collect(),
505                );
506                BaseTy::fn_def(*def_id, args)
507            }
508            BaseTy::Tuple(tys) => {
509                let tys = tys
510                    .iter()
511                    .map(|ty| BasicBlockEnvShape::pack_ty(scope, ty))
512                    .collect();
513                BaseTy::Tuple(tys)
514            }
515            BaseTy::Slice(ty) => BaseTy::Slice(Self::pack_ty(scope, ty)),
516            BaseTy::Ref(r, ty, mutbl) => BaseTy::Ref(*r, Self::pack_ty(scope, ty), *mutbl),
517            BaseTy::Array(ty, c) => BaseTy::Array(Self::pack_ty(scope, ty), c.clone()),
518            BaseTy::Int(_)
519            | BaseTy::Param(_)
520            | BaseTy::Uint(_)
521            | BaseTy::Bool
522            | BaseTy::Float(_)
523            | BaseTy::Str
524            | BaseTy::RawPtr(_, _)
525            | BaseTy::RawPtrMetadata(_)
526            | BaseTy::Char
527            | BaseTy::Never
528            | BaseTy::Closure(..)
529            | BaseTy::Dynamic(..)
530            | BaseTy::Alias(..)
531            | BaseTy::FnPtr(..)
532            | BaseTy::Foreign(..)
533            | BaseTy::Coroutine(..) => {
534                if scope.has_free_vars(bty) {
535                    tracked_span_bug!("unexpected type with free vars")
536                } else {
537                    bty.clone()
538                }
539            }
540            BaseTy::Pat => {
541                todo!()
542            }
543            BaseTy::Infer(..) => {
544                tracked_span_bug!("unexpected infer type")
545            }
546        }
547    }
548
549    fn pack_generic_arg(scope: &Scope, arg: &GenericArg) -> GenericArg {
550        match arg {
551            GenericArg::Ty(ty) => GenericArg::Ty(Self::pack_ty(scope, ty)),
552            GenericArg::Base(arg) => {
553                assert!(!scope.has_free_vars(arg));
554                GenericArg::Base(arg.clone())
555            }
556            GenericArg::Lifetime(re) => GenericArg::Lifetime(*re),
557            GenericArg::Const(c) => GenericArg::Const(c.clone()),
558        }
559    }
560
561    fn update(&mut self, path: &Path, ty: Ty, span: Span) {
562        self.bindings.lookup(path, span).update(ty);
563    }
564
565    /// join(self, genv, other) consumes the bindings in other, to "update"
566    /// `self` in place, and returns `true` if there was an actual change
567    /// or `false` indicating no change (i.e., a fixpoint was reached).
568    pub(crate) fn join(&mut self, other: TypeEnv, span: Span) -> bool {
569        let paths = self.bindings.paths();
570
571        // Join types
572        let mut modified = false;
573        for path in &paths {
574            let ty1 = self.bindings.get(path);
575            let ty2 = other.bindings.get(path);
576            let ty = if ty1 == ty2 { ty1.clone() } else { self.join_ty(&ty1, &ty2) };
577            modified |= ty1 != ty;
578            self.update(path, ty, span);
579        }
580
581        modified
582    }
583
584    fn join_ty(&self, ty1: &Ty, ty2: &Ty) -> Ty {
585        match (ty1.kind(), ty2.kind()) {
586            (TyKind::Blocked(ty1), _) => Ty::blocked(self.join_ty(ty1, &ty2.unblocked())),
587            (_, TyKind::Blocked(ty2)) => Ty::blocked(self.join_ty(&ty1.unblocked(), ty2)),
588            (TyKind::Uninit, _) | (_, TyKind::Uninit) => Ty::uninit(),
589            (TyKind::Exists(ty1), _) => self.join_ty(ty1.as_ref().skip_binder(), ty2),
590            (_, TyKind::Exists(ty2)) => self.join_ty(ty1, ty2.as_ref().skip_binder()),
591            (TyKind::Constr(_, ty1), _) => self.join_ty(ty1, ty2),
592            (_, TyKind::Constr(_, ty2)) => self.join_ty(ty1, ty2),
593            (TyKind::Indexed(bty1, idx1), TyKind::Indexed(bty2, idx2)) => {
594                let bty = self.join_bty(bty1, bty2);
595                let mut sorts = vec![];
596                let idx = self.join_idx(idx1, idx2, &bty.sort(), &mut sorts);
597                if sorts.is_empty() {
598                    Ty::indexed(bty, idx)
599                } else {
600                    let ty = Ty::constr(Expr::hole(HoleKind::Pred), Ty::indexed(bty, idx));
601                    Ty::exists(Binder::bind_with_sorts(ty, &sorts))
602                }
603            }
604            (TyKind::Ptr(rk1, path1), TyKind::Ptr(rk2, path2)) => {
605                debug_assert_eq!(rk1, rk2);
606                debug_assert_eq!(path1, path2);
607                Ty::ptr(*rk1, path1.clone())
608            }
609            (TyKind::Param(param_ty1), TyKind::Param(param_ty2)) => {
610                debug_assert_eq!(param_ty1, param_ty2);
611                Ty::param(*param_ty1)
612            }
613            (
614                TyKind::Downcast(adt1, args1, ty1, variant1, fields1),
615                TyKind::Downcast(adt2, args2, ty2, variant2, fields2),
616            ) => {
617                debug_assert_eq!(adt1, adt2);
618                debug_assert_eq!(args1, args2);
619                debug_assert!(ty1 == ty2 && !self.scope.has_free_vars(ty2));
620                debug_assert_eq!(variant1, variant2);
621                debug_assert_eq!(fields1.len(), fields2.len());
622                let fields = iter::zip(fields1, fields2)
623                    .map(|(ty1, ty2)| self.join_ty(ty1, ty2))
624                    .collect();
625                Ty::downcast(adt1.clone(), args1.clone(), ty1.clone(), *variant1, fields)
626            }
627            _ => tracked_span_bug!("unexpected types: `{ty1:?}` - `{ty2:?}`"),
628        }
629    }
630
631    fn join_idx(&self, e1: &Expr, e2: &Expr, sort: &Sort, bound_sorts: &mut Vec<Sort>) -> Expr {
632        match (e1.kind(), e2.kind(), sort) {
633            (ExprKind::Tuple(es1), ExprKind::Tuple(es2), Sort::Tuple(sorts)) => {
634                debug_assert_eq3!(es1.len(), es2.len(), sorts.len());
635                Expr::tuple(
636                    izip!(es1, es2, sorts)
637                        .map(|(e1, e2, sort)| self.join_idx(e1, e2, sort, bound_sorts))
638                        .collect(),
639                )
640            }
641            (
642                ExprKind::Ctor(Ctor::Struct(_), flds1),
643                ExprKind::Ctor(Ctor::Struct(_), flds2),
644                Sort::App(SortCtor::Adt(sort_def), args),
645            ) => {
646                let sorts = sort_def.struct_variant().field_sorts(args);
647                debug_assert_eq3!(flds1.len(), flds2.len(), sorts.len());
648
649                Expr::ctor_struct(
650                    sort_def.did(),
651                    izip!(flds1, flds2, &sorts)
652                        .map(|(f1, f2, sort)| self.join_idx(f1, f2, sort, bound_sorts))
653                        .collect(),
654                )
655            }
656            _ => {
657                let has_free_vars2 = self.scope.has_free_vars(e2);
658                let has_escaping_vars1 = e1.has_escaping_bvars();
659                let has_escaping_vars2 = e2.has_escaping_bvars();
660                if !has_free_vars2 && !has_escaping_vars1 && !has_escaping_vars2 && e1 == e2 {
661                    e1.clone()
662                } else if sort.is_pred() {
663                    // FIXME(nilehmann) we shouldn't special case predicates here. Instead, we
664                    // should differentiate between generics and indices.
665                    let fsort = sort.expect_func().expect_mono();
666                    Expr::abs(Lambda::bind_with_fsort(Expr::hole(HoleKind::Pred), fsort))
667                } else {
668                    bound_sorts.push(sort.clone());
669                    Expr::bvar(
670                        INNERMOST,
671                        BoundVar::from_usize(bound_sorts.len() - 1),
672                        BoundReftKind::Anon,
673                    )
674                }
675            }
676        }
677    }
678
679    fn join_bty(&self, bty1: &BaseTy, bty2: &BaseTy) -> BaseTy {
680        match (bty1, bty2) {
681            (BaseTy::Adt(def1, args1), BaseTy::Adt(def2, args2)) => {
682                tracked_span_dbg_assert_eq!(def1.did(), def2.did());
683                let args = iter::zip(args1, args2)
684                    .map(|(arg1, arg2)| self.join_generic_arg(arg1, arg2))
685                    .collect();
686                BaseTy::adt(def1.clone(), List::from_vec(args))
687            }
688            (BaseTy::Tuple(fields1), BaseTy::Tuple(fields2)) => {
689                let fields = iter::zip(fields1, fields2)
690                    .map(|(ty1, ty2)| self.join_ty(ty1, ty2))
691                    .collect();
692                BaseTy::Tuple(fields)
693            }
694            (BaseTy::Alias(kind1, alias_ty1), BaseTy::Alias(kind2, alias_ty2)) => {
695                tracked_span_dbg_assert_eq!(kind1, kind2);
696                tracked_span_dbg_assert_eq!(alias_ty1, alias_ty2);
697                BaseTy::Alias(*kind1, alias_ty1.clone())
698            }
699            (BaseTy::Ref(r1, ty1, mutbl1), BaseTy::Ref(r2, ty2, mutbl2)) => {
700                tracked_span_dbg_assert_eq!(r1, r2);
701                tracked_span_dbg_assert_eq!(mutbl1, mutbl2);
702                BaseTy::Ref(*r1, self.join_ty(ty1, ty2), *mutbl1)
703            }
704            (BaseTy::Array(ty1, len1), BaseTy::Array(ty2, len2)) => {
705                tracked_span_dbg_assert_eq!(len1, len2);
706                BaseTy::Array(self.join_ty(ty1, ty2), len1.clone())
707            }
708            (BaseTy::Slice(ty1), BaseTy::Slice(ty2)) => BaseTy::Slice(self.join_ty(ty1, ty2)),
709            _ => {
710                tracked_span_dbg_assert_eq!(bty1, bty2);
711                bty1.clone()
712            }
713        }
714    }
715
716    fn join_generic_arg(&self, arg1: &GenericArg, arg2: &GenericArg) -> GenericArg {
717        match (arg1, arg2) {
718            (GenericArg::Ty(ty1), GenericArg::Ty(ty2)) => GenericArg::Ty(self.join_ty(ty1, ty2)),
719            (GenericArg::Base(ctor1), GenericArg::Base(ctor2)) => {
720                let sty1 = ctor1.as_ref().skip_binder();
721                let sty2 = ctor2.as_ref().skip_binder();
722                debug_assert!(sty1.idx.is_nu());
723                debug_assert!(sty2.idx.is_nu());
724
725                let bty = self.join_bty(&sty1.bty, &sty2.bty);
726                let pred = if self.scope.has_free_vars(&sty2.pred) || sty1.pred != sty2.pred {
727                    Expr::hole(HoleKind::Pred)
728                } else {
729                    sty1.pred.clone()
730                };
731                let sort = bty.sort();
732                let ctor = Binder::bind_with_sort(SubsetTy::new(bty, Expr::nu(), pred), sort);
733                GenericArg::Base(ctor)
734            }
735            (GenericArg::Lifetime(re1), GenericArg::Lifetime(_re2)) => {
736                // TODO(nilehmann) loop_abstract_refinement.rs is triggering this assertion to fail
737                // wee should fix it.
738                // debug_assert_eq!(re1, _re2);
739                GenericArg::Lifetime(*re1)
740            }
741            (GenericArg::Const(c1), GenericArg::Const(c2)) => {
742                debug_assert_eq!(c1, c2);
743                GenericArg::Const(c1.clone())
744            }
745            _ => tracked_span_bug!("unexpected generic args: `{arg1:?}` - `{arg2:?}`"),
746        }
747    }
748
749    pub fn into_bb_env(self, infcx: &mut InferCtxtRoot, body: &Body) -> BasicBlockEnv {
750        let mut delegate = LocalHoister::default();
751        let mut hoister = Hoister::with_delegate(&mut delegate).transparent();
752
753        let mut bindings = self.bindings;
754        bindings.fmap_mut(|loc, ty| {
755            let name = if let Loc::Local(local) = loc {
756                body.local_names.get(local).copied()
757            } else {
758                None
759            };
760            hoister.delegate.name = name;
761            hoister.hoist(ty)
762        });
763
764        BasicBlockEnv {
765            // We are relying on all the types in `bindings` not having escaping bvars, otherwise
766            // we would have to shift them in since we are creating a new binder.
767            data: delegate.bind(|vars, preds| {
768                // Replace all holes with a single fresh kvar on all parameters
769                let mut constrs = preds
770                    .into_iter()
771                    .filter(|pred| !matches!(pred.kind(), ExprKind::Hole(HoleKind::Pred)))
772                    .collect_vec();
773                let kvar = infcx.fresh_kvar_in_scope(
774                    std::slice::from_ref(&vars),
775                    &self.scope,
776                    KVarEncoding::Conj,
777                );
778                constrs.push(kvar);
779
780                // Replace remaining holes by fresh kvars
781                let mut kvar_gen = |binders: &[_], kind| {
782                    debug_assert_eq!(kind, HoleKind::Pred);
783                    let binders = std::iter::once(vars.clone())
784                        .chain(binders.iter().cloned())
785                        .collect_vec();
786                    infcx.fresh_kvar_in_scope(&binders, &self.scope, KVarEncoding::Conj)
787                };
788                bindings.fmap_mut(|_, binding| binding.replace_holes(&mut kvar_gen));
789
790                BasicBlockEnvData { constrs: constrs.into(), bindings }
791            }),
792            scope: self.scope,
793        }
794    }
795}
796
797impl TypeVisitable for BasicBlockEnvData {
798    fn visit_with<V: TypeVisitor>(&self, _visitor: &mut V) -> ControlFlow<V::BreakTy> {
799        unimplemented!()
800    }
801}
802
803impl TypeFoldable for BasicBlockEnvData {
804    fn try_fold_with<F: FallibleTypeFolder>(
805        &self,
806        folder: &mut F,
807    ) -> std::result::Result<Self, F::Error> {
808        Ok(BasicBlockEnvData {
809            constrs: self.constrs.try_fold_with(folder)?,
810            bindings: self.bindings.try_fold_with(folder)?,
811        })
812    }
813}
814
815impl BasicBlockEnv {
816    pub(crate) fn enter<'a>(
817        &self,
818        infcx: &mut InferCtxt,
819        local_decls: &'a LocalDecls,
820    ) -> TypeEnv<'a> {
821        let data = self.data.replace_bound_refts_with(|sort, _, kind| {
822            Expr::fvar(infcx.define_bound_reft_var(sort, kind))
823        });
824        for constr in &data.constrs {
825            infcx.assume_pred(constr);
826        }
827        TypeEnv { bindings: data.bindings, local_decls }
828    }
829
830    pub(crate) fn scope(&self) -> &Scope {
831        &self.scope
832    }
833}
834
835mod pretty {
836    use std::fmt;
837
838    use flux_middle::pretty::*;
839
840    use super::*;
841
842    impl Pretty for TypeEnv<'_> {
843        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
844            w!(cx, f, "{:?}", &self.bindings)
845        }
846
847        fn default_cx(tcx: TyCtxt) -> PrettyCx {
848            PlacesTree::default_cx(tcx)
849        }
850    }
851
852    impl Pretty for BasicBlockEnvShape {
853        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
854            w!(cx, f, "{:?} {:?}", &self.scope, &self.bindings)
855        }
856
857        fn default_cx(tcx: TyCtxt) -> PrettyCx {
858            PlacesTree::default_cx(tcx)
859        }
860    }
861
862    impl Pretty for BasicBlockEnv {
863        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
864            w!(cx, f, "{:?} ", &self.scope)?;
865
866            let vars = self.data.vars();
867            cx.with_bound_vars(vars, || {
868                if !vars.is_empty() {
869                    cx.fmt_bound_vars(true, "for<", vars, "> ", f)?;
870                }
871                let data = self.data.as_ref().skip_binder();
872                if !data.constrs.is_empty() {
873                    w!(
874                        cx,
875                        f,
876                        "{:?} ⇒ ",
877                        join!(", ", data.constrs.iter().filter(|pred| !pred.is_trivially_true()))
878                    )?;
879                }
880                w!(cx, f, "{:?}", &data.bindings)
881            })
882        }
883
884        fn default_cx(tcx: TyCtxt) -> PrettyCx {
885            PlacesTree::default_cx(tcx)
886        }
887    }
888
889    impl_debug_with_default_cx! {
890        TypeEnv<'_> => "type_env",
891        BasicBlockEnvShape => "basic_block_env_shape",
892        BasicBlockEnv => "basic_block_env"
893    }
894}
895
896/// A very explicit representation of [`TypeEnv`] for debugging/tracing/serialization ONLY.
897#[derive(Serialize, DebugAsJson)]
898pub struct TypeEnvTrace(Vec<TypeEnvBind>);
899
900#[derive(Serialize)]
901struct TypeEnvBind {
902    local: LocInfo,
903    name: Option<String>,
904    kind: String,
905    ty: String,
906    span: Option<SpanTrace>,
907}
908
909#[derive(Serialize)]
910enum LocInfo {
911    Local(String),
912    Var(String),
913}
914
915fn loc_info(loc: &Loc) -> LocInfo {
916    match loc {
917        Loc::Local(local) => LocInfo::Local(format!("{local:?}")),
918        Loc::Var(var) => LocInfo::Var(format!("{var:?}")),
919    }
920}
921
922fn loc_name(local_names: &UnordMap<Local, Symbol>, loc: &Loc) -> Option<String> {
923    if let Loc::Local(local) = loc {
924        let name = local_names.get(local)?;
925        return Some(format!("{name}"));
926    }
927    None
928}
929
930fn loc_span(
931    genv: GlobalEnv,
932    local_decls: &IndexVec<Local, LocalDecl>,
933    loc: &Loc,
934) -> Option<SpanTrace> {
935    if let Loc::Local(local) = loc {
936        return local_decls
937            .get(*local)
938            .map(|local_decl| SpanTrace::new(genv.tcx(), local_decl.source_info.span));
939    }
940    None
941}
942
943impl TypeEnvTrace {
944    pub fn new(
945        genv: GlobalEnv,
946        local_names: &UnordMap<Local, Symbol>,
947        local_decls: &IndexVec<Local, LocalDecl>,
948        cx: PrettyCx,
949        env: &TypeEnv,
950    ) -> Self {
951        let mut bindings = vec![];
952        env.bindings
953            .iter()
954            .filter(|(_, binding)| !binding.ty.is_uninit())
955            .sorted_by(|(loc1, _), (loc2, _)| loc1.cmp(loc2))
956            .for_each(|(loc, binding)| {
957                let name = loc_name(local_names, loc);
958                let local = loc_info(loc);
959                let kind = format!("{:?}", binding.kind);
960                let ty = binding.ty.nested_string(&cx);
961                let span = loc_span(genv, local_decls, loc);
962                bindings.push(TypeEnvBind { name, local, kind, ty, span });
963            });
964
965        TypeEnvTrace(bindings)
966    }
967}