flux_refineck/
type_env.rs

1mod place_ty;
2
3use std::{iter, ops::ControlFlow};
4
5use flux_common::{bug, dbg::debug_assert_eq3, tracked_span_bug, tracked_span_dbg_assert_eq};
6use flux_infer::{
7    fixpoint_encoding::KVarEncoding,
8    infer::{ConstrReason, InferCtxt, InferCtxtAt, InferCtxtRoot, InferResult},
9    refine_tree::Scope,
10};
11use flux_macros::DebugAsJson;
12use flux_middle::{
13    PlaceExt as _,
14    global_env::GlobalEnv,
15    pretty::{PrettyCx, PrettyNested},
16    queries::QueryResult,
17    rty::{
18        BaseTy, Binder, BoundReftKind, Ctor, Ensures, Expr, ExprKind, FnSig, GenericArg, HoleKind,
19        INNERMOST, Lambda, List, Loc, Mutability, Path, PtrKind, Region, SortCtor, SubsetTy, Ty,
20        TyKind, VariantIdx,
21        canonicalize::{Hoister, LocalHoister},
22        fold::{FallibleTypeFolder, TypeFoldable, TypeVisitable, TypeVisitor},
23        region_matching::{rty_match_regions, ty_match_regions},
24    },
25};
26use flux_rustc_bridge::{
27    self,
28    mir::{BasicBlock, Body, Local, LocalDecl, LocalDecls, Place, PlaceElem},
29    ty,
30};
31use itertools::{Itertools, izip};
32use rustc_data_structures::unord::UnordMap;
33use rustc_index::{IndexSlice, IndexVec};
34use rustc_middle::{mir::RETURN_PLACE, ty::TyCtxt};
35use rustc_span::{Span, Symbol};
36use rustc_type_ir::BoundVar;
37use serde::Serialize;
38
39use self::place_ty::{LocKind, PlacesTree};
40use super::rty::Sort;
41
42#[derive(Clone, Default)]
43pub struct TypeEnv<'a> {
44    bindings: PlacesTree,
45    local_decls: &'a LocalDecls,
46}
47
48pub struct BasicBlockEnvShape {
49    scope: Scope,
50    bindings: PlacesTree,
51}
52
53pub struct BasicBlockEnv {
54    data: Binder<BasicBlockEnvData>,
55    scope: Scope,
56}
57
58#[derive(Debug)]
59struct BasicBlockEnvData {
60    constrs: List<Expr>,
61    bindings: PlacesTree,
62}
63
64impl<'a> TypeEnv<'a> {
65    pub fn new(infcx: &mut InferCtxt, body: &'a Body, fn_sig: &FnSig) -> TypeEnv<'a> {
66        let mut env = TypeEnv { bindings: PlacesTree::default(), local_decls: &body.local_decls };
67
68        for requires in fn_sig.requires() {
69            infcx.assume_pred(requires);
70        }
71
72        for (local, ty) in body.args_iter().zip(fn_sig.inputs()) {
73            let ty = infcx.unpack(ty);
74            infcx.assume_invariants(&ty);
75            env.alloc_with_ty(local, ty);
76        }
77
78        for local in body.vars_and_temps_iter() {
79            env.alloc(local);
80        }
81
82        env.alloc(RETURN_PLACE);
83        env
84    }
85
86    pub fn empty() -> TypeEnv<'a> {
87        TypeEnv { bindings: PlacesTree::default(), local_decls: IndexSlice::empty() }
88    }
89
90    fn alloc_with_ty(&mut self, local: Local, ty: Ty) {
91        let ty = ty_match_regions(&ty, &self.local_decls[local].ty);
92        self.bindings.insert(local.into(), LocKind::Local, ty);
93    }
94
95    fn alloc(&mut self, local: Local) {
96        self.bindings
97            .insert(local.into(), LocKind::Local, Ty::uninit());
98    }
99
100    pub(crate) fn into_infer(self, scope: Scope) -> BasicBlockEnvShape {
101        BasicBlockEnvShape::new(scope, self)
102    }
103
104    pub(crate) fn lookup_rust_ty(&self, genv: GlobalEnv, place: &Place) -> QueryResult<ty::Ty> {
105        Ok(place.ty(genv, self.local_decls)?.ty)
106    }
107
108    pub(crate) fn lookup_place(
109        &mut self,
110        infcx: &mut InferCtxtAt,
111        place: &Place,
112    ) -> InferResult<Ty> {
113        let span = infcx.span;
114        Ok(self.bindings.lookup_unfolding(infcx, place, span)?.ty)
115    }
116
117    pub(crate) fn get(&self, path: &Path) -> Ty {
118        self.bindings.get(path)
119    }
120
121    pub fn update_path(&mut self, path: &Path, new_ty: Ty, span: Span) {
122        self.bindings.lookup(path, span).update(new_ty);
123    }
124
125    /// When checking a borrow in the right hand side of an assignment `x = &'?n p`, we use the
126    /// annotated region `'?n` in the type of the result. This region will only be used temporarily
127    /// and then replaced by the region in the type of `x` after the assignment. See [`TypeEnv::assign`]
128    pub(crate) fn borrow(
129        &mut self,
130        infcx: &mut InferCtxtAt,
131        re: Region,
132        mutbl: Mutability,
133        place: &Place,
134    ) -> InferResult<Ty> {
135        let span = infcx.span;
136        let result = self.bindings.lookup_unfolding(infcx, place, span)?;
137        if result.is_strg && mutbl == Mutability::Mut {
138            Ok(Ty::ptr(PtrKind::Mut(re), result.path()))
139        } else {
140            // FIXME(nilehmann) we should block the place here. That would require a notion
141            // of shared vs mutable block types because sometimes blocked places from a shared
142            // reference never get unblocked and we should still allow reads through them.
143            Ok(Ty::mk_ref(re, result.ty, mutbl))
144        }
145    }
146
147    // FIXME(nilehmann) this is only used in a single place and we have it because [`TypeEnv`]
148    // doesn't expose a lookup without unfolding
149    pub(crate) fn ptr_to_ref_at_place(
150        &mut self,
151        infcx: &mut InferCtxtAt,
152        place: &Place,
153    ) -> InferResult {
154        let lookup = self.bindings.lookup(place, infcx.span);
155        let TyKind::Ptr(PtrKind::Mut(re), path) = lookup.ty.kind() else {
156            tracked_span_bug!("ptr_to_borrow called on non mutable pointer type")
157        };
158
159        let ref_ty =
160            self.ptr_to_ref(infcx, ConstrReason::Other, *re, path, PtrToRefBound::Infer)?;
161
162        self.bindings.lookup(place, infcx.span).update(ref_ty);
163
164        Ok(())
165    }
166
167    /// Convert a (strong) pointer to a mutable reference.
168    ///
169    /// This roughly implements the following inference rule:
170    /// ```text
171    ///                   t₁ <: t₂
172    /// -------------------------------------------------
173    /// Γ₁,ℓ:t1,Γ₂ ; ptr(mut, ℓ) => Γ₁,ℓ:†t₂,Γ₂ ; &mut t2
174    /// ```
175    /// That's it, we first get the current type `t₁` at location `ℓ` and check it is a subtype
176    /// of `t₂`. Then, we update the type of `ℓ` to `t₂` and block the place.
177    ///
178    /// The bound `t₂` can be either inferred ([`PtrToRefBound::Infer`]), explicitly provided
179    /// ([`PtrToRefBound::Ty`]), or made equal to `t₁` ([`PtrToRefBound::Identity`]).
180    ///
181    /// As an example, consider the environment `x: i32[a]` and the pointer `ptr(mut, x)`.
182    /// Converting the pointer to a mutable reference with an inferred bound produces the following
183    /// derivation (roughly):
184    ///
185    /// ```text
186    ///                    i32[a] <: i32{v: $k(v)}
187    /// ----------------------------------------------------------------
188    /// x: i32[a] ; ptr(mut, x) => x:†i32{v: $k(v)} ; &mut i32{v: $k(v)}
189    /// ```
190    pub(crate) fn ptr_to_ref(
191        &mut self,
192        infcx: &mut InferCtxtAt,
193        reason: ConstrReason,
194        re: Region,
195        path: &Path,
196        bound: PtrToRefBound,
197    ) -> InferResult<Ty> {
198        // ℓ: t1
199        let t1 = self.bindings.lookup(path, infcx.span).fold(infcx)?;
200
201        // t1 <: t2
202        let t2 = match bound {
203            PtrToRefBound::Ty(t2) => {
204                let t2 = rty_match_regions(&t2, &t1);
205                infcx.subtyping(&t1, &t2, reason)?;
206                t2
207            }
208            PtrToRefBound::Infer => {
209                let t2 = t1.with_holes().replace_holes(|sorts, kind| {
210                    debug_assert_eq!(kind, HoleKind::Pred);
211                    infcx.fresh_kvar(sorts, KVarEncoding::Conj)
212                });
213                infcx.subtyping(&t1, &t2, reason)?;
214                t2
215            }
216            PtrToRefBound::Identity => t1.clone(),
217        };
218
219        // ℓ: †t2
220        self.bindings
221            .lookup(path, infcx.span)
222            .block_with(t2.clone());
223
224        Ok(Ty::mk_ref(re, t2, Mutability::Mut))
225    }
226
227    pub(crate) fn fold_local_ptrs(&mut self, infcx: &mut InferCtxtAt) -> InferResult {
228        for (loc, bound, ty) in self.bindings.local_ptrs() {
229            infcx.subtyping(&ty, &bound, ConstrReason::FoldLocal)?;
230            self.bindings.remove_local(&loc);
231        }
232        Ok(())
233    }
234
235    /// Updates the type of `place` to `new_ty`. This may involve a *strong update* if we have
236    /// ownership of `place` or a *weak update* if it's behind a reference (which fires a subtyping
237    /// constraint)
238    ///
239    /// When strong updating, the process involves recovering the original regions (lifetimes) used
240    /// in the (unrefined) Rust type of `place` and then substituting these regions in `new_ty`. For
241    /// instance, if we are assigning a value of type `S<&'?10 i32{v: v > 0}>` to a variable `x`,
242    /// and the (unrefined) Rust type of `x` is `S<&'?5 i32>`, before the assignment, we identify a
243    /// substitution that maps the region `'?10` to `'?5`. After applying this substitution, the
244    /// type of the place `x` is updated accordingly. This ensures that the lifetimes in the
245    /// assigned type are consistent with those expected by the place's original type definition.
246    pub(crate) fn assign(
247        &mut self,
248        infcx: &mut InferCtxtAt,
249        place: &Place,
250        new_ty: Ty,
251    ) -> InferResult {
252        let rustc_ty = place.ty(infcx.genv, self.local_decls)?.ty;
253        let new_ty = ty_match_regions(&new_ty, &rustc_ty);
254        let span = infcx.span;
255        let result = self.bindings.lookup_unfolding(infcx, place, span)?;
256        if result.is_strg {
257            result.update(new_ty);
258        } else if !place.behind_raw_ptr(infcx.genv, self.local_decls)? {
259            infcx.subtyping(&new_ty, &result.ty, ConstrReason::Assign)?;
260        }
261        Ok(())
262    }
263
264    pub(crate) fn move_place(&mut self, infcx: &mut InferCtxtAt, place: &Place) -> InferResult<Ty> {
265        let span = infcx.span;
266        let result = self.bindings.lookup_unfolding(infcx, place, span)?;
267        if result.is_strg {
268            let uninit = Ty::uninit();
269            Ok(result.update(uninit))
270        } else {
271            // ignore the 'move' and trust rustc managed the move correctly
272            // https://github.com/flux-rs/flux/issues/725#issuecomment-2295065634
273            Ok(result.ty)
274        }
275    }
276
277    pub(crate) fn unpack(&mut self, infcx: &mut InferCtxt) {
278        self.bindings.fmap_mut(|ty| infcx.hoister(true).hoist(ty));
279    }
280
281    pub(crate) fn unblock(&mut self, infcx: &mut InferCtxt, place: &Place) {
282        self.bindings.unblock(infcx, place);
283    }
284
285    pub(crate) fn check_goto(
286        self,
287        infcx: &mut InferCtxtAt,
288        bb_env: &BasicBlockEnv,
289        target: BasicBlock,
290    ) -> InferResult {
291        infcx.ensure_resolved_evars(|infcx| {
292            let bb_env = bb_env
293                .data
294                .replace_bound_refts_with(|sort, mode, _| infcx.fresh_infer_var(sort, mode));
295
296            // Check constraints
297            for constr in &bb_env.constrs {
298                infcx.check_pred(constr, ConstrReason::Goto(target));
299            }
300
301            // Check subtyping
302            let bb_env = bb_env.bindings.flatten();
303            for (path, _, ty2) in bb_env {
304                let ty1 = self.bindings.get(&path);
305                infcx.subtyping(&ty1.unblocked(), &ty2.unblocked(), ConstrReason::Goto(target))?;
306            }
307            Ok(())
308        })
309    }
310
311    pub(crate) fn fold(&mut self, infcx: &mut InferCtxtAt, place: &Place) -> InferResult {
312        let span = infcx.span;
313        self.bindings.lookup(place, span).fold(infcx)?;
314        Ok(())
315    }
316
317    pub(crate) fn unfold_local_ptr(
318        &mut self,
319        infcx: &mut InferCtxt,
320        bound: &Ty,
321    ) -> InferResult<Loc> {
322        let loc = Loc::from(infcx.define_var(&Sort::Loc));
323        let ty = infcx.unpack(bound);
324        self.bindings
325            .insert(loc, LocKind::LocalPtr(bound.clone()), ty);
326        Ok(loc)
327    }
328
329    /// ```text
330    /// -----------------------------------
331    /// Γ ; &strg <ℓ: t> => Γ,ℓ: t ; ptr(ℓ)
332    /// ```
333    pub(crate) fn unfold_strg_ref(
334        &mut self,
335        infcx: &mut InferCtxt,
336        path: &Path,
337        ty: &Ty,
338    ) -> InferResult<Loc> {
339        if let Some(loc) = path.to_loc() {
340            let ty = infcx.unpack(ty);
341            self.bindings.insert(loc, LocKind::Universal, ty);
342            Ok(loc)
343        } else {
344            bug!("unfold_strg_ref: unexpected path {path:?}")
345        }
346    }
347
348    pub(crate) fn unfold(
349        &mut self,
350        infcx: &mut InferCtxt,
351        place: &Place,
352        span: Span,
353    ) -> InferResult {
354        self.bindings.unfold(infcx, place, span)
355    }
356
357    pub(crate) fn downcast(
358        &mut self,
359        infcx: &mut InferCtxtAt,
360        place: &Place,
361        variant_idx: VariantIdx,
362    ) -> InferResult {
363        let mut down_place = place.clone();
364        let span = infcx.span;
365        down_place
366            .projection
367            .push(PlaceElem::Downcast(None, variant_idx));
368        self.bindings.unfold(infcx, &down_place, span)?;
369        Ok(())
370    }
371
372    pub fn fully_resolve_evars(&mut self, infcx: &InferCtxt) {
373        self.bindings.fmap_mut(|ty| infcx.fully_resolve_evars(ty));
374    }
375
376    pub(crate) fn assume_ensures(
377        &mut self,
378        infcx: &mut InferCtxt,
379        ensures: &[Ensures],
380        span: Span,
381    ) {
382        for ensure in ensures {
383            match ensure {
384                Ensures::Type(path, updated_ty) => {
385                    let updated_ty = infcx.unpack(updated_ty);
386                    infcx.assume_invariants(&updated_ty);
387                    self.update_path(path, updated_ty, span);
388                }
389                Ensures::Pred(e) => infcx.assume_pred(e),
390            }
391        }
392    }
393
394    pub(crate) fn check_ensures(
395        &mut self,
396        at: &mut InferCtxtAt,
397        ensures: &[Ensures],
398        reason: ConstrReason,
399    ) -> InferResult {
400        for constraint in ensures {
401            match constraint {
402                Ensures::Type(path, ty) => {
403                    let actual_ty = self.get(path);
404                    at.subtyping(&actual_ty, ty, reason)?;
405                }
406                Ensures::Pred(e) => {
407                    at.check_pred(e, ConstrReason::Ret);
408                }
409            }
410        }
411        Ok(())
412    }
413}
414
415pub(crate) enum PtrToRefBound {
416    Ty(Ty),
417    Infer,
418    Identity,
419}
420
421impl flux_infer::infer::LocEnv for TypeEnv<'_> {
422    fn ptr_to_ref(
423        &mut self,
424        infcx: &mut InferCtxtAt,
425        reason: ConstrReason,
426        re: Region,
427        path: &Path,
428        bound: Ty,
429    ) -> InferResult<Ty> {
430        self.ptr_to_ref(infcx, reason, re, path, PtrToRefBound::Ty(bound))
431    }
432
433    fn get(&self, path: &Path) -> Ty {
434        self.get(path)
435    }
436
437    fn unfold_strg_ref(&mut self, infcx: &mut InferCtxt, path: &Path, ty: &Ty) -> InferResult<Loc> {
438        self.unfold_strg_ref(infcx, path, ty)
439    }
440}
441
442impl BasicBlockEnvShape {
443    pub fn enter<'a>(&self, local_decls: &'a LocalDecls) -> TypeEnv<'a> {
444        TypeEnv { bindings: self.bindings.clone(), local_decls }
445    }
446
447    fn new(scope: Scope, env: TypeEnv) -> BasicBlockEnvShape {
448        let mut bindings = env.bindings;
449        bindings.fmap_mut(|ty| BasicBlockEnvShape::pack_ty(&scope, ty));
450        BasicBlockEnvShape { scope, bindings }
451    }
452
453    fn pack_ty(scope: &Scope, ty: &Ty) -> Ty {
454        match ty.kind() {
455            TyKind::Indexed(bty, idxs) => {
456                let bty = BasicBlockEnvShape::pack_bty(scope, bty);
457                if scope.has_free_vars(idxs) {
458                    Ty::exists_with_constr(bty, Expr::hole(HoleKind::Pred))
459                } else {
460                    Ty::indexed(bty, idxs.clone())
461                }
462            }
463            TyKind::Downcast(adt, args, ty, variant, fields) => {
464                debug_assert!(!scope.has_free_vars(args));
465                debug_assert!(!scope.has_free_vars(ty));
466                let fields = fields.iter().map(|ty| Self::pack_ty(scope, ty)).collect();
467                Ty::downcast(adt.clone(), args.clone(), ty.clone(), *variant, fields)
468            }
469            TyKind::Blocked(ty) => Ty::blocked(BasicBlockEnvShape::pack_ty(scope, ty)),
470            // FIXME(nilehmann) [`TyKind::Exists`] could also contain free variables.
471            TyKind::Exists(_)
472            | TyKind::Discr(..)
473            | TyKind::Ptr(..)
474            | TyKind::Uninit
475            | TyKind::Param(_)
476            | TyKind::Constr(_, _) => ty.clone(),
477            TyKind::Infer(_) => bug!("unexpected hole whecn checking function body"),
478            TyKind::StrgRef(..) => bug!("unexpected strong reference when checking function body"),
479        }
480    }
481
482    fn pack_bty(scope: &Scope, bty: &BaseTy) -> BaseTy {
483        match bty {
484            BaseTy::Adt(adt_def, args) => {
485                let args = List::from_vec(
486                    args.iter()
487                        .map(|arg| Self::pack_generic_arg(scope, arg))
488                        .collect(),
489                );
490                BaseTy::adt(adt_def.clone(), args)
491            }
492            BaseTy::FnDef(def_id, args) => {
493                let args = List::from_vec(
494                    args.iter()
495                        .map(|arg| Self::pack_generic_arg(scope, arg))
496                        .collect(),
497                );
498                BaseTy::fn_def(*def_id, args)
499            }
500            BaseTy::Tuple(tys) => {
501                let tys = tys
502                    .iter()
503                    .map(|ty| BasicBlockEnvShape::pack_ty(scope, ty))
504                    .collect();
505                BaseTy::Tuple(tys)
506            }
507            BaseTy::Slice(ty) => BaseTy::Slice(Self::pack_ty(scope, ty)),
508            BaseTy::Ref(r, ty, mutbl) => BaseTy::Ref(*r, Self::pack_ty(scope, ty), *mutbl),
509            BaseTy::Array(ty, c) => BaseTy::Array(Self::pack_ty(scope, ty), c.clone()),
510            BaseTy::Int(_)
511            | BaseTy::Param(_)
512            | BaseTy::Uint(_)
513            | BaseTy::Bool
514            | BaseTy::Float(_)
515            | BaseTy::Str
516            | BaseTy::RawPtr(_, _)
517            | BaseTy::RawPtrMetadata(_)
518            | BaseTy::Char
519            | BaseTy::Never
520            | BaseTy::Closure(..)
521            | BaseTy::Dynamic(..)
522            | BaseTy::Alias(..)
523            | BaseTy::FnPtr(..)
524            | BaseTy::Foreign(..)
525            | BaseTy::Coroutine(..) => {
526                assert!(!scope.has_free_vars(bty));
527                bty.clone()
528            }
529            BaseTy::Infer(..) => {
530                tracked_span_bug!("unexpected infer type")
531            }
532        }
533    }
534
535    fn pack_generic_arg(scope: &Scope, arg: &GenericArg) -> GenericArg {
536        match arg {
537            GenericArg::Ty(ty) => GenericArg::Ty(Self::pack_ty(scope, ty)),
538            GenericArg::Base(arg) => {
539                assert!(!scope.has_free_vars(arg));
540                GenericArg::Base(arg.clone())
541            }
542            GenericArg::Lifetime(re) => GenericArg::Lifetime(*re),
543            GenericArg::Const(c) => GenericArg::Const(c.clone()),
544        }
545    }
546
547    fn update(&mut self, path: &Path, ty: Ty, span: Span) {
548        self.bindings.lookup(path, span).update(ty);
549    }
550
551    /// join(self, genv, other) consumes the bindings in other, to "update"
552    /// `self` in place, and returns `true` if there was an actual change
553    /// or `false` indicating no change (i.e., a fixpoint was reached).
554    pub(crate) fn join(&mut self, other: TypeEnv, span: Span) -> bool {
555        let paths = self.bindings.paths();
556
557        // Join types
558        let mut modified = false;
559        for path in &paths {
560            let ty1 = self.bindings.get(path);
561            let ty2 = other.bindings.get(path);
562            let ty = self.join_ty(&ty1, &ty2);
563            modified |= ty1 != ty;
564            self.update(path, ty, span);
565        }
566
567        modified
568    }
569
570    fn join_ty(&self, ty1: &Ty, ty2: &Ty) -> Ty {
571        match (ty1.kind(), ty2.kind()) {
572            (TyKind::Blocked(ty1), _) => Ty::blocked(self.join_ty(ty1, &ty2.unblocked())),
573            (_, TyKind::Blocked(ty2)) => Ty::blocked(self.join_ty(&ty1.unblocked(), ty2)),
574            (TyKind::Uninit, _) | (_, TyKind::Uninit) => Ty::uninit(),
575            (TyKind::Exists(ty1), _) => self.join_ty(ty1.as_ref().skip_binder(), ty2),
576            (_, TyKind::Exists(ty2)) => self.join_ty(ty1, ty2.as_ref().skip_binder()),
577            (TyKind::Constr(_, ty1), _) => self.join_ty(ty1, ty2),
578            (_, TyKind::Constr(_, ty2)) => self.join_ty(ty1, ty2),
579            (TyKind::Indexed(bty1, idx1), TyKind::Indexed(bty2, idx2)) => {
580                let bty = self.join_bty(bty1, bty2);
581                let mut sorts = vec![];
582                let idx = self.join_idx(idx1, idx2, &bty.sort(), &mut sorts);
583                if sorts.is_empty() {
584                    Ty::indexed(bty, idx)
585                } else {
586                    let ty = Ty::constr(Expr::hole(HoleKind::Pred), Ty::indexed(bty, idx));
587                    Ty::exists(Binder::bind_with_sorts(ty, &sorts))
588                }
589            }
590            (TyKind::Ptr(rk1, path1), TyKind::Ptr(rk2, path2)) => {
591                debug_assert_eq!(rk1, rk2);
592                debug_assert_eq!(path1, path2);
593                Ty::ptr(*rk1, path1.clone())
594            }
595            (TyKind::Param(param_ty1), TyKind::Param(param_ty2)) => {
596                debug_assert_eq!(param_ty1, param_ty2);
597                Ty::param(*param_ty1)
598            }
599            (
600                TyKind::Downcast(adt1, args1, ty1, variant1, fields1),
601                TyKind::Downcast(adt2, args2, ty2, variant2, fields2),
602            ) => {
603                debug_assert_eq!(adt1, adt2);
604                debug_assert_eq!(args1, args2);
605                debug_assert!(ty1 == ty2 && !self.scope.has_free_vars(ty2));
606                debug_assert_eq!(variant1, variant2);
607                debug_assert_eq!(fields1.len(), fields2.len());
608                let fields = iter::zip(fields1, fields2)
609                    .map(|(ty1, ty2)| self.join_ty(ty1, ty2))
610                    .collect();
611                Ty::downcast(adt1.clone(), args1.clone(), ty1.clone(), *variant1, fields)
612            }
613            _ => tracked_span_bug!("unexpected types: `{ty1:?}` - `{ty2:?}`"),
614        }
615    }
616
617    fn join_idx(&self, e1: &Expr, e2: &Expr, sort: &Sort, bound_sorts: &mut Vec<Sort>) -> Expr {
618        match (e1.kind(), e2.kind(), sort) {
619            (ExprKind::Tuple(es1), ExprKind::Tuple(es2), Sort::Tuple(sorts)) => {
620                debug_assert_eq3!(es1.len(), es2.len(), sorts.len());
621                Expr::tuple(
622                    izip!(es1, es2, sorts)
623                        .map(|(e1, e2, sort)| self.join_idx(e1, e2, sort, bound_sorts))
624                        .collect(),
625                )
626            }
627            (
628                ExprKind::Ctor(Ctor::Struct(_), flds1),
629                ExprKind::Ctor(Ctor::Struct(_), flds2),
630                Sort::App(SortCtor::Adt(sort_def), args),
631            ) => {
632                let sorts = sort_def.struct_variant().field_sorts(args);
633                debug_assert_eq3!(flds1.len(), flds2.len(), sorts.len());
634
635                Expr::ctor_struct(
636                    sort_def.did(),
637                    izip!(flds1, flds2, &sorts)
638                        .map(|(f1, f2, sort)| self.join_idx(f1, f2, sort, bound_sorts))
639                        .collect(),
640                )
641            }
642            _ => {
643                let has_free_vars2 = self.scope.has_free_vars(e2);
644                let has_escaping_vars1 = e1.has_escaping_bvars();
645                let has_escaping_vars2 = e2.has_escaping_bvars();
646                if !has_free_vars2 && !has_escaping_vars1 && !has_escaping_vars2 && e1 == e2 {
647                    e1.clone()
648                } else if sort.is_pred() {
649                    // FIXME(nilehmann) we shouldn't special case predicates here. Instead, we
650                    // should differentiate between generics and indices.
651                    let fsort = sort.expect_func().expect_mono();
652                    Expr::abs(Lambda::bind_with_fsort(Expr::hole(HoleKind::Pred), fsort))
653                } else {
654                    bound_sorts.push(sort.clone());
655                    Expr::bvar(
656                        INNERMOST,
657                        BoundVar::from_usize(bound_sorts.len() - 1),
658                        BoundReftKind::Annon,
659                    )
660                }
661            }
662        }
663    }
664
665    fn join_bty(&self, bty1: &BaseTy, bty2: &BaseTy) -> BaseTy {
666        match (bty1, bty2) {
667            (BaseTy::Adt(def1, args1), BaseTy::Adt(def2, args2)) => {
668                debug_assert_eq!(def1.did(), def2.did());
669                let args = iter::zip(args1, args2)
670                    .map(|(arg1, arg2)| self.join_generic_arg(arg1, arg2))
671                    .collect();
672                BaseTy::adt(def1.clone(), List::from_vec(args))
673            }
674            (BaseTy::Tuple(fields1), BaseTy::Tuple(fields2)) => {
675                let fields = iter::zip(fields1, fields2)
676                    .map(|(ty1, ty2)| self.join_ty(ty1, ty2))
677                    .collect();
678                BaseTy::Tuple(fields)
679            }
680            (BaseTy::Alias(kind1, alias_ty1), BaseTy::Alias(kind2, alias_ty2)) => {
681                debug_assert_eq!(kind1, kind2);
682                debug_assert_eq!(alias_ty1, alias_ty2);
683                BaseTy::Alias(*kind1, alias_ty1.clone())
684            }
685            (BaseTy::Ref(r1, ty1, mutbl1), BaseTy::Ref(r2, ty2, mutbl2)) => {
686                debug_assert_eq!(r1, r2);
687                debug_assert_eq!(mutbl1, mutbl2);
688                BaseTy::Ref(*r1, self.join_ty(ty1, ty2), *mutbl1)
689            }
690            (BaseTy::Array(ty1, len1), BaseTy::Array(ty2, len2)) => {
691                tracked_span_dbg_assert_eq!(len1, len2);
692                BaseTy::Array(self.join_ty(ty1, ty2), len1.clone())
693            }
694            (BaseTy::Slice(ty1), BaseTy::Slice(ty2)) => BaseTy::Slice(self.join_ty(ty1, ty2)),
695            _ => {
696                tracked_span_dbg_assert_eq!(bty1, bty2);
697                bty1.clone()
698            }
699        }
700    }
701
702    fn join_generic_arg(&self, arg1: &GenericArg, arg2: &GenericArg) -> GenericArg {
703        match (arg1, arg2) {
704            (GenericArg::Ty(ty1), GenericArg::Ty(ty2)) => GenericArg::Ty(self.join_ty(ty1, ty2)),
705            (GenericArg::Base(ctor1), GenericArg::Base(ctor2)) => {
706                let sty1 = ctor1.as_ref().skip_binder();
707                let sty2 = ctor2.as_ref().skip_binder();
708                debug_assert!(sty1.idx.is_nu());
709                debug_assert!(sty2.idx.is_nu());
710
711                let bty = self.join_bty(&sty1.bty, &sty2.bty);
712                let pred = if self.scope.has_free_vars(&sty2.pred) || sty1.pred != sty2.pred {
713                    Expr::hole(HoleKind::Pred)
714                } else {
715                    sty1.pred.clone()
716                };
717                let sort = bty.sort();
718                let ctor = Binder::bind_with_sort(SubsetTy::new(bty, Expr::nu(), pred), sort);
719                GenericArg::Base(ctor)
720            }
721            (GenericArg::Lifetime(re1), GenericArg::Lifetime(_re2)) => {
722                // TODO(nilehmann) loop_abstract_refinement.rs is triggering this assertion to fail
723                // wee should fix it.
724                // debug_assert_eq!(re1, _re2);
725                GenericArg::Lifetime(*re1)
726            }
727            (GenericArg::Const(c1), GenericArg::Const(c2)) => {
728                debug_assert_eq!(c1, c2);
729                GenericArg::Const(c1.clone())
730            }
731            _ => tracked_span_bug!("unexpected generic args: `{arg1:?}` - `{arg2:?}`"),
732        }
733    }
734
735    pub fn into_bb_env(self, infcx: &mut InferCtxtRoot) -> BasicBlockEnv {
736        let mut delegate = LocalHoister::default();
737        let mut hoister = Hoister::with_delegate(&mut delegate).transparent();
738
739        let mut bindings = self.bindings;
740        bindings.fmap_mut(|ty| hoister.hoist(ty));
741
742        BasicBlockEnv {
743            // We are relying on all the types in `bindings` not having escaping bvars, otherwise
744            // we would have to shift them in since we are creating a new binder.
745            data: delegate.bind(|vars, preds| {
746                // Replace all holes with a single fresh kvar on all parameters
747                let mut constrs = preds
748                    .into_iter()
749                    .filter(|pred| !matches!(pred.kind(), ExprKind::Hole(HoleKind::Pred)))
750                    .collect_vec();
751                let kvar =
752                    infcx.fresh_kvar_in_scope(&[vars.clone()], &self.scope, KVarEncoding::Conj);
753                constrs.push(kvar);
754
755                // Replace remaining holes by fresh kvars
756                let mut kvar_gen = |binders: &[_], kind| {
757                    debug_assert_eq!(kind, HoleKind::Pred);
758                    let binders = std::iter::once(vars.clone())
759                        .chain(binders.iter().cloned())
760                        .collect_vec();
761                    infcx.fresh_kvar_in_scope(&binders, &self.scope, KVarEncoding::Conj)
762                };
763                bindings.fmap_mut(|binding| binding.replace_holes(&mut kvar_gen));
764
765                BasicBlockEnvData { constrs: constrs.into(), bindings }
766            }),
767            scope: self.scope,
768        }
769    }
770}
771
772impl TypeVisitable for BasicBlockEnvData {
773    fn visit_with<V: TypeVisitor>(&self, _visitor: &mut V) -> ControlFlow<V::BreakTy> {
774        unimplemented!()
775    }
776}
777
778impl TypeFoldable for BasicBlockEnvData {
779    fn try_fold_with<F: FallibleTypeFolder>(
780        &self,
781        folder: &mut F,
782    ) -> std::result::Result<Self, F::Error> {
783        Ok(BasicBlockEnvData {
784            constrs: self.constrs.try_fold_with(folder)?,
785            bindings: self.bindings.try_fold_with(folder)?,
786        })
787    }
788}
789
790impl BasicBlockEnv {
791    pub(crate) fn enter<'a>(
792        &self,
793        infcx: &mut InferCtxt,
794        local_decls: &'a LocalDecls,
795    ) -> TypeEnv<'a> {
796        let data = self
797            .data
798            .replace_bound_refts_with(|sort, _, _| Expr::fvar(infcx.define_var(sort)));
799        for constr in &data.constrs {
800            infcx.assume_pred(constr);
801        }
802        TypeEnv { bindings: data.bindings, local_decls }
803    }
804
805    pub(crate) fn scope(&self) -> &Scope {
806        &self.scope
807    }
808}
809
810mod pretty {
811    use std::fmt;
812
813    use flux_middle::pretty::*;
814
815    use super::*;
816
817    impl Pretty for TypeEnv<'_> {
818        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
819            w!(cx, f, "{:?}", &self.bindings)
820        }
821
822        fn default_cx(tcx: TyCtxt) -> PrettyCx {
823            PlacesTree::default_cx(tcx)
824        }
825    }
826
827    impl Pretty for BasicBlockEnvShape {
828        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
829            w!(cx, f, "{:?} {:?}", &self.scope, &self.bindings)
830        }
831
832        fn default_cx(tcx: TyCtxt) -> PrettyCx {
833            PlacesTree::default_cx(tcx)
834        }
835    }
836
837    impl Pretty for BasicBlockEnv {
838        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
839            w!(cx, f, "{:?} ", &self.scope)?;
840
841            let vars = self.data.vars();
842            cx.with_bound_vars(vars, || {
843                if !vars.is_empty() {
844                    cx.fmt_bound_vars(true, "for<", vars, "> ", f)?;
845                }
846                let data = self.data.as_ref().skip_binder();
847                if !data.constrs.is_empty() {
848                    w!(
849                        cx,
850                        f,
851                        "{:?} ⇒ ",
852                        join!(", ", data.constrs.iter().filter(|pred| !pred.is_trivially_true()))
853                    )?;
854                }
855                w!(cx, f, "{:?}", &data.bindings)
856            })
857        }
858
859        fn default_cx(tcx: TyCtxt) -> PrettyCx {
860            PlacesTree::default_cx(tcx)
861        }
862    }
863
864    impl_debug_with_default_cx! {
865        TypeEnv<'_> => "type_env",
866        BasicBlockEnvShape => "basic_block_env_shape",
867        BasicBlockEnv => "basic_block_env"
868    }
869}
870
871/// A very explicit representation of [`TypeEnv`] for debugging/tracing/serialization ONLY.
872#[derive(Serialize, DebugAsJson)]
873pub struct TypeEnvTrace(Vec<TypeEnvBind>);
874
875#[derive(Serialize)]
876struct TypeEnvBind {
877    local: LocInfo,
878    name: Option<String>,
879    kind: String,
880    ty: String,
881    span: Option<SpanTrace>,
882}
883
884#[derive(Serialize)]
885enum LocInfo {
886    Local(String),
887    Var(String),
888}
889
890fn loc_info(loc: &Loc) -> LocInfo {
891    match loc {
892        Loc::Local(local) => LocInfo::Local(format!("{local:?}")),
893        Loc::Var(var) => LocInfo::Var(format!("{var:?}")),
894    }
895}
896
897fn loc_name(local_names: &UnordMap<Local, Symbol>, loc: &Loc) -> Option<String> {
898    if let Loc::Local(local) = loc {
899        let name = local_names.get(local)?;
900        return Some(format!("{}", name));
901    }
902    None
903}
904
905fn loc_span(
906    genv: GlobalEnv,
907    local_decls: &IndexVec<Local, LocalDecl>,
908    loc: &Loc,
909) -> Option<SpanTrace> {
910    if let Loc::Local(local) = loc {
911        return local_decls
912            .get(*local)
913            .map(|local_decl| SpanTrace::new(genv, local_decl.source_info.span));
914    }
915    None
916}
917
918impl TypeEnvTrace {
919    pub fn new(
920        genv: GlobalEnv,
921        local_names: &UnordMap<Local, Symbol>,
922        local_decls: &IndexVec<Local, LocalDecl>,
923        env: &TypeEnv,
924    ) -> Self {
925        let mut bindings = vec![];
926        let cx = PrettyCx::default(genv).hide_regions(true);
927        env.bindings
928            .iter()
929            .filter(|(_, binding)| !binding.ty.is_uninit())
930            .sorted_by(|(loc1, _), (loc2, _)| loc1.cmp(loc2))
931            .for_each(|(loc, binding)| {
932                let name = loc_name(local_names, loc);
933                let local = loc_info(loc);
934                let kind = format!("{:?}", binding.kind);
935                let ty = binding.ty.nested_string(&cx);
936                let span = loc_span(genv, local_decls, loc);
937                bindings.push(TypeEnvBind { name, local, kind, ty, span });
938            });
939
940        TypeEnvTrace(bindings)
941    }
942}
943
944#[derive(Serialize, DebugAsJson)]
945pub struct SpanTrace {
946    file: Option<String>,
947    start_line: usize,
948    start_col: usize,
949    end_line: usize,
950    end_col: usize,
951}
952
953impl SpanTrace {
954    fn span_file(tcx: TyCtxt, span: Span) -> Option<String> {
955        let sm = tcx.sess.source_map();
956        let current_dir = &tcx.sess.opts.working_dir;
957        let current_dir = current_dir.local_path()?;
958        if let rustc_span::FileName::Real(file_name) = sm.span_to_filename(span) {
959            let file_path = file_name.local_path()?;
960            let full_path = current_dir.join(file_path);
961            Some(full_path.display().to_string())
962        } else {
963            None
964        }
965    }
966    pub fn new(genv: GlobalEnv, span: Span) -> Self {
967        let tcx = genv.tcx();
968        let sm = tcx.sess.source_map();
969        let (_, start_line, start_col, end_line, end_col) = sm.span_to_location_info(span);
970        let file = SpanTrace::span_file(tcx, span);
971        SpanTrace { file, start_line, start_col, end_line, end_col }
972    }
973}