flux_infer/
infer.rs

1use std::{cell::RefCell, fmt, iter};
2
3use flux_common::{bug, dbg, tracked_span_assert_eq, tracked_span_bug, tracked_span_dbg_assert_eq};
4use flux_config::{self as config, InferOpts, OverflowMode};
5use flux_macros::{TypeFoldable, TypeVisitable};
6use flux_middle::{
7    FixpointQueryKind,
8    def_id::MaybeExternId,
9    global_env::GlobalEnv,
10    queries::{QueryErr, QueryResult},
11    query_bug,
12    rty::{
13        self, AliasKind, AliasTy, BaseTy, Binder, BoundVariableKinds, CoroutineObligPredicate,
14        Ctor, ESpan, EVid, EarlyBinder, Expr, ExprKind, FieldProj, GenericArg, HoleKind, InferMode,
15        Lambda, List, Loc, Mutability, Name, Path, PolyVariant, PtrKind, RefineArgs, RefineArgsExt,
16        Region, Sort, Ty, TyCtor, TyKind, Var,
17        canonicalize::{Hoister, HoisterDelegate},
18        fold::TypeFoldable,
19    },
20};
21use itertools::{Itertools, izip};
22use rustc_hir::def_id::{DefId, LocalDefId};
23use rustc_macros::extension;
24use rustc_middle::{
25    mir::BasicBlock,
26    ty::{TyCtxt, Variance},
27};
28use rustc_span::Span;
29
30use crate::{
31    evars::{EVarState, EVarStore},
32    fixpoint_encoding::{FixQueryCache, FixpointCtxt, KVarEncoding, KVarGen},
33    projections::NormalizeExt as _,
34    refine_tree::{Cursor, Marker, RefineTree, Scope},
35};
36
37pub type InferResult<T = ()> = std::result::Result<T, InferErr>;
38
39#[derive(PartialEq, Eq, Clone, Copy, Hash)]
40pub struct Tag {
41    pub reason: ConstrReason,
42    pub src_span: Span,
43    pub dst_span: Option<ESpan>,
44}
45
46impl Tag {
47    pub fn new(reason: ConstrReason, span: Span) -> Self {
48        Self { reason, src_span: span, dst_span: None }
49    }
50
51    pub fn with_dst(self, dst_span: Option<ESpan>) -> Self {
52        Self { dst_span, ..self }
53    }
54}
55
56#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)]
57pub enum SubtypeReason {
58    Input,
59    Output,
60    Requires,
61    Ensures,
62}
63
64#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)]
65pub enum ConstrReason {
66    Call,
67    Assign,
68    Ret,
69    Fold,
70    FoldLocal,
71    Predicate,
72    Assert(&'static str),
73    Div,
74    Rem,
75    Goto(BasicBlock),
76    Overflow,
77    Underflow,
78    Subtype(SubtypeReason),
79    Other,
80}
81
82pub struct InferCtxtRoot<'genv, 'tcx> {
83    pub genv: GlobalEnv<'genv, 'tcx>,
84    inner: RefCell<InferCtxtInner>,
85    refine_tree: RefineTree,
86    opts: InferOpts,
87}
88
89pub struct InferCtxtRootBuilder<'a, 'genv, 'tcx> {
90    genv: GlobalEnv<'genv, 'tcx>,
91    opts: InferOpts,
92    params: Vec<(Var, Sort)>,
93    infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
94    dummy_kvars: bool,
95}
96
97#[extension(pub trait GlobalEnvExt<'genv, 'tcx>)]
98impl<'genv, 'tcx> GlobalEnv<'genv, 'tcx> {
99    fn infcx_root<'a>(
100        self,
101        infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
102        opts: InferOpts,
103    ) -> InferCtxtRootBuilder<'a, 'genv, 'tcx> {
104        InferCtxtRootBuilder { genv: self, infcx, params: vec![], opts, dummy_kvars: false }
105    }
106}
107
108impl<'genv, 'tcx> InferCtxtRootBuilder<'_, 'genv, 'tcx> {
109    pub fn with_dummy_kvars(mut self) -> Self {
110        self.dummy_kvars = true;
111        self
112    }
113
114    pub fn with_const_generics(mut self, def_id: DefId) -> QueryResult<Self> {
115        self.params.extend(
116            self.genv
117                .generics_of(def_id)?
118                .const_params(self.genv)?
119                .into_iter()
120                .map(|(pcst, sort)| (Var::ConstGeneric(pcst), sort)),
121        );
122        Ok(self)
123    }
124
125    pub fn with_refinement_generics(
126        mut self,
127        def_id: DefId,
128        args: &[GenericArg],
129    ) -> QueryResult<Self> {
130        for (index, param) in self
131            .genv
132            .refinement_generics_of(def_id)?
133            .iter_own_params()
134            .enumerate()
135        {
136            let param = param.instantiate(self.genv.tcx(), args, &[]);
137            let sort = param
138                .sort
139                .deeply_normalize_sorts(def_id, self.genv, self.infcx)?;
140
141            let var =
142                Var::EarlyParam(rty::EarlyReftParam { index: index as u32, name: param.name });
143            self.params.push((var, sort));
144        }
145        Ok(self)
146    }
147
148    pub fn identity_for_item(mut self, def_id: DefId) -> QueryResult<Self> {
149        self = self.with_const_generics(def_id)?;
150        let offset = self.params.len();
151        self.genv.refinement_generics_of(def_id)?.fill_item(
152            self.genv,
153            &mut self.params,
154            &mut |param, index| {
155                let index = (index - offset) as u32;
156                let param = param.instantiate_identity();
157                let sort = param
158                    .sort
159                    .deeply_normalize_sorts(def_id, self.genv, self.infcx)?;
160
161                let var = Var::EarlyParam(rty::EarlyReftParam { index, name: param.name });
162                Ok((var, sort))
163            },
164        )?;
165        Ok(self)
166    }
167
168    pub fn build(self) -> QueryResult<InferCtxtRoot<'genv, 'tcx>> {
169        Ok(InferCtxtRoot {
170            genv: self.genv,
171            inner: RefCell::new(InferCtxtInner::new(self.dummy_kvars)),
172            refine_tree: RefineTree::new(self.params),
173            opts: self.opts,
174        })
175    }
176}
177
178impl<'genv, 'tcx> InferCtxtRoot<'genv, 'tcx> {
179    pub fn infcx<'a>(
180        &'a mut self,
181        def_id: DefId,
182        region_infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
183    ) -> InferCtxt<'a, 'genv, 'tcx> {
184        InferCtxt {
185            genv: self.genv,
186            region_infcx,
187            def_id,
188            cursor: self.refine_tree.cursor_at_root(),
189            inner: &self.inner,
190            check_overflow: self.opts.check_overflow,
191        }
192    }
193
194    pub fn fresh_kvar_in_scope(
195        &self,
196        binders: &[BoundVariableKinds],
197        scope: &Scope,
198        encoding: KVarEncoding,
199    ) -> Expr {
200        let inner = &mut *self.inner.borrow_mut();
201        inner.kvars.fresh(binders, scope.iter(), encoding)
202    }
203
204    pub fn execute_lean_query(self, def_id: MaybeExternId) -> QueryResult<()> {
205        let inner = self.inner.into_inner();
206        let kvars = inner.kvars;
207        let evars = inner.evars;
208        let mut refine_tree = self.refine_tree;
209        refine_tree.replace_evars(&evars).unwrap();
210        refine_tree.simplify(self.genv);
211
212        let mut fcx = FixpointCtxt::new(self.genv, def_id, kvars);
213        let cstr = refine_tree.into_fixpoint(&mut fcx)?;
214        fcx.generate_and_check_lean_lemmas(cstr)
215    }
216
217    pub fn execute_fixpoint_query(
218        self,
219        cache: &mut FixQueryCache,
220        def_id: MaybeExternId,
221        kind: FixpointQueryKind,
222    ) -> QueryResult<Vec<Tag>> {
223        let inner = self.inner.into_inner();
224        let kvars = inner.kvars;
225        let evars = inner.evars;
226
227        let ext = kind.ext();
228
229        let mut refine_tree = self.refine_tree;
230
231        refine_tree.replace_evars(&evars).unwrap();
232
233        if config::dump_constraint() {
234            dbg::dump_item_info(self.genv.tcx(), def_id.resolved_id(), ext, &refine_tree).unwrap();
235        }
236        refine_tree.simplify(self.genv);
237        if config::dump_constraint() {
238            let simp_ext = format!("simp.{ext}");
239            dbg::dump_item_info(self.genv.tcx(), def_id.resolved_id(), simp_ext, &refine_tree)
240                .unwrap();
241        }
242
243        let mut fcx = FixpointCtxt::new(self.genv, def_id, kvars);
244        let cstr = refine_tree.into_fixpoint(&mut fcx)?;
245
246        let backend = match self.opts.solver {
247            flux_config::SmtSolver::Z3 => liquid_fixpoint::SmtSolver::Z3,
248            flux_config::SmtSolver::CVC5 => liquid_fixpoint::SmtSolver::CVC5,
249        };
250
251        fcx.check(cache, cstr, kind, self.opts.scrape_quals, backend)
252    }
253
254    pub fn split(self) -> (RefineTree, KVarGen) {
255        (self.refine_tree, self.inner.into_inner().kvars)
256    }
257}
258
259pub struct InferCtxt<'infcx, 'genv, 'tcx> {
260    pub genv: GlobalEnv<'genv, 'tcx>,
261    pub region_infcx: &'infcx rustc_infer::infer::InferCtxt<'tcx>,
262    pub def_id: DefId,
263    pub check_overflow: OverflowMode,
264    cursor: Cursor<'infcx>,
265    inner: &'infcx RefCell<InferCtxtInner>,
266}
267
268struct InferCtxtInner {
269    kvars: KVarGen,
270    evars: EVarStore,
271}
272
273impl InferCtxtInner {
274    fn new(dummy_kvars: bool) -> Self {
275        Self { kvars: KVarGen::new(dummy_kvars), evars: Default::default() }
276    }
277}
278
279impl<'infcx, 'genv, 'tcx> InferCtxt<'infcx, 'genv, 'tcx> {
280    pub fn at(&mut self, span: Span) -> InferCtxtAt<'_, 'infcx, 'genv, 'tcx> {
281        InferCtxtAt { infcx: self, span }
282    }
283
284    pub fn instantiate_refine_args(
285        &mut self,
286        callee_def_id: DefId,
287        args: &[rty::GenericArg],
288    ) -> InferResult<List<Expr>> {
289        Ok(RefineArgs::for_item(self.genv, callee_def_id, |param, _| {
290            let param = param.instantiate(self.genv.tcx(), args, &[]);
291            Ok(self.fresh_infer_var(&param.sort, param.mode))
292        })?)
293    }
294
295    pub fn instantiate_generic_args(&mut self, args: &[GenericArg]) -> Vec<GenericArg> {
296        args.iter()
297            .map(|a| a.replace_holes(|binders, kind| self.fresh_infer_var_for_hole(binders, kind)))
298            .collect_vec()
299    }
300
301    pub fn fresh_infer_var(&self, sort: &Sort, mode: InferMode) -> Expr {
302        match mode {
303            InferMode::KVar => {
304                let fsort = sort.expect_func().expect_mono();
305                let vars = fsort.inputs().iter().cloned().map_into().collect();
306                let kvar = self.fresh_kvar(&[vars], KVarEncoding::Single);
307                Expr::abs(Lambda::bind_with_fsort(kvar, fsort))
308            }
309            InferMode::EVar => self.fresh_evar(),
310        }
311    }
312
313    pub fn fresh_infer_var_for_hole(
314        &mut self,
315        binders: &[BoundVariableKinds],
316        kind: HoleKind,
317    ) -> Expr {
318        match kind {
319            HoleKind::Pred => self.fresh_kvar(binders, KVarEncoding::Conj),
320            HoleKind::Expr(_) => {
321                // We only use expression holes to infer early param arguments for opaque types
322                // at function calls. These should be well-scoped in the current scope, so we ignore
323                // the extra `binders` around the hole.
324                self.fresh_evar()
325            }
326        }
327    }
328
329    /// Generate a fresh kvar in the _given_ [`Scope`] (similar method in [`InferCtxtRoot`]).
330    pub fn fresh_kvar_in_scope(
331        &self,
332        binders: &[BoundVariableKinds],
333        scope: &Scope,
334        encoding: KVarEncoding,
335    ) -> Expr {
336        let inner = &mut *self.inner.borrow_mut();
337        inner.kvars.fresh(binders, scope.iter(), encoding)
338    }
339
340    /// Generate a fresh kvar in the current scope. See [`KVarGen::fresh`].
341    pub fn fresh_kvar(&self, binders: &[BoundVariableKinds], encoding: KVarEncoding) -> Expr {
342        let inner = &mut *self.inner.borrow_mut();
343        inner.kvars.fresh(binders, self.cursor.vars(), encoding)
344    }
345
346    fn fresh_evar(&self) -> Expr {
347        let evars = &mut self.inner.borrow_mut().evars;
348        Expr::evar(evars.fresh(self.cursor.marker()))
349    }
350
351    pub fn unify_exprs(&self, a: &Expr, b: &Expr) {
352        if a.has_evars() {
353            return;
354        }
355        let evars = &mut self.inner.borrow_mut().evars;
356        if let ExprKind::Var(Var::EVar(evid)) = b.kind()
357            && let EVarState::Unsolved(marker) = evars.get(*evid)
358            && !marker.has_free_vars(a)
359        {
360            evars.solve(*evid, a.clone());
361        }
362    }
363
364    fn enter_exists<T, U>(
365        &mut self,
366        t: &Binder<T>,
367        f: impl FnOnce(&mut InferCtxt<'_, 'genv, 'tcx>, T) -> U,
368    ) -> U
369    where
370        T: TypeFoldable,
371    {
372        self.ensure_resolved_evars(|infcx| {
373            let t = t.replace_bound_refts_with(|sort, mode, _| infcx.fresh_infer_var(sort, mode));
374            Ok(f(infcx, t))
375        })
376        .unwrap()
377    }
378
379    /// Used in conjunction with [`InferCtxt::pop_evar_scope`] to ensure evars are solved at the end
380    /// of some scope, for example, to ensure all evars generated during a function call are solved
381    /// after checking argument subtyping. These functions can be used in a stack-like fashion to
382    /// create nested scopes.
383    pub fn push_evar_scope(&mut self) {
384        self.inner.borrow_mut().evars.push_scope();
385    }
386
387    /// Pop a scope and check all evars have been solved. This only check evars generated from the
388    /// last call to [`InferCtxt::push_evar_scope`].
389    pub fn pop_evar_scope(&mut self) -> InferResult {
390        self.inner
391            .borrow_mut()
392            .evars
393            .pop_scope()
394            .map_err(InferErr::UnsolvedEvar)
395    }
396
397    /// Convenience method pairing [`InferCtxt::push_evar_scope`] and [`InferCtxt::pop_evar_scope`].
398    pub fn ensure_resolved_evars<R>(
399        &mut self,
400        f: impl FnOnce(&mut Self) -> InferResult<R>,
401    ) -> InferResult<R> {
402        self.push_evar_scope();
403        let r = f(self)?;
404        self.pop_evar_scope()?;
405        Ok(r)
406    }
407
408    pub fn fully_resolve_evars<T: TypeFoldable>(&self, t: &T) -> T {
409        self.inner.borrow().evars.replace_evars(t).unwrap()
410    }
411
412    pub fn tcx(&self) -> TyCtxt<'tcx> {
413        self.genv.tcx()
414    }
415
416    pub fn cursor(&self) -> &Cursor<'infcx> {
417        &self.cursor
418    }
419}
420
421/// Methods that interact with the underlying [`Cursor`]
422impl<'infcx, 'genv, 'tcx> InferCtxt<'infcx, 'genv, 'tcx> {
423    pub fn change_item<'a>(
424        &'a mut self,
425        def_id: LocalDefId,
426        region_infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
427    ) -> InferCtxt<'a, 'genv, 'tcx> {
428        InferCtxt {
429            def_id: def_id.to_def_id(),
430            cursor: self.cursor.branch(),
431            region_infcx,
432            ..*self
433        }
434    }
435
436    pub fn move_to(&mut self, marker: &Marker, clear_children: bool) -> InferCtxt<'_, 'genv, 'tcx> {
437        InferCtxt {
438            cursor: self
439                .cursor
440                .move_to(marker, clear_children)
441                .unwrap_or_else(|| tracked_span_bug!()),
442            ..*self
443        }
444    }
445
446    pub fn branch(&mut self) -> InferCtxt<'_, 'genv, 'tcx> {
447        InferCtxt { cursor: self.cursor.branch(), ..*self }
448    }
449
450    pub fn define_var(&mut self, sort: &Sort) -> Name {
451        self.cursor.define_var(sort)
452    }
453
454    pub fn check_pred(&mut self, pred: impl Into<Expr>, tag: Tag) {
455        self.cursor.check_pred(pred, tag);
456    }
457
458    pub fn assume_pred(&mut self, pred: impl Into<Expr>) {
459        self.cursor.assume_pred(pred);
460    }
461
462    pub fn unpack(&mut self, ty: &Ty) -> Ty {
463        self.hoister(false).hoist(ty)
464    }
465
466    pub fn marker(&self) -> Marker {
467        self.cursor.marker()
468    }
469
470    pub fn hoister(
471        &mut self,
472        assume_invariants: bool,
473    ) -> Hoister<Unpacker<'_, 'infcx, 'genv, 'tcx>> {
474        Hoister::with_delegate(Unpacker { infcx: self, assume_invariants }).transparent()
475    }
476
477    pub fn assume_invariants(&mut self, ty: &Ty) {
478        self.cursor
479            .assume_invariants(self.genv.tcx(), ty, self.check_overflow);
480    }
481
482    fn check_impl(&mut self, pred1: impl Into<Expr>, pred2: impl Into<Expr>, tag: Tag) {
483        self.cursor.check_impl(pred1, pred2, tag);
484    }
485}
486
487pub struct Unpacker<'a, 'infcx, 'genv, 'tcx> {
488    infcx: &'a mut InferCtxt<'infcx, 'genv, 'tcx>,
489    assume_invariants: bool,
490}
491
492impl HoisterDelegate for Unpacker<'_, '_, '_, '_> {
493    fn hoist_exists(&mut self, ty_ctor: &TyCtor) -> Ty {
494        let ty =
495            ty_ctor.replace_bound_refts_with(|sort, _, _| Expr::fvar(self.infcx.define_var(sort)));
496        if self.assume_invariants {
497            self.infcx.assume_invariants(&ty);
498        }
499        ty
500    }
501
502    fn hoist_constr(&mut self, pred: Expr) {
503        self.infcx.assume_pred(pred);
504    }
505}
506
507impl std::fmt::Debug for InferCtxt<'_, '_, '_> {
508    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
509        std::fmt::Debug::fmt(&self.cursor, f)
510    }
511}
512
513#[derive(Debug)]
514pub struct InferCtxtAt<'a, 'infcx, 'genv, 'tcx> {
515    pub infcx: &'a mut InferCtxt<'infcx, 'genv, 'tcx>,
516    pub span: Span,
517}
518
519impl<'genv, 'tcx> InferCtxtAt<'_, '_, 'genv, 'tcx> {
520    fn tag(&self, reason: ConstrReason) -> Tag {
521        Tag::new(reason, self.span)
522    }
523
524    pub fn check_pred(&mut self, pred: impl Into<Expr>, reason: ConstrReason) {
525        let tag = self.tag(reason);
526        self.infcx.check_pred(pred, tag);
527    }
528
529    pub fn check_non_closure_clauses(
530        &mut self,
531        clauses: &[rty::Clause],
532        reason: ConstrReason,
533    ) -> InferResult {
534        for clause in clauses {
535            if let rty::ClauseKind::Projection(projection_pred) = clause.kind_skipping_binder() {
536                let impl_elem = BaseTy::projection(projection_pred.projection_ty)
537                    .to_ty()
538                    .deeply_normalize(self)?;
539                let term = projection_pred.term.to_ty().deeply_normalize(self)?;
540
541                // TODO: does this really need to be invariant? https://github.com/flux-rs/flux/pull/478#issuecomment-1654035374
542                self.subtyping(&impl_elem, &term, reason)?;
543                self.subtyping(&term, &impl_elem, reason)?;
544            }
545        }
546        Ok(())
547    }
548
549    /// Relate types via subtyping. This is the same as [`InferCtxtAt::subtyping`] except that we
550    /// also require a [`LocEnv`] to handle pointers and strong references
551    pub fn subtyping_with_env(
552        &mut self,
553        env: &mut impl LocEnv,
554        a: &Ty,
555        b: &Ty,
556        reason: ConstrReason,
557    ) -> InferResult {
558        let mut sub = Sub::new(env, reason, self.span);
559        sub.tys(self.infcx, a, b)
560    }
561
562    /// Relate types via subtyping and returns coroutine obligations. This doesn't handle subtyping
563    /// when strong references are involved.
564    ///
565    /// See comment for [`Sub::obligations`].
566    pub fn subtyping(
567        &mut self,
568        a: &Ty,
569        b: &Ty,
570        reason: ConstrReason,
571    ) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
572        let mut env = DummyEnv;
573        let mut sub = Sub::new(&mut env, reason, self.span);
574        sub.tys(self.infcx, a, b)?;
575        Ok(sub.obligations)
576    }
577
578    pub fn subtyping_generic_args(
579        &mut self,
580        variance: Variance,
581        a: &GenericArg,
582        b: &GenericArg,
583        reason: ConstrReason,
584    ) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
585        let mut env = DummyEnv;
586        let mut sub = Sub::new(&mut env, reason, self.span);
587        sub.generic_args(self.infcx, variance, a, b)?;
588        Ok(sub.obligations)
589    }
590
591    // FIXME(nilehmann) this is similar to `Checker::check_call`, but since is used from
592    // `place_ty::fold` we cannot use that directly. We should try to unify them, because
593    // there are a couple of things missing here (e.g., checking clauses on the struct definition).
594    pub fn check_constructor(
595        &mut self,
596        variant: EarlyBinder<PolyVariant>,
597        generic_args: &[GenericArg],
598        fields: &[Ty],
599        reason: ConstrReason,
600    ) -> InferResult<Ty> {
601        let ret = self.ensure_resolved_evars(|this| {
602            // Replace holes in generic arguments with fresh inference variables
603            let generic_args = this.instantiate_generic_args(generic_args);
604
605            let variant = variant
606                .instantiate(this.tcx(), &generic_args, &[])
607                .replace_bound_refts_with(|sort, mode, _| this.fresh_infer_var(sort, mode));
608
609            // Check arguments
610            for (actual, formal) in iter::zip(fields, variant.fields()) {
611                this.subtyping(actual, formal, reason)?;
612            }
613
614            // Check requires predicates
615            for require in &variant.requires {
616                this.check_pred(require, ConstrReason::Fold);
617            }
618
619            Ok(variant.ret())
620        })?;
621        Ok(self.fully_resolve_evars(&ret))
622    }
623
624    pub fn ensure_resolved_evars<R>(
625        &mut self,
626        f: impl FnOnce(&mut InferCtxtAt<'_, '_, 'genv, 'tcx>) -> InferResult<R>,
627    ) -> InferResult<R> {
628        self.infcx
629            .ensure_resolved_evars(|infcx| f(&mut infcx.at(self.span)))
630    }
631}
632
633impl<'a, 'genv, 'tcx> std::ops::Deref for InferCtxtAt<'_, 'a, 'genv, 'tcx> {
634    type Target = InferCtxt<'a, 'genv, 'tcx>;
635
636    fn deref(&self) -> &Self::Target {
637        self.infcx
638    }
639}
640
641impl std::ops::DerefMut for InferCtxtAt<'_, '_, '_, '_> {
642    fn deref_mut(&mut self) -> &mut Self::Target {
643        self.infcx
644    }
645}
646
647/// Used for debugging to attach a "trace" to the [`RefineTree`] that can be used to print information
648/// to recover the derivation when relating types via subtyping. The code that attaches the trace is
649/// currently commented out because the output is too verbose.
650#[derive(TypeVisitable, TypeFoldable)]
651pub(crate) enum TypeTrace {
652    Types(Ty, Ty),
653    BaseTys(BaseTy, BaseTy),
654}
655
656#[expect(dead_code, reason = "we use this for debugging some time")]
657impl TypeTrace {
658    fn tys(a: &Ty, b: &Ty) -> Self {
659        Self::Types(a.clone(), b.clone())
660    }
661
662    fn btys(a: &BaseTy, b: &BaseTy) -> Self {
663        Self::BaseTys(a.clone(), b.clone())
664    }
665}
666
667impl fmt::Debug for TypeTrace {
668    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
669        match self {
670            TypeTrace::Types(a, b) => write!(f, "{a:?} - {b:?}"),
671            TypeTrace::BaseTys(a, b) => write!(f, "{a:?} - {b:?}"),
672        }
673    }
674}
675
676pub trait LocEnv {
677    fn ptr_to_ref(
678        &mut self,
679        infcx: &mut InferCtxtAt,
680        reason: ConstrReason,
681        re: Region,
682        path: &Path,
683        bound: Ty,
684    ) -> InferResult<Ty>;
685
686    fn unfold_strg_ref(&mut self, infcx: &mut InferCtxt, path: &Path, ty: &Ty) -> InferResult<Loc>;
687
688    fn get(&self, path: &Path) -> Ty;
689}
690
691struct DummyEnv;
692
693impl LocEnv for DummyEnv {
694    fn ptr_to_ref(
695        &mut self,
696        _: &mut InferCtxtAt,
697        _: ConstrReason,
698        _: Region,
699        _: &Path,
700        _: Ty,
701    ) -> InferResult<Ty> {
702        bug!("call to `ptr_to_ref` on `DummyEnv`")
703    }
704
705    fn unfold_strg_ref(&mut self, _: &mut InferCtxt, _: &Path, _: &Ty) -> InferResult<Loc> {
706        bug!("call to `unfold_str_ref` on `DummyEnv`")
707    }
708
709    fn get(&self, _: &Path) -> Ty {
710        bug!("call to `get` on `DummyEnv`")
711    }
712}
713
714/// Context used to relate two types `a` and `b` via subtyping
715struct Sub<'a, E> {
716    /// The environment to lookup locations pointed to by [`TyKind::Ptr`].
717    env: &'a mut E,
718    reason: ConstrReason,
719    span: Span,
720    /// FIXME(nilehmann) This is used to store coroutine obligations generated during subtyping when
721    /// relating an opaque type. Other obligations related to relating opaque types are resolved
722    /// directly here. The implementation is really messy and we may be missing some obligations.
723    obligations: Vec<Binder<rty::CoroutineObligPredicate>>,
724}
725
726impl<'a, E: LocEnv> Sub<'a, E> {
727    fn new(env: &'a mut E, reason: ConstrReason, span: Span) -> Self {
728        Self { env, reason, span, obligations: vec![] }
729    }
730
731    fn tag(&self) -> Tag {
732        Tag::new(self.reason, self.span)
733    }
734
735    fn tys(&mut self, infcx: &mut InferCtxt, a: &Ty, b: &Ty) -> InferResult {
736        let infcx = &mut infcx.branch();
737        // infcx.cursor.push_trace(TypeTrace::tys(a, b));
738
739        // We *fully* unpack the lhs before continuing to be able to prove goals like this
740        // ∃a. (i32[a], ∃b. {i32[b] | a > b})} <: ∃a,b. ({i32[a] | b < a}, i32[b])
741        // See S4.5 in https://arxiv.org/pdf/2209.13000v1.pdf
742        let a = infcx.unpack(a);
743
744        match (a.kind(), b.kind()) {
745            (TyKind::Exists(..), _) => {
746                bug!("existentials should have been removed by the unpacking above");
747            }
748            (TyKind::Constr(..), _) => {
749                bug!("constraint types should have been removed by the unpacking above");
750            }
751
752            (_, TyKind::Exists(ctor_b)) => {
753                infcx.enter_exists(ctor_b, |infcx, ty_b| self.tys(infcx, &a, &ty_b))
754            }
755            (_, TyKind::Constr(pred_b, ty_b)) => {
756                infcx.check_pred(pred_b, self.tag());
757                self.tys(infcx, &a, ty_b)
758            }
759
760            (TyKind::Ptr(PtrKind::Mut(_), path_a), TyKind::StrgRef(_, path_b, ty_b)) => {
761                // We should technically remove `path1` from `env`, but we are assuming that functions
762                // always give back ownership of the location so `path1` is going to be overwritten
763                // after the call anyways.
764                let ty_a = self.env.get(path_a);
765                infcx.unify_exprs(&path_a.to_expr(), &path_b.to_expr());
766                self.tys(infcx, &ty_a, ty_b)
767            }
768            (TyKind::StrgRef(_, path_a, ty_a), TyKind::StrgRef(_, path_b, ty_b)) => {
769                // We have to unfold strong references prior to a subtyping check. Normally, when
770                // checking a function body, a `StrgRef` is automatically unfolded i.e. `x:&strg T`
771                // is turned into a `x:ptr(l); l: T` where `l` is some fresh location. However, we
772                // need the below to do a similar unfolding during function subtyping where we just
773                // have the super-type signature that needs to be unfolded. We also add the binding
774                // to the environment so that we can:
775                // (1) UPDATE the location after the call, and
776                // (2) CHECK the relevant `ensures` clauses of the super-sig.
777                // Same as the `Ptr` case above we should remove the location from the environment
778                // after unfolding to consume it, but we are assuming functions always give back
779                // ownership.
780                self.env.unfold_strg_ref(infcx, path_a, ty_a)?;
781                let ty_a = self.env.get(path_a);
782                infcx.unify_exprs(&path_a.to_expr(), &path_b.to_expr());
783                self.tys(infcx, &ty_a, ty_b)
784            }
785            (
786                TyKind::Ptr(PtrKind::Mut(re), path),
787                TyKind::Indexed(BaseTy::Ref(_, bound, Mutability::Mut), idx),
788            ) => {
789                // We sometimes generate evars for the index of references so we need to make sure
790                // we solve them.
791                self.idxs_eq(infcx, &Expr::unit(), idx);
792
793                self.env.ptr_to_ref(
794                    &mut infcx.at(self.span),
795                    self.reason,
796                    *re,
797                    path,
798                    bound.clone(),
799                )?;
800                Ok(())
801            }
802
803            (TyKind::Indexed(bty_a, idx_a), TyKind::Indexed(bty_b, idx_b)) => {
804                self.btys(infcx, bty_a, bty_b)?;
805                self.idxs_eq(infcx, idx_a, idx_b);
806                Ok(())
807            }
808            (TyKind::Ptr(pk_a, path_a), TyKind::Ptr(pk_b, path_b)) => {
809                debug_assert_eq!(pk_a, pk_b);
810                debug_assert_eq!(path_a, path_b);
811                Ok(())
812            }
813            (TyKind::Param(param_ty_a), TyKind::Param(param_ty_b)) => {
814                debug_assert_eq!(param_ty_a, param_ty_b);
815                Ok(())
816            }
817            (_, TyKind::Uninit) => Ok(()),
818            (TyKind::Downcast(.., fields_a), TyKind::Downcast(.., fields_b)) => {
819                debug_assert_eq!(fields_a.len(), fields_b.len());
820                for (ty_a, ty_b) in iter::zip(fields_a, fields_b) {
821                    self.tys(infcx, ty_a, ty_b)?;
822                }
823                Ok(())
824            }
825            _ => Err(query_bug!("incompatible types: `{a:?}` - `{b:?}`"))?,
826        }
827    }
828
829    fn btys(&mut self, infcx: &mut InferCtxt, a: &BaseTy, b: &BaseTy) -> InferResult {
830        // infcx.push_trace(TypeTrace::btys(a, b));
831
832        match (a, b) {
833            (BaseTy::Int(int_ty_a), BaseTy::Int(int_ty_b)) => {
834                debug_assert_eq!(int_ty_a, int_ty_b);
835                Ok(())
836            }
837            (BaseTy::Uint(uint_ty_a), BaseTy::Uint(uint_ty_b)) => {
838                debug_assert_eq!(uint_ty_a, uint_ty_b);
839                Ok(())
840            }
841            (BaseTy::Adt(a_adt, a_args), BaseTy::Adt(b_adt, b_args)) => {
842                tracked_span_dbg_assert_eq!(a_adt.did(), b_adt.did());
843                tracked_span_dbg_assert_eq!(a_args.len(), b_args.len());
844                let variances = infcx.genv.variances_of(a_adt.did());
845                for (variance, ty_a, ty_b) in izip!(variances, a_args.iter(), b_args.iter()) {
846                    self.generic_args(infcx, *variance, ty_a, ty_b)?;
847                }
848                Ok(())
849            }
850            (BaseTy::FnDef(a_def_id, a_args), BaseTy::FnDef(b_def_id, b_args)) => {
851                debug_assert_eq!(a_def_id, b_def_id);
852                debug_assert_eq!(a_args.len(), b_args.len());
853                // NOTE: we don't check subtyping here because the RHS is *really*
854                // the function type, the LHS is just generated by rustc.
855                // we could generate a subtyping constraint but those would
856                // just be trivial (but might cause useless cycles in fixpoint).
857                // Nico: (This is probably ok because) We never do function
858                // subtyping between `FnDef` *except* when (the def_id) is
859                // passed as an argument to a function.
860                for (arg_a, arg_b) in iter::zip(a_args, b_args) {
861                    match (arg_a, arg_b) {
862                        (GenericArg::Ty(ty_a), GenericArg::Ty(ty_b)) => {
863                            let bty_a = ty_a.as_bty_skipping_existentials();
864                            let bty_b = ty_b.as_bty_skipping_existentials();
865                            tracked_span_dbg_assert_eq!(bty_a, bty_b);
866                        }
867                        (GenericArg::Base(ctor_a), GenericArg::Base(ctor_b)) => {
868                            let bty_a = ctor_a.as_bty_skipping_binder();
869                            let bty_b = ctor_b.as_bty_skipping_binder();
870                            tracked_span_dbg_assert_eq!(bty_a, bty_b);
871                        }
872                        (_, _) => tracked_span_dbg_assert_eq!(arg_a, arg_b),
873                    }
874                }
875                Ok(())
876            }
877            (BaseTy::Float(float_ty_a), BaseTy::Float(float_ty_b)) => {
878                debug_assert_eq!(float_ty_a, float_ty_b);
879                Ok(())
880            }
881            (BaseTy::Slice(ty_a), BaseTy::Slice(ty_b)) => self.tys(infcx, ty_a, ty_b),
882            (BaseTy::Ref(_, ty_a, Mutability::Mut), BaseTy::Ref(_, ty_b, Mutability::Mut)) => {
883                if ty_a.is_slice()
884                    && let TyKind::Indexed(_, idx_a) = ty_a.kind()
885                    && let TyKind::Exists(bty_b) = ty_b.kind()
886                {
887                    // For `&mut [T1][e] <: &mut ∃v[T2][v]`, we can hoist out the existential on the right because we know
888                    // the index is immutable. This means we have to prove `&mut [T1][e] <: ∃v. &mut [T2][v]`
889                    // This will in turn require proving `&mut [T1][e1] <: &mut [T2][?v]` for a fresh evar `?v`.
890                    // We know the evar will solve to `e`, so subtyping simplifies to the bellow.
891                    self.tys(infcx, ty_a, ty_b)?;
892                    self.tys(infcx, &bty_b.replace_bound_reft(idx_a), ty_a)
893                } else {
894                    self.tys(infcx, ty_a, ty_b)?;
895                    self.tys(infcx, ty_b, ty_a)
896                }
897            }
898            (BaseTy::Ref(_, ty_a, Mutability::Not), BaseTy::Ref(_, ty_b, Mutability::Not)) => {
899                self.tys(infcx, ty_a, ty_b)
900            }
901            (BaseTy::Tuple(tys_a), BaseTy::Tuple(tys_b)) => {
902                debug_assert_eq!(tys_a.len(), tys_b.len());
903                for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
904                    self.tys(infcx, ty_a, ty_b)?;
905                }
906                Ok(())
907            }
908            (_, BaseTy::Alias(AliasKind::Opaque, alias_ty_b)) => {
909                if let BaseTy::Alias(AliasKind::Opaque, alias_ty_a) = a {
910                    debug_assert_eq!(alias_ty_a.refine_args.len(), alias_ty_b.refine_args.len());
911                    iter::zip(alias_ty_a.refine_args.iter(), alias_ty_b.refine_args.iter())
912                        .for_each(|(expr_a, expr_b)| infcx.unify_exprs(expr_a, expr_b));
913                }
914                self.handle_opaque_type(infcx, a, alias_ty_b)
915            }
916            (
917                BaseTy::Alias(AliasKind::Projection, alias_ty_a),
918                BaseTy::Alias(AliasKind::Projection, alias_ty_b),
919            ) => {
920                tracked_span_dbg_assert_eq!(alias_ty_a, alias_ty_b);
921                Ok(())
922            }
923            (BaseTy::Array(ty_a, len_a), BaseTy::Array(ty_b, len_b)) => {
924                tracked_span_dbg_assert_eq!(len_a, len_b);
925                self.tys(infcx, ty_a, ty_b)
926            }
927            (BaseTy::Param(param_a), BaseTy::Param(param_b)) => {
928                debug_assert_eq!(param_a, param_b);
929                Ok(())
930            }
931            (BaseTy::Bool, BaseTy::Bool)
932            | (BaseTy::Str, BaseTy::Str)
933            | (BaseTy::Char, BaseTy::Char)
934            | (BaseTy::RawPtr(_, _), BaseTy::RawPtr(_, _))
935            | (BaseTy::RawPtrMetadata(_), BaseTy::RawPtrMetadata(_)) => Ok(()),
936            (BaseTy::Dynamic(preds_a, _), BaseTy::Dynamic(preds_b, _)) => {
937                tracked_span_assert_eq!(preds_a.erase_regions(), preds_b.erase_regions());
938                Ok(())
939            }
940            (BaseTy::Closure(did1, tys_a, _), BaseTy::Closure(did2, tys_b, _)) if did1 == did2 => {
941                debug_assert_eq!(tys_a.len(), tys_b.len());
942                for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
943                    self.tys(infcx, ty_a, ty_b)?;
944                }
945                Ok(())
946            }
947            (BaseTy::FnPtr(sig_a), BaseTy::FnPtr(sig_b)) => {
948                tracked_span_assert_eq!(sig_a.erase_regions(), sig_b.erase_regions());
949                Ok(())
950            }
951            (BaseTy::Never, BaseTy::Never) => Ok(()),
952            _ => Err(query_bug!("incompatible base types: `{a:?}` - `{b:?}`"))?,
953        }
954    }
955
956    fn generic_args(
957        &mut self,
958        infcx: &mut InferCtxt,
959        variance: Variance,
960        a: &GenericArg,
961        b: &GenericArg,
962    ) -> InferResult {
963        let (ty_a, ty_b) = match (a, b) {
964            (GenericArg::Ty(ty_a), GenericArg::Ty(ty_b)) => (ty_a.clone(), ty_b.clone()),
965            (GenericArg::Base(ctor_a), GenericArg::Base(ctor_b)) => {
966                tracked_span_dbg_assert_eq!(ctor_a.sort(), ctor_b.sort());
967                (ctor_a.to_ty(), ctor_b.to_ty())
968            }
969            (GenericArg::Lifetime(_), GenericArg::Lifetime(_)) => return Ok(()),
970            (GenericArg::Const(cst_a), GenericArg::Const(cst_b)) => {
971                debug_assert_eq!(cst_a, cst_b);
972                return Ok(());
973            }
974            _ => Err(query_bug!("incompatible generic args: `{a:?}` `{b:?}`"))?,
975        };
976        match variance {
977            Variance::Covariant => self.tys(infcx, &ty_a, &ty_b),
978            Variance::Invariant => {
979                self.tys(infcx, &ty_a, &ty_b)?;
980                self.tys(infcx, &ty_b, &ty_a)
981            }
982            Variance::Contravariant => self.tys(infcx, &ty_b, &ty_a),
983            Variance::Bivariant => Ok(()),
984        }
985    }
986
987    fn idxs_eq(&mut self, infcx: &mut InferCtxt, a: &Expr, b: &Expr) {
988        if a == b {
989            return;
990        }
991        match (a.kind(), b.kind()) {
992            (
993                ExprKind::Ctor(Ctor::Struct(did_a), flds_a),
994                ExprKind::Ctor(Ctor::Struct(did_b), flds_b),
995            ) => {
996                debug_assert_eq!(did_a, did_b);
997                for (a, b) in iter::zip(flds_a, flds_b) {
998                    self.idxs_eq(infcx, a, b);
999                }
1000            }
1001            (ExprKind::Tuple(flds_a), ExprKind::Tuple(flds_b)) => {
1002                for (a, b) in iter::zip(flds_a, flds_b) {
1003                    self.idxs_eq(infcx, a, b);
1004                }
1005            }
1006            (_, ExprKind::Tuple(flds_b)) => {
1007                for (f, b) in flds_b.iter().enumerate() {
1008                    let proj = FieldProj::Tuple { arity: flds_b.len(), field: f as u32 };
1009                    let a = a.proj_and_reduce(proj);
1010                    self.idxs_eq(infcx, &a, b);
1011                }
1012            }
1013
1014            (_, ExprKind::Ctor(Ctor::Struct(def_id), flds_b)) => {
1015                for (f, b) in flds_b.iter().enumerate() {
1016                    let proj = FieldProj::Adt { def_id: *def_id, field: f as u32 };
1017                    let a = a.proj_and_reduce(proj);
1018                    self.idxs_eq(infcx, &a, b);
1019                }
1020            }
1021
1022            (ExprKind::Tuple(flds_a), _) => {
1023                infcx.unify_exprs(a, b);
1024                for (f, a) in flds_a.iter().enumerate() {
1025                    let proj = FieldProj::Tuple { arity: flds_a.len(), field: f as u32 };
1026                    let b = b.proj_and_reduce(proj);
1027                    self.idxs_eq(infcx, a, &b);
1028                }
1029            }
1030            (ExprKind::Ctor(Ctor::Struct(def_id), flds_a), _) => {
1031                infcx.unify_exprs(a, b);
1032                for (f, a) in flds_a.iter().enumerate() {
1033                    let proj = FieldProj::Adt { def_id: *def_id, field: f as u32 };
1034                    let b = b.proj_and_reduce(proj);
1035                    self.idxs_eq(infcx, a, &b);
1036                }
1037            }
1038            (ExprKind::Abs(lam_a), ExprKind::Abs(lam_b)) => {
1039                self.abs_eq(infcx, lam_a, lam_b);
1040            }
1041            (_, ExprKind::Abs(lam_b)) => {
1042                self.abs_eq(infcx, &a.eta_expand_abs(lam_b.vars(), lam_b.output()), lam_b);
1043            }
1044            (ExprKind::Abs(lam_a), _) => {
1045                infcx.unify_exprs(a, b);
1046                self.abs_eq(infcx, lam_a, &b.eta_expand_abs(lam_a.vars(), lam_a.output()));
1047            }
1048            (ExprKind::KVar(_), _) | (_, ExprKind::KVar(_)) => {
1049                infcx.check_impl(a, b, self.tag());
1050                infcx.check_impl(b, a, self.tag());
1051            }
1052            _ => {
1053                infcx.unify_exprs(a, b);
1054                let span = b.span();
1055                infcx.check_pred(Expr::binary_op(rty::BinOp::Eq, a, b).at_opt(span), self.tag());
1056            }
1057        }
1058    }
1059
1060    fn abs_eq(&mut self, infcx: &mut InferCtxt, a: &Lambda, b: &Lambda) {
1061        debug_assert_eq!(a.vars().len(), b.vars().len());
1062        let vars = a
1063            .vars()
1064            .iter()
1065            .map(|kind| Expr::fvar(infcx.define_var(kind.expect_sort())))
1066            .collect_vec();
1067        let body_a = a.apply(&vars);
1068        let body_b = b.apply(&vars);
1069        self.idxs_eq(infcx, &body_a, &body_b);
1070    }
1071
1072    fn handle_opaque_type(
1073        &mut self,
1074        infcx: &mut InferCtxt,
1075        bty: &BaseTy,
1076        alias_ty: &AliasTy,
1077    ) -> InferResult {
1078        if let BaseTy::Coroutine(def_id, resume_ty, upvar_tys) = bty {
1079            let obligs = mk_coroutine_obligations(
1080                infcx.genv,
1081                def_id,
1082                resume_ty,
1083                upvar_tys,
1084                &alias_ty.def_id,
1085            )?;
1086            self.obligations.extend(obligs);
1087        } else {
1088            let bounds = infcx.genv.item_bounds(alias_ty.def_id)?.instantiate(
1089                infcx.tcx(),
1090                &alias_ty.args,
1091                &alias_ty.refine_args,
1092            );
1093            for clause in &bounds {
1094                if !clause.kind().vars().is_empty() {
1095                    Err(query_bug!("handle_opaque_types: clause with bound vars: `{clause:?}`"))?;
1096                }
1097                if let rty::ClauseKind::Projection(pred) = clause.kind_skipping_binder() {
1098                    let alias_ty = pred.projection_ty.with_self_ty(bty.to_subset_ty_ctor());
1099                    let ty1 = BaseTy::Alias(AliasKind::Projection, alias_ty)
1100                        .to_ty()
1101                        .deeply_normalize(&mut infcx.at(self.span))?;
1102                    let ty2 = pred.term.to_ty();
1103                    self.tys(infcx, &ty1, &ty2)?;
1104                }
1105            }
1106        }
1107        Ok(())
1108    }
1109}
1110
1111fn mk_coroutine_obligations(
1112    genv: GlobalEnv,
1113    generator_did: &DefId,
1114    resume_ty: &Ty,
1115    upvar_tys: &List<Ty>,
1116    opaque_def_id: &DefId,
1117) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
1118    let bounds = genv.item_bounds(*opaque_def_id)?.skip_binder();
1119    for bound in &bounds {
1120        if let Some(proj_clause) = bound.as_projection_clause() {
1121            return Ok(vec![proj_clause.map(|proj_clause| {
1122                let output = proj_clause.term;
1123                CoroutineObligPredicate {
1124                    def_id: *generator_did,
1125                    resume_ty: resume_ty.clone(),
1126                    upvar_tys: upvar_tys.clone(),
1127                    output: output.to_ty(),
1128                }
1129            })]);
1130        }
1131    }
1132    bug!("no projection predicate")
1133}
1134
1135#[derive(Debug)]
1136pub enum InferErr {
1137    UnsolvedEvar(EVid),
1138    Query(QueryErr),
1139}
1140
1141impl From<QueryErr> for InferErr {
1142    fn from(v: QueryErr) -> Self {
1143        Self::Query(v)
1144    }
1145}
1146
1147mod pretty {
1148    use std::fmt;
1149
1150    use flux_middle::pretty::*;
1151
1152    use super::*;
1153
1154    impl Pretty for Tag {
1155        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1156            w!(cx, f, "{:?} at {:?}", ^self.reason, self.src_span)?;
1157            if let Some(dst_span) = self.dst_span {
1158                w!(cx, f, " ({:?})", ^dst_span)?;
1159            }
1160            Ok(())
1161        }
1162    }
1163
1164    impl_debug_with_default_cx!(Tag);
1165}