flux_infer/
infer.rs

1use std::{cell::RefCell, fmt, iter};
2
3use flux_common::{bug, dbg, tracked_span_assert_eq, tracked_span_bug, tracked_span_dbg_assert_eq};
4use flux_config::{self as config, InferOpts};
5use flux_macros::{TypeFoldable, TypeVisitable};
6use flux_middle::{
7    FixpointQueryKind,
8    def_id::MaybeExternId,
9    global_env::GlobalEnv,
10    queries::{QueryErr, QueryResult},
11    query_bug,
12    rty::{
13        self, AliasKind, AliasTy, BaseTy, Binder, BoundVariableKinds, CoroutineObligPredicate,
14        Ctor, ESpan, EVid, EarlyBinder, Expr, ExprKind, FieldProj, GenericArg, HoleKind, InferMode,
15        Lambda, List, Loc, Mutability, Name, Path, PolyVariant, PtrKind, RefineArgs, RefineArgsExt,
16        Region, Sort, Ty, TyCtor, TyKind, Var,
17        canonicalize::{Hoister, HoisterDelegate},
18        fold::TypeFoldable,
19    },
20};
21use itertools::{Itertools, izip};
22use rustc_hir::def_id::{DefId, LocalDefId};
23use rustc_macros::extension;
24use rustc_middle::{
25    mir::BasicBlock,
26    ty::{TyCtxt, Variance},
27};
28use rustc_span::Span;
29
30use crate::{
31    evars::{EVarState, EVarStore},
32    fixpoint_encoding::{FixQueryCache, FixpointCtxt, KVarEncoding, KVarGen},
33    projections::NormalizeExt as _,
34    refine_tree::{Cursor, Marker, RefineTree, Scope},
35};
36
37pub type InferResult<T = ()> = std::result::Result<T, InferErr>;
38
39#[derive(PartialEq, Eq, Clone, Copy, Hash)]
40pub struct Tag {
41    pub reason: ConstrReason,
42    pub src_span: Span,
43    pub dst_span: Option<ESpan>,
44}
45
46impl Tag {
47    pub fn new(reason: ConstrReason, span: Span) -> Self {
48        Self { reason, src_span: span, dst_span: None }
49    }
50
51    pub fn with_dst(self, dst_span: Option<ESpan>) -> Self {
52        Self { dst_span, ..self }
53    }
54}
55
56#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)]
57pub enum SubtypeReason {
58    Input,
59    Output,
60    Requires,
61    Ensures,
62}
63
64#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)]
65pub enum ConstrReason {
66    Call,
67    Assign,
68    Ret,
69    Fold,
70    FoldLocal,
71    Predicate,
72    Assert(&'static str),
73    Div,
74    Rem,
75    Goto(BasicBlock),
76    Overflow,
77    Subtype(SubtypeReason),
78    Other,
79}
80
81pub struct InferCtxtRoot<'genv, 'tcx> {
82    pub genv: GlobalEnv<'genv, 'tcx>,
83    inner: RefCell<InferCtxtInner>,
84    refine_tree: RefineTree,
85    opts: InferOpts,
86}
87
88pub struct InferCtxtRootBuilder<'a, 'genv, 'tcx> {
89    genv: GlobalEnv<'genv, 'tcx>,
90    opts: InferOpts,
91    params: Vec<(Var, Sort)>,
92    infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
93    dummy_kvars: bool,
94}
95
96#[extension(pub trait GlobalEnvExt<'genv, 'tcx>)]
97impl<'genv, 'tcx> GlobalEnv<'genv, 'tcx> {
98    fn infcx_root<'a>(
99        self,
100        infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
101        opts: InferOpts,
102    ) -> InferCtxtRootBuilder<'a, 'genv, 'tcx> {
103        InferCtxtRootBuilder { genv: self, infcx, params: vec![], opts, dummy_kvars: false }
104    }
105}
106
107impl<'genv, 'tcx> InferCtxtRootBuilder<'_, 'genv, 'tcx> {
108    pub fn with_dummy_kvars(mut self) -> Self {
109        self.dummy_kvars = true;
110        self
111    }
112
113    pub fn with_const_generics(mut self, def_id: DefId) -> QueryResult<Self> {
114        self.params.extend(
115            self.genv
116                .generics_of(def_id)?
117                .const_params(self.genv)?
118                .into_iter()
119                .map(|(pcst, sort)| (Var::ConstGeneric(pcst), sort)),
120        );
121        Ok(self)
122    }
123
124    pub fn with_refinement_generics(
125        mut self,
126        def_id: DefId,
127        args: &[GenericArg],
128    ) -> QueryResult<Self> {
129        for (index, param) in self
130            .genv
131            .refinement_generics_of(def_id)?
132            .iter_own_params()
133            .enumerate()
134        {
135            let param = param.instantiate(self.genv.tcx(), args, &[]);
136            let sort = param.sort.normalize_sorts(def_id, self.genv, self.infcx)?;
137
138            let var =
139                Var::EarlyParam(rty::EarlyReftParam { index: index as u32, name: param.name });
140            self.params.push((var, sort));
141        }
142        Ok(self)
143    }
144
145    pub fn identity_for_item(mut self, def_id: DefId) -> QueryResult<Self> {
146        self = self.with_const_generics(def_id)?;
147        let offset = self.params.len();
148        self.genv.refinement_generics_of(def_id)?.fill_item(
149            self.genv,
150            &mut self.params,
151            &mut |param, index| {
152                let index = (index - offset) as u32;
153                let param = param.instantiate_identity();
154                let sort = param.sort.normalize_sorts(def_id, self.genv, self.infcx)?;
155
156                let var = Var::EarlyParam(rty::EarlyReftParam { index, name: param.name });
157                Ok((var, sort))
158            },
159        )?;
160        Ok(self)
161    }
162
163    pub fn build(self) -> QueryResult<InferCtxtRoot<'genv, 'tcx>> {
164        Ok(InferCtxtRoot {
165            genv: self.genv,
166            inner: RefCell::new(InferCtxtInner::new(self.dummy_kvars)),
167            refine_tree: RefineTree::new(self.params),
168            opts: self.opts,
169        })
170    }
171}
172
173impl<'genv, 'tcx> InferCtxtRoot<'genv, 'tcx> {
174    pub fn infcx<'a>(
175        &'a mut self,
176        def_id: DefId,
177        region_infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
178    ) -> InferCtxt<'a, 'genv, 'tcx> {
179        InferCtxt {
180            genv: self.genv,
181            region_infcx,
182            def_id,
183            cursor: self.refine_tree.cursor_at_root(),
184            inner: &self.inner,
185            check_overflow: self.opts.check_overflow,
186        }
187    }
188
189    pub fn fresh_kvar_in_scope(
190        &self,
191        binders: &[BoundVariableKinds],
192        scope: &Scope,
193        encoding: KVarEncoding,
194    ) -> Expr {
195        let inner = &mut *self.inner.borrow_mut();
196        inner.kvars.fresh(binders, scope.iter(), encoding)
197    }
198
199    pub fn execute_fixpoint_query(
200        self,
201        cache: &mut FixQueryCache,
202        def_id: MaybeExternId,
203        kind: FixpointQueryKind,
204    ) -> QueryResult<Vec<Tag>> {
205        let inner = self.inner.into_inner();
206        let kvars = inner.kvars;
207        let evars = inner.evars;
208
209        let ext = kind.ext();
210
211        let mut refine_tree = self.refine_tree;
212
213        refine_tree.replace_evars(&evars).unwrap();
214
215        if config::dump_constraint() {
216            dbg::dump_item_info(self.genv.tcx(), def_id.resolved_id(), ext, &refine_tree).unwrap();
217        }
218        refine_tree.simplify(self.genv);
219        if config::dump_constraint() {
220            let simp_ext = format!("simp.{ext}");
221            dbg::dump_item_info(self.genv.tcx(), def_id.resolved_id(), simp_ext, &refine_tree)
222                .unwrap();
223        }
224
225        let mut fcx = FixpointCtxt::new(self.genv, def_id, kvars);
226        let cstr = refine_tree.into_fixpoint(&mut fcx)?;
227
228        let backend = match self.opts.solver {
229            flux_config::SmtSolver::Z3 => liquid_fixpoint::SmtSolver::Z3,
230            flux_config::SmtSolver::CVC5 => liquid_fixpoint::SmtSolver::CVC5,
231        };
232
233        fcx.check(cache, cstr, kind, self.opts.scrape_quals, backend)
234    }
235
236    pub fn split(self) -> (RefineTree, KVarGen) {
237        (self.refine_tree, self.inner.into_inner().kvars)
238    }
239}
240
241pub struct InferCtxt<'infcx, 'genv, 'tcx> {
242    pub genv: GlobalEnv<'genv, 'tcx>,
243    pub region_infcx: &'infcx rustc_infer::infer::InferCtxt<'tcx>,
244    pub def_id: DefId,
245    pub check_overflow: bool,
246    cursor: Cursor<'infcx>,
247    inner: &'infcx RefCell<InferCtxtInner>,
248}
249
250struct InferCtxtInner {
251    kvars: KVarGen,
252    evars: EVarStore,
253}
254
255impl InferCtxtInner {
256    fn new(dummy_kvars: bool) -> Self {
257        Self { kvars: KVarGen::new(dummy_kvars), evars: Default::default() }
258    }
259}
260
261impl<'infcx, 'genv, 'tcx> InferCtxt<'infcx, 'genv, 'tcx> {
262    pub fn at(&mut self, span: Span) -> InferCtxtAt<'_, 'infcx, 'genv, 'tcx> {
263        InferCtxtAt { infcx: self, span }
264    }
265
266    pub fn instantiate_refine_args(
267        &mut self,
268        callee_def_id: DefId,
269        args: &[rty::GenericArg],
270    ) -> InferResult<List<Expr>> {
271        Ok(RefineArgs::for_item(self.genv, callee_def_id, |param, _| {
272            let param = param.instantiate(self.genv.tcx(), args, &[]);
273            Ok(self.fresh_infer_var(&param.sort, param.mode))
274        })?)
275    }
276
277    pub fn instantiate_generic_args(&mut self, args: &[GenericArg]) -> Vec<GenericArg> {
278        args.iter()
279            .map(|a| a.replace_holes(|binders, kind| self.fresh_infer_var_for_hole(binders, kind)))
280            .collect_vec()
281    }
282
283    pub fn fresh_infer_var(&self, sort: &Sort, mode: InferMode) -> Expr {
284        match mode {
285            InferMode::KVar => {
286                let fsort = sort.expect_func().expect_mono();
287                let vars = fsort.inputs().iter().cloned().map_into().collect();
288                let kvar = self.fresh_kvar(&[vars], KVarEncoding::Single);
289                Expr::abs(Lambda::bind_with_fsort(kvar, fsort))
290            }
291            InferMode::EVar => self.fresh_evar(),
292        }
293    }
294
295    pub fn fresh_infer_var_for_hole(
296        &mut self,
297        binders: &[BoundVariableKinds],
298        kind: HoleKind,
299    ) -> Expr {
300        match kind {
301            HoleKind::Pred => self.fresh_kvar(binders, KVarEncoding::Conj),
302            HoleKind::Expr(_) => {
303                // We only use expression holes to infer early param arguments for opaque types
304                // at function calls. These should be well-scoped in the current scope, so we ignore
305                // the extra `binders` around the hole.
306                self.fresh_evar()
307            }
308        }
309    }
310
311    /// Generate a fresh kvar in the _given_ [`Scope`] (similar method in [`InferCtxtRoot`]).
312    pub fn fresh_kvar_in_scope(
313        &self,
314        binders: &[BoundVariableKinds],
315        scope: &Scope,
316        encoding: KVarEncoding,
317    ) -> Expr {
318        let inner = &mut *self.inner.borrow_mut();
319        inner.kvars.fresh(binders, scope.iter(), encoding)
320    }
321
322    /// Generate a fresh kvar in the current scope. See [`KVarGen::fresh`].
323    pub fn fresh_kvar(&self, binders: &[BoundVariableKinds], encoding: KVarEncoding) -> Expr {
324        let inner = &mut *self.inner.borrow_mut();
325        inner.kvars.fresh(binders, self.cursor.vars(), encoding)
326    }
327
328    fn fresh_evar(&self) -> Expr {
329        let evars = &mut self.inner.borrow_mut().evars;
330        Expr::evar(evars.fresh(self.cursor.marker()))
331    }
332
333    pub fn unify_exprs(&self, a: &Expr, b: &Expr) {
334        if a.has_evars() {
335            return;
336        }
337        let evars = &mut self.inner.borrow_mut().evars;
338        if let ExprKind::Var(Var::EVar(evid)) = b.kind()
339            && let EVarState::Unsolved(marker) = evars.get(*evid)
340            && !marker.has_free_vars(a)
341        {
342            evars.solve(*evid, a.clone());
343        }
344    }
345
346    fn enter_exists<T, U>(
347        &mut self,
348        t: &Binder<T>,
349        f: impl FnOnce(&mut InferCtxt<'_, 'genv, 'tcx>, T) -> U,
350    ) -> U
351    where
352        T: TypeFoldable,
353    {
354        self.ensure_resolved_evars(|infcx| {
355            let t = t.replace_bound_refts_with(|sort, mode, _| infcx.fresh_infer_var(sort, mode));
356            Ok(f(infcx, t))
357        })
358        .unwrap()
359    }
360
361    /// Used in conjunction with [`InferCtxt::pop_evar_scope`] to ensure evars are solved at the end
362    /// of some scope, for example, to ensure all evars generated during a function call are solved
363    /// after checking argument subtyping. These functions can be used in a stack-like fashion to
364    /// create nested scopes.
365    pub fn push_evar_scope(&mut self) {
366        self.inner.borrow_mut().evars.push_scope();
367    }
368
369    /// Pop a scope and check all evars have been solved. This only check evars generated from the
370    /// last call to [`InferCtxt::push_evar_scope`].
371    pub fn pop_evar_scope(&mut self) -> InferResult {
372        self.inner
373            .borrow_mut()
374            .evars
375            .pop_scope()
376            .map_err(InferErr::UnsolvedEvar)
377    }
378
379    /// Convenience method pairing [`InferCtxt::push_evar_scope`] and [`InferCtxt::pop_evar_scope`].
380    pub fn ensure_resolved_evars<R>(
381        &mut self,
382        f: impl FnOnce(&mut Self) -> InferResult<R>,
383    ) -> InferResult<R> {
384        self.push_evar_scope();
385        let r = f(self)?;
386        self.pop_evar_scope()?;
387        Ok(r)
388    }
389
390    pub fn fully_resolve_evars<T: TypeFoldable>(&self, t: &T) -> T {
391        self.inner.borrow().evars.replace_evars(t).unwrap()
392    }
393
394    pub fn tcx(&self) -> TyCtxt<'tcx> {
395        self.genv.tcx()
396    }
397
398    pub fn cursor(&self) -> &Cursor<'infcx> {
399        &self.cursor
400    }
401}
402
403/// Methods that interact with the underlying [`Cursor`]
404impl<'infcx, 'genv, 'tcx> InferCtxt<'infcx, 'genv, 'tcx> {
405    pub fn change_item<'a>(
406        &'a mut self,
407        def_id: LocalDefId,
408        region_infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
409    ) -> InferCtxt<'a, 'genv, 'tcx> {
410        InferCtxt {
411            def_id: def_id.to_def_id(),
412            cursor: self.cursor.branch(),
413            region_infcx,
414            ..*self
415        }
416    }
417
418    pub fn move_to(&mut self, marker: &Marker, clear_children: bool) -> InferCtxt<'_, 'genv, 'tcx> {
419        InferCtxt {
420            cursor: self
421                .cursor
422                .move_to(marker, clear_children)
423                .unwrap_or_else(|| tracked_span_bug!()),
424            ..*self
425        }
426    }
427
428    pub fn branch(&mut self) -> InferCtxt<'_, 'genv, 'tcx> {
429        InferCtxt { cursor: self.cursor.branch(), ..*self }
430    }
431
432    pub fn define_var(&mut self, sort: &Sort) -> Name {
433        self.cursor.define_var(sort)
434    }
435
436    pub fn check_pred(&mut self, pred: impl Into<Expr>, tag: Tag) {
437        self.cursor.check_pred(pred, tag);
438    }
439
440    pub fn assume_pred(&mut self, pred: impl Into<Expr>) {
441        self.cursor.assume_pred(pred);
442    }
443
444    pub fn unpack(&mut self, ty: &Ty) -> Ty {
445        self.hoister(false).hoist(ty)
446    }
447
448    pub fn marker(&self) -> Marker {
449        self.cursor.marker()
450    }
451
452    pub fn hoister(
453        &mut self,
454        assume_invariants: bool,
455    ) -> Hoister<Unpacker<'_, 'infcx, 'genv, 'tcx>> {
456        Hoister::with_delegate(Unpacker { infcx: self, assume_invariants }).transparent()
457    }
458
459    pub fn assume_invariants(&mut self, ty: &Ty) {
460        self.cursor
461            .assume_invariants(self.genv.tcx(), ty, self.check_overflow);
462    }
463
464    fn check_impl(&mut self, pred1: impl Into<Expr>, pred2: impl Into<Expr>, tag: Tag) {
465        self.cursor.check_impl(pred1, pred2, tag);
466    }
467}
468
469pub struct Unpacker<'a, 'infcx, 'genv, 'tcx> {
470    infcx: &'a mut InferCtxt<'infcx, 'genv, 'tcx>,
471    assume_invariants: bool,
472}
473
474impl HoisterDelegate for Unpacker<'_, '_, '_, '_> {
475    fn hoist_exists(&mut self, ty_ctor: &TyCtor) -> Ty {
476        let ty =
477            ty_ctor.replace_bound_refts_with(|sort, _, _| Expr::fvar(self.infcx.define_var(sort)));
478        if self.assume_invariants {
479            self.infcx.assume_invariants(&ty);
480        }
481        ty
482    }
483
484    fn hoist_constr(&mut self, pred: Expr) {
485        self.infcx.assume_pred(pred);
486    }
487}
488
489impl std::fmt::Debug for InferCtxt<'_, '_, '_> {
490    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
491        std::fmt::Debug::fmt(&self.cursor, f)
492    }
493}
494
495#[derive(Debug)]
496pub struct InferCtxtAt<'a, 'infcx, 'genv, 'tcx> {
497    pub infcx: &'a mut InferCtxt<'infcx, 'genv, 'tcx>,
498    pub span: Span,
499}
500
501impl<'genv, 'tcx> InferCtxtAt<'_, '_, 'genv, 'tcx> {
502    fn tag(&self, reason: ConstrReason) -> Tag {
503        Tag::new(reason, self.span)
504    }
505
506    pub fn check_pred(&mut self, pred: impl Into<Expr>, reason: ConstrReason) {
507        let tag = self.tag(reason);
508        self.infcx.check_pred(pred, tag);
509    }
510
511    pub fn check_non_closure_clauses(
512        &mut self,
513        clauses: &[rty::Clause],
514        reason: ConstrReason,
515    ) -> InferResult {
516        for clause in clauses {
517            if let rty::ClauseKind::Projection(projection_pred) = clause.kind_skipping_binder() {
518                let impl_elem = BaseTy::projection(projection_pred.projection_ty)
519                    .to_ty()
520                    .normalize_projections(self)?;
521                let term = projection_pred.term.to_ty().normalize_projections(self)?;
522
523                // TODO: does this really need to be invariant? https://github.com/flux-rs/flux/pull/478#issuecomment-1654035374
524                self.subtyping(&impl_elem, &term, reason)?;
525                self.subtyping(&term, &impl_elem, reason)?;
526            }
527        }
528        Ok(())
529    }
530
531    /// Relate types via subtyping. This is the same as [`InferCtxtAt::subtyping`] except that we
532    /// also require a [`LocEnv`] to handle pointers and strong references
533    pub fn subtyping_with_env(
534        &mut self,
535        env: &mut impl LocEnv,
536        a: &Ty,
537        b: &Ty,
538        reason: ConstrReason,
539    ) -> InferResult {
540        let mut sub = Sub::new(env, reason, self.span);
541        sub.tys(self.infcx, a, b)
542    }
543
544    /// Relate types via subtyping and returns coroutine obligations. This doesn't handle subtyping
545    /// when strong references are involved.
546    ///
547    /// See comment for [`Sub::obligations`].
548    pub fn subtyping(
549        &mut self,
550        a: &Ty,
551        b: &Ty,
552        reason: ConstrReason,
553    ) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
554        let mut env = DummyEnv;
555        let mut sub = Sub::new(&mut env, reason, self.span);
556        sub.tys(self.infcx, a, b)?;
557        Ok(sub.obligations)
558    }
559
560    pub fn subtyping_generic_args(
561        &mut self,
562        variance: Variance,
563        a: &GenericArg,
564        b: &GenericArg,
565        reason: ConstrReason,
566    ) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
567        let mut env = DummyEnv;
568        let mut sub = Sub::new(&mut env, reason, self.span);
569        sub.generic_args(self.infcx, variance, a, b)?;
570        Ok(sub.obligations)
571    }
572
573    // FIXME(nilehmann) this is similar to `Checker::check_call`, but since is used from
574    // `place_ty::fold` we cannot use that directly. We should try to unify them, because
575    // there are a couple of things missing here (e.g., checking clauses on the struct definition).
576    pub fn check_constructor(
577        &mut self,
578        variant: EarlyBinder<PolyVariant>,
579        generic_args: &[GenericArg],
580        fields: &[Ty],
581        reason: ConstrReason,
582    ) -> InferResult<Ty> {
583        let ret = self.ensure_resolved_evars(|this| {
584            // Replace holes in generic arguments with fresh inference variables
585            let generic_args = this.instantiate_generic_args(generic_args);
586
587            let variant = variant
588                .instantiate(this.tcx(), &generic_args, &[])
589                .replace_bound_refts_with(|sort, mode, _| this.fresh_infer_var(sort, mode));
590
591            // Check arguments
592            for (actual, formal) in iter::zip(fields, variant.fields()) {
593                this.subtyping(actual, formal, reason)?;
594            }
595
596            // Check requires predicates
597            for require in &variant.requires {
598                this.check_pred(require, ConstrReason::Fold);
599            }
600
601            Ok(variant.ret())
602        })?;
603        Ok(self.fully_resolve_evars(&ret))
604    }
605
606    pub fn ensure_resolved_evars<R>(
607        &mut self,
608        f: impl FnOnce(&mut InferCtxtAt<'_, '_, 'genv, 'tcx>) -> InferResult<R>,
609    ) -> InferResult<R> {
610        self.infcx
611            .ensure_resolved_evars(|infcx| f(&mut infcx.at(self.span)))
612    }
613}
614
615impl<'a, 'genv, 'tcx> std::ops::Deref for InferCtxtAt<'_, 'a, 'genv, 'tcx> {
616    type Target = InferCtxt<'a, 'genv, 'tcx>;
617
618    fn deref(&self) -> &Self::Target {
619        self.infcx
620    }
621}
622
623impl std::ops::DerefMut for InferCtxtAt<'_, '_, '_, '_> {
624    fn deref_mut(&mut self) -> &mut Self::Target {
625        self.infcx
626    }
627}
628
629/// Used for debugging to attach a "trace" to the [`RefineTree`] that can be used to print information
630/// to recover the derivation when relating types via subtyping. The code that attaches the trace is
631/// currently commented out because the output is too verbose.
632#[derive(TypeVisitable, TypeFoldable)]
633pub(crate) enum TypeTrace {
634    Types(Ty, Ty),
635    BaseTys(BaseTy, BaseTy),
636}
637
638#[expect(dead_code, reason = "we use this for debugging some time")]
639impl TypeTrace {
640    fn tys(a: &Ty, b: &Ty) -> Self {
641        Self::Types(a.clone(), b.clone())
642    }
643
644    fn btys(a: &BaseTy, b: &BaseTy) -> Self {
645        Self::BaseTys(a.clone(), b.clone())
646    }
647}
648
649impl fmt::Debug for TypeTrace {
650    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
651        match self {
652            TypeTrace::Types(a, b) => write!(f, "{a:?} - {b:?}"),
653            TypeTrace::BaseTys(a, b) => write!(f, "{a:?} - {b:?}"),
654        }
655    }
656}
657
658pub trait LocEnv {
659    fn ptr_to_ref(
660        &mut self,
661        infcx: &mut InferCtxtAt,
662        reason: ConstrReason,
663        re: Region,
664        path: &Path,
665        bound: Ty,
666    ) -> InferResult<Ty>;
667
668    fn unfold_strg_ref(&mut self, infcx: &mut InferCtxt, path: &Path, ty: &Ty) -> InferResult<Loc>;
669
670    fn get(&self, path: &Path) -> Ty;
671}
672
673struct DummyEnv;
674
675impl LocEnv for DummyEnv {
676    fn ptr_to_ref(
677        &mut self,
678        _: &mut InferCtxtAt,
679        _: ConstrReason,
680        _: Region,
681        _: &Path,
682        _: Ty,
683    ) -> InferResult<Ty> {
684        bug!("call to `ptr_to_ref` on `DummyEnv`")
685    }
686
687    fn unfold_strg_ref(&mut self, _: &mut InferCtxt, _: &Path, _: &Ty) -> InferResult<Loc> {
688        bug!("call to `unfold_str_ref` on `DummyEnv`")
689    }
690
691    fn get(&self, _: &Path) -> Ty {
692        bug!("call to `get` on `DummyEnv`")
693    }
694}
695
696/// Context used to relate two types `a` and `b` via subtyping
697struct Sub<'a, E> {
698    /// The environment to lookup locations pointed to by [`TyKind::Ptr`].
699    env: &'a mut E,
700    reason: ConstrReason,
701    span: Span,
702    /// FIXME(nilehmann) This is used to store coroutine obligations generated during subtyping when
703    /// relating an opaque type. Other obligations related to relating opaque types are resolved
704    /// directly here. The implementation is really messy and we may be missing some obligations.
705    obligations: Vec<Binder<rty::CoroutineObligPredicate>>,
706}
707
708impl<'a, E: LocEnv> Sub<'a, E> {
709    fn new(env: &'a mut E, reason: ConstrReason, span: Span) -> Self {
710        Self { env, reason, span, obligations: vec![] }
711    }
712
713    fn tag(&self) -> Tag {
714        Tag::new(self.reason, self.span)
715    }
716
717    fn tys(&mut self, infcx: &mut InferCtxt, a: &Ty, b: &Ty) -> InferResult {
718        let infcx = &mut infcx.branch();
719        // infcx.cursor.push_trace(TypeTrace::tys(a, b));
720
721        // We *fully* unpack the lhs before continuing to be able to prove goals like this
722        // ∃a. (i32[a], ∃b. {i32[b] | a > b})} <: ∃a,b. ({i32[a] | b < a}, i32[b])
723        // See S4.5 in https://arxiv.org/pdf/2209.13000v1.pdf
724        let a = infcx.unpack(a);
725
726        match (a.kind(), b.kind()) {
727            (TyKind::Exists(..), _) => {
728                bug!("existentials should have been removed by the unpacking above");
729            }
730            (TyKind::Constr(..), _) => {
731                bug!("constraint types should have been removed by the unpacking above");
732            }
733
734            (_, TyKind::Exists(ctor_b)) => {
735                infcx.enter_exists(ctor_b, |infcx, ty_b| self.tys(infcx, &a, &ty_b))
736            }
737            (_, TyKind::Constr(pred_b, ty_b)) => {
738                infcx.check_pred(pred_b, self.tag());
739                self.tys(infcx, &a, ty_b)
740            }
741
742            (TyKind::Ptr(PtrKind::Mut(_), path_a), TyKind::StrgRef(_, path_b, ty_b)) => {
743                // We should technically remove `path1` from `env`, but we are assuming that functions
744                // always give back ownership of the location so `path1` is going to be overwritten
745                // after the call anyways.
746                let ty_a = self.env.get(path_a);
747                infcx.unify_exprs(&path_a.to_expr(), &path_b.to_expr());
748                self.tys(infcx, &ty_a, ty_b)
749            }
750            (TyKind::StrgRef(_, path_a, ty_a), TyKind::StrgRef(_, path_b, ty_b)) => {
751                // We have to unfold strong references prior to a subtyping check. Normally, when
752                // checking a function body, a `StrgRef` is automatically unfolded i.e. `x:&strg T`
753                // is turned into a `x:ptr(l); l: T` where `l` is some fresh location. However, we
754                // need the below to do a similar unfolding during function subtyping where we just
755                // have the super-type signature that needs to be unfolded. We also add the binding
756                // to the environment so that we can:
757                // (1) UPDATE the location after the call, and
758                // (2) CHECK the relevant `ensures` clauses of the super-sig.
759                // Same as the `Ptr` case above we should remove the location from the environment
760                // after unfolding to consume it, but we are assuming functions always give back
761                // ownership.
762                self.env.unfold_strg_ref(infcx, path_a, ty_a)?;
763                let ty_a = self.env.get(path_a);
764                infcx.unify_exprs(&path_a.to_expr(), &path_b.to_expr());
765                self.tys(infcx, &ty_a, ty_b)
766            }
767            (
768                TyKind::Ptr(PtrKind::Mut(re), path),
769                TyKind::Indexed(BaseTy::Ref(_, bound, Mutability::Mut), idx),
770            ) => {
771                // We sometimes generate evars for the index of references so we need to make sure
772                // we solve them.
773                self.idxs_eq(infcx, &Expr::unit(), idx);
774
775                self.env.ptr_to_ref(
776                    &mut infcx.at(self.span),
777                    self.reason,
778                    *re,
779                    path,
780                    bound.clone(),
781                )?;
782                Ok(())
783            }
784
785            (TyKind::Indexed(bty_a, idx_a), TyKind::Indexed(bty_b, idx_b)) => {
786                self.btys(infcx, bty_a, bty_b)?;
787                self.idxs_eq(infcx, idx_a, idx_b);
788                Ok(())
789            }
790            (TyKind::Ptr(pk_a, path_a), TyKind::Ptr(pk_b, path_b)) => {
791                debug_assert_eq!(pk_a, pk_b);
792                debug_assert_eq!(path_a, path_b);
793                Ok(())
794            }
795            (TyKind::Param(param_ty_a), TyKind::Param(param_ty_b)) => {
796                debug_assert_eq!(param_ty_a, param_ty_b);
797                Ok(())
798            }
799            (_, TyKind::Uninit) => Ok(()),
800            (TyKind::Downcast(.., fields_a), TyKind::Downcast(.., fields_b)) => {
801                debug_assert_eq!(fields_a.len(), fields_b.len());
802                for (ty_a, ty_b) in iter::zip(fields_a, fields_b) {
803                    self.tys(infcx, ty_a, ty_b)?;
804                }
805                Ok(())
806            }
807            _ => Err(query_bug!("incompatible types: `{a:?}` - `{b:?}`"))?,
808        }
809    }
810
811    fn btys(&mut self, infcx: &mut InferCtxt, a: &BaseTy, b: &BaseTy) -> InferResult {
812        // infcx.push_trace(TypeTrace::btys(a, b));
813
814        match (a, b) {
815            (BaseTy::Int(int_ty_a), BaseTy::Int(int_ty_b)) => {
816                debug_assert_eq!(int_ty_a, int_ty_b);
817                Ok(())
818            }
819            (BaseTy::Uint(uint_ty_a), BaseTy::Uint(uint_ty_b)) => {
820                debug_assert_eq!(uint_ty_a, uint_ty_b);
821                Ok(())
822            }
823            (BaseTy::Adt(a_adt, a_args), BaseTy::Adt(b_adt, b_args)) => {
824                tracked_span_dbg_assert_eq!(a_adt.did(), b_adt.did());
825                tracked_span_dbg_assert_eq!(a_args.len(), b_args.len());
826                let variances = infcx.genv.variances_of(a_adt.did());
827                for (variance, ty_a, ty_b) in izip!(variances, a_args.iter(), b_args.iter()) {
828                    self.generic_args(infcx, *variance, ty_a, ty_b)?;
829                }
830                Ok(())
831            }
832            (BaseTy::FnDef(a_def_id, a_args), BaseTy::FnDef(b_def_id, b_args)) => {
833                debug_assert_eq!(a_def_id, b_def_id);
834                debug_assert_eq!(a_args.len(), b_args.len());
835                // NOTE: we don't check subtyping here because the RHS is *really*
836                // the function type, the LHS is just generated by rustc.
837                // we could generate a subtyping constraint but those would
838                // just be trivial (but might cause useless cycles in fixpoint).
839                // Nico: (This is probably ok because) We never do function
840                // subtyping between `FnDef` *except* when (the def_id) is
841                // passed as an argument to a function.
842                for (arg_a, arg_b) in iter::zip(a_args, b_args) {
843                    match (arg_a, arg_b) {
844                        (GenericArg::Ty(ty_a), GenericArg::Ty(ty_b)) => {
845                            let bty_a = ty_a.as_bty_skipping_existentials();
846                            let bty_b = ty_b.as_bty_skipping_existentials();
847                            tracked_span_dbg_assert_eq!(bty_a, bty_b);
848                        }
849                        (GenericArg::Base(ctor_a), GenericArg::Base(ctor_b)) => {
850                            let bty_a = ctor_a.as_bty_skipping_binder();
851                            let bty_b = ctor_b.as_bty_skipping_binder();
852                            tracked_span_dbg_assert_eq!(bty_a, bty_b);
853                        }
854                        (_, _) => tracked_span_dbg_assert_eq!(arg_a, arg_b),
855                    }
856                }
857                Ok(())
858            }
859            (BaseTy::Float(float_ty_a), BaseTy::Float(float_ty_b)) => {
860                debug_assert_eq!(float_ty_a, float_ty_b);
861                Ok(())
862            }
863            (BaseTy::Slice(ty_a), BaseTy::Slice(ty_b)) => self.tys(infcx, ty_a, ty_b),
864            (BaseTy::Ref(_, ty_a, Mutability::Mut), BaseTy::Ref(_, ty_b, Mutability::Mut)) => {
865                self.tys(infcx, ty_a, ty_b)?;
866                self.tys(infcx, ty_b, ty_a)
867            }
868            (BaseTy::Ref(_, ty_a, Mutability::Not), BaseTy::Ref(_, ty_b, Mutability::Not)) => {
869                self.tys(infcx, ty_a, ty_b)
870            }
871            (BaseTy::Tuple(tys_a), BaseTy::Tuple(tys_b)) => {
872                debug_assert_eq!(tys_a.len(), tys_b.len());
873                for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
874                    self.tys(infcx, ty_a, ty_b)?;
875                }
876                Ok(())
877            }
878            (_, BaseTy::Alias(AliasKind::Opaque, alias_ty_b)) => {
879                if let BaseTy::Alias(AliasKind::Opaque, alias_ty_a) = a {
880                    debug_assert_eq!(alias_ty_a.refine_args.len(), alias_ty_b.refine_args.len());
881                    iter::zip(alias_ty_a.refine_args.iter(), alias_ty_b.refine_args.iter())
882                        .for_each(|(expr_a, expr_b)| infcx.unify_exprs(expr_a, expr_b));
883                }
884                self.handle_opaque_type(infcx, a, alias_ty_b)
885            }
886            (
887                BaseTy::Alias(AliasKind::Projection, alias_ty_a),
888                BaseTy::Alias(AliasKind::Projection, alias_ty_b),
889            ) => {
890                tracked_span_dbg_assert_eq!(alias_ty_a, alias_ty_b);
891                Ok(())
892            }
893            (BaseTy::Array(ty_a, len_a), BaseTy::Array(ty_b, len_b)) => {
894                tracked_span_dbg_assert_eq!(len_a, len_b);
895                self.tys(infcx, ty_a, ty_b)
896            }
897            (BaseTy::Param(param_a), BaseTy::Param(param_b)) => {
898                debug_assert_eq!(param_a, param_b);
899                Ok(())
900            }
901            (BaseTy::Bool, BaseTy::Bool)
902            | (BaseTy::Str, BaseTy::Str)
903            | (BaseTy::Char, BaseTy::Char)
904            | (BaseTy::RawPtr(_, _), BaseTy::RawPtr(_, _))
905            | (BaseTy::RawPtrMetadata(_), BaseTy::RawPtrMetadata(_)) => Ok(()),
906            (BaseTy::Dynamic(preds_a, _), BaseTy::Dynamic(preds_b, _)) => {
907                tracked_span_assert_eq!(preds_a.erase_regions(), preds_b.erase_regions());
908                Ok(())
909            }
910            (BaseTy::Closure(did1, tys_a, _), BaseTy::Closure(did2, tys_b, _)) if did1 == did2 => {
911                debug_assert_eq!(tys_a.len(), tys_b.len());
912                for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
913                    self.tys(infcx, ty_a, ty_b)?;
914                }
915                Ok(())
916            }
917            (BaseTy::FnPtr(sig_a), BaseTy::FnPtr(sig_b)) => {
918                tracked_span_assert_eq!(sig_a.erase_regions(), sig_b.erase_regions());
919                Ok(())
920            }
921            (BaseTy::Never, BaseTy::Never) => Ok(()),
922            _ => Err(query_bug!("incompatible base types: `{a:?}` - `{b:?}`"))?,
923        }
924    }
925
926    fn generic_args(
927        &mut self,
928        infcx: &mut InferCtxt,
929        variance: Variance,
930        a: &GenericArg,
931        b: &GenericArg,
932    ) -> InferResult {
933        let (ty_a, ty_b) = match (a, b) {
934            (GenericArg::Ty(ty_a), GenericArg::Ty(ty_b)) => (ty_a.clone(), ty_b.clone()),
935            (GenericArg::Base(ctor_a), GenericArg::Base(ctor_b)) => {
936                tracked_span_dbg_assert_eq!(ctor_a.sort(), ctor_b.sort());
937                (ctor_a.to_ty(), ctor_b.to_ty())
938            }
939            (GenericArg::Lifetime(_), GenericArg::Lifetime(_)) => return Ok(()),
940            (GenericArg::Const(cst_a), GenericArg::Const(cst_b)) => {
941                debug_assert_eq!(cst_a, cst_b);
942                return Ok(());
943            }
944            _ => Err(query_bug!("incompatible generic args: `{a:?}` `{b:?}`"))?,
945        };
946        match variance {
947            Variance::Covariant => self.tys(infcx, &ty_a, &ty_b),
948            Variance::Invariant => {
949                self.tys(infcx, &ty_a, &ty_b)?;
950                self.tys(infcx, &ty_b, &ty_a)
951            }
952            Variance::Contravariant => self.tys(infcx, &ty_b, &ty_a),
953            Variance::Bivariant => Ok(()),
954        }
955    }
956
957    fn idxs_eq(&mut self, infcx: &mut InferCtxt, a: &Expr, b: &Expr) {
958        if a == b {
959            return;
960        }
961        match (a.kind(), b.kind()) {
962            (
963                ExprKind::Ctor(Ctor::Struct(did_a), flds_a),
964                ExprKind::Ctor(Ctor::Struct(did_b), flds_b),
965            ) => {
966                debug_assert_eq!(did_a, did_b);
967                for (a, b) in iter::zip(flds_a, flds_b) {
968                    self.idxs_eq(infcx, a, b);
969                }
970            }
971            (ExprKind::Tuple(flds_a), ExprKind::Tuple(flds_b)) => {
972                for (a, b) in iter::zip(flds_a, flds_b) {
973                    self.idxs_eq(infcx, a, b);
974                }
975            }
976            (_, ExprKind::Tuple(flds_b)) => {
977                for (f, b) in flds_b.iter().enumerate() {
978                    let proj = FieldProj::Tuple { arity: flds_b.len(), field: f as u32 };
979                    let a = a.proj_and_reduce(proj);
980                    self.idxs_eq(infcx, &a, b);
981                }
982            }
983
984            (_, ExprKind::Ctor(Ctor::Struct(def_id), flds_b)) => {
985                for (f, b) in flds_b.iter().enumerate() {
986                    let proj = FieldProj::Adt { def_id: *def_id, field: f as u32 };
987                    let a = a.proj_and_reduce(proj);
988                    self.idxs_eq(infcx, &a, b);
989                }
990            }
991
992            (ExprKind::Tuple(flds_a), _) => {
993                infcx.unify_exprs(a, b);
994                for (f, a) in flds_a.iter().enumerate() {
995                    let proj = FieldProj::Tuple { arity: flds_a.len(), field: f as u32 };
996                    let b = b.proj_and_reduce(proj);
997                    self.idxs_eq(infcx, a, &b);
998                }
999            }
1000            (ExprKind::Ctor(Ctor::Struct(def_id), flds_a), _) => {
1001                infcx.unify_exprs(a, b);
1002                for (f, a) in flds_a.iter().enumerate() {
1003                    let proj = FieldProj::Adt { def_id: *def_id, field: f as u32 };
1004                    let b = b.proj_and_reduce(proj);
1005                    self.idxs_eq(infcx, a, &b);
1006                }
1007            }
1008            (ExprKind::Abs(lam_a), ExprKind::Abs(lam_b)) => {
1009                self.abs_eq(infcx, lam_a, lam_b);
1010            }
1011            (_, ExprKind::Abs(lam_b)) => {
1012                self.abs_eq(infcx, &a.eta_expand_abs(lam_b.vars(), lam_b.output()), lam_b);
1013            }
1014            (ExprKind::Abs(lam_a), _) => {
1015                infcx.unify_exprs(a, b);
1016                self.abs_eq(infcx, lam_a, &b.eta_expand_abs(lam_a.vars(), lam_a.output()));
1017            }
1018            (ExprKind::KVar(_), _) | (_, ExprKind::KVar(_)) => {
1019                infcx.check_impl(a, b, self.tag());
1020                infcx.check_impl(b, a, self.tag());
1021            }
1022            _ => {
1023                infcx.unify_exprs(a, b);
1024                let span = b.span();
1025                infcx.check_pred(Expr::binary_op(rty::BinOp::Eq, a, b).at_opt(span), self.tag());
1026            }
1027        }
1028    }
1029
1030    fn abs_eq(&mut self, infcx: &mut InferCtxt, a: &Lambda, b: &Lambda) {
1031        debug_assert_eq!(a.vars().len(), b.vars().len());
1032        let vars = a
1033            .vars()
1034            .iter()
1035            .map(|kind| Expr::fvar(infcx.define_var(kind.expect_sort())))
1036            .collect_vec();
1037        let body_a = a.apply(&vars);
1038        let body_b = b.apply(&vars);
1039        self.idxs_eq(infcx, &body_a, &body_b);
1040    }
1041
1042    fn handle_opaque_type(
1043        &mut self,
1044        infcx: &mut InferCtxt,
1045        bty: &BaseTy,
1046        alias_ty: &AliasTy,
1047    ) -> InferResult {
1048        if let BaseTy::Coroutine(def_id, resume_ty, upvar_tys) = bty {
1049            let obligs = mk_coroutine_obligations(
1050                infcx.genv,
1051                def_id,
1052                resume_ty,
1053                upvar_tys,
1054                &alias_ty.def_id,
1055            )?;
1056            self.obligations.extend(obligs);
1057        } else {
1058            let bounds = infcx.genv.item_bounds(alias_ty.def_id)?.instantiate(
1059                infcx.tcx(),
1060                &alias_ty.args,
1061                &alias_ty.refine_args,
1062            );
1063            for clause in &bounds {
1064                if !clause.kind().vars().is_empty() {
1065                    Err(query_bug!("handle_opaque_types: clause with bound vars: `{clause:?}`"))?;
1066                }
1067                if let rty::ClauseKind::Projection(pred) = clause.kind_skipping_binder() {
1068                    let alias_ty = pred.projection_ty.with_self_ty(bty.to_subset_ty_ctor());
1069                    let ty1 = BaseTy::Alias(AliasKind::Projection, alias_ty)
1070                        .to_ty()
1071                        .normalize_projections(&mut infcx.at(self.span))?;
1072                    let ty2 = pred.term.to_ty();
1073                    self.tys(infcx, &ty1, &ty2)?;
1074                }
1075            }
1076        }
1077        Ok(())
1078    }
1079}
1080
1081fn mk_coroutine_obligations(
1082    genv: GlobalEnv,
1083    generator_did: &DefId,
1084    resume_ty: &Ty,
1085    upvar_tys: &List<Ty>,
1086    opaque_def_id: &DefId,
1087) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
1088    let bounds = genv.item_bounds(*opaque_def_id)?.skip_binder();
1089    for bound in &bounds {
1090        if let Some(proj_clause) = bound.as_projection_clause() {
1091            return Ok(vec![proj_clause.map(|proj_clause| {
1092                let output = proj_clause.term;
1093                CoroutineObligPredicate {
1094                    def_id: *generator_did,
1095                    resume_ty: resume_ty.clone(),
1096                    upvar_tys: upvar_tys.clone(),
1097                    output: output.to_ty(),
1098                }
1099            })]);
1100        }
1101    }
1102    bug!("no projection predicate")
1103}
1104
1105#[derive(Debug)]
1106pub enum InferErr {
1107    UnsolvedEvar(EVid),
1108    Query(QueryErr),
1109}
1110
1111impl From<QueryErr> for InferErr {
1112    fn from(v: QueryErr) -> Self {
1113        Self::Query(v)
1114    }
1115}
1116
1117mod pretty {
1118    use std::fmt;
1119
1120    use flux_middle::pretty::*;
1121
1122    use super::*;
1123
1124    impl Pretty for Tag {
1125        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1126            w!(cx, f, "{:?} at {:?}", ^self.reason, self.src_span)
1127        }
1128    }
1129
1130    impl_debug_with_default_cx!(Tag);
1131}