1use std::{cell::RefCell, fmt, iter};
2
3use flux_common::{bug, dbg, tracked_span_assert_eq, tracked_span_bug, tracked_span_dbg_assert_eq};
4use flux_config::{self as config, InferOpts, OverflowMode};
5use flux_macros::{TypeFoldable, TypeVisitable};
6use flux_middle::{
7 FixpointQueryKind,
8 def_id::MaybeExternId,
9 global_env::GlobalEnv,
10 metrics::{self, Metric},
11 queries::{QueryErr, QueryResult},
12 query_bug,
13 rty::{
14 self, AliasKind, AliasTy, BaseTy, Binder, BoundReftKind, BoundVariableKinds,
15 CoroutineObligPredicate, Ctor, ESpan, EVid, EarlyBinder, Expr, ExprKind, FieldProj,
16 GenericArg, HoleKind, InferMode, Lambda, List, Loc, Mutability, Name, NameProvenance, Path,
17 PolyVariant, PtrKind, RefineArgs, RefineArgsExt, Region, Sort, Ty, TyCtor, TyKind, Var,
18 canonicalize::{Hoister, HoisterDelegate},
19 fold::TypeFoldable,
20 },
21};
22use itertools::{Itertools, izip};
23use rustc_hir::def_id::{DefId, LocalDefId};
24use rustc_macros::extension;
25use rustc_middle::{
26 mir::BasicBlock,
27 ty::{TyCtxt, Variance},
28};
29use rustc_span::{Span, Symbol};
30use rustc_type_ir::Variance::Invariant;
31
32use crate::{
33 evars::{EVarState, EVarStore},
34 fixpoint_encoding::{
35 Answer, Backend, FixQueryCache, FixpointCtxt, KVarEncoding, KVarGen, KVarSolutions,
36 },
37 projections::NormalizeExt as _,
38 refine_tree::{Cursor, Marker, RefineTree, Scope},
39};
40
41pub type InferResult<T = ()> = std::result::Result<T, InferErr>;
42
43#[derive(PartialEq, Eq, Clone, Copy, Hash)]
44pub struct Tag {
45 pub reason: ConstrReason,
46 pub src_span: Span,
47 pub dst_span: Option<ESpan>,
48}
49
50impl Tag {
51 pub fn new(reason: ConstrReason, span: Span) -> Self {
52 Self { reason, src_span: span, dst_span: None }
53 }
54
55 pub fn with_dst(self, dst_span: Option<ESpan>) -> Self {
56 Self { dst_span, ..self }
57 }
58}
59
60#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)]
61pub enum SubtypeReason {
62 Input,
63 Output,
64 Requires,
65 Ensures,
66}
67
68#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)]
69pub enum ConstrReason {
70 Call,
71 Assign,
72 Ret,
73 Fold,
74 FoldLocal,
75 Predicate,
76 Assert(&'static str),
77 Div,
78 Rem,
79 Goto(BasicBlock),
80 Overflow,
81 Underflow,
82 Subtype(SubtypeReason),
83 NoPanic(DefId),
84 Other,
85}
86
87pub struct InferCtxtRoot<'genv, 'tcx> {
88 pub genv: GlobalEnv<'genv, 'tcx>,
89 inner: RefCell<InferCtxtInner>,
90 refine_tree: RefineTree,
91 opts: InferOpts,
92}
93
94pub struct InferCtxtRootBuilder<'a, 'genv, 'tcx> {
95 genv: GlobalEnv<'genv, 'tcx>,
96 opts: InferOpts,
97 params: Vec<(Var, Sort)>,
98 infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
99 dummy_kvars: bool,
100}
101
102#[extension(pub trait GlobalEnvExt<'genv, 'tcx>)]
103impl<'genv, 'tcx> GlobalEnv<'genv, 'tcx> {
104 fn infcx_root<'a>(
105 self,
106 infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
107 opts: InferOpts,
108 ) -> InferCtxtRootBuilder<'a, 'genv, 'tcx> {
109 InferCtxtRootBuilder { genv: self, infcx, params: vec![], opts, dummy_kvars: false }
110 }
111}
112
113impl<'genv, 'tcx> InferCtxtRootBuilder<'_, 'genv, 'tcx> {
114 pub fn with_dummy_kvars(mut self) -> Self {
115 self.dummy_kvars = true;
116 self
117 }
118
119 pub fn with_const_generics(mut self, def_id: DefId) -> QueryResult<Self> {
120 self.params.extend(
121 self.genv
122 .generics_of(def_id)?
123 .const_params(self.genv)?
124 .into_iter()
125 .map(|(pcst, sort)| (Var::ConstGeneric(pcst), sort)),
126 );
127 Ok(self)
128 }
129
130 pub fn with_refinement_generics(
131 mut self,
132 def_id: DefId,
133 args: &[GenericArg],
134 ) -> QueryResult<Self> {
135 for (index, param) in self
136 .genv
137 .refinement_generics_of(def_id)?
138 .iter_own_params()
139 .enumerate()
140 {
141 let param = param.instantiate(self.genv.tcx(), args, &[]);
142 let sort = param
143 .sort
144 .deeply_normalize_sorts(def_id, self.genv, self.infcx)?;
145
146 let var =
147 Var::EarlyParam(rty::EarlyReftParam { index: index as u32, name: param.name });
148 self.params.push((var, sort));
149 }
150 Ok(self)
151 }
152
153 pub fn identity_for_item(mut self, def_id: DefId) -> QueryResult<Self> {
154 self = self.with_const_generics(def_id)?;
155 let offset = self.params.len();
156 self.genv.refinement_generics_of(def_id)?.fill_item(
157 self.genv,
158 &mut self.params,
159 &mut |param, index| {
160 let index = (index - offset) as u32;
161 let param = param.instantiate_identity();
162 let sort = param
163 .sort
164 .deeply_normalize_sorts(def_id, self.genv, self.infcx)?;
165
166 let var = Var::EarlyParam(rty::EarlyReftParam { index, name: param.name });
167 Ok((var, sort))
168 },
169 )?;
170 Ok(self)
171 }
172
173 pub fn build(self) -> QueryResult<InferCtxtRoot<'genv, 'tcx>> {
174 Ok(InferCtxtRoot {
175 genv: self.genv,
176 inner: RefCell::new(InferCtxtInner::new(self.dummy_kvars)),
177 refine_tree: RefineTree::new(self.params),
178 opts: self.opts,
179 })
180 }
181}
182
183impl<'genv, 'tcx> InferCtxtRoot<'genv, 'tcx> {
184 pub fn infcx<'a>(
185 &'a mut self,
186 def_id: DefId,
187 region_infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
188 ) -> InferCtxt<'a, 'genv, 'tcx> {
189 InferCtxt {
190 genv: self.genv,
191 region_infcx,
192 def_id,
193 cursor: self.refine_tree.cursor_at_root(),
194 inner: &self.inner,
195 check_overflow: self.opts.check_overflow,
196 }
197 }
198
199 pub fn fresh_kvar_in_scope(
200 &self,
201 binders: &[BoundVariableKinds],
202 scope: &Scope,
203 encoding: KVarEncoding,
204 ) -> Expr {
205 let inner = &mut *self.inner.borrow_mut();
206 inner.kvars.fresh(binders, scope.iter(), encoding)
207 }
208
209 pub fn execute_lean_query(
210 self,
211 cache: &mut FixQueryCache,
212 def_id: MaybeExternId,
213 ) -> QueryResult {
214 let inner = self.inner.into_inner();
215 let kvars = inner.kvars;
216 let evars = inner.evars;
217 let mut refine_tree = self.refine_tree;
218 refine_tree.replace_evars(&evars).unwrap();
219 refine_tree.simplify(self.genv);
220
221 let solver = match self.opts.solver {
222 flux_config::SmtSolver::Z3 => liquid_fixpoint::SmtSolver::Z3,
223 flux_config::SmtSolver::CVC5 => liquid_fixpoint::SmtSolver::CVC5,
224 };
225 let mut fcx = FixpointCtxt::new(self.genv, def_id, kvars, Backend::Lean);
226 let cstr = refine_tree.to_fixpoint(&mut fcx)?;
227 let cstr_variable_sorts = cstr.variable_sorts();
228 let task = fcx.create_task(def_id, cstr, self.opts.scrape_quals, solver)?;
229 let result = fcx.run_task(cache, def_id, FixpointQueryKind::Body, &task)?;
230
231 fcx.generate_lean_files(
232 def_id,
233 task,
234 KVarSolutions::closed_solutions(
235 cstr_variable_sorts,
236 result.solution,
237 result.non_cut_solution,
238 ),
239 )
240 }
241
242 pub fn execute_fixpoint_query(
243 self,
244 cache: &mut FixQueryCache,
245 def_id: MaybeExternId,
246 kind: FixpointQueryKind,
247 ) -> QueryResult<Answer<Tag>> {
248 let inner = self.inner.into_inner();
249 let kvars = inner.kvars;
250 let evars = inner.evars;
251
252 let ext = kind.ext();
253
254 let mut refine_tree = self.refine_tree;
255
256 refine_tree.replace_evars(&evars).unwrap();
257
258 if config::dump_constraint() {
259 dbg::dump_item_info(self.genv.tcx(), def_id.resolved_id(), ext, &refine_tree).unwrap();
260 }
261 refine_tree.simplify(self.genv);
262 if config::dump_constraint() {
263 let simp_ext = format!("simp.{ext}");
264 dbg::dump_item_info(self.genv.tcx(), def_id.resolved_id(), simp_ext, &refine_tree)
265 .unwrap();
266 }
267
268 let backend = match self.opts.solver {
269 flux_config::SmtSolver::Z3 => liquid_fixpoint::SmtSolver::Z3,
270 flux_config::SmtSolver::CVC5 => liquid_fixpoint::SmtSolver::CVC5,
271 };
272
273 let mut fcx = FixpointCtxt::new(self.genv, def_id, kvars, Backend::Fixpoint);
274 let cstr = refine_tree.to_fixpoint(&mut fcx)?;
275
276 let count = cstr.concrete_head_count();
278 metrics::incr_metric(Metric::CsTotal, count as u32);
279 if count == 0 {
280 metrics::incr_metric_if(kind.is_body(), Metric::FnTrivial);
281 return Ok(Answer::trivial());
282 }
283
284 let task = fcx.create_task(def_id, cstr, self.opts.scrape_quals, backend)?;
285 let result = fcx.run_task(cache, def_id, kind, &task)?;
286 Ok(fcx.result_to_answer(result))
287 }
288
289 pub fn split(self) -> (RefineTree, KVarGen) {
290 (self.refine_tree, self.inner.into_inner().kvars)
291 }
292}
293
294pub struct InferCtxt<'infcx, 'genv, 'tcx> {
295 pub genv: GlobalEnv<'genv, 'tcx>,
296 pub region_infcx: &'infcx rustc_infer::infer::InferCtxt<'tcx>,
297 pub def_id: DefId,
298 pub check_overflow: OverflowMode,
299 cursor: Cursor<'infcx>,
300 inner: &'infcx RefCell<InferCtxtInner>,
301}
302
303struct InferCtxtInner {
304 kvars: KVarGen,
305 evars: EVarStore,
306}
307
308impl InferCtxtInner {
309 fn new(dummy_kvars: bool) -> Self {
310 Self { kvars: KVarGen::new(dummy_kvars), evars: Default::default() }
311 }
312}
313
314impl<'infcx, 'genv, 'tcx> InferCtxt<'infcx, 'genv, 'tcx> {
315 pub fn at(&mut self, span: Span) -> InferCtxtAt<'_, 'infcx, 'genv, 'tcx> {
316 InferCtxtAt { infcx: self, span }
317 }
318
319 pub fn instantiate_refine_args(
320 &mut self,
321 callee_def_id: DefId,
322 args: &[rty::GenericArg],
323 ) -> InferResult<List<Expr>> {
324 Ok(RefineArgs::for_item(self.genv, callee_def_id, |param, _| {
325 let param = param.instantiate(self.genv.tcx(), args, &[]);
326 Ok(self.fresh_infer_var(¶m.sort, param.mode))
327 })?)
328 }
329
330 pub fn instantiate_generic_args(&mut self, args: &[GenericArg]) -> Vec<GenericArg> {
331 args.iter()
332 .map(|a| a.replace_holes(|binders, kind| self.fresh_infer_var_for_hole(binders, kind)))
333 .collect_vec()
334 }
335
336 pub fn fresh_infer_var(&self, sort: &Sort, mode: InferMode) -> Expr {
337 match mode {
338 InferMode::KVar => {
339 let fsort = sort.expect_func().expect_mono();
340 let vars = fsort.inputs().iter().cloned().map_into().collect();
341 let kvar = self.fresh_kvar(&[vars], KVarEncoding::Single);
342 Expr::abs(Lambda::bind_with_fsort(kvar, fsort))
343 }
344 InferMode::EVar => self.fresh_evar(),
345 }
346 }
347
348 pub fn fresh_infer_var_for_hole(
349 &mut self,
350 binders: &[BoundVariableKinds],
351 kind: HoleKind,
352 ) -> Expr {
353 match kind {
354 HoleKind::Pred => self.fresh_kvar(binders, KVarEncoding::Conj),
355 HoleKind::Expr(_) => {
356 self.fresh_evar()
360 }
361 }
362 }
363
364 pub fn fresh_kvar_in_scope(
366 &self,
367 binders: &[BoundVariableKinds],
368 scope: &Scope,
369 encoding: KVarEncoding,
370 ) -> Expr {
371 let inner = &mut *self.inner.borrow_mut();
372 inner.kvars.fresh(binders, scope.iter(), encoding)
373 }
374
375 pub fn fresh_kvar(&self, binders: &[BoundVariableKinds], encoding: KVarEncoding) -> Expr {
377 let inner = &mut *self.inner.borrow_mut();
378 inner.kvars.fresh(binders, self.cursor.vars(), encoding)
379 }
380
381 fn fresh_evar(&self) -> Expr {
382 let evars = &mut self.inner.borrow_mut().evars;
383 Expr::evar(evars.fresh(self.cursor.marker()))
384 }
385
386 pub fn unify_exprs(&self, a: &Expr, b: &Expr) {
387 if a.has_evars() {
388 return;
389 }
390 let evars = &mut self.inner.borrow_mut().evars;
391 if let ExprKind::Var(Var::EVar(evid)) = b.kind()
392 && let EVarState::Unsolved(marker) = evars.get(*evid)
393 && !marker.has_free_vars(a)
394 {
395 evars.solve(*evid, a.clone());
396 }
397 }
398
399 fn enter_exists<T, U>(
400 &mut self,
401 t: &Binder<T>,
402 f: impl FnOnce(&mut InferCtxt<'_, 'genv, 'tcx>, T) -> U,
403 ) -> U
404 where
405 T: TypeFoldable,
406 {
407 self.ensure_resolved_evars(|infcx| {
408 let t = t.replace_bound_refts_with(|sort, mode, _| infcx.fresh_infer_var(sort, mode));
409 Ok(f(infcx, t))
410 })
411 .unwrap()
412 }
413
414 pub fn push_evar_scope(&mut self) {
419 self.inner.borrow_mut().evars.push_scope();
420 }
421
422 pub fn pop_evar_scope(&mut self) -> InferResult {
425 self.inner
426 .borrow_mut()
427 .evars
428 .pop_scope()
429 .map_err(InferErr::UnsolvedEvar)
430 }
431
432 pub fn ensure_resolved_evars<R>(
434 &mut self,
435 f: impl FnOnce(&mut Self) -> InferResult<R>,
436 ) -> InferResult<R> {
437 self.push_evar_scope();
438 let r = f(self)?;
439 self.pop_evar_scope()?;
440 Ok(r)
441 }
442
443 pub fn fully_resolve_evars<T: TypeFoldable>(&self, t: &T) -> T {
444 self.inner.borrow().evars.replace_evars(t).unwrap()
445 }
446
447 pub fn tcx(&self) -> TyCtxt<'tcx> {
448 self.genv.tcx()
449 }
450
451 pub fn cursor(&self) -> &Cursor<'infcx> {
452 &self.cursor
453 }
454}
455
456impl<'infcx, 'genv, 'tcx> InferCtxt<'infcx, 'genv, 'tcx> {
458 pub fn change_item<'a>(
459 &'a mut self,
460 def_id: LocalDefId,
461 region_infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
462 ) -> InferCtxt<'a, 'genv, 'tcx> {
463 InferCtxt {
464 def_id: def_id.to_def_id(),
465 cursor: self.cursor.branch(),
466 region_infcx,
467 ..*self
468 }
469 }
470
471 pub fn move_to(&mut self, marker: &Marker, clear_children: bool) -> InferCtxt<'_, 'genv, 'tcx> {
472 InferCtxt {
473 cursor: self
474 .cursor
475 .move_to(marker, clear_children)
476 .unwrap_or_else(|| tracked_span_bug!()),
477 ..*self
478 }
479 }
480
481 pub fn branch(&mut self) -> InferCtxt<'_, 'genv, 'tcx> {
482 InferCtxt { cursor: self.cursor.branch(), ..*self }
483 }
484
485 fn define_var(&mut self, sort: &Sort, provenance: NameProvenance) -> Name {
486 self.cursor.define_var(sort, provenance)
487 }
488
489 pub fn define_bound_reft_var(&mut self, sort: &Sort, kind: BoundReftKind) -> Name {
490 self.define_var(sort, NameProvenance::UnfoldBoundReft(kind))
491 }
492
493 pub fn define_unknown_var(&mut self, sort: &Sort) -> Name {
494 self.cursor.define_var(sort, NameProvenance::Unknown)
495 }
496
497 pub fn check_pred(&mut self, pred: impl Into<Expr>, tag: Tag) {
498 self.cursor.check_pred(pred, tag);
499 }
500
501 pub fn assume_pred(&mut self, pred: impl Into<Expr>) {
502 self.cursor.assume_pred(pred);
503 }
504
505 pub fn unpack(&mut self, ty: &Ty) -> Ty {
506 self.hoister(false).hoist(ty)
507 }
508
509 pub fn unpack_at_name(&mut self, name: Option<Symbol>, ty: &Ty) -> Ty {
510 let mut hoister = self.hoister(false);
511 hoister.delegate.name = name;
512 hoister.hoist(ty)
513 }
514
515 pub fn marker(&self) -> Marker {
516 self.cursor.marker()
517 }
518
519 pub fn hoister(
520 &mut self,
521 assume_invariants: bool,
522 ) -> Hoister<Unpacker<'_, 'infcx, 'genv, 'tcx>> {
523 Hoister::with_delegate(Unpacker { infcx: self, assume_invariants, name: None })
524 .transparent()
525 }
526
527 pub fn assume_invariants(&mut self, ty: &Ty) {
528 self.cursor
529 .assume_invariants(self.genv.tcx(), ty, self.check_overflow);
530 }
531
532 fn check_impl(&mut self, pred1: impl Into<Expr>, pred2: impl Into<Expr>, tag: Tag) {
533 self.cursor.check_impl(pred1, pred2, tag);
534 }
535}
536
537pub struct Unpacker<'a, 'infcx, 'genv, 'tcx> {
538 infcx: &'a mut InferCtxt<'infcx, 'genv, 'tcx>,
539 assume_invariants: bool,
540 name: Option<Symbol>,
541}
542
543impl HoisterDelegate for Unpacker<'_, '_, '_, '_> {
544 fn hoist_exists(&mut self, ty_ctor: &TyCtor) -> Ty {
545 let ty = ty_ctor.replace_bound_refts_with(|sort, _, kind| {
546 let kind = if let Some(name) = self.name { BoundReftKind::Named(name) } else { kind };
547 Expr::fvar(self.infcx.define_bound_reft_var(sort, kind))
548 });
549 if self.assume_invariants {
550 self.infcx.assume_invariants(&ty);
551 }
552 ty
553 }
554
555 fn hoist_constr(&mut self, pred: Expr) {
556 self.infcx.assume_pred(pred);
557 }
558}
559
560impl std::fmt::Debug for InferCtxt<'_, '_, '_> {
561 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
562 std::fmt::Debug::fmt(&self.cursor, f)
563 }
564}
565
566#[derive(Debug)]
567pub struct InferCtxtAt<'a, 'infcx, 'genv, 'tcx> {
568 pub infcx: &'a mut InferCtxt<'infcx, 'genv, 'tcx>,
569 pub span: Span,
570}
571
572impl<'genv, 'tcx> InferCtxtAt<'_, '_, 'genv, 'tcx> {
573 fn tag(&self, reason: ConstrReason) -> Tag {
574 Tag::new(reason, self.span)
575 }
576
577 pub fn check_pred(&mut self, pred: impl Into<Expr>, reason: ConstrReason) {
578 let tag = self.tag(reason);
579 self.infcx.check_pred(pred, tag);
580 }
581
582 pub fn check_non_closure_clauses(
583 &mut self,
584 clauses: &[rty::Clause],
585 reason: ConstrReason,
586 ) -> InferResult {
587 for clause in clauses {
588 if let rty::ClauseKind::Projection(projection_pred) = clause.kind_skipping_binder() {
589 let impl_elem = BaseTy::projection(projection_pred.projection_ty)
590 .to_ty()
591 .deeply_normalize(self)?;
592 let term = projection_pred.term.to_ty().deeply_normalize(self)?;
593
594 self.subtyping(&impl_elem, &term, reason)?;
596 self.subtyping(&term, &impl_elem, reason)?;
597 }
598 }
599 Ok(())
600 }
601
602 pub fn subtyping_with_env(
605 &mut self,
606 env: &mut impl LocEnv,
607 a: &Ty,
608 b: &Ty,
609 reason: ConstrReason,
610 ) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
611 let mut sub = Sub::new(env, reason, self.span);
612 sub.tys(self.infcx, a, b)?;
613 Ok(sub.obligations)
614 }
615
616 pub fn subtyping(
621 &mut self,
622 a: &Ty,
623 b: &Ty,
624 reason: ConstrReason,
625 ) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
626 let mut env = DummyEnv;
627 let mut sub = Sub::new(&mut env, reason, self.span);
628 sub.tys(self.infcx, a, b)?;
629 Ok(sub.obligations)
630 }
631
632 pub fn subtyping_generic_args(
633 &mut self,
634 variance: Variance,
635 a: &GenericArg,
636 b: &GenericArg,
637 reason: ConstrReason,
638 ) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
639 let mut env = DummyEnv;
640 let mut sub = Sub::new(&mut env, reason, self.span);
641 sub.generic_args(self.infcx, variance, a, b)?;
642 Ok(sub.obligations)
643 }
644
645 pub fn check_constructor(
649 &mut self,
650 variant: EarlyBinder<PolyVariant>,
651 generic_args: &[GenericArg],
652 fields: &[Ty],
653 reason: ConstrReason,
654 ) -> InferResult<Ty> {
655 let ret = self.ensure_resolved_evars(|this| {
656 let generic_args = this.instantiate_generic_args(generic_args);
658
659 let variant = variant
660 .instantiate(this.tcx(), &generic_args, &[])
661 .replace_bound_refts_with(|sort, mode, _| this.fresh_infer_var(sort, mode));
662
663 for (actual, formal) in iter::zip(fields, variant.fields()) {
665 this.subtyping(actual, formal, reason)?;
666 }
667
668 for require in &variant.requires {
670 this.check_pred(require, ConstrReason::Fold);
671 }
672
673 Ok(variant.ret())
674 })?;
675 Ok(self.fully_resolve_evars(&ret))
676 }
677
678 pub fn ensure_resolved_evars<R>(
679 &mut self,
680 f: impl FnOnce(&mut InferCtxtAt<'_, '_, 'genv, 'tcx>) -> InferResult<R>,
681 ) -> InferResult<R> {
682 self.infcx
683 .ensure_resolved_evars(|infcx| f(&mut infcx.at(self.span)))
684 }
685}
686
687impl<'a, 'genv, 'tcx> std::ops::Deref for InferCtxtAt<'_, 'a, 'genv, 'tcx> {
688 type Target = InferCtxt<'a, 'genv, 'tcx>;
689
690 fn deref(&self) -> &Self::Target {
691 self.infcx
692 }
693}
694
695impl std::ops::DerefMut for InferCtxtAt<'_, '_, '_, '_> {
696 fn deref_mut(&mut self) -> &mut Self::Target {
697 self.infcx
698 }
699}
700
701#[derive(TypeVisitable, TypeFoldable)]
705pub(crate) enum TypeTrace {
706 Types(Ty, Ty),
707 BaseTys(BaseTy, BaseTy),
708}
709
710#[expect(dead_code, reason = "we use this for debugging some time")]
711impl TypeTrace {
712 fn tys(a: &Ty, b: &Ty) -> Self {
713 Self::Types(a.clone(), b.clone())
714 }
715
716 fn btys(a: &BaseTy, b: &BaseTy) -> Self {
717 Self::BaseTys(a.clone(), b.clone())
718 }
719}
720
721impl fmt::Debug for TypeTrace {
722 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
723 match self {
724 TypeTrace::Types(a, b) => write!(f, "{a:?} - {b:?}"),
725 TypeTrace::BaseTys(a, b) => write!(f, "{a:?} - {b:?}"),
726 }
727 }
728}
729
730pub trait LocEnv {
731 fn ptr_to_ref(
732 &mut self,
733 infcx: &mut InferCtxtAt,
734 reason: ConstrReason,
735 re: Region,
736 path: &Path,
737 bound: Ty,
738 ) -> InferResult<Ty>;
739
740 fn unfold_strg_ref(&mut self, infcx: &mut InferCtxt, path: &Path, ty: &Ty) -> InferResult<Loc>;
741
742 fn get(&self, path: &Path) -> Ty;
743}
744
745struct DummyEnv;
746
747impl LocEnv for DummyEnv {
748 fn ptr_to_ref(
749 &mut self,
750 _: &mut InferCtxtAt,
751 _: ConstrReason,
752 _: Region,
753 _: &Path,
754 _: Ty,
755 ) -> InferResult<Ty> {
756 tracked_span_bug!("call to `ptr_to_ref` on `DummyEnv`")
757 }
758
759 fn unfold_strg_ref(&mut self, _: &mut InferCtxt, _: &Path, _: &Ty) -> InferResult<Loc> {
760 tracked_span_bug!("call to `unfold_str_ref` on `DummyEnv`")
761 }
762
763 fn get(&self, _: &Path) -> Ty {
764 tracked_span_bug!("call to `get` on `DummyEnv`")
765 }
766}
767
768struct Sub<'a, E> {
770 env: &'a mut E,
772 reason: ConstrReason,
773 span: Span,
774 obligations: Vec<Binder<rty::CoroutineObligPredicate>>,
778}
779
780impl<'a, E: LocEnv> Sub<'a, E> {
781 fn new(env: &'a mut E, reason: ConstrReason, span: Span) -> Self {
782 Self { env, reason, span, obligations: vec![] }
783 }
784
785 fn tag(&self) -> Tag {
786 Tag::new(self.reason, self.span)
787 }
788
789 fn tys(&mut self, infcx: &mut InferCtxt, a: &Ty, b: &Ty) -> InferResult {
790 let infcx = &mut infcx.branch();
791 let a = infcx.unpack(a);
797
798 match (a.kind(), b.kind()) {
799 (TyKind::Exists(..), _) => {
800 bug!("existentials should have been removed by the unpacking above");
801 }
802 (TyKind::Constr(..), _) => {
803 bug!("constraint types should have been removed by the unpacking above");
804 }
805
806 (_, TyKind::Exists(ctor_b)) => {
807 infcx.enter_exists(ctor_b, |infcx, ty_b| self.tys(infcx, &a, &ty_b))
808 }
809 (_, TyKind::Constr(pred_b, ty_b)) => {
810 infcx.check_pred(pred_b, self.tag());
811 self.tys(infcx, &a, ty_b)
812 }
813
814 (TyKind::Ptr(PtrKind::Mut(_), path_a), TyKind::StrgRef(_, path_b, ty_b)) => {
815 let ty_a = self.env.get(path_a);
819 infcx.unify_exprs(&path_a.to_expr(), &path_b.to_expr());
820 self.tys(infcx, &ty_a, ty_b)
821 }
822 (TyKind::StrgRef(_, path_a, ty_a), TyKind::StrgRef(_, path_b, ty_b)) => {
823 self.env.unfold_strg_ref(infcx, path_a, ty_a)?;
835 let ty_a = self.env.get(path_a);
836 infcx.unify_exprs(&path_a.to_expr(), &path_b.to_expr());
837 self.tys(infcx, &ty_a, ty_b)
838 }
839 (
840 TyKind::Ptr(PtrKind::Mut(re), path),
841 TyKind::Indexed(BaseTy::Ref(_, bound, Mutability::Mut), idx),
842 ) => {
843 self.idxs_eq(infcx, &Expr::unit(), idx);
846
847 self.env.ptr_to_ref(
848 &mut infcx.at(self.span),
849 self.reason,
850 *re,
851 path,
852 bound.clone(),
853 )?;
854 Ok(())
855 }
856
857 (TyKind::Indexed(bty_a, idx_a), TyKind::Indexed(bty_b, idx_b)) => {
858 self.btys(infcx, bty_a, bty_b)?;
859 self.idxs_eq(infcx, idx_a, idx_b);
860 Ok(())
861 }
862 (TyKind::Ptr(pk_a, path_a), TyKind::Ptr(pk_b, path_b)) => {
863 debug_assert_eq!(pk_a, pk_b);
864 debug_assert_eq!(path_a, path_b);
865 Ok(())
866 }
867 (TyKind::Param(param_ty_a), TyKind::Param(param_ty_b)) => {
868 debug_assert_eq!(param_ty_a, param_ty_b);
869 Ok(())
870 }
871 (_, TyKind::Uninit) => Ok(()),
872 (TyKind::Downcast(.., fields_a), TyKind::Downcast(.., fields_b)) => {
873 debug_assert_eq!(fields_a.len(), fields_b.len());
874 for (ty_a, ty_b) in iter::zip(fields_a, fields_b) {
875 self.tys(infcx, ty_a, ty_b)?;
876 }
877 Ok(())
878 }
879 _ => Err(query_bug!("incompatible types: `{a:?}` - `{b:?}`"))?,
880 }
881 }
882
883 fn btys(&mut self, infcx: &mut InferCtxt, a: &BaseTy, b: &BaseTy) -> InferResult {
884 match (a, b) {
887 (BaseTy::Int(int_ty_a), BaseTy::Int(int_ty_b)) => {
888 debug_assert_eq!(int_ty_a, int_ty_b);
889 Ok(())
890 }
891 (BaseTy::Uint(uint_ty_a), BaseTy::Uint(uint_ty_b)) => {
892 debug_assert_eq!(uint_ty_a, uint_ty_b);
893 Ok(())
894 }
895 (BaseTy::Adt(a_adt, a_args), BaseTy::Adt(b_adt, b_args)) => {
896 tracked_span_dbg_assert_eq!(a_adt.did(), b_adt.did());
897 tracked_span_dbg_assert_eq!(a_args.len(), b_args.len());
898 let variances = infcx.genv.variances_of(a_adt.did());
899 for (variance, ty_a, ty_b) in izip!(variances, a_args.iter(), b_args.iter()) {
900 self.generic_args(infcx, *variance, ty_a, ty_b)?;
901 }
902 Ok(())
903 }
904 (BaseTy::FnDef(a_def_id, a_args), BaseTy::FnDef(b_def_id, b_args)) => {
905 debug_assert_eq!(a_def_id, b_def_id);
906 debug_assert_eq!(a_args.len(), b_args.len());
907 for (arg_a, arg_b) in iter::zip(a_args, b_args) {
915 match (arg_a, arg_b) {
916 (GenericArg::Ty(ty_a), GenericArg::Ty(ty_b)) => {
917 let bty_a = ty_a.as_bty_skipping_existentials();
918 let bty_b = ty_b.as_bty_skipping_existentials();
919 tracked_span_dbg_assert_eq!(bty_a, bty_b);
920 }
921 (GenericArg::Base(ctor_a), GenericArg::Base(ctor_b)) => {
922 let bty_a = ctor_a.as_bty_skipping_binder();
923 let bty_b = ctor_b.as_bty_skipping_binder();
924 tracked_span_dbg_assert_eq!(bty_a, bty_b);
925 }
926 (_, _) => tracked_span_dbg_assert_eq!(arg_a, arg_b),
927 }
928 }
929 Ok(())
930 }
931 (BaseTy::Float(float_ty_a), BaseTy::Float(float_ty_b)) => {
932 debug_assert_eq!(float_ty_a, float_ty_b);
933 Ok(())
934 }
935 (BaseTy::Slice(ty_a), BaseTy::Slice(ty_b)) => self.tys(infcx, ty_a, ty_b),
936 (BaseTy::Ref(_, ty_a, Mutability::Mut), BaseTy::Ref(_, ty_b, Mutability::Mut)) => {
937 if ty_a.is_slice()
938 && let TyKind::Indexed(_, idx_a) = ty_a.kind()
939 && let TyKind::Exists(bty_b) = ty_b.kind()
940 {
941 self.tys(infcx, ty_a, ty_b)?;
946 self.tys(infcx, &bty_b.replace_bound_reft(idx_a), ty_a)
947 } else {
948 self.tys(infcx, ty_a, ty_b)?;
949 self.tys(infcx, ty_b, ty_a)
950 }
951 }
952 (BaseTy::Ref(_, ty_a, Mutability::Not), BaseTy::Ref(_, ty_b, Mutability::Not)) => {
953 self.tys(infcx, ty_a, ty_b)
954 }
955 (BaseTy::Tuple(tys_a), BaseTy::Tuple(tys_b)) => {
956 debug_assert_eq!(tys_a.len(), tys_b.len());
957 for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
958 self.tys(infcx, ty_a, ty_b)?;
959 }
960 Ok(())
961 }
962 (
963 BaseTy::Alias(AliasKind::Opaque, alias_ty_a),
964 BaseTy::Alias(AliasKind::Opaque, alias_ty_b),
965 ) => {
966 debug_assert_eq!(alias_ty_a.def_id, alias_ty_b.def_id);
967
968 for (ty_a, ty_b) in izip!(alias_ty_a.args.iter(), alias_ty_b.args.iter()) {
970 self.generic_args(infcx, Invariant, ty_a, ty_b)?;
971 }
972
973 debug_assert_eq!(alias_ty_a.refine_args.len(), alias_ty_b.refine_args.len());
975 iter::zip(alias_ty_a.refine_args.iter(), alias_ty_b.refine_args.iter())
976 .for_each(|(expr_a, expr_b)| infcx.unify_exprs(expr_a, expr_b));
977
978 Ok(())
979 }
980 (_, BaseTy::Alias(AliasKind::Opaque, alias_ty_b)) => {
981 self.handle_opaque_type(infcx, a, alias_ty_b)
983 }
984 (
985 BaseTy::Alias(AliasKind::Projection, alias_ty_a),
986 BaseTy::Alias(AliasKind::Projection, alias_ty_b),
987 ) => {
988 tracked_span_dbg_assert_eq!(alias_ty_a, alias_ty_b);
989 Ok(())
990 }
991 (BaseTy::Array(ty_a, len_a), BaseTy::Array(ty_b, len_b)) => {
992 tracked_span_dbg_assert_eq!(len_a, len_b);
993 self.tys(infcx, ty_a, ty_b)
994 }
995 (BaseTy::Param(param_a), BaseTy::Param(param_b)) => {
996 debug_assert_eq!(param_a, param_b);
997 Ok(())
998 }
999 (BaseTy::Bool, BaseTy::Bool)
1000 | (BaseTy::Str, BaseTy::Str)
1001 | (BaseTy::Char, BaseTy::Char)
1002 | (BaseTy::RawPtr(_, _), BaseTy::RawPtr(_, _))
1003 | (BaseTy::RawPtrMetadata(_), BaseTy::RawPtrMetadata(_)) => Ok(()),
1004 (BaseTy::Dynamic(preds_a, _), BaseTy::Dynamic(preds_b, _)) => {
1005 tracked_span_assert_eq!(preds_a.erase_regions(), preds_b.erase_regions());
1006 Ok(())
1007 }
1008 (BaseTy::Closure(did1, tys_a, _), BaseTy::Closure(did2, tys_b, _)) if did1 == did2 => {
1009 debug_assert_eq!(tys_a.len(), tys_b.len());
1010 for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
1011 self.tys(infcx, ty_a, ty_b)?;
1012 }
1013 Ok(())
1014 }
1015 (BaseTy::FnPtr(sig_a), BaseTy::FnPtr(sig_b)) => {
1016 tracked_span_assert_eq!(sig_a.erase_regions(), sig_b.erase_regions());
1017 Ok(())
1018 }
1019 (BaseTy::Never, BaseTy::Never) => Ok(()),
1020 (
1021 BaseTy::Coroutine(did1, resume_ty_a, tys_a, _),
1022 BaseTy::Coroutine(did2, resume_ty_b, tys_b, _),
1023 ) if did1 == did2 => {
1024 debug_assert_eq!(tys_a.len(), tys_b.len());
1025 for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
1026 self.tys(infcx, ty_a, ty_b)?;
1027 }
1028 self.tys(infcx, resume_ty_b, resume_ty_a)?;
1030 self.tys(infcx, resume_ty_a, resume_ty_b)?;
1031
1032 Ok(())
1033 }
1034 _ => Err(query_bug!("incompatible base types: `{a:?}` - `{b:?}`"))?,
1035 }
1036 }
1037
1038 fn generic_args(
1039 &mut self,
1040 infcx: &mut InferCtxt,
1041 variance: Variance,
1042 a: &GenericArg,
1043 b: &GenericArg,
1044 ) -> InferResult {
1045 let (ty_a, ty_b) = match (a, b) {
1046 (GenericArg::Ty(ty_a), GenericArg::Ty(ty_b)) => (ty_a.clone(), ty_b.clone()),
1047 (GenericArg::Base(ctor_a), GenericArg::Base(ctor_b)) => {
1048 tracked_span_dbg_assert_eq!(ctor_a.sort(), ctor_b.sort());
1049 (ctor_a.to_ty(), ctor_b.to_ty())
1050 }
1051 (GenericArg::Lifetime(_), GenericArg::Lifetime(_)) => return Ok(()),
1052 (GenericArg::Const(cst_a), GenericArg::Const(cst_b)) => {
1053 debug_assert_eq!(cst_a, cst_b);
1054 return Ok(());
1055 }
1056 _ => Err(query_bug!("incompatible generic args: `{a:?}` `{b:?}`"))?,
1057 };
1058 match variance {
1059 Variance::Covariant => self.tys(infcx, &ty_a, &ty_b),
1060 Variance::Invariant => {
1061 self.tys(infcx, &ty_a, &ty_b)?;
1062 self.tys(infcx, &ty_b, &ty_a)
1063 }
1064 Variance::Contravariant => self.tys(infcx, &ty_b, &ty_a),
1065 Variance::Bivariant => Ok(()),
1066 }
1067 }
1068
1069 fn idxs_eq(&mut self, infcx: &mut InferCtxt, a: &Expr, b: &Expr) {
1070 if a == b {
1071 return;
1072 }
1073 match (a.kind(), b.kind()) {
1074 (
1075 ExprKind::Ctor(Ctor::Struct(did_a), flds_a),
1076 ExprKind::Ctor(Ctor::Struct(did_b), flds_b),
1077 ) => {
1078 debug_assert_eq!(did_a, did_b);
1079 for (a, b) in iter::zip(flds_a, flds_b) {
1080 self.idxs_eq(infcx, a, b);
1081 }
1082 }
1083 (ExprKind::Tuple(flds_a), ExprKind::Tuple(flds_b)) => {
1084 for (a, b) in iter::zip(flds_a, flds_b) {
1085 self.idxs_eq(infcx, a, b);
1086 }
1087 }
1088 (_, ExprKind::Tuple(flds_b)) => {
1089 for (f, b) in flds_b.iter().enumerate() {
1090 let proj = FieldProj::Tuple { arity: flds_b.len(), field: f as u32 };
1091 let a = a.proj_and_reduce(proj);
1092 self.idxs_eq(infcx, &a, b);
1093 }
1094 }
1095
1096 (_, ExprKind::Ctor(Ctor::Struct(def_id), flds_b)) => {
1097 for (f, b) in flds_b.iter().enumerate() {
1098 let proj = FieldProj::Adt { def_id: *def_id, field: f as u32 };
1099 let a = a.proj_and_reduce(proj);
1100 self.idxs_eq(infcx, &a, b);
1101 }
1102 }
1103
1104 (ExprKind::Tuple(flds_a), _) => {
1105 infcx.unify_exprs(a, b);
1106 for (f, a) in flds_a.iter().enumerate() {
1107 let proj = FieldProj::Tuple { arity: flds_a.len(), field: f as u32 };
1108 let b = b.proj_and_reduce(proj);
1109 self.idxs_eq(infcx, a, &b);
1110 }
1111 }
1112 (ExprKind::Ctor(Ctor::Struct(def_id), flds_a), _) => {
1113 infcx.unify_exprs(a, b);
1114 for (f, a) in flds_a.iter().enumerate() {
1115 let proj = FieldProj::Adt { def_id: *def_id, field: f as u32 };
1116 let b = b.proj_and_reduce(proj);
1117 self.idxs_eq(infcx, a, &b);
1118 }
1119 }
1120 (ExprKind::Abs(lam_a), ExprKind::Abs(lam_b)) => {
1121 self.abs_eq(infcx, lam_a, lam_b);
1122 }
1123 (_, ExprKind::Abs(lam_b)) => {
1124 self.abs_eq(infcx, &a.eta_expand_abs(lam_b.vars(), lam_b.output()), lam_b);
1125 }
1126 (ExprKind::Abs(lam_a), _) => {
1127 infcx.unify_exprs(a, b);
1128 self.abs_eq(infcx, lam_a, &b.eta_expand_abs(lam_a.vars(), lam_a.output()));
1129 }
1130 (ExprKind::KVar(_), _) | (_, ExprKind::KVar(_)) => {
1131 infcx.check_impl(a, b, self.tag());
1132 infcx.check_impl(b, a, self.tag());
1133 }
1134 _ => {
1135 infcx.unify_exprs(a, b);
1136 let span = b.span();
1137 infcx.check_pred(Expr::binary_op(rty::BinOp::Eq, a, b).at_opt(span), self.tag());
1138 }
1139 }
1140 }
1141
1142 fn abs_eq(&mut self, infcx: &mut InferCtxt, a: &Lambda, b: &Lambda) {
1143 debug_assert_eq!(a.vars().len(), b.vars().len());
1144 let vars = a
1145 .vars()
1146 .iter()
1147 .map(|kind| {
1148 let (sort, _, kind) = kind.expect_refine();
1149 Expr::fvar(infcx.define_bound_reft_var(sort, kind))
1150 })
1151 .collect_vec();
1152 let body_a = a.apply(&vars);
1153 let body_b = b.apply(&vars);
1154 self.idxs_eq(infcx, &body_a, &body_b);
1155 }
1156
1157 fn handle_opaque_type(
1158 &mut self,
1159 infcx: &mut InferCtxt,
1160 bty: &BaseTy,
1161 alias_ty: &AliasTy,
1162 ) -> InferResult {
1163 if let BaseTy::Coroutine(def_id, resume_ty, upvar_tys, args) = bty {
1164 let obligs = mk_coroutine_obligations(
1165 infcx.genv,
1166 def_id,
1167 resume_ty,
1168 upvar_tys,
1169 &alias_ty.def_id,
1170 args.clone(),
1171 )?;
1172 self.obligations.extend(obligs);
1173 } else {
1174 let bounds = infcx.genv.item_bounds(alias_ty.def_id)?.instantiate(
1175 infcx.tcx(),
1176 &alias_ty.args,
1177 &alias_ty.refine_args,
1178 );
1179 for clause in &bounds {
1180 if !clause.kind().vars().is_empty() {
1181 Err(query_bug!("handle_opaque_types: clause with bound vars: `{clause:?}`"))?;
1182 }
1183 if let rty::ClauseKind::Projection(pred) = clause.kind_skipping_binder() {
1184 let alias_ty = pred.projection_ty.with_self_ty(bty.to_subset_ty_ctor());
1185 let ty1 = BaseTy::Alias(AliasKind::Projection, alias_ty)
1186 .to_ty()
1187 .deeply_normalize(&mut infcx.at(self.span))?;
1188 let ty2 = pred.term.to_ty();
1189 self.tys(infcx, &ty1, &ty2)?;
1190 }
1191 }
1192 }
1193 Ok(())
1194 }
1195}
1196
1197fn mk_coroutine_obligations(
1198 genv: GlobalEnv,
1199 generator_did: &DefId,
1200 resume_ty: &Ty,
1201 upvar_tys: &List<Ty>,
1202 opaque_def_id: &DefId,
1203 args: flux_rustc_bridge::ty::GenericArgs,
1204) -> InferResult<Vec<Binder<rty::CoroutineObligPredicate>>> {
1205 let bounds = genv.item_bounds(*opaque_def_id)?.skip_binder();
1206 for bound in &bounds {
1207 if let Some(proj_clause) = bound.as_projection_clause() {
1208 return Ok(vec![proj_clause.map(|proj_clause| {
1209 let output = proj_clause.term;
1210 CoroutineObligPredicate {
1211 def_id: *generator_did,
1212 resume_ty: resume_ty.clone(),
1213 upvar_tys: upvar_tys.clone(),
1214 output: output.to_ty(),
1215 args,
1216 }
1217 })]);
1218 }
1219 }
1220 bug!("no projection predicate")
1221}
1222
1223#[derive(Debug)]
1224pub enum InferErr {
1225 UnsolvedEvar(EVid),
1226 Query(QueryErr),
1227}
1228
1229impl From<QueryErr> for InferErr {
1230 fn from(v: QueryErr) -> Self {
1231 Self::Query(v)
1232 }
1233}
1234
1235mod pretty {
1236 use std::fmt;
1237
1238 use flux_middle::pretty::*;
1239
1240 use super::*;
1241
1242 impl Pretty for Tag {
1243 fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1244 w!(cx, f, "{:?} at {:?}", ^self.reason, self.src_span)?;
1245 if let Some(dst_span) = self.dst_span {
1246 w!(cx, f, " ({:?})", ^dst_span)?;
1247 }
1248 Ok(())
1249 }
1250 }
1251
1252 impl_debug_with_default_cx!(Tag);
1253}