1use std::iter;
2
3use derive_where::derive_where;
4use ena::unify::InPlaceUnificationTable;
5use flux_common::{bug, iter::IterExt, span_bug, tracked_span_bug};
6use flux_errors::{ErrorGuaranteed, Errors};
7use flux_infer::projections::NormalizeExt;
8use flux_middle::{
9 fhir::{self, FhirId, FluxOwnerId, QPathExpr, visit::Visitor as _},
10 global_env::GlobalEnv,
11 queries::QueryResult,
12 rty::{
13 self, AdtSortDef, FuncSort, List, WfckResults,
14 fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable},
15 },
16};
17use itertools::{Itertools, izip};
18use rustc_data_structures::unord::UnordMap;
19use rustc_errors::Diagnostic;
20use rustc_hash::FxHashMap;
21use rustc_hir::def::DefKind;
22use rustc_middle::ty::TypingMode;
23use rustc_span::{Span, def_id::DefId, symbol::Ident};
24
25use super::errors;
26use crate::rustc_infer::infer::TyCtxtInferExt;
27
28type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
29
30pub(super) struct InferCtxt<'genv, 'tcx> {
31 pub genv: GlobalEnv<'genv, 'tcx>,
32 pub owner: FluxOwnerId,
33 pub wfckresults: WfckResults,
34 sort_unification_table: InPlaceUnificationTable<rty::SortVid>,
35 bv_size_unification_table: InPlaceUnificationTable<rty::BvSizeVid>,
36 params: FxHashMap<fhir::ParamId, (fhir::RefineParam<'genv>, rty::Sort)>,
37 node_sort: FxHashMap<FhirId, rty::Sort>,
38 path_args: UnordMap<FhirId, rty::GenericArgs>,
39 sort_of_alias_reft: FxHashMap<FhirId, rty::FuncSort>,
40 sort_of_literal: NodeMap<'genv, rty::Sort>,
41 sort_of_bin_op: NodeMap<'genv, rty::Sort>,
42 sort_args_of_app: NodeMap<'genv, List<rty::SortArg>>,
43}
44
45pub fn prim_op_sort(op: &fhir::BinOp) -> Option<(Vec<rty::Sort>, rty::Sort)> {
46 match op {
47 fhir::BinOp::BitAnd
48 | fhir::BinOp::BitOr
49 | fhir::BinOp::BitXor
50 | fhir::BinOp::BitShl
51 | fhir::BinOp::BitShr => Some((vec![rty::Sort::Int, rty::Sort::Int], rty::Sort::Int)),
52 _ => None,
53 }
54}
55
56impl<'genv, 'tcx> InferCtxt<'genv, 'tcx> {
57 pub(super) fn new(genv: GlobalEnv<'genv, 'tcx>, owner: FluxOwnerId) -> Self {
58 let mut sort_unification_table = InPlaceUnificationTable::new();
60 sort_unification_table.new_key(Default::default());
61 Self {
62 genv,
63 owner,
64 wfckresults: WfckResults::new(owner),
65 sort_unification_table,
66 bv_size_unification_table: InPlaceUnificationTable::new(),
67 params: Default::default(),
68 node_sort: Default::default(),
69 path_args: Default::default(),
70 sort_of_alias_reft: Default::default(),
71 sort_of_literal: Default::default(),
72 sort_of_bin_op: Default::default(),
73 sort_args_of_app: Default::default(),
74 }
75 }
76
77 fn check_abs(
78 &mut self,
79 expr: &fhir::Expr<'genv>,
80 params: &[fhir::RefineParam],
81 body: &fhir::Expr<'genv>,
82 expected: &rty::Sort,
83 ) -> Result {
84 let Some(fsort) = self.is_coercible_from_func(expected, expr.fhir_id) else {
85 return Err(self.emit_err(errors::UnexpectedFun::new(expr.span, expected)));
86 };
87
88 let fsort = fsort.expect_mono();
89
90 if params.len() != fsort.inputs().len() {
91 return Err(self.emit_err(errors::ParamCountMismatch::new(
92 expr.span,
93 fsort.inputs().len(),
94 params.len(),
95 )));
96 }
97 iter::zip(params, fsort.inputs()).try_for_each_exhaust(|(param, expected)| {
98 let found = self.param_sort(param.id);
99 if self.try_equate(&found, expected).is_none() {
100 return Err(self.emit_sort_mismatch(param.span, expected, &found));
101 }
102 Ok(())
103 })?;
104 self.check_expr(body, fsort.output())?;
105 self.wfckresults
106 .node_sorts_mut()
107 .insert(body.fhir_id, fsort.output().clone());
108 Ok(())
109 }
110
111 fn check_field_exprs(
112 &mut self,
113 expr_span: Span,
114 sort_def: &AdtSortDef,
115 sort_args: &[rty::Sort],
116 field_exprs: &[fhir::FieldExpr<'genv>],
117 spread: Option<&fhir::Spread<'genv>>,
118 expected: &rty::Sort,
119 ) -> Result {
120 let sort_by_field_name = sort_def.struct_variant().sort_by_field_name(sort_args);
121 let mut used_fields = FxHashMap::default();
122 for expr in field_exprs {
123 let Some(sort) = sort_by_field_name.get(&expr.ident.name) else {
125 return Err(self.emit_err(errors::FieldNotFound::new(expected.clone(), expr.ident)));
126 };
127 if let Some(old_field) = used_fields.insert(expr.ident.name, expr.ident) {
128 return Err(self.emit_err(errors::DuplicateFieldUsed::new(expr.ident, old_field)));
129 }
130 self.check_expr(&expr.expr, sort)?;
131 }
132 if let Some(spread) = spread {
133 self.check_expr(&spread.expr, expected)
135 } else if sort_by_field_name.len() != used_fields.len() {
136 let missing_fields = sort_by_field_name
138 .into_keys()
139 .filter(|x| !used_fields.contains_key(x))
140 .collect();
141 Err(self.emit_err(errors::ConstructorMissingFields::new(expr_span, missing_fields)))
142 } else {
143 Ok(())
144 }
145 }
146
147 fn check_constructor(
148 &mut self,
149 expr: &fhir::Expr,
150 field_exprs: &[fhir::FieldExpr<'genv>],
151 spread: Option<&fhir::Spread<'genv>>,
152 expected: &rty::Sort,
153 ) -> Result {
154 let expected = self.resolve_vars_if_possible(expected);
155 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = &expected {
156 self.wfckresults
157 .record_ctors_mut()
158 .insert(expr.fhir_id, sort_def.did());
159 self.check_field_exprs(expr.span, sort_def, sort_args, field_exprs, spread, &expected)?;
160 Ok(())
161 } else {
162 Err(self.emit_err(errors::UnexpectedConstructor::new(expr.span, &expected)))
163 }
164 }
165
166 fn check_record(
167 &mut self,
168 arg: &fhir::Expr<'genv>,
169 flds: &[fhir::Expr<'genv>],
170 expected: &rty::Sort,
171 ) -> Result {
172 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
173 let sorts = sort_def.struct_variant().field_sorts(sort_args);
174 if flds.len() != sorts.len() {
175 return Err(self.emit_err(errors::ArgCountMismatch::new(
176 Some(arg.span),
177 String::from("type"),
178 sorts.len(),
179 flds.len(),
180 )));
181 }
182 self.wfckresults
183 .record_ctors_mut()
184 .insert(arg.fhir_id, sort_def.did());
185
186 izip!(flds, &sorts)
187 .map(|(arg, expected)| self.check_expr(arg, expected))
188 .try_collect_exhaust()
189 } else {
190 Err(self.emit_err(errors::ArgCountMismatch::new(
191 Some(arg.span),
192 String::from("type"),
193 1,
194 flds.len(),
195 )))
196 }
197 }
198
199 pub(super) fn check_expr(&mut self, expr: &fhir::Expr<'genv>, expected: &rty::Sort) -> Result {
200 match expr.kind {
201 fhir::ExprKind::IfThenElse(p, e1, e2) => {
202 self.check_expr(p, &rty::Sort::Bool)?;
203 self.check_expr(e1, expected)?;
204 self.check_expr(e2, expected)?;
205 return Ok(());
206 }
207 fhir::ExprKind::Abs(params, body) => {
208 self.check_abs(expr, params, body, expected)?;
209 return Ok(());
210 }
211 fhir::ExprKind::Record(fields) => {
212 self.check_record(expr, fields, expected)?;
213 return Ok(());
214 }
215 fhir::ExprKind::Constructor(None, exprs, spread) => {
216 self.check_constructor(expr, exprs, spread, expected)?;
217 return Ok(());
218 }
219 fhir::ExprKind::Tuple(exprs) => {
220 if let rty::Sort::Tuple(expected_sorts) = expected
222 && exprs.len() == expected_sorts.len()
223 {
224 for (expr, expected_sort) in iter::zip(exprs, expected_sorts) {
225 self.check_expr(expr, expected_sort)?;
226 }
227 return Ok(());
228 }
229 }
230 fhir::ExprKind::Err(_) => {
231 return Ok(());
233 }
234 fhir::ExprKind::UnaryOp(..)
235 | fhir::ExprKind::SetLiteral(..)
236 | fhir::ExprKind::BinaryOp(..)
237 | fhir::ExprKind::Dot(..)
238 | fhir::ExprKind::App(..)
239 | fhir::ExprKind::Alias(..)
240 | fhir::ExprKind::Var(..)
241 | fhir::ExprKind::Literal(..)
242 | fhir::ExprKind::BoundedQuant(..)
243 | fhir::ExprKind::Block(..)
244 | fhir::ExprKind::Constructor(..)
245 | fhir::ExprKind::PrimApp(..) => {}
246 }
247 let found = self.synth_expr(expr)?;
249 let found = self.resolve_vars_if_possible(&found);
250 let expected = self.resolve_vars_if_possible(expected);
251 if !self.is_coercible(&found, &expected, expr.fhir_id) {
252 return Err(self.emit_sort_mismatch(expr.span, &expected, &found));
253 }
254 Ok(())
255 }
256
257 pub(super) fn check_loc(&mut self, loc: &fhir::PathExpr) -> Result {
258 let found = self.synth_path(loc);
259 if found == rty::Sort::Loc {
260 Ok(())
261 } else {
262 Err(self.emit_sort_mismatch(loc.span, &rty::Sort::Loc, &found))
263 }
264 }
265
266 fn synth_lit(&mut self, lit: fhir::Lit, expr: &fhir::Expr<'genv>) -> rty::Sort {
267 match lit {
268 fhir::Lit::Int(_, Some(fhir::NumLitKind::Int)) => rty::Sort::Int,
269 fhir::Lit::Int(_, Some(fhir::NumLitKind::Real)) => rty::Sort::Real,
270 fhir::Lit::Int(_, None) => {
271 let sort = self.next_sort_var_with_cstr(rty::SortCstr::NUMERIC);
272 self.sort_of_literal.insert(*expr, sort.clone());
273 sort
274 }
275 fhir::Lit::Bool(_) => rty::Sort::Bool,
276 fhir::Lit::Str(_) => rty::Sort::Str,
277 fhir::Lit::Char(_) => rty::Sort::Char,
278 }
279 }
280
281 fn synth_prim_app(
282 &mut self,
283 op: &fhir::BinOp,
284 e1: &fhir::Expr<'genv>,
285 e2: &fhir::Expr<'genv>,
286 span: Span,
287 ) -> Result<rty::Sort> {
288 let Some((inputs, output)) = prim_op_sort(op) else {
289 return Err(self.emit_err(errors::UnsupportedPrimOp::new(span, *op)));
290 };
291 let [sort1, sort2] = &inputs[..] else {
292 return Err(self.emit_err(errors::ArgCountMismatch::new(
293 Some(span),
294 String::from("primop app"),
295 inputs.len(),
296 2,
297 )));
298 };
299 self.check_expr(e1, sort1)?;
300 self.check_expr(e2, sort2)?;
301 Ok(output)
302 }
303
304 fn synth_expr(&mut self, expr: &fhir::Expr<'genv>) -> Result<rty::Sort> {
305 match expr.kind {
306 fhir::ExprKind::Var(QPathExpr::Resolved(path, _)) => Ok(self.synth_path(&path)),
307 fhir::ExprKind::Var(QPathExpr::TypeRelative(..)) => {
308 Ok(self.synth_type_relative_path(expr))
309 }
310 fhir::ExprKind::Literal(lit) => Ok(self.synth_lit(lit, expr)),
311 fhir::ExprKind::BinaryOp(op, e1, e2) => self.synth_binary_op(expr, op, e1, e2),
312 fhir::ExprKind::PrimApp(op, e1, e2) => self.synth_prim_app(&op, e1, e2, expr.span),
313 fhir::ExprKind::UnaryOp(op, e) => self.synth_unary_op(op, e),
314 fhir::ExprKind::App(callee, args) => {
315 let sort = self.ensure_resolved_path(&callee)?;
316 let Some(poly_fsort) = self.is_coercible_to_func(&sort, callee.fhir_id) else {
317 return Err(self.emit_err(errors::ExpectedFun::new(callee.span, &sort)));
318 };
319 let fsort = self.instantiate_func_sort(expr, poly_fsort);
320 self.synth_app(fsort, args, expr.span)
321 }
322 fhir::ExprKind::BoundedQuant(.., body) => {
323 self.check_expr(body, &rty::Sort::Bool)?;
324 Ok(rty::Sort::Bool)
325 }
326 fhir::ExprKind::Alias(_alias_reft, args) => {
327 let fsort = self.sort_of_alias_reft(expr.fhir_id);
330 self.synth_app(fsort, args, expr.span)
331 }
332 fhir::ExprKind::IfThenElse(p, e1, e2) => {
333 self.check_expr(p, &rty::Sort::Bool)?;
334 let sort = self.synth_expr(e1)?;
335 self.check_expr(e2, &sort)?;
336 Ok(sort)
337 }
338 fhir::ExprKind::Dot(base, fld) => {
339 let sort = self.synth_expr(base)?;
340 let sort = self
341 .fully_resolve(&sort)
342 .map_err(|_| self.emit_err(errors::CannotInferSort::new(base.span)))?;
343 match &sort {
344 rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) => {
345 let (proj, sort) = sort_def
346 .struct_variant()
347 .field_by_name(sort_def.did(), sort_args, fld.name)
348 .ok_or_else(|| self.emit_field_not_found(&sort, fld))?;
349 self.wfckresults
350 .field_projs_mut()
351 .insert(expr.fhir_id, proj);
352 Ok(sort)
353 }
354 rty::Sort::Tuple(sorts) => {
355 if let Ok(field_idx) = fld.name.as_str().parse::<usize>()
357 && field_idx < sorts.len()
358 {
359 let proj = rty::FieldProj::Tuple {
360 arity: sorts.len(),
361 field: field_idx as u32,
362 };
363 self.wfckresults
364 .field_projs_mut()
365 .insert(expr.fhir_id, proj);
366 Ok(sorts[field_idx].clone())
367 } else {
368 Err(self.emit_field_not_found(&sort, fld))
369 }
370 }
371 rty::Sort::Bool | rty::Sort::Int | rty::Sort::Real => {
372 Err(self.emit_err(errors::InvalidPrimitiveDotAccess::new(&sort, fld)))
373 }
374 _ => Err(self.emit_field_not_found(&sort, fld)),
375 }
376 }
377 fhir::ExprKind::Constructor(Some(path), field_exprs, spread) => {
378 let path_def_id = match path.res {
382 fhir::Res::Def(DefKind::Enum | DefKind::Struct, def_id) => def_id,
383 _ => span_bug!(expr.span, "unexpected path in constructor"),
384 };
385 let sort_def = self
386 .genv
387 .adt_sort_def_of(path_def_id)
388 .map_err(|e| self.emit_err(e))?;
389 let fresh_args: rty::List<_> = (0..sort_def.param_count())
391 .map(|_| self.next_sort_var())
392 .collect();
393 let sort = rty::Sort::App(rty::SortCtor::Adt(sort_def.clone()), fresh_args.clone());
394 self.check_field_exprs(
396 expr.span,
397 &sort_def,
398 &fresh_args,
399 field_exprs,
400 spread,
401 &sort,
402 )?;
403 Ok(sort)
404 }
405 fhir::ExprKind::Block(decls, body) => {
406 for decl in decls {
407 self.check_expr(&decl.init, &self.param_sort(decl.param.id))?;
408 }
409 self.synth_expr(body)
410 }
411 fhir::ExprKind::SetLiteral(elems) => {
412 let elem_sort = self.next_sort_var();
413 for elem in elems {
414 self.check_expr(elem, &elem_sort)?;
415 }
416 Ok(rty::Sort::App(rty::SortCtor::Set, List::singleton(elem_sort)))
417 }
418 fhir::ExprKind::Tuple(exprs) => {
419 let sorts = exprs
420 .iter()
421 .map(|expr| self.synth_expr(expr))
422 .try_collect()?;
423 Ok(rty::Sort::Tuple(sorts))
424 }
425 _ => Err(self.emit_err(errors::CannotInferSort::new(expr.span))),
426 }
427 }
428
429 fn synth_path(&mut self, path: &fhir::PathExpr) -> rty::Sort {
430 self.node_sort
431 .get(&path.fhir_id)
432 .unwrap_or_else(|| tracked_span_bug!("no sort found for path: `{path:?}`"))
433 .clone()
434 }
435
436 fn synth_type_relative_path(&mut self, expr: &fhir::Expr) -> rty::Sort {
437 self.node_sort
438 .get(&expr.fhir_id)
439 .unwrap_or_else(|| tracked_span_bug!("no sort found for: `{expr:?}`"))
440 .clone()
441 }
442
443 fn synth_binary_op(
444 &mut self,
445 expr: &fhir::Expr<'genv>,
446 op: fhir::BinOp,
447 e1: &fhir::Expr<'genv>,
448 e2: &fhir::Expr<'genv>,
449 ) -> Result<rty::Sort> {
450 match op {
451 fhir::BinOp::Or | fhir::BinOp::And | fhir::BinOp::Iff | fhir::BinOp::Imp => {
452 self.check_expr(e1, &rty::Sort::Bool)?;
453 self.check_expr(e2, &rty::Sort::Bool)?;
454 Ok(rty::Sort::Bool)
455 }
456 fhir::BinOp::Eq | fhir::BinOp::Ne => {
457 let sort = self.next_sort_var();
458 self.check_expr(e1, &sort)?;
459 self.check_expr(e2, &sort)?;
460 Ok(rty::Sort::Bool)
461 }
462 fhir::BinOp::Lt | fhir::BinOp::Le | fhir::BinOp::Gt | fhir::BinOp::Ge => {
463 let sort = self.next_sort_var();
464 self.check_expr(e1, &sort)?;
465 self.check_expr(e2, &sort)?;
466 self.sort_of_bin_op.insert(*expr, sort.clone());
467 Ok(rty::Sort::Bool)
468 }
469 fhir::BinOp::Add
470 | fhir::BinOp::Sub
471 | fhir::BinOp::Mul
472 | fhir::BinOp::Div
473 | fhir::BinOp::Mod
474 | fhir::BinOp::BitAnd
475 | fhir::BinOp::BitOr
476 | fhir::BinOp::BitXor
477 | fhir::BinOp::BitShl
478 | fhir::BinOp::BitShr => {
479 let sort = self.next_sort_var_with_cstr(rty::SortCstr::from_bin_op(op));
480 self.check_expr(e1, &sort)?;
481 self.check_expr(e2, &sort)?;
482 self.sort_of_bin_op.insert(*expr, sort.clone());
483 Ok(sort)
484 }
485 }
486 }
487
488 fn synth_unary_op(&mut self, op: fhir::UnOp, e: &fhir::Expr<'genv>) -> Result<rty::Sort> {
489 match op {
490 fhir::UnOp::Not => {
491 self.check_expr(e, &rty::Sort::Bool)?;
492 Ok(rty::Sort::Bool)
493 }
494 fhir::UnOp::Neg => {
495 self.check_expr(e, &rty::Sort::Int)?;
496 Ok(rty::Sort::Int)
497 }
498 }
499 }
500
501 fn synth_app(
502 &mut self,
503 fsort: FuncSort,
504 args: &[fhir::Expr<'genv>],
505 span: Span,
506 ) -> Result<rty::Sort> {
507 if args.len() != fsort.inputs().len() {
508 return Err(self.emit_err(errors::ArgCountMismatch::new(
509 Some(span),
510 String::from("function"),
511 fsort.inputs().len(),
512 args.len(),
513 )));
514 }
515
516 iter::zip(args, fsort.inputs())
517 .try_for_each_exhaust(|(arg, formal)| self.check_expr(arg, formal))?;
518
519 Ok(fsort.output().clone())
520 }
521
522 fn instantiate_func_sort(
523 &mut self,
524 app_expr: &fhir::Expr<'genv>,
525 fsort: rty::PolyFuncSort,
526 ) -> rty::FuncSort {
527 let args = fsort
528 .params()
529 .map(|kind| {
530 match kind {
531 rty::SortParamKind::Sort => rty::SortArg::Sort(self.next_sort_var()),
532 rty::SortParamKind::BvSize => rty::SortArg::BvSize(self.next_bv_size_var()),
533 }
534 })
535 .collect_vec();
536 self.sort_args_of_app
537 .insert(*app_expr, List::from_slice(&args));
538 fsort.instantiate(&args)
539 }
540
541 pub(crate) fn insert_node_sort(&mut self, fhir_id: FhirId, sort: rty::Sort) {
542 self.node_sort.insert(fhir_id, sort);
543 }
544
545 pub(crate) fn sort_of_bty(&self, bty: &fhir::BaseTy) -> rty::Sort {
546 self.node_sort
547 .get(&bty.fhir_id)
548 .unwrap_or_else(|| tracked_span_bug!("no sort found for bty: `{bty:?}`"))
549 .clone()
550 }
551
552 pub(crate) fn insert_path_args(&mut self, fhir_id: FhirId, args: rty::GenericArgs) {
553 self.path_args.insert(fhir_id, args);
554 }
555
556 pub(crate) fn path_args(&self, fhir_id: FhirId) -> rty::GenericArgs {
557 self.path_args
558 .get(&fhir_id)
559 .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
560 .clone()
561 }
562
563 pub(crate) fn insert_sort_for_alias_reft(&mut self, fhir_id: FhirId, fsort: rty::FuncSort) {
564 self.sort_of_alias_reft.insert(fhir_id, fsort);
565 }
566
567 fn sort_of_alias_reft(&self, fhir_id: FhirId) -> rty::FuncSort {
568 self.sort_of_alias_reft
569 .get(&fhir_id)
570 .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
571 .clone()
572 }
573
574 fn deeply_normalize_sorts<T: TypeFoldable + Clone>(
575 genv: GlobalEnv,
576 owner: FluxOwnerId,
577 t: &T,
578 ) -> QueryResult<T> {
579 let infcx = genv
580 .tcx()
581 .infer_ctxt()
582 .with_next_trait_solver(true)
583 .build(TypingMode::non_body_analysis());
584 if let Some(def_id) = owner.resolved_id() {
585 t.deeply_normalize_sorts(def_id, genv, &infcx)
586 } else {
587 Ok(t.clone())
588 }
589 }
590
591 pub(crate) fn normalize_sorts(&mut self) -> QueryResult {
596 for sort in self.node_sort.values_mut() {
597 *sort = Self::deeply_normalize_sorts(self.genv, self.owner, sort)?;
598 }
599 for fsort in self.sort_of_alias_reft.values_mut() {
600 *fsort = Self::deeply_normalize_sorts(self.genv, self.owner, fsort)?;
601 }
602 Ok(())
603 }
604}
605
606impl<'genv> InferCtxt<'genv, '_> {
607 pub(super) fn declare_param(&mut self, param: fhir::RefineParam<'genv>, sort: rty::Sort) {
608 self.params.insert(param.id, (param, sort));
609 }
610
611 fn is_coercible(&mut self, sort1: &rty::Sort, sort2: &rty::Sort, fhir_id: FhirId) -> bool {
616 if self.try_equate(sort1, sort2).is_some() {
617 return true;
618 }
619
620 let mut sort1 = sort1.clone();
621 let mut sort2 = sort2.clone();
622
623 let mut coercions = vec![];
624 if let Some((def_id, sort)) = self.is_single_field_struct(&sort1) {
625 coercions.push(rty::Coercion::Project(def_id));
626 sort1 = sort.clone();
627 }
628 if let Some((def_id, sort)) = self.is_single_field_struct(&sort2) {
629 coercions.push(rty::Coercion::Inject(def_id));
630 sort2 = sort.clone();
631 }
632 self.wfckresults.coercions_mut().insert(fhir_id, coercions);
633 self.try_equate(&sort1, &sort2).is_some()
634 }
635
636 fn is_coercible_from_func(
637 &mut self,
638 sort: &rty::Sort,
639 fhir_id: FhirId,
640 ) -> Option<rty::PolyFuncSort> {
641 if let rty::Sort::Func(fsort) = sort {
642 Some(fsort.clone())
643 } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
644 self.wfckresults
645 .coercions_mut()
646 .insert(fhir_id, vec![rty::Coercion::Inject(def_id)]);
647 Some(fsort.clone())
648 } else {
649 None
650 }
651 }
652
653 fn is_coercible_to_func(
654 &mut self,
655 sort: &rty::Sort,
656 fhir_id: FhirId,
657 ) -> Option<rty::PolyFuncSort> {
658 if let rty::Sort::Func(fsort) = sort {
659 Some(fsort.clone())
660 } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
661 self.wfckresults
662 .coercions_mut()
663 .insert(fhir_id, vec![rty::Coercion::Project(def_id)]);
664 Some(fsort.clone())
665 } else {
666 None
667 }
668 }
669
670 fn try_equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
671 let sort1 = self.resolve_vars_if_possible(sort1);
672 let sort2 = self.resolve_vars_if_possible(sort2);
673 self.try_equate_inner(&sort1, &sort2)
674 }
675
676 fn try_equate_inner(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
677 match (sort1, sort2) {
678 (rty::Sort::Infer(vid1), rty::Sort::Infer(vid2)) => {
679 self.sort_unification_table
680 .unify_var_var(*vid1, *vid2)
681 .ok()?;
682 }
683 (rty::Sort::Infer(vid), sort) | (sort, rty::Sort::Infer(vid)) => {
684 self.sort_unification_table
685 .unify_var_value(*vid, rty::SortVarVal::Solved(sort.clone()))
686 .ok()?;
687 }
688 (rty::Sort::App(ctor1, args1), rty::Sort::App(ctor2, args2)) => {
689 if ctor1 != ctor2 || args1.len() != args2.len() {
690 return None;
691 }
692 for (s1, s2) in iter::zip(args1, args2) {
693 self.try_equate_inner(s1, s2)?;
694 }
695 }
696 (rty::Sort::Tuple(sorts1), rty::Sort::Tuple(sorts2)) => {
697 if sorts1.len() != sorts2.len() {
698 return None;
699 }
700 for (s1, s2) in iter::zip(sorts1, sorts2) {
701 self.try_equate_inner(s1, s2)?;
702 }
703 }
704 (rty::Sort::BitVec(size1), rty::Sort::BitVec(size2)) => {
705 self.try_equate_bv_sizes(*size1, *size2)?;
706 }
707 _ if sort1 == sort2 => {}
708 _ => return None,
709 }
710 Some(sort1.clone())
711 }
712
713 fn try_equate_bv_sizes(
714 &mut self,
715 size1: rty::BvSize,
716 size2: rty::BvSize,
717 ) -> Option<rty::BvSize> {
718 match (size1, size2) {
719 (rty::BvSize::Infer(vid1), rty::BvSize::Infer(vid2)) => {
720 self.bv_size_unification_table
721 .unify_var_var(vid1, vid2)
722 .ok()?;
723 }
724 (rty::BvSize::Infer(vid), size) | (size, rty::BvSize::Infer(vid)) => {
725 self.bv_size_unification_table
726 .unify_var_value(vid, Some(size))
727 .ok()?;
728 }
729 _ if size1 == size2 => {}
730 _ => return None,
731 }
732 Some(size1)
733 }
734
735 fn next_sort_var(&mut self) -> rty::Sort {
736 rty::Sort::Infer(self.next_sort_vid(Default::default()))
737 }
738
739 fn next_sort_var_with_cstr(&mut self, cstr: rty::SortCstr) -> rty::Sort {
740 rty::Sort::Infer(self.next_sort_vid(rty::SortVarVal::Unsolved(cstr)))
741 }
742
743 pub(crate) fn next_sort_vid(&mut self, val: rty::SortVarVal) -> rty::SortVid {
744 self.sort_unification_table.new_key(val)
745 }
746
747 fn next_bv_size_var(&mut self) -> rty::BvSize {
748 rty::BvSize::Infer(self.next_bv_size_vid())
749 }
750
751 fn next_bv_size_vid(&mut self) -> rty::BvSizeVid {
752 self.bv_size_unification_table.new_key(None)
753 }
754
755 fn ensure_resolved_path(&mut self, path: &fhir::PathExpr) -> Result<rty::Sort> {
756 let sort = self.synth_path(path);
757 self.fully_resolve(&sort)
758 .map_err(|_| self.emit_err(errors::CannotInferSort::new(path.span)))
759 }
760
761 fn is_single_field_struct(&mut self, sort: &rty::Sort) -> Option<(DefId, rty::Sort)> {
762 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = sort
763 && let Some(variant) = sort_def.opt_struct_variant()
764 && let [sort] = &variant.field_sorts(sort_args)[..]
765 {
766 Some((sort_def.did(), sort.clone()))
767 } else {
768 None
769 }
770 }
771
772 pub(crate) fn into_results(mut self) -> Result<WfckResults> {
773 for (node, sort) in std::mem::take(&mut self.sort_of_literal) {
776 if let rty::Sort::Infer(vid) = &sort {
782 let _ = self
783 .sort_unification_table
784 .unify_var_value(*vid, rty::SortVarVal::Solved(rty::Sort::Int));
785 }
786
787 let sort = self
788 .fully_resolve(&sort)
789 .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
790 self.wfckresults.node_sorts_mut().insert(node.fhir_id, sort);
791 }
792
793 for (node, sort) in std::mem::take(&mut self.sort_of_bin_op) {
795 let sort = self
796 .fully_resolve(&sort)
797 .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
798 self.wfckresults
799 .bin_op_sorts_mut()
800 .insert(node.fhir_id, sort);
801 }
802
803 let allow_uninterpreted_cast =
804 self.owner
805 .as_rust()
806 .map_or_else(flux_config::allow_uninterpreted_cast, |def_id| {
807 self.genv
808 .infer_opts(def_id.local_id().def_id)
809 .allow_uninterpreted_cast
810 });
811
812 for (node, sort_args) in std::mem::take(&mut self.sort_args_of_app) {
814 let mut res = vec![];
815 for sort_arg in &sort_args {
816 let sort_arg = match sort_arg {
817 rty::SortArg::Sort(sort) => {
818 let sort = self
819 .fully_resolve(sort)
820 .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
821 rty::SortArg::Sort(sort)
822 }
823 rty::SortArg::BvSize(rty::BvSize::Infer(vid)) => {
824 let size = self
825 .bv_size_unification_table
826 .probe_value(*vid)
827 .ok_or_else(|| {
828 self.emit_err(errors::CannotInferSort::new(node.span))
829 })?;
830 rty::SortArg::BvSize(size)
831 }
832 _ => sort_arg.clone(),
833 };
834 res.push(sort_arg);
835 }
836 if let fhir::ExprKind::App(callee, _) = node.kind
837 && matches!(callee.res, fhir::Res::GlobalFunc(fhir::SpecFuncKind::Cast))
838 {
839 let [rty::SortArg::Sort(from), rty::SortArg::Sort(to)] = &res[..] else {
840 span_bug!(node.span, "invalid sort args!")
841 };
842 if !allow_uninterpreted_cast
843 && matches!(from.cast_kind(to), rty::CastKind::Uninterpreted)
844 {
845 return Err(self.emit_err(errors::InvalidCast::new(node.span, from, to)));
846 }
847 }
848 self.wfckresults
849 .fn_app_sorts_mut()
850 .insert(node.fhir_id, res.into());
851 }
852
853 for (_, (param, sort)) in std::mem::take(&mut self.params) {
855 let sort = self
856 .fully_resolve(&sort)
857 .map_err(|_| self.emit_err(errors::SortAnnotationNeeded::new(¶m)))?;
858 self.wfckresults.param_sorts_mut().insert(param.id, sort);
859 }
860
861 Ok(self.wfckresults)
862 }
863
864 pub(crate) fn infer_mode(&self, id: fhir::ParamId) -> fhir::InferMode {
865 fhir::InferMode::from_param_kind(self.params[&id].0.kind)
866 }
867
868 #[track_caller]
869 pub(crate) fn param_sort(&self, id: fhir::ParamId) -> rty::Sort {
870 self.params
871 .get(&id)
872 .unwrap_or_else(|| bug!("no entry found for `{id:?}`"))
873 .1
874 .clone()
875 }
876
877 fn shallow_resolve(&mut self, sort: &rty::Sort) -> rty::Sort {
878 sort.fold_with(&mut ShallowResolver { infcx: self })
879 }
880
881 fn resolve_vars_if_possible(&mut self, sort: &rty::Sort) -> rty::Sort {
882 sort.fold_with(&mut OpportunisticResolver { infcx: self })
883 }
884
885 pub(crate) fn fully_resolve(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, ()> {
886 sort.try_fold_with(&mut FullResolver { infcx: self })
887 }
888}
889
890pub(crate) struct ImplicitParamInferer<'a, 'genv, 'tcx> {
901 infcx: &'a mut InferCtxt<'genv, 'tcx>,
902 errors: Errors<'genv>,
903}
904
905impl<'a, 'genv, 'tcx> ImplicitParamInferer<'a, 'genv, 'tcx> {
906 pub(crate) fn infer(
907 infcx: &'a mut InferCtxt<'genv, 'tcx>,
908 node: &fhir::OwnerNode<'genv>,
909 ) -> Result {
910 let errors = Errors::new(infcx.genv.sess());
911 let mut vis = Self { infcx, errors };
912 vis.visit_node(node);
913 vis.errors.to_result()
914 }
915
916 fn infer_implicit_params(&mut self, idx: &fhir::Expr, expected: &rty::Sort) {
917 match idx.kind {
918 fhir::ExprKind::Var(QPathExpr::Resolved(var, Some(_))) => {
919 let (_, id) = var.res.expect_param();
920 let found = self.infcx.param_sort(id);
921 self.infcx.try_equate(&found, expected).unwrap_or_else(|| {
922 span_bug!(idx.span, "failed to equate sorts: `{found:?}` `{expected:?}`")
923 });
924 }
925 fhir::ExprKind::Record(flds) => {
926 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
927 let sorts = sort_def.struct_variant().field_sorts(sort_args);
928 if flds.len() != sorts.len() {
929 self.errors.emit(errors::ArgCountMismatch::new(
930 Some(idx.span),
931 String::from("type"),
932 sorts.len(),
933 flds.len(),
934 ));
935 return;
936 }
937 for (f, sort) in iter::zip(flds, &sorts) {
938 self.infer_implicit_params(f, sort);
939 }
940 } else {
941 self.errors.emit(errors::ArgCountMismatch::new(
942 Some(idx.span),
943 String::from("type"),
944 1,
945 flds.len(),
946 ));
947 }
948 }
949 _ => {}
950 }
951 }
952}
953
954impl<'genv> fhir::visit::Visitor<'genv> for ImplicitParamInferer<'_, 'genv, '_> {
955 fn visit_ty(&mut self, ty: &fhir::Ty<'genv>) {
956 if let fhir::TyKind::Indexed(bty, idx) = &ty.kind {
957 let expected = self.infcx.sort_of_bty(bty);
958 self.infer_implicit_params(idx, &expected);
959 }
960 fhir::visit::walk_ty(self, ty);
961 }
962}
963
964impl InferCtxt<'_, '_> {
965 #[track_caller]
966 fn emit_sort_mismatch(
967 &mut self,
968 span: Span,
969 expected: &rty::Sort,
970 found: &rty::Sort,
971 ) -> ErrorGuaranteed {
972 let expected = self.resolve_vars_if_possible(expected);
973 let found = self.resolve_vars_if_possible(found);
974 self.emit_err(errors::SortMismatch::new(span, expected, found))
975 }
976
977 fn emit_field_not_found(&mut self, sort: &rty::Sort, field: Ident) -> ErrorGuaranteed {
978 self.emit_err(errors::FieldNotFound::new(sort.clone(), field))
979 }
980
981 #[track_caller]
982 fn emit_err<'b>(&'b self, err: impl Diagnostic<'b>) -> ErrorGuaranteed {
983 self.genv.sess().emit_err(err)
984 }
985}
986
987struct ShallowResolver<'a, 'genv, 'tcx> {
988 infcx: &'a mut InferCtxt<'genv, 'tcx>,
989}
990
991impl TypeFolder for ShallowResolver<'_, '_, '_> {
992 fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
993 match sort {
994 rty::Sort::Infer(vid) => {
995 self.infcx
1001 .sort_unification_table
1002 .probe_value(*vid)
1003 .map_solved(|sort| sort.fold_with(self))
1004 .solved_or(sort)
1005 }
1006 rty::Sort::BitVec(rty::BvSize::Infer(vid)) => {
1007 self.infcx
1008 .bv_size_unification_table
1009 .probe_value(*vid)
1010 .map(rty::Sort::BitVec)
1011 .unwrap_or(sort.clone())
1012 }
1013 _ => sort.clone(),
1014 }
1015 }
1016}
1017
1018struct OpportunisticResolver<'a, 'genv, 'tcx> {
1019 infcx: &'a mut InferCtxt<'genv, 'tcx>,
1020}
1021
1022impl TypeFolder for OpportunisticResolver<'_, '_, '_> {
1023 fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
1024 let s = self.infcx.shallow_resolve(sort);
1025 s.super_fold_with(self)
1026 }
1027}
1028
1029struct FullResolver<'a, 'genv, 'tcx> {
1030 infcx: &'a mut InferCtxt<'genv, 'tcx>,
1031}
1032
1033impl FallibleTypeFolder for FullResolver<'_, '_, '_> {
1034 type Error = ();
1035
1036 fn try_fold_sort(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, Self::Error> {
1037 let s = self.infcx.shallow_resolve(sort);
1038 match s {
1039 rty::Sort::Infer(_) | rty::Sort::BitVec(rty::BvSize::Infer(_)) => Err(()),
1040 _ => s.try_super_fold_with(self),
1041 }
1042 }
1043}
1044
1045#[derive_where(Default)]
1049struct NodeMap<'genv, T> {
1050 map: FxHashMap<FhirId, (fhir::Expr<'genv>, T)>,
1051}
1052
1053impl<'genv, T> NodeMap<'genv, T> {
1054 fn insert(&mut self, node: fhir::Expr<'genv>, data: T) {
1056 assert!(self.map.insert(node.fhir_id, (node, data)).is_none());
1057 }
1058}
1059
1060impl<'genv, T> IntoIterator for NodeMap<'genv, T> {
1061 type Item = (fhir::Expr<'genv>, T);
1062
1063 type IntoIter = std::collections::hash_map::IntoValues<FhirId, Self::Item>;
1064
1065 fn into_iter(self) -> Self::IntoIter {
1066 self.map.into_values()
1067 }
1068}