1use std::{cell::RefCell, fmt, iter};
2
3use flux_common::{bug, dbg, tracked_span_assert_eq, tracked_span_bug, tracked_span_dbg_assert_eq};
4use flux_config::{self as config, InferOpts, OverflowMode, RawDerefMode};
5use flux_macros::{TypeFoldable, TypeVisitable};
6use flux_middle::{
7 FixpointQueryKind,
8 def_id::MaybeExternId,
9 global_env::GlobalEnv,
10 metrics::{self, Metric},
11 queries::{QueryErr, QueryResult},
12 query_bug,
13 rty::{
14 self, AliasKind, AliasTy, BaseTy, Binder, BoundReftKind, BoundVariableKinds,
15 CoroutineObligPredicate, Ctor, ESpan, EVid, EarlyBinder, Expr, ExprKind, FieldProj,
16 GenericArg, HoleKind, InferMode, Lambda, List, Loc, Mutability, Name, NameProvenance, Path,
17 PolyVariant, PtrKind, RefineArgs, RefineArgsExt, Region, Sort, Ty, TyCtor, TyKind, Var,
18 canonicalize::{Hoister, HoisterDelegate},
19 fold::TypeFoldable,
20 },
21};
22use itertools::{Itertools, izip};
23use rustc_hir::def_id::{DefId, LocalDefId};
24use rustc_macros::extension;
25use rustc_middle::{
26 mir::BasicBlock,
27 ty::{TyCtxt, Variance},
28};
29use rustc_span::{Span, Symbol};
30use rustc_type_ir::Variance::Invariant;
31
32use crate::{
33 evars::{EVarState, EVarStore},
34 fixpoint_encoding::{
35 Answer, Backend, FixQueryCache, FixpointCtxt, KVarEncoding, KVarGen, KVarSolutions,
36 lean_task_key,
37 },
38 lean_encoding::log_proof,
39 projections::NormalizeExt as _,
40 refine_tree::{Cursor, Marker, RefineTree, Scope},
41};
42
43pub type InferResult<T = ()> = std::result::Result<T, InferErr>;
44
45#[derive(PartialEq, Eq, Clone, Copy, Hash)]
46pub struct Tag {
47 pub reason: ConstrReason,
48 pub src_span: Span,
49 pub dst_span: Option<ESpan>,
50}
51
52impl Tag {
53 pub fn new(reason: ConstrReason, span: Span) -> Self {
54 Self { reason, src_span: span, dst_span: None }
55 }
56
57 pub fn with_dst(self, dst_span: Option<ESpan>) -> Self {
58 Self { dst_span, ..self }
59 }
60}
61
62#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)]
63pub enum SubtypeReason {
64 Input,
65 Output,
66 Requires,
67 Ensures,
68}
69
70#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)]
71pub enum ConstrReason {
72 Call,
73 Assign,
74 Ret,
75 Fold,
76 FoldLocal,
77 Predicate,
78 Assert(&'static str),
79 Div,
80 Rem,
81 Goto(BasicBlock),
82 Overflow,
83 Underflow,
84 Subtype(SubtypeReason),
85 NoPanic(DefId),
86 Other,
87}
88
89pub struct InferCtxtRoot<'genv, 'tcx> {
90 pub genv: GlobalEnv<'genv, 'tcx>,
91 inner: RefCell<InferCtxtInner>,
92 refine_tree: RefineTree,
93 opts: InferOpts,
94}
95
96pub struct InferCtxtRootBuilder<'a, 'genv, 'tcx> {
97 genv: GlobalEnv<'genv, 'tcx>,
98 opts: InferOpts,
99 params: Vec<(Var, Sort)>,
100 infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
101 dummy_kvars: bool,
102}
103
104#[extension(pub trait GlobalEnvExt<'genv, 'tcx>)]
105impl<'genv, 'tcx> GlobalEnv<'genv, 'tcx> {
106 fn infcx_root<'a>(
107 self,
108 infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
109 opts: InferOpts,
110 ) -> InferCtxtRootBuilder<'a, 'genv, 'tcx> {
111 InferCtxtRootBuilder { genv: self, infcx, params: vec![], opts, dummy_kvars: false }
112 }
113}
114
115impl<'genv, 'tcx> InferCtxtRootBuilder<'_, 'genv, 'tcx> {
116 pub fn with_dummy_kvars(mut self) -> Self {
117 self.dummy_kvars = true;
118 self
119 }
120
121 pub fn with_const_generics(mut self, def_id: DefId) -> QueryResult<Self> {
122 self.params.extend(
123 self.genv
124 .generics_of(def_id)?
125 .const_params(self.genv)?
126 .into_iter()
127 .map(|(pcst, sort)| (Var::ConstGeneric(pcst), sort)),
128 );
129 Ok(self)
130 }
131
132 pub fn with_refinement_generics(
133 mut self,
134 def_id: DefId,
135 args: &[GenericArg],
136 ) -> QueryResult<Self> {
137 for (index, param) in self
138 .genv
139 .refinement_generics_of(def_id)?
140 .iter_own_params()
141 .enumerate()
142 {
143 let param = param.instantiate(self.genv.tcx(), args, &[]);
144 let sort = param
145 .sort
146 .deeply_normalize_sorts(def_id, self.genv, self.infcx)?;
147
148 let var =
149 Var::EarlyParam(rty::EarlyReftParam { index: index as u32, name: param.name });
150 self.params.push((var, sort));
151 }
152 Ok(self)
153 }
154
155 pub fn identity_for_item(mut self, def_id: DefId) -> QueryResult<Self> {
156 self = self.with_const_generics(def_id)?;
157 let offset = self.params.len();
158 self.genv.refinement_generics_of(def_id)?.fill_item(
159 self.genv,
160 &mut self.params,
161 &mut |param, index| {
162 let index = (index - offset) as u32;
163 let param = param.instantiate_identity();
164 let sort = param
165 .sort
166 .deeply_normalize_sorts(def_id, self.genv, self.infcx)?;
167
168 let var = Var::EarlyParam(rty::EarlyReftParam { index, name: param.name });
169 Ok((var, sort))
170 },
171 )?;
172 Ok(self)
173 }
174
175 pub fn build(self) -> QueryResult<InferCtxtRoot<'genv, 'tcx>> {
176 Ok(InferCtxtRoot {
177 genv: self.genv,
178 inner: RefCell::new(InferCtxtInner::new(self.dummy_kvars)),
179 refine_tree: RefineTree::new(self.params),
180 opts: self.opts,
181 })
182 }
183}
184
185impl<'genv, 'tcx> InferCtxtRoot<'genv, 'tcx> {
186 pub fn infcx<'a>(
187 &'a mut self,
188 def_id: DefId,
189 region_infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
190 ) -> InferCtxt<'a, 'genv, 'tcx> {
191 InferCtxt {
192 genv: self.genv,
193 region_infcx,
194 def_id,
195 cursor: self.refine_tree.cursor_at_root(),
196 inner: &self.inner,
197 check_overflow: self.opts.check_overflow,
198 allow_raw_deref: self.opts.allow_raw_deref,
199 }
200 }
201
202 pub fn fresh_kvar_in_scope(
203 &self,
204 binders: &[BoundVariableKinds],
205 scope: &Scope,
206 encoding: KVarEncoding,
207 ) -> Expr {
208 let inner = &mut *self.inner.borrow_mut();
209 inner.kvars.fresh(binders, scope.iter(), encoding)
210 }
211
212 pub fn execute_lean_query(
213 self,
214 cache: &mut FixQueryCache,
215 def_id: MaybeExternId,
216 ) -> QueryResult {
217 let inner = self.inner.into_inner();
218 let kvars = inner.kvars;
219 let evars = inner.evars;
220 let mut refine_tree = self.refine_tree;
221 refine_tree.replace_evars(&evars).unwrap();
222 refine_tree.simplify(self.genv);
223
224 let solver = match self.opts.solver {
225 flux_config::SmtSolver::Z3 => liquid_fixpoint::SmtSolver::Z3,
226 flux_config::SmtSolver::CVC5 => liquid_fixpoint::SmtSolver::CVC5,
227 };
228 let mut fcx = FixpointCtxt::new(self.genv, def_id, kvars, Backend::Lean);
229 let cstr = refine_tree.to_fixpoint(&mut fcx)?;
230 let cstr_variable_sorts = cstr.variable_sorts();
231 let task = fcx.create_task(def_id, cstr, self.opts.scrape_quals, solver)?;
232
233 log_proof(self.genv, def_id)?;
234 if config::is_cache_enabled() {
236 let key = lean_task_key(self.genv.tcx(), def_id.resolved_id());
237 let hash = task.hash_with_default();
238 if cache.lookup(&key, hash).is_some() {
239 return Ok(());
240 }
241 }
242
243 let result = fcx.run_task(cache, def_id, FixpointQueryKind::Body, &task)?;
244
245 fcx.generate_lean_files(
246 def_id,
247 task,
248 KVarSolutions::closed_solutions(
249 cstr_variable_sorts,
250 result.solution,
251 result.non_cut_solution,
252 ),
253 )
254 }
255
256 pub fn execute_fixpoint_query(
257 self,
258 cache: &mut FixQueryCache,
259 def_id: MaybeExternId,
260 kind: FixpointQueryKind,
261 ) -> QueryResult<Answer<Tag>> {
262 let inner = self.inner.into_inner();
263 let kvars = inner.kvars;
264 let evars = inner.evars;
265
266 let ext = kind.ext();
267
268 let mut refine_tree = self.refine_tree;
269
270 refine_tree.replace_evars(&evars).unwrap();
271
272 if config::dump_constraint() {
273 dbg::dump_item_info(self.genv.tcx(), def_id.resolved_id(), ext, &refine_tree).unwrap();
274 }
275 refine_tree.simplify(self.genv);
276 if config::dump_constraint() {
277 let simp_ext = format!("simp.{ext}");
278 dbg::dump_item_info(self.genv.tcx(), def_id.resolved_id(), simp_ext, &refine_tree)
279 .unwrap();
280 }
281
282 let backend = match self.opts.solver {
283 flux_config::SmtSolver::Z3 => liquid_fixpoint::SmtSolver::Z3,
284 flux_config::SmtSolver::CVC5 => liquid_fixpoint::SmtSolver::CVC5,
285 };
286
287 let mut fcx = FixpointCtxt::new(self.genv, def_id, kvars, Backend::Fixpoint);
288 let cstr = refine_tree.to_fixpoint(&mut fcx)?;
289
290 let count = cstr.concrete_head_count();
292 metrics::incr_metric(Metric::CsTotal, count as u32);
293 if count == 0 {
294 metrics::incr_metric_if(kind.is_body(), Metric::FnTrivial);
295 return Ok(Answer::trivial());
296 }
297
298 let task = fcx.create_task(def_id, cstr, self.opts.scrape_quals, backend)?;
299 let result = fcx.run_task(cache, def_id, kind, &task)?;
300 Ok(fcx.result_to_answer(result))
301 }
302
303 pub fn split(self) -> (RefineTree, KVarGen) {
304 (self.refine_tree, self.inner.into_inner().kvars)
305 }
306}
307
308pub struct InferCtxt<'infcx, 'genv, 'tcx> {
309 pub genv: GlobalEnv<'genv, 'tcx>,
310 pub region_infcx: &'infcx rustc_infer::infer::InferCtxt<'tcx>,
311 pub def_id: DefId,
312 pub check_overflow: OverflowMode,
313 pub allow_raw_deref: flux_config::RawDerefMode,
314 cursor: Cursor<'infcx>,
315 inner: &'infcx RefCell<InferCtxtInner>,
316}
317
318struct InferCtxtInner {
319 kvars: KVarGen,
320 evars: EVarStore,
321}
322
323impl InferCtxtInner {
324 fn new(dummy_kvars: bool) -> Self {
325 Self { kvars: KVarGen::new(dummy_kvars), evars: Default::default() }
326 }
327}
328
329impl<'infcx, 'genv, 'tcx> InferCtxt<'infcx, 'genv, 'tcx> {
330 pub fn at(&mut self, span: Span) -> InferCtxtAt<'_, 'infcx, 'genv, 'tcx> {
331 InferCtxtAt { infcx: self, span }
332 }
333
334 pub fn instantiate_refine_args(
335 &mut self,
336 callee_def_id: DefId,
337 args: &[rty::GenericArg],
338 ) -> InferResult<List<Expr>> {
339 Ok(RefineArgs::for_item(self.genv, callee_def_id, |param, _| {
340 let param = param.instantiate(self.genv.tcx(), args, &[]);
341 Ok(self.fresh_infer_var(¶m.sort, param.mode))
342 })?)
343 }
344
345 pub fn instantiate_generic_args(&mut self, args: &[GenericArg]) -> Vec<GenericArg> {
346 args.iter()
347 .map(|a| a.replace_holes(|binders, kind| self.fresh_infer_var_for_hole(binders, kind)))
348 .collect_vec()
349 }
350
351 pub fn fresh_infer_var(&self, sort: &Sort, mode: InferMode) -> Expr {
352 match mode {
353 InferMode::KVar => {
354 let fsort = sort.expect_func().expect_mono();
355 let vars = fsort.inputs().iter().cloned().map_into().collect();
356 let kvar = self.fresh_kvar(&[vars], KVarEncoding::Single);
357 Expr::abs(Lambda::bind_with_fsort(kvar, fsort))
358 }
359 InferMode::EVar => self.fresh_evar(),
360 }
361 }
362
363 pub fn fresh_infer_var_for_hole(
364 &mut self,
365 binders: &[BoundVariableKinds],
366 kind: HoleKind,
367 ) -> Expr {
368 match kind {
369 HoleKind::Pred => self.fresh_kvar(binders, KVarEncoding::Conj),
370 HoleKind::Expr(_) => {
371 self.fresh_evar()
375 }
376 }
377 }
378
379 pub fn fresh_kvar_in_scope(
381 &self,
382 binders: &[BoundVariableKinds],
383 scope: &Scope,
384 encoding: KVarEncoding,
385 ) -> Expr {
386 let inner = &mut *self.inner.borrow_mut();
387 inner.kvars.fresh(binders, scope.iter(), encoding)
388 }
389
390 pub fn fresh_kvar(&self, binders: &[BoundVariableKinds], encoding: KVarEncoding) -> Expr {
392 let inner = &mut *self.inner.borrow_mut();
393 inner.kvars.fresh(binders, self.cursor.vars(), encoding)
394 }
395
396 fn fresh_evar(&self) -> Expr {
397 let evars = &mut self.inner.borrow_mut().evars;
398 Expr::evar(evars.fresh(self.cursor.marker()))
399 }
400
401 pub fn unify_exprs(&self, a: &Expr, b: &Expr) {
402 if a.has_evars() {
403 return;
404 }
405 let evars = &mut self.inner.borrow_mut().evars;
406 if let ExprKind::Var(Var::EVar(evid)) = b.kind()
407 && let EVarState::Unsolved(marker) = evars.get(*evid)
408 && !marker.has_free_vars(a)
409 {
410 evars.solve(*evid, a.clone());
411 }
412 }
413
414 fn enter_exists<T, U>(
415 &mut self,
416 t: &Binder<T>,
417 f: impl FnOnce(&mut InferCtxt<'_, 'genv, 'tcx>, T) -> U,
418 ) -> U
419 where
420 T: TypeFoldable,
421 {
422 self.ensure_resolved_evars(|infcx| {
423 let t = t.replace_bound_refts_with(|sort, mode, _| infcx.fresh_infer_var(sort, mode));
424 Ok(f(infcx, t))
425 })
426 .unwrap()
427 }
428
429 pub fn push_evar_scope(&mut self) {
434 self.inner.borrow_mut().evars.push_scope();
435 }
436
437 pub fn pop_evar_scope(&mut self) -> InferResult {
440 self.inner
441 .borrow_mut()
442 .evars
443 .pop_scope()
444 .map_err(InferErr::UnsolvedEvar)
445 }
446
447 pub fn ensure_resolved_evars<R>(
449 &mut self,
450 f: impl FnOnce(&mut Self) -> InferResult<R>,
451 ) -> InferResult<R> {
452 self.push_evar_scope();
453 let r = f(self)?;
454 self.pop_evar_scope()?;
455 Ok(r)
456 }
457
458 pub fn fully_resolve_evars<T: TypeFoldable>(&self, t: &T) -> T {
459 self.inner.borrow().evars.replace_evars(t).unwrap()
460 }
461
462 pub fn tcx(&self) -> TyCtxt<'tcx> {
463 self.genv.tcx()
464 }
465
466 pub fn cursor(&self) -> &Cursor<'infcx> {
467 &self.cursor
468 }
469
470 pub fn allow_raw_deref(&self) -> bool {
471 matches!(self.allow_raw_deref, RawDerefMode::Ok)
472 }
473}
474
475impl<'infcx, 'genv, 'tcx> InferCtxt<'infcx, 'genv, 'tcx> {
477 pub fn change_item<'a>(
478 &'a mut self,
479 def_id: LocalDefId,
480 region_infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
481 ) -> InferCtxt<'a, 'genv, 'tcx> {
482 InferCtxt {
483 def_id: def_id.to_def_id(),
484 cursor: self.cursor.branch(),
485 region_infcx,
486 ..*self
487 }
488 }
489
490 pub fn move_to(&mut self, marker: &Marker, clear_children: bool) -> InferCtxt<'_, 'genv, 'tcx> {
491 InferCtxt {
492 cursor: self
493 .cursor
494 .move_to(marker, clear_children)
495 .unwrap_or_else(|| tracked_span_bug!()),
496 ..*self
497 }
498 }
499
500 pub fn branch(&mut self) -> InferCtxt<'_, 'genv, 'tcx> {
501 InferCtxt { cursor: self.cursor.branch(), ..*self }
502 }
503
504 fn define_var(&mut self, sort: &Sort, provenance: NameProvenance) -> Name {
505 self.cursor.define_var(sort, provenance)
506 }
507
508 pub fn define_bound_reft_var(&mut self, sort: &Sort, kind: BoundReftKind) -> Name {
509 self.define_var(sort, NameProvenance::UnfoldBoundReft(kind))
510 }
511
512 pub fn define_unknown_var(&mut self, sort: &Sort) -> Name {
513 self.cursor.define_var(sort, NameProvenance::Unknown)
514 }
515
516 pub fn check_pred(&mut self, pred: impl Into<Expr>, tag: Tag) {
517 self.cursor.check_pred(pred, tag);
518 }
519
520 pub fn assume_pred(&mut self, pred: impl Into<Expr>) {
521 self.cursor.assume_pred(pred);
522 }
523
524 pub fn unpack(&mut self, ty: &Ty) -> Ty {
525 self.hoister(false).hoist(ty)
526 }
527
528 pub fn unpack_at_name(&mut self, name: Option<Symbol>, ty: &Ty) -> Ty {
529 let mut hoister = self.hoister(false);
530 hoister.delegate.name = name;
531 hoister.hoist(ty)
532 }
533
534 pub fn marker(&self) -> Marker {
535 self.cursor.marker()
536 }
537
538 pub fn hoister(
539 &mut self,
540 assume_invariants: bool,
541 ) -> Hoister<Unpacker<'_, 'infcx, 'genv, 'tcx>> {
542 Hoister::with_delegate(Unpacker { infcx: self, assume_invariants, name: None })
543 .transparent()
544 }
545
546 pub fn assume_invariants(&mut self, ty: &Ty) {
547 self.cursor
548 .assume_invariants(self.genv.tcx(), ty, self.check_overflow);
549 }
550
551 fn check_impl(&mut self, pred1: impl Into<Expr>, pred2: impl Into<Expr>, tag: Tag) {
552 self.cursor.check_impl(pred1, pred2, tag);
553 }
554}
555
556pub struct Unpacker<'a, 'infcx, 'genv, 'tcx> {
557 infcx: &'a mut InferCtxt<'infcx, 'genv, 'tcx>,
558 assume_invariants: bool,
559 name: Option<Symbol>,
560}
561
562impl HoisterDelegate for Unpacker<'_, '_, '_, '_> {
563 fn hoist_exists(&mut self, ty_ctor: &TyCtor) -> Ty {
564 let ty = ty_ctor.replace_bound_refts_with(|sort, _, kind| {
565 let kind = if let Some(name) = self.name { BoundReftKind::Named(name) } else { kind };
566 Expr::fvar(self.infcx.define_bound_reft_var(sort, kind))
567 });
568 if self.assume_invariants {
569 self.infcx.assume_invariants(&ty);
570 }
571 ty
572 }
573
574 fn hoist_constr(&mut self, pred: Expr) {
575 self.infcx.assume_pred(pred);
576 }
577}
578
579impl std::fmt::Debug for InferCtxt<'_, '_, '_> {
580 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
581 std::fmt::Debug::fmt(&self.cursor, f)
582 }
583}
584
585#[derive(Debug)]
586pub struct InferCtxtAt<'a, 'infcx, 'genv, 'tcx> {
587 pub infcx: &'a mut InferCtxt<'infcx, 'genv, 'tcx>,
588 pub span: Span,
589}
590
591impl<'genv, 'tcx> InferCtxtAt<'_, '_, 'genv, 'tcx> {
592 fn tag(&self, reason: ConstrReason) -> Tag {
593 Tag::new(reason, self.span)
594 }
595
596 pub fn check_pred(&mut self, pred: impl Into<Expr>, reason: ConstrReason) {
597 let tag = self.tag(reason);
598 self.infcx.check_pred(pred, tag);
599 }
600
601 pub fn check_non_closure_clauses(
602 &mut self,
603 clauses: &[rty::Clause],
604 reason: ConstrReason,
605 ) -> InferResult {
606 for clause in clauses {
607 if let rty::ClauseKind::Projection(projection_pred) = clause.kind_skipping_binder() {
608 let impl_elem = BaseTy::projection(projection_pred.projection_ty)
609 .to_ty()
610 .deeply_normalize(self)?;
611 let term = projection_pred.term.to_ty().deeply_normalize(self)?;
612
613 self.subtyping(&impl_elem, &term, reason)?;
615 self.subtyping(&term, &impl_elem, reason)?;
616 }
617 }
618 Ok(())
619 }
620
621 pub fn subtyping_with_env(
624 &mut self,
625 env: &mut impl LocEnv,
626 a: &Ty,
627 b: &Ty,
628 reason: ConstrReason,
629 ) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
630 let mut sub = Sub::new(env, reason, self.span);
631 sub.tys(self.infcx, a, b)?;
632 Ok(sub.obligations)
633 }
634
635 pub fn subtyping(
640 &mut self,
641 a: &Ty,
642 b: &Ty,
643 reason: ConstrReason,
644 ) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
645 let mut env = DummyEnv;
646 let mut sub = Sub::new(&mut env, reason, self.span);
647 sub.tys(self.infcx, a, b)?;
648 Ok(sub.obligations)
649 }
650
651 pub fn subtyping_generic_args(
652 &mut self,
653 variance: Variance,
654 a: &GenericArg,
655 b: &GenericArg,
656 reason: ConstrReason,
657 ) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
658 let mut env = DummyEnv;
659 let mut sub = Sub::new(&mut env, reason, self.span);
660 sub.generic_args(self.infcx, variance, a, b)?;
661 Ok(sub.obligations)
662 }
663
664 pub fn check_constructor(
668 &mut self,
669 variant: EarlyBinder<PolyVariant>,
670 generic_args: &[GenericArg],
671 fields: &[Ty],
672 reason: ConstrReason,
673 ) -> InferResult<Ty> {
674 let ret = self.ensure_resolved_evars(|this| {
675 let generic_args = this.instantiate_generic_args(generic_args);
677
678 let variant = variant
679 .instantiate(this.tcx(), &generic_args, &[])
680 .replace_bound_refts_with(|sort, mode, _| this.fresh_infer_var(sort, mode));
681
682 for (actual, formal) in iter::zip(fields, variant.fields()) {
684 this.subtyping(actual, formal, reason)?;
685 }
686
687 for require in &variant.requires {
689 this.check_pred(require, ConstrReason::Fold);
690 }
691
692 Ok(variant.ret())
693 })?;
694 Ok(self.fully_resolve_evars(&ret))
695 }
696
697 pub fn ensure_resolved_evars<R>(
698 &mut self,
699 f: impl FnOnce(&mut InferCtxtAt<'_, '_, 'genv, 'tcx>) -> InferResult<R>,
700 ) -> InferResult<R> {
701 self.infcx
702 .ensure_resolved_evars(|infcx| f(&mut infcx.at(self.span)))
703 }
704}
705
706impl<'a, 'genv, 'tcx> std::ops::Deref for InferCtxtAt<'_, 'a, 'genv, 'tcx> {
707 type Target = InferCtxt<'a, 'genv, 'tcx>;
708
709 fn deref(&self) -> &Self::Target {
710 self.infcx
711 }
712}
713
714impl std::ops::DerefMut for InferCtxtAt<'_, '_, '_, '_> {
715 fn deref_mut(&mut self) -> &mut Self::Target {
716 self.infcx
717 }
718}
719
720#[derive(TypeVisitable, TypeFoldable)]
724pub(crate) enum TypeTrace {
725 Types(Ty, Ty),
726 BaseTys(BaseTy, BaseTy),
727}
728
729#[expect(dead_code, reason = "we use this for debugging some time")]
730impl TypeTrace {
731 fn tys(a: &Ty, b: &Ty) -> Self {
732 Self::Types(a.clone(), b.clone())
733 }
734
735 fn btys(a: &BaseTy, b: &BaseTy) -> Self {
736 Self::BaseTys(a.clone(), b.clone())
737 }
738}
739
740impl fmt::Debug for TypeTrace {
741 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
742 match self {
743 TypeTrace::Types(a, b) => write!(f, "{a:?} - {b:?}"),
744 TypeTrace::BaseTys(a, b) => write!(f, "{a:?} - {b:?}"),
745 }
746 }
747}
748
749pub trait LocEnv {
750 fn ptr_to_ref(
751 &mut self,
752 infcx: &mut InferCtxtAt,
753 reason: ConstrReason,
754 re: Region,
755 path: &Path,
756 bound: Ty,
757 ) -> InferResult<Ty>;
758
759 fn unfold_strg_ref(&mut self, infcx: &mut InferCtxt, path: &Path, ty: &Ty) -> InferResult<Loc>;
760
761 fn get(&self, path: &Path) -> Ty;
762}
763
764struct DummyEnv;
765
766impl LocEnv for DummyEnv {
767 fn ptr_to_ref(
768 &mut self,
769 _: &mut InferCtxtAt,
770 _: ConstrReason,
771 _: Region,
772 _: &Path,
773 _: Ty,
774 ) -> InferResult<Ty> {
775 tracked_span_bug!("call to `ptr_to_ref` on `DummyEnv`")
776 }
777
778 fn unfold_strg_ref(&mut self, _: &mut InferCtxt, _: &Path, _: &Ty) -> InferResult<Loc> {
779 tracked_span_bug!("call to `unfold_str_ref` on `DummyEnv`")
780 }
781
782 fn get(&self, _: &Path) -> Ty {
783 tracked_span_bug!("call to `get` on `DummyEnv`")
784 }
785}
786
787struct Sub<'a, E> {
789 env: &'a mut E,
791 reason: ConstrReason,
792 span: Span,
793 obligations: Vec<Binder<rty::CoroutineObligPredicate>>,
797}
798
799impl<'a, E: LocEnv> Sub<'a, E> {
800 fn new(env: &'a mut E, reason: ConstrReason, span: Span) -> Self {
801 Self { env, reason, span, obligations: vec![] }
802 }
803
804 fn tag(&self) -> Tag {
805 Tag::new(self.reason, self.span)
806 }
807
808 fn tys(&mut self, infcx: &mut InferCtxt, a: &Ty, b: &Ty) -> InferResult {
809 let infcx = &mut infcx.branch();
810 let a = infcx.unpack(a);
816
817 match (a.kind(), b.kind()) {
818 (TyKind::Exists(..), _) => {
819 bug!("existentials should have been removed by the unpacking above");
820 }
821 (TyKind::Constr(..), _) => {
822 bug!("constraint types should have been removed by the unpacking above");
823 }
824
825 (_, TyKind::Exists(ctor_b)) => {
826 infcx.enter_exists(ctor_b, |infcx, ty_b| self.tys(infcx, &a, &ty_b))
827 }
828 (_, TyKind::Constr(pred_b, ty_b)) => {
829 infcx.check_pred(pred_b, self.tag());
830 self.tys(infcx, &a, ty_b)
831 }
832
833 (TyKind::Ptr(PtrKind::Mut(_), path_a), TyKind::StrgRef(_, path_b, ty_b)) => {
834 let ty_a = self.env.get(path_a);
838 infcx.unify_exprs(&path_a.to_expr(), &path_b.to_expr());
839 self.tys(infcx, &ty_a, ty_b)
840 }
841 (TyKind::StrgRef(_, path_a, ty_a), TyKind::StrgRef(_, path_b, ty_b)) => {
842 self.env.unfold_strg_ref(infcx, path_a, ty_a)?;
854 let ty_a = self.env.get(path_a);
855 infcx.unify_exprs(&path_a.to_expr(), &path_b.to_expr());
856 self.tys(infcx, &ty_a, ty_b)
857 }
858 (
859 TyKind::Ptr(PtrKind::Mut(re), path),
860 TyKind::Indexed(BaseTy::Ref(_, bound, Mutability::Mut), idx),
861 ) => {
862 self.idxs_eq(infcx, &Expr::unit(), idx);
865
866 self.env.ptr_to_ref(
867 &mut infcx.at(self.span),
868 self.reason,
869 *re,
870 path,
871 bound.clone(),
872 )?;
873 Ok(())
874 }
875
876 (TyKind::Indexed(bty_a, idx_a), TyKind::Indexed(bty_b, idx_b)) => {
877 self.btys(infcx, bty_a, bty_b)?;
878 self.idxs_eq(infcx, idx_a, idx_b);
879 Ok(())
880 }
881 (TyKind::Ptr(pk_a, path_a), TyKind::Ptr(pk_b, path_b)) => {
882 debug_assert_eq!(pk_a, pk_b);
883 debug_assert_eq!(path_a, path_b);
884 Ok(())
885 }
886 (TyKind::Param(param_ty_a), TyKind::Param(param_ty_b)) => {
887 debug_assert_eq!(param_ty_a, param_ty_b);
888 Ok(())
889 }
890 (_, TyKind::Uninit) => Ok(()),
891 (TyKind::Downcast(.., fields_a), TyKind::Downcast(.., fields_b)) => {
892 debug_assert_eq!(fields_a.len(), fields_b.len());
893 for (ty_a, ty_b) in iter::zip(fields_a, fields_b) {
894 self.tys(infcx, ty_a, ty_b)?;
895 }
896 Ok(())
897 }
898 _ => Err(query_bug!("incompatible types: `{a:?}` - `{b:?}`"))?,
899 }
900 }
901
902 fn btys(&mut self, infcx: &mut InferCtxt, a: &BaseTy, b: &BaseTy) -> InferResult {
903 match (a, b) {
906 (BaseTy::Int(int_ty_a), BaseTy::Int(int_ty_b)) => {
907 debug_assert_eq!(int_ty_a, int_ty_b);
908 Ok(())
909 }
910 (BaseTy::Uint(uint_ty_a), BaseTy::Uint(uint_ty_b)) => {
911 debug_assert_eq!(uint_ty_a, uint_ty_b);
912 Ok(())
913 }
914 (BaseTy::Adt(a_adt, a_args), BaseTy::Adt(b_adt, b_args)) => {
915 tracked_span_dbg_assert_eq!(a_adt.did(), b_adt.did());
916 tracked_span_dbg_assert_eq!(a_args.len(), b_args.len());
917 let variances = infcx.genv.variances_of(a_adt.did());
918 for (variance, ty_a, ty_b) in izip!(variances, a_args.iter(), b_args.iter()) {
919 self.generic_args(infcx, *variance, ty_a, ty_b)?;
920 }
921 Ok(())
922 }
923 (BaseTy::FnDef(a_def_id, a_args), BaseTy::FnDef(b_def_id, b_args)) => {
924 debug_assert_eq!(a_def_id, b_def_id);
925 debug_assert_eq!(a_args.len(), b_args.len());
926 for (arg_a, arg_b) in iter::zip(a_args, b_args) {
934 match (arg_a, arg_b) {
935 (GenericArg::Ty(ty_a), GenericArg::Ty(ty_b)) => {
936 let bty_a = ty_a.as_bty_skipping_existentials();
937 let bty_b = ty_b.as_bty_skipping_existentials();
938 tracked_span_dbg_assert_eq!(bty_a, bty_b);
939 }
940 (GenericArg::Base(ctor_a), GenericArg::Base(ctor_b)) => {
941 let bty_a = ctor_a.as_bty_skipping_binder();
942 let bty_b = ctor_b.as_bty_skipping_binder();
943 tracked_span_dbg_assert_eq!(bty_a, bty_b);
944 }
945 (_, _) => tracked_span_dbg_assert_eq!(arg_a, arg_b),
946 }
947 }
948 Ok(())
949 }
950 (BaseTy::Float(float_ty_a), BaseTy::Float(float_ty_b)) => {
951 debug_assert_eq!(float_ty_a, float_ty_b);
952 Ok(())
953 }
954 (BaseTy::Slice(ty_a), BaseTy::Slice(ty_b)) => self.tys(infcx, ty_a, ty_b),
955
956 (BaseTy::RawPtr(ty_a, mut_a), BaseTy::RawPtr(ty_b, mut_b)) => {
957 debug_assert_eq!(mut_a, mut_b);
958 self.tys(infcx, ty_a, ty_b)?;
959 if matches!(mut_a, Mutability::Mut) {
960 self.tys(infcx, ty_b, ty_a)?;
961 }
962 Ok(())
963 }
964
965 (BaseTy::Ref(_, ty_a, Mutability::Mut), BaseTy::Ref(_, ty_b, Mutability::Mut)) => {
966 if ty_a.is_slice()
967 && let TyKind::Indexed(_, idx_a) = ty_a.kind()
968 && let TyKind::Exists(bty_b) = ty_b.kind()
969 {
970 self.tys(infcx, ty_a, ty_b)?;
975 self.tys(infcx, &bty_b.replace_bound_reft(idx_a), ty_a)
976 } else {
977 self.tys(infcx, ty_a, ty_b)?;
978 self.tys(infcx, ty_b, ty_a)
979 }
980 }
981 (BaseTy::Ref(_, ty_a, Mutability::Not), BaseTy::Ref(_, ty_b, Mutability::Not)) => {
982 self.tys(infcx, ty_a, ty_b)
983 }
984 (BaseTy::Tuple(tys_a), BaseTy::Tuple(tys_b)) => {
985 debug_assert_eq!(tys_a.len(), tys_b.len());
986 for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
987 self.tys(infcx, ty_a, ty_b)?;
988 }
989 Ok(())
990 }
991 (
992 BaseTy::Alias(AliasKind::Opaque, alias_ty_a),
993 BaseTy::Alias(AliasKind::Opaque, alias_ty_b),
994 ) => {
995 debug_assert_eq!(alias_ty_a.def_id, alias_ty_b.def_id);
996
997 for (ty_a, ty_b) in izip!(alias_ty_a.args.iter(), alias_ty_b.args.iter()) {
999 self.generic_args(infcx, Invariant, ty_a, ty_b)?;
1000 }
1001
1002 debug_assert_eq!(alias_ty_a.refine_args.len(), alias_ty_b.refine_args.len());
1004 iter::zip(alias_ty_a.refine_args.iter(), alias_ty_b.refine_args.iter())
1005 .for_each(|(expr_a, expr_b)| infcx.unify_exprs(expr_a, expr_b));
1006
1007 Ok(())
1008 }
1009 (_, BaseTy::Alias(AliasKind::Opaque, alias_ty_b)) => {
1010 self.handle_opaque_type(infcx, a, alias_ty_b)
1012 }
1013 (
1014 BaseTy::Alias(AliasKind::Projection, alias_ty_a),
1015 BaseTy::Alias(AliasKind::Projection, alias_ty_b),
1016 ) => {
1017 tracked_span_dbg_assert_eq!(alias_ty_a.erase_regions(), alias_ty_b.erase_regions());
1018 Ok(())
1019 }
1020 (BaseTy::Array(ty_a, len_a), BaseTy::Array(ty_b, len_b)) => {
1021 tracked_span_dbg_assert_eq!(len_a, len_b);
1022 self.tys(infcx, ty_a, ty_b)
1023 }
1024 (BaseTy::Param(param_a), BaseTy::Param(param_b)) => {
1025 debug_assert_eq!(param_a, param_b);
1026 Ok(())
1027 }
1028 (BaseTy::Bool, BaseTy::Bool)
1029 | (BaseTy::Str, BaseTy::Str)
1030 | (BaseTy::Char, BaseTy::Char)
1031 | (BaseTy::RawPtrMetadata(_), BaseTy::RawPtrMetadata(_)) => Ok(()),
1032 (BaseTy::Dynamic(preds_a, _), BaseTy::Dynamic(preds_b, _)) => {
1033 tracked_span_assert_eq!(preds_a.erase_regions(), preds_b.erase_regions());
1034 Ok(())
1035 }
1036 (BaseTy::Closure(did1, tys_a, _, _), BaseTy::Closure(did2, tys_b, _, _))
1037 if did1 == did2 =>
1038 {
1039 debug_assert_eq!(tys_a.len(), tys_b.len());
1040 for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
1041 self.tys(infcx, ty_a, ty_b)?;
1042 }
1043 Ok(())
1044 }
1045 (BaseTy::FnPtr(sig_a), BaseTy::FnPtr(sig_b)) => {
1046 tracked_span_assert_eq!(sig_a.erase_regions(), sig_b.erase_regions());
1047 Ok(())
1048 }
1049 (BaseTy::Never, BaseTy::Never) => Ok(()),
1050 (
1051 BaseTy::Coroutine(did1, resume_ty_a, tys_a, _),
1052 BaseTy::Coroutine(did2, resume_ty_b, tys_b, _),
1053 ) if did1 == did2 => {
1054 debug_assert_eq!(tys_a.len(), tys_b.len());
1055 for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
1056 self.tys(infcx, ty_a, ty_b)?;
1057 }
1058 self.tys(infcx, resume_ty_b, resume_ty_a)?;
1060 self.tys(infcx, resume_ty_a, resume_ty_b)?;
1061
1062 Ok(())
1063 }
1064 (BaseTy::Foreign(did_a), BaseTy::Foreign(did_b)) if did_a == did_b => Ok(()),
1065 _ => Err(query_bug!("incompatible base types: `{a:#?}` - `{b:#?}`"))?,
1066 }
1067 }
1068
1069 fn generic_args(
1070 &mut self,
1071 infcx: &mut InferCtxt,
1072 variance: Variance,
1073 a: &GenericArg,
1074 b: &GenericArg,
1075 ) -> InferResult {
1076 let (ty_a, ty_b) = match (a, b) {
1077 (GenericArg::Ty(ty_a), GenericArg::Ty(ty_b)) => (ty_a.clone(), ty_b.clone()),
1078 (GenericArg::Base(ctor_a), GenericArg::Base(ctor_b)) => {
1079 tracked_span_dbg_assert_eq!(
1080 ctor_a.sort().erase_regions(),
1081 ctor_b.sort().erase_regions()
1082 );
1083 (ctor_a.to_ty(), ctor_b.to_ty())
1084 }
1085 (GenericArg::Lifetime(_), GenericArg::Lifetime(_)) => return Ok(()),
1086 (GenericArg::Const(cst_a), GenericArg::Const(cst_b)) => {
1087 debug_assert_eq!(cst_a, cst_b);
1088 return Ok(());
1089 }
1090 _ => Err(query_bug!("incompatible generic args: `{a:?}` `{b:?}`"))?,
1091 };
1092 match variance {
1093 Variance::Covariant => self.tys(infcx, &ty_a, &ty_b),
1094 Variance::Invariant => {
1095 self.tys(infcx, &ty_a, &ty_b)?;
1096 self.tys(infcx, &ty_b, &ty_a)
1097 }
1098 Variance::Contravariant => self.tys(infcx, &ty_b, &ty_a),
1099 Variance::Bivariant => Ok(()),
1100 }
1101 }
1102
1103 fn idxs_eq(&mut self, infcx: &mut InferCtxt, a: &Expr, b: &Expr) {
1104 if a == b {
1105 return;
1106 }
1107 match (a.kind(), b.kind()) {
1108 (
1109 ExprKind::Ctor(Ctor::Struct(did_a), flds_a),
1110 ExprKind::Ctor(Ctor::Struct(did_b), flds_b),
1111 ) => {
1112 debug_assert_eq!(did_a, did_b);
1113 for (a, b) in iter::zip(flds_a, flds_b) {
1114 self.idxs_eq(infcx, a, b);
1115 }
1116 }
1117 (ExprKind::Tuple(flds_a), ExprKind::Tuple(flds_b)) => {
1118 for (a, b) in iter::zip(flds_a, flds_b) {
1119 self.idxs_eq(infcx, a, b);
1120 }
1121 }
1122 (_, ExprKind::Tuple(flds_b)) => {
1123 for (f, b) in flds_b.iter().enumerate() {
1124 let proj = FieldProj::Tuple { arity: flds_b.len(), field: f as u32 };
1125 let a = a.proj_and_reduce(proj);
1126 self.idxs_eq(infcx, &a, b);
1127 }
1128 }
1129
1130 (_, ExprKind::Ctor(Ctor::Struct(def_id), flds_b)) => {
1131 for (f, b) in flds_b.iter().enumerate() {
1132 let proj = FieldProj::Adt { def_id: *def_id, field: f as u32 };
1133 let a = a.proj_and_reduce(proj);
1134 self.idxs_eq(infcx, &a, b);
1135 }
1136 }
1137
1138 (ExprKind::Tuple(flds_a), _) => {
1139 infcx.unify_exprs(a, b);
1140 for (f, a) in flds_a.iter().enumerate() {
1141 let proj = FieldProj::Tuple { arity: flds_a.len(), field: f as u32 };
1142 let b = b.proj_and_reduce(proj);
1143 self.idxs_eq(infcx, a, &b);
1144 }
1145 }
1146 (ExprKind::Ctor(Ctor::Struct(def_id), flds_a), _) => {
1147 infcx.unify_exprs(a, b);
1148 for (f, a) in flds_a.iter().enumerate() {
1149 let proj = FieldProj::Adt { def_id: *def_id, field: f as u32 };
1150 let b = b.proj_and_reduce(proj);
1151 self.idxs_eq(infcx, a, &b);
1152 }
1153 }
1154 (ExprKind::Abs(lam_a), ExprKind::Abs(lam_b)) => {
1155 self.abs_eq(infcx, lam_a, lam_b);
1156 }
1157 (_, ExprKind::Abs(lam_b)) => {
1158 self.abs_eq(infcx, &a.eta_expand_abs(lam_b.vars(), lam_b.output()), lam_b);
1159 }
1160 (ExprKind::Abs(lam_a), _) => {
1161 infcx.unify_exprs(a, b);
1162 self.abs_eq(infcx, lam_a, &b.eta_expand_abs(lam_a.vars(), lam_a.output()));
1163 }
1164 (ExprKind::KVar(_), _) | (_, ExprKind::KVar(_)) => {
1165 infcx.check_impl(a, b, self.tag());
1166 infcx.check_impl(b, a, self.tag());
1167 }
1168 _ => {
1169 infcx.unify_exprs(a, b);
1170 let span = b.span();
1171 infcx.check_pred(Expr::binary_op(rty::BinOp::Eq, a, b).at_opt(span), self.tag());
1172 }
1173 }
1174 }
1175
1176 fn abs_eq(&mut self, infcx: &mut InferCtxt, a: &Lambda, b: &Lambda) {
1177 debug_assert_eq!(a.vars().len(), b.vars().len());
1178 let vars = a
1179 .vars()
1180 .iter()
1181 .map(|kind| {
1182 let (sort, _, kind) = kind.expect_refine();
1183 Expr::fvar(infcx.define_bound_reft_var(sort, kind))
1184 })
1185 .collect_vec();
1186 let body_a = a.apply(&vars);
1187 let body_b = b.apply(&vars);
1188 self.idxs_eq(infcx, &body_a, &body_b);
1189 }
1190
1191 fn handle_opaque_type(
1192 &mut self,
1193 infcx: &mut InferCtxt,
1194 bty: &BaseTy,
1195 alias_ty: &AliasTy,
1196 ) -> InferResult {
1197 if let BaseTy::Coroutine(def_id, resume_ty, upvar_tys, args) = bty {
1198 let obligs = mk_coroutine_obligations(
1199 infcx.genv,
1200 def_id,
1201 resume_ty,
1202 upvar_tys,
1203 &alias_ty.def_id,
1204 args.clone(),
1205 )?;
1206 self.obligations.extend(obligs);
1207 } else {
1208 let bounds = infcx.genv.item_bounds(alias_ty.def_id)?.instantiate(
1209 infcx.tcx(),
1210 &alias_ty.args,
1211 &alias_ty.refine_args,
1212 );
1213 for clause in &bounds {
1214 if !clause.kind().vars().is_empty() {
1215 Err(query_bug!("handle_opaque_types: clause with bound vars: `{clause:?}`"))?;
1216 }
1217 if let rty::ClauseKind::Projection(pred) = clause.kind_skipping_binder() {
1218 let alias_ty = pred.projection_ty.with_self_ty(bty.to_subset_ty_ctor());
1219 let ty1 = BaseTy::Alias(AliasKind::Projection, alias_ty)
1220 .to_ty()
1221 .deeply_normalize(&mut infcx.at(self.span))?;
1222 let ty2 = pred.term.to_ty();
1223 self.tys(infcx, &ty1, &ty2)?;
1224 }
1225 }
1226 }
1227 Ok(())
1228 }
1229}
1230
1231fn mk_coroutine_obligations(
1232 genv: GlobalEnv,
1233 generator_did: &DefId,
1234 resume_ty: &Ty,
1235 upvar_tys: &List<Ty>,
1236 opaque_def_id: &DefId,
1237 args: flux_rustc_bridge::ty::GenericArgs,
1238) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
1239 let bounds = genv.item_bounds(*opaque_def_id)?.skip_binder();
1240 for bound in &bounds {
1241 if let Some(proj_clause) = bound.as_projection_clause() {
1242 return Ok(vec![proj_clause.map(|proj_clause| {
1243 let output = proj_clause.term;
1244 CoroutineObligPredicate {
1245 def_id: *generator_did,
1246 resume_ty: resume_ty.clone(),
1247 upvar_tys: upvar_tys.clone(),
1248 output: output.to_ty(),
1249 args,
1250 }
1251 })]);
1252 }
1253 }
1254 bug!("no projection predicate")
1255}
1256
1257#[derive(Debug)]
1258pub enum InferErr {
1259 UnsolvedEvar(EVid),
1260 Query(QueryErr),
1261}
1262
1263impl From<QueryErr> for InferErr {
1264 fn from(v: QueryErr) -> Self {
1265 Self::Query(v)
1266 }
1267}
1268
1269mod pretty {
1270 use std::fmt;
1271
1272 use flux_middle::pretty::*;
1273
1274 use super::*;
1275
1276 impl Pretty for Tag {
1277 fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1278 w!(cx, f, "{:?} at {:?}", ^self.reason, self.src_span)?;
1279 if let Some(dst_span) = self.dst_span {
1280 w!(cx, f, " ({:?})", ^dst_span)?;
1281 }
1282 Ok(())
1283 }
1284 }
1285
1286 impl_debug_with_default_cx!(Tag);
1287}