1use std::iter;
2
3use derive_where::derive_where;
4use ena::unify::InPlaceUnificationTable;
5use flux_common::{bug, iter::IterExt, span_bug, tracked_span_bug};
6use flux_errors::{ErrorGuaranteed, Errors};
7use flux_infer::projections::NormalizeExt;
8use flux_middle::{
9 fhir::{self, FhirId, FluxOwnerId, QPathExpr, visit::Visitor as _},
10 global_env::GlobalEnv,
11 queries::QueryResult,
12 rty::{
13 self, AdtSortDef, FuncSort, List, RecordCtor, WfckResults,
14 fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable},
15 },
16};
17use itertools::{Itertools, izip};
18use rustc_data_structures::unord::UnordMap;
19use rustc_errors::Diagnostic;
20use rustc_hash::FxHashMap;
21use rustc_hir::def::DefKind;
22use rustc_middle::ty::TypingMode;
23use rustc_span::{Span, def_id::DefId, symbol::Ident};
24
25use super::errors;
26use crate::rustc_infer::infer::TyCtxtInferExt;
27
28type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
29
30pub(super) struct InferCtxt<'genv, 'tcx> {
31 pub genv: GlobalEnv<'genv, 'tcx>,
32 pub owner: FluxOwnerId,
33 pub wfckresults: WfckResults,
34 sort_unification_table: InPlaceUnificationTable<rty::SortVid>,
35 bv_size_unification_table: InPlaceUnificationTable<rty::BvSizeVid>,
36 params: FxHashMap<fhir::ParamId, (fhir::RefineParam<'genv>, rty::Sort)>,
37 node_sort: FxHashMap<FhirId, rty::Sort>,
38 path_args: UnordMap<FhirId, rty::GenericArgs>,
39 sort_of_alias_reft: FxHashMap<FhirId, rty::FuncSort>,
40 sort_of_literal: NodeMap<'genv, rty::Sort>,
41 sort_of_bin_op: NodeMap<'genv, rty::Sort>,
42 sort_args_of_app: NodeMap<'genv, List<rty::SortArg>>,
43}
44
45pub fn prim_op_sort(op: &fhir::BinOp) -> Option<(Vec<rty::Sort>, rty::Sort)> {
46 match op {
47 fhir::BinOp::BitAnd
48 | fhir::BinOp::BitOr
49 | fhir::BinOp::BitXor
50 | fhir::BinOp::BitShl
51 | fhir::BinOp::BitShr => Some((vec![rty::Sort::Int, rty::Sort::Int], rty::Sort::Int)),
52 _ => None,
53 }
54}
55
56impl<'genv, 'tcx> InferCtxt<'genv, 'tcx> {
57 pub(super) fn new(genv: GlobalEnv<'genv, 'tcx>, owner: FluxOwnerId) -> Self {
58 let mut sort_unification_table = InPlaceUnificationTable::new();
60 sort_unification_table.new_key(Default::default());
61 Self {
62 genv,
63 owner,
64 wfckresults: WfckResults::new(owner),
65 sort_unification_table,
66 bv_size_unification_table: InPlaceUnificationTable::new(),
67 params: Default::default(),
68 node_sort: Default::default(),
69 path_args: Default::default(),
70 sort_of_alias_reft: Default::default(),
71 sort_of_literal: Default::default(),
72 sort_of_bin_op: Default::default(),
73 sort_args_of_app: Default::default(),
74 }
75 }
76
77 fn check_abs(
78 &mut self,
79 expr: &fhir::Expr<'genv>,
80 params: &[fhir::RefineParam],
81 body: &fhir::Expr<'genv>,
82 expected: &rty::Sort,
83 ) -> Result {
84 let Some(fsort) = self.is_coercible_from_func(expected, expr.fhir_id) else {
85 return Err(self.emit_err(errors::UnexpectedFun::new(expr.span, expected)));
86 };
87
88 let fsort = fsort.expect_mono();
89
90 if params.len() != fsort.inputs().len() {
91 return Err(self.emit_err(errors::ParamCountMismatch::new(
92 expr.span,
93 fsort.inputs().len(),
94 params.len(),
95 )));
96 }
97 iter::zip(params, fsort.inputs()).try_for_each_exhaust(|(param, expected)| {
98 let found = self.param_sort(param.id);
99 if self.try_equate(&found, expected).is_none() {
100 return Err(self.emit_sort_mismatch(param.span, expected, &found));
101 }
102 Ok(())
103 })?;
104 self.check_expr(body, fsort.output())?;
105 self.wfckresults
106 .node_sorts_mut()
107 .insert(body.fhir_id, fsort.output().clone());
108 Ok(())
109 }
110
111 fn check_field_exprs(
112 &mut self,
113 expr_span: Span,
114 sort_def: &AdtSortDef,
115 sort_args: &[rty::Sort],
116 field_exprs: &[fhir::FieldExpr<'genv>],
117 spread: Option<&fhir::Spread<'genv>>,
118 expected: &rty::Sort,
119 ) -> Result {
120 let sort_by_field_name = sort_def.struct_variant().sort_by_field_name(sort_args);
121 let mut used_fields = FxHashMap::default();
122 for expr in field_exprs {
123 let Some(sort) = sort_by_field_name.get(&expr.ident.name) else {
125 return Err(self.emit_err(errors::FieldNotFound::new(expected.clone(), expr.ident)));
126 };
127 if let Some(old_field) = used_fields.insert(expr.ident.name, expr.ident) {
128 return Err(self.emit_err(errors::DuplicateFieldUsed::new(expr.ident, old_field)));
129 }
130 self.check_expr(&expr.expr, sort)?;
131 }
132 if let Some(spread) = spread {
133 self.check_expr(&spread.expr, expected)
135 } else if sort_by_field_name.len() != used_fields.len() {
136 let missing_fields = sort_by_field_name
138 .into_keys()
139 .filter(|x| !used_fields.contains_key(x))
140 .collect();
141 Err(self.emit_err(errors::ConstructorMissingFields::new(expr_span, missing_fields)))
142 } else {
143 Ok(())
144 }
145 }
146
147 fn check_constructor(
148 &mut self,
149 expr: &fhir::Expr,
150 field_exprs: &[fhir::FieldExpr<'genv>],
151 spread: Option<&fhir::Spread<'genv>>,
152 expected: &rty::Sort,
153 ) -> Result {
154 let expected = self.resolve_vars_if_possible(expected);
155 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = &expected {
156 self.wfckresults
157 .record_ctors_mut()
158 .insert(expr.fhir_id, RecordCtor::Struct(sort_def.did()));
159 self.check_field_exprs(expr.span, sort_def, sort_args, field_exprs, spread, &expected)?;
160 Ok(())
161 } else {
162 Err(self.emit_err(errors::UnexpectedConstructor::new(expr.span, &expected)))
163 }
164 }
165
166 fn check_record(
167 &mut self,
168 arg: &fhir::Expr<'genv>,
169 flds: &[fhir::Expr<'genv>],
170 expected: &rty::Sort,
171 ) -> Result {
172 if let Some(sorts) = expected.field_sorts() {
173 if flds.len() != sorts.len() {
174 return Err(self.emit_err(errors::ArgCountMismatch::new(
175 Some(arg.span),
176 String::from("type"),
177 sorts.len(),
178 flds.len(),
179 )));
180 }
181 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), _) = expected {
182 self.wfckresults
183 .record_ctors_mut()
184 .insert(arg.fhir_id, RecordCtor::Struct(sort_def.did()));
185 } else if matches!(expected, rty::Sort::RawPtr) {
186 self.wfckresults
187 .record_ctors_mut()
188 .insert(arg.fhir_id, RecordCtor::RawPtr);
189 }
190
191 izip!(flds, &sorts)
192 .map(|(arg, expected)| self.check_expr(arg, expected))
193 .try_collect_exhaust()
194 } else {
195 Err(self.emit_err(errors::ArgCountMismatch::new(
196 Some(arg.span),
197 String::from("type"),
198 1,
199 flds.len(),
200 )))
201 }
202 }
203
204 pub(super) fn check_expr(&mut self, expr: &fhir::Expr<'genv>, expected: &rty::Sort) -> Result {
205 match expr.kind {
206 fhir::ExprKind::IfThenElse(p, e1, e2) => {
207 self.check_expr(p, &rty::Sort::Bool)?;
208 self.check_expr(e1, expected)?;
209 self.check_expr(e2, expected)?;
210 return Ok(());
211 }
212 fhir::ExprKind::Abs(params, body) => {
213 self.check_abs(expr, params, body, expected)?;
214 return Ok(());
215 }
216 fhir::ExprKind::Record(fields) => {
217 self.check_record(expr, fields, expected)?;
218 return Ok(());
219 }
220 fhir::ExprKind::Constructor(None, exprs, spread) => {
221 self.check_constructor(expr, exprs, spread, expected)?;
222 return Ok(());
223 }
224 fhir::ExprKind::Tuple(exprs) => {
225 if let rty::Sort::Tuple(expected_sorts) = expected
227 && exprs.len() == expected_sorts.len()
228 {
229 for (expr, expected_sort) in iter::zip(exprs, expected_sorts) {
230 self.check_expr(expr, expected_sort)?;
231 }
232 return Ok(());
233 }
234 }
235 fhir::ExprKind::Err(_) => {
236 return Ok(());
238 }
239 fhir::ExprKind::UnaryOp(..)
240 | fhir::ExprKind::SetLiteral(..)
241 | fhir::ExprKind::BinaryOp(..)
242 | fhir::ExprKind::Dot(..)
243 | fhir::ExprKind::App(..)
244 | fhir::ExprKind::Alias(..)
245 | fhir::ExprKind::Var(..)
246 | fhir::ExprKind::Literal(..)
247 | fhir::ExprKind::Quant(..)
248 | fhir::ExprKind::Block(..)
249 | fhir::ExprKind::Constructor(..)
250 | fhir::ExprKind::PrimApp(..) => {}
251 }
252 let found = self.synth_expr(expr)?;
254 let found = self.resolve_vars_if_possible(&found);
255 let expected = self.resolve_vars_if_possible(expected);
256 if !self.is_coercible(&found, &expected, expr.fhir_id) {
257 return Err(self.emit_sort_mismatch(expr.span, &expected, &found));
258 }
259 Ok(())
260 }
261
262 pub(super) fn check_loc(&mut self, loc: &fhir::PathExpr) -> Result {
263 let found = self.synth_path(loc);
264 if found == rty::Sort::Loc {
265 Ok(())
266 } else {
267 Err(self.emit_sort_mismatch(loc.span, &rty::Sort::Loc, &found))
268 }
269 }
270
271 fn synth_lit(&mut self, lit: fhir::Lit, expr: &fhir::Expr<'genv>) -> rty::Sort {
272 match lit {
273 fhir::Lit::Int(_) => {
274 let sort = self.next_sort_var_with_cstr(rty::SortCstr::NUMERIC);
275 self.sort_of_literal.insert(*expr, sort.clone());
276 sort
277 }
278 fhir::Lit::Real(_) => rty::Sort::Real,
279 fhir::Lit::Bool(_) => rty::Sort::Bool,
280 fhir::Lit::Str(_) => rty::Sort::Str,
281 fhir::Lit::Char(_) => rty::Sort::Char,
282 }
283 }
284
285 fn synth_prim_app(
286 &mut self,
287 op: &fhir::BinOp,
288 e1: &fhir::Expr<'genv>,
289 e2: &fhir::Expr<'genv>,
290 span: Span,
291 ) -> Result<rty::Sort> {
292 let Some((inputs, output)) = prim_op_sort(op) else {
293 return Err(self.emit_err(errors::UnsupportedPrimOp::new(span, *op)));
294 };
295 let [sort1, sort2] = &inputs[..] else {
296 return Err(self.emit_err(errors::ArgCountMismatch::new(
297 Some(span),
298 String::from("primop app"),
299 inputs.len(),
300 2,
301 )));
302 };
303 self.check_expr(e1, sort1)?;
304 self.check_expr(e2, sort2)?;
305 Ok(output)
306 }
307
308 fn synth_expr(&mut self, expr: &fhir::Expr<'genv>) -> Result<rty::Sort> {
309 match expr.kind {
310 fhir::ExprKind::Var(QPathExpr::Resolved(path, _)) => Ok(self.synth_path(&path)),
311 fhir::ExprKind::Var(QPathExpr::TypeRelative(..)) => {
312 Ok(self.synth_type_relative_path(expr))
313 }
314 fhir::ExprKind::Literal(lit) => Ok(self.synth_lit(lit, expr)),
315 fhir::ExprKind::BinaryOp(op, e1, e2) => self.synth_binary_op(expr, op, e1, e2),
316 fhir::ExprKind::PrimApp(op, e1, e2) => self.synth_prim_app(&op, e1, e2, expr.span),
317 fhir::ExprKind::UnaryOp(op, e) => self.synth_unary_op(op, e),
318 fhir::ExprKind::App(callee, args) => {
319 let sort = self.ensure_resolved_path(&callee)?;
320 let Some(poly_fsort) = self.is_coercible_to_func(&sort, callee.fhir_id) else {
321 return Err(self.emit_err(errors::ExpectedFun::new(callee.span, &sort)));
322 };
323 let fsort = self.instantiate_func_sort(expr, poly_fsort);
324 self.synth_app(fsort, args, expr.span)
325 }
326 fhir::ExprKind::Quant(_, param, dom, body) => {
327 let sort = self.param_sort(param.id);
328 if let fhir::QuantDom::Bounded { .. } = dom
329 && self.try_equate(&sort, &rty::Sort::Int).is_none()
330 {
331 return Err(self.emit_err(errors::IllSortedQuantifier::new(param.span, sort)));
332 }
333 self.check_expr(body, &rty::Sort::Bool)?;
334 Ok(rty::Sort::Bool)
335 }
336 fhir::ExprKind::Alias(_alias_reft, args) => {
337 let fsort = self.sort_of_alias_reft(expr.fhir_id);
340 self.synth_app(fsort, args, expr.span)
341 }
342 fhir::ExprKind::IfThenElse(p, e1, e2) => {
343 self.check_expr(p, &rty::Sort::Bool)?;
344 let sort = self.synth_expr(e1)?;
345 self.check_expr(e2, &sort)?;
346 Ok(sort)
347 }
348 fhir::ExprKind::Dot(base, fld) => {
349 let sort = self.synth_expr(base)?;
350 let sort = self
351 .fully_resolve(&sort)
352 .map_err(|_| self.emit_err(errors::CannotInferSort::new(base.span)))?;
353 match &sort {
354 rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) => {
355 let (proj, sort) = sort_def
356 .struct_variant()
357 .field_by_name(sort_def.did(), sort_args, fld.name)
358 .ok_or_else(|| self.emit_field_not_found(&sort, fld))?;
359 self.wfckresults
360 .field_projs_mut()
361 .insert(expr.fhir_id, proj);
362 Ok(sort)
363 }
364 rty::Sort::Tuple(sorts) => {
365 if let Ok(field_idx) = fld.name.as_str().parse::<usize>()
367 && field_idx < sorts.len()
368 {
369 let proj = rty::FieldProj::Tuple {
370 arity: sorts.len(),
371 field: field_idx as u32,
372 };
373 self.wfckresults
374 .field_projs_mut()
375 .insert(expr.fhir_id, proj);
376 Ok(sorts[field_idx].clone())
377 } else {
378 Err(self.emit_field_not_found(&sort, fld))
379 }
380 }
381 rty::Sort::RawPtr => {
382 let field = rty::RawPtrField::from_name(fld.name)
383 .ok_or_else(|| self.emit_field_not_found(&sort, fld))?;
384 self.wfckresults
385 .field_projs_mut()
386 .insert(expr.fhir_id, rty::FieldProj::RawPtr { field });
387 Ok(rty::Sort::Int)
388 }
389 rty::Sort::Bool | rty::Sort::Int | rty::Sort::Real => {
390 Err(self.emit_err(errors::InvalidPrimitiveDotAccess::new(&sort, fld)))
391 }
392 _ => Err(self.emit_field_not_found(&sort, fld)),
393 }
394 }
395 fhir::ExprKind::Constructor(Some(path), field_exprs, spread) => {
396 let path_def_id = match path.res {
400 fhir::Res::Def(DefKind::Enum | DefKind::Struct, def_id) => def_id,
401 _ => span_bug!(expr.span, "unexpected path in constructor"),
402 };
403 let sort_def = self
404 .genv
405 .adt_sort_def_of(path_def_id)
406 .map_err(|e| self.emit_err(e))?;
407 let fresh_args: rty::List<_> = (0..sort_def.param_count())
409 .map(|_| self.next_sort_var())
410 .collect();
411 let sort = rty::Sort::App(rty::SortCtor::Adt(sort_def.clone()), fresh_args.clone());
412 self.check_field_exprs(
414 expr.span,
415 &sort_def,
416 &fresh_args,
417 field_exprs,
418 spread,
419 &sort,
420 )?;
421 Ok(sort)
422 }
423 fhir::ExprKind::Block(decls, body) => {
424 for decl in decls {
425 self.check_expr(&decl.init, &self.param_sort(decl.param.id))?;
426 }
427 self.synth_expr(body)
428 }
429 fhir::ExprKind::SetLiteral(elems) => {
430 let elem_sort = self.next_sort_var();
431 for elem in elems {
432 self.check_expr(elem, &elem_sort)?;
433 }
434 Ok(rty::Sort::App(rty::SortCtor::Set, List::singleton(elem_sort)))
435 }
436 fhir::ExprKind::Tuple(exprs) => {
437 let sorts = exprs
438 .iter()
439 .map(|expr| self.synth_expr(expr))
440 .try_collect()?;
441 Ok(rty::Sort::Tuple(sorts))
442 }
443 _ => Err(self.emit_err(errors::CannotInferSort::new(expr.span))),
444 }
445 }
446
447 fn synth_path(&mut self, path: &fhir::PathExpr) -> rty::Sort {
448 self.node_sort
449 .get(&path.fhir_id)
450 .unwrap_or_else(|| tracked_span_bug!("no sort found for path: `{path:?}`"))
451 .clone()
452 }
453
454 fn synth_type_relative_path(&mut self, expr: &fhir::Expr) -> rty::Sort {
455 self.node_sort
456 .get(&expr.fhir_id)
457 .unwrap_or_else(|| tracked_span_bug!("no sort found for: `{expr:?}`"))
458 .clone()
459 }
460
461 fn synth_binary_op(
462 &mut self,
463 expr: &fhir::Expr<'genv>,
464 op: fhir::BinOp,
465 e1: &fhir::Expr<'genv>,
466 e2: &fhir::Expr<'genv>,
467 ) -> Result<rty::Sort> {
468 match op {
469 fhir::BinOp::Or | fhir::BinOp::And | fhir::BinOp::Iff | fhir::BinOp::Imp => {
470 self.check_expr(e1, &rty::Sort::Bool)?;
471 self.check_expr(e2, &rty::Sort::Bool)?;
472 Ok(rty::Sort::Bool)
473 }
474 fhir::BinOp::Eq | fhir::BinOp::Ne => {
475 let sort = self.next_sort_var();
476 self.check_expr(e1, &sort)?;
477 self.check_expr(e2, &sort)?;
478 Ok(rty::Sort::Bool)
479 }
480 fhir::BinOp::Lt | fhir::BinOp::Le | fhir::BinOp::Gt | fhir::BinOp::Ge => {
481 let sort = self.next_sort_var();
482 self.check_expr(e1, &sort)?;
483 self.check_expr(e2, &sort)?;
484 self.sort_of_bin_op.insert(*expr, sort.clone());
485 Ok(rty::Sort::Bool)
486 }
487 fhir::BinOp::Add
488 | fhir::BinOp::Sub
489 | fhir::BinOp::Mul
490 | fhir::BinOp::Div
491 | fhir::BinOp::Mod
492 | fhir::BinOp::BitAnd
493 | fhir::BinOp::BitOr
494 | fhir::BinOp::BitXor
495 | fhir::BinOp::BitShl
496 | fhir::BinOp::BitShr => {
497 let sort = self.next_sort_var_with_cstr(rty::SortCstr::from_bin_op(op));
498 self.check_expr(e1, &sort)?;
499 self.check_expr(e2, &sort)?;
500 self.sort_of_bin_op.insert(*expr, sort.clone());
501 Ok(sort)
502 }
503 }
504 }
505
506 fn synth_unary_op(&mut self, op: fhir::UnOp, e: &fhir::Expr<'genv>) -> Result<rty::Sort> {
507 match op {
508 fhir::UnOp::Not => {
509 self.check_expr(e, &rty::Sort::Bool)?;
510 Ok(rty::Sort::Bool)
511 }
512 fhir::UnOp::Neg => {
513 self.check_expr(e, &rty::Sort::Int)?;
514 Ok(rty::Sort::Int)
515 }
516 }
517 }
518
519 fn synth_app(
520 &mut self,
521 fsort: FuncSort,
522 args: &[fhir::Expr<'genv>],
523 span: Span,
524 ) -> Result<rty::Sort> {
525 if args.len() != fsort.inputs().len() {
526 return Err(self.emit_err(errors::ArgCountMismatch::new(
527 Some(span),
528 String::from("function"),
529 fsort.inputs().len(),
530 args.len(),
531 )));
532 }
533
534 iter::zip(args, fsort.inputs())
535 .try_for_each_exhaust(|(arg, formal)| self.check_expr(arg, formal))?;
536
537 Ok(fsort.output().clone())
538 }
539
540 fn instantiate_func_sort(
541 &mut self,
542 app_expr: &fhir::Expr<'genv>,
543 fsort: rty::PolyFuncSort,
544 ) -> rty::FuncSort {
545 let args = fsort
546 .params()
547 .map(|kind| {
548 match kind {
549 rty::SortParamKind::Sort => rty::SortArg::Sort(self.next_sort_var()),
550 rty::SortParamKind::BvSize => rty::SortArg::BvSize(self.next_bv_size_var()),
551 }
552 })
553 .collect_vec();
554 self.sort_args_of_app
555 .insert(*app_expr, List::from_slice(&args));
556 fsort.instantiate(&args)
557 }
558
559 pub(crate) fn insert_node_sort(&mut self, fhir_id: FhirId, sort: rty::Sort) {
560 self.node_sort.insert(fhir_id, sort);
561 }
562
563 pub(crate) fn sort_of_bty(&self, bty: &fhir::BaseTy) -> rty::Sort {
564 self.node_sort
565 .get(&bty.fhir_id)
566 .unwrap_or_else(|| tracked_span_bug!("no sort found for bty: `{bty:?}`"))
567 .clone()
568 }
569
570 pub(crate) fn insert_path_args(&mut self, fhir_id: FhirId, args: rty::GenericArgs) {
571 self.path_args.insert(fhir_id, args);
572 }
573
574 pub(crate) fn path_args(&self, fhir_id: FhirId) -> rty::GenericArgs {
575 self.path_args
576 .get(&fhir_id)
577 .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
578 .clone()
579 }
580
581 pub(crate) fn insert_sort_for_alias_reft(&mut self, fhir_id: FhirId, fsort: rty::FuncSort) {
582 self.sort_of_alias_reft.insert(fhir_id, fsort);
583 }
584
585 fn sort_of_alias_reft(&self, fhir_id: FhirId) -> rty::FuncSort {
586 self.sort_of_alias_reft
587 .get(&fhir_id)
588 .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
589 .clone()
590 }
591
592 fn deeply_normalize_sorts<T: TypeFoldable + Clone>(
593 genv: GlobalEnv,
594 owner: FluxOwnerId,
595 t: &T,
596 ) -> QueryResult<T> {
597 let infcx = genv
598 .tcx()
599 .infer_ctxt()
600 .with_next_trait_solver(true)
601 .build(TypingMode::non_body_analysis());
602 if let Some(def_id) = owner.resolved_id() {
603 t.deeply_normalize_sorts(def_id, genv, &infcx)
604 } else {
605 Ok(t.clone())
606 }
607 }
608
609 pub(crate) fn normalize_sorts(&mut self) -> QueryResult {
614 for sort in self.node_sort.values_mut() {
615 *sort = Self::deeply_normalize_sorts(self.genv, self.owner, sort)?;
616 }
617 for fsort in self.sort_of_alias_reft.values_mut() {
618 *fsort = Self::deeply_normalize_sorts(self.genv, self.owner, fsort)?;
619 }
620 Ok(())
621 }
622}
623
624impl<'genv> InferCtxt<'genv, '_> {
625 pub(super) fn declare_param(&mut self, param: fhir::RefineParam<'genv>, sort: rty::Sort) {
626 self.params.insert(param.id, (param, sort));
627 }
628
629 fn is_coercible(&mut self, sort1: &rty::Sort, sort2: &rty::Sort, fhir_id: FhirId) -> bool {
634 if self.try_equate(sort1, sort2).is_some() {
635 return true;
636 }
637
638 let mut sort1 = sort1.clone();
639 let mut sort2 = sort2.clone();
640
641 let mut coercions = vec![];
642 if let Some((def_id, sort)) = self.is_single_field_struct(&sort1) {
643 coercions.push(rty::Coercion::Project(def_id));
644 sort1 = sort.clone();
645 }
646 if let Some((def_id, sort)) = self.is_single_field_struct(&sort2) {
647 coercions.push(rty::Coercion::Inject(def_id));
648 sort2 = sort.clone();
649 }
650 self.wfckresults.coercions_mut().insert(fhir_id, coercions);
651 self.try_equate(&sort1, &sort2).is_some()
652 }
653
654 fn is_coercible_from_func(
655 &mut self,
656 sort: &rty::Sort,
657 fhir_id: FhirId,
658 ) -> Option<rty::PolyFuncSort> {
659 if let rty::Sort::Func(fsort) = sort {
660 Some(fsort.clone())
661 } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
662 self.wfckresults
663 .coercions_mut()
664 .insert(fhir_id, vec![rty::Coercion::Inject(def_id)]);
665 Some(fsort.clone())
666 } else {
667 None
668 }
669 }
670
671 fn is_coercible_to_func(
672 &mut self,
673 sort: &rty::Sort,
674 fhir_id: FhirId,
675 ) -> Option<rty::PolyFuncSort> {
676 if let rty::Sort::Func(fsort) = sort {
677 Some(fsort.clone())
678 } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
679 self.wfckresults
680 .coercions_mut()
681 .insert(fhir_id, vec![rty::Coercion::Project(def_id)]);
682 Some(fsort.clone())
683 } else {
684 None
685 }
686 }
687
688 fn try_equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
689 let sort1 = self.resolve_vars_if_possible(sort1);
690 let sort2 = self.resolve_vars_if_possible(sort2);
691 self.try_equate_inner(&sort1, &sort2)
692 }
693
694 fn try_equate_inner(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
695 match (sort1, sort2) {
696 (rty::Sort::Infer(vid1), rty::Sort::Infer(vid2)) => {
697 self.sort_unification_table
698 .unify_var_var(*vid1, *vid2)
699 .ok()?;
700 }
701 (rty::Sort::Infer(vid), sort) | (sort, rty::Sort::Infer(vid)) => {
702 self.sort_unification_table
703 .unify_var_value(*vid, rty::SortVarVal::Solved(sort.clone()))
704 .ok()?;
705 }
706 (rty::Sort::App(ctor1, args1), rty::Sort::App(ctor2, args2)) => {
707 if ctor1 != ctor2 || args1.len() != args2.len() {
708 return None;
709 }
710 for (s1, s2) in iter::zip(args1, args2) {
711 self.try_equate_inner(s1, s2)?;
712 }
713 }
714 (rty::Sort::Tuple(sorts1), rty::Sort::Tuple(sorts2)) => {
715 if sorts1.len() != sorts2.len() {
716 return None;
717 }
718 for (s1, s2) in iter::zip(sorts1, sorts2) {
719 self.try_equate_inner(s1, s2)?;
720 }
721 }
722 (rty::Sort::BitVec(size1), rty::Sort::BitVec(size2)) => {
723 self.try_equate_bv_sizes(*size1, *size2)?;
724 }
725 _ if sort1 == sort2 => {}
726 _ => return None,
727 }
728 Some(sort1.clone())
729 }
730
731 fn try_equate_bv_sizes(
732 &mut self,
733 size1: rty::BvSize,
734 size2: rty::BvSize,
735 ) -> Option<rty::BvSize> {
736 match (size1, size2) {
737 (rty::BvSize::Infer(vid1), rty::BvSize::Infer(vid2)) => {
738 self.bv_size_unification_table
739 .unify_var_var(vid1, vid2)
740 .ok()?;
741 }
742 (rty::BvSize::Infer(vid), size) | (size, rty::BvSize::Infer(vid)) => {
743 self.bv_size_unification_table
744 .unify_var_value(vid, Some(size))
745 .ok()?;
746 }
747 _ if size1 == size2 => {}
748 _ => return None,
749 }
750 Some(size1)
751 }
752
753 fn next_sort_var(&mut self) -> rty::Sort {
754 rty::Sort::Infer(self.next_sort_vid(Default::default()))
755 }
756
757 fn next_sort_var_with_cstr(&mut self, cstr: rty::SortCstr) -> rty::Sort {
758 rty::Sort::Infer(self.next_sort_vid(rty::SortVarVal::Unsolved(cstr)))
759 }
760
761 pub(crate) fn next_sort_vid(&mut self, val: rty::SortVarVal) -> rty::SortVid {
762 self.sort_unification_table.new_key(val)
763 }
764
765 fn next_bv_size_var(&mut self) -> rty::BvSize {
766 rty::BvSize::Infer(self.next_bv_size_vid())
767 }
768
769 fn next_bv_size_vid(&mut self) -> rty::BvSizeVid {
770 self.bv_size_unification_table.new_key(None)
771 }
772
773 fn ensure_resolved_path(&mut self, path: &fhir::PathExpr) -> Result<rty::Sort> {
774 let sort = self.synth_path(path);
775 self.fully_resolve(&sort)
776 .map_err(|_| self.emit_err(errors::CannotInferSort::new(path.span)))
777 }
778
779 fn is_single_field_struct(&mut self, sort: &rty::Sort) -> Option<(DefId, rty::Sort)> {
780 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = sort
781 && let Some(variant) = sort_def.opt_struct_variant()
782 && let [sort] = &variant.field_sorts(sort_args)[..]
783 {
784 Some((sort_def.did(), sort.clone()))
785 } else {
786 None
787 }
788 }
789
790 pub(crate) fn into_results(mut self) -> Result<WfckResults> {
791 for (node, sort) in std::mem::take(&mut self.sort_of_literal) {
794 if let rty::Sort::Infer(vid) = &sort {
800 let _ = self
801 .sort_unification_table
802 .unify_var_value(*vid, rty::SortVarVal::Solved(rty::Sort::Int));
803 }
804
805 let sort = self
806 .fully_resolve(&sort)
807 .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
808 self.wfckresults.node_sorts_mut().insert(node.fhir_id, sort);
809 }
810
811 for (node, sort) in std::mem::take(&mut self.sort_of_bin_op) {
813 let sort = self
814 .fully_resolve(&sort)
815 .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
816 self.wfckresults
817 .bin_op_sorts_mut()
818 .insert(node.fhir_id, sort);
819 }
820
821 let allow_uninterpreted_cast =
822 self.owner
823 .as_rust()
824 .map_or_else(flux_config::allow_uninterpreted_cast, |def_id| {
825 self.genv
826 .infer_opts(def_id.local_id().def_id)
827 .allow_uninterpreted_cast
828 });
829
830 for (node, sort_args) in std::mem::take(&mut self.sort_args_of_app) {
832 let mut res = vec![];
833 for sort_arg in &sort_args {
834 let sort_arg = match sort_arg {
835 rty::SortArg::Sort(sort) => {
836 let sort = self
837 .fully_resolve(sort)
838 .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
839 rty::SortArg::Sort(sort)
840 }
841 rty::SortArg::BvSize(rty::BvSize::Infer(vid)) => {
842 let size = self
843 .bv_size_unification_table
844 .probe_value(*vid)
845 .ok_or_else(|| {
846 self.emit_err(errors::CannotInferSort::new(node.span))
847 })?;
848 rty::SortArg::BvSize(size)
849 }
850 _ => sort_arg.clone(),
851 };
852 res.push(sort_arg);
853 }
854 if let fhir::ExprKind::App(callee, _) = node.kind
855 && matches!(callee.res, fhir::Res::GlobalFunc(fhir::SpecFuncKind::Cast))
856 {
857 let [rty::SortArg::Sort(from), rty::SortArg::Sort(to)] = &res[..] else {
858 span_bug!(node.span, "invalid sort args!")
859 };
860 if !allow_uninterpreted_cast
861 && matches!(from.cast_kind(to), rty::CastKind::Uninterpreted)
862 {
863 return Err(self.emit_err(errors::InvalidCast::new(node.span, from, to)));
864 }
865 }
866 self.wfckresults
867 .fn_app_sorts_mut()
868 .insert(node.fhir_id, res.into());
869 }
870
871 for (_, (param, sort)) in std::mem::take(&mut self.params) {
873 let sort = self
874 .fully_resolve(&sort)
875 .map_err(|_| self.emit_err(errors::SortAnnotationNeeded::new(¶m)))?;
876 self.wfckresults.param_sorts_mut().insert(param.id, sort);
877 }
878
879 Ok(self.wfckresults)
880 }
881
882 pub(crate) fn infer_mode(&self, id: fhir::ParamId) -> fhir::InferMode {
883 fhir::InferMode::from_param_kind(self.params[&id].0.kind)
884 }
885
886 #[track_caller]
887 pub(crate) fn param_sort(&self, id: fhir::ParamId) -> rty::Sort {
888 self.params
889 .get(&id)
890 .unwrap_or_else(|| bug!("no entry found for `{id:?}`"))
891 .1
892 .clone()
893 }
894
895 fn shallow_resolve(&mut self, sort: &rty::Sort) -> rty::Sort {
896 sort.fold_with(&mut ShallowResolver { infcx: self })
897 }
898
899 fn resolve_vars_if_possible(&mut self, sort: &rty::Sort) -> rty::Sort {
900 sort.fold_with(&mut OpportunisticResolver { infcx: self })
901 }
902
903 pub(crate) fn fully_resolve(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, ()> {
904 sort.try_fold_with(&mut FullResolver { infcx: self })
905 }
906}
907
908pub(crate) struct ImplicitParamInferer<'a, 'genv, 'tcx> {
919 infcx: &'a mut InferCtxt<'genv, 'tcx>,
920 errors: Errors<'genv>,
921}
922
923impl<'a, 'genv, 'tcx> ImplicitParamInferer<'a, 'genv, 'tcx> {
924 pub(crate) fn infer(
925 infcx: &'a mut InferCtxt<'genv, 'tcx>,
926 node: &fhir::OwnerNode<'genv>,
927 ) -> Result {
928 let errors = Errors::new(infcx.genv.sess());
929 let mut vis = Self { infcx, errors };
930 vis.visit_node(node);
931 vis.errors.to_result()
932 }
933
934 fn infer_implicit_params(&mut self, idx: &fhir::Expr, expected: &rty::Sort) {
935 match idx.kind {
936 fhir::ExprKind::Var(QPathExpr::Resolved(var, Some(_))) => {
937 let (_, id) = var.res.expect_param();
938 let found = self.infcx.param_sort(id);
939 self.infcx.try_equate(&found, expected).unwrap_or_else(|| {
940 span_bug!(idx.span, "failed to equate sorts: `{found:?}` `{expected:?}`")
941 });
942 }
943 fhir::ExprKind::Record(flds) => {
944 if let Some(sorts) = expected.field_sorts() {
945 if flds.len() != sorts.len() {
946 self.errors.emit(errors::ArgCountMismatch::new(
947 Some(idx.span),
948 String::from("type"),
949 sorts.len(),
950 flds.len(),
951 ));
952 return;
953 }
954 for (f, sort) in iter::zip(flds, &sorts) {
955 self.infer_implicit_params(f, sort);
956 }
957 } else {
958 self.errors.emit(errors::ArgCountMismatch::new(
959 Some(idx.span),
960 String::from("type"),
961 1,
962 flds.len(),
963 ));
964 }
965 }
966 _ => {}
967 }
968 }
969}
970
971impl<'genv> fhir::visit::Visitor<'genv> for ImplicitParamInferer<'_, 'genv, '_> {
972 fn visit_ty(&mut self, ty: &fhir::Ty<'genv>) {
973 if let fhir::TyKind::Indexed(bty, idx) = &ty.kind {
974 let expected = self.infcx.sort_of_bty(bty);
975 self.infer_implicit_params(idx, &expected);
976 }
977 fhir::visit::walk_ty(self, ty);
978 }
979}
980
981impl InferCtxt<'_, '_> {
982 #[track_caller]
983 fn emit_sort_mismatch(
984 &mut self,
985 span: Span,
986 expected: &rty::Sort,
987 found: &rty::Sort,
988 ) -> ErrorGuaranteed {
989 let expected = self.resolve_vars_if_possible(expected);
990 let found = self.resolve_vars_if_possible(found);
991 self.emit_err(errors::SortMismatch::new(span, expected, found))
992 }
993
994 fn emit_field_not_found(&mut self, sort: &rty::Sort, field: Ident) -> ErrorGuaranteed {
995 self.emit_err(errors::FieldNotFound::new(sort.clone(), field))
996 }
997
998 #[track_caller]
999 fn emit_err<'b>(&'b self, err: impl Diagnostic<'b>) -> ErrorGuaranteed {
1000 self.genv.sess().emit_err(err)
1001 }
1002}
1003
1004struct ShallowResolver<'a, 'genv, 'tcx> {
1005 infcx: &'a mut InferCtxt<'genv, 'tcx>,
1006}
1007
1008impl TypeFolder for ShallowResolver<'_, '_, '_> {
1009 fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
1010 match sort {
1011 rty::Sort::Infer(vid) => {
1012 self.infcx
1018 .sort_unification_table
1019 .probe_value(*vid)
1020 .map_solved(|sort| sort.fold_with(self))
1021 .solved_or(sort)
1022 }
1023 rty::Sort::BitVec(rty::BvSize::Infer(vid)) => {
1024 self.infcx
1025 .bv_size_unification_table
1026 .probe_value(*vid)
1027 .map(rty::Sort::BitVec)
1028 .unwrap_or(sort.clone())
1029 }
1030 _ => sort.clone(),
1031 }
1032 }
1033}
1034
1035struct OpportunisticResolver<'a, 'genv, 'tcx> {
1036 infcx: &'a mut InferCtxt<'genv, 'tcx>,
1037}
1038
1039impl TypeFolder for OpportunisticResolver<'_, '_, '_> {
1040 fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
1041 let s = self.infcx.shallow_resolve(sort);
1042 s.super_fold_with(self)
1043 }
1044}
1045
1046struct FullResolver<'a, 'genv, 'tcx> {
1047 infcx: &'a mut InferCtxt<'genv, 'tcx>,
1048}
1049
1050impl FallibleTypeFolder for FullResolver<'_, '_, '_> {
1051 type Error = ();
1052
1053 fn try_fold_sort(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, Self::Error> {
1054 let s = self.infcx.shallow_resolve(sort);
1055 match s {
1056 rty::Sort::Infer(_) | rty::Sort::BitVec(rty::BvSize::Infer(_)) => Err(()),
1057 _ => s.try_super_fold_with(self),
1058 }
1059 }
1060}
1061
1062#[derive_where(Default)]
1066struct NodeMap<'genv, T> {
1067 map: FxHashMap<FhirId, (fhir::Expr<'genv>, T)>,
1068}
1069
1070impl<'genv, T> NodeMap<'genv, T> {
1071 fn insert(&mut self, node: fhir::Expr<'genv>, data: T) {
1073 assert!(self.map.insert(node.fhir_id, (node, data)).is_none());
1074 }
1075}
1076
1077impl<'genv, T> IntoIterator for NodeMap<'genv, T> {
1078 type Item = (fhir::Expr<'genv>, T);
1079
1080 type IntoIter = std::collections::hash_map::IntoValues<FhirId, Self::Item>;
1081
1082 fn into_iter(self) -> Self::IntoIter {
1083 self.map.into_values()
1084 }
1085}