flux_refineck/
type_env.rs

1mod place_ty;
2
3use std::{iter, ops::ControlFlow};
4
5use flux_common::{
6    bug,
7    dbg::{SpanTrace, debug_assert_eq3},
8    tracked_span_bug, tracked_span_dbg_assert_eq,
9};
10use flux_infer::{
11    fixpoint_encoding::KVarEncoding,
12    infer::{ConstrReason, InferCtxt, InferCtxtAt, InferCtxtRoot, InferResult},
13    refine_tree::Scope,
14};
15use flux_macros::DebugAsJson;
16use flux_middle::{
17    PlaceExt as _,
18    global_env::GlobalEnv,
19    pretty::{PrettyCx, PrettyNested},
20    queries::QueryResult,
21    rty::{
22        BaseTy, Binder, BoundReftKind, Ctor, Ensures, Expr, ExprKind, FnSig, GenericArg, HoleKind,
23        INNERMOST, Lambda, List, Loc, Mutability, Path, PtrKind, Region, SortCtor, SubsetTy, Ty,
24        TyKind, VariantIdx,
25        canonicalize::{Hoister, LocalHoister},
26        fold::{FallibleTypeFolder, TypeFoldable, TypeVisitable, TypeVisitor},
27        region_matching::{rty_match_regions, ty_match_regions},
28    },
29};
30use flux_rustc_bridge::{
31    self,
32    mir::{BasicBlock, Body, Local, LocalDecl, LocalDecls, Place, PlaceElem},
33    ty,
34};
35use itertools::{Itertools, izip};
36use rustc_data_structures::unord::UnordMap;
37use rustc_index::{IndexSlice, IndexVec};
38use rustc_middle::{mir::RETURN_PLACE, ty::TyCtxt};
39use rustc_span::{Span, Symbol};
40use rustc_type_ir::BoundVar;
41use serde::Serialize;
42
43use self::place_ty::{LocKind, PlacesTree};
44use super::rty::Sort;
45
46#[derive(Clone, Default)]
47pub struct TypeEnv<'a> {
48    bindings: PlacesTree,
49    local_decls: &'a LocalDecls,
50}
51
52pub struct BasicBlockEnvShape {
53    scope: Scope,
54    bindings: PlacesTree,
55}
56
57pub struct BasicBlockEnv {
58    data: Binder<BasicBlockEnvData>,
59    scope: Scope,
60}
61
62#[derive(Debug)]
63struct BasicBlockEnvData {
64    constrs: List<Expr>,
65    bindings: PlacesTree,
66}
67
68impl<'a> TypeEnv<'a> {
69    pub fn new(infcx: &mut InferCtxt, body: &'a Body, fn_sig: &FnSig) -> TypeEnv<'a> {
70        let mut env = TypeEnv { bindings: PlacesTree::default(), local_decls: &body.local_decls };
71
72        for requires in fn_sig.requires() {
73            infcx.assume_pred(requires);
74        }
75
76        for (local, ty) in body.args_iter().zip(fn_sig.inputs()) {
77            let ty = infcx.unpack(ty);
78            infcx.assume_invariants(&ty);
79            env.alloc_with_ty(local, ty);
80        }
81
82        for local in body.vars_and_temps_iter() {
83            env.alloc(local);
84        }
85
86        env.alloc(RETURN_PLACE);
87        env
88    }
89
90    pub fn empty() -> TypeEnv<'a> {
91        TypeEnv { bindings: PlacesTree::default(), local_decls: IndexSlice::empty() }
92    }
93
94    fn alloc_with_ty(&mut self, local: Local, ty: Ty) {
95        let ty = ty_match_regions(&ty, &self.local_decls[local].ty);
96        self.bindings.insert(local.into(), LocKind::Local, ty);
97    }
98
99    fn alloc(&mut self, local: Local) {
100        self.bindings
101            .insert(local.into(), LocKind::Local, Ty::uninit());
102    }
103
104    pub(crate) fn into_infer(self, scope: Scope) -> BasicBlockEnvShape {
105        BasicBlockEnvShape::new(scope, self)
106    }
107
108    pub(crate) fn lookup_rust_ty(&self, genv: GlobalEnv, place: &Place) -> QueryResult<ty::Ty> {
109        Ok(place.ty(genv, self.local_decls)?.ty)
110    }
111
112    pub(crate) fn lookup_place(
113        &mut self,
114        infcx: &mut InferCtxtAt,
115        place: &Place,
116    ) -> InferResult<Ty> {
117        let span = infcx.span;
118        Ok(self.bindings.lookup_unfolding(infcx, place, span)?.ty)
119    }
120
121    pub(crate) fn get(&self, path: &Path) -> Ty {
122        self.bindings.get(path)
123    }
124
125    pub fn update_path(&mut self, path: &Path, new_ty: Ty, span: Span) {
126        self.bindings.lookup(path, span).update(new_ty);
127    }
128
129    /// When checking a borrow in the right hand side of an assignment `x = &'?n p`, we use the
130    /// annotated region `'?n` in the type of the result. This region will only be used temporarily
131    /// and then replaced by the region in the type of `x` after the assignment. See [`TypeEnv::assign`]
132    pub(crate) fn borrow(
133        &mut self,
134        infcx: &mut InferCtxtAt,
135        re: Region,
136        mutbl: Mutability,
137        place: &Place,
138    ) -> InferResult<Ty> {
139        let span = infcx.span;
140        let result = self.bindings.lookup_unfolding(infcx, place, span)?;
141        if result.is_strg && mutbl == Mutability::Mut {
142            Ok(Ty::ptr(PtrKind::Mut(re), result.path()))
143        } else {
144            // FIXME(nilehmann) we should block the place here. That would require a notion
145            // of shared vs mutable block types because sometimes blocked places from a shared
146            // reference never get unblocked and we should still allow reads through them.
147            Ok(Ty::mk_ref(re, result.ty, mutbl))
148        }
149    }
150
151    // FIXME(nilehmann) this is only used in a single place and we have it because [`TypeEnv`]
152    // doesn't expose a lookup without unfolding
153    pub(crate) fn ptr_to_ref_at_place(
154        &mut self,
155        infcx: &mut InferCtxtAt,
156        place: &Place,
157    ) -> InferResult {
158        let lookup = self.bindings.lookup(place, infcx.span);
159        let TyKind::Ptr(PtrKind::Mut(re), path) = lookup.ty.kind() else {
160            tracked_span_bug!("ptr_to_borrow called on non mutable pointer type")
161        };
162
163        let ref_ty =
164            self.ptr_to_ref(infcx, ConstrReason::Other, *re, path, PtrToRefBound::Infer)?;
165
166        self.bindings.lookup(place, infcx.span).update(ref_ty);
167
168        Ok(())
169    }
170
171    /// Convert a (strong) pointer to a mutable reference.
172    ///
173    /// This roughly implements the following inference rule:
174    /// ```text
175    ///                   t₁ <: t₂
176    /// -------------------------------------------------
177    /// Γ₁,ℓ:t1,Γ₂ ; ptr(mut, ℓ) => Γ₁,ℓ:†t₂,Γ₂ ; &mut t2
178    /// ```
179    /// That's it, we first get the current type `t₁` at location `ℓ` and check it is a subtype
180    /// of `t₂`. Then, we update the type of `ℓ` to `t₂` and block the place.
181    ///
182    /// The bound `t₂` can be either inferred ([`PtrToRefBound::Infer`]), explicitly provided
183    /// ([`PtrToRefBound::Ty`]), or made equal to `t₁` ([`PtrToRefBound::Identity`]).
184    ///
185    /// As an example, consider the environment `x: i32[a]` and the pointer `ptr(mut, x)`.
186    /// Converting the pointer to a mutable reference with an inferred bound produces the following
187    /// derivation (roughly):
188    ///
189    /// ```text
190    ///                    i32[a] <: i32{v: $k(v)}
191    /// ----------------------------------------------------------------
192    /// x: i32[a] ; ptr(mut, x) => x:†i32{v: $k(v)} ; &mut i32{v: $k(v)}
193    /// ```
194    pub(crate) fn ptr_to_ref(
195        &mut self,
196        infcx: &mut InferCtxtAt,
197        reason: ConstrReason,
198        re: Region,
199        path: &Path,
200        bound: PtrToRefBound,
201    ) -> InferResult<Ty> {
202        // ℓ: t1
203        let t1 = self.bindings.lookup(path, infcx.span).fold(infcx)?;
204
205        // t1 <: t2
206        let t2 = match bound {
207            PtrToRefBound::Ty(t2) => {
208                let t2 = rty_match_regions(&t2, &t1);
209                infcx.subtyping(&t1, &t2, reason)?;
210                t2
211            }
212            PtrToRefBound::Infer => {
213                let t2 = t1.with_holes().replace_holes(|sorts, kind| {
214                    debug_assert_eq!(kind, HoleKind::Pred);
215                    infcx.fresh_kvar(sorts, KVarEncoding::Conj)
216                });
217                infcx.subtyping(&t1, &t2, reason)?;
218                t2
219            }
220            PtrToRefBound::Identity => t1.clone(),
221        };
222
223        // ℓ: †t2
224        self.bindings
225            .lookup(path, infcx.span)
226            .block_with(t2.clone());
227
228        Ok(Ty::mk_ref(re, t2, Mutability::Mut))
229    }
230
231    pub(crate) fn fold_local_ptrs(&mut self, infcx: &mut InferCtxtAt) -> InferResult {
232        for (loc, bound, ty) in self.bindings.local_ptrs() {
233            infcx.subtyping(&ty, &bound, ConstrReason::FoldLocal)?;
234            self.bindings.remove_local(&loc);
235        }
236        Ok(())
237    }
238
239    /// Updates the type of `place` to `new_ty`. This may involve a *strong update* if we have
240    /// ownership of `place` or a *weak update* if it's behind a reference (which fires a subtyping
241    /// constraint)
242    ///
243    /// When strong updating, the process involves recovering the original regions (lifetimes) used
244    /// in the (unrefined) Rust type of `place` and then substituting these regions in `new_ty`. For
245    /// instance, if we are assigning a value of type `S<&'?10 i32{v: v > 0}>` to a variable `x`,
246    /// and the (unrefined) Rust type of `x` is `S<&'?5 i32>`, before the assignment, we identify a
247    /// substitution that maps the region `'?10` to `'?5`. After applying this substitution, the
248    /// type of the place `x` is updated accordingly. This ensures that the lifetimes in the
249    /// assigned type are consistent with those expected by the place's original type definition.
250    pub(crate) fn assign(
251        &mut self,
252        infcx: &mut InferCtxtAt,
253        place: &Place,
254        new_ty: Ty,
255    ) -> InferResult {
256        let rustc_ty = place.ty(infcx.genv, self.local_decls)?.ty;
257        let new_ty = ty_match_regions(&new_ty, &rustc_ty);
258        let span = infcx.span;
259        let result = self.bindings.lookup_unfolding(infcx, place, span)?;
260        if result.is_strg {
261            result.update(new_ty);
262        } else if !place.behind_raw_ptr(infcx.genv, self.local_decls)? {
263            infcx.subtyping(&new_ty, &result.ty, ConstrReason::Assign)?;
264        }
265        Ok(())
266    }
267
268    pub(crate) fn move_place(&mut self, infcx: &mut InferCtxtAt, place: &Place) -> InferResult<Ty> {
269        let span = infcx.span;
270        let result = self.bindings.lookup_unfolding(infcx, place, span)?;
271        if result.is_strg {
272            let uninit = Ty::uninit();
273            Ok(result.update(uninit))
274        } else {
275            // ignore the 'move' and trust rustc managed the move correctly
276            // https://github.com/flux-rs/flux/issues/725#issuecomment-2295065634
277            Ok(result.ty)
278        }
279    }
280
281    pub(crate) fn unpack(&mut self, infcx: &mut InferCtxt) {
282        self.bindings.fmap_mut(|ty| infcx.hoister(true).hoist(ty));
283    }
284
285    pub(crate) fn unblock(&mut self, infcx: &mut InferCtxt, place: &Place) {
286        self.bindings.unblock(infcx, place);
287    }
288
289    pub(crate) fn check_goto(
290        self,
291        infcx: &mut InferCtxtAt,
292        bb_env: &BasicBlockEnv,
293        target: BasicBlock,
294    ) -> InferResult {
295        infcx.ensure_resolved_evars(|infcx| {
296            let bb_env = bb_env
297                .data
298                .replace_bound_refts_with(|sort, mode, _| infcx.fresh_infer_var(sort, mode));
299
300            // Check constraints
301            for constr in &bb_env.constrs {
302                infcx.check_pred(constr, ConstrReason::Goto(target));
303            }
304
305            // Check subtyping
306            let bb_env = bb_env.bindings.flatten();
307            for (path, _, ty2) in bb_env {
308                let ty1 = self.bindings.get(&path);
309                infcx.subtyping(&ty1.unblocked(), &ty2.unblocked(), ConstrReason::Goto(target))?;
310            }
311            Ok(())
312        })
313    }
314
315    pub(crate) fn fold(&mut self, infcx: &mut InferCtxtAt, place: &Place) -> InferResult {
316        let span = infcx.span;
317        self.bindings.lookup(place, span).fold(infcx)?;
318        Ok(())
319    }
320
321    pub(crate) fn unfold_local_ptr(
322        &mut self,
323        infcx: &mut InferCtxt,
324        bound: &Ty,
325    ) -> InferResult<Loc> {
326        let loc = Loc::from(infcx.define_var(&Sort::Loc));
327        let ty = infcx.unpack(bound);
328        self.bindings
329            .insert(loc, LocKind::LocalPtr(bound.clone()), ty);
330        Ok(loc)
331    }
332
333    /// ```text
334    /// -----------------------------------
335    /// Γ ; &strg <ℓ: t> => Γ,ℓ: t ; ptr(ℓ)
336    /// ```
337    pub(crate) fn unfold_strg_ref(
338        &mut self,
339        infcx: &mut InferCtxt,
340        path: &Path,
341        ty: &Ty,
342    ) -> InferResult<Loc> {
343        if let Some(loc) = path.to_loc() {
344            let ty = infcx.unpack(ty);
345            self.bindings.insert(loc, LocKind::Universal, ty);
346            Ok(loc)
347        } else {
348            bug!("unfold_strg_ref: unexpected path {path:?}")
349        }
350    }
351
352    pub(crate) fn unfold(
353        &mut self,
354        infcx: &mut InferCtxt,
355        place: &Place,
356        span: Span,
357    ) -> InferResult {
358        self.bindings.unfold(infcx, place, span)
359    }
360
361    pub(crate) fn downcast(
362        &mut self,
363        infcx: &mut InferCtxtAt,
364        place: &Place,
365        variant_idx: VariantIdx,
366    ) -> InferResult {
367        let mut down_place = place.clone();
368        let span = infcx.span;
369        down_place
370            .projection
371            .push(PlaceElem::Downcast(None, variant_idx));
372        self.bindings.unfold(infcx, &down_place, span)?;
373        Ok(())
374    }
375
376    pub fn fully_resolve_evars(&mut self, infcx: &InferCtxt) {
377        self.bindings.fmap_mut(|ty| infcx.fully_resolve_evars(ty));
378    }
379
380    pub(crate) fn assume_ensures(
381        &mut self,
382        infcx: &mut InferCtxt,
383        ensures: &[Ensures],
384        span: Span,
385    ) {
386        for ensure in ensures {
387            match ensure {
388                Ensures::Type(path, updated_ty) => {
389                    let updated_ty = infcx.unpack(updated_ty);
390                    infcx.assume_invariants(&updated_ty);
391                    self.update_path(path, updated_ty, span);
392                }
393                Ensures::Pred(e) => infcx.assume_pred(e),
394            }
395        }
396    }
397
398    pub(crate) fn check_ensures(
399        &mut self,
400        at: &mut InferCtxtAt,
401        ensures: &[Ensures],
402        reason: ConstrReason,
403    ) -> InferResult {
404        for constraint in ensures {
405            match constraint {
406                Ensures::Type(path, ty) => {
407                    let actual_ty = self.get(path);
408                    at.subtyping(&actual_ty, ty, reason)?;
409                }
410                Ensures::Pred(e) => {
411                    at.check_pred(e, ConstrReason::Ret);
412                }
413            }
414        }
415        Ok(())
416    }
417}
418
419pub(crate) enum PtrToRefBound {
420    Ty(Ty),
421    Infer,
422    Identity,
423}
424
425impl flux_infer::infer::LocEnv for TypeEnv<'_> {
426    fn ptr_to_ref(
427        &mut self,
428        infcx: &mut InferCtxtAt,
429        reason: ConstrReason,
430        re: Region,
431        path: &Path,
432        bound: Ty,
433    ) -> InferResult<Ty> {
434        self.ptr_to_ref(infcx, reason, re, path, PtrToRefBound::Ty(bound))
435    }
436
437    fn get(&self, path: &Path) -> Ty {
438        self.get(path)
439    }
440
441    fn unfold_strg_ref(&mut self, infcx: &mut InferCtxt, path: &Path, ty: &Ty) -> InferResult<Loc> {
442        self.unfold_strg_ref(infcx, path, ty)
443    }
444}
445
446impl BasicBlockEnvShape {
447    pub fn enter<'a>(&self, local_decls: &'a LocalDecls) -> TypeEnv<'a> {
448        TypeEnv { bindings: self.bindings.clone(), local_decls }
449    }
450
451    fn new(scope: Scope, env: TypeEnv) -> BasicBlockEnvShape {
452        let mut bindings = env.bindings;
453        bindings.fmap_mut(|ty| BasicBlockEnvShape::pack_ty(&scope, ty));
454        BasicBlockEnvShape { scope, bindings }
455    }
456
457    fn pack_ty(scope: &Scope, ty: &Ty) -> Ty {
458        match ty.kind() {
459            TyKind::Indexed(bty, idxs) => {
460                let bty = BasicBlockEnvShape::pack_bty(scope, bty);
461                if scope.has_free_vars(idxs) {
462                    Ty::exists_with_constr(bty, Expr::hole(HoleKind::Pred))
463                } else {
464                    Ty::indexed(bty, idxs.clone())
465                }
466            }
467            TyKind::Downcast(adt, args, ty, variant, fields) => {
468                debug_assert!(!scope.has_free_vars(args));
469                debug_assert!(!scope.has_free_vars(ty));
470                let fields = fields.iter().map(|ty| Self::pack_ty(scope, ty)).collect();
471                Ty::downcast(adt.clone(), args.clone(), ty.clone(), *variant, fields)
472            }
473            TyKind::Blocked(ty) => Ty::blocked(BasicBlockEnvShape::pack_ty(scope, ty)),
474            // FIXME(nilehmann) [`TyKind::Exists`] could also contain free variables.
475            TyKind::Exists(_)
476            | TyKind::Discr(..)
477            | TyKind::Ptr(..)
478            | TyKind::Uninit
479            | TyKind::Param(_)
480            | TyKind::Constr(_, _) => ty.clone(),
481            TyKind::Infer(_) => bug!("unexpected hole whecn checking function body"),
482            TyKind::StrgRef(..) => bug!("unexpected strong reference when checking function body"),
483        }
484    }
485
486    fn pack_bty(scope: &Scope, bty: &BaseTy) -> BaseTy {
487        match bty {
488            BaseTy::Adt(adt_def, args) => {
489                let args = List::from_vec(
490                    args.iter()
491                        .map(|arg| Self::pack_generic_arg(scope, arg))
492                        .collect(),
493                );
494                BaseTy::adt(adt_def.clone(), args)
495            }
496            BaseTy::FnDef(def_id, args) => {
497                let args = List::from_vec(
498                    args.iter()
499                        .map(|arg| Self::pack_generic_arg(scope, arg))
500                        .collect(),
501                );
502                BaseTy::fn_def(*def_id, args)
503            }
504            BaseTy::Tuple(tys) => {
505                let tys = tys
506                    .iter()
507                    .map(|ty| BasicBlockEnvShape::pack_ty(scope, ty))
508                    .collect();
509                BaseTy::Tuple(tys)
510            }
511            BaseTy::Slice(ty) => BaseTy::Slice(Self::pack_ty(scope, ty)),
512            BaseTy::Ref(r, ty, mutbl) => BaseTy::Ref(*r, Self::pack_ty(scope, ty), *mutbl),
513            BaseTy::Array(ty, c) => BaseTy::Array(Self::pack_ty(scope, ty), c.clone()),
514            BaseTy::Int(_)
515            | BaseTy::Param(_)
516            | BaseTy::Uint(_)
517            | BaseTy::Bool
518            | BaseTy::Float(_)
519            | BaseTy::Str
520            | BaseTy::RawPtr(_, _)
521            | BaseTy::RawPtrMetadata(_)
522            | BaseTy::Char
523            | BaseTy::Never
524            | BaseTy::Closure(..)
525            | BaseTy::Dynamic(..)
526            | BaseTy::Alias(..)
527            | BaseTy::FnPtr(..)
528            | BaseTy::Foreign(..)
529            | BaseTy::Coroutine(..) => {
530                if scope.has_free_vars(bty) {
531                    tracked_span_bug!("unexpected type with free vars")
532                } else {
533                    bty.clone()
534                }
535            }
536            BaseTy::Infer(..) => {
537                tracked_span_bug!("unexpected infer type")
538            }
539        }
540    }
541
542    fn pack_generic_arg(scope: &Scope, arg: &GenericArg) -> GenericArg {
543        match arg {
544            GenericArg::Ty(ty) => GenericArg::Ty(Self::pack_ty(scope, ty)),
545            GenericArg::Base(arg) => {
546                assert!(!scope.has_free_vars(arg));
547                GenericArg::Base(arg.clone())
548            }
549            GenericArg::Lifetime(re) => GenericArg::Lifetime(*re),
550            GenericArg::Const(c) => GenericArg::Const(c.clone()),
551        }
552    }
553
554    fn update(&mut self, path: &Path, ty: Ty, span: Span) {
555        self.bindings.lookup(path, span).update(ty);
556    }
557
558    /// join(self, genv, other) consumes the bindings in other, to "update"
559    /// `self` in place, and returns `true` if there was an actual change
560    /// or `false` indicating no change (i.e., a fixpoint was reached).
561    pub(crate) fn join(&mut self, other: TypeEnv, span: Span) -> bool {
562        let paths = self.bindings.paths();
563
564        // Join types
565        let mut modified = false;
566        for path in &paths {
567            let ty1 = self.bindings.get(path);
568            let ty2 = other.bindings.get(path);
569            let ty = self.join_ty(&ty1, &ty2);
570            modified |= ty1 != ty;
571            self.update(path, ty, span);
572        }
573
574        modified
575    }
576
577    fn join_ty(&self, ty1: &Ty, ty2: &Ty) -> Ty {
578        match (ty1.kind(), ty2.kind()) {
579            (TyKind::Blocked(ty1), _) => Ty::blocked(self.join_ty(ty1, &ty2.unblocked())),
580            (_, TyKind::Blocked(ty2)) => Ty::blocked(self.join_ty(&ty1.unblocked(), ty2)),
581            (TyKind::Uninit, _) | (_, TyKind::Uninit) => Ty::uninit(),
582            (TyKind::Exists(ty1), _) => self.join_ty(ty1.as_ref().skip_binder(), ty2),
583            (_, TyKind::Exists(ty2)) => self.join_ty(ty1, ty2.as_ref().skip_binder()),
584            (TyKind::Constr(_, ty1), _) => self.join_ty(ty1, ty2),
585            (_, TyKind::Constr(_, ty2)) => self.join_ty(ty1, ty2),
586            (TyKind::Indexed(bty1, idx1), TyKind::Indexed(bty2, idx2)) => {
587                let bty = self.join_bty(bty1, bty2);
588                let mut sorts = vec![];
589                let idx = self.join_idx(idx1, idx2, &bty.sort(), &mut sorts);
590                if sorts.is_empty() {
591                    Ty::indexed(bty, idx)
592                } else {
593                    let ty = Ty::constr(Expr::hole(HoleKind::Pred), Ty::indexed(bty, idx));
594                    Ty::exists(Binder::bind_with_sorts(ty, &sorts))
595                }
596            }
597            (TyKind::Ptr(rk1, path1), TyKind::Ptr(rk2, path2)) => {
598                debug_assert_eq!(rk1, rk2);
599                debug_assert_eq!(path1, path2);
600                Ty::ptr(*rk1, path1.clone())
601            }
602            (TyKind::Param(param_ty1), TyKind::Param(param_ty2)) => {
603                debug_assert_eq!(param_ty1, param_ty2);
604                Ty::param(*param_ty1)
605            }
606            (
607                TyKind::Downcast(adt1, args1, ty1, variant1, fields1),
608                TyKind::Downcast(adt2, args2, ty2, variant2, fields2),
609            ) => {
610                debug_assert_eq!(adt1, adt2);
611                debug_assert_eq!(args1, args2);
612                debug_assert!(ty1 == ty2 && !self.scope.has_free_vars(ty2));
613                debug_assert_eq!(variant1, variant2);
614                debug_assert_eq!(fields1.len(), fields2.len());
615                let fields = iter::zip(fields1, fields2)
616                    .map(|(ty1, ty2)| self.join_ty(ty1, ty2))
617                    .collect();
618                Ty::downcast(adt1.clone(), args1.clone(), ty1.clone(), *variant1, fields)
619            }
620            _ => tracked_span_bug!("unexpected types: `{ty1:?}` - `{ty2:?}`"),
621        }
622    }
623
624    fn join_idx(&self, e1: &Expr, e2: &Expr, sort: &Sort, bound_sorts: &mut Vec<Sort>) -> Expr {
625        match (e1.kind(), e2.kind(), sort) {
626            (ExprKind::Tuple(es1), ExprKind::Tuple(es2), Sort::Tuple(sorts)) => {
627                debug_assert_eq3!(es1.len(), es2.len(), sorts.len());
628                Expr::tuple(
629                    izip!(es1, es2, sorts)
630                        .map(|(e1, e2, sort)| self.join_idx(e1, e2, sort, bound_sorts))
631                        .collect(),
632                )
633            }
634            (
635                ExprKind::Ctor(Ctor::Struct(_), flds1),
636                ExprKind::Ctor(Ctor::Struct(_), flds2),
637                Sort::App(SortCtor::Adt(sort_def), args),
638            ) => {
639                let sorts = sort_def.struct_variant().field_sorts(args);
640                debug_assert_eq3!(flds1.len(), flds2.len(), sorts.len());
641
642                Expr::ctor_struct(
643                    sort_def.did(),
644                    izip!(flds1, flds2, &sorts)
645                        .map(|(f1, f2, sort)| self.join_idx(f1, f2, sort, bound_sorts))
646                        .collect(),
647                )
648            }
649            _ => {
650                let has_free_vars2 = self.scope.has_free_vars(e2);
651                let has_escaping_vars1 = e1.has_escaping_bvars();
652                let has_escaping_vars2 = e2.has_escaping_bvars();
653                if !has_free_vars2 && !has_escaping_vars1 && !has_escaping_vars2 && e1 == e2 {
654                    e1.clone()
655                } else if sort.is_pred() {
656                    // FIXME(nilehmann) we shouldn't special case predicates here. Instead, we
657                    // should differentiate between generics and indices.
658                    let fsort = sort.expect_func().expect_mono();
659                    Expr::abs(Lambda::bind_with_fsort(Expr::hole(HoleKind::Pred), fsort))
660                } else {
661                    bound_sorts.push(sort.clone());
662                    Expr::bvar(
663                        INNERMOST,
664                        BoundVar::from_usize(bound_sorts.len() - 1),
665                        BoundReftKind::Anon,
666                    )
667                }
668            }
669        }
670    }
671
672    fn join_bty(&self, bty1: &BaseTy, bty2: &BaseTy) -> BaseTy {
673        match (bty1, bty2) {
674            (BaseTy::Adt(def1, args1), BaseTy::Adt(def2, args2)) => {
675                tracked_span_dbg_assert_eq!(def1.did(), def2.did());
676                let args = iter::zip(args1, args2)
677                    .map(|(arg1, arg2)| self.join_generic_arg(arg1, arg2))
678                    .collect();
679                BaseTy::adt(def1.clone(), List::from_vec(args))
680            }
681            (BaseTy::Tuple(fields1), BaseTy::Tuple(fields2)) => {
682                let fields = iter::zip(fields1, fields2)
683                    .map(|(ty1, ty2)| self.join_ty(ty1, ty2))
684                    .collect();
685                BaseTy::Tuple(fields)
686            }
687            (BaseTy::Alias(kind1, alias_ty1), BaseTy::Alias(kind2, alias_ty2)) => {
688                tracked_span_dbg_assert_eq!(kind1, kind2);
689                tracked_span_dbg_assert_eq!(alias_ty1, alias_ty2);
690                BaseTy::Alias(*kind1, alias_ty1.clone())
691            }
692            (BaseTy::Ref(r1, ty1, mutbl1), BaseTy::Ref(r2, ty2, mutbl2)) => {
693                tracked_span_dbg_assert_eq!(r1, r2);
694                tracked_span_dbg_assert_eq!(mutbl1, mutbl2);
695                BaseTy::Ref(*r1, self.join_ty(ty1, ty2), *mutbl1)
696            }
697            (BaseTy::Array(ty1, len1), BaseTy::Array(ty2, len2)) => {
698                tracked_span_dbg_assert_eq!(len1, len2);
699                BaseTy::Array(self.join_ty(ty1, ty2), len1.clone())
700            }
701            (BaseTy::Slice(ty1), BaseTy::Slice(ty2)) => BaseTy::Slice(self.join_ty(ty1, ty2)),
702            _ => {
703                tracked_span_dbg_assert_eq!(bty1, bty2);
704                bty1.clone()
705            }
706        }
707    }
708
709    fn join_generic_arg(&self, arg1: &GenericArg, arg2: &GenericArg) -> GenericArg {
710        match (arg1, arg2) {
711            (GenericArg::Ty(ty1), GenericArg::Ty(ty2)) => GenericArg::Ty(self.join_ty(ty1, ty2)),
712            (GenericArg::Base(ctor1), GenericArg::Base(ctor2)) => {
713                let sty1 = ctor1.as_ref().skip_binder();
714                let sty2 = ctor2.as_ref().skip_binder();
715                debug_assert!(sty1.idx.is_nu());
716                debug_assert!(sty2.idx.is_nu());
717
718                let bty = self.join_bty(&sty1.bty, &sty2.bty);
719                let pred = if self.scope.has_free_vars(&sty2.pred) || sty1.pred != sty2.pred {
720                    Expr::hole(HoleKind::Pred)
721                } else {
722                    sty1.pred.clone()
723                };
724                let sort = bty.sort();
725                let ctor = Binder::bind_with_sort(SubsetTy::new(bty, Expr::nu(), pred), sort);
726                GenericArg::Base(ctor)
727            }
728            (GenericArg::Lifetime(re1), GenericArg::Lifetime(_re2)) => {
729                // TODO(nilehmann) loop_abstract_refinement.rs is triggering this assertion to fail
730                // wee should fix it.
731                // debug_assert_eq!(re1, _re2);
732                GenericArg::Lifetime(*re1)
733            }
734            (GenericArg::Const(c1), GenericArg::Const(c2)) => {
735                debug_assert_eq!(c1, c2);
736                GenericArg::Const(c1.clone())
737            }
738            _ => tracked_span_bug!("unexpected generic args: `{arg1:?}` - `{arg2:?}`"),
739        }
740    }
741
742    pub fn into_bb_env(self, infcx: &mut InferCtxtRoot) -> BasicBlockEnv {
743        let mut delegate = LocalHoister::default();
744        let mut hoister = Hoister::with_delegate(&mut delegate).transparent();
745
746        let mut bindings = self.bindings;
747        bindings.fmap_mut(|ty| hoister.hoist(ty));
748
749        BasicBlockEnv {
750            // We are relying on all the types in `bindings` not having escaping bvars, otherwise
751            // we would have to shift them in since we are creating a new binder.
752            data: delegate.bind(|vars, preds| {
753                // Replace all holes with a single fresh kvar on all parameters
754                let mut constrs = preds
755                    .into_iter()
756                    .filter(|pred| !matches!(pred.kind(), ExprKind::Hole(HoleKind::Pred)))
757                    .collect_vec();
758                let kvar = infcx.fresh_kvar_in_scope(
759                    std::slice::from_ref(&vars),
760                    &self.scope,
761                    KVarEncoding::Conj,
762                );
763                constrs.push(kvar);
764
765                // Replace remaining holes by fresh kvars
766                let mut kvar_gen = |binders: &[_], kind| {
767                    debug_assert_eq!(kind, HoleKind::Pred);
768                    let binders = std::iter::once(vars.clone())
769                        .chain(binders.iter().cloned())
770                        .collect_vec();
771                    infcx.fresh_kvar_in_scope(&binders, &self.scope, KVarEncoding::Conj)
772                };
773                bindings.fmap_mut(|binding| binding.replace_holes(&mut kvar_gen));
774
775                BasicBlockEnvData { constrs: constrs.into(), bindings }
776            }),
777            scope: self.scope,
778        }
779    }
780}
781
782impl TypeVisitable for BasicBlockEnvData {
783    fn visit_with<V: TypeVisitor>(&self, _visitor: &mut V) -> ControlFlow<V::BreakTy> {
784        unimplemented!()
785    }
786}
787
788impl TypeFoldable for BasicBlockEnvData {
789    fn try_fold_with<F: FallibleTypeFolder>(
790        &self,
791        folder: &mut F,
792    ) -> std::result::Result<Self, F::Error> {
793        Ok(BasicBlockEnvData {
794            constrs: self.constrs.try_fold_with(folder)?,
795            bindings: self.bindings.try_fold_with(folder)?,
796        })
797    }
798}
799
800impl BasicBlockEnv {
801    pub(crate) fn enter<'a>(
802        &self,
803        infcx: &mut InferCtxt,
804        local_decls: &'a LocalDecls,
805    ) -> TypeEnv<'a> {
806        let data = self
807            .data
808            .replace_bound_refts_with(|sort, _, _| Expr::fvar(infcx.define_var(sort)));
809        for constr in &data.constrs {
810            infcx.assume_pred(constr);
811        }
812        TypeEnv { bindings: data.bindings, local_decls }
813    }
814
815    pub(crate) fn scope(&self) -> &Scope {
816        &self.scope
817    }
818}
819
820mod pretty {
821    use std::fmt;
822
823    use flux_middle::pretty::*;
824
825    use super::*;
826
827    impl Pretty for TypeEnv<'_> {
828        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
829            w!(cx, f, "{:?}", &self.bindings)
830        }
831
832        fn default_cx(tcx: TyCtxt) -> PrettyCx {
833            PlacesTree::default_cx(tcx)
834        }
835    }
836
837    impl Pretty for BasicBlockEnvShape {
838        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
839            w!(cx, f, "{:?} {:?}", &self.scope, &self.bindings)
840        }
841
842        fn default_cx(tcx: TyCtxt) -> PrettyCx {
843            PlacesTree::default_cx(tcx)
844        }
845    }
846
847    impl Pretty for BasicBlockEnv {
848        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
849            w!(cx, f, "{:?} ", &self.scope)?;
850
851            let vars = self.data.vars();
852            cx.with_bound_vars(vars, || {
853                if !vars.is_empty() {
854                    cx.fmt_bound_vars(true, "for<", vars, "> ", f)?;
855                }
856                let data = self.data.as_ref().skip_binder();
857                if !data.constrs.is_empty() {
858                    w!(
859                        cx,
860                        f,
861                        "{:?} ⇒ ",
862                        join!(", ", data.constrs.iter().filter(|pred| !pred.is_trivially_true()))
863                    )?;
864                }
865                w!(cx, f, "{:?}", &data.bindings)
866            })
867        }
868
869        fn default_cx(tcx: TyCtxt) -> PrettyCx {
870            PlacesTree::default_cx(tcx)
871        }
872    }
873
874    impl_debug_with_default_cx! {
875        TypeEnv<'_> => "type_env",
876        BasicBlockEnvShape => "basic_block_env_shape",
877        BasicBlockEnv => "basic_block_env"
878    }
879}
880
881/// A very explicit representation of [`TypeEnv`] for debugging/tracing/serialization ONLY.
882#[derive(Serialize, DebugAsJson)]
883pub struct TypeEnvTrace(Vec<TypeEnvBind>);
884
885#[derive(Serialize)]
886struct TypeEnvBind {
887    local: LocInfo,
888    name: Option<String>,
889    kind: String,
890    ty: String,
891    span: Option<SpanTrace>,
892}
893
894#[derive(Serialize)]
895enum LocInfo {
896    Local(String),
897    Var(String),
898}
899
900fn loc_info(loc: &Loc) -> LocInfo {
901    match loc {
902        Loc::Local(local) => LocInfo::Local(format!("{local:?}")),
903        Loc::Var(var) => LocInfo::Var(format!("{var:?}")),
904    }
905}
906
907fn loc_name(local_names: &UnordMap<Local, Symbol>, loc: &Loc) -> Option<String> {
908    if let Loc::Local(local) = loc {
909        let name = local_names.get(local)?;
910        return Some(format!("{name}"));
911    }
912    None
913}
914
915fn loc_span(
916    genv: GlobalEnv,
917    local_decls: &IndexVec<Local, LocalDecl>,
918    loc: &Loc,
919) -> Option<SpanTrace> {
920    if let Loc::Local(local) = loc {
921        return local_decls
922            .get(*local)
923            .map(|local_decl| SpanTrace::new(genv.tcx(), local_decl.source_info.span));
924    }
925    None
926}
927
928impl TypeEnvTrace {
929    pub fn new(
930        genv: GlobalEnv,
931        local_names: &UnordMap<Local, Symbol>,
932        local_decls: &IndexVec<Local, LocalDecl>,
933        env: &TypeEnv,
934    ) -> Self {
935        let mut bindings = vec![];
936        let cx = PrettyCx::default(genv).hide_regions(true);
937        env.bindings
938            .iter()
939            .filter(|(_, binding)| !binding.ty.is_uninit())
940            .sorted_by(|(loc1, _), (loc2, _)| loc1.cmp(loc2))
941            .for_each(|(loc, binding)| {
942                let name = loc_name(local_names, loc);
943                let local = loc_info(loc);
944                let kind = format!("{:?}", binding.kind);
945                let ty = binding.ty.nested_string(&cx);
946                let span = loc_span(genv, local_decls, loc);
947                bindings.push(TypeEnvBind { name, local, kind, ty, span });
948            });
949
950        TypeEnvTrace(bindings)
951    }
952}