flux_refineck/
type_env.rs

1mod place_ty;
2
3use std::{iter, ops::ControlFlow};
4
5use flux_common::{
6    bug,
7    dbg::{SpanTrace, debug_assert_eq3},
8    tracked_span_bug, tracked_span_dbg_assert_eq,
9};
10use flux_infer::{
11    fixpoint_encoding::KVarEncoding,
12    infer::{ConstrReason, InferCtxt, InferCtxtAt, InferCtxtRoot, InferResult},
13    refine_tree::Scope,
14};
15use flux_macros::DebugAsJson;
16use flux_middle::{
17    PlaceExt as _,
18    global_env::GlobalEnv,
19    pretty::{PrettyCx, PrettyNested},
20    queries::QueryResult,
21    rty::{
22        BaseTy, Binder, BoundReftKind, Ctor, Ensures, Expr, ExprKind, FnSig, GenericArg, HoleKind,
23        INNERMOST, Lambda, List, Loc, Mutability, Path, PtrKind, Region, SortCtor, SubsetTy, Ty,
24        TyKind, VariantIdx,
25        canonicalize::{Hoister, LocalHoister},
26        fold::{FallibleTypeFolder, TypeFoldable, TypeVisitable, TypeVisitor},
27        region_matching::{rty_match_regions, ty_match_regions},
28    },
29};
30use flux_rustc_bridge::{
31    self,
32    mir::{BasicBlock, Body, Local, LocalDecl, LocalDecls, Place, PlaceElem},
33    ty,
34};
35use itertools::{Itertools, izip};
36use rustc_data_structures::unord::UnordMap;
37use rustc_index::{IndexSlice, IndexVec};
38use rustc_middle::{mir::RETURN_PLACE, ty::TyCtxt};
39use rustc_span::{Span, Symbol};
40use rustc_type_ir::BoundVar;
41use serde::Serialize;
42
43use self::place_ty::{LocKind, PlacesTree};
44use super::rty::Sort;
45
46#[derive(Clone, Default)]
47pub struct TypeEnv<'a> {
48    bindings: PlacesTree,
49    local_decls: &'a LocalDecls,
50}
51
52pub struct BasicBlockEnvShape {
53    scope: Scope,
54    bindings: PlacesTree,
55}
56
57pub struct BasicBlockEnv {
58    data: Binder<BasicBlockEnvData>,
59    scope: Scope,
60}
61
62#[derive(Debug)]
63struct BasicBlockEnvData {
64    constrs: List<Expr>,
65    bindings: PlacesTree,
66}
67
68impl<'a> TypeEnv<'a> {
69    pub fn new(infcx: &mut InferCtxt, body: &'a Body, fn_sig: &FnSig) -> TypeEnv<'a> {
70        let mut env = TypeEnv { bindings: PlacesTree::default(), local_decls: &body.local_decls };
71
72        for requires in fn_sig.requires() {
73            infcx.assume_pred(requires);
74        }
75
76        for (local, ty) in body.args_iter().zip(fn_sig.inputs()) {
77            let ty = infcx.unpack(ty);
78            infcx.assume_invariants(&ty);
79            env.alloc_with_ty(local, ty);
80        }
81
82        for local in body.vars_and_temps_iter() {
83            env.alloc(local);
84        }
85
86        env.alloc(RETURN_PLACE);
87        env
88    }
89
90    pub fn empty() -> TypeEnv<'a> {
91        TypeEnv { bindings: PlacesTree::default(), local_decls: IndexSlice::empty() }
92    }
93
94    fn alloc_with_ty(&mut self, local: Local, ty: Ty) {
95        let ty = ty_match_regions(&ty, &self.local_decls[local].ty);
96        self.bindings.insert(local.into(), LocKind::Local, ty);
97    }
98
99    fn alloc(&mut self, local: Local) {
100        self.bindings
101            .insert(local.into(), LocKind::Local, Ty::uninit());
102    }
103
104    pub(crate) fn into_infer(self, scope: Scope) -> BasicBlockEnvShape {
105        BasicBlockEnvShape::new(scope, self)
106    }
107
108    pub(crate) fn lookup_rust_ty(&self, genv: GlobalEnv, place: &Place) -> QueryResult<ty::Ty> {
109        Ok(place.ty(genv, self.local_decls)?.ty)
110    }
111
112    pub(crate) fn lookup_place(
113        &mut self,
114        infcx: &mut InferCtxtAt,
115        place: &Place,
116    ) -> InferResult<Ty> {
117        let span = infcx.span;
118        Ok(self.bindings.lookup_unfolding(infcx, place, span)?.ty)
119    }
120
121    pub(crate) fn get(&self, path: &Path) -> Ty {
122        self.bindings.get(path)
123    }
124
125    pub fn update_path(&mut self, path: &Path, new_ty: Ty, span: Span) {
126        self.bindings.lookup(path, span).update(new_ty);
127    }
128
129    /// When checking a borrow in the right hand side of an assignment `x = &'?n p`, we use the
130    /// annotated region `'?n` in the type of the result. This region will only be used temporarily
131    /// and then replaced by the region in the type of `x` after the assignment. See [`TypeEnv::assign`]
132    pub(crate) fn borrow(
133        &mut self,
134        infcx: &mut InferCtxtAt,
135        re: Region,
136        mutbl: Mutability,
137        place: &Place,
138    ) -> InferResult<Ty> {
139        let span = infcx.span;
140        let result = self.bindings.lookup_unfolding(infcx, place, span)?;
141        if result.is_strg && mutbl == Mutability::Mut {
142            Ok(Ty::ptr(PtrKind::Mut(re), result.path()))
143        } else {
144            // FIXME(nilehmann) we should block the place here. That would require a notion
145            // of shared vs mutable block types because sometimes blocked places from a shared
146            // reference never get unblocked and we should still allow reads through them.
147            Ok(Ty::mk_ref(re, result.ty, mutbl))
148        }
149    }
150
151    // FIXME(nilehmann) this is only used in a single place and we have it because [`TypeEnv`]
152    // doesn't expose a lookup without unfolding
153    pub(crate) fn ptr_to_ref_at_place(
154        &mut self,
155        infcx: &mut InferCtxtAt,
156        place: &Place,
157    ) -> InferResult {
158        let lookup = self.bindings.lookup(place, infcx.span);
159        let TyKind::Ptr(PtrKind::Mut(re), path) = lookup.ty.kind() else {
160            tracked_span_bug!("ptr_to_borrow called on non mutable pointer type")
161        };
162
163        let ref_ty =
164            self.ptr_to_ref(infcx, ConstrReason::Other, *re, path, PtrToRefBound::Infer)?;
165
166        self.bindings.lookup(place, infcx.span).update(ref_ty);
167
168        Ok(())
169    }
170
171    /// Convert a (strong) pointer to a mutable reference.
172    ///
173    /// This roughly implements the following inference rule:
174    /// ```text
175    ///                   t₁ <: t₂
176    /// -------------------------------------------------
177    /// Γ₁,ℓ:t1,Γ₂ ; ptr(mut, ℓ) => Γ₁,ℓ:†t₂,Γ₂ ; &mut t2
178    /// ```
179    /// That's it, we first get the current type `t₁` at location `ℓ` and check it is a subtype
180    /// of `t₂`. Then, we update the type of `ℓ` to `t₂` and block the place.
181    ///
182    /// The bound `t₂` can be either inferred ([`PtrToRefBound::Infer`]), explicitly provided
183    /// ([`PtrToRefBound::Ty`]), or made equal to `t₁` ([`PtrToRefBound::Identity`]).
184    ///
185    /// As an example, consider the environment `x: i32[a]` and the pointer `ptr(mut, x)`.
186    /// Converting the pointer to a mutable reference with an inferred bound produces the following
187    /// derivation (roughly):
188    ///
189    /// ```text
190    ///                    i32[a] <: i32{v: $k(v)}
191    /// ----------------------------------------------------------------
192    /// x: i32[a] ; ptr(mut, x) => x:†i32{v: $k(v)} ; &mut i32{v: $k(v)}
193    /// ```
194    pub(crate) fn ptr_to_ref(
195        &mut self,
196        infcx: &mut InferCtxtAt,
197        reason: ConstrReason,
198        re: Region,
199        path: &Path,
200        bound: PtrToRefBound,
201    ) -> InferResult<Ty> {
202        // ℓ: t1
203        let t1 = self.bindings.lookup(path, infcx.span).fold(infcx)?;
204
205        // t1 <: t2
206        let t2 = match bound {
207            PtrToRefBound::Ty(t2) => {
208                let t2 = rty_match_regions(&t2, &t1);
209                infcx.subtyping(&t1, &t2, reason)?;
210                t2
211            }
212            PtrToRefBound::Infer => {
213                let t2 = t1.with_holes().replace_holes(|sorts, kind| {
214                    debug_assert_eq!(kind, HoleKind::Pred);
215                    infcx.fresh_kvar(sorts, KVarEncoding::Conj)
216                });
217                infcx.subtyping(&t1, &t2, reason)?;
218                t2
219            }
220            PtrToRefBound::Identity => t1.clone(),
221        };
222
223        // ℓ: †t2
224        self.bindings
225            .lookup(path, infcx.span)
226            .block_with(t2.clone());
227
228        Ok(Ty::mk_ref(re, t2, Mutability::Mut))
229    }
230
231    pub(crate) fn fold_local_ptrs(&mut self, infcx: &mut InferCtxtAt) -> InferResult {
232        for (loc, bound, ty) in self.bindings.local_ptrs() {
233            infcx.subtyping(&ty, &bound, ConstrReason::FoldLocal)?;
234            self.bindings.remove_local(&loc);
235        }
236        Ok(())
237    }
238
239    /// Updates the type of `place` to `new_ty`. This may involve a *strong update* if we have
240    /// ownership of `place` or a *weak update* if it's behind a reference (which fires a subtyping
241    /// constraint)
242    ///
243    /// When strong updating, the process involves recovering the original regions (lifetimes) used
244    /// in the (unrefined) Rust type of `place` and then substituting these regions in `new_ty`. For
245    /// instance, if we are assigning a value of type `S<&'?10 i32{v: v > 0}>` to a variable `x`,
246    /// and the (unrefined) Rust type of `x` is `S<&'?5 i32>`, before the assignment, we identify a
247    /// substitution that maps the region `'?10` to `'?5`. After applying this substitution, the
248    /// type of the place `x` is updated accordingly. This ensures that the lifetimes in the
249    /// assigned type are consistent with those expected by the place's original type definition.
250    pub(crate) fn assign(
251        &mut self,
252        infcx: &mut InferCtxtAt,
253        place: &Place,
254        new_ty: Ty,
255    ) -> InferResult {
256        let rustc_ty = place.ty(infcx.genv, self.local_decls)?.ty;
257        let new_ty = ty_match_regions(&new_ty, &rustc_ty);
258        let span = infcx.span;
259        let result = self.bindings.lookup_unfolding(infcx, place, span)?;
260        if result.is_strg {
261            result.update(new_ty);
262        } else if !place.behind_raw_ptr(infcx.genv, self.local_decls)? {
263            infcx.subtyping(&new_ty, &result.ty, ConstrReason::Assign)?;
264        }
265        Ok(())
266    }
267
268    pub(crate) fn move_place(&mut self, infcx: &mut InferCtxtAt, place: &Place) -> InferResult<Ty> {
269        let span = infcx.span;
270        let result = self.bindings.lookup_unfolding(infcx, place, span)?;
271        if result.is_strg {
272            let uninit = Ty::uninit();
273            Ok(result.update(uninit))
274        } else {
275            // ignore the 'move' and trust rustc managed the move correctly
276            // https://github.com/flux-rs/flux/issues/725#issuecomment-2295065634
277            Ok(result.ty)
278        }
279    }
280
281    pub(crate) fn unpack(&mut self, infcx: &mut InferCtxt) {
282        self.bindings
283            .fmap_mut(|_loc, ty| infcx.hoister(true).hoist(ty));
284    }
285
286    pub(crate) fn unblock(&mut self, infcx: &mut InferCtxt, place: &Place) {
287        self.bindings.unblock(infcx, place);
288    }
289
290    pub(crate) fn check_goto(
291        self,
292        infcx: &mut InferCtxtAt,
293        bb_env: &BasicBlockEnv,
294        target: BasicBlock,
295    ) -> InferResult {
296        infcx.ensure_resolved_evars(|infcx| {
297            let bb_env = bb_env
298                .data
299                .replace_bound_refts_with(|sort, mode, _| infcx.fresh_infer_var(sort, mode));
300
301            // Check constraints
302            for constr in &bb_env.constrs {
303                infcx.check_pred(constr, ConstrReason::Goto(target));
304            }
305
306            // Check subtyping
307            let bb_env = bb_env.bindings.flatten();
308            for (path, _, ty2) in bb_env {
309                let ty1 = self.bindings.get(&path);
310                infcx.subtyping(&ty1.unblocked(), &ty2.unblocked(), ConstrReason::Goto(target))?;
311            }
312            Ok(())
313        })
314    }
315
316    pub(crate) fn fold(&mut self, infcx: &mut InferCtxtAt, place: &Place) -> InferResult {
317        let span = infcx.span;
318        self.bindings.lookup(place, span).fold(infcx)?;
319        Ok(())
320    }
321
322    pub(crate) fn unfold_local_ptr(
323        &mut self,
324        infcx: &mut InferCtxt,
325        bound: &Ty,
326    ) -> InferResult<Loc> {
327        let name = infcx.define_unknown_var(&Sort::Loc);
328        let loc = Loc::from(name);
329        let ty = infcx.unpack(bound);
330        self.bindings
331            .insert(loc, LocKind::LocalPtr(bound.clone()), ty);
332        Ok(loc)
333    }
334
335    /// ```text
336    /// -----------------------------------
337    /// Γ ; &strg <ℓ: t> => Γ,ℓ: t ; ptr(ℓ)
338    /// ```
339    pub(crate) fn unfold_strg_ref(
340        &mut self,
341        infcx: &mut InferCtxt,
342        path: &Path,
343        ty: &Ty,
344    ) -> InferResult<Loc> {
345        if let Some(loc) = path.to_loc() {
346            let ty = infcx.unpack(ty);
347            self.bindings.insert(loc, LocKind::Universal, ty);
348            Ok(loc)
349        } else {
350            bug!("unfold_strg_ref: unexpected path {path:?}")
351        }
352    }
353
354    pub(crate) fn unfold(
355        &mut self,
356        infcx: &mut InferCtxt,
357        place: &Place,
358        span: Span,
359    ) -> InferResult {
360        self.bindings.unfold(infcx, place, span)
361    }
362
363    pub(crate) fn downcast(
364        &mut self,
365        infcx: &mut InferCtxtAt,
366        place: &Place,
367        variant_idx: VariantIdx,
368    ) -> InferResult {
369        let mut down_place = place.clone();
370        let span = infcx.span;
371        down_place
372            .projection
373            .push(PlaceElem::Downcast(None, variant_idx));
374        self.bindings.unfold(infcx, &down_place, span)?;
375        Ok(())
376    }
377
378    pub fn fully_resolve_evars(&mut self, infcx: &InferCtxt) {
379        self.bindings
380            .fmap_mut(|_loc, ty| infcx.fully_resolve_evars(ty));
381    }
382
383    pub(crate) fn assume_ensures(
384        &mut self,
385        infcx: &mut InferCtxt,
386        ensures: &[Ensures],
387        span: Span,
388    ) {
389        for ensure in ensures {
390            match ensure {
391                Ensures::Type(path, updated_ty) => {
392                    let updated_ty = infcx.unpack(updated_ty);
393                    infcx.assume_invariants(&updated_ty);
394                    self.update_path(path, updated_ty, span);
395                }
396                Ensures::Pred(e) => infcx.assume_pred(e),
397            }
398        }
399    }
400
401    pub(crate) fn check_ensures(
402        &mut self,
403        at: &mut InferCtxtAt,
404        ensures: &[Ensures],
405        reason: ConstrReason,
406    ) -> InferResult {
407        for constraint in ensures {
408            match constraint {
409                Ensures::Type(path, ty) => {
410                    let actual_ty = self.get(path).unblocked(); // HACK
411                    at.subtyping(&actual_ty, ty, reason)?;
412                }
413                Ensures::Pred(e) => {
414                    at.check_pred(e, ConstrReason::Ret);
415                }
416            }
417        }
418        Ok(())
419    }
420}
421
422pub(crate) enum PtrToRefBound {
423    Ty(Ty),
424    Infer,
425    Identity,
426}
427
428impl flux_infer::infer::LocEnv for TypeEnv<'_> {
429    fn ptr_to_ref(
430        &mut self,
431        infcx: &mut InferCtxtAt,
432        reason: ConstrReason,
433        re: Region,
434        path: &Path,
435        bound: Ty,
436    ) -> InferResult<Ty> {
437        self.ptr_to_ref(infcx, reason, re, path, PtrToRefBound::Ty(bound))
438    }
439
440    fn get(&self, path: &Path) -> Ty {
441        self.get(path)
442    }
443
444    fn unfold_strg_ref(&mut self, infcx: &mut InferCtxt, path: &Path, ty: &Ty) -> InferResult<Loc> {
445        self.unfold_strg_ref(infcx, path, ty)
446    }
447}
448
449impl BasicBlockEnvShape {
450    pub fn enter<'a>(&self, local_decls: &'a LocalDecls) -> TypeEnv<'a> {
451        TypeEnv { bindings: self.bindings.clone(), local_decls }
452    }
453
454    fn new(scope: Scope, env: TypeEnv) -> BasicBlockEnvShape {
455        let mut bindings = env.bindings;
456        bindings.fmap_mut(|_loc, ty| BasicBlockEnvShape::pack_ty(&scope, ty));
457        BasicBlockEnvShape { scope, bindings }
458    }
459
460    fn pack_ty(scope: &Scope, ty: &Ty) -> Ty {
461        match ty.kind() {
462            TyKind::Indexed(bty, idxs) => {
463                let bty = BasicBlockEnvShape::pack_bty(scope, bty);
464                if scope.has_free_vars(idxs) {
465                    Ty::exists_with_constr(bty, Expr::hole(HoleKind::Pred))
466                } else {
467                    Ty::indexed(bty, idxs.clone())
468                }
469            }
470            TyKind::Downcast(adt, args, ty, variant, fields) => {
471                debug_assert!(!scope.has_free_vars(args));
472                debug_assert!(!scope.has_free_vars(ty));
473                let fields = fields.iter().map(|ty| Self::pack_ty(scope, ty)).collect();
474                Ty::downcast(adt.clone(), args.clone(), ty.clone(), *variant, fields)
475            }
476            TyKind::Blocked(ty) => Ty::blocked(BasicBlockEnvShape::pack_ty(scope, ty)),
477            // FIXME(nilehmann) [`TyKind::Exists`] could also contain free variables.
478            TyKind::Exists(_)
479            | TyKind::Discr(..)
480            | TyKind::Ptr(..)
481            | TyKind::Uninit
482            | TyKind::Param(_)
483            | TyKind::Constr(_, _) => ty.clone(),
484            TyKind::Infer(_) => bug!("unexpected hole whecn checking function body"),
485            TyKind::StrgRef(..) => bug!("unexpected strong reference when checking function body"),
486        }
487    }
488
489    fn pack_bty(scope: &Scope, bty: &BaseTy) -> BaseTy {
490        match bty {
491            BaseTy::Adt(adt_def, args) => {
492                let args = List::from_vec(
493                    args.iter()
494                        .map(|arg| Self::pack_generic_arg(scope, arg))
495                        .collect(),
496                );
497                BaseTy::adt(adt_def.clone(), args)
498            }
499            BaseTy::FnDef(def_id, args) => {
500                let args = List::from_vec(
501                    args.iter()
502                        .map(|arg| Self::pack_generic_arg(scope, arg))
503                        .collect(),
504                );
505                BaseTy::fn_def(*def_id, args)
506            }
507            BaseTy::Tuple(tys) => {
508                let tys = tys
509                    .iter()
510                    .map(|ty| BasicBlockEnvShape::pack_ty(scope, ty))
511                    .collect();
512                BaseTy::Tuple(tys)
513            }
514            BaseTy::Slice(ty) => BaseTy::Slice(Self::pack_ty(scope, ty)),
515            BaseTy::Ref(r, ty, mutbl) => BaseTy::Ref(*r, Self::pack_ty(scope, ty), *mutbl),
516            BaseTy::Array(ty, c) => BaseTy::Array(Self::pack_ty(scope, ty), c.clone()),
517            BaseTy::Int(_)
518            | BaseTy::Param(_)
519            | BaseTy::Uint(_)
520            | BaseTy::Bool
521            | BaseTy::Float(_)
522            | BaseTy::Str
523            | BaseTy::RawPtr(_, _)
524            | BaseTy::RawPtrMetadata(_)
525            | BaseTy::Char
526            | BaseTy::Never
527            | BaseTy::Closure(..)
528            | BaseTy::Dynamic(..)
529            | BaseTy::Alias(..)
530            | BaseTy::FnPtr(..)
531            | BaseTy::Foreign(..)
532            | BaseTy::Coroutine(..) => {
533                if scope.has_free_vars(bty) {
534                    tracked_span_bug!("unexpected type with free vars")
535                } else {
536                    bty.clone()
537                }
538            }
539            BaseTy::Infer(..) => {
540                tracked_span_bug!("unexpected infer type")
541            }
542        }
543    }
544
545    fn pack_generic_arg(scope: &Scope, arg: &GenericArg) -> GenericArg {
546        match arg {
547            GenericArg::Ty(ty) => GenericArg::Ty(Self::pack_ty(scope, ty)),
548            GenericArg::Base(arg) => {
549                assert!(!scope.has_free_vars(arg));
550                GenericArg::Base(arg.clone())
551            }
552            GenericArg::Lifetime(re) => GenericArg::Lifetime(*re),
553            GenericArg::Const(c) => GenericArg::Const(c.clone()),
554        }
555    }
556
557    fn update(&mut self, path: &Path, ty: Ty, span: Span) {
558        self.bindings.lookup(path, span).update(ty);
559    }
560
561    /// join(self, genv, other) consumes the bindings in other, to "update"
562    /// `self` in place, and returns `true` if there was an actual change
563    /// or `false` indicating no change (i.e., a fixpoint was reached).
564    pub(crate) fn join(&mut self, other: TypeEnv, span: Span) -> bool {
565        let paths = self.bindings.paths();
566
567        // Join types
568        let mut modified = false;
569        for path in &paths {
570            let ty1 = self.bindings.get(path);
571            let ty2 = other.bindings.get(path);
572            let ty = if ty1 == ty2 { ty1.clone() } else { self.join_ty(&ty1, &ty2) };
573            modified |= ty1 != ty;
574            self.update(path, ty, span);
575        }
576
577        modified
578    }
579
580    fn join_ty(&self, ty1: &Ty, ty2: &Ty) -> Ty {
581        match (ty1.kind(), ty2.kind()) {
582            (TyKind::Blocked(ty1), _) => Ty::blocked(self.join_ty(ty1, &ty2.unblocked())),
583            (_, TyKind::Blocked(ty2)) => Ty::blocked(self.join_ty(&ty1.unblocked(), ty2)),
584            (TyKind::Uninit, _) | (_, TyKind::Uninit) => Ty::uninit(),
585            (TyKind::Exists(ty1), _) => self.join_ty(ty1.as_ref().skip_binder(), ty2),
586            (_, TyKind::Exists(ty2)) => self.join_ty(ty1, ty2.as_ref().skip_binder()),
587            (TyKind::Constr(_, ty1), _) => self.join_ty(ty1, ty2),
588            (_, TyKind::Constr(_, ty2)) => self.join_ty(ty1, ty2),
589            (TyKind::Indexed(bty1, idx1), TyKind::Indexed(bty2, idx2)) => {
590                let bty = self.join_bty(bty1, bty2);
591                let mut sorts = vec![];
592                let idx = self.join_idx(idx1, idx2, &bty.sort(), &mut sorts);
593                if sorts.is_empty() {
594                    Ty::indexed(bty, idx)
595                } else {
596                    let ty = Ty::constr(Expr::hole(HoleKind::Pred), Ty::indexed(bty, idx));
597                    Ty::exists(Binder::bind_with_sorts(ty, &sorts))
598                }
599            }
600            (TyKind::Ptr(rk1, path1), TyKind::Ptr(rk2, path2)) => {
601                debug_assert_eq!(rk1, rk2);
602                debug_assert_eq!(path1, path2);
603                Ty::ptr(*rk1, path1.clone())
604            }
605            (TyKind::Param(param_ty1), TyKind::Param(param_ty2)) => {
606                debug_assert_eq!(param_ty1, param_ty2);
607                Ty::param(*param_ty1)
608            }
609            (
610                TyKind::Downcast(adt1, args1, ty1, variant1, fields1),
611                TyKind::Downcast(adt2, args2, ty2, variant2, fields2),
612            ) => {
613                debug_assert_eq!(adt1, adt2);
614                debug_assert_eq!(args1, args2);
615                debug_assert!(ty1 == ty2 && !self.scope.has_free_vars(ty2));
616                debug_assert_eq!(variant1, variant2);
617                debug_assert_eq!(fields1.len(), fields2.len());
618                let fields = iter::zip(fields1, fields2)
619                    .map(|(ty1, ty2)| self.join_ty(ty1, ty2))
620                    .collect();
621                Ty::downcast(adt1.clone(), args1.clone(), ty1.clone(), *variant1, fields)
622            }
623            _ => tracked_span_bug!("unexpected types: `{ty1:?}` - `{ty2:?}`"),
624        }
625    }
626
627    fn join_idx(&self, e1: &Expr, e2: &Expr, sort: &Sort, bound_sorts: &mut Vec<Sort>) -> Expr {
628        match (e1.kind(), e2.kind(), sort) {
629            (ExprKind::Tuple(es1), ExprKind::Tuple(es2), Sort::Tuple(sorts)) => {
630                debug_assert_eq3!(es1.len(), es2.len(), sorts.len());
631                Expr::tuple(
632                    izip!(es1, es2, sorts)
633                        .map(|(e1, e2, sort)| self.join_idx(e1, e2, sort, bound_sorts))
634                        .collect(),
635                )
636            }
637            (
638                ExprKind::Ctor(Ctor::Struct(_), flds1),
639                ExprKind::Ctor(Ctor::Struct(_), flds2),
640                Sort::App(SortCtor::Adt(sort_def), args),
641            ) => {
642                let sorts = sort_def.struct_variant().field_sorts(args);
643                debug_assert_eq3!(flds1.len(), flds2.len(), sorts.len());
644
645                Expr::ctor_struct(
646                    sort_def.did(),
647                    izip!(flds1, flds2, &sorts)
648                        .map(|(f1, f2, sort)| self.join_idx(f1, f2, sort, bound_sorts))
649                        .collect(),
650                )
651            }
652            _ => {
653                let has_free_vars2 = self.scope.has_free_vars(e2);
654                let has_escaping_vars1 = e1.has_escaping_bvars();
655                let has_escaping_vars2 = e2.has_escaping_bvars();
656                if !has_free_vars2 && !has_escaping_vars1 && !has_escaping_vars2 && e1 == e2 {
657                    e1.clone()
658                } else if sort.is_pred() {
659                    // FIXME(nilehmann) we shouldn't special case predicates here. Instead, we
660                    // should differentiate between generics and indices.
661                    let fsort = sort.expect_func().expect_mono();
662                    Expr::abs(Lambda::bind_with_fsort(Expr::hole(HoleKind::Pred), fsort))
663                } else {
664                    bound_sorts.push(sort.clone());
665                    Expr::bvar(
666                        INNERMOST,
667                        BoundVar::from_usize(bound_sorts.len() - 1),
668                        BoundReftKind::Anon,
669                    )
670                }
671            }
672        }
673    }
674
675    fn join_bty(&self, bty1: &BaseTy, bty2: &BaseTy) -> BaseTy {
676        match (bty1, bty2) {
677            (BaseTy::Adt(def1, args1), BaseTy::Adt(def2, args2)) => {
678                tracked_span_dbg_assert_eq!(def1.did(), def2.did());
679                let args = iter::zip(args1, args2)
680                    .map(|(arg1, arg2)| self.join_generic_arg(arg1, arg2))
681                    .collect();
682                BaseTy::adt(def1.clone(), List::from_vec(args))
683            }
684            (BaseTy::Tuple(fields1), BaseTy::Tuple(fields2)) => {
685                let fields = iter::zip(fields1, fields2)
686                    .map(|(ty1, ty2)| self.join_ty(ty1, ty2))
687                    .collect();
688                BaseTy::Tuple(fields)
689            }
690            (BaseTy::Alias(kind1, alias_ty1), BaseTy::Alias(kind2, alias_ty2)) => {
691                tracked_span_dbg_assert_eq!(kind1, kind2);
692                tracked_span_dbg_assert_eq!(alias_ty1, alias_ty2);
693                BaseTy::Alias(*kind1, alias_ty1.clone())
694            }
695            (BaseTy::Ref(r1, ty1, mutbl1), BaseTy::Ref(r2, ty2, mutbl2)) => {
696                tracked_span_dbg_assert_eq!(r1, r2);
697                tracked_span_dbg_assert_eq!(mutbl1, mutbl2);
698                BaseTy::Ref(*r1, self.join_ty(ty1, ty2), *mutbl1)
699            }
700            (BaseTy::Array(ty1, len1), BaseTy::Array(ty2, len2)) => {
701                tracked_span_dbg_assert_eq!(len1, len2);
702                BaseTy::Array(self.join_ty(ty1, ty2), len1.clone())
703            }
704            (BaseTy::Slice(ty1), BaseTy::Slice(ty2)) => BaseTy::Slice(self.join_ty(ty1, ty2)),
705            _ => {
706                tracked_span_dbg_assert_eq!(bty1, bty2);
707                bty1.clone()
708            }
709        }
710    }
711
712    fn join_generic_arg(&self, arg1: &GenericArg, arg2: &GenericArg) -> GenericArg {
713        match (arg1, arg2) {
714            (GenericArg::Ty(ty1), GenericArg::Ty(ty2)) => GenericArg::Ty(self.join_ty(ty1, ty2)),
715            (GenericArg::Base(ctor1), GenericArg::Base(ctor2)) => {
716                let sty1 = ctor1.as_ref().skip_binder();
717                let sty2 = ctor2.as_ref().skip_binder();
718                debug_assert!(sty1.idx.is_nu());
719                debug_assert!(sty2.idx.is_nu());
720
721                let bty = self.join_bty(&sty1.bty, &sty2.bty);
722                let pred = if self.scope.has_free_vars(&sty2.pred) || sty1.pred != sty2.pred {
723                    Expr::hole(HoleKind::Pred)
724                } else {
725                    sty1.pred.clone()
726                };
727                let sort = bty.sort();
728                let ctor = Binder::bind_with_sort(SubsetTy::new(bty, Expr::nu(), pred), sort);
729                GenericArg::Base(ctor)
730            }
731            (GenericArg::Lifetime(re1), GenericArg::Lifetime(_re2)) => {
732                // TODO(nilehmann) loop_abstract_refinement.rs is triggering this assertion to fail
733                // wee should fix it.
734                // debug_assert_eq!(re1, _re2);
735                GenericArg::Lifetime(*re1)
736            }
737            (GenericArg::Const(c1), GenericArg::Const(c2)) => {
738                debug_assert_eq!(c1, c2);
739                GenericArg::Const(c1.clone())
740            }
741            _ => tracked_span_bug!("unexpected generic args: `{arg1:?}` - `{arg2:?}`"),
742        }
743    }
744
745    pub fn into_bb_env(self, infcx: &mut InferCtxtRoot, body: &Body) -> BasicBlockEnv {
746        let mut delegate = LocalHoister::default();
747        let mut hoister = Hoister::with_delegate(&mut delegate).transparent();
748
749        let mut bindings = self.bindings;
750        bindings.fmap_mut(|loc, ty| {
751            let name = if let Loc::Local(local) = loc {
752                body.local_names.get(local).copied()
753            } else {
754                None
755            };
756            hoister.delegate.name = name;
757            hoister.hoist(ty)
758        });
759
760        BasicBlockEnv {
761            // We are relying on all the types in `bindings` not having escaping bvars, otherwise
762            // we would have to shift them in since we are creating a new binder.
763            data: delegate.bind(|vars, preds| {
764                // Replace all holes with a single fresh kvar on all parameters
765                let mut constrs = preds
766                    .into_iter()
767                    .filter(|pred| !matches!(pred.kind(), ExprKind::Hole(HoleKind::Pred)))
768                    .collect_vec();
769                let kvar = infcx.fresh_kvar_in_scope(
770                    std::slice::from_ref(&vars),
771                    &self.scope,
772                    KVarEncoding::Conj,
773                );
774                constrs.push(kvar);
775
776                // Replace remaining holes by fresh kvars
777                let mut kvar_gen = |binders: &[_], kind| {
778                    debug_assert_eq!(kind, HoleKind::Pred);
779                    let binders = std::iter::once(vars.clone())
780                        .chain(binders.iter().cloned())
781                        .collect_vec();
782                    infcx.fresh_kvar_in_scope(&binders, &self.scope, KVarEncoding::Conj)
783                };
784                bindings.fmap_mut(|_, binding| binding.replace_holes(&mut kvar_gen));
785
786                BasicBlockEnvData { constrs: constrs.into(), bindings }
787            }),
788            scope: self.scope,
789        }
790    }
791}
792
793impl TypeVisitable for BasicBlockEnvData {
794    fn visit_with<V: TypeVisitor>(&self, _visitor: &mut V) -> ControlFlow<V::BreakTy> {
795        unimplemented!()
796    }
797}
798
799impl TypeFoldable for BasicBlockEnvData {
800    fn try_fold_with<F: FallibleTypeFolder>(
801        &self,
802        folder: &mut F,
803    ) -> std::result::Result<Self, F::Error> {
804        Ok(BasicBlockEnvData {
805            constrs: self.constrs.try_fold_with(folder)?,
806            bindings: self.bindings.try_fold_with(folder)?,
807        })
808    }
809}
810
811impl BasicBlockEnv {
812    pub(crate) fn enter<'a>(
813        &self,
814        infcx: &mut InferCtxt,
815        local_decls: &'a LocalDecls,
816    ) -> TypeEnv<'a> {
817        let data = self.data.replace_bound_refts_with(|sort, _, kind| {
818            Expr::fvar(infcx.define_bound_reft_var(sort, kind))
819        });
820        for constr in &data.constrs {
821            infcx.assume_pred(constr);
822        }
823        TypeEnv { bindings: data.bindings, local_decls }
824    }
825
826    pub(crate) fn scope(&self) -> &Scope {
827        &self.scope
828    }
829}
830
831mod pretty {
832    use std::fmt;
833
834    use flux_middle::pretty::*;
835
836    use super::*;
837
838    impl Pretty for TypeEnv<'_> {
839        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
840            w!(cx, f, "{:?}", &self.bindings)
841        }
842
843        fn default_cx(tcx: TyCtxt) -> PrettyCx {
844            PlacesTree::default_cx(tcx)
845        }
846    }
847
848    impl Pretty for BasicBlockEnvShape {
849        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
850            w!(cx, f, "{:?} {:?}", &self.scope, &self.bindings)
851        }
852
853        fn default_cx(tcx: TyCtxt) -> PrettyCx {
854            PlacesTree::default_cx(tcx)
855        }
856    }
857
858    impl Pretty for BasicBlockEnv {
859        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
860            w!(cx, f, "{:?} ", &self.scope)?;
861
862            let vars = self.data.vars();
863            cx.with_bound_vars(vars, || {
864                if !vars.is_empty() {
865                    cx.fmt_bound_vars(true, "for<", vars, "> ", f)?;
866                }
867                let data = self.data.as_ref().skip_binder();
868                if !data.constrs.is_empty() {
869                    w!(
870                        cx,
871                        f,
872                        "{:?} ⇒ ",
873                        join!(", ", data.constrs.iter().filter(|pred| !pred.is_trivially_true()))
874                    )?;
875                }
876                w!(cx, f, "{:?}", &data.bindings)
877            })
878        }
879
880        fn default_cx(tcx: TyCtxt) -> PrettyCx {
881            PlacesTree::default_cx(tcx)
882        }
883    }
884
885    impl_debug_with_default_cx! {
886        TypeEnv<'_> => "type_env",
887        BasicBlockEnvShape => "basic_block_env_shape",
888        BasicBlockEnv => "basic_block_env"
889    }
890}
891
892/// A very explicit representation of [`TypeEnv`] for debugging/tracing/serialization ONLY.
893#[derive(Serialize, DebugAsJson)]
894pub struct TypeEnvTrace(Vec<TypeEnvBind>);
895
896#[derive(Serialize)]
897struct TypeEnvBind {
898    local: LocInfo,
899    name: Option<String>,
900    kind: String,
901    ty: String,
902    span: Option<SpanTrace>,
903}
904
905#[derive(Serialize)]
906enum LocInfo {
907    Local(String),
908    Var(String),
909}
910
911fn loc_info(loc: &Loc) -> LocInfo {
912    match loc {
913        Loc::Local(local) => LocInfo::Local(format!("{local:?}")),
914        Loc::Var(var) => LocInfo::Var(format!("{var:?}")),
915    }
916}
917
918fn loc_name(local_names: &UnordMap<Local, Symbol>, loc: &Loc) -> Option<String> {
919    if let Loc::Local(local) = loc {
920        let name = local_names.get(local)?;
921        return Some(format!("{name}"));
922    }
923    None
924}
925
926fn loc_span(
927    genv: GlobalEnv,
928    local_decls: &IndexVec<Local, LocalDecl>,
929    loc: &Loc,
930) -> Option<SpanTrace> {
931    if let Loc::Local(local) = loc {
932        return local_decls
933            .get(*local)
934            .map(|local_decl| SpanTrace::new(genv.tcx(), local_decl.source_info.span));
935    }
936    None
937}
938
939impl TypeEnvTrace {
940    pub fn new(
941        genv: GlobalEnv,
942        local_names: &UnordMap<Local, Symbol>,
943        local_decls: &IndexVec<Local, LocalDecl>,
944        cx: PrettyCx,
945        env: &TypeEnv,
946    ) -> Self {
947        let mut bindings = vec![];
948        env.bindings
949            .iter()
950            .filter(|(_, binding)| !binding.ty.is_uninit())
951            .sorted_by(|(loc1, _), (loc2, _)| loc1.cmp(loc2))
952            .for_each(|(loc, binding)| {
953                let name = loc_name(local_names, loc);
954                let local = loc_info(loc);
955                let kind = format!("{:?}", binding.kind);
956                let ty = binding.ty.nested_string(&cx);
957                let span = loc_span(genv, local_decls, loc);
958                bindings.push(TypeEnvBind { name, local, kind, ty, span });
959            });
960
961        TypeEnvTrace(bindings)
962    }
963}