1use std::{cell::RefCell, fmt, iter};
2
3use flux_common::{bug, dbg, tracked_span_assert_eq, tracked_span_bug, tracked_span_dbg_assert_eq};
4use flux_config::{self as config, InferOpts, OverflowMode, RawDerefMode};
5use flux_macros::{TypeFoldable, TypeVisitable};
6use flux_middle::{
7 FixpointQueryKind, PanicSpec,
8 def_id::MaybeExternId,
9 global_env::GlobalEnv,
10 metrics::{self, Metric},
11 queries::{QueryErr, QueryResult},
12 query_bug,
13 rty::{
14 self, AliasKind, AliasTy, BaseTy, Binder, BoundReftKind, BoundVariableKinds,
15 CoroutineObligPredicate, Ctor, ESpan, EVid, EarlyBinder, Expr, ExprKind, FieldProj,
16 GenericArg, HoleKind, InferMode, Lambda, List, Loc, Mutability, Name, NameProvenance, Path,
17 PolyVariant, PtrKind, RefineArgs, RefineArgsExt, Region, Sort, Ty, TyCtor, TyKind, Var,
18 canonicalize::{Hoister, HoisterDelegate},
19 fold::TypeFoldable,
20 },
21};
22use itertools::{Itertools, izip};
23use rustc_hir::def_id::{DefId, LocalDefId};
24use rustc_macros::extension;
25use rustc_middle::{
26 mir::BasicBlock,
27 ty::{TyCtxt, Variance},
28};
29use rustc_span::{Span, Symbol};
30use rustc_type_ir::Variance::Invariant;
31
32use crate::{
33 evars::{EVarState, EVarStore},
34 fixpoint_encoding::{
35 Answer, Backend, FixQueryCache, FixpointCtxt, KVarEncoding, KVarGen, lean_task_key,
36 },
37 lean_encoding::log_proof,
38 projections::NormalizeExt as _,
39 refine_tree::{Cursor, Marker, RefineTree, Scope},
40};
41
42pub type InferResult<T = ()> = std::result::Result<T, InferErr>;
43
44#[derive(PartialEq, Eq, Clone, Copy, Hash)]
45pub struct Tag {
46 pub reason: ConstrReason,
47 pub src_span: Span,
48 pub dst_span: Option<ESpan>,
49}
50
51impl Tag {
52 pub fn new(reason: ConstrReason, span: Span) -> Self {
53 Self { reason, src_span: span, dst_span: None }
54 }
55
56 pub fn with_dst(self, dst_span: Option<ESpan>) -> Self {
57 Self { dst_span, ..self }
58 }
59}
60
61#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)]
62pub enum SubtypeReason {
63 Input,
64 Output,
65 Requires,
66 Ensures,
67}
68
69#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)]
70pub enum ConstrReason {
71 Call,
72 Assign,
73 Ret,
74 Fold,
75 FoldLocal,
76 Predicate,
77 Assert(&'static str),
78 Div,
79 Rem,
80 Goto(BasicBlock),
81 Overflow,
82 Underflow,
83 Subtype(SubtypeReason),
84 NoPanic(DefId, PanicSpec),
85 Other,
86}
87
88pub struct InferCtxtRoot<'genv, 'tcx> {
89 pub genv: GlobalEnv<'genv, 'tcx>,
90 inner: RefCell<InferCtxtInner>,
91 refine_tree: RefineTree,
92 opts: InferOpts,
93}
94
95pub struct InferCtxtRootBuilder<'a, 'genv, 'tcx> {
96 genv: GlobalEnv<'genv, 'tcx>,
97 opts: InferOpts,
98 params: Vec<(Var, Sort)>,
99 infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
100 dummy_kvars: bool,
101}
102
103#[extension(pub trait GlobalEnvExt<'genv, 'tcx>)]
104impl<'genv, 'tcx> GlobalEnv<'genv, 'tcx> {
105 fn infcx_root<'a>(
106 self,
107 infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
108 opts: InferOpts,
109 ) -> InferCtxtRootBuilder<'a, 'genv, 'tcx> {
110 InferCtxtRootBuilder { genv: self, infcx, params: vec![], opts, dummy_kvars: false }
111 }
112}
113
114impl<'genv, 'tcx> InferCtxtRootBuilder<'_, 'genv, 'tcx> {
115 pub fn with_dummy_kvars(mut self) -> Self {
116 self.dummy_kvars = true;
117 self
118 }
119
120 pub fn with_const_generics(mut self, def_id: DefId) -> QueryResult<Self> {
121 self.params.extend(
122 self.genv
123 .generics_of(def_id)?
124 .const_params(self.genv)?
125 .into_iter()
126 .map(|(pcst, sort)| (Var::ConstGeneric(pcst), sort)),
127 );
128 Ok(self)
129 }
130
131 pub fn with_refinement_generics(
132 mut self,
133 def_id: DefId,
134 args: &[GenericArg],
135 ) -> QueryResult<Self> {
136 for (index, param) in self
137 .genv
138 .refinement_generics_of(def_id)?
139 .iter_own_params()
140 .enumerate()
141 {
142 let param = param.instantiate(self.genv.tcx(), args, &[]);
143 let sort = param
144 .sort
145 .deeply_normalize_sorts(def_id, self.genv, self.infcx)?;
146
147 let var =
148 Var::EarlyParam(rty::EarlyReftParam { index: index as u32, name: param.name });
149 self.params.push((var, sort));
150 }
151 Ok(self)
152 }
153
154 pub fn identity_for_item(mut self, def_id: DefId) -> QueryResult<Self> {
155 self = self.with_const_generics(def_id)?;
156 let offset = self.params.len();
157 self.genv.refinement_generics_of(def_id)?.fill_item(
158 self.genv,
159 &mut self.params,
160 &mut |param, index| {
161 let index = (index - offset) as u32;
162 let param = param.instantiate_identity();
163 let sort = param
164 .sort
165 .deeply_normalize_sorts(def_id, self.genv, self.infcx)?;
166
167 let var = Var::EarlyParam(rty::EarlyReftParam { index, name: param.name });
168 Ok((var, sort))
169 },
170 )?;
171 Ok(self)
172 }
173
174 pub fn build(self) -> QueryResult<InferCtxtRoot<'genv, 'tcx>> {
175 Ok(InferCtxtRoot {
176 genv: self.genv,
177 inner: RefCell::new(InferCtxtInner::new(self.dummy_kvars)),
178 refine_tree: RefineTree::new(self.params),
179 opts: self.opts,
180 })
181 }
182}
183
184impl<'genv, 'tcx> InferCtxtRoot<'genv, 'tcx> {
185 pub fn infcx<'a>(
186 &'a mut self,
187 def_id: DefId,
188 region_infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
189 ) -> InferCtxt<'a, 'genv, 'tcx> {
190 InferCtxt {
191 genv: self.genv,
192 region_infcx,
193 def_id,
194 cursor: self.refine_tree.cursor_at_root(),
195 inner: &self.inner,
196 check_overflow: self.opts.check_overflow,
197 allow_raw_deref: self.opts.allow_raw_deref,
198 }
199 }
200
201 pub fn fresh_kvar_in_scope(
202 &self,
203 binders: &[BoundVariableKinds],
204 scope: &Scope,
205 encoding: KVarEncoding,
206 ) -> Expr {
207 let inner = &mut *self.inner.borrow_mut();
208 inner.kvars.fresh(binders, scope.iter(), encoding)
209 }
210
211 pub fn execute_lean_query(
212 self,
213 cache: &mut FixQueryCache,
214 def_id: MaybeExternId,
215 ) -> QueryResult {
216 let inner = self.inner.into_inner();
217 let kvars = inner.kvars;
218 let evars = inner.evars;
219 let mut refine_tree = self.refine_tree;
220 refine_tree.replace_evars(&evars).unwrap();
221 refine_tree.simplify(self.genv);
222
223 let solver = match self.opts.solver {
224 flux_config::SmtSolver::Z3 => liquid_fixpoint::SmtSolver::Z3,
225 flux_config::SmtSolver::CVC5 => liquid_fixpoint::SmtSolver::CVC5,
226 };
227 let mut fcx = FixpointCtxt::new(self.genv, def_id, kvars, Backend::Lean);
228 let cstr = refine_tree.to_fixpoint(&mut fcx)?;
229 let task = fcx.create_task(def_id, cstr, self.opts.scrape_quals, solver)?;
230
231 log_proof(self.genv, def_id)?;
232 if config::is_cache_enabled() {
234 let key = lean_task_key(self.genv.tcx(), def_id.resolved_id());
235 let hash = task.hash_with_default();
236 if cache.lookup(&key, hash).is_some() {
237 return Ok(());
238 }
239 }
240
241 fcx.generate_lean_files(def_id, task)
242 }
243
244 pub fn execute_fixpoint_query(
245 self,
246 cache: &mut FixQueryCache,
247 def_id: MaybeExternId,
248 kind: FixpointQueryKind,
249 ) -> QueryResult<Answer<Tag>> {
250 let inner = self.inner.into_inner();
251 let kvars = inner.kvars;
252 let evars = inner.evars;
253
254 let ext = kind.ext();
255
256 let mut refine_tree = self.refine_tree;
257
258 refine_tree.replace_evars(&evars).unwrap();
259
260 if config::dump_constraint() {
261 dbg::dump_item_info(self.genv.tcx(), def_id.resolved_id(), ext, &refine_tree).unwrap();
262 }
263 refine_tree.simplify(self.genv);
264 if config::dump_constraint() {
265 let simp_ext = format!("simp.{ext}");
266 dbg::dump_item_info(self.genv.tcx(), def_id.resolved_id(), simp_ext, &refine_tree)
267 .unwrap();
268 }
269
270 let backend = match self.opts.solver {
271 flux_config::SmtSolver::Z3 => liquid_fixpoint::SmtSolver::Z3,
272 flux_config::SmtSolver::CVC5 => liquid_fixpoint::SmtSolver::CVC5,
273 };
274
275 let mut fcx = FixpointCtxt::new(self.genv, def_id, kvars, Backend::Fixpoint);
276 let cstr = refine_tree.to_fixpoint(&mut fcx)?;
277
278 let count = cstr.concrete_head_count();
280 metrics::incr_metric(Metric::CsTotal, count as u32);
281 if count == 0 {
282 metrics::incr_metric_if(kind.is_body(), Metric::FnTrivial);
283 return Ok(Answer::trivial());
284 }
285
286 let task = fcx.create_task(def_id, cstr, self.opts.scrape_quals, backend)?;
287 let result = fcx.run_task(cache, def_id, kind, &task)?;
288 Ok(fcx.result_to_answer(result))
289 }
290
291 pub fn split(self) -> (RefineTree, KVarGen) {
292 (self.refine_tree, self.inner.into_inner().kvars)
293 }
294}
295
296pub struct InferCtxt<'infcx, 'genv, 'tcx> {
297 pub genv: GlobalEnv<'genv, 'tcx>,
298 pub region_infcx: &'infcx rustc_infer::infer::InferCtxt<'tcx>,
299 pub def_id: DefId,
300 pub check_overflow: OverflowMode,
301 pub allow_raw_deref: flux_config::RawDerefMode,
302 cursor: Cursor<'infcx>,
303 inner: &'infcx RefCell<InferCtxtInner>,
304}
305
306struct InferCtxtInner {
307 kvars: KVarGen,
308 evars: EVarStore,
309}
310
311impl InferCtxtInner {
312 fn new(dummy_kvars: bool) -> Self {
313 Self { kvars: KVarGen::new(dummy_kvars), evars: Default::default() }
314 }
315}
316
317impl<'infcx, 'genv, 'tcx> InferCtxt<'infcx, 'genv, 'tcx> {
318 pub fn at(&mut self, span: Span) -> InferCtxtAt<'_, 'infcx, 'genv, 'tcx> {
319 InferCtxtAt { infcx: self, span }
320 }
321
322 pub fn instantiate_refine_args(
323 &mut self,
324 callee_def_id: DefId,
325 args: &[rty::GenericArg],
326 ) -> InferResult<List<Expr>> {
327 Ok(RefineArgs::for_item(self.genv, callee_def_id, |param, _| {
328 let param = param.instantiate(self.genv.tcx(), args, &[]);
329 Ok(self.fresh_infer_var(¶m.sort, param.mode))
330 })?)
331 }
332
333 pub fn instantiate_generic_args(&mut self, args: &[GenericArg]) -> Vec<GenericArg> {
334 args.iter()
335 .map(|a| a.replace_holes(|binders, kind| self.fresh_infer_var_for_hole(binders, kind)))
336 .collect_vec()
337 }
338
339 pub fn fresh_infer_var(&self, sort: &Sort, mode: InferMode) -> Expr {
340 match mode {
341 InferMode::KVar => {
342 let fsort = sort.expect_func().expect_mono();
343 let vars = fsort.inputs().iter().cloned().map_into().collect();
344 let kvar = self.fresh_kvar(&[vars], KVarEncoding::Single);
345 Expr::abs(Lambda::bind_with_fsort(kvar, fsort))
346 }
347 InferMode::EVar => self.fresh_evar(),
348 }
349 }
350
351 pub fn fresh_infer_var_for_hole(
352 &mut self,
353 binders: &[BoundVariableKinds],
354 kind: HoleKind,
355 ) -> Expr {
356 match kind {
357 HoleKind::Pred => self.fresh_kvar(binders, KVarEncoding::Conj),
358 HoleKind::Expr(_) => {
359 self.fresh_evar()
363 }
364 }
365 }
366
367 pub fn fresh_kvar_in_scope(
369 &self,
370 binders: &[BoundVariableKinds],
371 scope: &Scope,
372 encoding: KVarEncoding,
373 ) -> Expr {
374 let inner = &mut *self.inner.borrow_mut();
375 inner.kvars.fresh(binders, scope.iter(), encoding)
376 }
377
378 pub fn fresh_kvar(&self, binders: &[BoundVariableKinds], encoding: KVarEncoding) -> Expr {
380 let inner = &mut *self.inner.borrow_mut();
381 inner.kvars.fresh(binders, self.cursor.vars(), encoding)
382 }
383
384 fn fresh_evar(&self) -> Expr {
385 let evars = &mut self.inner.borrow_mut().evars;
386 Expr::evar(evars.fresh(self.cursor.marker()))
387 }
388
389 pub fn unify_exprs(&self, a: &Expr, b: &Expr) {
390 if a.has_evars() {
391 return;
392 }
393 let evars = &mut self.inner.borrow_mut().evars;
394 if let ExprKind::Var(Var::EVar(evid)) = b.kind()
395 && let EVarState::Unsolved(marker) = evars.get(*evid)
396 && !marker.has_free_vars(a)
397 {
398 evars.solve(*evid, a.clone());
399 }
400 }
401
402 fn enter_exists<T, U>(
403 &mut self,
404 t: &Binder<T>,
405 f: impl FnOnce(&mut InferCtxt<'_, 'genv, 'tcx>, T) -> U,
406 ) -> U
407 where
408 T: TypeFoldable,
409 {
410 self.ensure_resolved_evars(|infcx| {
411 let t = t.replace_bound_refts_with(|sort, mode, _| infcx.fresh_infer_var(sort, mode));
412 Ok(f(infcx, t))
413 })
414 .unwrap()
415 }
416
417 pub fn push_evar_scope(&mut self) {
422 self.inner.borrow_mut().evars.push_scope();
423 }
424
425 pub fn pop_evar_scope(&mut self) -> InferResult {
428 self.inner
429 .borrow_mut()
430 .evars
431 .pop_scope()
432 .map_err(InferErr::UnsolvedEvar)
433 }
434
435 pub fn ensure_resolved_evars<R>(
437 &mut self,
438 f: impl FnOnce(&mut Self) -> InferResult<R>,
439 ) -> InferResult<R> {
440 self.push_evar_scope();
441 let r = f(self)?;
442 self.pop_evar_scope()?;
443 Ok(r)
444 }
445
446 pub fn fully_resolve_evars<T: TypeFoldable>(&self, t: &T) -> T {
447 self.inner.borrow().evars.replace_evars(t).unwrap()
448 }
449
450 pub fn tcx(&self) -> TyCtxt<'tcx> {
451 self.genv.tcx()
452 }
453
454 pub fn cursor(&self) -> &Cursor<'infcx> {
455 &self.cursor
456 }
457
458 pub fn allow_raw_deref(&self) -> bool {
459 matches!(self.allow_raw_deref, RawDerefMode::Ok)
460 }
461}
462
463impl<'infcx, 'genv, 'tcx> InferCtxt<'infcx, 'genv, 'tcx> {
465 pub fn change_item<'a>(
466 &'a mut self,
467 def_id: LocalDefId,
468 region_infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
469 ) -> InferCtxt<'a, 'genv, 'tcx> {
470 InferCtxt {
471 def_id: def_id.to_def_id(),
472 cursor: self.cursor.branch(),
473 region_infcx,
474 ..*self
475 }
476 }
477
478 pub fn move_to(&mut self, marker: &Marker, clear_children: bool) -> InferCtxt<'_, 'genv, 'tcx> {
479 InferCtxt {
480 cursor: self
481 .cursor
482 .move_to(marker, clear_children)
483 .unwrap_or_else(|| tracked_span_bug!()),
484 ..*self
485 }
486 }
487
488 pub fn branch(&mut self) -> InferCtxt<'_, 'genv, 'tcx> {
489 InferCtxt { cursor: self.cursor.branch(), ..*self }
490 }
491
492 fn define_var(&mut self, sort: &Sort, provenance: NameProvenance) -> Name {
493 self.cursor.define_var(sort, provenance)
494 }
495
496 pub fn define_bound_reft_var(&mut self, sort: &Sort, kind: BoundReftKind) -> Name {
497 self.define_var(sort, NameProvenance::UnfoldBoundReft(kind))
498 }
499
500 pub fn define_unknown_var(&mut self, sort: &Sort) -> Name {
501 self.cursor.define_var(sort, NameProvenance::Unknown)
502 }
503
504 pub fn check_pred(&mut self, pred: impl Into<Expr>, tag: Tag) {
505 self.cursor.check_pred(pred, tag);
506 }
507
508 pub fn assume_pred(&mut self, pred: impl Into<Expr>) {
509 self.cursor.assume_pred(pred);
510 }
511
512 pub fn unpack(&mut self, ty: &Ty) -> Ty {
513 self.hoister(false).hoist(ty)
514 }
515
516 pub fn unpack_at_name(&mut self, name: Option<Symbol>, ty: &Ty) -> Ty {
517 let mut hoister = self.hoister(false);
518 hoister.delegate.name = name;
519 hoister.hoist(ty)
520 }
521
522 pub fn marker(&self) -> Marker {
523 self.cursor.marker()
524 }
525
526 pub fn hoister(
527 &mut self,
528 assume_invariants: bool,
529 ) -> Hoister<Unpacker<'_, 'infcx, 'genv, 'tcx>> {
530 Hoister::with_delegate(Unpacker { infcx: self, assume_invariants, name: None })
531 .transparent()
532 }
533
534 pub fn assume_invariants(&mut self, ty: &Ty) {
535 self.cursor
536 .assume_invariants(self.genv.tcx(), ty, self.check_overflow);
537 }
538
539 fn check_impl(&mut self, pred1: impl Into<Expr>, pred2: impl Into<Expr>, tag: Tag) {
540 self.cursor.check_impl(pred1, pred2, tag);
541 }
542}
543
544pub struct Unpacker<'a, 'infcx, 'genv, 'tcx> {
545 infcx: &'a mut InferCtxt<'infcx, 'genv, 'tcx>,
546 assume_invariants: bool,
547 name: Option<Symbol>,
548}
549
550impl HoisterDelegate for Unpacker<'_, '_, '_, '_> {
551 fn hoist_exists(&mut self, ty_ctor: &TyCtor) -> Ty {
552 let ty = ty_ctor.replace_bound_refts_with(|sort, _, kind| {
553 let kind = if let Some(name) = self.name { BoundReftKind::Named(name) } else { kind };
554 Expr::fvar(self.infcx.define_bound_reft_var(sort, kind))
555 });
556 if self.assume_invariants {
557 self.infcx.assume_invariants(&ty);
558 }
559 ty
560 }
561
562 fn hoist_constr(&mut self, pred: Expr) {
563 self.infcx.assume_pred(pred);
564 }
565}
566
567impl std::fmt::Debug for InferCtxt<'_, '_, '_> {
568 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
569 std::fmt::Debug::fmt(&self.cursor, f)
570 }
571}
572
573#[derive(Debug)]
574pub struct InferCtxtAt<'a, 'infcx, 'genv, 'tcx> {
575 pub infcx: &'a mut InferCtxt<'infcx, 'genv, 'tcx>,
576 pub span: Span,
577}
578
579impl<'genv, 'tcx> InferCtxtAt<'_, '_, 'genv, 'tcx> {
580 fn tag(&self, reason: ConstrReason) -> Tag {
581 Tag::new(reason, self.span)
582 }
583
584 pub fn check_pred(&mut self, pred: impl Into<Expr>, reason: ConstrReason) {
585 let tag = self.tag(reason);
586 self.infcx.check_pred(pred, tag);
587 }
588
589 pub fn check_non_closure_clauses(
590 &mut self,
591 clauses: &[rty::Clause],
592 reason: ConstrReason,
593 ) -> InferResult {
594 for clause in clauses {
595 if let rty::ClauseKind::Projection(projection_pred) = clause.kind_skipping_binder() {
596 let impl_elem = BaseTy::projection(projection_pred.projection_ty)
597 .to_ty()
598 .deeply_normalize(self)?;
599 let term = projection_pred.term.to_ty().deeply_normalize(self)?;
600
601 self.subtyping(&impl_elem, &term, reason)?;
603 self.subtyping(&term, &impl_elem, reason)?;
604 }
605 }
606 Ok(())
607 }
608
609 pub fn subtyping_with_env(
612 &mut self,
613 env: &mut impl LocEnv,
614 a: &Ty,
615 b: &Ty,
616 reason: ConstrReason,
617 ) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
618 let mut sub = Sub::new(env, reason, self.span);
619 sub.tys(self.infcx, a, b)?;
620 Ok(sub.obligations)
621 }
622
623 pub fn subtyping(
628 &mut self,
629 a: &Ty,
630 b: &Ty,
631 reason: ConstrReason,
632 ) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
633 let mut env = DummyEnv;
634 let mut sub = Sub::new(&mut env, reason, self.span);
635 sub.tys(self.infcx, a, b)?;
636 Ok(sub.obligations)
637 }
638
639 pub fn subtyping_generic_args(
640 &mut self,
641 variance: Variance,
642 a: &GenericArg,
643 b: &GenericArg,
644 reason: ConstrReason,
645 ) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
646 let mut env = DummyEnv;
647 let mut sub = Sub::new(&mut env, reason, self.span);
648 sub.generic_args(self.infcx, variance, a, b)?;
649 Ok(sub.obligations)
650 }
651
652 pub fn check_constructor(
656 &mut self,
657 variant: EarlyBinder<PolyVariant>,
658 generic_args: &[GenericArg],
659 fields: &[Ty],
660 reason: ConstrReason,
661 ) -> InferResult<Ty> {
662 let ret = self.ensure_resolved_evars(|this| {
663 let generic_args = this.instantiate_generic_args(generic_args);
665
666 let variant = variant
667 .instantiate(this.tcx(), &generic_args, &[])
668 .replace_bound_refts_with(|sort, mode, _| this.fresh_infer_var(sort, mode));
669
670 for (actual, formal) in iter::zip(fields, variant.fields()) {
672 this.subtyping(actual, formal, reason)?;
673 }
674
675 for require in &variant.requires {
677 this.check_pred(require, ConstrReason::Fold);
678 }
679
680 Ok(variant.ret())
681 })?;
682 Ok(self.fully_resolve_evars(&ret))
683 }
684
685 pub fn ensure_resolved_evars<R>(
686 &mut self,
687 f: impl FnOnce(&mut InferCtxtAt<'_, '_, 'genv, 'tcx>) -> InferResult<R>,
688 ) -> InferResult<R> {
689 self.infcx
690 .ensure_resolved_evars(|infcx| f(&mut infcx.at(self.span)))
691 }
692}
693
694impl<'a, 'genv, 'tcx> std::ops::Deref for InferCtxtAt<'_, 'a, 'genv, 'tcx> {
695 type Target = InferCtxt<'a, 'genv, 'tcx>;
696
697 fn deref(&self) -> &Self::Target {
698 self.infcx
699 }
700}
701
702impl std::ops::DerefMut for InferCtxtAt<'_, '_, '_, '_> {
703 fn deref_mut(&mut self) -> &mut Self::Target {
704 self.infcx
705 }
706}
707
708#[derive(TypeVisitable, TypeFoldable)]
712pub(crate) enum TypeTrace {
713 Types(Ty, Ty),
714 BaseTys(BaseTy, BaseTy),
715}
716
717#[expect(dead_code, reason = "we use this for debugging some time")]
718impl TypeTrace {
719 fn tys(a: &Ty, b: &Ty) -> Self {
720 Self::Types(a.clone(), b.clone())
721 }
722
723 fn btys(a: &BaseTy, b: &BaseTy) -> Self {
724 Self::BaseTys(a.clone(), b.clone())
725 }
726}
727
728impl fmt::Debug for TypeTrace {
729 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
730 match self {
731 TypeTrace::Types(a, b) => write!(f, "{a:?} - {b:?}"),
732 TypeTrace::BaseTys(a, b) => write!(f, "{a:?} - {b:?}"),
733 }
734 }
735}
736
737pub trait LocEnv {
738 fn ptr_to_ref(
739 &mut self,
740 infcx: &mut InferCtxtAt,
741 reason: ConstrReason,
742 re: Region,
743 path: &Path,
744 bound: Ty,
745 ) -> InferResult<Ty>;
746
747 fn unfold_strg_ref(&mut self, infcx: &mut InferCtxt, path: &Path, ty: &Ty) -> InferResult<Loc>;
748
749 fn get(&self, path: &Path) -> Ty;
750}
751
752struct DummyEnv;
753
754impl LocEnv for DummyEnv {
755 fn ptr_to_ref(
756 &mut self,
757 _: &mut InferCtxtAt,
758 _: ConstrReason,
759 _: Region,
760 _: &Path,
761 _: Ty,
762 ) -> InferResult<Ty> {
763 tracked_span_bug!("call to `ptr_to_ref` on `DummyEnv`")
764 }
765
766 fn unfold_strg_ref(&mut self, _: &mut InferCtxt, _: &Path, _: &Ty) -> InferResult<Loc> {
767 tracked_span_bug!("call to `unfold_str_ref` on `DummyEnv`")
768 }
769
770 fn get(&self, _: &Path) -> Ty {
771 tracked_span_bug!("call to `get` on `DummyEnv`")
772 }
773}
774
775struct Sub<'a, E> {
777 env: &'a mut E,
779 reason: ConstrReason,
780 span: Span,
781 obligations: Vec<Binder<rty::CoroutineObligPredicate>>,
785}
786
787impl<'a, E: LocEnv> Sub<'a, E> {
788 fn new(env: &'a mut E, reason: ConstrReason, span: Span) -> Self {
789 Self { env, reason, span, obligations: vec![] }
790 }
791
792 fn tag(&self) -> Tag {
793 Tag::new(self.reason, self.span)
794 }
795
796 fn tys(&mut self, infcx: &mut InferCtxt, a: &Ty, b: &Ty) -> InferResult {
797 let infcx = &mut infcx.branch();
798 let a = infcx.unpack(a);
804
805 match (a.kind(), b.kind()) {
806 (TyKind::Exists(..), _) => {
807 bug!("existentials should have been removed by the unpacking above");
808 }
809 (TyKind::Constr(..), _) => {
810 bug!("constraint types should have been removed by the unpacking above");
811 }
812
813 (_, TyKind::Exists(ctor_b)) => {
814 infcx.enter_exists(ctor_b, |infcx, ty_b| self.tys(infcx, &a, &ty_b))
815 }
816 (_, TyKind::Constr(pred_b, ty_b)) => {
817 infcx.check_pred(pred_b, self.tag());
818 self.tys(infcx, &a, ty_b)
819 }
820
821 (TyKind::Ptr(PtrKind::Mut(_), path_a), TyKind::StrgRef(_, path_b, ty_b)) => {
822 let ty_a = self.env.get(path_a);
826 infcx.unify_exprs(&path_a.to_expr(), &path_b.to_expr());
827 self.tys(infcx, &ty_a, ty_b)
828 }
829 (TyKind::StrgRef(_, path_a, ty_a), TyKind::StrgRef(_, path_b, ty_b)) => {
830 self.env.unfold_strg_ref(infcx, path_a, ty_a)?;
842 let ty_a = self.env.get(path_a);
843 infcx.unify_exprs(&path_a.to_expr(), &path_b.to_expr());
844 self.tys(infcx, &ty_a, ty_b)
845 }
846 (
847 TyKind::Ptr(PtrKind::Mut(re), path),
848 TyKind::Indexed(BaseTy::Ref(_, bound, Mutability::Mut), idx),
849 ) => {
850 self.idxs_eq(infcx, &Expr::unit(), idx);
853
854 self.env.ptr_to_ref(
855 &mut infcx.at(self.span),
856 self.reason,
857 *re,
858 path,
859 bound.clone(),
860 )?;
861 Ok(())
862 }
863
864 (TyKind::Indexed(bty_a, idx_a), TyKind::Indexed(bty_b, idx_b)) => {
865 self.btys(infcx, bty_a, bty_b)?;
866 self.idxs_eq(infcx, idx_a, idx_b);
867 Ok(())
868 }
869 (TyKind::Ptr(pk_a, path_a), TyKind::Ptr(pk_b, path_b)) => {
870 debug_assert_eq!(pk_a, pk_b);
871 debug_assert_eq!(path_a, path_b);
872 Ok(())
873 }
874 (TyKind::Param(param_ty_a), TyKind::Param(param_ty_b)) => {
875 debug_assert_eq!(param_ty_a, param_ty_b);
876 Ok(())
877 }
878 (_, TyKind::Uninit) => Ok(()),
879 (TyKind::Downcast(.., fields_a), TyKind::Downcast(.., fields_b)) => {
880 debug_assert_eq!(fields_a.len(), fields_b.len());
881 for (ty_a, ty_b) in iter::zip(fields_a, fields_b) {
882 self.tys(infcx, ty_a, ty_b)?;
883 }
884 Ok(())
885 }
886 _ => Err(query_bug!("incompatible types: `{a:?}` - `{b:?}`"))?,
887 }
888 }
889
890 fn btys(&mut self, infcx: &mut InferCtxt, a: &BaseTy, b: &BaseTy) -> InferResult {
891 match (a, b) {
894 (BaseTy::Int(int_ty_a), BaseTy::Int(int_ty_b)) => {
895 debug_assert_eq!(int_ty_a, int_ty_b);
896 Ok(())
897 }
898 (BaseTy::Uint(uint_ty_a), BaseTy::Uint(uint_ty_b)) => {
899 debug_assert_eq!(uint_ty_a, uint_ty_b);
900 Ok(())
901 }
902 (BaseTy::Adt(a_adt, a_args), BaseTy::Adt(b_adt, b_args)) => {
903 tracked_span_dbg_assert_eq!(a_adt.did(), b_adt.did());
904 tracked_span_dbg_assert_eq!(a_args.len(), b_args.len());
905 let variances = infcx.genv.variances_of(a_adt.did());
906 for (variance, ty_a, ty_b) in izip!(variances, a_args.iter(), b_args.iter()) {
907 self.generic_args(infcx, *variance, ty_a, ty_b)?;
908 }
909 Ok(())
910 }
911 (BaseTy::FnDef(a_def_id, a_args), BaseTy::FnDef(b_def_id, b_args)) => {
912 debug_assert_eq!(a_def_id, b_def_id);
913 debug_assert_eq!(a_args.len(), b_args.len());
914 for (arg_a, arg_b) in iter::zip(a_args, b_args) {
922 match (arg_a, arg_b) {
923 (GenericArg::Ty(ty_a), GenericArg::Ty(ty_b)) => {
924 let bty_a = ty_a.as_bty_skipping_existentials();
925 let bty_b = ty_b.as_bty_skipping_existentials();
926 tracked_span_dbg_assert_eq!(bty_a, bty_b);
927 }
928 (GenericArg::Base(ctor_a), GenericArg::Base(ctor_b)) => {
929 let bty_a = ctor_a.as_bty_skipping_binder();
930 let bty_b = ctor_b.as_bty_skipping_binder();
931 tracked_span_dbg_assert_eq!(bty_a, bty_b);
932 }
933 (_, _) => tracked_span_dbg_assert_eq!(arg_a, arg_b),
934 }
935 }
936 Ok(())
937 }
938 (BaseTy::Float(float_ty_a), BaseTy::Float(float_ty_b)) => {
939 debug_assert_eq!(float_ty_a, float_ty_b);
940 Ok(())
941 }
942 (BaseTy::Slice(ty_a), BaseTy::Slice(ty_b)) => self.tys(infcx, ty_a, ty_b),
943
944 (BaseTy::RawPtr(ty_a, mut_a), BaseTy::RawPtr(ty_b, mut_b)) => {
945 debug_assert_eq!(mut_a, mut_b);
946 self.tys(infcx, ty_a, ty_b)?;
947 if matches!(mut_a, Mutability::Mut) {
948 self.tys(infcx, ty_b, ty_a)?;
949 }
950 Ok(())
951 }
952
953 (BaseTy::Ref(_, ty_a, Mutability::Mut), BaseTy::Ref(_, ty_b, Mutability::Mut)) => {
954 if ty_a.is_slice()
955 && let TyKind::Indexed(_, idx_a) = ty_a.kind()
956 && let TyKind::Exists(bty_b) = ty_b.kind()
957 {
958 self.tys(infcx, ty_a, ty_b)?;
963 self.tys(infcx, &bty_b.replace_bound_reft(idx_a), ty_a)
964 } else {
965 self.tys(infcx, ty_a, ty_b)?;
966 self.tys(infcx, ty_b, ty_a)
967 }
968 }
969 (BaseTy::Ref(_, ty_a, Mutability::Not), BaseTy::Ref(_, ty_b, Mutability::Not)) => {
970 self.tys(infcx, ty_a, ty_b)
971 }
972 (BaseTy::Tuple(tys_a), BaseTy::Tuple(tys_b)) => {
973 debug_assert_eq!(tys_a.len(), tys_b.len());
974 for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
975 self.tys(infcx, ty_a, ty_b)?;
976 }
977 Ok(())
978 }
979 (
980 BaseTy::Alias(AliasKind::Opaque, alias_ty_a),
981 BaseTy::Alias(AliasKind::Opaque, alias_ty_b),
982 ) => {
983 debug_assert_eq!(alias_ty_a.def_id, alias_ty_b.def_id);
984
985 for (ty_a, ty_b) in izip!(alias_ty_a.args.iter(), alias_ty_b.args.iter()) {
987 self.generic_args(infcx, Invariant, ty_a, ty_b)?;
988 }
989
990 debug_assert_eq!(alias_ty_a.refine_args.len(), alias_ty_b.refine_args.len());
992 iter::zip(alias_ty_a.refine_args.iter(), alias_ty_b.refine_args.iter())
993 .for_each(|(expr_a, expr_b)| infcx.unify_exprs(expr_a, expr_b));
994
995 Ok(())
996 }
997 (_, BaseTy::Alias(AliasKind::Opaque, alias_ty_b)) => {
998 self.handle_opaque_type(infcx, a, alias_ty_b)
1000 }
1001 (
1002 BaseTy::Alias(AliasKind::Projection, alias_ty_a),
1003 BaseTy::Alias(AliasKind::Projection, alias_ty_b),
1004 ) => {
1005 tracked_span_dbg_assert_eq!(alias_ty_a.erase_regions(), alias_ty_b.erase_regions());
1006 Ok(())
1007 }
1008 (BaseTy::Array(ty_a, len_a), BaseTy::Array(ty_b, len_b)) => {
1009 tracked_span_dbg_assert_eq!(len_a, len_b);
1010 self.tys(infcx, ty_a, ty_b)
1011 }
1012 (BaseTy::Param(param_a), BaseTy::Param(param_b)) => {
1013 debug_assert_eq!(param_a, param_b);
1014 Ok(())
1015 }
1016 (BaseTy::Bool, BaseTy::Bool)
1017 | (BaseTy::Str, BaseTy::Str)
1018 | (BaseTy::Char, BaseTy::Char)
1019 | (BaseTy::RawPtrMetadata(_), BaseTy::RawPtrMetadata(_)) => Ok(()),
1020 (BaseTy::Dynamic(preds_a, _), BaseTy::Dynamic(preds_b, _)) => {
1021 tracked_span_assert_eq!(preds_a.erase_regions(), preds_b.erase_regions());
1022 Ok(())
1023 }
1024 (BaseTy::Closure(did1, tys_a, _, _), BaseTy::Closure(did2, tys_b, _, _))
1025 if did1 == did2 =>
1026 {
1027 debug_assert_eq!(tys_a.len(), tys_b.len());
1028 for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
1029 self.tys(infcx, ty_a, ty_b)?;
1030 }
1031 Ok(())
1032 }
1033 (BaseTy::FnPtr(sig_a), BaseTy::FnPtr(sig_b)) => {
1034 tracked_span_assert_eq!(sig_a.erase_regions(), sig_b.erase_regions());
1035 Ok(())
1036 }
1037 (BaseTy::Never, BaseTy::Never) => Ok(()),
1038 (
1039 BaseTy::Coroutine(did1, resume_ty_a, tys_a, _),
1040 BaseTy::Coroutine(did2, resume_ty_b, tys_b, _),
1041 ) if did1 == did2 => {
1042 debug_assert_eq!(tys_a.len(), tys_b.len());
1043 for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
1044 self.tys(infcx, ty_a, ty_b)?;
1045 }
1046 self.tys(infcx, resume_ty_b, resume_ty_a)?;
1048 self.tys(infcx, resume_ty_a, resume_ty_b)?;
1049
1050 Ok(())
1051 }
1052 (BaseTy::Foreign(did_a), BaseTy::Foreign(did_b)) if did_a == did_b => Ok(()),
1053 _ => Err(query_bug!("incompatible base types: `{a:#?}` - `{b:#?}`"))?,
1054 }
1055 }
1056
1057 fn generic_args(
1058 &mut self,
1059 infcx: &mut InferCtxt,
1060 variance: Variance,
1061 a: &GenericArg,
1062 b: &GenericArg,
1063 ) -> InferResult {
1064 let (ty_a, ty_b) = match (a, b) {
1065 (GenericArg::Ty(ty_a), GenericArg::Ty(ty_b)) => (ty_a.clone(), ty_b.clone()),
1066 (GenericArg::Base(ctor_a), GenericArg::Base(ctor_b)) => {
1067 tracked_span_dbg_assert_eq!(
1068 ctor_a.sort().erase_regions(),
1069 ctor_b.sort().erase_regions()
1070 );
1071 (ctor_a.to_ty(), ctor_b.to_ty())
1072 }
1073 (GenericArg::Lifetime(_), GenericArg::Lifetime(_)) => return Ok(()),
1074 (GenericArg::Const(cst_a), GenericArg::Const(cst_b)) => {
1075 debug_assert_eq!(cst_a, cst_b);
1076 return Ok(());
1077 }
1078 _ => Err(query_bug!("incompatible generic args: `{a:?}` `{b:?}`"))?,
1079 };
1080 match variance {
1081 Variance::Covariant => self.tys(infcx, &ty_a, &ty_b),
1082 Variance::Invariant => {
1083 self.tys(infcx, &ty_a, &ty_b)?;
1084 self.tys(infcx, &ty_b, &ty_a)
1085 }
1086 Variance::Contravariant => self.tys(infcx, &ty_b, &ty_a),
1087 Variance::Bivariant => Ok(()),
1088 }
1089 }
1090
1091 fn idxs_eq(&mut self, infcx: &mut InferCtxt, a: &Expr, b: &Expr) {
1092 if a == b {
1093 return;
1094 }
1095 match (a.kind(), b.kind()) {
1096 (
1097 ExprKind::Ctor(Ctor::Struct(did_a), flds_a),
1098 ExprKind::Ctor(Ctor::Struct(did_b), flds_b),
1099 ) => {
1100 debug_assert_eq!(did_a, did_b);
1101 for (a, b) in iter::zip(flds_a, flds_b) {
1102 self.idxs_eq(infcx, a, b);
1103 }
1104 }
1105 (ExprKind::Tuple(flds_a), ExprKind::Tuple(flds_b)) => {
1106 for (a, b) in iter::zip(flds_a, flds_b) {
1107 self.idxs_eq(infcx, a, b);
1108 }
1109 }
1110 (ExprKind::Ctor(Ctor::RawPtr, flds_a), ExprKind::Ctor(Ctor::RawPtr, flds_b)) => {
1111 for (a, b) in iter::zip(flds_a, flds_b) {
1112 self.idxs_eq(infcx, a, b);
1113 }
1114 }
1115 (_, ExprKind::Tuple(flds_b)) => {
1116 for (f, b) in flds_b.iter().enumerate() {
1117 let proj = FieldProj::Tuple { arity: flds_b.len(), field: f as u32 };
1118 let a = a.proj_and_reduce(proj);
1119 self.idxs_eq(infcx, &a, b);
1120 }
1121 }
1122 (_, ExprKind::Ctor(Ctor::RawPtr, flds_b)) => {
1123 for (f, b) in flds_b.iter().enumerate() {
1124 let field = rty::RawPtrField::from_index(f as u32).unwrap();
1125 let a = a.proj_and_reduce(FieldProj::RawPtr { field });
1126 self.idxs_eq(infcx, &a, b);
1127 }
1128 }
1129
1130 (_, ExprKind::Ctor(Ctor::Struct(def_id), flds_b)) => {
1131 for (f, b) in flds_b.iter().enumerate() {
1132 let proj = FieldProj::Adt { def_id: *def_id, field: f as u32 };
1133 let a = a.proj_and_reduce(proj);
1134 self.idxs_eq(infcx, &a, b);
1135 }
1136 }
1137
1138 (ExprKind::Tuple(flds_a), _) => {
1139 infcx.unify_exprs(a, b);
1140 for (f, a) in flds_a.iter().enumerate() {
1141 let proj = FieldProj::Tuple { arity: flds_a.len(), field: f as u32 };
1142 let b = b.proj_and_reduce(proj);
1143 self.idxs_eq(infcx, a, &b);
1144 }
1145 }
1146 (ExprKind::Ctor(Ctor::RawPtr, flds_a), _) => {
1147 infcx.unify_exprs(a, b);
1148 for (f, a) in flds_a.iter().enumerate() {
1149 let field = rty::RawPtrField::from_index(f as u32).unwrap();
1150 let b = b.proj_and_reduce(FieldProj::RawPtr { field });
1151 self.idxs_eq(infcx, a, &b);
1152 }
1153 }
1154 (ExprKind::Ctor(Ctor::Struct(def_id), flds_a), _) => {
1155 infcx.unify_exprs(a, b);
1156 for (f, a) in flds_a.iter().enumerate() {
1157 let proj = FieldProj::Adt { def_id: *def_id, field: f as u32 };
1158 let b = b.proj_and_reduce(proj);
1159 self.idxs_eq(infcx, a, &b);
1160 }
1161 }
1162 (ExprKind::Abs(lam_a), ExprKind::Abs(lam_b)) => {
1163 self.abs_eq(infcx, lam_a, lam_b);
1164 }
1165 (_, ExprKind::Abs(lam_b)) => {
1166 self.abs_eq(infcx, &a.eta_expand_abs(lam_b.vars(), lam_b.output()), lam_b);
1167 }
1168 (ExprKind::Abs(lam_a), _) => {
1169 infcx.unify_exprs(a, b);
1170 self.abs_eq(infcx, lam_a, &b.eta_expand_abs(lam_a.vars(), lam_a.output()));
1171 }
1172 (ExprKind::KVar(_), _) | (_, ExprKind::KVar(_)) => {
1173 infcx.check_impl(a, b, self.tag());
1174 infcx.check_impl(b, a, self.tag());
1175 }
1176 _ => {
1177 infcx.unify_exprs(a, b);
1178 let span = b.span();
1179 infcx.check_pred(Expr::binary_op(rty::BinOp::Eq, a, b).at_opt(span), self.tag());
1180 }
1181 }
1182 }
1183
1184 fn abs_eq(&mut self, infcx: &mut InferCtxt, a: &Lambda, b: &Lambda) {
1185 debug_assert_eq!(a.vars().len(), b.vars().len());
1186 let vars = a
1187 .vars()
1188 .iter()
1189 .map(|kind| {
1190 let (sort, _, kind) = kind.expect_refine();
1191 Expr::fvar(infcx.define_bound_reft_var(sort, kind))
1192 })
1193 .collect_vec();
1194 let body_a = a.apply(&vars);
1195 let body_b = b.apply(&vars);
1196 self.idxs_eq(infcx, &body_a, &body_b);
1197 }
1198
1199 fn handle_opaque_type(
1200 &mut self,
1201 infcx: &mut InferCtxt,
1202 bty: &BaseTy,
1203 alias_ty: &AliasTy,
1204 ) -> InferResult {
1205 if let BaseTy::Coroutine(def_id, resume_ty, upvar_tys, args) = bty {
1206 let obligs = mk_coroutine_obligations(
1207 infcx.genv,
1208 def_id,
1209 resume_ty,
1210 upvar_tys,
1211 &alias_ty.def_id,
1212 args.clone(),
1213 )?;
1214 self.obligations.extend(obligs);
1215 } else {
1216 let bounds = infcx.genv.item_bounds(alias_ty.def_id)?.instantiate(
1217 infcx.tcx(),
1218 &alias_ty.args,
1219 &alias_ty.refine_args,
1220 );
1221 for clause in &bounds {
1222 if !clause.kind().vars().is_empty() {
1223 Err(query_bug!("handle_opaque_types: clause with bound vars: `{clause:?}`"))?;
1224 }
1225 if let rty::ClauseKind::Projection(pred) = clause.kind_skipping_binder() {
1226 let alias_ty = pred.projection_ty.with_self_ty(bty.to_subset_ty_ctor());
1227 let ty1 = BaseTy::Alias(AliasKind::Projection, alias_ty)
1228 .to_ty()
1229 .deeply_normalize(&mut infcx.at(self.span))?;
1230 let ty2 = pred.term.to_ty();
1231 self.tys(infcx, &ty1, &ty2)?;
1232 }
1233 }
1234 }
1235 Ok(())
1236 }
1237}
1238
1239fn mk_coroutine_obligations(
1240 genv: GlobalEnv,
1241 generator_did: &DefId,
1242 resume_ty: &Ty,
1243 upvar_tys: &List<Ty>,
1244 opaque_def_id: &DefId,
1245 args: flux_rustc_bridge::ty::GenericArgs,
1246) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
1247 let bounds = genv.item_bounds(*opaque_def_id)?.skip_binder();
1248 for bound in &bounds {
1249 if let Some(proj_clause) = bound.as_projection_clause() {
1250 return Ok(vec![proj_clause.map(|proj_clause| {
1251 let output = proj_clause.term;
1252 CoroutineObligPredicate {
1253 def_id: *generator_did,
1254 resume_ty: resume_ty.clone(),
1255 upvar_tys: upvar_tys.clone(),
1256 output: output.to_ty(),
1257 args,
1258 }
1259 })]);
1260 }
1261 }
1262 bug!("no projection predicate")
1263}
1264
1265#[derive(Debug)]
1266pub enum InferErr {
1267 UnsolvedEvar(EVid),
1268 Query(QueryErr),
1269}
1270
1271impl From<QueryErr> for InferErr {
1272 fn from(v: QueryErr) -> Self {
1273 Self::Query(v)
1274 }
1275}
1276
1277mod pretty {
1278 use std::fmt;
1279
1280 use flux_middle::pretty::*;
1281
1282 use super::*;
1283
1284 impl Pretty for Tag {
1285 fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1286 w!(cx, f, "{:?} at {:?}", ^self.reason, self.src_span)?;
1287 if let Some(dst_span) = self.dst_span {
1288 w!(cx, f, " ({:?})", ^dst_span)?;
1289 }
1290 Ok(())
1291 }
1292 }
1293
1294 impl_debug_with_default_cx!(Tag);
1295}