flux_refineck/
type_env.rs

1mod place_ty;
2
3use std::{iter, ops::ControlFlow};
4
5use flux_common::{
6    bug,
7    dbg::{SpanTrace, debug_assert_eq3},
8    tracked_span_bug, tracked_span_dbg_assert_eq,
9};
10use flux_infer::{
11    fixpoint_encoding::KVarEncoding,
12    infer::{ConstrReason, InferCtxt, InferCtxtAt, InferCtxtRoot, InferResult},
13    refine_tree::Scope,
14};
15use flux_macros::DebugAsJson;
16use flux_middle::{
17    PlaceExt as _,
18    global_env::GlobalEnv,
19    pretty::{PrettyCx, PrettyNested},
20    queries::QueryResult,
21    rty::{
22        BaseTy, Binder, BoundReftKind, Ctor, Ensures, Expr, ExprKind, FnSig, GenericArg, HoleKind,
23        INNERMOST, Lambda, List, Loc, Mutability, Path, PtrKind, Region, SortCtor, SubsetTy, Ty,
24        TyKind, VariantIdx,
25        canonicalize::{Hoister, LocalHoister},
26        fold::{FallibleTypeFolder, TypeFoldable, TypeVisitable, TypeVisitor},
27        region_matching::{rty_match_regions, ty_match_regions},
28    },
29};
30use flux_rustc_bridge::{
31    self,
32    mir::{BasicBlock, Body, Local, LocalDecl, LocalDecls, Place, PlaceElem},
33    ty,
34};
35use itertools::{Itertools, izip};
36use rustc_data_structures::unord::UnordMap;
37use rustc_index::{IndexSlice, IndexVec};
38use rustc_middle::{mir::RETURN_PLACE, ty::TyCtxt};
39use rustc_span::{Span, Symbol};
40use rustc_type_ir::BoundVar;
41use serde::Serialize;
42
43use self::place_ty::{LocKind, PlacesTree};
44use super::rty::Sort;
45
46#[derive(Clone, Default)]
47pub struct TypeEnv<'a> {
48    bindings: PlacesTree,
49    local_decls: &'a LocalDecls,
50}
51
52pub struct BasicBlockEnvShape {
53    scope: Scope,
54    bindings: PlacesTree,
55}
56
57pub struct BasicBlockEnv {
58    data: Binder<BasicBlockEnvData>,
59    scope: Scope,
60}
61
62#[derive(Debug)]
63struct BasicBlockEnvData {
64    constrs: List<Expr>,
65    bindings: PlacesTree,
66}
67
68impl<'a> TypeEnv<'a> {
69    pub fn new(infcx: &mut InferCtxt, body: &'a Body, fn_sig: &FnSig) -> TypeEnv<'a> {
70        let mut env = TypeEnv { bindings: PlacesTree::default(), local_decls: &body.local_decls };
71
72        for requires in fn_sig.requires() {
73            infcx.assume_pred(requires);
74        }
75
76        for (local, ty) in body.args_iter().zip(fn_sig.inputs()) {
77            let ty = infcx.unpack(ty);
78            infcx.assume_invariants(&ty);
79            env.alloc_with_ty(local, ty);
80        }
81
82        for local in body.vars_and_temps_iter() {
83            env.alloc(local);
84        }
85
86        env.alloc(RETURN_PLACE);
87        env
88    }
89
90    pub fn empty() -> TypeEnv<'a> {
91        TypeEnv { bindings: PlacesTree::default(), local_decls: IndexSlice::empty() }
92    }
93
94    fn alloc_with_ty(&mut self, local: Local, ty: Ty) {
95        let ty = ty_match_regions(&ty, &self.local_decls[local].ty);
96        self.bindings.insert(local.into(), LocKind::Local, ty);
97    }
98
99    fn alloc(&mut self, local: Local) {
100        self.bindings
101            .insert(local.into(), LocKind::Local, Ty::uninit());
102    }
103
104    pub(crate) fn into_infer(self, scope: Scope) -> BasicBlockEnvShape {
105        BasicBlockEnvShape::new(scope, self)
106    }
107
108    pub(crate) fn lookup_rust_ty(&self, genv: GlobalEnv, place: &Place) -> QueryResult<ty::Ty> {
109        Ok(place.ty(genv, self.local_decls)?.ty)
110    }
111
112    pub(crate) fn lookup_place(
113        &mut self,
114        infcx: &mut InferCtxtAt,
115        place: &Place,
116    ) -> InferResult<Ty> {
117        let span = infcx.span;
118        Ok(self.bindings.lookup_unfolding(infcx, place, span)?.ty)
119    }
120
121    pub(crate) fn get(&self, path: &Path) -> Ty {
122        self.bindings.get(path)
123    }
124
125    pub fn update_path(&mut self, path: &Path, new_ty: Ty, span: Span) {
126        self.bindings.lookup(path, span).update(new_ty);
127    }
128
129    /// When checking a borrow in the right hand side of an assignment `x = &'?n p`, we use the
130    /// annotated region `'?n` in the type of the result. This region will only be used temporarily
131    /// and then replaced by the region in the type of `x` after the assignment. See [`TypeEnv::assign`]
132    pub(crate) fn borrow(
133        &mut self,
134        infcx: &mut InferCtxtAt,
135        re: Region,
136        mutbl: Mutability,
137        place: &Place,
138    ) -> InferResult<Ty> {
139        let span = infcx.span;
140        let result = self.bindings.lookup_unfolding(infcx, place, span)?;
141        if result.is_strg && mutbl == Mutability::Mut {
142            Ok(Ty::ptr(PtrKind::Mut(re), result.path()))
143        } else {
144            // FIXME(nilehmann) we should block the place here. That would require a notion
145            // of shared vs mutable block types because sometimes blocked places from a shared
146            // reference never get unblocked and we should still allow reads through them.
147            Ok(Ty::mk_ref(re, result.ty, mutbl))
148        }
149    }
150
151    // FIXME(nilehmann) this is only used in a single place and we have it because [`TypeEnv`]
152    // doesn't expose a lookup without unfolding
153    pub(crate) fn ptr_to_ref_at_place(
154        &mut self,
155        infcx: &mut InferCtxtAt,
156        place: &Place,
157    ) -> InferResult {
158        let lookup = self.bindings.lookup(place, infcx.span);
159        let TyKind::Ptr(PtrKind::Mut(re), path) = lookup.ty.kind() else {
160            tracked_span_bug!("ptr_to_borrow called on non mutable pointer type")
161        };
162
163        let ref_ty =
164            self.ptr_to_ref(infcx, ConstrReason::Other, *re, path, PtrToRefBound::Infer)?;
165
166        self.bindings.lookup(place, infcx.span).update(ref_ty);
167
168        Ok(())
169    }
170
171    /// Convert a (strong) pointer to a mutable reference.
172    ///
173    /// This roughly implements the following inference rule:
174    /// ```text
175    ///                   t₁ <: t₂
176    /// -------------------------------------------------
177    /// Γ₁,ℓ:t1,Γ₂ ; ptr(mut, ℓ) => Γ₁,ℓ:†t₂,Γ₂ ; &mut t2
178    /// ```
179    /// That's it, we first get the current type `t₁` at location `ℓ` and check it is a subtype
180    /// of `t₂`. Then, we update the type of `ℓ` to `t₂` and block the place.
181    ///
182    /// The bound `t₂` can be either inferred ([`PtrToRefBound::Infer`]), explicitly provided
183    /// ([`PtrToRefBound::Ty`]), or made equal to `t₁` ([`PtrToRefBound::Identity`]).
184    ///
185    /// As an example, consider the environment `x: i32[a]` and the pointer `ptr(mut, x)`.
186    /// Converting the pointer to a mutable reference with an inferred bound produces the following
187    /// derivation (roughly):
188    ///
189    /// ```text
190    ///                    i32[a] <: i32{v: $k(v)}
191    /// ----------------------------------------------------------------
192    /// x: i32[a] ; ptr(mut, x) => x:†i32{v: $k(v)} ; &mut i32{v: $k(v)}
193    /// ```
194    pub(crate) fn ptr_to_ref(
195        &mut self,
196        infcx: &mut InferCtxtAt,
197        reason: ConstrReason,
198        re: Region,
199        path: &Path,
200        bound: PtrToRefBound,
201    ) -> InferResult<Ty> {
202        // ℓ: t1
203        let t1 = self.bindings.lookup(path, infcx.span).fold(infcx)?;
204
205        // t1 <: t2
206        let t2 = match bound {
207            PtrToRefBound::Ty(t2) => {
208                let t2 = rty_match_regions(&t2, &t1);
209                infcx.subtyping(&t1, &t2, reason)?;
210                t2
211            }
212            PtrToRefBound::Infer => {
213                let t2 = t1.with_holes().replace_holes(|sorts, kind| {
214                    debug_assert_eq!(kind, HoleKind::Pred);
215                    infcx.fresh_kvar(sorts, KVarEncoding::Conj)
216                });
217                infcx.subtyping(&t1, &t2, reason)?;
218                t2
219            }
220            PtrToRefBound::Identity => t1.clone(),
221        };
222
223        // ℓ: †t2
224        self.bindings
225            .lookup(path, infcx.span)
226            .block_with(t2.clone());
227
228        Ok(Ty::mk_ref(re, t2, Mutability::Mut))
229    }
230
231    pub(crate) fn fold_local_ptrs(&mut self, infcx: &mut InferCtxtAt) -> InferResult {
232        for (loc, bound, ty) in self.bindings.local_ptrs() {
233            infcx.subtyping(&ty, &bound, ConstrReason::FoldLocal)?;
234            self.bindings.remove_local(&loc);
235        }
236        Ok(())
237    }
238
239    /// Updates the type of `place` to `new_ty`. This may involve a *strong update* if we have
240    /// ownership of `place` or a *weak update* if it's behind a reference (which fires a subtyping
241    /// constraint)
242    ///
243    /// When strong updating, the process involves recovering the original regions (lifetimes) used
244    /// in the (unrefined) Rust type of `place` and then substituting these regions in `new_ty`. For
245    /// instance, if we are assigning a value of type `S<&'?10 i32{v: v > 0}>` to a variable `x`,
246    /// and the (unrefined) Rust type of `x` is `S<&'?5 i32>`, before the assignment, we identify a
247    /// substitution that maps the region `'?10` to `'?5`. After applying this substitution, the
248    /// type of the place `x` is updated accordingly. This ensures that the lifetimes in the
249    /// assigned type are consistent with those expected by the place's original type definition.
250    pub(crate) fn assign(
251        &mut self,
252        infcx: &mut InferCtxtAt,
253        place: &Place,
254        new_ty: Ty,
255    ) -> InferResult {
256        let rustc_ty = place.ty(infcx.genv, self.local_decls)?.ty;
257        let new_ty = ty_match_regions(&new_ty, &rustc_ty);
258        let span = infcx.span;
259        let result = self.bindings.lookup_unfolding(infcx, place, span)?;
260        if result.is_strg {
261            result.update(new_ty);
262        } else if !place.behind_raw_ptr(infcx.genv, self.local_decls)? {
263            infcx.subtyping(&new_ty, &result.ty, ConstrReason::Assign)?;
264        }
265        Ok(())
266    }
267
268    pub(crate) fn move_place(&mut self, infcx: &mut InferCtxtAt, place: &Place) -> InferResult<Ty> {
269        let span = infcx.span;
270        let result = self.bindings.lookup_unfolding(infcx, place, span)?;
271        if result.is_strg {
272            let uninit = Ty::uninit();
273            Ok(result.update(uninit))
274        } else {
275            // ignore the 'move' and trust rustc managed the move correctly
276            // https://github.com/flux-rs/flux/issues/725#issuecomment-2295065634
277            Ok(result.ty)
278        }
279    }
280
281    pub(crate) fn unpack(&mut self, infcx: &mut InferCtxt) {
282        self.bindings
283            .fmap_mut(|_loc, ty| infcx.hoister(true).hoist(ty));
284    }
285
286    pub(crate) fn unblock(&mut self, infcx: &mut InferCtxt, place: &Place) {
287        self.bindings.unblock(infcx, place);
288    }
289
290    pub(crate) fn check_goto(
291        self,
292        infcx: &mut InferCtxtAt,
293        bb_env: &BasicBlockEnv,
294        target: BasicBlock,
295    ) -> InferResult {
296        infcx.ensure_resolved_evars(|infcx| {
297            let bb_env = bb_env
298                .data
299                .replace_bound_refts_with(|sort, mode, _| infcx.fresh_infer_var(sort, mode));
300
301            // Check constraints
302            for constr in &bb_env.constrs {
303                infcx.check_pred(constr, ConstrReason::Goto(target));
304            }
305
306            // Check subtyping
307            let bb_env = bb_env.bindings.flatten();
308            for (path, _, ty2) in bb_env {
309                let ty1 = self.bindings.get(&path);
310                infcx.subtyping(&ty1.unblocked(), &ty2.unblocked(), ConstrReason::Goto(target))?;
311            }
312            Ok(())
313        })
314    }
315
316    pub(crate) fn fold(&mut self, infcx: &mut InferCtxtAt, place: &Place) -> InferResult {
317        let span = infcx.span;
318        self.bindings.lookup(place, span).fold(infcx)?;
319        Ok(())
320    }
321
322    pub(crate) fn unfold_local_ptr(
323        &mut self,
324        infcx: &mut InferCtxt,
325        bound: &Ty,
326    ) -> InferResult<Loc> {
327        let name = infcx.define_unknown_var(&Sort::Loc);
328        let loc = Loc::from(name);
329        let ty = infcx.unpack(bound);
330        self.bindings
331            .insert(loc, LocKind::LocalPtr(bound.clone()), ty);
332        Ok(loc)
333    }
334
335    /// ```text
336    /// -----------------------------------
337    /// Γ ; &strg <ℓ: t> => Γ,ℓ: t ; ptr(ℓ)
338    /// ```
339    pub(crate) fn unfold_strg_ref(
340        &mut self,
341        infcx: &mut InferCtxt,
342        path: &Path,
343        ty: &Ty,
344    ) -> InferResult<Loc> {
345        if let Some(loc) = path.to_loc() {
346            let ty = infcx.unpack(ty);
347            self.bindings.insert(loc, LocKind::Universal, ty);
348            Ok(loc)
349        } else {
350            bug!("unfold_strg_ref: unexpected path {path:?}")
351        }
352    }
353
354    pub(crate) fn unfold(
355        &mut self,
356        infcx: &mut InferCtxt,
357        place: &Place,
358        span: Span,
359    ) -> InferResult {
360        self.bindings.unfold(infcx, place, span)
361    }
362
363    pub(crate) fn downcast(
364        &mut self,
365        infcx: &mut InferCtxtAt,
366        place: &Place,
367        variant_idx: VariantIdx,
368    ) -> InferResult {
369        let mut down_place = place.clone();
370        let span = infcx.span;
371        down_place
372            .projection
373            .push(PlaceElem::Downcast(None, variant_idx));
374        self.bindings.unfold(infcx, &down_place, span)?;
375        Ok(())
376    }
377
378    pub fn fully_resolve_evars(&mut self, infcx: &InferCtxt) {
379        self.bindings
380            .fmap_mut(|_loc, ty| infcx.fully_resolve_evars(ty));
381    }
382
383    pub(crate) fn assume_ensures(
384        &mut self,
385        infcx: &mut InferCtxt,
386        ensures: &[Ensures],
387        span: Span,
388    ) {
389        for ensure in ensures {
390            match ensure {
391                Ensures::Type(path, updated_ty) => {
392                    let updated_ty = infcx.unpack(updated_ty);
393                    infcx.assume_invariants(&updated_ty);
394                    self.update_path(path, updated_ty, span);
395                }
396                Ensures::Pred(e) => infcx.assume_pred(e),
397            }
398        }
399    }
400
401    pub(crate) fn check_ensures(
402        &mut self,
403        at: &mut InferCtxtAt,
404        ensures: &[Ensures],
405        reason: ConstrReason,
406    ) -> InferResult {
407        for constraint in ensures {
408            match constraint {
409                Ensures::Type(path, ty) => {
410                    let actual_ty = self.get(path).unblocked(); // HACK
411                    at.subtyping(&actual_ty, ty, reason)?;
412                }
413                Ensures::Pred(e) => {
414                    at.check_pred(e, ConstrReason::Ret);
415                }
416            }
417        }
418        Ok(())
419    }
420}
421
422pub(crate) enum PtrToRefBound {
423    Ty(Ty),
424    Infer,
425    Identity,
426}
427
428impl flux_infer::infer::LocEnv for TypeEnv<'_> {
429    fn ptr_to_ref(
430        &mut self,
431        infcx: &mut InferCtxtAt,
432        reason: ConstrReason,
433        re: Region,
434        path: &Path,
435        bound: Ty,
436    ) -> InferResult<Ty> {
437        self.ptr_to_ref(infcx, reason, re, path, PtrToRefBound::Ty(bound))
438    }
439
440    fn get(&self, path: &Path) -> Ty {
441        self.get(path)
442    }
443
444    fn unfold_strg_ref(&mut self, infcx: &mut InferCtxt, path: &Path, ty: &Ty) -> InferResult<Loc> {
445        self.unfold_strg_ref(infcx, path, ty)
446    }
447}
448
449impl BasicBlockEnvShape {
450    pub fn enter<'a>(&self, local_decls: &'a LocalDecls) -> TypeEnv<'a> {
451        TypeEnv { bindings: self.bindings.clone(), local_decls }
452    }
453
454    fn new(scope: Scope, env: TypeEnv) -> BasicBlockEnvShape {
455        let mut bindings = env.bindings;
456        bindings.fmap_mut(|_loc, ty| BasicBlockEnvShape::pack_ty(&scope, ty));
457        BasicBlockEnvShape { scope, bindings }
458    }
459
460    fn pack_ty(scope: &Scope, ty: &Ty) -> Ty {
461        match ty.kind() {
462            TyKind::Indexed(bty, idxs) => {
463                let bty = BasicBlockEnvShape::pack_bty(scope, bty);
464                if scope.has_free_vars(idxs) {
465                    Ty::exists_with_constr(bty, Expr::hole(HoleKind::Pred))
466                } else {
467                    Ty::indexed(bty, idxs.clone())
468                }
469            }
470            TyKind::Downcast(adt, args, ty, variant, fields) => {
471                debug_assert!(!scope.has_free_vars(args));
472                debug_assert!(!scope.has_free_vars(ty));
473                let fields = fields.iter().map(|ty| Self::pack_ty(scope, ty)).collect();
474                Ty::downcast(adt.clone(), args.clone(), ty.clone(), *variant, fields)
475            }
476            TyKind::Blocked(ty) => Ty::blocked(BasicBlockEnvShape::pack_ty(scope, ty)),
477            // FIXME(nilehmann) [`TyKind::Exists`] could also contain free variables.
478            TyKind::Exists(_)
479            | TyKind::Discr(..)
480            | TyKind::Ptr(..)
481            | TyKind::Uninit
482            | TyKind::Param(_)
483            | TyKind::Constr(_, _) => ty.clone(),
484            TyKind::Infer(_) => bug!("unexpected hole whecn checking function body"),
485            TyKind::StrgRef(..) => bug!("unexpected strong reference when checking function body"),
486        }
487    }
488
489    fn pack_bty(scope: &Scope, bty: &BaseTy) -> BaseTy {
490        match bty {
491            BaseTy::Adt(adt_def, args) => {
492                let args = List::from_vec(
493                    args.iter()
494                        .map(|arg| Self::pack_generic_arg(scope, arg))
495                        .collect(),
496                );
497                BaseTy::adt(adt_def.clone(), args)
498            }
499            BaseTy::FnDef(def_id, args) => {
500                let args = List::from_vec(
501                    args.iter()
502                        .map(|arg| Self::pack_generic_arg(scope, arg))
503                        .collect(),
504                );
505                BaseTy::fn_def(*def_id, args)
506            }
507            BaseTy::Tuple(tys) => {
508                let tys = tys
509                    .iter()
510                    .map(|ty| BasicBlockEnvShape::pack_ty(scope, ty))
511                    .collect();
512                BaseTy::Tuple(tys)
513            }
514            BaseTy::Slice(ty) => BaseTy::Slice(Self::pack_ty(scope, ty)),
515            BaseTy::Ref(r, ty, mutbl) => BaseTy::Ref(*r, Self::pack_ty(scope, ty), *mutbl),
516            BaseTy::Array(ty, c) => BaseTy::Array(Self::pack_ty(scope, ty), c.clone()),
517            BaseTy::Int(_)
518            | BaseTy::Param(_)
519            | BaseTy::Uint(_)
520            | BaseTy::Bool
521            | BaseTy::Float(_)
522            | BaseTy::Str
523            | BaseTy::RawPtr(_, _)
524            | BaseTy::RawPtrMetadata(_)
525            | BaseTy::Char
526            | BaseTy::Never
527            | BaseTy::Closure(..)
528            | BaseTy::Dynamic(..)
529            | BaseTy::Alias(..)
530            | BaseTy::FnPtr(..)
531            | BaseTy::Foreign(..)
532            | BaseTy::Coroutine(..) => {
533                if scope.has_free_vars(bty) {
534                    tracked_span_bug!("unexpected type with free vars")
535                } else {
536                    bty.clone()
537                }
538            }
539            BaseTy::Pat => {
540                todo!()
541            }
542            BaseTy::Infer(..) => {
543                tracked_span_bug!("unexpected infer type")
544            }
545        }
546    }
547
548    fn pack_generic_arg(scope: &Scope, arg: &GenericArg) -> GenericArg {
549        match arg {
550            GenericArg::Ty(ty) => GenericArg::Ty(Self::pack_ty(scope, ty)),
551            GenericArg::Base(arg) => {
552                assert!(!scope.has_free_vars(arg));
553                GenericArg::Base(arg.clone())
554            }
555            GenericArg::Lifetime(re) => GenericArg::Lifetime(*re),
556            GenericArg::Const(c) => GenericArg::Const(c.clone()),
557        }
558    }
559
560    fn update(&mut self, path: &Path, ty: Ty, span: Span) {
561        self.bindings.lookup(path, span).update(ty);
562    }
563
564    /// join(self, genv, other) consumes the bindings in other, to "update"
565    /// `self` in place, and returns `true` if there was an actual change
566    /// or `false` indicating no change (i.e., a fixpoint was reached).
567    pub(crate) fn join(&mut self, other: TypeEnv, span: Span) -> bool {
568        let paths = self.bindings.paths();
569
570        // Join types
571        let mut modified = false;
572        for path in &paths {
573            let ty1 = self.bindings.get(path);
574            let ty2 = other.bindings.get(path);
575            let ty = if ty1 == ty2 { ty1.clone() } else { self.join_ty(&ty1, &ty2) };
576            modified |= ty1 != ty;
577            self.update(path, ty, span);
578        }
579
580        modified
581    }
582
583    fn join_ty(&self, ty1: &Ty, ty2: &Ty) -> Ty {
584        match (ty1.kind(), ty2.kind()) {
585            (TyKind::Blocked(ty1), _) => Ty::blocked(self.join_ty(ty1, &ty2.unblocked())),
586            (_, TyKind::Blocked(ty2)) => Ty::blocked(self.join_ty(&ty1.unblocked(), ty2)),
587            (TyKind::Uninit, _) | (_, TyKind::Uninit) => Ty::uninit(),
588            (TyKind::Exists(ty1), _) => self.join_ty(ty1.as_ref().skip_binder(), ty2),
589            (_, TyKind::Exists(ty2)) => self.join_ty(ty1, ty2.as_ref().skip_binder()),
590            (TyKind::Constr(_, ty1), _) => self.join_ty(ty1, ty2),
591            (_, TyKind::Constr(_, ty2)) => self.join_ty(ty1, ty2),
592            (TyKind::Indexed(bty1, idx1), TyKind::Indexed(bty2, idx2)) => {
593                let bty = self.join_bty(bty1, bty2);
594                let mut sorts = vec![];
595                let idx = self.join_idx(idx1, idx2, &bty.sort(), &mut sorts);
596                if sorts.is_empty() {
597                    Ty::indexed(bty, idx)
598                } else {
599                    let ty = Ty::constr(Expr::hole(HoleKind::Pred), Ty::indexed(bty, idx));
600                    Ty::exists(Binder::bind_with_sorts(ty, &sorts))
601                }
602            }
603            (TyKind::Ptr(rk1, path1), TyKind::Ptr(rk2, path2)) => {
604                debug_assert_eq!(rk1, rk2);
605                debug_assert_eq!(path1, path2);
606                Ty::ptr(*rk1, path1.clone())
607            }
608            (TyKind::Param(param_ty1), TyKind::Param(param_ty2)) => {
609                debug_assert_eq!(param_ty1, param_ty2);
610                Ty::param(*param_ty1)
611            }
612            (
613                TyKind::Downcast(adt1, args1, ty1, variant1, fields1),
614                TyKind::Downcast(adt2, args2, ty2, variant2, fields2),
615            ) => {
616                debug_assert_eq!(adt1, adt2);
617                debug_assert_eq!(args1, args2);
618                debug_assert!(ty1 == ty2 && !self.scope.has_free_vars(ty2));
619                debug_assert_eq!(variant1, variant2);
620                debug_assert_eq!(fields1.len(), fields2.len());
621                let fields = iter::zip(fields1, fields2)
622                    .map(|(ty1, ty2)| self.join_ty(ty1, ty2))
623                    .collect();
624                Ty::downcast(adt1.clone(), args1.clone(), ty1.clone(), *variant1, fields)
625            }
626            _ => tracked_span_bug!("unexpected types: `{ty1:?}` - `{ty2:?}`"),
627        }
628    }
629
630    fn join_idx(&self, e1: &Expr, e2: &Expr, sort: &Sort, bound_sorts: &mut Vec<Sort>) -> Expr {
631        match (e1.kind(), e2.kind(), sort) {
632            (ExprKind::Tuple(es1), ExprKind::Tuple(es2), Sort::Tuple(sorts)) => {
633                debug_assert_eq3!(es1.len(), es2.len(), sorts.len());
634                Expr::tuple(
635                    izip!(es1, es2, sorts)
636                        .map(|(e1, e2, sort)| self.join_idx(e1, e2, sort, bound_sorts))
637                        .collect(),
638                )
639            }
640            (
641                ExprKind::Ctor(Ctor::Struct(_), flds1),
642                ExprKind::Ctor(Ctor::Struct(_), flds2),
643                Sort::App(SortCtor::Adt(sort_def), args),
644            ) => {
645                let sorts = sort_def.struct_variant().field_sorts(args);
646                debug_assert_eq3!(flds1.len(), flds2.len(), sorts.len());
647
648                Expr::ctor_struct(
649                    sort_def.did(),
650                    izip!(flds1, flds2, &sorts)
651                        .map(|(f1, f2, sort)| self.join_idx(f1, f2, sort, bound_sorts))
652                        .collect(),
653                )
654            }
655            _ => {
656                let has_free_vars2 = self.scope.has_free_vars(e2);
657                let has_escaping_vars1 = e1.has_escaping_bvars();
658                let has_escaping_vars2 = e2.has_escaping_bvars();
659                if !has_free_vars2 && !has_escaping_vars1 && !has_escaping_vars2 && e1 == e2 {
660                    e1.clone()
661                } else if sort.is_pred() {
662                    // FIXME(nilehmann) we shouldn't special case predicates here. Instead, we
663                    // should differentiate between generics and indices.
664                    let fsort = sort.expect_func().expect_mono();
665                    Expr::abs(Lambda::bind_with_fsort(Expr::hole(HoleKind::Pred), fsort))
666                } else {
667                    bound_sorts.push(sort.clone());
668                    Expr::bvar(
669                        INNERMOST,
670                        BoundVar::from_usize(bound_sorts.len() - 1),
671                        BoundReftKind::Anon,
672                    )
673                }
674            }
675        }
676    }
677
678    fn join_bty(&self, bty1: &BaseTy, bty2: &BaseTy) -> BaseTy {
679        match (bty1, bty2) {
680            (BaseTy::Adt(def1, args1), BaseTy::Adt(def2, args2)) => {
681                tracked_span_dbg_assert_eq!(def1.did(), def2.did());
682                let args = iter::zip(args1, args2)
683                    .map(|(arg1, arg2)| self.join_generic_arg(arg1, arg2))
684                    .collect();
685                BaseTy::adt(def1.clone(), List::from_vec(args))
686            }
687            (BaseTy::Tuple(fields1), BaseTy::Tuple(fields2)) => {
688                let fields = iter::zip(fields1, fields2)
689                    .map(|(ty1, ty2)| self.join_ty(ty1, ty2))
690                    .collect();
691                BaseTy::Tuple(fields)
692            }
693            (BaseTy::Alias(kind1, alias_ty1), BaseTy::Alias(kind2, alias_ty2)) => {
694                tracked_span_dbg_assert_eq!(kind1, kind2);
695                tracked_span_dbg_assert_eq!(alias_ty1, alias_ty2);
696                BaseTy::Alias(*kind1, alias_ty1.clone())
697            }
698            (BaseTy::Ref(r1, ty1, mutbl1), BaseTy::Ref(r2, ty2, mutbl2)) => {
699                tracked_span_dbg_assert_eq!(r1, r2);
700                tracked_span_dbg_assert_eq!(mutbl1, mutbl2);
701                BaseTy::Ref(*r1, self.join_ty(ty1, ty2), *mutbl1)
702            }
703            (BaseTy::Array(ty1, len1), BaseTy::Array(ty2, len2)) => {
704                tracked_span_dbg_assert_eq!(len1, len2);
705                BaseTy::Array(self.join_ty(ty1, ty2), len1.clone())
706            }
707            (BaseTy::Slice(ty1), BaseTy::Slice(ty2)) => BaseTy::Slice(self.join_ty(ty1, ty2)),
708            _ => {
709                tracked_span_dbg_assert_eq!(bty1, bty2);
710                bty1.clone()
711            }
712        }
713    }
714
715    fn join_generic_arg(&self, arg1: &GenericArg, arg2: &GenericArg) -> GenericArg {
716        match (arg1, arg2) {
717            (GenericArg::Ty(ty1), GenericArg::Ty(ty2)) => GenericArg::Ty(self.join_ty(ty1, ty2)),
718            (GenericArg::Base(ctor1), GenericArg::Base(ctor2)) => {
719                let sty1 = ctor1.as_ref().skip_binder();
720                let sty2 = ctor2.as_ref().skip_binder();
721                debug_assert!(sty1.idx.is_nu());
722                debug_assert!(sty2.idx.is_nu());
723
724                let bty = self.join_bty(&sty1.bty, &sty2.bty);
725                let pred = if self.scope.has_free_vars(&sty2.pred) || sty1.pred != sty2.pred {
726                    Expr::hole(HoleKind::Pred)
727                } else {
728                    sty1.pred.clone()
729                };
730                let sort = bty.sort();
731                let ctor = Binder::bind_with_sort(SubsetTy::new(bty, Expr::nu(), pred), sort);
732                GenericArg::Base(ctor)
733            }
734            (GenericArg::Lifetime(re1), GenericArg::Lifetime(_re2)) => {
735                // TODO(nilehmann) loop_abstract_refinement.rs is triggering this assertion to fail
736                // wee should fix it.
737                // debug_assert_eq!(re1, _re2);
738                GenericArg::Lifetime(*re1)
739            }
740            (GenericArg::Const(c1), GenericArg::Const(c2)) => {
741                debug_assert_eq!(c1, c2);
742                GenericArg::Const(c1.clone())
743            }
744            _ => tracked_span_bug!("unexpected generic args: `{arg1:?}` - `{arg2:?}`"),
745        }
746    }
747
748    pub fn into_bb_env(self, infcx: &mut InferCtxtRoot, body: &Body) -> BasicBlockEnv {
749        let mut delegate = LocalHoister::default();
750        let mut hoister = Hoister::with_delegate(&mut delegate).transparent();
751
752        let mut bindings = self.bindings;
753        bindings.fmap_mut(|loc, ty| {
754            let name = if let Loc::Local(local) = loc {
755                body.local_names.get(local).copied()
756            } else {
757                None
758            };
759            hoister.delegate.name = name;
760            hoister.hoist(ty)
761        });
762
763        BasicBlockEnv {
764            // We are relying on all the types in `bindings` not having escaping bvars, otherwise
765            // we would have to shift them in since we are creating a new binder.
766            data: delegate.bind(|vars, preds| {
767                // Replace all holes with a single fresh kvar on all parameters
768                let mut constrs = preds
769                    .into_iter()
770                    .filter(|pred| !matches!(pred.kind(), ExprKind::Hole(HoleKind::Pred)))
771                    .collect_vec();
772                let kvar = infcx.fresh_kvar_in_scope(
773                    std::slice::from_ref(&vars),
774                    &self.scope,
775                    KVarEncoding::Conj,
776                );
777                constrs.push(kvar);
778
779                // Replace remaining holes by fresh kvars
780                let mut kvar_gen = |binders: &[_], kind| {
781                    debug_assert_eq!(kind, HoleKind::Pred);
782                    let binders = std::iter::once(vars.clone())
783                        .chain(binders.iter().cloned())
784                        .collect_vec();
785                    infcx.fresh_kvar_in_scope(&binders, &self.scope, KVarEncoding::Conj)
786                };
787                bindings.fmap_mut(|_, binding| binding.replace_holes(&mut kvar_gen));
788
789                BasicBlockEnvData { constrs: constrs.into(), bindings }
790            }),
791            scope: self.scope,
792        }
793    }
794}
795
796impl TypeVisitable for BasicBlockEnvData {
797    fn visit_with<V: TypeVisitor>(&self, _visitor: &mut V) -> ControlFlow<V::BreakTy> {
798        unimplemented!()
799    }
800}
801
802impl TypeFoldable for BasicBlockEnvData {
803    fn try_fold_with<F: FallibleTypeFolder>(
804        &self,
805        folder: &mut F,
806    ) -> std::result::Result<Self, F::Error> {
807        Ok(BasicBlockEnvData {
808            constrs: self.constrs.try_fold_with(folder)?,
809            bindings: self.bindings.try_fold_with(folder)?,
810        })
811    }
812}
813
814impl BasicBlockEnv {
815    pub(crate) fn enter<'a>(
816        &self,
817        infcx: &mut InferCtxt,
818        local_decls: &'a LocalDecls,
819    ) -> TypeEnv<'a> {
820        let data = self.data.replace_bound_refts_with(|sort, _, kind| {
821            Expr::fvar(infcx.define_bound_reft_var(sort, kind))
822        });
823        for constr in &data.constrs {
824            infcx.assume_pred(constr);
825        }
826        TypeEnv { bindings: data.bindings, local_decls }
827    }
828
829    pub(crate) fn scope(&self) -> &Scope {
830        &self.scope
831    }
832}
833
834mod pretty {
835    use std::fmt;
836
837    use flux_middle::pretty::*;
838
839    use super::*;
840
841    impl Pretty for TypeEnv<'_> {
842        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
843            w!(cx, f, "{:?}", &self.bindings)
844        }
845
846        fn default_cx(tcx: TyCtxt) -> PrettyCx {
847            PlacesTree::default_cx(tcx)
848        }
849    }
850
851    impl Pretty for BasicBlockEnvShape {
852        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
853            w!(cx, f, "{:?} {:?}", &self.scope, &self.bindings)
854        }
855
856        fn default_cx(tcx: TyCtxt) -> PrettyCx {
857            PlacesTree::default_cx(tcx)
858        }
859    }
860
861    impl Pretty for BasicBlockEnv {
862        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
863            w!(cx, f, "{:?} ", &self.scope)?;
864
865            let vars = self.data.vars();
866            cx.with_bound_vars(vars, || {
867                if !vars.is_empty() {
868                    cx.fmt_bound_vars(true, "for<", vars, "> ", f)?;
869                }
870                let data = self.data.as_ref().skip_binder();
871                if !data.constrs.is_empty() {
872                    w!(
873                        cx,
874                        f,
875                        "{:?} ⇒ ",
876                        join!(", ", data.constrs.iter().filter(|pred| !pred.is_trivially_true()))
877                    )?;
878                }
879                w!(cx, f, "{:?}", &data.bindings)
880            })
881        }
882
883        fn default_cx(tcx: TyCtxt) -> PrettyCx {
884            PlacesTree::default_cx(tcx)
885        }
886    }
887
888    impl_debug_with_default_cx! {
889        TypeEnv<'_> => "type_env",
890        BasicBlockEnvShape => "basic_block_env_shape",
891        BasicBlockEnv => "basic_block_env"
892    }
893}
894
895/// A very explicit representation of [`TypeEnv`] for debugging/tracing/serialization ONLY.
896#[derive(Serialize, DebugAsJson)]
897pub struct TypeEnvTrace(Vec<TypeEnvBind>);
898
899#[derive(Serialize)]
900struct TypeEnvBind {
901    local: LocInfo,
902    name: Option<String>,
903    kind: String,
904    ty: String,
905    span: Option<SpanTrace>,
906}
907
908#[derive(Serialize)]
909enum LocInfo {
910    Local(String),
911    Var(String),
912}
913
914fn loc_info(loc: &Loc) -> LocInfo {
915    match loc {
916        Loc::Local(local) => LocInfo::Local(format!("{local:?}")),
917        Loc::Var(var) => LocInfo::Var(format!("{var:?}")),
918    }
919}
920
921fn loc_name(local_names: &UnordMap<Local, Symbol>, loc: &Loc) -> Option<String> {
922    if let Loc::Local(local) = loc {
923        let name = local_names.get(local)?;
924        return Some(format!("{name}"));
925    }
926    None
927}
928
929fn loc_span(
930    genv: GlobalEnv,
931    local_decls: &IndexVec<Local, LocalDecl>,
932    loc: &Loc,
933) -> Option<SpanTrace> {
934    if let Loc::Local(local) = loc {
935        return local_decls
936            .get(*local)
937            .map(|local_decl| SpanTrace::new(genv.tcx(), local_decl.source_info.span));
938    }
939    None
940}
941
942impl TypeEnvTrace {
943    pub fn new(
944        genv: GlobalEnv,
945        local_names: &UnordMap<Local, Symbol>,
946        local_decls: &IndexVec<Local, LocalDecl>,
947        cx: PrettyCx,
948        env: &TypeEnv,
949    ) -> Self {
950        let mut bindings = vec![];
951        env.bindings
952            .iter()
953            .filter(|(_, binding)| !binding.ty.is_uninit())
954            .sorted_by(|(loc1, _), (loc2, _)| loc1.cmp(loc2))
955            .for_each(|(loc, binding)| {
956                let name = loc_name(local_names, loc);
957                let local = loc_info(loc);
958                let kind = format!("{:?}", binding.kind);
959                let ty = binding.ty.nested_string(&cx);
960                let span = loc_span(genv, local_decls, loc);
961                bindings.push(TypeEnvBind { name, local, kind, ty, span });
962            });
963
964        TypeEnvTrace(bindings)
965    }
966}