flux_infer/
infer.rs

1use std::{cell::RefCell, fmt, iter};
2
3use flux_common::{bug, dbg, tracked_span_assert_eq, tracked_span_bug, tracked_span_dbg_assert_eq};
4use flux_config::{self as config, InferOpts, OverflowMode, RawDerefMode};
5use flux_macros::{TypeFoldable, TypeVisitable};
6use flux_middle::{
7    FixpointQueryKind, PanicSpec,
8    def_id::MaybeExternId,
9    global_env::GlobalEnv,
10    metrics::{self, Metric},
11    queries::{QueryErr, QueryResult},
12    query_bug,
13    rty::{
14        self, AliasKind, AliasTy, BaseTy, Binder, BoundReftKind, BoundVariableKinds,
15        CoroutineObligPredicate, Ctor, ESpan, EVid, EarlyBinder, Expr, ExprKind, FieldProj,
16        GenericArg, HoleKind, InferMode, Lambda, List, Loc, Mutability, Name, NameProvenance, Path,
17        PolyVariant, PtrKind, RefineArgs, RefineArgsExt, Region, Sort, Ty, TyCtor, TyKind, Var,
18        canonicalize::{Hoister, HoisterDelegate},
19        fold::TypeFoldable,
20    },
21};
22use itertools::{Itertools, izip};
23use rustc_hir::def_id::{DefId, LocalDefId};
24use rustc_macros::extension;
25use rustc_middle::{
26    mir::BasicBlock,
27    ty::{TyCtxt, Variance},
28};
29use rustc_span::{Span, Symbol};
30use rustc_type_ir::Variance::Invariant;
31
32use crate::{
33    evars::{EVarState, EVarStore},
34    fixpoint_encoding::{
35        Answer, Backend, FixQueryCache, FixpointCtxt, KVarEncoding, KVarGen, lean_task_key,
36    },
37    lean_encoding::log_proof,
38    projections::NormalizeExt as _,
39    refine_tree::{Cursor, Marker, RefineTree, Scope},
40};
41
42pub type InferResult<T = ()> = std::result::Result<T, InferErr>;
43
44#[derive(PartialEq, Eq, Clone, Copy, Hash)]
45pub struct Tag {
46    pub reason: ConstrReason,
47    pub src_span: Span,
48    pub dst_span: Option<ESpan>,
49}
50
51impl Tag {
52    pub fn new(reason: ConstrReason, span: Span) -> Self {
53        Self { reason, src_span: span, dst_span: None }
54    }
55
56    pub fn with_dst(self, dst_span: Option<ESpan>) -> Self {
57        Self { dst_span, ..self }
58    }
59}
60
61#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)]
62pub enum SubtypeReason {
63    Input,
64    Output,
65    Requires,
66    Ensures,
67}
68
69#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)]
70pub enum ConstrReason {
71    Call,
72    Assign,
73    Ret,
74    Fold,
75    FoldLocal,
76    Predicate,
77    Assert(&'static str),
78    Div,
79    Rem,
80    Goto(BasicBlock),
81    Overflow,
82    Underflow,
83    Subtype(SubtypeReason),
84    NoPanic(DefId, PanicSpec),
85    Other,
86}
87
88pub struct InferCtxtRoot<'genv, 'tcx> {
89    pub genv: GlobalEnv<'genv, 'tcx>,
90    inner: RefCell<InferCtxtInner>,
91    refine_tree: RefineTree,
92    opts: InferOpts,
93}
94
95pub struct InferCtxtRootBuilder<'a, 'genv, 'tcx> {
96    genv: GlobalEnv<'genv, 'tcx>,
97    opts: InferOpts,
98    params: Vec<(Var, Sort)>,
99    infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
100    dummy_kvars: bool,
101}
102
103#[extension(pub trait GlobalEnvExt<'genv, 'tcx>)]
104impl<'genv, 'tcx> GlobalEnv<'genv, 'tcx> {
105    fn infcx_root<'a>(
106        self,
107        infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
108        opts: InferOpts,
109    ) -> InferCtxtRootBuilder<'a, 'genv, 'tcx> {
110        InferCtxtRootBuilder { genv: self, infcx, params: vec![], opts, dummy_kvars: false }
111    }
112}
113
114impl<'genv, 'tcx> InferCtxtRootBuilder<'_, 'genv, 'tcx> {
115    pub fn with_dummy_kvars(mut self) -> Self {
116        self.dummy_kvars = true;
117        self
118    }
119
120    pub fn with_const_generics(mut self, def_id: DefId) -> QueryResult<Self> {
121        self.params.extend(
122            self.genv
123                .generics_of(def_id)?
124                .const_params(self.genv)?
125                .into_iter()
126                .map(|(pcst, sort)| (Var::ConstGeneric(pcst), sort)),
127        );
128        Ok(self)
129    }
130
131    pub fn with_refinement_generics(
132        mut self,
133        def_id: DefId,
134        args: &[GenericArg],
135    ) -> QueryResult<Self> {
136        for (index, param) in self
137            .genv
138            .refinement_generics_of(def_id)?
139            .iter_own_params()
140            .enumerate()
141        {
142            let param = param.instantiate(self.genv.tcx(), args, &[]);
143            let sort = param
144                .sort
145                .deeply_normalize_sorts(def_id, self.genv, self.infcx)?;
146
147            let var =
148                Var::EarlyParam(rty::EarlyReftParam { index: index as u32, name: param.name });
149            self.params.push((var, sort));
150        }
151        Ok(self)
152    }
153
154    pub fn identity_for_item(mut self, def_id: DefId) -> QueryResult<Self> {
155        self = self.with_const_generics(def_id)?;
156        let offset = self.params.len();
157        self.genv.refinement_generics_of(def_id)?.fill_item(
158            self.genv,
159            &mut self.params,
160            &mut |param, index| {
161                let index = (index - offset) as u32;
162                let param = param.instantiate_identity();
163                let sort = param
164                    .sort
165                    .deeply_normalize_sorts(def_id, self.genv, self.infcx)?;
166
167                let var = Var::EarlyParam(rty::EarlyReftParam { index, name: param.name });
168                Ok((var, sort))
169            },
170        )?;
171        Ok(self)
172    }
173
174    pub fn build(self) -> QueryResult<InferCtxtRoot<'genv, 'tcx>> {
175        Ok(InferCtxtRoot {
176            genv: self.genv,
177            inner: RefCell::new(InferCtxtInner::new(self.dummy_kvars)),
178            refine_tree: RefineTree::new(self.params),
179            opts: self.opts,
180        })
181    }
182}
183
184impl<'genv, 'tcx> InferCtxtRoot<'genv, 'tcx> {
185    pub fn infcx<'a>(
186        &'a mut self,
187        def_id: DefId,
188        region_infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
189    ) -> InferCtxt<'a, 'genv, 'tcx> {
190        InferCtxt {
191            genv: self.genv,
192            region_infcx,
193            def_id,
194            cursor: self.refine_tree.cursor_at_root(),
195            inner: &self.inner,
196            check_overflow: self.opts.check_overflow,
197            allow_raw_deref: self.opts.allow_raw_deref,
198        }
199    }
200
201    pub fn fresh_kvar_in_scope(
202        &self,
203        binders: &[BoundVariableKinds],
204        scope: &Scope,
205        encoding: KVarEncoding,
206    ) -> Expr {
207        let inner = &mut *self.inner.borrow_mut();
208        inner.kvars.fresh(binders, scope.iter(), encoding)
209    }
210
211    pub fn execute_lean_query(
212        self,
213        cache: &mut FixQueryCache,
214        def_id: MaybeExternId,
215    ) -> QueryResult {
216        let inner = self.inner.into_inner();
217        let kvars = inner.kvars;
218        let evars = inner.evars;
219        let mut refine_tree = self.refine_tree;
220        refine_tree.replace_evars(&evars).unwrap();
221        refine_tree.simplify(self.genv);
222
223        let solver = match self.opts.solver {
224            flux_config::SmtSolver::Z3 => liquid_fixpoint::SmtSolver::Z3,
225            flux_config::SmtSolver::CVC5 => liquid_fixpoint::SmtSolver::CVC5,
226        };
227        let mut fcx = FixpointCtxt::new(self.genv, def_id, kvars, Backend::Lean);
228        let cstr = refine_tree.to_fixpoint(&mut fcx)?;
229        let task = fcx.create_task(def_id, cstr, self.opts.scrape_quals, solver)?;
230
231        log_proof(self.genv, def_id)?;
232        // Skip re-generation if task is already cached (same hash → same lean files on disk).
233        if config::is_cache_enabled() {
234            let key = lean_task_key(self.genv.tcx(), def_id.resolved_id());
235            let hash = task.hash_with_default();
236            if cache.lookup(&key, hash).is_some() {
237                return Ok(());
238            }
239        }
240
241        fcx.generate_lean_files(def_id, task)
242    }
243
244    pub fn execute_fixpoint_query(
245        self,
246        cache: &mut FixQueryCache,
247        def_id: MaybeExternId,
248        kind: FixpointQueryKind,
249    ) -> QueryResult<Answer<Tag>> {
250        let inner = self.inner.into_inner();
251        let kvars = inner.kvars;
252        let evars = inner.evars;
253
254        let ext = kind.ext();
255
256        let mut refine_tree = self.refine_tree;
257
258        refine_tree.replace_evars(&evars).unwrap();
259
260        if config::dump_constraint() {
261            dbg::dump_item_info(self.genv.tcx(), def_id.resolved_id(), ext, &refine_tree).unwrap();
262        }
263        refine_tree.simplify(self.genv);
264        if config::dump_constraint() {
265            let simp_ext = format!("simp.{ext}");
266            dbg::dump_item_info(self.genv.tcx(), def_id.resolved_id(), simp_ext, &refine_tree)
267                .unwrap();
268        }
269
270        let backend = match self.opts.solver {
271            flux_config::SmtSolver::Z3 => liquid_fixpoint::SmtSolver::Z3,
272            flux_config::SmtSolver::CVC5 => liquid_fixpoint::SmtSolver::CVC5,
273        };
274
275        let mut fcx = FixpointCtxt::new(self.genv, def_id, kvars, Backend::Fixpoint);
276        let cstr = refine_tree.to_fixpoint(&mut fcx)?;
277
278        // skip checking trivial constraints
279        let count = cstr.concrete_head_count();
280        metrics::incr_metric(Metric::CsTotal, count as u32);
281        if count == 0 {
282            metrics::incr_metric_if(kind.is_body(), Metric::FnTrivial);
283            return Ok(Answer::trivial());
284        }
285
286        let task = fcx.create_task(def_id, cstr, self.opts.scrape_quals, backend)?;
287        let result = fcx.run_task(cache, def_id, kind, &task)?;
288        Ok(fcx.result_to_answer(result))
289    }
290
291    pub fn split(self) -> (RefineTree, KVarGen) {
292        (self.refine_tree, self.inner.into_inner().kvars)
293    }
294}
295
296pub struct InferCtxt<'infcx, 'genv, 'tcx> {
297    pub genv: GlobalEnv<'genv, 'tcx>,
298    pub region_infcx: &'infcx rustc_infer::infer::InferCtxt<'tcx>,
299    pub def_id: DefId,
300    pub check_overflow: OverflowMode,
301    pub allow_raw_deref: flux_config::RawDerefMode,
302    cursor: Cursor<'infcx>,
303    inner: &'infcx RefCell<InferCtxtInner>,
304}
305
306struct InferCtxtInner {
307    kvars: KVarGen,
308    evars: EVarStore,
309}
310
311impl InferCtxtInner {
312    fn new(dummy_kvars: bool) -> Self {
313        Self { kvars: KVarGen::new(dummy_kvars), evars: Default::default() }
314    }
315}
316
317impl<'infcx, 'genv, 'tcx> InferCtxt<'infcx, 'genv, 'tcx> {
318    pub fn at(&mut self, span: Span) -> InferCtxtAt<'_, 'infcx, 'genv, 'tcx> {
319        InferCtxtAt { infcx: self, span }
320    }
321
322    pub fn instantiate_refine_args(
323        &mut self,
324        callee_def_id: DefId,
325        args: &[rty::GenericArg],
326    ) -> InferResult<List<Expr>> {
327        Ok(RefineArgs::for_item(self.genv, callee_def_id, |param, _| {
328            let param = param.instantiate(self.genv.tcx(), args, &[]);
329            Ok(self.fresh_infer_var(&param.sort, param.mode))
330        })?)
331    }
332
333    pub fn instantiate_generic_args(&mut self, args: &[GenericArg]) -> Vec<GenericArg> {
334        args.iter()
335            .map(|a| a.replace_holes(|binders, kind| self.fresh_infer_var_for_hole(binders, kind)))
336            .collect_vec()
337    }
338
339    pub fn fresh_infer_var(&self, sort: &Sort, mode: InferMode) -> Expr {
340        match mode {
341            InferMode::KVar => {
342                let fsort = sort.expect_func().expect_mono();
343                let vars = fsort.inputs().iter().cloned().map_into().collect();
344                let kvar = self.fresh_kvar(&[vars], KVarEncoding::Single);
345                Expr::abs(Lambda::bind_with_fsort(kvar, fsort))
346            }
347            InferMode::EVar => self.fresh_evar(),
348        }
349    }
350
351    pub fn fresh_infer_var_for_hole(
352        &mut self,
353        binders: &[BoundVariableKinds],
354        kind: HoleKind,
355    ) -> Expr {
356        match kind {
357            HoleKind::Pred => self.fresh_kvar(binders, KVarEncoding::Conj),
358            HoleKind::Expr(_) => {
359                // We only use expression holes to infer early param arguments for opaque types
360                // at function calls. These should be well-scoped in the current scope, so we ignore
361                // the extra `binders` around the hole.
362                self.fresh_evar()
363            }
364        }
365    }
366
367    /// Generate a fresh kvar in the _given_ [`Scope`] (similar method in [`InferCtxtRoot`]).
368    pub fn fresh_kvar_in_scope(
369        &self,
370        binders: &[BoundVariableKinds],
371        scope: &Scope,
372        encoding: KVarEncoding,
373    ) -> Expr {
374        let inner = &mut *self.inner.borrow_mut();
375        inner.kvars.fresh(binders, scope.iter(), encoding)
376    }
377
378    /// Generate a fresh kvar in the current scope. See [`KVarGen::fresh`].
379    pub fn fresh_kvar(&self, binders: &[BoundVariableKinds], encoding: KVarEncoding) -> Expr {
380        let inner = &mut *self.inner.borrow_mut();
381        inner.kvars.fresh(binders, self.cursor.vars(), encoding)
382    }
383
384    fn fresh_evar(&self) -> Expr {
385        let evars = &mut self.inner.borrow_mut().evars;
386        Expr::evar(evars.fresh(self.cursor.marker()))
387    }
388
389    pub fn unify_exprs(&self, a: &Expr, b: &Expr) {
390        if a.has_evars() {
391            return;
392        }
393        let evars = &mut self.inner.borrow_mut().evars;
394        if let ExprKind::Var(Var::EVar(evid)) = b.kind()
395            && let EVarState::Unsolved(marker) = evars.get(*evid)
396            && !marker.has_free_vars(a)
397        {
398            evars.solve(*evid, a.clone());
399        }
400    }
401
402    fn enter_exists<T, U>(
403        &mut self,
404        t: &Binder<T>,
405        f: impl FnOnce(&mut InferCtxt<'_, 'genv, 'tcx>, T) -> U,
406    ) -> U
407    where
408        T: TypeFoldable,
409    {
410        self.ensure_resolved_evars(|infcx| {
411            let t = t.replace_bound_refts_with(|sort, mode, _| infcx.fresh_infer_var(sort, mode));
412            Ok(f(infcx, t))
413        })
414        .unwrap()
415    }
416
417    /// Used in conjunction with [`InferCtxt::pop_evar_scope`] to ensure evars are solved at the end
418    /// of some scope, for example, to ensure all evars generated during a function call are solved
419    /// after checking argument subtyping. These functions can be used in a stack-like fashion to
420    /// create nested scopes.
421    pub fn push_evar_scope(&mut self) {
422        self.inner.borrow_mut().evars.push_scope();
423    }
424
425    /// Pop a scope and check all evars have been solved. This only check evars generated from the
426    /// last call to [`InferCtxt::push_evar_scope`].
427    pub fn pop_evar_scope(&mut self) -> InferResult {
428        self.inner
429            .borrow_mut()
430            .evars
431            .pop_scope()
432            .map_err(InferErr::UnsolvedEvar)
433    }
434
435    /// Convenience method pairing [`InferCtxt::push_evar_scope`] and [`InferCtxt::pop_evar_scope`].
436    pub fn ensure_resolved_evars<R>(
437        &mut self,
438        f: impl FnOnce(&mut Self) -> InferResult<R>,
439    ) -> InferResult<R> {
440        self.push_evar_scope();
441        let r = f(self)?;
442        self.pop_evar_scope()?;
443        Ok(r)
444    }
445
446    pub fn fully_resolve_evars<T: TypeFoldable>(&self, t: &T) -> T {
447        self.inner.borrow().evars.replace_evars(t).unwrap()
448    }
449
450    pub fn tcx(&self) -> TyCtxt<'tcx> {
451        self.genv.tcx()
452    }
453
454    pub fn cursor(&self) -> &Cursor<'infcx> {
455        &self.cursor
456    }
457
458    pub fn allow_raw_deref(&self) -> bool {
459        matches!(self.allow_raw_deref, RawDerefMode::Ok)
460    }
461}
462
463/// Methods that interact with the underlying [`Cursor`]
464impl<'infcx, 'genv, 'tcx> InferCtxt<'infcx, 'genv, 'tcx> {
465    pub fn change_item<'a>(
466        &'a mut self,
467        def_id: LocalDefId,
468        region_infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
469    ) -> InferCtxt<'a, 'genv, 'tcx> {
470        InferCtxt {
471            def_id: def_id.to_def_id(),
472            cursor: self.cursor.branch(),
473            region_infcx,
474            ..*self
475        }
476    }
477
478    pub fn move_to(&mut self, marker: &Marker, clear_children: bool) -> InferCtxt<'_, 'genv, 'tcx> {
479        InferCtxt {
480            cursor: self
481                .cursor
482                .move_to(marker, clear_children)
483                .unwrap_or_else(|| tracked_span_bug!()),
484            ..*self
485        }
486    }
487
488    pub fn branch(&mut self) -> InferCtxt<'_, 'genv, 'tcx> {
489        InferCtxt { cursor: self.cursor.branch(), ..*self }
490    }
491
492    fn define_var(&mut self, sort: &Sort, provenance: NameProvenance) -> Name {
493        self.cursor.define_var(sort, provenance)
494    }
495
496    pub fn define_bound_reft_var(&mut self, sort: &Sort, kind: BoundReftKind) -> Name {
497        self.define_var(sort, NameProvenance::UnfoldBoundReft(kind))
498    }
499
500    pub fn define_unknown_var(&mut self, sort: &Sort) -> Name {
501        self.cursor.define_var(sort, NameProvenance::Unknown)
502    }
503
504    pub fn check_pred(&mut self, pred: impl Into<Expr>, tag: Tag) {
505        self.cursor.check_pred(pred, tag);
506    }
507
508    pub fn assume_pred(&mut self, pred: impl Into<Expr>) {
509        self.cursor.assume_pred(pred);
510    }
511
512    pub fn unpack(&mut self, ty: &Ty) -> Ty {
513        self.hoister(false).hoist(ty)
514    }
515
516    pub fn unpack_at_name(&mut self, name: Option<Symbol>, ty: &Ty) -> Ty {
517        let mut hoister = self.hoister(false);
518        hoister.delegate.name = name;
519        hoister.hoist(ty)
520    }
521
522    pub fn marker(&self) -> Marker {
523        self.cursor.marker()
524    }
525
526    pub fn hoister(
527        &mut self,
528        assume_invariants: bool,
529    ) -> Hoister<Unpacker<'_, 'infcx, 'genv, 'tcx>> {
530        Hoister::with_delegate(Unpacker { infcx: self, assume_invariants, name: None })
531            .transparent()
532    }
533
534    pub fn assume_invariants(&mut self, ty: &Ty) {
535        self.cursor
536            .assume_invariants(self.genv.tcx(), ty, self.check_overflow);
537    }
538
539    fn check_impl(&mut self, pred1: impl Into<Expr>, pred2: impl Into<Expr>, tag: Tag) {
540        self.cursor.check_impl(pred1, pred2, tag);
541    }
542}
543
544pub struct Unpacker<'a, 'infcx, 'genv, 'tcx> {
545    infcx: &'a mut InferCtxt<'infcx, 'genv, 'tcx>,
546    assume_invariants: bool,
547    name: Option<Symbol>,
548}
549
550impl HoisterDelegate for Unpacker<'_, '_, '_, '_> {
551    fn hoist_exists(&mut self, ty_ctor: &TyCtor) -> Ty {
552        let ty = ty_ctor.replace_bound_refts_with(|sort, _, kind| {
553            let kind = if let Some(name) = self.name { BoundReftKind::Named(name) } else { kind };
554            Expr::fvar(self.infcx.define_bound_reft_var(sort, kind))
555        });
556        if self.assume_invariants {
557            self.infcx.assume_invariants(&ty);
558        }
559        ty
560    }
561
562    fn hoist_constr(&mut self, pred: Expr) {
563        self.infcx.assume_pred(pred);
564    }
565}
566
567impl std::fmt::Debug for InferCtxt<'_, '_, '_> {
568    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
569        std::fmt::Debug::fmt(&self.cursor, f)
570    }
571}
572
573#[derive(Debug)]
574pub struct InferCtxtAt<'a, 'infcx, 'genv, 'tcx> {
575    pub infcx: &'a mut InferCtxt<'infcx, 'genv, 'tcx>,
576    pub span: Span,
577}
578
579impl<'genv, 'tcx> InferCtxtAt<'_, '_, 'genv, 'tcx> {
580    fn tag(&self, reason: ConstrReason) -> Tag {
581        Tag::new(reason, self.span)
582    }
583
584    pub fn check_pred(&mut self, pred: impl Into<Expr>, reason: ConstrReason) {
585        let tag = self.tag(reason);
586        self.infcx.check_pred(pred, tag);
587    }
588
589    pub fn check_non_closure_clauses(
590        &mut self,
591        clauses: &[rty::Clause],
592        reason: ConstrReason,
593    ) -> InferResult {
594        for clause in clauses {
595            if let rty::ClauseKind::Projection(projection_pred) = clause.kind_skipping_binder() {
596                let impl_elem = BaseTy::projection(projection_pred.projection_ty)
597                    .to_ty()
598                    .deeply_normalize(self)?;
599                let term = projection_pred.term.to_ty().deeply_normalize(self)?;
600
601                // TODO: does this really need to be invariant? https://github.com/flux-rs/flux/pull/478#issuecomment-1654035374
602                self.subtyping(&impl_elem, &term, reason)?;
603                self.subtyping(&term, &impl_elem, reason)?;
604            }
605        }
606        Ok(())
607    }
608
609    /// Relate types via subtyping. This is the same as [`InferCtxtAt::subtyping`] except that we
610    /// also require a [`LocEnv`] to handle pointers and strong references
611    pub fn subtyping_with_env(
612        &mut self,
613        env: &mut impl LocEnv,
614        a: &Ty,
615        b: &Ty,
616        reason: ConstrReason,
617    ) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
618        let mut sub = Sub::new(env, reason, self.span);
619        sub.tys(self.infcx, a, b)?;
620        Ok(sub.obligations)
621    }
622
623    /// Relate types via subtyping and returns coroutine obligations. This doesn't handle subtyping
624    /// when strong references are involved.
625    ///
626    /// See comment for [`Sub::obligations`].
627    pub fn subtyping(
628        &mut self,
629        a: &Ty,
630        b: &Ty,
631        reason: ConstrReason,
632    ) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
633        let mut env = DummyEnv;
634        let mut sub = Sub::new(&mut env, reason, self.span);
635        sub.tys(self.infcx, a, b)?;
636        Ok(sub.obligations)
637    }
638
639    pub fn subtyping_generic_args(
640        &mut self,
641        variance: Variance,
642        a: &GenericArg,
643        b: &GenericArg,
644        reason: ConstrReason,
645    ) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
646        let mut env = DummyEnv;
647        let mut sub = Sub::new(&mut env, reason, self.span);
648        sub.generic_args(self.infcx, variance, a, b)?;
649        Ok(sub.obligations)
650    }
651
652    // FIXME(nilehmann) this is similar to `Checker::check_call`, but since is used from
653    // `place_ty::fold` we cannot use that directly. We should try to unify them, because
654    // there are a couple of things missing here (e.g., checking clauses on the struct definition).
655    pub fn check_constructor(
656        &mut self,
657        variant: EarlyBinder<PolyVariant>,
658        generic_args: &[GenericArg],
659        fields: &[Ty],
660        reason: ConstrReason,
661    ) -> InferResult<Ty> {
662        let ret = self.ensure_resolved_evars(|this| {
663            // Replace holes in generic arguments with fresh inference variables
664            let generic_args = this.instantiate_generic_args(generic_args);
665
666            let variant = variant
667                .instantiate(this.tcx(), &generic_args, &[])
668                .replace_bound_refts_with(|sort, mode, _| this.fresh_infer_var(sort, mode));
669
670            // Check arguments
671            for (actual, formal) in iter::zip(fields, variant.fields()) {
672                this.subtyping(actual, formal, reason)?;
673            }
674
675            // Check requires predicates
676            for require in &variant.requires {
677                this.check_pred(require, ConstrReason::Fold);
678            }
679
680            Ok(variant.ret())
681        })?;
682        Ok(self.fully_resolve_evars(&ret))
683    }
684
685    pub fn ensure_resolved_evars<R>(
686        &mut self,
687        f: impl FnOnce(&mut InferCtxtAt<'_, '_, 'genv, 'tcx>) -> InferResult<R>,
688    ) -> InferResult<R> {
689        self.infcx
690            .ensure_resolved_evars(|infcx| f(&mut infcx.at(self.span)))
691    }
692}
693
694impl<'a, 'genv, 'tcx> std::ops::Deref for InferCtxtAt<'_, 'a, 'genv, 'tcx> {
695    type Target = InferCtxt<'a, 'genv, 'tcx>;
696
697    fn deref(&self) -> &Self::Target {
698        self.infcx
699    }
700}
701
702impl std::ops::DerefMut for InferCtxtAt<'_, '_, '_, '_> {
703    fn deref_mut(&mut self) -> &mut Self::Target {
704        self.infcx
705    }
706}
707
708/// Used for debugging to attach a "trace" to the [`RefineTree`] that can be used to print information
709/// to recover the derivation when relating types via subtyping. The code that attaches the trace is
710/// currently commented out because the output is too verbose.
711#[derive(TypeVisitable, TypeFoldable)]
712pub(crate) enum TypeTrace {
713    Types(Ty, Ty),
714    BaseTys(BaseTy, BaseTy),
715}
716
717#[expect(dead_code, reason = "we use this for debugging some time")]
718impl TypeTrace {
719    fn tys(a: &Ty, b: &Ty) -> Self {
720        Self::Types(a.clone(), b.clone())
721    }
722
723    fn btys(a: &BaseTy, b: &BaseTy) -> Self {
724        Self::BaseTys(a.clone(), b.clone())
725    }
726}
727
728impl fmt::Debug for TypeTrace {
729    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
730        match self {
731            TypeTrace::Types(a, b) => write!(f, "{a:?} - {b:?}"),
732            TypeTrace::BaseTys(a, b) => write!(f, "{a:?} - {b:?}"),
733        }
734    }
735}
736
737pub trait LocEnv {
738    fn ptr_to_ref(
739        &mut self,
740        infcx: &mut InferCtxtAt,
741        reason: ConstrReason,
742        re: Region,
743        path: &Path,
744        bound: Ty,
745    ) -> InferResult<Ty>;
746
747    fn unfold_strg_ref(&mut self, infcx: &mut InferCtxt, path: &Path, ty: &Ty) -> InferResult<Loc>;
748
749    fn get(&self, path: &Path) -> Ty;
750}
751
752struct DummyEnv;
753
754impl LocEnv for DummyEnv {
755    fn ptr_to_ref(
756        &mut self,
757        _: &mut InferCtxtAt,
758        _: ConstrReason,
759        _: Region,
760        _: &Path,
761        _: Ty,
762    ) -> InferResult<Ty> {
763        tracked_span_bug!("call to `ptr_to_ref` on `DummyEnv`")
764    }
765
766    fn unfold_strg_ref(&mut self, _: &mut InferCtxt, _: &Path, _: &Ty) -> InferResult<Loc> {
767        tracked_span_bug!("call to `unfold_str_ref` on `DummyEnv`")
768    }
769
770    fn get(&self, _: &Path) -> Ty {
771        tracked_span_bug!("call to `get` on `DummyEnv`")
772    }
773}
774
775/// Context used to relate two types `a` and `b` via subtyping
776struct Sub<'a, E> {
777    /// The environment to lookup locations pointed to by [`TyKind::Ptr`].
778    env: &'a mut E,
779    reason: ConstrReason,
780    span: Span,
781    /// FIXME(nilehmann) This is used to store coroutine obligations generated during subtyping when
782    /// relating an opaque type. Other obligations related to relating opaque types are resolved
783    /// directly here. The implementation is really messy and we may be missing some obligations.
784    obligations: Vec<Binder<rty::CoroutineObligPredicate>>,
785}
786
787impl<'a, E: LocEnv> Sub<'a, E> {
788    fn new(env: &'a mut E, reason: ConstrReason, span: Span) -> Self {
789        Self { env, reason, span, obligations: vec![] }
790    }
791
792    fn tag(&self) -> Tag {
793        Tag::new(self.reason, self.span)
794    }
795
796    fn tys(&mut self, infcx: &mut InferCtxt, a: &Ty, b: &Ty) -> InferResult {
797        let infcx = &mut infcx.branch();
798        // infcx.cursor.push_trace(TypeTrace::tys(a, b));
799
800        // We *fully* unpack the lhs before continuing to be able to prove goals like this
801        // ∃a. (i32[a], ∃b. {i32[b] | a > b})} <: ∃a,b. ({i32[a] | b < a}, i32[b])
802        // See S4.5 in https://arxiv.org/pdf/2209.13000v1.pdf
803        let a = infcx.unpack(a);
804
805        match (a.kind(), b.kind()) {
806            (TyKind::Exists(..), _) => {
807                bug!("existentials should have been removed by the unpacking above");
808            }
809            (TyKind::Constr(..), _) => {
810                bug!("constraint types should have been removed by the unpacking above");
811            }
812
813            (_, TyKind::Exists(ctor_b)) => {
814                infcx.enter_exists(ctor_b, |infcx, ty_b| self.tys(infcx, &a, &ty_b))
815            }
816            (_, TyKind::Constr(pred_b, ty_b)) => {
817                infcx.check_pred(pred_b, self.tag());
818                self.tys(infcx, &a, ty_b)
819            }
820
821            (TyKind::Ptr(PtrKind::Mut(_), path_a), TyKind::StrgRef(_, path_b, ty_b)) => {
822                // We should technically remove `path1` from `env`, but we are assuming that functions
823                // always give back ownership of the location so `path1` is going to be overwritten
824                // after the call anyways.
825                let ty_a = self.env.get(path_a);
826                infcx.unify_exprs(&path_a.to_expr(), &path_b.to_expr());
827                self.tys(infcx, &ty_a, ty_b)
828            }
829            (TyKind::StrgRef(_, path_a, ty_a), TyKind::StrgRef(_, path_b, ty_b)) => {
830                // We have to unfold strong references prior to a subtyping check. Normally, when
831                // checking a function body, a `StrgRef` is automatically unfolded i.e. `x:&strg T`
832                // is turned into a `x:ptr(l); l: T` where `l` is some fresh location. However, we
833                // need the below to do a similar unfolding during function subtyping where we just
834                // have the super-type signature that needs to be unfolded. We also add the binding
835                // to the environment so that we can:
836                // (1) UPDATE the location after the call, and
837                // (2) CHECK the relevant `ensures` clauses of the super-sig.
838                // Same as the `Ptr` case above we should remove the location from the environment
839                // after unfolding to consume it, but we are assuming functions always give back
840                // ownership.
841                self.env.unfold_strg_ref(infcx, path_a, ty_a)?;
842                let ty_a = self.env.get(path_a);
843                infcx.unify_exprs(&path_a.to_expr(), &path_b.to_expr());
844                self.tys(infcx, &ty_a, ty_b)
845            }
846            (
847                TyKind::Ptr(PtrKind::Mut(re), path),
848                TyKind::Indexed(BaseTy::Ref(_, bound, Mutability::Mut), idx),
849            ) => {
850                // We sometimes generate evars for the index of references so we need to make sure
851                // we solve them.
852                self.idxs_eq(infcx, &Expr::unit(), idx);
853
854                self.env.ptr_to_ref(
855                    &mut infcx.at(self.span),
856                    self.reason,
857                    *re,
858                    path,
859                    bound.clone(),
860                )?;
861                Ok(())
862            }
863
864            (TyKind::Indexed(bty_a, idx_a), TyKind::Indexed(bty_b, idx_b)) => {
865                self.btys(infcx, bty_a, bty_b)?;
866                self.idxs_eq(infcx, idx_a, idx_b);
867                Ok(())
868            }
869            (TyKind::Ptr(pk_a, path_a), TyKind::Ptr(pk_b, path_b)) => {
870                debug_assert_eq!(pk_a, pk_b);
871                debug_assert_eq!(path_a, path_b);
872                Ok(())
873            }
874            (TyKind::Param(param_ty_a), TyKind::Param(param_ty_b)) => {
875                debug_assert_eq!(param_ty_a, param_ty_b);
876                Ok(())
877            }
878            (_, TyKind::Uninit) => Ok(()),
879            (TyKind::Downcast(.., fields_a), TyKind::Downcast(.., fields_b)) => {
880                debug_assert_eq!(fields_a.len(), fields_b.len());
881                for (ty_a, ty_b) in iter::zip(fields_a, fields_b) {
882                    self.tys(infcx, ty_a, ty_b)?;
883                }
884                Ok(())
885            }
886            _ => Err(query_bug!("incompatible types: `{a:?}` - `{b:?}`"))?,
887        }
888    }
889
890    fn btys(&mut self, infcx: &mut InferCtxt, a: &BaseTy, b: &BaseTy) -> InferResult {
891        // infcx.push_trace(TypeTrace::btys(a, b));
892
893        match (a, b) {
894            (BaseTy::Int(int_ty_a), BaseTy::Int(int_ty_b)) => {
895                debug_assert_eq!(int_ty_a, int_ty_b);
896                Ok(())
897            }
898            (BaseTy::Uint(uint_ty_a), BaseTy::Uint(uint_ty_b)) => {
899                debug_assert_eq!(uint_ty_a, uint_ty_b);
900                Ok(())
901            }
902            (BaseTy::Adt(a_adt, a_args), BaseTy::Adt(b_adt, b_args)) => {
903                tracked_span_dbg_assert_eq!(a_adt.did(), b_adt.did());
904                tracked_span_dbg_assert_eq!(a_args.len(), b_args.len());
905                let variances = infcx.genv.variances_of(a_adt.did());
906                for (variance, ty_a, ty_b) in izip!(variances, a_args.iter(), b_args.iter()) {
907                    self.generic_args(infcx, *variance, ty_a, ty_b)?;
908                }
909                Ok(())
910            }
911            (BaseTy::FnDef(a_def_id, a_args), BaseTy::FnDef(b_def_id, b_args)) => {
912                debug_assert_eq!(a_def_id, b_def_id);
913                debug_assert_eq!(a_args.len(), b_args.len());
914                // NOTE: we don't check subtyping here because the RHS is *really*
915                // the function type, the LHS is just generated by rustc.
916                // we could generate a subtyping constraint but those would
917                // just be trivial (but might cause useless cycles in fixpoint).
918                // Nico: (This is probably ok because) We never do function
919                // subtyping between `FnDef` *except* when (the def_id) is
920                // passed as an argument to a function.
921                for (arg_a, arg_b) in iter::zip(a_args, b_args) {
922                    match (arg_a, arg_b) {
923                        (GenericArg::Ty(ty_a), GenericArg::Ty(ty_b)) => {
924                            let bty_a = ty_a.as_bty_skipping_existentials();
925                            let bty_b = ty_b.as_bty_skipping_existentials();
926                            tracked_span_dbg_assert_eq!(bty_a, bty_b);
927                        }
928                        (GenericArg::Base(ctor_a), GenericArg::Base(ctor_b)) => {
929                            let bty_a = ctor_a.as_bty_skipping_binder();
930                            let bty_b = ctor_b.as_bty_skipping_binder();
931                            tracked_span_dbg_assert_eq!(bty_a, bty_b);
932                        }
933                        (_, _) => tracked_span_dbg_assert_eq!(arg_a, arg_b),
934                    }
935                }
936                Ok(())
937            }
938            (BaseTy::Float(float_ty_a), BaseTy::Float(float_ty_b)) => {
939                debug_assert_eq!(float_ty_a, float_ty_b);
940                Ok(())
941            }
942            (BaseTy::Slice(ty_a), BaseTy::Slice(ty_b)) => self.tys(infcx, ty_a, ty_b),
943
944            (BaseTy::RawPtr(ty_a, mut_a), BaseTy::RawPtr(ty_b, mut_b)) => {
945                debug_assert_eq!(mut_a, mut_b);
946                self.tys(infcx, ty_a, ty_b)?;
947                if matches!(mut_a, Mutability::Mut) {
948                    self.tys(infcx, ty_b, ty_a)?;
949                }
950                Ok(())
951            }
952
953            (BaseTy::Ref(_, ty_a, Mutability::Mut), BaseTy::Ref(_, ty_b, Mutability::Mut)) => {
954                if ty_a.is_slice()
955                    && let TyKind::Indexed(_, idx_a) = ty_a.kind()
956                    && let TyKind::Exists(bty_b) = ty_b.kind()
957                {
958                    // For `&mut [T1][e] <: &mut ∃v[T2][v]`, we can hoist out the existential on the right because we know
959                    // the index is immutable. This means we have to prove `&mut [T1][e] <: ∃v. &mut [T2][v]`
960                    // This will in turn require proving `&mut [T1][e1] <: &mut [T2][?v]` for a fresh evar `?v`.
961                    // We know the evar will solve to `e`, so subtyping simplifies to the bellow.
962                    self.tys(infcx, ty_a, ty_b)?;
963                    self.tys(infcx, &bty_b.replace_bound_reft(idx_a), ty_a)
964                } else {
965                    self.tys(infcx, ty_a, ty_b)?;
966                    self.tys(infcx, ty_b, ty_a)
967                }
968            }
969            (BaseTy::Ref(_, ty_a, Mutability::Not), BaseTy::Ref(_, ty_b, Mutability::Not)) => {
970                self.tys(infcx, ty_a, ty_b)
971            }
972            (BaseTy::Tuple(tys_a), BaseTy::Tuple(tys_b)) => {
973                debug_assert_eq!(tys_a.len(), tys_b.len());
974                for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
975                    self.tys(infcx, ty_a, ty_b)?;
976                }
977                Ok(())
978            }
979            (
980                BaseTy::Alias(AliasKind::Opaque, alias_ty_a),
981                BaseTy::Alias(AliasKind::Opaque, alias_ty_b),
982            ) => {
983                debug_assert_eq!(alias_ty_a.def_id, alias_ty_b.def_id);
984
985                // handle type-args
986                for (ty_a, ty_b) in izip!(alias_ty_a.args.iter(), alias_ty_b.args.iter()) {
987                    self.generic_args(infcx, Invariant, ty_a, ty_b)?;
988                }
989
990                // handle refine-args
991                debug_assert_eq!(alias_ty_a.refine_args.len(), alias_ty_b.refine_args.len());
992                iter::zip(alias_ty_a.refine_args.iter(), alias_ty_b.refine_args.iter())
993                    .for_each(|(expr_a, expr_b)| infcx.unify_exprs(expr_a, expr_b));
994
995                Ok(())
996            }
997            (_, BaseTy::Alias(AliasKind::Opaque, alias_ty_b)) => {
998                // only for when concrete type on LHS and impl-with-bounds on RHS
999                self.handle_opaque_type(infcx, a, alias_ty_b)
1000            }
1001            (
1002                BaseTy::Alias(AliasKind::Projection, alias_ty_a),
1003                BaseTy::Alias(AliasKind::Projection, alias_ty_b),
1004            ) => {
1005                tracked_span_dbg_assert_eq!(alias_ty_a.erase_regions(), alias_ty_b.erase_regions());
1006                Ok(())
1007            }
1008            (BaseTy::Array(ty_a, len_a), BaseTy::Array(ty_b, len_b)) => {
1009                tracked_span_dbg_assert_eq!(len_a, len_b);
1010                self.tys(infcx, ty_a, ty_b)
1011            }
1012            (BaseTy::Param(param_a), BaseTy::Param(param_b)) => {
1013                debug_assert_eq!(param_a, param_b);
1014                Ok(())
1015            }
1016            (BaseTy::Bool, BaseTy::Bool)
1017            | (BaseTy::Str, BaseTy::Str)
1018            | (BaseTy::Char, BaseTy::Char)
1019            | (BaseTy::RawPtrMetadata(_), BaseTy::RawPtrMetadata(_)) => Ok(()),
1020            (BaseTy::Dynamic(preds_a, _), BaseTy::Dynamic(preds_b, _)) => {
1021                tracked_span_assert_eq!(preds_a.erase_regions(), preds_b.erase_regions());
1022                Ok(())
1023            }
1024            (BaseTy::Closure(did1, tys_a, _, _), BaseTy::Closure(did2, tys_b, _, _))
1025                if did1 == did2 =>
1026            {
1027                debug_assert_eq!(tys_a.len(), tys_b.len());
1028                for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
1029                    self.tys(infcx, ty_a, ty_b)?;
1030                }
1031                Ok(())
1032            }
1033            (BaseTy::FnPtr(sig_a), BaseTy::FnPtr(sig_b)) => {
1034                tracked_span_assert_eq!(sig_a.erase_regions(), sig_b.erase_regions());
1035                Ok(())
1036            }
1037            (BaseTy::Never, BaseTy::Never) => Ok(()),
1038            (
1039                BaseTy::Coroutine(did1, resume_ty_a, tys_a, _),
1040                BaseTy::Coroutine(did2, resume_ty_b, tys_b, _),
1041            ) if did1 == did2 => {
1042                debug_assert_eq!(tys_a.len(), tys_b.len());
1043                for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
1044                    self.tys(infcx, ty_a, ty_b)?;
1045                }
1046                // TODO(RJ): Treating resume type as invariant...but I think they should be contravariant(?)
1047                self.tys(infcx, resume_ty_b, resume_ty_a)?;
1048                self.tys(infcx, resume_ty_a, resume_ty_b)?;
1049
1050                Ok(())
1051            }
1052            (BaseTy::Foreign(did_a), BaseTy::Foreign(did_b)) if did_a == did_b => Ok(()),
1053            _ => Err(query_bug!("incompatible base types: `{a:#?}` - `{b:#?}`"))?,
1054        }
1055    }
1056
1057    fn generic_args(
1058        &mut self,
1059        infcx: &mut InferCtxt,
1060        variance: Variance,
1061        a: &GenericArg,
1062        b: &GenericArg,
1063    ) -> InferResult {
1064        let (ty_a, ty_b) = match (a, b) {
1065            (GenericArg::Ty(ty_a), GenericArg::Ty(ty_b)) => (ty_a.clone(), ty_b.clone()),
1066            (GenericArg::Base(ctor_a), GenericArg::Base(ctor_b)) => {
1067                tracked_span_dbg_assert_eq!(
1068                    ctor_a.sort().erase_regions(),
1069                    ctor_b.sort().erase_regions()
1070                );
1071                (ctor_a.to_ty(), ctor_b.to_ty())
1072            }
1073            (GenericArg::Lifetime(_), GenericArg::Lifetime(_)) => return Ok(()),
1074            (GenericArg::Const(cst_a), GenericArg::Const(cst_b)) => {
1075                debug_assert_eq!(cst_a, cst_b);
1076                return Ok(());
1077            }
1078            _ => Err(query_bug!("incompatible generic args: `{a:?}` `{b:?}`"))?,
1079        };
1080        match variance {
1081            Variance::Covariant => self.tys(infcx, &ty_a, &ty_b),
1082            Variance::Invariant => {
1083                self.tys(infcx, &ty_a, &ty_b)?;
1084                self.tys(infcx, &ty_b, &ty_a)
1085            }
1086            Variance::Contravariant => self.tys(infcx, &ty_b, &ty_a),
1087            Variance::Bivariant => Ok(()),
1088        }
1089    }
1090
1091    fn idxs_eq(&mut self, infcx: &mut InferCtxt, a: &Expr, b: &Expr) {
1092        if a == b {
1093            return;
1094        }
1095        match (a.kind(), b.kind()) {
1096            (
1097                ExprKind::Ctor(Ctor::Struct(did_a), flds_a),
1098                ExprKind::Ctor(Ctor::Struct(did_b), flds_b),
1099            ) => {
1100                debug_assert_eq!(did_a, did_b);
1101                for (a, b) in iter::zip(flds_a, flds_b) {
1102                    self.idxs_eq(infcx, a, b);
1103                }
1104            }
1105            (ExprKind::Tuple(flds_a), ExprKind::Tuple(flds_b)) => {
1106                for (a, b) in iter::zip(flds_a, flds_b) {
1107                    self.idxs_eq(infcx, a, b);
1108                }
1109            }
1110            (ExprKind::Ctor(Ctor::RawPtr, flds_a), ExprKind::Ctor(Ctor::RawPtr, flds_b)) => {
1111                for (a, b) in iter::zip(flds_a, flds_b) {
1112                    self.idxs_eq(infcx, a, b);
1113                }
1114            }
1115            (_, ExprKind::Tuple(flds_b)) => {
1116                for (f, b) in flds_b.iter().enumerate() {
1117                    let proj = FieldProj::Tuple { arity: flds_b.len(), field: f as u32 };
1118                    let a = a.proj_and_reduce(proj);
1119                    self.idxs_eq(infcx, &a, b);
1120                }
1121            }
1122            (_, ExprKind::Ctor(Ctor::RawPtr, flds_b)) => {
1123                for (f, b) in flds_b.iter().enumerate() {
1124                    let field = rty::RawPtrField::from_index(f as u32).unwrap();
1125                    let a = a.proj_and_reduce(FieldProj::RawPtr { field });
1126                    self.idxs_eq(infcx, &a, b);
1127                }
1128            }
1129
1130            (_, ExprKind::Ctor(Ctor::Struct(def_id), flds_b)) => {
1131                for (f, b) in flds_b.iter().enumerate() {
1132                    let proj = FieldProj::Adt { def_id: *def_id, field: f as u32 };
1133                    let a = a.proj_and_reduce(proj);
1134                    self.idxs_eq(infcx, &a, b);
1135                }
1136            }
1137
1138            (ExprKind::Tuple(flds_a), _) => {
1139                infcx.unify_exprs(a, b);
1140                for (f, a) in flds_a.iter().enumerate() {
1141                    let proj = FieldProj::Tuple { arity: flds_a.len(), field: f as u32 };
1142                    let b = b.proj_and_reduce(proj);
1143                    self.idxs_eq(infcx, a, &b);
1144                }
1145            }
1146            (ExprKind::Ctor(Ctor::RawPtr, flds_a), _) => {
1147                infcx.unify_exprs(a, b);
1148                for (f, a) in flds_a.iter().enumerate() {
1149                    let field = rty::RawPtrField::from_index(f as u32).unwrap();
1150                    let b = b.proj_and_reduce(FieldProj::RawPtr { field });
1151                    self.idxs_eq(infcx, a, &b);
1152                }
1153            }
1154            (ExprKind::Ctor(Ctor::Struct(def_id), flds_a), _) => {
1155                infcx.unify_exprs(a, b);
1156                for (f, a) in flds_a.iter().enumerate() {
1157                    let proj = FieldProj::Adt { def_id: *def_id, field: f as u32 };
1158                    let b = b.proj_and_reduce(proj);
1159                    self.idxs_eq(infcx, a, &b);
1160                }
1161            }
1162            (ExprKind::Abs(lam_a), ExprKind::Abs(lam_b)) => {
1163                self.abs_eq(infcx, lam_a, lam_b);
1164            }
1165            (_, ExprKind::Abs(lam_b)) => {
1166                self.abs_eq(infcx, &a.eta_expand_abs(lam_b.vars(), lam_b.output()), lam_b);
1167            }
1168            (ExprKind::Abs(lam_a), _) => {
1169                infcx.unify_exprs(a, b);
1170                self.abs_eq(infcx, lam_a, &b.eta_expand_abs(lam_a.vars(), lam_a.output()));
1171            }
1172            (ExprKind::KVar(_), _) | (_, ExprKind::KVar(_)) => {
1173                infcx.check_impl(a, b, self.tag());
1174                infcx.check_impl(b, a, self.tag());
1175            }
1176            _ => {
1177                infcx.unify_exprs(a, b);
1178                let span = b.span();
1179                infcx.check_pred(Expr::binary_op(rty::BinOp::Eq, a, b).at_opt(span), self.tag());
1180            }
1181        }
1182    }
1183
1184    fn abs_eq(&mut self, infcx: &mut InferCtxt, a: &Lambda, b: &Lambda) {
1185        debug_assert_eq!(a.vars().len(), b.vars().len());
1186        let vars = a
1187            .vars()
1188            .iter()
1189            .map(|kind| {
1190                let (sort, _, kind) = kind.expect_refine();
1191                Expr::fvar(infcx.define_bound_reft_var(sort, kind))
1192            })
1193            .collect_vec();
1194        let body_a = a.apply(&vars);
1195        let body_b = b.apply(&vars);
1196        self.idxs_eq(infcx, &body_a, &body_b);
1197    }
1198
1199    fn handle_opaque_type(
1200        &mut self,
1201        infcx: &mut InferCtxt,
1202        bty: &BaseTy,
1203        alias_ty: &AliasTy,
1204    ) -> InferResult {
1205        if let BaseTy::Coroutine(def_id, resume_ty, upvar_tys, args) = bty {
1206            let obligs = mk_coroutine_obligations(
1207                infcx.genv,
1208                def_id,
1209                resume_ty,
1210                upvar_tys,
1211                &alias_ty.def_id,
1212                args.clone(),
1213            )?;
1214            self.obligations.extend(obligs);
1215        } else {
1216            let bounds = infcx.genv.item_bounds(alias_ty.def_id)?.instantiate(
1217                infcx.tcx(),
1218                &alias_ty.args,
1219                &alias_ty.refine_args,
1220            );
1221            for clause in &bounds {
1222                if !clause.kind().vars().is_empty() {
1223                    Err(query_bug!("handle_opaque_types: clause with bound vars: `{clause:?}`"))?;
1224                }
1225                if let rty::ClauseKind::Projection(pred) = clause.kind_skipping_binder() {
1226                    let alias_ty = pred.projection_ty.with_self_ty(bty.to_subset_ty_ctor());
1227                    let ty1 = BaseTy::Alias(AliasKind::Projection, alias_ty)
1228                        .to_ty()
1229                        .deeply_normalize(&mut infcx.at(self.span))?;
1230                    let ty2 = pred.term.to_ty();
1231                    self.tys(infcx, &ty1, &ty2)?;
1232                }
1233            }
1234        }
1235        Ok(())
1236    }
1237}
1238
1239fn mk_coroutine_obligations(
1240    genv: GlobalEnv,
1241    generator_did: &DefId,
1242    resume_ty: &Ty,
1243    upvar_tys: &List<Ty>,
1244    opaque_def_id: &DefId,
1245    args: flux_rustc_bridge::ty::GenericArgs,
1246) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
1247    let bounds = genv.item_bounds(*opaque_def_id)?.skip_binder();
1248    for bound in &bounds {
1249        if let Some(proj_clause) = bound.as_projection_clause() {
1250            return Ok(vec![proj_clause.map(|proj_clause| {
1251                let output = proj_clause.term;
1252                CoroutineObligPredicate {
1253                    def_id: *generator_did,
1254                    resume_ty: resume_ty.clone(),
1255                    upvar_tys: upvar_tys.clone(),
1256                    output: output.to_ty(),
1257                    args,
1258                }
1259            })]);
1260        }
1261    }
1262    bug!("no projection predicate")
1263}
1264
1265#[derive(Debug)]
1266pub enum InferErr {
1267    UnsolvedEvar(EVid),
1268    Query(QueryErr),
1269}
1270
1271impl From<QueryErr> for InferErr {
1272    fn from(v: QueryErr) -> Self {
1273        Self::Query(v)
1274    }
1275}
1276
1277mod pretty {
1278    use std::fmt;
1279
1280    use flux_middle::pretty::*;
1281
1282    use super::*;
1283
1284    impl Pretty for Tag {
1285        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1286            w!(cx, f, "{:?} at {:?}", ^self.reason, self.src_span)?;
1287            if let Some(dst_span) = self.dst_span {
1288                w!(cx, f, " ({:?})", ^dst_span)?;
1289            }
1290            Ok(())
1291        }
1292    }
1293
1294    impl_debug_with_default_cx!(Tag);
1295}