1use std::{cell::RefCell, fmt, iter};
2
3use flux_common::{bug, dbg, tracked_span_assert_eq, tracked_span_bug, tracked_span_dbg_assert_eq};
4use flux_config::{self as config, InferOpts, OverflowMode};
5use flux_macros::{TypeFoldable, TypeVisitable};
6use flux_middle::{
7 FixpointQueryKind,
8 def_id::MaybeExternId,
9 global_env::GlobalEnv,
10 queries::{QueryErr, QueryResult},
11 query_bug,
12 rty::{
13 self, AliasKind, AliasTy, BaseTy, Binder, BoundVariableKinds, CoroutineObligPredicate,
14 Ctor, ESpan, EVid, EarlyBinder, Expr, ExprKind, FieldProj, GenericArg, HoleKind, InferMode,
15 Lambda, List, Loc, Mutability, Name, Path, PolyVariant, PtrKind, RefineArgs, RefineArgsExt,
16 Region, Sort, Ty, TyCtor, TyKind, Var,
17 canonicalize::{Hoister, HoisterDelegate},
18 fold::TypeFoldable,
19 },
20};
21use itertools::{Itertools, izip};
22use rustc_hir::def_id::{DefId, LocalDefId};
23use rustc_macros::extension;
24use rustc_middle::{
25 mir::BasicBlock,
26 ty::{TyCtxt, Variance},
27};
28use rustc_span::Span;
29
30use crate::{
31 evars::{EVarState, EVarStore},
32 fixpoint_encoding::{FixQueryCache, FixpointCtxt, KVarEncoding, KVarGen},
33 projections::NormalizeExt as _,
34 refine_tree::{Cursor, Marker, RefineTree, Scope},
35};
36
37pub type InferResult<T = ()> = std::result::Result<T, InferErr>;
38
39#[derive(PartialEq, Eq, Clone, Copy, Hash)]
40pub struct Tag {
41 pub reason: ConstrReason,
42 pub src_span: Span,
43 pub dst_span: Option<ESpan>,
44}
45
46impl Tag {
47 pub fn new(reason: ConstrReason, span: Span) -> Self {
48 Self { reason, src_span: span, dst_span: None }
49 }
50
51 pub fn with_dst(self, dst_span: Option<ESpan>) -> Self {
52 Self { dst_span, ..self }
53 }
54}
55
56#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)]
57pub enum SubtypeReason {
58 Input,
59 Output,
60 Requires,
61 Ensures,
62}
63
64#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)]
65pub enum ConstrReason {
66 Call,
67 Assign,
68 Ret,
69 Fold,
70 FoldLocal,
71 Predicate,
72 Assert(&'static str),
73 Div,
74 Rem,
75 Goto(BasicBlock),
76 Overflow,
77 Underflow,
78 Subtype(SubtypeReason),
79 Other,
80}
81
82pub struct InferCtxtRoot<'genv, 'tcx> {
83 pub genv: GlobalEnv<'genv, 'tcx>,
84 inner: RefCell<InferCtxtInner>,
85 refine_tree: RefineTree,
86 opts: InferOpts,
87}
88
89pub struct InferCtxtRootBuilder<'a, 'genv, 'tcx> {
90 genv: GlobalEnv<'genv, 'tcx>,
91 opts: InferOpts,
92 params: Vec<(Var, Sort)>,
93 infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
94 dummy_kvars: bool,
95}
96
97#[extension(pub trait GlobalEnvExt<'genv, 'tcx>)]
98impl<'genv, 'tcx> GlobalEnv<'genv, 'tcx> {
99 fn infcx_root<'a>(
100 self,
101 infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
102 opts: InferOpts,
103 ) -> InferCtxtRootBuilder<'a, 'genv, 'tcx> {
104 InferCtxtRootBuilder { genv: self, infcx, params: vec![], opts, dummy_kvars: false }
105 }
106}
107
108impl<'genv, 'tcx> InferCtxtRootBuilder<'_, 'genv, 'tcx> {
109 pub fn with_dummy_kvars(mut self) -> Self {
110 self.dummy_kvars = true;
111 self
112 }
113
114 pub fn with_const_generics(mut self, def_id: DefId) -> QueryResult<Self> {
115 self.params.extend(
116 self.genv
117 .generics_of(def_id)?
118 .const_params(self.genv)?
119 .into_iter()
120 .map(|(pcst, sort)| (Var::ConstGeneric(pcst), sort)),
121 );
122 Ok(self)
123 }
124
125 pub fn with_refinement_generics(
126 mut self,
127 def_id: DefId,
128 args: &[GenericArg],
129 ) -> QueryResult<Self> {
130 for (index, param) in self
131 .genv
132 .refinement_generics_of(def_id)?
133 .iter_own_params()
134 .enumerate()
135 {
136 let param = param.instantiate(self.genv.tcx(), args, &[]);
137 let sort = param
138 .sort
139 .deeply_normalize_sorts(def_id, self.genv, self.infcx)?;
140
141 let var =
142 Var::EarlyParam(rty::EarlyReftParam { index: index as u32, name: param.name });
143 self.params.push((var, sort));
144 }
145 Ok(self)
146 }
147
148 pub fn identity_for_item(mut self, def_id: DefId) -> QueryResult<Self> {
149 self = self.with_const_generics(def_id)?;
150 let offset = self.params.len();
151 self.genv.refinement_generics_of(def_id)?.fill_item(
152 self.genv,
153 &mut self.params,
154 &mut |param, index| {
155 let index = (index - offset) as u32;
156 let param = param.instantiate_identity();
157 let sort = param
158 .sort
159 .deeply_normalize_sorts(def_id, self.genv, self.infcx)?;
160
161 let var = Var::EarlyParam(rty::EarlyReftParam { index, name: param.name });
162 Ok((var, sort))
163 },
164 )?;
165 Ok(self)
166 }
167
168 pub fn build(self) -> QueryResult<InferCtxtRoot<'genv, 'tcx>> {
169 Ok(InferCtxtRoot {
170 genv: self.genv,
171 inner: RefCell::new(InferCtxtInner::new(self.dummy_kvars)),
172 refine_tree: RefineTree::new(self.params),
173 opts: self.opts,
174 })
175 }
176}
177
178impl<'genv, 'tcx> InferCtxtRoot<'genv, 'tcx> {
179 pub fn infcx<'a>(
180 &'a mut self,
181 def_id: DefId,
182 region_infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
183 ) -> InferCtxt<'a, 'genv, 'tcx> {
184 InferCtxt {
185 genv: self.genv,
186 region_infcx,
187 def_id,
188 cursor: self.refine_tree.cursor_at_root(),
189 inner: &self.inner,
190 check_overflow: self.opts.check_overflow,
191 }
192 }
193
194 pub fn fresh_kvar_in_scope(
195 &self,
196 binders: &[BoundVariableKinds],
197 scope: &Scope,
198 encoding: KVarEncoding,
199 ) -> Expr {
200 let inner = &mut *self.inner.borrow_mut();
201 inner.kvars.fresh(binders, scope.iter(), encoding)
202 }
203
204 pub fn execute_lean_query(self, def_id: MaybeExternId) -> QueryResult<()> {
205 let inner = self.inner.into_inner();
206 let kvars = inner.kvars;
207 let evars = inner.evars;
208 let mut refine_tree = self.refine_tree;
209 refine_tree.replace_evars(&evars).unwrap();
210 refine_tree.simplify(self.genv);
211
212 let mut fcx = FixpointCtxt::new(self.genv, def_id, kvars);
213 let cstr = refine_tree.into_fixpoint(&mut fcx)?;
214 fcx.generate_and_check_lean_lemmas(cstr)
215 }
216
217 pub fn execute_fixpoint_query(
218 self,
219 cache: &mut FixQueryCache,
220 def_id: MaybeExternId,
221 kind: FixpointQueryKind,
222 ) -> QueryResult<Vec<Tag>> {
223 let inner = self.inner.into_inner();
224 let kvars = inner.kvars;
225 let evars = inner.evars;
226
227 let ext = kind.ext();
228
229 let mut refine_tree = self.refine_tree;
230
231 refine_tree.replace_evars(&evars).unwrap();
232
233 if config::dump_constraint() {
234 dbg::dump_item_info(self.genv.tcx(), def_id.resolved_id(), ext, &refine_tree).unwrap();
235 }
236 refine_tree.simplify(self.genv);
237 if config::dump_constraint() {
238 let simp_ext = format!("simp.{ext}");
239 dbg::dump_item_info(self.genv.tcx(), def_id.resolved_id(), simp_ext, &refine_tree)
240 .unwrap();
241 }
242
243 let mut fcx = FixpointCtxt::new(self.genv, def_id, kvars);
244 let cstr = refine_tree.into_fixpoint(&mut fcx)?;
245
246 let backend = match self.opts.solver {
247 flux_config::SmtSolver::Z3 => liquid_fixpoint::SmtSolver::Z3,
248 flux_config::SmtSolver::CVC5 => liquid_fixpoint::SmtSolver::CVC5,
249 };
250
251 fcx.check(cache, cstr, kind, self.opts.scrape_quals, backend)
252 }
253
254 pub fn split(self) -> (RefineTree, KVarGen) {
255 (self.refine_tree, self.inner.into_inner().kvars)
256 }
257}
258
259pub struct InferCtxt<'infcx, 'genv, 'tcx> {
260 pub genv: GlobalEnv<'genv, 'tcx>,
261 pub region_infcx: &'infcx rustc_infer::infer::InferCtxt<'tcx>,
262 pub def_id: DefId,
263 pub check_overflow: OverflowMode,
264 cursor: Cursor<'infcx>,
265 inner: &'infcx RefCell<InferCtxtInner>,
266}
267
268struct InferCtxtInner {
269 kvars: KVarGen,
270 evars: EVarStore,
271}
272
273impl InferCtxtInner {
274 fn new(dummy_kvars: bool) -> Self {
275 Self { kvars: KVarGen::new(dummy_kvars), evars: Default::default() }
276 }
277}
278
279impl<'infcx, 'genv, 'tcx> InferCtxt<'infcx, 'genv, 'tcx> {
280 pub fn at(&mut self, span: Span) -> InferCtxtAt<'_, 'infcx, 'genv, 'tcx> {
281 InferCtxtAt { infcx: self, span }
282 }
283
284 pub fn instantiate_refine_args(
285 &mut self,
286 callee_def_id: DefId,
287 args: &[rty::GenericArg],
288 ) -> InferResult<List<Expr>> {
289 Ok(RefineArgs::for_item(self.genv, callee_def_id, |param, _| {
290 let param = param.instantiate(self.genv.tcx(), args, &[]);
291 Ok(self.fresh_infer_var(¶m.sort, param.mode))
292 })?)
293 }
294
295 pub fn instantiate_generic_args(&mut self, args: &[GenericArg]) -> Vec<GenericArg> {
296 args.iter()
297 .map(|a| a.replace_holes(|binders, kind| self.fresh_infer_var_for_hole(binders, kind)))
298 .collect_vec()
299 }
300
301 pub fn fresh_infer_var(&self, sort: &Sort, mode: InferMode) -> Expr {
302 match mode {
303 InferMode::KVar => {
304 let fsort = sort.expect_func().expect_mono();
305 let vars = fsort.inputs().iter().cloned().map_into().collect();
306 let kvar = self.fresh_kvar(&[vars], KVarEncoding::Single);
307 Expr::abs(Lambda::bind_with_fsort(kvar, fsort))
308 }
309 InferMode::EVar => self.fresh_evar(),
310 }
311 }
312
313 pub fn fresh_infer_var_for_hole(
314 &mut self,
315 binders: &[BoundVariableKinds],
316 kind: HoleKind,
317 ) -> Expr {
318 match kind {
319 HoleKind::Pred => self.fresh_kvar(binders, KVarEncoding::Conj),
320 HoleKind::Expr(_) => {
321 self.fresh_evar()
325 }
326 }
327 }
328
329 pub fn fresh_kvar_in_scope(
331 &self,
332 binders: &[BoundVariableKinds],
333 scope: &Scope,
334 encoding: KVarEncoding,
335 ) -> Expr {
336 let inner = &mut *self.inner.borrow_mut();
337 inner.kvars.fresh(binders, scope.iter(), encoding)
338 }
339
340 pub fn fresh_kvar(&self, binders: &[BoundVariableKinds], encoding: KVarEncoding) -> Expr {
342 let inner = &mut *self.inner.borrow_mut();
343 inner.kvars.fresh(binders, self.cursor.vars(), encoding)
344 }
345
346 fn fresh_evar(&self) -> Expr {
347 let evars = &mut self.inner.borrow_mut().evars;
348 Expr::evar(evars.fresh(self.cursor.marker()))
349 }
350
351 pub fn unify_exprs(&self, a: &Expr, b: &Expr) {
352 if a.has_evars() {
353 return;
354 }
355 let evars = &mut self.inner.borrow_mut().evars;
356 if let ExprKind::Var(Var::EVar(evid)) = b.kind()
357 && let EVarState::Unsolved(marker) = evars.get(*evid)
358 && !marker.has_free_vars(a)
359 {
360 evars.solve(*evid, a.clone());
361 }
362 }
363
364 fn enter_exists<T, U>(
365 &mut self,
366 t: &Binder<T>,
367 f: impl FnOnce(&mut InferCtxt<'_, 'genv, 'tcx>, T) -> U,
368 ) -> U
369 where
370 T: TypeFoldable,
371 {
372 self.ensure_resolved_evars(|infcx| {
373 let t = t.replace_bound_refts_with(|sort, mode, _| infcx.fresh_infer_var(sort, mode));
374 Ok(f(infcx, t))
375 })
376 .unwrap()
377 }
378
379 pub fn push_evar_scope(&mut self) {
384 self.inner.borrow_mut().evars.push_scope();
385 }
386
387 pub fn pop_evar_scope(&mut self) -> InferResult {
390 self.inner
391 .borrow_mut()
392 .evars
393 .pop_scope()
394 .map_err(InferErr::UnsolvedEvar)
395 }
396
397 pub fn ensure_resolved_evars<R>(
399 &mut self,
400 f: impl FnOnce(&mut Self) -> InferResult<R>,
401 ) -> InferResult<R> {
402 self.push_evar_scope();
403 let r = f(self)?;
404 self.pop_evar_scope()?;
405 Ok(r)
406 }
407
408 pub fn fully_resolve_evars<T: TypeFoldable>(&self, t: &T) -> T {
409 self.inner.borrow().evars.replace_evars(t).unwrap()
410 }
411
412 pub fn tcx(&self) -> TyCtxt<'tcx> {
413 self.genv.tcx()
414 }
415
416 pub fn cursor(&self) -> &Cursor<'infcx> {
417 &self.cursor
418 }
419}
420
421impl<'infcx, 'genv, 'tcx> InferCtxt<'infcx, 'genv, 'tcx> {
423 pub fn change_item<'a>(
424 &'a mut self,
425 def_id: LocalDefId,
426 region_infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
427 ) -> InferCtxt<'a, 'genv, 'tcx> {
428 InferCtxt {
429 def_id: def_id.to_def_id(),
430 cursor: self.cursor.branch(),
431 region_infcx,
432 ..*self
433 }
434 }
435
436 pub fn move_to(&mut self, marker: &Marker, clear_children: bool) -> InferCtxt<'_, 'genv, 'tcx> {
437 InferCtxt {
438 cursor: self
439 .cursor
440 .move_to(marker, clear_children)
441 .unwrap_or_else(|| tracked_span_bug!()),
442 ..*self
443 }
444 }
445
446 pub fn branch(&mut self) -> InferCtxt<'_, 'genv, 'tcx> {
447 InferCtxt { cursor: self.cursor.branch(), ..*self }
448 }
449
450 pub fn define_var(&mut self, sort: &Sort) -> Name {
451 self.cursor.define_var(sort)
452 }
453
454 pub fn check_pred(&mut self, pred: impl Into<Expr>, tag: Tag) {
455 self.cursor.check_pred(pred, tag);
456 }
457
458 pub fn assume_pred(&mut self, pred: impl Into<Expr>) {
459 self.cursor.assume_pred(pred);
460 }
461
462 pub fn unpack(&mut self, ty: &Ty) -> Ty {
463 self.hoister(false).hoist(ty)
464 }
465
466 pub fn marker(&self) -> Marker {
467 self.cursor.marker()
468 }
469
470 pub fn hoister(
471 &mut self,
472 assume_invariants: bool,
473 ) -> Hoister<Unpacker<'_, 'infcx, 'genv, 'tcx>> {
474 Hoister::with_delegate(Unpacker { infcx: self, assume_invariants }).transparent()
475 }
476
477 pub fn assume_invariants(&mut self, ty: &Ty) {
478 self.cursor
479 .assume_invariants(self.genv.tcx(), ty, self.check_overflow);
480 }
481
482 fn check_impl(&mut self, pred1: impl Into<Expr>, pred2: impl Into<Expr>, tag: Tag) {
483 self.cursor.check_impl(pred1, pred2, tag);
484 }
485}
486
487pub struct Unpacker<'a, 'infcx, 'genv, 'tcx> {
488 infcx: &'a mut InferCtxt<'infcx, 'genv, 'tcx>,
489 assume_invariants: bool,
490}
491
492impl HoisterDelegate for Unpacker<'_, '_, '_, '_> {
493 fn hoist_exists(&mut self, ty_ctor: &TyCtor) -> Ty {
494 let ty =
495 ty_ctor.replace_bound_refts_with(|sort, _, _| Expr::fvar(self.infcx.define_var(sort)));
496 if self.assume_invariants {
497 self.infcx.assume_invariants(&ty);
498 }
499 ty
500 }
501
502 fn hoist_constr(&mut self, pred: Expr) {
503 self.infcx.assume_pred(pred);
504 }
505}
506
507impl std::fmt::Debug for InferCtxt<'_, '_, '_> {
508 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
509 std::fmt::Debug::fmt(&self.cursor, f)
510 }
511}
512
513#[derive(Debug)]
514pub struct InferCtxtAt<'a, 'infcx, 'genv, 'tcx> {
515 pub infcx: &'a mut InferCtxt<'infcx, 'genv, 'tcx>,
516 pub span: Span,
517}
518
519impl<'genv, 'tcx> InferCtxtAt<'_, '_, 'genv, 'tcx> {
520 fn tag(&self, reason: ConstrReason) -> Tag {
521 Tag::new(reason, self.span)
522 }
523
524 pub fn check_pred(&mut self, pred: impl Into<Expr>, reason: ConstrReason) {
525 let tag = self.tag(reason);
526 self.infcx.check_pred(pred, tag);
527 }
528
529 pub fn check_non_closure_clauses(
530 &mut self,
531 clauses: &[rty::Clause],
532 reason: ConstrReason,
533 ) -> InferResult {
534 for clause in clauses {
535 if let rty::ClauseKind::Projection(projection_pred) = clause.kind_skipping_binder() {
536 let impl_elem = BaseTy::projection(projection_pred.projection_ty)
537 .to_ty()
538 .deeply_normalize(self)?;
539 let term = projection_pred.term.to_ty().deeply_normalize(self)?;
540
541 self.subtyping(&impl_elem, &term, reason)?;
543 self.subtyping(&term, &impl_elem, reason)?;
544 }
545 }
546 Ok(())
547 }
548
549 pub fn subtyping_with_env(
552 &mut self,
553 env: &mut impl LocEnv,
554 a: &Ty,
555 b: &Ty,
556 reason: ConstrReason,
557 ) -> InferResult {
558 let mut sub = Sub::new(env, reason, self.span);
559 sub.tys(self.infcx, a, b)
560 }
561
562 pub fn subtyping(
567 &mut self,
568 a: &Ty,
569 b: &Ty,
570 reason: ConstrReason,
571 ) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
572 let mut env = DummyEnv;
573 let mut sub = Sub::new(&mut env, reason, self.span);
574 sub.tys(self.infcx, a, b)?;
575 Ok(sub.obligations)
576 }
577
578 pub fn subtyping_generic_args(
579 &mut self,
580 variance: Variance,
581 a: &GenericArg,
582 b: &GenericArg,
583 reason: ConstrReason,
584 ) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
585 let mut env = DummyEnv;
586 let mut sub = Sub::new(&mut env, reason, self.span);
587 sub.generic_args(self.infcx, variance, a, b)?;
588 Ok(sub.obligations)
589 }
590
591 pub fn check_constructor(
595 &mut self,
596 variant: EarlyBinder<PolyVariant>,
597 generic_args: &[GenericArg],
598 fields: &[Ty],
599 reason: ConstrReason,
600 ) -> InferResult<Ty> {
601 let ret = self.ensure_resolved_evars(|this| {
602 let generic_args = this.instantiate_generic_args(generic_args);
604
605 let variant = variant
606 .instantiate(this.tcx(), &generic_args, &[])
607 .replace_bound_refts_with(|sort, mode, _| this.fresh_infer_var(sort, mode));
608
609 for (actual, formal) in iter::zip(fields, variant.fields()) {
611 this.subtyping(actual, formal, reason)?;
612 }
613
614 for require in &variant.requires {
616 this.check_pred(require, ConstrReason::Fold);
617 }
618
619 Ok(variant.ret())
620 })?;
621 Ok(self.fully_resolve_evars(&ret))
622 }
623
624 pub fn ensure_resolved_evars<R>(
625 &mut self,
626 f: impl FnOnce(&mut InferCtxtAt<'_, '_, 'genv, 'tcx>) -> InferResult<R>,
627 ) -> InferResult<R> {
628 self.infcx
629 .ensure_resolved_evars(|infcx| f(&mut infcx.at(self.span)))
630 }
631}
632
633impl<'a, 'genv, 'tcx> std::ops::Deref for InferCtxtAt<'_, 'a, 'genv, 'tcx> {
634 type Target = InferCtxt<'a, 'genv, 'tcx>;
635
636 fn deref(&self) -> &Self::Target {
637 self.infcx
638 }
639}
640
641impl std::ops::DerefMut for InferCtxtAt<'_, '_, '_, '_> {
642 fn deref_mut(&mut self) -> &mut Self::Target {
643 self.infcx
644 }
645}
646
647#[derive(TypeVisitable, TypeFoldable)]
651pub(crate) enum TypeTrace {
652 Types(Ty, Ty),
653 BaseTys(BaseTy, BaseTy),
654}
655
656#[expect(dead_code, reason = "we use this for debugging some time")]
657impl TypeTrace {
658 fn tys(a: &Ty, b: &Ty) -> Self {
659 Self::Types(a.clone(), b.clone())
660 }
661
662 fn btys(a: &BaseTy, b: &BaseTy) -> Self {
663 Self::BaseTys(a.clone(), b.clone())
664 }
665}
666
667impl fmt::Debug for TypeTrace {
668 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
669 match self {
670 TypeTrace::Types(a, b) => write!(f, "{a:?} - {b:?}"),
671 TypeTrace::BaseTys(a, b) => write!(f, "{a:?} - {b:?}"),
672 }
673 }
674}
675
676pub trait LocEnv {
677 fn ptr_to_ref(
678 &mut self,
679 infcx: &mut InferCtxtAt,
680 reason: ConstrReason,
681 re: Region,
682 path: &Path,
683 bound: Ty,
684 ) -> InferResult<Ty>;
685
686 fn unfold_strg_ref(&mut self, infcx: &mut InferCtxt, path: &Path, ty: &Ty) -> InferResult<Loc>;
687
688 fn get(&self, path: &Path) -> Ty;
689}
690
691struct DummyEnv;
692
693impl LocEnv for DummyEnv {
694 fn ptr_to_ref(
695 &mut self,
696 _: &mut InferCtxtAt,
697 _: ConstrReason,
698 _: Region,
699 _: &Path,
700 _: Ty,
701 ) -> InferResult<Ty> {
702 bug!("call to `ptr_to_ref` on `DummyEnv`")
703 }
704
705 fn unfold_strg_ref(&mut self, _: &mut InferCtxt, _: &Path, _: &Ty) -> InferResult<Loc> {
706 bug!("call to `unfold_str_ref` on `DummyEnv`")
707 }
708
709 fn get(&self, _: &Path) -> Ty {
710 bug!("call to `get` on `DummyEnv`")
711 }
712}
713
714struct Sub<'a, E> {
716 env: &'a mut E,
718 reason: ConstrReason,
719 span: Span,
720 obligations: Vec<Binder<rty::CoroutineObligPredicate>>,
724}
725
726impl<'a, E: LocEnv> Sub<'a, E> {
727 fn new(env: &'a mut E, reason: ConstrReason, span: Span) -> Self {
728 Self { env, reason, span, obligations: vec![] }
729 }
730
731 fn tag(&self) -> Tag {
732 Tag::new(self.reason, self.span)
733 }
734
735 fn tys(&mut self, infcx: &mut InferCtxt, a: &Ty, b: &Ty) -> InferResult {
736 let infcx = &mut infcx.branch();
737 let a = infcx.unpack(a);
743
744 match (a.kind(), b.kind()) {
745 (TyKind::Exists(..), _) => {
746 bug!("existentials should have been removed by the unpacking above");
747 }
748 (TyKind::Constr(..), _) => {
749 bug!("constraint types should have been removed by the unpacking above");
750 }
751
752 (_, TyKind::Exists(ctor_b)) => {
753 infcx.enter_exists(ctor_b, |infcx, ty_b| self.tys(infcx, &a, &ty_b))
754 }
755 (_, TyKind::Constr(pred_b, ty_b)) => {
756 infcx.check_pred(pred_b, self.tag());
757 self.tys(infcx, &a, ty_b)
758 }
759
760 (TyKind::Ptr(PtrKind::Mut(_), path_a), TyKind::StrgRef(_, path_b, ty_b)) => {
761 let ty_a = self.env.get(path_a);
765 infcx.unify_exprs(&path_a.to_expr(), &path_b.to_expr());
766 self.tys(infcx, &ty_a, ty_b)
767 }
768 (TyKind::StrgRef(_, path_a, ty_a), TyKind::StrgRef(_, path_b, ty_b)) => {
769 self.env.unfold_strg_ref(infcx, path_a, ty_a)?;
781 let ty_a = self.env.get(path_a);
782 infcx.unify_exprs(&path_a.to_expr(), &path_b.to_expr());
783 self.tys(infcx, &ty_a, ty_b)
784 }
785 (
786 TyKind::Ptr(PtrKind::Mut(re), path),
787 TyKind::Indexed(BaseTy::Ref(_, bound, Mutability::Mut), idx),
788 ) => {
789 self.idxs_eq(infcx, &Expr::unit(), idx);
792
793 self.env.ptr_to_ref(
794 &mut infcx.at(self.span),
795 self.reason,
796 *re,
797 path,
798 bound.clone(),
799 )?;
800 Ok(())
801 }
802
803 (TyKind::Indexed(bty_a, idx_a), TyKind::Indexed(bty_b, idx_b)) => {
804 self.btys(infcx, bty_a, bty_b)?;
805 self.idxs_eq(infcx, idx_a, idx_b);
806 Ok(())
807 }
808 (TyKind::Ptr(pk_a, path_a), TyKind::Ptr(pk_b, path_b)) => {
809 debug_assert_eq!(pk_a, pk_b);
810 debug_assert_eq!(path_a, path_b);
811 Ok(())
812 }
813 (TyKind::Param(param_ty_a), TyKind::Param(param_ty_b)) => {
814 debug_assert_eq!(param_ty_a, param_ty_b);
815 Ok(())
816 }
817 (_, TyKind::Uninit) => Ok(()),
818 (TyKind::Downcast(.., fields_a), TyKind::Downcast(.., fields_b)) => {
819 debug_assert_eq!(fields_a.len(), fields_b.len());
820 for (ty_a, ty_b) in iter::zip(fields_a, fields_b) {
821 self.tys(infcx, ty_a, ty_b)?;
822 }
823 Ok(())
824 }
825 _ => Err(query_bug!("incompatible types: `{a:?}` - `{b:?}`"))?,
826 }
827 }
828
829 fn btys(&mut self, infcx: &mut InferCtxt, a: &BaseTy, b: &BaseTy) -> InferResult {
830 match (a, b) {
833 (BaseTy::Int(int_ty_a), BaseTy::Int(int_ty_b)) => {
834 debug_assert_eq!(int_ty_a, int_ty_b);
835 Ok(())
836 }
837 (BaseTy::Uint(uint_ty_a), BaseTy::Uint(uint_ty_b)) => {
838 debug_assert_eq!(uint_ty_a, uint_ty_b);
839 Ok(())
840 }
841 (BaseTy::Adt(a_adt, a_args), BaseTy::Adt(b_adt, b_args)) => {
842 tracked_span_dbg_assert_eq!(a_adt.did(), b_adt.did());
843 tracked_span_dbg_assert_eq!(a_args.len(), b_args.len());
844 let variances = infcx.genv.variances_of(a_adt.did());
845 for (variance, ty_a, ty_b) in izip!(variances, a_args.iter(), b_args.iter()) {
846 self.generic_args(infcx, *variance, ty_a, ty_b)?;
847 }
848 Ok(())
849 }
850 (BaseTy::FnDef(a_def_id, a_args), BaseTy::FnDef(b_def_id, b_args)) => {
851 debug_assert_eq!(a_def_id, b_def_id);
852 debug_assert_eq!(a_args.len(), b_args.len());
853 for (arg_a, arg_b) in iter::zip(a_args, b_args) {
861 match (arg_a, arg_b) {
862 (GenericArg::Ty(ty_a), GenericArg::Ty(ty_b)) => {
863 let bty_a = ty_a.as_bty_skipping_existentials();
864 let bty_b = ty_b.as_bty_skipping_existentials();
865 tracked_span_dbg_assert_eq!(bty_a, bty_b);
866 }
867 (GenericArg::Base(ctor_a), GenericArg::Base(ctor_b)) => {
868 let bty_a = ctor_a.as_bty_skipping_binder();
869 let bty_b = ctor_b.as_bty_skipping_binder();
870 tracked_span_dbg_assert_eq!(bty_a, bty_b);
871 }
872 (_, _) => tracked_span_dbg_assert_eq!(arg_a, arg_b),
873 }
874 }
875 Ok(())
876 }
877 (BaseTy::Float(float_ty_a), BaseTy::Float(float_ty_b)) => {
878 debug_assert_eq!(float_ty_a, float_ty_b);
879 Ok(())
880 }
881 (BaseTy::Slice(ty_a), BaseTy::Slice(ty_b)) => self.tys(infcx, ty_a, ty_b),
882 (BaseTy::Ref(_, ty_a, Mutability::Mut), BaseTy::Ref(_, ty_b, Mutability::Mut)) => {
883 if ty_a.is_slice()
884 && let TyKind::Indexed(_, idx_a) = ty_a.kind()
885 && let TyKind::Exists(bty_b) = ty_b.kind()
886 {
887 self.tys(infcx, ty_a, ty_b)?;
892 self.tys(infcx, &bty_b.replace_bound_reft(idx_a), ty_a)
893 } else {
894 self.tys(infcx, ty_a, ty_b)?;
895 self.tys(infcx, ty_b, ty_a)
896 }
897 }
898 (BaseTy::Ref(_, ty_a, Mutability::Not), BaseTy::Ref(_, ty_b, Mutability::Not)) => {
899 self.tys(infcx, ty_a, ty_b)
900 }
901 (BaseTy::Tuple(tys_a), BaseTy::Tuple(tys_b)) => {
902 debug_assert_eq!(tys_a.len(), tys_b.len());
903 for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
904 self.tys(infcx, ty_a, ty_b)?;
905 }
906 Ok(())
907 }
908 (_, BaseTy::Alias(AliasKind::Opaque, alias_ty_b)) => {
909 if let BaseTy::Alias(AliasKind::Opaque, alias_ty_a) = a {
910 debug_assert_eq!(alias_ty_a.refine_args.len(), alias_ty_b.refine_args.len());
911 iter::zip(alias_ty_a.refine_args.iter(), alias_ty_b.refine_args.iter())
912 .for_each(|(expr_a, expr_b)| infcx.unify_exprs(expr_a, expr_b));
913 }
914 self.handle_opaque_type(infcx, a, alias_ty_b)
915 }
916 (
917 BaseTy::Alias(AliasKind::Projection, alias_ty_a),
918 BaseTy::Alias(AliasKind::Projection, alias_ty_b),
919 ) => {
920 tracked_span_dbg_assert_eq!(alias_ty_a, alias_ty_b);
921 Ok(())
922 }
923 (BaseTy::Array(ty_a, len_a), BaseTy::Array(ty_b, len_b)) => {
924 tracked_span_dbg_assert_eq!(len_a, len_b);
925 self.tys(infcx, ty_a, ty_b)
926 }
927 (BaseTy::Param(param_a), BaseTy::Param(param_b)) => {
928 debug_assert_eq!(param_a, param_b);
929 Ok(())
930 }
931 (BaseTy::Bool, BaseTy::Bool)
932 | (BaseTy::Str, BaseTy::Str)
933 | (BaseTy::Char, BaseTy::Char)
934 | (BaseTy::RawPtr(_, _), BaseTy::RawPtr(_, _))
935 | (BaseTy::RawPtrMetadata(_), BaseTy::RawPtrMetadata(_)) => Ok(()),
936 (BaseTy::Dynamic(preds_a, _), BaseTy::Dynamic(preds_b, _)) => {
937 tracked_span_assert_eq!(preds_a.erase_regions(), preds_b.erase_regions());
938 Ok(())
939 }
940 (BaseTy::Closure(did1, tys_a, _), BaseTy::Closure(did2, tys_b, _)) if did1 == did2 => {
941 debug_assert_eq!(tys_a.len(), tys_b.len());
942 for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
943 self.tys(infcx, ty_a, ty_b)?;
944 }
945 Ok(())
946 }
947 (BaseTy::FnPtr(sig_a), BaseTy::FnPtr(sig_b)) => {
948 tracked_span_assert_eq!(sig_a.erase_regions(), sig_b.erase_regions());
949 Ok(())
950 }
951 (BaseTy::Never, BaseTy::Never) => Ok(()),
952 _ => Err(query_bug!("incompatible base types: `{a:?}` - `{b:?}`"))?,
953 }
954 }
955
956 fn generic_args(
957 &mut self,
958 infcx: &mut InferCtxt,
959 variance: Variance,
960 a: &GenericArg,
961 b: &GenericArg,
962 ) -> InferResult {
963 let (ty_a, ty_b) = match (a, b) {
964 (GenericArg::Ty(ty_a), GenericArg::Ty(ty_b)) => (ty_a.clone(), ty_b.clone()),
965 (GenericArg::Base(ctor_a), GenericArg::Base(ctor_b)) => {
966 tracked_span_dbg_assert_eq!(ctor_a.sort(), ctor_b.sort());
967 (ctor_a.to_ty(), ctor_b.to_ty())
968 }
969 (GenericArg::Lifetime(_), GenericArg::Lifetime(_)) => return Ok(()),
970 (GenericArg::Const(cst_a), GenericArg::Const(cst_b)) => {
971 debug_assert_eq!(cst_a, cst_b);
972 return Ok(());
973 }
974 _ => Err(query_bug!("incompatible generic args: `{a:?}` `{b:?}`"))?,
975 };
976 match variance {
977 Variance::Covariant => self.tys(infcx, &ty_a, &ty_b),
978 Variance::Invariant => {
979 self.tys(infcx, &ty_a, &ty_b)?;
980 self.tys(infcx, &ty_b, &ty_a)
981 }
982 Variance::Contravariant => self.tys(infcx, &ty_b, &ty_a),
983 Variance::Bivariant => Ok(()),
984 }
985 }
986
987 fn idxs_eq(&mut self, infcx: &mut InferCtxt, a: &Expr, b: &Expr) {
988 if a == b {
989 return;
990 }
991 match (a.kind(), b.kind()) {
992 (
993 ExprKind::Ctor(Ctor::Struct(did_a), flds_a),
994 ExprKind::Ctor(Ctor::Struct(did_b), flds_b),
995 ) => {
996 debug_assert_eq!(did_a, did_b);
997 for (a, b) in iter::zip(flds_a, flds_b) {
998 self.idxs_eq(infcx, a, b);
999 }
1000 }
1001 (ExprKind::Tuple(flds_a), ExprKind::Tuple(flds_b)) => {
1002 for (a, b) in iter::zip(flds_a, flds_b) {
1003 self.idxs_eq(infcx, a, b);
1004 }
1005 }
1006 (_, ExprKind::Tuple(flds_b)) => {
1007 for (f, b) in flds_b.iter().enumerate() {
1008 let proj = FieldProj::Tuple { arity: flds_b.len(), field: f as u32 };
1009 let a = a.proj_and_reduce(proj);
1010 self.idxs_eq(infcx, &a, b);
1011 }
1012 }
1013
1014 (_, ExprKind::Ctor(Ctor::Struct(def_id), flds_b)) => {
1015 for (f, b) in flds_b.iter().enumerate() {
1016 let proj = FieldProj::Adt { def_id: *def_id, field: f as u32 };
1017 let a = a.proj_and_reduce(proj);
1018 self.idxs_eq(infcx, &a, b);
1019 }
1020 }
1021
1022 (ExprKind::Tuple(flds_a), _) => {
1023 infcx.unify_exprs(a, b);
1024 for (f, a) in flds_a.iter().enumerate() {
1025 let proj = FieldProj::Tuple { arity: flds_a.len(), field: f as u32 };
1026 let b = b.proj_and_reduce(proj);
1027 self.idxs_eq(infcx, a, &b);
1028 }
1029 }
1030 (ExprKind::Ctor(Ctor::Struct(def_id), flds_a), _) => {
1031 infcx.unify_exprs(a, b);
1032 for (f, a) in flds_a.iter().enumerate() {
1033 let proj = FieldProj::Adt { def_id: *def_id, field: f as u32 };
1034 let b = b.proj_and_reduce(proj);
1035 self.idxs_eq(infcx, a, &b);
1036 }
1037 }
1038 (ExprKind::Abs(lam_a), ExprKind::Abs(lam_b)) => {
1039 self.abs_eq(infcx, lam_a, lam_b);
1040 }
1041 (_, ExprKind::Abs(lam_b)) => {
1042 self.abs_eq(infcx, &a.eta_expand_abs(lam_b.vars(), lam_b.output()), lam_b);
1043 }
1044 (ExprKind::Abs(lam_a), _) => {
1045 infcx.unify_exprs(a, b);
1046 self.abs_eq(infcx, lam_a, &b.eta_expand_abs(lam_a.vars(), lam_a.output()));
1047 }
1048 (ExprKind::KVar(_), _) | (_, ExprKind::KVar(_)) => {
1049 infcx.check_impl(a, b, self.tag());
1050 infcx.check_impl(b, a, self.tag());
1051 }
1052 _ => {
1053 infcx.unify_exprs(a, b);
1054 let span = b.span();
1055 infcx.check_pred(Expr::binary_op(rty::BinOp::Eq, a, b).at_opt(span), self.tag());
1056 }
1057 }
1058 }
1059
1060 fn abs_eq(&mut self, infcx: &mut InferCtxt, a: &Lambda, b: &Lambda) {
1061 debug_assert_eq!(a.vars().len(), b.vars().len());
1062 let vars = a
1063 .vars()
1064 .iter()
1065 .map(|kind| Expr::fvar(infcx.define_var(kind.expect_sort())))
1066 .collect_vec();
1067 let body_a = a.apply(&vars);
1068 let body_b = b.apply(&vars);
1069 self.idxs_eq(infcx, &body_a, &body_b);
1070 }
1071
1072 fn handle_opaque_type(
1073 &mut self,
1074 infcx: &mut InferCtxt,
1075 bty: &BaseTy,
1076 alias_ty: &AliasTy,
1077 ) -> InferResult {
1078 if let BaseTy::Coroutine(def_id, resume_ty, upvar_tys) = bty {
1079 let obligs = mk_coroutine_obligations(
1080 infcx.genv,
1081 def_id,
1082 resume_ty,
1083 upvar_tys,
1084 &alias_ty.def_id,
1085 )?;
1086 self.obligations.extend(obligs);
1087 } else {
1088 let bounds = infcx.genv.item_bounds(alias_ty.def_id)?.instantiate(
1089 infcx.tcx(),
1090 &alias_ty.args,
1091 &alias_ty.refine_args,
1092 );
1093 for clause in &bounds {
1094 if !clause.kind().vars().is_empty() {
1095 Err(query_bug!("handle_opaque_types: clause with bound vars: `{clause:?}`"))?;
1096 }
1097 if let rty::ClauseKind::Projection(pred) = clause.kind_skipping_binder() {
1098 let alias_ty = pred.projection_ty.with_self_ty(bty.to_subset_ty_ctor());
1099 let ty1 = BaseTy::Alias(AliasKind::Projection, alias_ty)
1100 .to_ty()
1101 .deeply_normalize(&mut infcx.at(self.span))?;
1102 let ty2 = pred.term.to_ty();
1103 self.tys(infcx, &ty1, &ty2)?;
1104 }
1105 }
1106 }
1107 Ok(())
1108 }
1109}
1110
1111fn mk_coroutine_obligations(
1112 genv: GlobalEnv,
1113 generator_did: &DefId,
1114 resume_ty: &Ty,
1115 upvar_tys: &List<Ty>,
1116 opaque_def_id: &DefId,
1117) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
1118 let bounds = genv.item_bounds(*opaque_def_id)?.skip_binder();
1119 for bound in &bounds {
1120 if let Some(proj_clause) = bound.as_projection_clause() {
1121 return Ok(vec![proj_clause.map(|proj_clause| {
1122 let output = proj_clause.term;
1123 CoroutineObligPredicate {
1124 def_id: *generator_did,
1125 resume_ty: resume_ty.clone(),
1126 upvar_tys: upvar_tys.clone(),
1127 output: output.to_ty(),
1128 }
1129 })]);
1130 }
1131 }
1132 bug!("no projection predicate")
1133}
1134
1135#[derive(Debug)]
1136pub enum InferErr {
1137 UnsolvedEvar(EVid),
1138 Query(QueryErr),
1139}
1140
1141impl From<QueryErr> for InferErr {
1142 fn from(v: QueryErr) -> Self {
1143 Self::Query(v)
1144 }
1145}
1146
1147mod pretty {
1148 use std::fmt;
1149
1150 use flux_middle::pretty::*;
1151
1152 use super::*;
1153
1154 impl Pretty for Tag {
1155 fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1156 w!(cx, f, "{:?} at {:?}", ^self.reason, self.src_span)?;
1157 if let Some(dst_span) = self.dst_span {
1158 w!(cx, f, " ({:?})", ^dst_span)?;
1159 }
1160 Ok(())
1161 }
1162 }
1163
1164 impl_debug_with_default_cx!(Tag);
1165}