1use std::{cell::RefCell, fmt, iter};
2
3use flux_common::{bug, dbg, tracked_span_assert_eq, tracked_span_bug, tracked_span_dbg_assert_eq};
4use flux_config::{self as config, InferOpts};
5use flux_macros::{TypeFoldable, TypeVisitable};
6use flux_middle::{
7 FixpointQueryKind,
8 def_id::MaybeExternId,
9 global_env::GlobalEnv,
10 queries::{QueryErr, QueryResult},
11 query_bug,
12 rty::{
13 self, AliasKind, AliasTy, BaseTy, Binder, BoundVariableKinds, CoroutineObligPredicate,
14 Ctor, ESpan, EVid, EarlyBinder, Expr, ExprKind, FieldProj, GenericArg, HoleKind, InferMode,
15 Lambda, List, Loc, Mutability, Name, Path, PolyVariant, PtrKind, RefineArgs, RefineArgsExt,
16 Region, Sort, Ty, TyCtor, TyKind, Var,
17 canonicalize::{Hoister, HoisterDelegate},
18 fold::TypeFoldable,
19 },
20};
21use itertools::{Itertools, izip};
22use rustc_hir::def_id::{DefId, LocalDefId};
23use rustc_macros::extension;
24use rustc_middle::{
25 mir::BasicBlock,
26 ty::{TyCtxt, Variance},
27};
28use rustc_span::Span;
29
30use crate::{
31 evars::{EVarState, EVarStore},
32 fixpoint_encoding::{FixQueryCache, FixpointCtxt, KVarEncoding, KVarGen},
33 projections::NormalizeExt as _,
34 refine_tree::{Cursor, Marker, RefineTree, Scope},
35};
36
37pub type InferResult<T = ()> = std::result::Result<T, InferErr>;
38
39#[derive(PartialEq, Eq, Clone, Copy, Hash)]
40pub struct Tag {
41 pub reason: ConstrReason,
42 pub src_span: Span,
43 pub dst_span: Option<ESpan>,
44}
45
46impl Tag {
47 pub fn new(reason: ConstrReason, span: Span) -> Self {
48 Self { reason, src_span: span, dst_span: None }
49 }
50
51 pub fn with_dst(self, dst_span: Option<ESpan>) -> Self {
52 Self { dst_span, ..self }
53 }
54}
55
56#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)]
57pub enum SubtypeReason {
58 Input,
59 Output,
60 Requires,
61 Ensures,
62}
63
64#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)]
65pub enum ConstrReason {
66 Call,
67 Assign,
68 Ret,
69 Fold,
70 FoldLocal,
71 Predicate,
72 Assert(&'static str),
73 Div,
74 Rem,
75 Goto(BasicBlock),
76 Overflow,
77 Subtype(SubtypeReason),
78 Other,
79}
80
81pub struct InferCtxtRoot<'genv, 'tcx> {
82 pub genv: GlobalEnv<'genv, 'tcx>,
83 inner: RefCell<InferCtxtInner>,
84 refine_tree: RefineTree,
85 opts: InferOpts,
86}
87
88pub struct InferCtxtRootBuilder<'a, 'genv, 'tcx> {
89 genv: GlobalEnv<'genv, 'tcx>,
90 opts: InferOpts,
91 params: Vec<(Var, Sort)>,
92 infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
93 dummy_kvars: bool,
94}
95
96#[extension(pub trait GlobalEnvExt<'genv, 'tcx>)]
97impl<'genv, 'tcx> GlobalEnv<'genv, 'tcx> {
98 fn infcx_root<'a>(
99 self,
100 infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
101 opts: InferOpts,
102 ) -> InferCtxtRootBuilder<'a, 'genv, 'tcx> {
103 InferCtxtRootBuilder { genv: self, infcx, params: vec![], opts, dummy_kvars: false }
104 }
105}
106
107impl<'genv, 'tcx> InferCtxtRootBuilder<'_, 'genv, 'tcx> {
108 pub fn with_dummy_kvars(mut self) -> Self {
109 self.dummy_kvars = true;
110 self
111 }
112
113 pub fn with_const_generics(mut self, def_id: DefId) -> QueryResult<Self> {
114 self.params.extend(
115 self.genv
116 .generics_of(def_id)?
117 .const_params(self.genv)?
118 .into_iter()
119 .map(|(pcst, sort)| (Var::ConstGeneric(pcst), sort)),
120 );
121 Ok(self)
122 }
123
124 pub fn with_refinement_generics(
125 mut self,
126 def_id: DefId,
127 args: &[GenericArg],
128 ) -> QueryResult<Self> {
129 for (index, param) in self
130 .genv
131 .refinement_generics_of(def_id)?
132 .iter_own_params()
133 .enumerate()
134 {
135 let param = param.instantiate(self.genv.tcx(), args, &[]);
136 let sort = param.sort.normalize_sorts(def_id, self.genv, self.infcx)?;
137
138 let var =
139 Var::EarlyParam(rty::EarlyReftParam { index: index as u32, name: param.name });
140 self.params.push((var, sort));
141 }
142 Ok(self)
143 }
144
145 pub fn identity_for_item(mut self, def_id: DefId) -> QueryResult<Self> {
146 self = self.with_const_generics(def_id)?;
147 let offset = self.params.len();
148 self.genv.refinement_generics_of(def_id)?.fill_item(
149 self.genv,
150 &mut self.params,
151 &mut |param, index| {
152 let index = (index - offset) as u32;
153 let param = param.instantiate_identity();
154 let sort = param.sort.normalize_sorts(def_id, self.genv, self.infcx)?;
155
156 let var = Var::EarlyParam(rty::EarlyReftParam { index, name: param.name });
157 Ok((var, sort))
158 },
159 )?;
160 Ok(self)
161 }
162
163 pub fn build(self) -> QueryResult<InferCtxtRoot<'genv, 'tcx>> {
164 Ok(InferCtxtRoot {
165 genv: self.genv,
166 inner: RefCell::new(InferCtxtInner::new(self.dummy_kvars)),
167 refine_tree: RefineTree::new(self.params),
168 opts: self.opts,
169 })
170 }
171}
172
173impl<'genv, 'tcx> InferCtxtRoot<'genv, 'tcx> {
174 pub fn infcx<'a>(
175 &'a mut self,
176 def_id: DefId,
177 region_infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
178 ) -> InferCtxt<'a, 'genv, 'tcx> {
179 InferCtxt {
180 genv: self.genv,
181 region_infcx,
182 def_id,
183 cursor: self.refine_tree.cursor_at_root(),
184 inner: &self.inner,
185 check_overflow: self.opts.check_overflow,
186 }
187 }
188
189 pub fn fresh_kvar_in_scope(
190 &self,
191 binders: &[BoundVariableKinds],
192 scope: &Scope,
193 encoding: KVarEncoding,
194 ) -> Expr {
195 let inner = &mut *self.inner.borrow_mut();
196 inner.kvars.fresh(binders, scope.iter(), encoding)
197 }
198
199 pub fn execute_fixpoint_query(
200 self,
201 cache: &mut FixQueryCache,
202 def_id: MaybeExternId,
203 kind: FixpointQueryKind,
204 ) -> QueryResult<Vec<Tag>> {
205 let inner = self.inner.into_inner();
206 let kvars = inner.kvars;
207 let evars = inner.evars;
208
209 let ext = kind.ext();
210
211 let mut refine_tree = self.refine_tree;
212
213 refine_tree.replace_evars(&evars).unwrap();
214
215 if config::dump_constraint() {
216 dbg::dump_item_info(self.genv.tcx(), def_id.resolved_id(), ext, &refine_tree).unwrap();
217 }
218 refine_tree.simplify(self.genv);
219 if config::dump_constraint() {
220 let simp_ext = format!("simp.{ext}");
221 dbg::dump_item_info(self.genv.tcx(), def_id.resolved_id(), simp_ext, &refine_tree)
222 .unwrap();
223 }
224
225 let mut fcx = FixpointCtxt::new(self.genv, def_id, kvars);
226 let cstr = refine_tree.into_fixpoint(&mut fcx)?;
227
228 let backend = match self.opts.solver {
229 flux_config::SmtSolver::Z3 => liquid_fixpoint::SmtSolver::Z3,
230 flux_config::SmtSolver::CVC5 => liquid_fixpoint::SmtSolver::CVC5,
231 };
232
233 fcx.check(cache, cstr, kind, self.opts.scrape_quals, backend)
234 }
235
236 pub fn split(self) -> (RefineTree, KVarGen) {
237 (self.refine_tree, self.inner.into_inner().kvars)
238 }
239}
240
241pub struct InferCtxt<'infcx, 'genv, 'tcx> {
242 pub genv: GlobalEnv<'genv, 'tcx>,
243 pub region_infcx: &'infcx rustc_infer::infer::InferCtxt<'tcx>,
244 pub def_id: DefId,
245 pub check_overflow: bool,
246 cursor: Cursor<'infcx>,
247 inner: &'infcx RefCell<InferCtxtInner>,
248}
249
250struct InferCtxtInner {
251 kvars: KVarGen,
252 evars: EVarStore,
253}
254
255impl InferCtxtInner {
256 fn new(dummy_kvars: bool) -> Self {
257 Self { kvars: KVarGen::new(dummy_kvars), evars: Default::default() }
258 }
259}
260
261impl<'infcx, 'genv, 'tcx> InferCtxt<'infcx, 'genv, 'tcx> {
262 pub fn at(&mut self, span: Span) -> InferCtxtAt<'_, 'infcx, 'genv, 'tcx> {
263 InferCtxtAt { infcx: self, span }
264 }
265
266 pub fn instantiate_refine_args(
267 &mut self,
268 callee_def_id: DefId,
269 args: &[rty::GenericArg],
270 ) -> InferResult<List<Expr>> {
271 Ok(RefineArgs::for_item(self.genv, callee_def_id, |param, _| {
272 let param = param.instantiate(self.genv.tcx(), args, &[]);
273 Ok(self.fresh_infer_var(¶m.sort, param.mode))
274 })?)
275 }
276
277 pub fn instantiate_generic_args(&mut self, args: &[GenericArg]) -> Vec<GenericArg> {
278 args.iter()
279 .map(|a| a.replace_holes(|binders, kind| self.fresh_infer_var_for_hole(binders, kind)))
280 .collect_vec()
281 }
282
283 pub fn fresh_infer_var(&self, sort: &Sort, mode: InferMode) -> Expr {
284 match mode {
285 InferMode::KVar => {
286 let fsort = sort.expect_func().expect_mono();
287 let vars = fsort.inputs().iter().cloned().map_into().collect();
288 let kvar = self.fresh_kvar(&[vars], KVarEncoding::Single);
289 Expr::abs(Lambda::bind_with_fsort(kvar, fsort))
290 }
291 InferMode::EVar => self.fresh_evar(),
292 }
293 }
294
295 pub fn fresh_infer_var_for_hole(
296 &mut self,
297 binders: &[BoundVariableKinds],
298 kind: HoleKind,
299 ) -> Expr {
300 match kind {
301 HoleKind::Pred => self.fresh_kvar(binders, KVarEncoding::Conj),
302 HoleKind::Expr(_) => {
303 self.fresh_evar()
307 }
308 }
309 }
310
311 pub fn fresh_kvar_in_scope(
313 &self,
314 binders: &[BoundVariableKinds],
315 scope: &Scope,
316 encoding: KVarEncoding,
317 ) -> Expr {
318 let inner = &mut *self.inner.borrow_mut();
319 inner.kvars.fresh(binders, scope.iter(), encoding)
320 }
321
322 pub fn fresh_kvar(&self, binders: &[BoundVariableKinds], encoding: KVarEncoding) -> Expr {
324 let inner = &mut *self.inner.borrow_mut();
325 inner.kvars.fresh(binders, self.cursor.vars(), encoding)
326 }
327
328 fn fresh_evar(&self) -> Expr {
329 let evars = &mut self.inner.borrow_mut().evars;
330 Expr::evar(evars.fresh(self.cursor.marker()))
331 }
332
333 pub fn unify_exprs(&self, a: &Expr, b: &Expr) {
334 if a.has_evars() {
335 return;
336 }
337 let evars = &mut self.inner.borrow_mut().evars;
338 if let ExprKind::Var(Var::EVar(evid)) = b.kind()
339 && let EVarState::Unsolved(marker) = evars.get(*evid)
340 && !marker.has_free_vars(a)
341 {
342 evars.solve(*evid, a.clone());
343 }
344 }
345
346 fn enter_exists<T, U>(
347 &mut self,
348 t: &Binder<T>,
349 f: impl FnOnce(&mut InferCtxt<'_, 'genv, 'tcx>, T) -> U,
350 ) -> U
351 where
352 T: TypeFoldable,
353 {
354 self.ensure_resolved_evars(|infcx| {
355 let t = t.replace_bound_refts_with(|sort, mode, _| infcx.fresh_infer_var(sort, mode));
356 Ok(f(infcx, t))
357 })
358 .unwrap()
359 }
360
361 pub fn push_evar_scope(&mut self) {
366 self.inner.borrow_mut().evars.push_scope();
367 }
368
369 pub fn pop_evar_scope(&mut self) -> InferResult {
372 self.inner
373 .borrow_mut()
374 .evars
375 .pop_scope()
376 .map_err(InferErr::UnsolvedEvar)
377 }
378
379 pub fn ensure_resolved_evars<R>(
381 &mut self,
382 f: impl FnOnce(&mut Self) -> InferResult<R>,
383 ) -> InferResult<R> {
384 self.push_evar_scope();
385 let r = f(self)?;
386 self.pop_evar_scope()?;
387 Ok(r)
388 }
389
390 pub fn fully_resolve_evars<T: TypeFoldable>(&self, t: &T) -> T {
391 self.inner.borrow().evars.replace_evars(t).unwrap()
392 }
393
394 pub fn tcx(&self) -> TyCtxt<'tcx> {
395 self.genv.tcx()
396 }
397
398 pub fn cursor(&self) -> &Cursor<'infcx> {
399 &self.cursor
400 }
401}
402
403impl<'infcx, 'genv, 'tcx> InferCtxt<'infcx, 'genv, 'tcx> {
405 pub fn change_item<'a>(
406 &'a mut self,
407 def_id: LocalDefId,
408 region_infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
409 ) -> InferCtxt<'a, 'genv, 'tcx> {
410 InferCtxt {
411 def_id: def_id.to_def_id(),
412 cursor: self.cursor.branch(),
413 region_infcx,
414 ..*self
415 }
416 }
417
418 pub fn move_to(&mut self, marker: &Marker, clear_children: bool) -> InferCtxt<'_, 'genv, 'tcx> {
419 InferCtxt {
420 cursor: self
421 .cursor
422 .move_to(marker, clear_children)
423 .unwrap_or_else(|| tracked_span_bug!()),
424 ..*self
425 }
426 }
427
428 pub fn branch(&mut self) -> InferCtxt<'_, 'genv, 'tcx> {
429 InferCtxt { cursor: self.cursor.branch(), ..*self }
430 }
431
432 pub fn define_var(&mut self, sort: &Sort) -> Name {
433 self.cursor.define_var(sort)
434 }
435
436 pub fn check_pred(&mut self, pred: impl Into<Expr>, tag: Tag) {
437 self.cursor.check_pred(pred, tag);
438 }
439
440 pub fn assume_pred(&mut self, pred: impl Into<Expr>) {
441 self.cursor.assume_pred(pred);
442 }
443
444 pub fn unpack(&mut self, ty: &Ty) -> Ty {
445 self.hoister(false).hoist(ty)
446 }
447
448 pub fn marker(&self) -> Marker {
449 self.cursor.marker()
450 }
451
452 pub fn hoister(
453 &mut self,
454 assume_invariants: bool,
455 ) -> Hoister<Unpacker<'_, 'infcx, 'genv, 'tcx>> {
456 Hoister::with_delegate(Unpacker { infcx: self, assume_invariants }).transparent()
457 }
458
459 pub fn assume_invariants(&mut self, ty: &Ty) {
460 self.cursor
461 .assume_invariants(self.genv.tcx(), ty, self.check_overflow);
462 }
463
464 fn check_impl(&mut self, pred1: impl Into<Expr>, pred2: impl Into<Expr>, tag: Tag) {
465 self.cursor.check_impl(pred1, pred2, tag);
466 }
467}
468
469pub struct Unpacker<'a, 'infcx, 'genv, 'tcx> {
470 infcx: &'a mut InferCtxt<'infcx, 'genv, 'tcx>,
471 assume_invariants: bool,
472}
473
474impl HoisterDelegate for Unpacker<'_, '_, '_, '_> {
475 fn hoist_exists(&mut self, ty_ctor: &TyCtor) -> Ty {
476 let ty =
477 ty_ctor.replace_bound_refts_with(|sort, _, _| Expr::fvar(self.infcx.define_var(sort)));
478 if self.assume_invariants {
479 self.infcx.assume_invariants(&ty);
480 }
481 ty
482 }
483
484 fn hoist_constr(&mut self, pred: Expr) {
485 self.infcx.assume_pred(pred);
486 }
487}
488
489impl std::fmt::Debug for InferCtxt<'_, '_, '_> {
490 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
491 std::fmt::Debug::fmt(&self.cursor, f)
492 }
493}
494
495#[derive(Debug)]
496pub struct InferCtxtAt<'a, 'infcx, 'genv, 'tcx> {
497 pub infcx: &'a mut InferCtxt<'infcx, 'genv, 'tcx>,
498 pub span: Span,
499}
500
501impl<'genv, 'tcx> InferCtxtAt<'_, '_, 'genv, 'tcx> {
502 fn tag(&self, reason: ConstrReason) -> Tag {
503 Tag::new(reason, self.span)
504 }
505
506 pub fn check_pred(&mut self, pred: impl Into<Expr>, reason: ConstrReason) {
507 let tag = self.tag(reason);
508 self.infcx.check_pred(pred, tag);
509 }
510
511 pub fn check_non_closure_clauses(
512 &mut self,
513 clauses: &[rty::Clause],
514 reason: ConstrReason,
515 ) -> InferResult {
516 for clause in clauses {
517 if let rty::ClauseKind::Projection(projection_pred) = clause.kind_skipping_binder() {
518 let impl_elem = BaseTy::projection(projection_pred.projection_ty)
519 .to_ty()
520 .normalize_projections(self)?;
521 let term = projection_pred.term.to_ty().normalize_projections(self)?;
522
523 self.subtyping(&impl_elem, &term, reason)?;
525 self.subtyping(&term, &impl_elem, reason)?;
526 }
527 }
528 Ok(())
529 }
530
531 pub fn subtyping_with_env(
534 &mut self,
535 env: &mut impl LocEnv,
536 a: &Ty,
537 b: &Ty,
538 reason: ConstrReason,
539 ) -> InferResult {
540 let mut sub = Sub::new(env, reason, self.span);
541 sub.tys(self.infcx, a, b)
542 }
543
544 pub fn subtyping(
549 &mut self,
550 a: &Ty,
551 b: &Ty,
552 reason: ConstrReason,
553 ) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
554 let mut env = DummyEnv;
555 let mut sub = Sub::new(&mut env, reason, self.span);
556 sub.tys(self.infcx, a, b)?;
557 Ok(sub.obligations)
558 }
559
560 pub fn subtyping_generic_args(
561 &mut self,
562 variance: Variance,
563 a: &GenericArg,
564 b: &GenericArg,
565 reason: ConstrReason,
566 ) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
567 let mut env = DummyEnv;
568 let mut sub = Sub::new(&mut env, reason, self.span);
569 sub.generic_args(self.infcx, variance, a, b)?;
570 Ok(sub.obligations)
571 }
572
573 pub fn check_constructor(
577 &mut self,
578 variant: EarlyBinder<PolyVariant>,
579 generic_args: &[GenericArg],
580 fields: &[Ty],
581 reason: ConstrReason,
582 ) -> InferResult<Ty> {
583 let ret = self.ensure_resolved_evars(|this| {
584 let generic_args = this.instantiate_generic_args(generic_args);
586
587 let variant = variant
588 .instantiate(this.tcx(), &generic_args, &[])
589 .replace_bound_refts_with(|sort, mode, _| this.fresh_infer_var(sort, mode));
590
591 for (actual, formal) in iter::zip(fields, variant.fields()) {
593 this.subtyping(actual, formal, reason)?;
594 }
595
596 for require in &variant.requires {
598 this.check_pred(require, ConstrReason::Fold);
599 }
600
601 Ok(variant.ret())
602 })?;
603 Ok(self.fully_resolve_evars(&ret))
604 }
605
606 pub fn ensure_resolved_evars<R>(
607 &mut self,
608 f: impl FnOnce(&mut InferCtxtAt<'_, '_, 'genv, 'tcx>) -> InferResult<R>,
609 ) -> InferResult<R> {
610 self.infcx
611 .ensure_resolved_evars(|infcx| f(&mut infcx.at(self.span)))
612 }
613}
614
615impl<'a, 'genv, 'tcx> std::ops::Deref for InferCtxtAt<'_, 'a, 'genv, 'tcx> {
616 type Target = InferCtxt<'a, 'genv, 'tcx>;
617
618 fn deref(&self) -> &Self::Target {
619 self.infcx
620 }
621}
622
623impl std::ops::DerefMut for InferCtxtAt<'_, '_, '_, '_> {
624 fn deref_mut(&mut self) -> &mut Self::Target {
625 self.infcx
626 }
627}
628
629#[derive(TypeVisitable, TypeFoldable)]
633pub(crate) enum TypeTrace {
634 Types(Ty, Ty),
635 BaseTys(BaseTy, BaseTy),
636}
637
638#[expect(dead_code, reason = "we use this for debugging some time")]
639impl TypeTrace {
640 fn tys(a: &Ty, b: &Ty) -> Self {
641 Self::Types(a.clone(), b.clone())
642 }
643
644 fn btys(a: &BaseTy, b: &BaseTy) -> Self {
645 Self::BaseTys(a.clone(), b.clone())
646 }
647}
648
649impl fmt::Debug for TypeTrace {
650 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
651 match self {
652 TypeTrace::Types(a, b) => write!(f, "{a:?} - {b:?}"),
653 TypeTrace::BaseTys(a, b) => write!(f, "{a:?} - {b:?}"),
654 }
655 }
656}
657
658pub trait LocEnv {
659 fn ptr_to_ref(
660 &mut self,
661 infcx: &mut InferCtxtAt,
662 reason: ConstrReason,
663 re: Region,
664 path: &Path,
665 bound: Ty,
666 ) -> InferResult<Ty>;
667
668 fn unfold_strg_ref(&mut self, infcx: &mut InferCtxt, path: &Path, ty: &Ty) -> InferResult<Loc>;
669
670 fn get(&self, path: &Path) -> Ty;
671}
672
673struct DummyEnv;
674
675impl LocEnv for DummyEnv {
676 fn ptr_to_ref(
677 &mut self,
678 _: &mut InferCtxtAt,
679 _: ConstrReason,
680 _: Region,
681 _: &Path,
682 _: Ty,
683 ) -> InferResult<Ty> {
684 bug!("call to `ptr_to_ref` on `DummyEnv`")
685 }
686
687 fn unfold_strg_ref(&mut self, _: &mut InferCtxt, _: &Path, _: &Ty) -> InferResult<Loc> {
688 bug!("call to `unfold_str_ref` on `DummyEnv`")
689 }
690
691 fn get(&self, _: &Path) -> Ty {
692 bug!("call to `get` on `DummyEnv`")
693 }
694}
695
696struct Sub<'a, E> {
698 env: &'a mut E,
700 reason: ConstrReason,
701 span: Span,
702 obligations: Vec<Binder<rty::CoroutineObligPredicate>>,
706}
707
708impl<'a, E: LocEnv> Sub<'a, E> {
709 fn new(env: &'a mut E, reason: ConstrReason, span: Span) -> Self {
710 Self { env, reason, span, obligations: vec![] }
711 }
712
713 fn tag(&self) -> Tag {
714 Tag::new(self.reason, self.span)
715 }
716
717 fn tys(&mut self, infcx: &mut InferCtxt, a: &Ty, b: &Ty) -> InferResult {
718 let infcx = &mut infcx.branch();
719 let a = infcx.unpack(a);
725
726 match (a.kind(), b.kind()) {
727 (TyKind::Exists(..), _) => {
728 bug!("existentials should have been removed by the unpacking above");
729 }
730 (TyKind::Constr(..), _) => {
731 bug!("constraint types should have been removed by the unpacking above");
732 }
733
734 (_, TyKind::Exists(ctor_b)) => {
735 infcx.enter_exists(ctor_b, |infcx, ty_b| self.tys(infcx, &a, &ty_b))
736 }
737 (_, TyKind::Constr(pred_b, ty_b)) => {
738 infcx.check_pred(pred_b, self.tag());
739 self.tys(infcx, &a, ty_b)
740 }
741
742 (TyKind::Ptr(PtrKind::Mut(_), path_a), TyKind::StrgRef(_, path_b, ty_b)) => {
743 let ty_a = self.env.get(path_a);
747 infcx.unify_exprs(&path_a.to_expr(), &path_b.to_expr());
748 self.tys(infcx, &ty_a, ty_b)
749 }
750 (TyKind::StrgRef(_, path_a, ty_a), TyKind::StrgRef(_, path_b, ty_b)) => {
751 self.env.unfold_strg_ref(infcx, path_a, ty_a)?;
763 let ty_a = self.env.get(path_a);
764 infcx.unify_exprs(&path_a.to_expr(), &path_b.to_expr());
765 self.tys(infcx, &ty_a, ty_b)
766 }
767 (
768 TyKind::Ptr(PtrKind::Mut(re), path),
769 TyKind::Indexed(BaseTy::Ref(_, bound, Mutability::Mut), idx),
770 ) => {
771 self.idxs_eq(infcx, &Expr::unit(), idx);
774
775 self.env.ptr_to_ref(
776 &mut infcx.at(self.span),
777 self.reason,
778 *re,
779 path,
780 bound.clone(),
781 )?;
782 Ok(())
783 }
784
785 (TyKind::Indexed(bty_a, idx_a), TyKind::Indexed(bty_b, idx_b)) => {
786 self.btys(infcx, bty_a, bty_b)?;
787 self.idxs_eq(infcx, idx_a, idx_b);
788 Ok(())
789 }
790 (TyKind::Ptr(pk_a, path_a), TyKind::Ptr(pk_b, path_b)) => {
791 debug_assert_eq!(pk_a, pk_b);
792 debug_assert_eq!(path_a, path_b);
793 Ok(())
794 }
795 (TyKind::Param(param_ty_a), TyKind::Param(param_ty_b)) => {
796 debug_assert_eq!(param_ty_a, param_ty_b);
797 Ok(())
798 }
799 (_, TyKind::Uninit) => Ok(()),
800 (TyKind::Downcast(.., fields_a), TyKind::Downcast(.., fields_b)) => {
801 debug_assert_eq!(fields_a.len(), fields_b.len());
802 for (ty_a, ty_b) in iter::zip(fields_a, fields_b) {
803 self.tys(infcx, ty_a, ty_b)?;
804 }
805 Ok(())
806 }
807 _ => Err(query_bug!("incompatible types: `{a:?}` - `{b:?}`"))?,
808 }
809 }
810
811 fn btys(&mut self, infcx: &mut InferCtxt, a: &BaseTy, b: &BaseTy) -> InferResult {
812 match (a, b) {
815 (BaseTy::Int(int_ty_a), BaseTy::Int(int_ty_b)) => {
816 debug_assert_eq!(int_ty_a, int_ty_b);
817 Ok(())
818 }
819 (BaseTy::Uint(uint_ty_a), BaseTy::Uint(uint_ty_b)) => {
820 debug_assert_eq!(uint_ty_a, uint_ty_b);
821 Ok(())
822 }
823 (BaseTy::Adt(a_adt, a_args), BaseTy::Adt(b_adt, b_args)) => {
824 tracked_span_dbg_assert_eq!(a_adt.did(), b_adt.did());
825 tracked_span_dbg_assert_eq!(a_args.len(), b_args.len());
826 let variances = infcx.genv.variances_of(a_adt.did());
827 for (variance, ty_a, ty_b) in izip!(variances, a_args.iter(), b_args.iter()) {
828 self.generic_args(infcx, *variance, ty_a, ty_b)?;
829 }
830 Ok(())
831 }
832 (BaseTy::FnDef(a_def_id, a_args), BaseTy::FnDef(b_def_id, b_args)) => {
833 debug_assert_eq!(a_def_id, b_def_id);
834 debug_assert_eq!(a_args.len(), b_args.len());
835 for (arg_a, arg_b) in iter::zip(a_args, b_args) {
843 match (arg_a, arg_b) {
844 (GenericArg::Ty(ty_a), GenericArg::Ty(ty_b)) => {
845 let bty_a = ty_a.as_bty_skipping_existentials();
846 let bty_b = ty_b.as_bty_skipping_existentials();
847 tracked_span_dbg_assert_eq!(bty_a, bty_b);
848 }
849 (GenericArg::Base(ctor_a), GenericArg::Base(ctor_b)) => {
850 let bty_a = ctor_a.as_bty_skipping_binder();
851 let bty_b = ctor_b.as_bty_skipping_binder();
852 tracked_span_dbg_assert_eq!(bty_a, bty_b);
853 }
854 (_, _) => tracked_span_dbg_assert_eq!(arg_a, arg_b),
855 }
856 }
857 Ok(())
858 }
859 (BaseTy::Float(float_ty_a), BaseTy::Float(float_ty_b)) => {
860 debug_assert_eq!(float_ty_a, float_ty_b);
861 Ok(())
862 }
863 (BaseTy::Slice(ty_a), BaseTy::Slice(ty_b)) => self.tys(infcx, ty_a, ty_b),
864 (BaseTy::Ref(_, ty_a, Mutability::Mut), BaseTy::Ref(_, ty_b, Mutability::Mut)) => {
865 self.tys(infcx, ty_a, ty_b)?;
866 self.tys(infcx, ty_b, ty_a)
867 }
868 (BaseTy::Ref(_, ty_a, Mutability::Not), BaseTy::Ref(_, ty_b, Mutability::Not)) => {
869 self.tys(infcx, ty_a, ty_b)
870 }
871 (BaseTy::Tuple(tys_a), BaseTy::Tuple(tys_b)) => {
872 debug_assert_eq!(tys_a.len(), tys_b.len());
873 for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
874 self.tys(infcx, ty_a, ty_b)?;
875 }
876 Ok(())
877 }
878 (_, BaseTy::Alias(AliasKind::Opaque, alias_ty_b)) => {
879 if let BaseTy::Alias(AliasKind::Opaque, alias_ty_a) = a {
880 debug_assert_eq!(alias_ty_a.refine_args.len(), alias_ty_b.refine_args.len());
881 iter::zip(alias_ty_a.refine_args.iter(), alias_ty_b.refine_args.iter())
882 .for_each(|(expr_a, expr_b)| infcx.unify_exprs(expr_a, expr_b));
883 }
884 self.handle_opaque_type(infcx, a, alias_ty_b)
885 }
886 (
887 BaseTy::Alias(AliasKind::Projection, alias_ty_a),
888 BaseTy::Alias(AliasKind::Projection, alias_ty_b),
889 ) => {
890 tracked_span_dbg_assert_eq!(alias_ty_a, alias_ty_b);
891 Ok(())
892 }
893 (BaseTy::Array(ty_a, len_a), BaseTy::Array(ty_b, len_b)) => {
894 tracked_span_dbg_assert_eq!(len_a, len_b);
895 self.tys(infcx, ty_a, ty_b)
896 }
897 (BaseTy::Param(param_a), BaseTy::Param(param_b)) => {
898 debug_assert_eq!(param_a, param_b);
899 Ok(())
900 }
901 (BaseTy::Bool, BaseTy::Bool)
902 | (BaseTy::Str, BaseTy::Str)
903 | (BaseTy::Char, BaseTy::Char)
904 | (BaseTy::RawPtr(_, _), BaseTy::RawPtr(_, _))
905 | (BaseTy::RawPtrMetadata(_), BaseTy::RawPtrMetadata(_)) => Ok(()),
906 (BaseTy::Dynamic(preds_a, _), BaseTy::Dynamic(preds_b, _)) => {
907 tracked_span_assert_eq!(preds_a.erase_regions(), preds_b.erase_regions());
908 Ok(())
909 }
910 (BaseTy::Closure(did1, tys_a, _), BaseTy::Closure(did2, tys_b, _)) if did1 == did2 => {
911 debug_assert_eq!(tys_a.len(), tys_b.len());
912 for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
913 self.tys(infcx, ty_a, ty_b)?;
914 }
915 Ok(())
916 }
917 (BaseTy::FnPtr(sig_a), BaseTy::FnPtr(sig_b)) => {
918 tracked_span_assert_eq!(sig_a.erase_regions(), sig_b.erase_regions());
919 Ok(())
920 }
921 (BaseTy::Never, BaseTy::Never) => Ok(()),
922 _ => Err(query_bug!("incompatible base types: `{a:?}` - `{b:?}`"))?,
923 }
924 }
925
926 fn generic_args(
927 &mut self,
928 infcx: &mut InferCtxt,
929 variance: Variance,
930 a: &GenericArg,
931 b: &GenericArg,
932 ) -> InferResult {
933 let (ty_a, ty_b) = match (a, b) {
934 (GenericArg::Ty(ty_a), GenericArg::Ty(ty_b)) => (ty_a.clone(), ty_b.clone()),
935 (GenericArg::Base(ctor_a), GenericArg::Base(ctor_b)) => {
936 tracked_span_dbg_assert_eq!(ctor_a.sort(), ctor_b.sort());
937 (ctor_a.to_ty(), ctor_b.to_ty())
938 }
939 (GenericArg::Lifetime(_), GenericArg::Lifetime(_)) => return Ok(()),
940 (GenericArg::Const(cst_a), GenericArg::Const(cst_b)) => {
941 debug_assert_eq!(cst_a, cst_b);
942 return Ok(());
943 }
944 _ => Err(query_bug!("incompatible generic args: `{a:?}` `{b:?}`"))?,
945 };
946 match variance {
947 Variance::Covariant => self.tys(infcx, &ty_a, &ty_b),
948 Variance::Invariant => {
949 self.tys(infcx, &ty_a, &ty_b)?;
950 self.tys(infcx, &ty_b, &ty_a)
951 }
952 Variance::Contravariant => self.tys(infcx, &ty_b, &ty_a),
953 Variance::Bivariant => Ok(()),
954 }
955 }
956
957 fn idxs_eq(&mut self, infcx: &mut InferCtxt, a: &Expr, b: &Expr) {
958 if a == b {
959 return;
960 }
961 match (a.kind(), b.kind()) {
962 (
963 ExprKind::Ctor(Ctor::Struct(did_a), flds_a),
964 ExprKind::Ctor(Ctor::Struct(did_b), flds_b),
965 ) => {
966 debug_assert_eq!(did_a, did_b);
967 for (a, b) in iter::zip(flds_a, flds_b) {
968 self.idxs_eq(infcx, a, b);
969 }
970 }
971 (ExprKind::Tuple(flds_a), ExprKind::Tuple(flds_b)) => {
972 for (a, b) in iter::zip(flds_a, flds_b) {
973 self.idxs_eq(infcx, a, b);
974 }
975 }
976 (_, ExprKind::Tuple(flds_b)) => {
977 for (f, b) in flds_b.iter().enumerate() {
978 let proj = FieldProj::Tuple { arity: flds_b.len(), field: f as u32 };
979 let a = a.proj_and_reduce(proj);
980 self.idxs_eq(infcx, &a, b);
981 }
982 }
983
984 (_, ExprKind::Ctor(Ctor::Struct(def_id), flds_b)) => {
985 for (f, b) in flds_b.iter().enumerate() {
986 let proj = FieldProj::Adt { def_id: *def_id, field: f as u32 };
987 let a = a.proj_and_reduce(proj);
988 self.idxs_eq(infcx, &a, b);
989 }
990 }
991
992 (ExprKind::Tuple(flds_a), _) => {
993 infcx.unify_exprs(a, b);
994 for (f, a) in flds_a.iter().enumerate() {
995 let proj = FieldProj::Tuple { arity: flds_a.len(), field: f as u32 };
996 let b = b.proj_and_reduce(proj);
997 self.idxs_eq(infcx, a, &b);
998 }
999 }
1000 (ExprKind::Ctor(Ctor::Struct(def_id), flds_a), _) => {
1001 infcx.unify_exprs(a, b);
1002 for (f, a) in flds_a.iter().enumerate() {
1003 let proj = FieldProj::Adt { def_id: *def_id, field: f as u32 };
1004 let b = b.proj_and_reduce(proj);
1005 self.idxs_eq(infcx, a, &b);
1006 }
1007 }
1008 (ExprKind::Abs(lam_a), ExprKind::Abs(lam_b)) => {
1009 self.abs_eq(infcx, lam_a, lam_b);
1010 }
1011 (_, ExprKind::Abs(lam_b)) => {
1012 self.abs_eq(infcx, &a.eta_expand_abs(lam_b.vars(), lam_b.output()), lam_b);
1013 }
1014 (ExprKind::Abs(lam_a), _) => {
1015 infcx.unify_exprs(a, b);
1016 self.abs_eq(infcx, lam_a, &b.eta_expand_abs(lam_a.vars(), lam_a.output()));
1017 }
1018 (ExprKind::KVar(_), _) | (_, ExprKind::KVar(_)) => {
1019 infcx.check_impl(a, b, self.tag());
1020 infcx.check_impl(b, a, self.tag());
1021 }
1022 _ => {
1023 infcx.unify_exprs(a, b);
1024 let span = b.span();
1025 infcx.check_pred(Expr::binary_op(rty::BinOp::Eq, a, b).at_opt(span), self.tag());
1026 }
1027 }
1028 }
1029
1030 fn abs_eq(&mut self, infcx: &mut InferCtxt, a: &Lambda, b: &Lambda) {
1031 debug_assert_eq!(a.vars().len(), b.vars().len());
1032 let vars = a
1033 .vars()
1034 .iter()
1035 .map(|kind| Expr::fvar(infcx.define_var(kind.expect_sort())))
1036 .collect_vec();
1037 let body_a = a.apply(&vars);
1038 let body_b = b.apply(&vars);
1039 self.idxs_eq(infcx, &body_a, &body_b);
1040 }
1041
1042 fn handle_opaque_type(
1043 &mut self,
1044 infcx: &mut InferCtxt,
1045 bty: &BaseTy,
1046 alias_ty: &AliasTy,
1047 ) -> InferResult {
1048 if let BaseTy::Coroutine(def_id, resume_ty, upvar_tys) = bty {
1049 let obligs = mk_coroutine_obligations(
1050 infcx.genv,
1051 def_id,
1052 resume_ty,
1053 upvar_tys,
1054 &alias_ty.def_id,
1055 )?;
1056 self.obligations.extend(obligs);
1057 } else {
1058 let bounds = infcx.genv.item_bounds(alias_ty.def_id)?.instantiate(
1059 infcx.tcx(),
1060 &alias_ty.args,
1061 &alias_ty.refine_args,
1062 );
1063 for clause in &bounds {
1064 if !clause.kind().vars().is_empty() {
1065 Err(query_bug!("handle_opaque_types: clause with bound vars: `{clause:?}`"))?;
1066 }
1067 if let rty::ClauseKind::Projection(pred) = clause.kind_skipping_binder() {
1068 let alias_ty = pred.projection_ty.with_self_ty(bty.to_subset_ty_ctor());
1069 let ty1 = BaseTy::Alias(AliasKind::Projection, alias_ty)
1070 .to_ty()
1071 .normalize_projections(&mut infcx.at(self.span))?;
1072 let ty2 = pred.term.to_ty();
1073 self.tys(infcx, &ty1, &ty2)?;
1074 }
1075 }
1076 }
1077 Ok(())
1078 }
1079}
1080
1081fn mk_coroutine_obligations(
1082 genv: GlobalEnv,
1083 generator_did: &DefId,
1084 resume_ty: &Ty,
1085 upvar_tys: &List<Ty>,
1086 opaque_def_id: &DefId,
1087) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
1088 let bounds = genv.item_bounds(*opaque_def_id)?.skip_binder();
1089 for bound in &bounds {
1090 if let Some(proj_clause) = bound.as_projection_clause() {
1091 return Ok(vec![proj_clause.map(|proj_clause| {
1092 let output = proj_clause.term;
1093 CoroutineObligPredicate {
1094 def_id: *generator_did,
1095 resume_ty: resume_ty.clone(),
1096 upvar_tys: upvar_tys.clone(),
1097 output: output.to_ty(),
1098 }
1099 })]);
1100 }
1101 }
1102 bug!("no projection predicate")
1103}
1104
1105#[derive(Debug)]
1106pub enum InferErr {
1107 UnsolvedEvar(EVid),
1108 Query(QueryErr),
1109}
1110
1111impl From<QueryErr> for InferErr {
1112 fn from(v: QueryErr) -> Self {
1113 Self::Query(v)
1114 }
1115}
1116
1117mod pretty {
1118 use std::fmt;
1119
1120 use flux_middle::pretty::*;
1121
1122 use super::*;
1123
1124 impl Pretty for Tag {
1125 fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1126 w!(cx, f, "{:?} at {:?}", ^self.reason, self.src_span)
1127 }
1128 }
1129
1130 impl_debug_with_default_cx!(Tag);
1131}