1use std::{cell::RefCell, collections::HashMap, fmt, iter};
2
3use flux_common::{bug, dbg, tracked_span_assert_eq, tracked_span_bug, tracked_span_dbg_assert_eq};
4use flux_config::{self as config, InferOpts, OverflowMode};
5use flux_macros::{TypeFoldable, TypeVisitable};
6use flux_middle::{
7 FixpointQueryKind,
8 def_id::MaybeExternId,
9 global_env::GlobalEnv,
10 queries::{QueryErr, QueryResult},
11 query_bug,
12 rty::{
13 self, AliasKind, AliasTy, BaseTy, Binder, BoundReftKind, BoundVariableKinds,
14 CoroutineObligPredicate, Ctor, ESpan, EVid, EarlyBinder, Expr, ExprKind, FieldProj,
15 GenericArg, HoleKind, InferMode, Lambda, List, Loc, Mutability, Name, NameProvenance, Path,
16 PolyVariant, PtrKind, RefineArgs, RefineArgsExt, Region, Sort, Ty, TyCtor, TyKind, Var,
17 canonicalize::{Hoister, HoisterDelegate},
18 fold::TypeFoldable,
19 },
20};
21use itertools::{Itertools, izip};
22use rustc_hir::def_id::{DefId, LocalDefId};
23use rustc_macros::extension;
24use rustc_middle::{
25 mir::BasicBlock,
26 ty::{TyCtxt, Variance},
27};
28use rustc_span::{Span, Symbol};
29use rustc_type_ir::Variance::Invariant;
30
31use crate::{
32 evars::{EVarState, EVarStore},
33 fixpoint_encoding::{Answer, Backend, FixQueryCache, FixpointCtxt, KVarEncoding, KVarGen},
34 projections::NormalizeExt as _,
35 refine_tree::{Cursor, Marker, RefineTree, Scope},
36};
37
38pub type InferResult<T = ()> = std::result::Result<T, InferErr>;
39
40#[derive(PartialEq, Eq, Clone, Copy, Hash)]
41pub struct Tag {
42 pub reason: ConstrReason,
43 pub src_span: Span,
44 pub dst_span: Option<ESpan>,
45}
46
47impl Tag {
48 pub fn new(reason: ConstrReason, span: Span) -> Self {
49 Self { reason, src_span: span, dst_span: None }
50 }
51
52 pub fn with_dst(self, dst_span: Option<ESpan>) -> Self {
53 Self { dst_span, ..self }
54 }
55}
56
57#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)]
58pub enum SubtypeReason {
59 Input,
60 Output,
61 Requires,
62 Ensures,
63}
64
65#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)]
66pub enum ConstrReason {
67 Call,
68 Assign,
69 Ret,
70 Fold,
71 FoldLocal,
72 Predicate,
73 Assert(&'static str),
74 Div,
75 Rem,
76 Goto(BasicBlock),
77 Overflow,
78 Underflow,
79 Subtype(SubtypeReason),
80 Other,
81}
82
83pub struct InferCtxtRoot<'genv, 'tcx> {
84 pub genv: GlobalEnv<'genv, 'tcx>,
85 inner: RefCell<InferCtxtInner>,
86 refine_tree: RefineTree,
87 opts: InferOpts,
88}
89
90pub struct InferCtxtRootBuilder<'a, 'genv, 'tcx> {
91 genv: GlobalEnv<'genv, 'tcx>,
92 opts: InferOpts,
93 params: Vec<(Var, Sort)>,
94 infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
95 dummy_kvars: bool,
96}
97
98#[extension(pub trait GlobalEnvExt<'genv, 'tcx>)]
99impl<'genv, 'tcx> GlobalEnv<'genv, 'tcx> {
100 fn infcx_root<'a>(
101 self,
102 infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
103 opts: InferOpts,
104 ) -> InferCtxtRootBuilder<'a, 'genv, 'tcx> {
105 InferCtxtRootBuilder { genv: self, infcx, params: vec![], opts, dummy_kvars: false }
106 }
107}
108
109impl<'genv, 'tcx> InferCtxtRootBuilder<'_, 'genv, 'tcx> {
110 pub fn with_dummy_kvars(mut self) -> Self {
111 self.dummy_kvars = true;
112 self
113 }
114
115 pub fn with_const_generics(mut self, def_id: DefId) -> QueryResult<Self> {
116 self.params.extend(
117 self.genv
118 .generics_of(def_id)?
119 .const_params(self.genv)?
120 .into_iter()
121 .map(|(pcst, sort)| (Var::ConstGeneric(pcst), sort)),
122 );
123 Ok(self)
124 }
125
126 pub fn with_refinement_generics(
127 mut self,
128 def_id: DefId,
129 args: &[GenericArg],
130 ) -> QueryResult<Self> {
131 for (index, param) in self
132 .genv
133 .refinement_generics_of(def_id)?
134 .iter_own_params()
135 .enumerate()
136 {
137 let param = param.instantiate(self.genv.tcx(), args, &[]);
138 let sort = param
139 .sort
140 .deeply_normalize_sorts(def_id, self.genv, self.infcx)?;
141
142 let var =
143 Var::EarlyParam(rty::EarlyReftParam { index: index as u32, name: param.name });
144 self.params.push((var, sort));
145 }
146 Ok(self)
147 }
148
149 pub fn identity_for_item(mut self, def_id: DefId) -> QueryResult<Self> {
150 self = self.with_const_generics(def_id)?;
151 let offset = self.params.len();
152 self.genv.refinement_generics_of(def_id)?.fill_item(
153 self.genv,
154 &mut self.params,
155 &mut |param, index| {
156 let index = (index - offset) as u32;
157 let param = param.instantiate_identity();
158 let sort = param
159 .sort
160 .deeply_normalize_sorts(def_id, self.genv, self.infcx)?;
161
162 let var = Var::EarlyParam(rty::EarlyReftParam { index, name: param.name });
163 Ok((var, sort))
164 },
165 )?;
166 Ok(self)
167 }
168
169 pub fn build(self) -> QueryResult<InferCtxtRoot<'genv, 'tcx>> {
170 Ok(InferCtxtRoot {
171 genv: self.genv,
172 inner: RefCell::new(InferCtxtInner::new(self.dummy_kvars)),
173 refine_tree: RefineTree::new(self.params),
174 opts: self.opts,
175 })
176 }
177}
178
179impl<'genv, 'tcx> InferCtxtRoot<'genv, 'tcx> {
180 pub fn infcx<'a>(
181 &'a mut self,
182 def_id: DefId,
183 region_infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
184 ) -> InferCtxt<'a, 'genv, 'tcx> {
185 InferCtxt {
186 genv: self.genv,
187 region_infcx,
188 def_id,
189 cursor: self.refine_tree.cursor_at_root(),
190 inner: &self.inner,
191 check_overflow: self.opts.check_overflow,
192 }
193 }
194
195 pub fn fresh_kvar_in_scope(
196 &self,
197 binders: &[BoundVariableKinds],
198 scope: &Scope,
199 encoding: KVarEncoding,
200 ) -> Expr {
201 let inner = &mut *self.inner.borrow_mut();
202 inner.kvars.fresh(binders, scope.iter(), encoding)
203 }
204
205 pub fn execute_lean_query(
206 self,
207 cache: &mut FixQueryCache,
208 def_id: MaybeExternId,
209 ) -> QueryResult<()> {
210 let inner = self.inner.into_inner();
211 let kvars = inner.kvars;
212 let evars = inner.evars;
213 let mut refine_tree = self.refine_tree;
214 refine_tree.replace_evars(&evars).unwrap();
215 refine_tree.simplify(self.genv);
216
217 let mut fcx_for_solver =
218 FixpointCtxt::new(self.genv, def_id, kvars.clone(), Backend::Fixpoint);
219 let cstr_for_solver = refine_tree.to_fixpoint(&mut fcx_for_solver)?;
220 let solver = match self.opts.solver {
221 flux_config::SmtSolver::Z3 => liquid_fixpoint::SmtSolver::Z3,
222 flux_config::SmtSolver::CVC5 => liquid_fixpoint::SmtSolver::CVC5,
223 };
224 let kvar_solutions = fcx_for_solver
225 .check(
226 cache,
227 def_id,
228 cstr_for_solver,
229 FixpointQueryKind::Body,
230 self.opts.scrape_quals,
231 solver,
232 )
233 .map(|answer| answer.solution)
234 .unwrap_or_default();
235
236 let mut fcx = FixpointCtxt::new(self.genv, def_id, kvars, Backend::Lean);
237 let cstr = refine_tree.to_fixpoint(&mut fcx)?;
238 let kvar_sol_funcs: HashMap<_, _> = kvar_solutions
239 .iter()
240 .map(|(kvid, sol)| fcx.kvar_solution_for_lean(*kvid, sol))
241 .collect::<Result<_, _>>()?;
242 fcx.generate_and_check_lean_lemmas(cstr, kvar_sol_funcs)
243 }
244
245 pub fn execute_fixpoint_query(
246 self,
247 cache: &mut FixQueryCache,
248 def_id: MaybeExternId,
249 kind: FixpointQueryKind,
250 ) -> QueryResult<Answer<Tag>> {
251 let inner = self.inner.into_inner();
252 let kvars = inner.kvars;
253 let evars = inner.evars;
254
255 let ext = kind.ext();
256
257 let mut refine_tree = self.refine_tree;
258
259 refine_tree.replace_evars(&evars).unwrap();
260
261 if config::dump_constraint() {
262 dbg::dump_item_info(self.genv.tcx(), def_id.resolved_id(), ext, &refine_tree).unwrap();
263 }
264 refine_tree.simplify(self.genv);
265 if config::dump_constraint() {
266 let simp_ext = format!("simp.{ext}");
267 dbg::dump_item_info(self.genv.tcx(), def_id.resolved_id(), simp_ext, &refine_tree)
268 .unwrap();
269 }
270
271 let mut fcx = FixpointCtxt::new(self.genv, def_id, kvars, Backend::Fixpoint);
272 let cstr = refine_tree.to_fixpoint(&mut fcx)?;
273
274 let backend = match self.opts.solver {
275 flux_config::SmtSolver::Z3 => liquid_fixpoint::SmtSolver::Z3,
276 flux_config::SmtSolver::CVC5 => liquid_fixpoint::SmtSolver::CVC5,
277 };
278
279 fcx.check(cache, def_id, cstr, kind, self.opts.scrape_quals, backend)
280 }
281
282 pub fn split(self) -> (RefineTree, KVarGen) {
283 (self.refine_tree, self.inner.into_inner().kvars)
284 }
285}
286
287pub struct InferCtxt<'infcx, 'genv, 'tcx> {
288 pub genv: GlobalEnv<'genv, 'tcx>,
289 pub region_infcx: &'infcx rustc_infer::infer::InferCtxt<'tcx>,
290 pub def_id: DefId,
291 pub check_overflow: OverflowMode,
292 cursor: Cursor<'infcx>,
293 inner: &'infcx RefCell<InferCtxtInner>,
294}
295
296struct InferCtxtInner {
297 kvars: KVarGen,
298 evars: EVarStore,
299}
300
301impl InferCtxtInner {
302 fn new(dummy_kvars: bool) -> Self {
303 Self { kvars: KVarGen::new(dummy_kvars), evars: Default::default() }
304 }
305}
306
307impl<'infcx, 'genv, 'tcx> InferCtxt<'infcx, 'genv, 'tcx> {
308 pub fn at(&mut self, span: Span) -> InferCtxtAt<'_, 'infcx, 'genv, 'tcx> {
309 InferCtxtAt { infcx: self, span }
310 }
311
312 pub fn instantiate_refine_args(
313 &mut self,
314 callee_def_id: DefId,
315 args: &[rty::GenericArg],
316 ) -> InferResult<List<Expr>> {
317 Ok(RefineArgs::for_item(self.genv, callee_def_id, |param, _| {
318 let param = param.instantiate(self.genv.tcx(), args, &[]);
319 Ok(self.fresh_infer_var(¶m.sort, param.mode))
320 })?)
321 }
322
323 pub fn instantiate_generic_args(&mut self, args: &[GenericArg]) -> Vec<GenericArg> {
324 args.iter()
325 .map(|a| a.replace_holes(|binders, kind| self.fresh_infer_var_for_hole(binders, kind)))
326 .collect_vec()
327 }
328
329 pub fn fresh_infer_var(&self, sort: &Sort, mode: InferMode) -> Expr {
330 match mode {
331 InferMode::KVar => {
332 let fsort = sort.expect_func().expect_mono();
333 let vars = fsort.inputs().iter().cloned().map_into().collect();
334 let kvar = self.fresh_kvar(&[vars], KVarEncoding::Single);
335 Expr::abs(Lambda::bind_with_fsort(kvar, fsort))
336 }
337 InferMode::EVar => self.fresh_evar(),
338 }
339 }
340
341 pub fn fresh_infer_var_for_hole(
342 &mut self,
343 binders: &[BoundVariableKinds],
344 kind: HoleKind,
345 ) -> Expr {
346 match kind {
347 HoleKind::Pred => self.fresh_kvar(binders, KVarEncoding::Conj),
348 HoleKind::Expr(_) => {
349 self.fresh_evar()
353 }
354 }
355 }
356
357 pub fn fresh_kvar_in_scope(
359 &self,
360 binders: &[BoundVariableKinds],
361 scope: &Scope,
362 encoding: KVarEncoding,
363 ) -> Expr {
364 let inner = &mut *self.inner.borrow_mut();
365 inner.kvars.fresh(binders, scope.iter(), encoding)
366 }
367
368 pub fn fresh_kvar(&self, binders: &[BoundVariableKinds], encoding: KVarEncoding) -> Expr {
370 let inner = &mut *self.inner.borrow_mut();
371 inner.kvars.fresh(binders, self.cursor.vars(), encoding)
372 }
373
374 fn fresh_evar(&self) -> Expr {
375 let evars = &mut self.inner.borrow_mut().evars;
376 Expr::evar(evars.fresh(self.cursor.marker()))
377 }
378
379 pub fn unify_exprs(&self, a: &Expr, b: &Expr) {
380 if a.has_evars() {
381 return;
382 }
383 let evars = &mut self.inner.borrow_mut().evars;
384 if let ExprKind::Var(Var::EVar(evid)) = b.kind()
385 && let EVarState::Unsolved(marker) = evars.get(*evid)
386 && !marker.has_free_vars(a)
387 {
388 evars.solve(*evid, a.clone());
389 }
390 }
391
392 fn enter_exists<T, U>(
393 &mut self,
394 t: &Binder<T>,
395 f: impl FnOnce(&mut InferCtxt<'_, 'genv, 'tcx>, T) -> U,
396 ) -> U
397 where
398 T: TypeFoldable,
399 {
400 self.ensure_resolved_evars(|infcx| {
401 let t = t.replace_bound_refts_with(|sort, mode, _| infcx.fresh_infer_var(sort, mode));
402 Ok(f(infcx, t))
403 })
404 .unwrap()
405 }
406
407 pub fn push_evar_scope(&mut self) {
412 self.inner.borrow_mut().evars.push_scope();
413 }
414
415 pub fn pop_evar_scope(&mut self) -> InferResult {
418 self.inner
419 .borrow_mut()
420 .evars
421 .pop_scope()
422 .map_err(InferErr::UnsolvedEvar)
423 }
424
425 pub fn ensure_resolved_evars<R>(
427 &mut self,
428 f: impl FnOnce(&mut Self) -> InferResult<R>,
429 ) -> InferResult<R> {
430 self.push_evar_scope();
431 let r = f(self)?;
432 self.pop_evar_scope()?;
433 Ok(r)
434 }
435
436 pub fn fully_resolve_evars<T: TypeFoldable>(&self, t: &T) -> T {
437 self.inner.borrow().evars.replace_evars(t).unwrap()
438 }
439
440 pub fn tcx(&self) -> TyCtxt<'tcx> {
441 self.genv.tcx()
442 }
443
444 pub fn cursor(&self) -> &Cursor<'infcx> {
445 &self.cursor
446 }
447}
448
449impl<'infcx, 'genv, 'tcx> InferCtxt<'infcx, 'genv, 'tcx> {
451 pub fn change_item<'a>(
452 &'a mut self,
453 def_id: LocalDefId,
454 region_infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
455 ) -> InferCtxt<'a, 'genv, 'tcx> {
456 InferCtxt {
457 def_id: def_id.to_def_id(),
458 cursor: self.cursor.branch(),
459 region_infcx,
460 ..*self
461 }
462 }
463
464 pub fn move_to(&mut self, marker: &Marker, clear_children: bool) -> InferCtxt<'_, 'genv, 'tcx> {
465 InferCtxt {
466 cursor: self
467 .cursor
468 .move_to(marker, clear_children)
469 .unwrap_or_else(|| tracked_span_bug!()),
470 ..*self
471 }
472 }
473
474 pub fn branch(&mut self) -> InferCtxt<'_, 'genv, 'tcx> {
475 InferCtxt { cursor: self.cursor.branch(), ..*self }
476 }
477
478 fn define_var(&mut self, sort: &Sort, provenance: NameProvenance) -> Name {
479 self.cursor.define_var(sort, provenance)
480 }
481
482 pub fn define_bound_reft_var(&mut self, sort: &Sort, kind: BoundReftKind) -> Name {
483 self.define_var(sort, NameProvenance::UnfoldBoundReft(kind))
484 }
485
486 pub fn define_unknown_var(&mut self, sort: &Sort) -> Name {
487 self.cursor.define_var(sort, NameProvenance::Unknown)
488 }
489
490 pub fn check_pred(&mut self, pred: impl Into<Expr>, tag: Tag) {
491 self.cursor.check_pred(pred, tag);
492 }
493
494 pub fn assume_pred(&mut self, pred: impl Into<Expr>) {
495 self.cursor.assume_pred(pred);
496 }
497
498 pub fn unpack(&mut self, ty: &Ty) -> Ty {
499 self.hoister(false).hoist(ty)
500 }
501
502 pub fn unpack_at_name(&mut self, name: Option<Symbol>, ty: &Ty) -> Ty {
503 let mut hoister = self.hoister(false);
504 hoister.delegate.name = name;
505 hoister.hoist(ty)
506 }
507
508 pub fn marker(&self) -> Marker {
509 self.cursor.marker()
510 }
511
512 pub fn hoister(
513 &mut self,
514 assume_invariants: bool,
515 ) -> Hoister<Unpacker<'_, 'infcx, 'genv, 'tcx>> {
516 Hoister::with_delegate(Unpacker { infcx: self, assume_invariants, name: None })
517 .transparent()
518 }
519
520 pub fn assume_invariants(&mut self, ty: &Ty) {
521 self.cursor
522 .assume_invariants(self.genv.tcx(), ty, self.check_overflow);
523 }
524
525 fn check_impl(&mut self, pred1: impl Into<Expr>, pred2: impl Into<Expr>, tag: Tag) {
526 self.cursor.check_impl(pred1, pred2, tag);
527 }
528}
529
530pub struct Unpacker<'a, 'infcx, 'genv, 'tcx> {
531 infcx: &'a mut InferCtxt<'infcx, 'genv, 'tcx>,
532 assume_invariants: bool,
533 name: Option<Symbol>,
534}
535
536impl HoisterDelegate for Unpacker<'_, '_, '_, '_> {
537 fn hoist_exists(&mut self, ty_ctor: &TyCtor) -> Ty {
538 let ty = ty_ctor.replace_bound_refts_with(|sort, _, kind| {
539 let kind = if let Some(name) = self.name { BoundReftKind::Named(name) } else { kind };
540 Expr::fvar(self.infcx.define_bound_reft_var(sort, kind))
541 });
542 if self.assume_invariants {
543 self.infcx.assume_invariants(&ty);
544 }
545 ty
546 }
547
548 fn hoist_constr(&mut self, pred: Expr) {
549 self.infcx.assume_pred(pred);
550 }
551}
552
553impl std::fmt::Debug for InferCtxt<'_, '_, '_> {
554 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
555 std::fmt::Debug::fmt(&self.cursor, f)
556 }
557}
558
559#[derive(Debug)]
560pub struct InferCtxtAt<'a, 'infcx, 'genv, 'tcx> {
561 pub infcx: &'a mut InferCtxt<'infcx, 'genv, 'tcx>,
562 pub span: Span,
563}
564
565impl<'genv, 'tcx> InferCtxtAt<'_, '_, 'genv, 'tcx> {
566 fn tag(&self, reason: ConstrReason) -> Tag {
567 Tag::new(reason, self.span)
568 }
569
570 pub fn check_pred(&mut self, pred: impl Into<Expr>, reason: ConstrReason) {
571 let tag = self.tag(reason);
572 self.infcx.check_pred(pred, tag);
573 }
574
575 pub fn check_non_closure_clauses(
576 &mut self,
577 clauses: &[rty::Clause],
578 reason: ConstrReason,
579 ) -> InferResult {
580 for clause in clauses {
581 if let rty::ClauseKind::Projection(projection_pred) = clause.kind_skipping_binder() {
582 let impl_elem = BaseTy::projection(projection_pred.projection_ty)
583 .to_ty()
584 .deeply_normalize(self)?;
585 let term = projection_pred.term.to_ty().deeply_normalize(self)?;
586
587 self.subtyping(&impl_elem, &term, reason)?;
589 self.subtyping(&term, &impl_elem, reason)?;
590 }
591 }
592 Ok(())
593 }
594
595 pub fn subtyping_with_env(
598 &mut self,
599 env: &mut impl LocEnv,
600 a: &Ty,
601 b: &Ty,
602 reason: ConstrReason,
603 ) -> InferResult {
604 let mut sub = Sub::new(env, reason, self.span);
605 sub.tys(self.infcx, a, b)
606 }
607
608 pub fn subtyping(
613 &mut self,
614 a: &Ty,
615 b: &Ty,
616 reason: ConstrReason,
617 ) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
618 let mut env = DummyEnv;
619 let mut sub = Sub::new(&mut env, reason, self.span);
620 sub.tys(self.infcx, a, b)?;
621 Ok(sub.obligations)
622 }
623
624 pub fn subtyping_generic_args(
625 &mut self,
626 variance: Variance,
627 a: &GenericArg,
628 b: &GenericArg,
629 reason: ConstrReason,
630 ) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
631 let mut env = DummyEnv;
632 let mut sub = Sub::new(&mut env, reason, self.span);
633 sub.generic_args(self.infcx, variance, a, b)?;
634 Ok(sub.obligations)
635 }
636
637 pub fn check_constructor(
641 &mut self,
642 variant: EarlyBinder<PolyVariant>,
643 generic_args: &[GenericArg],
644 fields: &[Ty],
645 reason: ConstrReason,
646 ) -> InferResult<Ty> {
647 let ret = self.ensure_resolved_evars(|this| {
648 let generic_args = this.instantiate_generic_args(generic_args);
650
651 let variant = variant
652 .instantiate(this.tcx(), &generic_args, &[])
653 .replace_bound_refts_with(|sort, mode, _| this.fresh_infer_var(sort, mode));
654
655 for (actual, formal) in iter::zip(fields, variant.fields()) {
657 this.subtyping(actual, formal, reason)?;
658 }
659
660 for require in &variant.requires {
662 this.check_pred(require, ConstrReason::Fold);
663 }
664
665 Ok(variant.ret())
666 })?;
667 Ok(self.fully_resolve_evars(&ret))
668 }
669
670 pub fn ensure_resolved_evars<R>(
671 &mut self,
672 f: impl FnOnce(&mut InferCtxtAt<'_, '_, 'genv, 'tcx>) -> InferResult<R>,
673 ) -> InferResult<R> {
674 self.infcx
675 .ensure_resolved_evars(|infcx| f(&mut infcx.at(self.span)))
676 }
677}
678
679impl<'a, 'genv, 'tcx> std::ops::Deref for InferCtxtAt<'_, 'a, 'genv, 'tcx> {
680 type Target = InferCtxt<'a, 'genv, 'tcx>;
681
682 fn deref(&self) -> &Self::Target {
683 self.infcx
684 }
685}
686
687impl std::ops::DerefMut for InferCtxtAt<'_, '_, '_, '_> {
688 fn deref_mut(&mut self) -> &mut Self::Target {
689 self.infcx
690 }
691}
692
693#[derive(TypeVisitable, TypeFoldable)]
697pub(crate) enum TypeTrace {
698 Types(Ty, Ty),
699 BaseTys(BaseTy, BaseTy),
700}
701
702#[expect(dead_code, reason = "we use this for debugging some time")]
703impl TypeTrace {
704 fn tys(a: &Ty, b: &Ty) -> Self {
705 Self::Types(a.clone(), b.clone())
706 }
707
708 fn btys(a: &BaseTy, b: &BaseTy) -> Self {
709 Self::BaseTys(a.clone(), b.clone())
710 }
711}
712
713impl fmt::Debug for TypeTrace {
714 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
715 match self {
716 TypeTrace::Types(a, b) => write!(f, "{a:?} - {b:?}"),
717 TypeTrace::BaseTys(a, b) => write!(f, "{a:?} - {b:?}"),
718 }
719 }
720}
721
722pub trait LocEnv {
723 fn ptr_to_ref(
724 &mut self,
725 infcx: &mut InferCtxtAt,
726 reason: ConstrReason,
727 re: Region,
728 path: &Path,
729 bound: Ty,
730 ) -> InferResult<Ty>;
731
732 fn unfold_strg_ref(&mut self, infcx: &mut InferCtxt, path: &Path, ty: &Ty) -> InferResult<Loc>;
733
734 fn get(&self, path: &Path) -> Ty;
735}
736
737struct DummyEnv;
738
739impl LocEnv for DummyEnv {
740 fn ptr_to_ref(
741 &mut self,
742 _: &mut InferCtxtAt,
743 _: ConstrReason,
744 _: Region,
745 _: &Path,
746 _: Ty,
747 ) -> InferResult<Ty> {
748 bug!("call to `ptr_to_ref` on `DummyEnv`")
749 }
750
751 fn unfold_strg_ref(&mut self, _: &mut InferCtxt, _: &Path, _: &Ty) -> InferResult<Loc> {
752 bug!("call to `unfold_str_ref` on `DummyEnv`")
753 }
754
755 fn get(&self, _: &Path) -> Ty {
756 bug!("call to `get` on `DummyEnv`")
757 }
758}
759
760struct Sub<'a, E> {
762 env: &'a mut E,
764 reason: ConstrReason,
765 span: Span,
766 obligations: Vec<Binder<rty::CoroutineObligPredicate>>,
770}
771
772impl<'a, E: LocEnv> Sub<'a, E> {
773 fn new(env: &'a mut E, reason: ConstrReason, span: Span) -> Self {
774 Self { env, reason, span, obligations: vec![] }
775 }
776
777 fn tag(&self) -> Tag {
778 Tag::new(self.reason, self.span)
779 }
780
781 fn tys(&mut self, infcx: &mut InferCtxt, a: &Ty, b: &Ty) -> InferResult {
782 let infcx = &mut infcx.branch();
783 let a = infcx.unpack(a);
789
790 match (a.kind(), b.kind()) {
791 (TyKind::Exists(..), _) => {
792 bug!("existentials should have been removed by the unpacking above");
793 }
794 (TyKind::Constr(..), _) => {
795 bug!("constraint types should have been removed by the unpacking above");
796 }
797
798 (_, TyKind::Exists(ctor_b)) => {
799 infcx.enter_exists(ctor_b, |infcx, ty_b| self.tys(infcx, &a, &ty_b))
800 }
801 (_, TyKind::Constr(pred_b, ty_b)) => {
802 infcx.check_pred(pred_b, self.tag());
803 self.tys(infcx, &a, ty_b)
804 }
805
806 (TyKind::Ptr(PtrKind::Mut(_), path_a), TyKind::StrgRef(_, path_b, ty_b)) => {
807 let ty_a = self.env.get(path_a);
811 infcx.unify_exprs(&path_a.to_expr(), &path_b.to_expr());
812 self.tys(infcx, &ty_a, ty_b)
813 }
814 (TyKind::StrgRef(_, path_a, ty_a), TyKind::StrgRef(_, path_b, ty_b)) => {
815 self.env.unfold_strg_ref(infcx, path_a, ty_a)?;
827 let ty_a = self.env.get(path_a);
828 infcx.unify_exprs(&path_a.to_expr(), &path_b.to_expr());
829 self.tys(infcx, &ty_a, ty_b)
830 }
831 (
832 TyKind::Ptr(PtrKind::Mut(re), path),
833 TyKind::Indexed(BaseTy::Ref(_, bound, Mutability::Mut), idx),
834 ) => {
835 self.idxs_eq(infcx, &Expr::unit(), idx);
838
839 self.env.ptr_to_ref(
840 &mut infcx.at(self.span),
841 self.reason,
842 *re,
843 path,
844 bound.clone(),
845 )?;
846 Ok(())
847 }
848
849 (TyKind::Indexed(bty_a, idx_a), TyKind::Indexed(bty_b, idx_b)) => {
850 self.btys(infcx, bty_a, bty_b)?;
851 self.idxs_eq(infcx, idx_a, idx_b);
852 Ok(())
853 }
854 (TyKind::Ptr(pk_a, path_a), TyKind::Ptr(pk_b, path_b)) => {
855 debug_assert_eq!(pk_a, pk_b);
856 debug_assert_eq!(path_a, path_b);
857 Ok(())
858 }
859 (TyKind::Param(param_ty_a), TyKind::Param(param_ty_b)) => {
860 debug_assert_eq!(param_ty_a, param_ty_b);
861 Ok(())
862 }
863 (_, TyKind::Uninit) => Ok(()),
864 (TyKind::Downcast(.., fields_a), TyKind::Downcast(.., fields_b)) => {
865 debug_assert_eq!(fields_a.len(), fields_b.len());
866 for (ty_a, ty_b) in iter::zip(fields_a, fields_b) {
867 self.tys(infcx, ty_a, ty_b)?;
868 }
869 Ok(())
870 }
871 _ => Err(query_bug!("incompatible types: `{a:?}` - `{b:?}`"))?,
872 }
873 }
874
875 fn btys(&mut self, infcx: &mut InferCtxt, a: &BaseTy, b: &BaseTy) -> InferResult {
876 match (a, b) {
879 (BaseTy::Int(int_ty_a), BaseTy::Int(int_ty_b)) => {
880 debug_assert_eq!(int_ty_a, int_ty_b);
881 Ok(())
882 }
883 (BaseTy::Uint(uint_ty_a), BaseTy::Uint(uint_ty_b)) => {
884 debug_assert_eq!(uint_ty_a, uint_ty_b);
885 Ok(())
886 }
887 (BaseTy::Adt(a_adt, a_args), BaseTy::Adt(b_adt, b_args)) => {
888 tracked_span_dbg_assert_eq!(a_adt.did(), b_adt.did());
889 tracked_span_dbg_assert_eq!(a_args.len(), b_args.len());
890 let variances = infcx.genv.variances_of(a_adt.did());
891 for (variance, ty_a, ty_b) in izip!(variances, a_args.iter(), b_args.iter()) {
892 self.generic_args(infcx, *variance, ty_a, ty_b)?;
893 }
894 Ok(())
895 }
896 (BaseTy::FnDef(a_def_id, a_args), BaseTy::FnDef(b_def_id, b_args)) => {
897 debug_assert_eq!(a_def_id, b_def_id);
898 debug_assert_eq!(a_args.len(), b_args.len());
899 for (arg_a, arg_b) in iter::zip(a_args, b_args) {
907 match (arg_a, arg_b) {
908 (GenericArg::Ty(ty_a), GenericArg::Ty(ty_b)) => {
909 let bty_a = ty_a.as_bty_skipping_existentials();
910 let bty_b = ty_b.as_bty_skipping_existentials();
911 tracked_span_dbg_assert_eq!(bty_a, bty_b);
912 }
913 (GenericArg::Base(ctor_a), GenericArg::Base(ctor_b)) => {
914 let bty_a = ctor_a.as_bty_skipping_binder();
915 let bty_b = ctor_b.as_bty_skipping_binder();
916 tracked_span_dbg_assert_eq!(bty_a, bty_b);
917 }
918 (_, _) => tracked_span_dbg_assert_eq!(arg_a, arg_b),
919 }
920 }
921 Ok(())
922 }
923 (BaseTy::Float(float_ty_a), BaseTy::Float(float_ty_b)) => {
924 debug_assert_eq!(float_ty_a, float_ty_b);
925 Ok(())
926 }
927 (BaseTy::Slice(ty_a), BaseTy::Slice(ty_b)) => self.tys(infcx, ty_a, ty_b),
928 (BaseTy::Ref(_, ty_a, Mutability::Mut), BaseTy::Ref(_, ty_b, Mutability::Mut)) => {
929 if ty_a.is_slice()
930 && let TyKind::Indexed(_, idx_a) = ty_a.kind()
931 && let TyKind::Exists(bty_b) = ty_b.kind()
932 {
933 self.tys(infcx, ty_a, ty_b)?;
938 self.tys(infcx, &bty_b.replace_bound_reft(idx_a), ty_a)
939 } else {
940 self.tys(infcx, ty_a, ty_b)?;
941 self.tys(infcx, ty_b, ty_a)
942 }
943 }
944 (BaseTy::Ref(_, ty_a, Mutability::Not), BaseTy::Ref(_, ty_b, Mutability::Not)) => {
945 self.tys(infcx, ty_a, ty_b)
946 }
947 (BaseTy::Tuple(tys_a), BaseTy::Tuple(tys_b)) => {
948 debug_assert_eq!(tys_a.len(), tys_b.len());
949 for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
950 self.tys(infcx, ty_a, ty_b)?;
951 }
952 Ok(())
953 }
954 (
955 BaseTy::Alias(AliasKind::Opaque, alias_ty_a),
956 BaseTy::Alias(AliasKind::Opaque, alias_ty_b),
957 ) => {
958 debug_assert_eq!(alias_ty_a.def_id, alias_ty_b.def_id);
959
960 for (ty_a, ty_b) in izip!(alias_ty_a.args.iter(), alias_ty_b.args.iter()) {
962 self.generic_args(infcx, Invariant, ty_a, ty_b)?;
963 }
964
965 debug_assert_eq!(alias_ty_a.refine_args.len(), alias_ty_b.refine_args.len());
967 iter::zip(alias_ty_a.refine_args.iter(), alias_ty_b.refine_args.iter())
968 .for_each(|(expr_a, expr_b)| infcx.unify_exprs(expr_a, expr_b));
969
970 Ok(())
971 }
972 (_, BaseTy::Alias(AliasKind::Opaque, alias_ty_b)) => {
973 self.handle_opaque_type(infcx, a, alias_ty_b)
975 }
976 (
977 BaseTy::Alias(AliasKind::Projection, alias_ty_a),
978 BaseTy::Alias(AliasKind::Projection, alias_ty_b),
979 ) => {
980 tracked_span_dbg_assert_eq!(alias_ty_a, alias_ty_b);
981 Ok(())
982 }
983 (BaseTy::Array(ty_a, len_a), BaseTy::Array(ty_b, len_b)) => {
984 tracked_span_dbg_assert_eq!(len_a, len_b);
985 self.tys(infcx, ty_a, ty_b)
986 }
987 (BaseTy::Param(param_a), BaseTy::Param(param_b)) => {
988 debug_assert_eq!(param_a, param_b);
989 Ok(())
990 }
991 (BaseTy::Bool, BaseTy::Bool)
992 | (BaseTy::Str, BaseTy::Str)
993 | (BaseTy::Char, BaseTy::Char)
994 | (BaseTy::RawPtr(_, _), BaseTy::RawPtr(_, _))
995 | (BaseTy::RawPtrMetadata(_), BaseTy::RawPtrMetadata(_)) => Ok(()),
996 (BaseTy::Dynamic(preds_a, _), BaseTy::Dynamic(preds_b, _)) => {
997 tracked_span_assert_eq!(preds_a.erase_regions(), preds_b.erase_regions());
998 Ok(())
999 }
1000 (BaseTy::Closure(did1, tys_a, _), BaseTy::Closure(did2, tys_b, _)) if did1 == did2 => {
1001 debug_assert_eq!(tys_a.len(), tys_b.len());
1002 for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
1003 self.tys(infcx, ty_a, ty_b)?;
1004 }
1005 Ok(())
1006 }
1007 (BaseTy::FnPtr(sig_a), BaseTy::FnPtr(sig_b)) => {
1008 tracked_span_assert_eq!(sig_a.erase_regions(), sig_b.erase_regions());
1009 Ok(())
1010 }
1011 (BaseTy::Never, BaseTy::Never) => Ok(()),
1012 _ => Err(query_bug!("incompatible base types: `{a:?}` - `{b:?}`"))?,
1013 }
1014 }
1015
1016 fn generic_args(
1017 &mut self,
1018 infcx: &mut InferCtxt,
1019 variance: Variance,
1020 a: &GenericArg,
1021 b: &GenericArg,
1022 ) -> InferResult {
1023 let (ty_a, ty_b) = match (a, b) {
1024 (GenericArg::Ty(ty_a), GenericArg::Ty(ty_b)) => (ty_a.clone(), ty_b.clone()),
1025 (GenericArg::Base(ctor_a), GenericArg::Base(ctor_b)) => {
1026 tracked_span_dbg_assert_eq!(ctor_a.sort(), ctor_b.sort());
1027 (ctor_a.to_ty(), ctor_b.to_ty())
1028 }
1029 (GenericArg::Lifetime(_), GenericArg::Lifetime(_)) => return Ok(()),
1030 (GenericArg::Const(cst_a), GenericArg::Const(cst_b)) => {
1031 debug_assert_eq!(cst_a, cst_b);
1032 return Ok(());
1033 }
1034 _ => Err(query_bug!("incompatible generic args: `{a:?}` `{b:?}`"))?,
1035 };
1036 match variance {
1037 Variance::Covariant => self.tys(infcx, &ty_a, &ty_b),
1038 Variance::Invariant => {
1039 self.tys(infcx, &ty_a, &ty_b)?;
1040 self.tys(infcx, &ty_b, &ty_a)
1041 }
1042 Variance::Contravariant => self.tys(infcx, &ty_b, &ty_a),
1043 Variance::Bivariant => Ok(()),
1044 }
1045 }
1046
1047 fn idxs_eq(&mut self, infcx: &mut InferCtxt, a: &Expr, b: &Expr) {
1048 if a == b {
1049 return;
1050 }
1051 match (a.kind(), b.kind()) {
1052 (
1053 ExprKind::Ctor(Ctor::Struct(did_a), flds_a),
1054 ExprKind::Ctor(Ctor::Struct(did_b), flds_b),
1055 ) => {
1056 debug_assert_eq!(did_a, did_b);
1057 for (a, b) in iter::zip(flds_a, flds_b) {
1058 self.idxs_eq(infcx, a, b);
1059 }
1060 }
1061 (ExprKind::Tuple(flds_a), ExprKind::Tuple(flds_b)) => {
1062 for (a, b) in iter::zip(flds_a, flds_b) {
1063 self.idxs_eq(infcx, a, b);
1064 }
1065 }
1066 (_, ExprKind::Tuple(flds_b)) => {
1067 for (f, b) in flds_b.iter().enumerate() {
1068 let proj = FieldProj::Tuple { arity: flds_b.len(), field: f as u32 };
1069 let a = a.proj_and_reduce(proj);
1070 self.idxs_eq(infcx, &a, b);
1071 }
1072 }
1073
1074 (_, ExprKind::Ctor(Ctor::Struct(def_id), flds_b)) => {
1075 for (f, b) in flds_b.iter().enumerate() {
1076 let proj = FieldProj::Adt { def_id: *def_id, field: f as u32 };
1077 let a = a.proj_and_reduce(proj);
1078 self.idxs_eq(infcx, &a, b);
1079 }
1080 }
1081
1082 (ExprKind::Tuple(flds_a), _) => {
1083 infcx.unify_exprs(a, b);
1084 for (f, a) in flds_a.iter().enumerate() {
1085 let proj = FieldProj::Tuple { arity: flds_a.len(), field: f as u32 };
1086 let b = b.proj_and_reduce(proj);
1087 self.idxs_eq(infcx, a, &b);
1088 }
1089 }
1090 (ExprKind::Ctor(Ctor::Struct(def_id), flds_a), _) => {
1091 infcx.unify_exprs(a, b);
1092 for (f, a) in flds_a.iter().enumerate() {
1093 let proj = FieldProj::Adt { def_id: *def_id, field: f as u32 };
1094 let b = b.proj_and_reduce(proj);
1095 self.idxs_eq(infcx, a, &b);
1096 }
1097 }
1098 (ExprKind::Abs(lam_a), ExprKind::Abs(lam_b)) => {
1099 self.abs_eq(infcx, lam_a, lam_b);
1100 }
1101 (_, ExprKind::Abs(lam_b)) => {
1102 self.abs_eq(infcx, &a.eta_expand_abs(lam_b.vars(), lam_b.output()), lam_b);
1103 }
1104 (ExprKind::Abs(lam_a), _) => {
1105 infcx.unify_exprs(a, b);
1106 self.abs_eq(infcx, lam_a, &b.eta_expand_abs(lam_a.vars(), lam_a.output()));
1107 }
1108 (ExprKind::KVar(_), _) | (_, ExprKind::KVar(_)) => {
1109 infcx.check_impl(a, b, self.tag());
1110 infcx.check_impl(b, a, self.tag());
1111 }
1112 _ => {
1113 infcx.unify_exprs(a, b);
1114 let span = b.span();
1115 infcx.check_pred(Expr::binary_op(rty::BinOp::Eq, a, b).at_opt(span), self.tag());
1116 }
1117 }
1118 }
1119
1120 fn abs_eq(&mut self, infcx: &mut InferCtxt, a: &Lambda, b: &Lambda) {
1121 debug_assert_eq!(a.vars().len(), b.vars().len());
1122 let vars = a
1123 .vars()
1124 .iter()
1125 .map(|kind| {
1126 let (sort, _, kind) = kind.expect_refine();
1127 Expr::fvar(infcx.define_bound_reft_var(sort, kind))
1128 })
1129 .collect_vec();
1130 let body_a = a.apply(&vars);
1131 let body_b = b.apply(&vars);
1132 self.idxs_eq(infcx, &body_a, &body_b);
1133 }
1134
1135 fn handle_opaque_type(
1136 &mut self,
1137 infcx: &mut InferCtxt,
1138 bty: &BaseTy,
1139 alias_ty: &AliasTy,
1140 ) -> InferResult {
1141 if let BaseTy::Coroutine(def_id, resume_ty, upvar_tys) = bty {
1142 let obligs = mk_coroutine_obligations(
1143 infcx.genv,
1144 def_id,
1145 resume_ty,
1146 upvar_tys,
1147 &alias_ty.def_id,
1148 )?;
1149 self.obligations.extend(obligs);
1150 } else {
1151 let bounds = infcx.genv.item_bounds(alias_ty.def_id)?.instantiate(
1152 infcx.tcx(),
1153 &alias_ty.args,
1154 &alias_ty.refine_args,
1155 );
1156 for clause in &bounds {
1157 if !clause.kind().vars().is_empty() {
1158 Err(query_bug!("handle_opaque_types: clause with bound vars: `{clause:?}`"))?;
1159 }
1160 if let rty::ClauseKind::Projection(pred) = clause.kind_skipping_binder() {
1161 let alias_ty = pred.projection_ty.with_self_ty(bty.to_subset_ty_ctor());
1162 let ty1 = BaseTy::Alias(AliasKind::Projection, alias_ty)
1163 .to_ty()
1164 .deeply_normalize(&mut infcx.at(self.span))?;
1165 let ty2 = pred.term.to_ty();
1166 self.tys(infcx, &ty1, &ty2)?;
1167 }
1168 }
1169 }
1170 Ok(())
1171 }
1172}
1173
1174fn mk_coroutine_obligations(
1175 genv: GlobalEnv,
1176 generator_did: &DefId,
1177 resume_ty: &Ty,
1178 upvar_tys: &List<Ty>,
1179 opaque_def_id: &DefId,
1180) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
1181 let bounds = genv.item_bounds(*opaque_def_id)?.skip_binder();
1182 for bound in &bounds {
1183 if let Some(proj_clause) = bound.as_projection_clause() {
1184 return Ok(vec![proj_clause.map(|proj_clause| {
1185 let output = proj_clause.term;
1186 CoroutineObligPredicate {
1187 def_id: *generator_did,
1188 resume_ty: resume_ty.clone(),
1189 upvar_tys: upvar_tys.clone(),
1190 output: output.to_ty(),
1191 }
1192 })]);
1193 }
1194 }
1195 bug!("no projection predicate")
1196}
1197
1198#[derive(Debug)]
1199pub enum InferErr {
1200 UnsolvedEvar(EVid),
1201 Query(QueryErr),
1202}
1203
1204impl From<QueryErr> for InferErr {
1205 fn from(v: QueryErr) -> Self {
1206 Self::Query(v)
1207 }
1208}
1209
1210mod pretty {
1211 use std::fmt;
1212
1213 use flux_middle::pretty::*;
1214
1215 use super::*;
1216
1217 impl Pretty for Tag {
1218 fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1219 w!(cx, f, "{:?} at {:?}", ^self.reason, self.src_span)?;
1220 if let Some(dst_span) = self.dst_span {
1221 w!(cx, f, " ({:?})", ^dst_span)?;
1222 }
1223 Ok(())
1224 }
1225 }
1226
1227 impl_debug_with_default_cx!(Tag);
1228}