1use std::iter;
2
3use derive_where::derive_where;
4use ena::unify::InPlaceUnificationTable;
5use flux_common::{bug, iter::IterExt, span_bug, tracked_span_bug};
6use flux_errors::{ErrorGuaranteed, Errors};
7use flux_infer::projections::NormalizeExt;
8use flux_middle::{
9 fhir::{self, FhirId, FluxOwnerId, visit::Visitor as _},
10 global_env::GlobalEnv,
11 queries::QueryResult,
12 rty::{
13 self, AdtSortDef, FuncSort, List, WfckResults,
14 fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable},
15 },
16};
17use itertools::{Itertools, izip};
18use rustc_data_structures::unord::UnordMap;
19use rustc_errors::Diagnostic;
20use rustc_hash::FxHashMap;
21use rustc_hir::def::DefKind;
22use rustc_middle::ty::TypingMode;
23use rustc_span::{Span, def_id::DefId, symbol::Ident};
24
25use super::errors;
26use crate::rustc_infer::infer::TyCtxtInferExt;
27
28type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
29
30pub(super) struct InferCtxt<'genv, 'tcx> {
31 pub genv: GlobalEnv<'genv, 'tcx>,
32 pub owner: FluxOwnerId,
33 pub wfckresults: WfckResults,
34 sort_unification_table: InPlaceUnificationTable<rty::SortVid>,
35 num_unification_table: InPlaceUnificationTable<rty::NumVid>,
36 bv_size_unification_table: InPlaceUnificationTable<rty::BvSizeVid>,
37 params: FxHashMap<fhir::ParamId, (fhir::RefineParam<'genv>, rty::Sort)>,
38 node_sort: FxHashMap<FhirId, rty::Sort>,
39 path_args: UnordMap<FhirId, rty::GenericArgs>,
40 sort_of_alias_reft: FxHashMap<FhirId, rty::FuncSort>,
41 sort_of_literal: NodeMap<'genv, rty::Sort>,
42 sort_of_bin_op: NodeMap<'genv, rty::Sort>,
43 sort_args_of_app: NodeMap<'genv, List<rty::SortArg>>,
44}
45
46pub fn prim_op_sort(op: &fhir::BinOp) -> Option<(Vec<rty::Sort>, rty::Sort)> {
47 match op {
48 fhir::BinOp::BitAnd
49 | fhir::BinOp::BitOr
50 | fhir::BinOp::BitXor
51 | fhir::BinOp::BitShl
52 | fhir::BinOp::BitShr => Some((vec![rty::Sort::Int, rty::Sort::Int], rty::Sort::Int)),
53 _ => None,
54 }
55}
56
57impl<'genv, 'tcx> InferCtxt<'genv, 'tcx> {
58 pub(super) fn new(genv: GlobalEnv<'genv, 'tcx>, owner: FluxOwnerId) -> Self {
59 let mut sort_unification_table = InPlaceUnificationTable::new();
61 sort_unification_table.new_key(None);
62 Self {
63 genv,
64 owner,
65 wfckresults: WfckResults::new(owner),
66 sort_unification_table,
67 num_unification_table: InPlaceUnificationTable::new(),
68 bv_size_unification_table: InPlaceUnificationTable::new(),
69 params: Default::default(),
70 node_sort: Default::default(),
71 path_args: Default::default(),
72 sort_of_alias_reft: Default::default(),
73 sort_of_literal: Default::default(),
74 sort_of_bin_op: Default::default(),
75 sort_args_of_app: Default::default(),
76 }
77 }
78
79 fn check_abs(
80 &mut self,
81 expr: &fhir::Expr<'genv>,
82 params: &[fhir::RefineParam],
83 body: &fhir::Expr<'genv>,
84 expected: &rty::Sort,
85 ) -> Result {
86 let Some(fsort) = self.is_coercible_from_func(expected, expr.fhir_id) else {
87 return Err(self.emit_err(errors::UnexpectedFun::new(expr.span, expected)));
88 };
89
90 let fsort = fsort.expect_mono();
91
92 if params.len() != fsort.inputs().len() {
93 return Err(self.emit_err(errors::ParamCountMismatch::new(
94 expr.span,
95 fsort.inputs().len(),
96 params.len(),
97 )));
98 }
99 iter::zip(params, fsort.inputs()).try_for_each_exhaust(|(param, expected)| {
100 let found = self.param_sort(param.id);
101 if self.try_equate(&found, expected).is_none() {
102 return Err(self.emit_sort_mismatch(param.span, expected, &found));
103 }
104 Ok(())
105 })?;
106 self.check_expr(body, fsort.output())?;
107 self.wfckresults
108 .node_sorts_mut()
109 .insert(body.fhir_id, fsort.output().clone());
110 Ok(())
111 }
112
113 fn check_field_exprs(
114 &mut self,
115 expr_span: Span,
116 sort_def: &AdtSortDef,
117 sort_args: &[rty::Sort],
118 field_exprs: &[fhir::FieldExpr<'genv>],
119 spread: &Option<&fhir::Spread<'genv>>,
120 expected: &rty::Sort,
121 ) -> Result {
122 let sort_by_field_name = sort_def.struct_variant().sort_by_field_name(sort_args);
123 let mut used_fields = FxHashMap::default();
124 for expr in field_exprs {
125 let Some(sort) = sort_by_field_name.get(&expr.ident.name) else {
127 return Err(self.emit_err(errors::FieldNotFound::new(expected.clone(), expr.ident)));
128 };
129 if let Some(old_field) = used_fields.insert(expr.ident.name, expr.ident) {
130 return Err(self.emit_err(errors::DuplicateFieldUsed::new(expr.ident, old_field)));
131 }
132 self.check_expr(&expr.expr, sort)?;
133 }
134 if let Some(spread) = spread {
135 self.check_expr(&spread.expr, expected)
137 } else if sort_by_field_name.len() != used_fields.len() {
138 let missing_fields = sort_by_field_name
140 .into_keys()
141 .filter(|x| !used_fields.contains_key(x))
142 .collect();
143 Err(self.emit_err(errors::ConstructorMissingFields::new(expr_span, missing_fields)))
144 } else {
145 Ok(())
146 }
147 }
148
149 fn check_constructor(
150 &mut self,
151 expr: &fhir::Expr,
152 field_exprs: &[fhir::FieldExpr<'genv>],
153 spread: &Option<&fhir::Spread<'genv>>,
154 expected: &rty::Sort,
155 ) -> Result {
156 let expected = self.resolve_vars_if_possible(expected);
157 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = &expected {
158 self.wfckresults
159 .record_ctors_mut()
160 .insert(expr.fhir_id, sort_def.did());
161 self.check_field_exprs(expr.span, sort_def, sort_args, field_exprs, spread, &expected)?;
162 Ok(())
163 } else {
164 Err(self.emit_err(errors::UnexpectedConstructor::new(expr.span, &expected)))
165 }
166 }
167
168 fn check_record(
169 &mut self,
170 arg: &fhir::Expr<'genv>,
171 flds: &[fhir::Expr<'genv>],
172 expected: &rty::Sort,
173 ) -> Result {
174 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
175 let sorts = sort_def.struct_variant().field_sorts(sort_args);
176 if flds.len() != sorts.len() {
177 return Err(self.emit_err(errors::ArgCountMismatch::new(
178 Some(arg.span),
179 String::from("type"),
180 sorts.len(),
181 flds.len(),
182 )));
183 }
184 self.wfckresults
185 .record_ctors_mut()
186 .insert(arg.fhir_id, sort_def.did());
187
188 izip!(flds, &sorts)
189 .map(|(arg, expected)| self.check_expr(arg, expected))
190 .try_collect_exhaust()
191 } else {
192 Err(self.emit_err(errors::ArgCountMismatch::new(
193 Some(arg.span),
194 String::from("type"),
195 1,
196 flds.len(),
197 )))
198 }
199 }
200
201 pub(super) fn check_expr(&mut self, expr: &fhir::Expr<'genv>, expected: &rty::Sort) -> Result {
202 match &expr.kind {
203 fhir::ExprKind::IfThenElse(p, e1, e2) => {
204 self.check_expr(p, &rty::Sort::Bool)?;
205 self.check_expr(e1, expected)?;
206 self.check_expr(e2, expected)?;
207 }
208 fhir::ExprKind::Abs(params, body) => {
209 self.check_abs(expr, params, body, expected)?;
210 }
211 fhir::ExprKind::Record(fields) => self.check_record(expr, fields, expected)?,
212 fhir::ExprKind::Constructor(None, exprs, spread) => {
213 self.check_constructor(expr, exprs, spread, expected)?;
214 }
215 fhir::ExprKind::UnaryOp(..)
216 | fhir::ExprKind::BinaryOp(..)
217 | fhir::ExprKind::Dot(..)
218 | fhir::ExprKind::App(..)
219 | fhir::ExprKind::Alias(..)
220 | fhir::ExprKind::Var(..)
221 | fhir::ExprKind::Literal(..)
222 | fhir::ExprKind::BoundedQuant(..)
223 | fhir::ExprKind::Block(..)
224 | fhir::ExprKind::Constructor(..)
225 | fhir::ExprKind::PrimApp(..) => {
226 let found = self.synth_expr(expr)?;
227 let found = self.resolve_vars_if_possible(&found);
228 let expected = self.resolve_vars_if_possible(expected);
229 if !self.is_coercible(&found, &expected, expr.fhir_id) {
230 return Err(self.emit_sort_mismatch(expr.span, &expected, &found));
231 }
232 }
233 fhir::ExprKind::Err(_) => {
234 }
236 }
237 Ok(())
238 }
239
240 pub(super) fn check_loc(&mut self, loc: &fhir::PathExpr) -> Result {
241 let found = self.synth_path(loc);
242 if found == rty::Sort::Loc {
243 Ok(())
244 } else {
245 Err(self.emit_sort_mismatch(loc.span, &rty::Sort::Loc, &found))
246 }
247 }
248
249 fn synth_lit(&mut self, lit: fhir::Lit, expr: &fhir::Expr<'genv>) -> rty::Sort {
250 match lit {
251 fhir::Lit::Int(_, Some(fhir::NumLitKind::Int)) => rty::Sort::Int,
252 fhir::Lit::Int(_, Some(fhir::NumLitKind::Real)) => rty::Sort::Real,
253 fhir::Lit::Int(_, None) => {
254 let sort = self.next_num_var();
255 self.sort_of_literal.insert(*expr, sort.clone());
256 sort
257 }
258 fhir::Lit::Bool(_) => rty::Sort::Bool,
259 fhir::Lit::Str(_) => rty::Sort::Str,
260 fhir::Lit::Char(_) => rty::Sort::Char,
261 }
262 }
263
264 fn synth_prim_app(
265 &mut self,
266 op: &fhir::BinOp,
267 e1: &fhir::Expr<'genv>,
268 e2: &fhir::Expr<'genv>,
269 span: Span,
270 ) -> Result<rty::Sort> {
271 let Some((inputs, output)) = prim_op_sort(op) else {
272 return Err(self.emit_err(errors::UnsupportedPrimOp::new(span, *op)));
273 };
274 let [sort1, sort2] = &inputs[..] else {
275 return Err(self.emit_err(errors::ArgCountMismatch::new(
276 Some(span),
277 String::from("primop app"),
278 inputs.len(),
279 2,
280 )));
281 };
282 self.check_expr(e1, sort1)?;
283 self.check_expr(e2, sort2)?;
284 Ok(output)
285 }
286
287 fn synth_expr(&mut self, expr: &fhir::Expr<'genv>) -> Result<rty::Sort> {
288 match expr.kind {
289 fhir::ExprKind::Var(var, _) => Ok(self.synth_path(&var)),
290 fhir::ExprKind::Literal(lit) => Ok(self.synth_lit(lit, expr)),
291 fhir::ExprKind::BinaryOp(op, e1, e2) => self.synth_binary_op(expr, op, e1, e2),
292 fhir::ExprKind::PrimApp(op, e1, e2) => self.synth_prim_app(&op, e1, e2, expr.span),
293 fhir::ExprKind::UnaryOp(op, e) => self.synth_unary_op(op, e),
294 fhir::ExprKind::App(callee, args) => {
295 let sort = self.ensure_resolved_path(&callee)?;
296 let Some(poly_fsort) = self.is_coercible_to_func(&sort, callee.fhir_id) else {
297 return Err(self.emit_err(errors::ExpectedFun::new(callee.span, &sort)));
298 };
299 let fsort = self.instantiate_func_sort(expr, poly_fsort);
300 self.synth_app(fsort, args, expr.span)
301 }
302 fhir::ExprKind::BoundedQuant(.., body) => {
303 self.check_expr(body, &rty::Sort::Bool)?;
304 Ok(rty::Sort::Bool)
305 }
306 fhir::ExprKind::Alias(_alias_reft, args) => {
307 let fsort = self.sort_of_alias_reft(expr.fhir_id);
310 self.synth_app(fsort, args, expr.span)
311 }
312 fhir::ExprKind::IfThenElse(p, e1, e2) => {
313 self.check_expr(p, &rty::Sort::Bool)?;
314 let sort = self.synth_expr(e1)?;
315 self.check_expr(e2, &sort)?;
316 Ok(sort)
317 }
318 fhir::ExprKind::Dot(base, fld) => {
319 let sort = self.synth_expr(base)?;
320 let sort = self
321 .fully_resolve(&sort)
322 .map_err(|_| self.emit_err(errors::CannotInferSort::new(base.span)))?;
323 match &sort {
324 rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) => {
325 let (proj, sort) = sort_def
326 .struct_variant()
327 .field_by_name(sort_def.did(), sort_args, fld.name)
328 .ok_or_else(|| self.emit_field_not_found(&sort, fld))?;
329 self.wfckresults
330 .field_projs_mut()
331 .insert(expr.fhir_id, proj);
332 Ok(sort)
333 }
334 rty::Sort::Bool | rty::Sort::Int | rty::Sort::Real => {
335 Err(self.emit_err(errors::InvalidPrimitiveDotAccess::new(&sort, fld)))
336 }
337 _ => Err(self.emit_field_not_found(&sort, fld)),
338 }
339 }
340 fhir::ExprKind::Constructor(Some(path), field_exprs, spread) => {
341 let path_def_id = match path.res {
345 fhir::Res::Def(DefKind::Enum | DefKind::Struct, def_id) => def_id,
346 _ => span_bug!(expr.span, "unexpected path in constructor"),
347 };
348 let sort_def = self
349 .genv
350 .adt_sort_def_of(path_def_id)
351 .map_err(|e| self.emit_err(e))?;
352 let fresh_args: rty::List<_> = (0..sort_def.param_count())
354 .map(|_| self.next_sort_var())
355 .collect();
356 let sort = rty::Sort::App(rty::SortCtor::Adt(sort_def.clone()), fresh_args.clone());
357 self.check_field_exprs(
359 expr.span,
360 &sort_def,
361 &fresh_args,
362 field_exprs,
363 &spread,
364 &sort,
365 )?;
366 Ok(sort)
367 }
368 fhir::ExprKind::Block(decls, body) => {
369 for decl in decls {
370 self.check_expr(&decl.init, &self.param_sort(decl.param.id))?;
371 }
372 self.synth_expr(body)
373 }
374 _ => Err(self.emit_err(errors::CannotInferSort::new(expr.span))),
375 }
376 }
377
378 fn synth_path(&mut self, path: &fhir::PathExpr) -> rty::Sort {
379 self.node_sort
380 .get(&path.fhir_id)
381 .unwrap_or_else(|| tracked_span_bug!("no sort found for path: `{path:?}`"))
382 .clone()
383 }
384
385 fn check_integral(&mut self, op: fhir::BinOp, sort: &rty::Sort, span: Span) -> Result {
386 if matches!(op, fhir::BinOp::Mod) {
387 let sort = self
388 .fully_resolve(sort)
389 .map_err(|_| self.emit_err(errors::CannotInferSort::new(span)))?;
390 if !matches!(sort, rty::Sort::Int | rty::Sort::BitVec(_)) {
391 span_bug!(span, "unexpected sort {sort:?} for operator {op:?}");
392 }
393 }
394 Ok(())
395 }
396
397 fn synth_binary_op(
398 &mut self,
399 expr: &fhir::Expr<'genv>,
400 op: fhir::BinOp,
401 e1: &fhir::Expr<'genv>,
402 e2: &fhir::Expr<'genv>,
403 ) -> Result<rty::Sort> {
404 match op {
405 fhir::BinOp::Or | fhir::BinOp::And | fhir::BinOp::Iff | fhir::BinOp::Imp => {
406 self.check_expr(e1, &rty::Sort::Bool)?;
407 self.check_expr(e2, &rty::Sort::Bool)?;
408 Ok(rty::Sort::Bool)
409 }
410 fhir::BinOp::Eq | fhir::BinOp::Ne => {
411 let sort = self.next_sort_var();
412 self.check_expr(e1, &sort)?;
413 self.check_expr(e2, &sort)?;
414 Ok(rty::Sort::Bool)
415 }
416 fhir::BinOp::Lt | fhir::BinOp::Le | fhir::BinOp::Gt | fhir::BinOp::Ge => {
417 let sort = self.next_sort_var();
418 self.check_expr(e1, &sort)?;
419 self.check_expr(e2, &sort)?;
420 self.sort_of_bin_op.insert(*expr, sort.clone());
421 Ok(rty::Sort::Bool)
422 }
423 fhir::BinOp::Add
424 | fhir::BinOp::Sub
425 | fhir::BinOp::Mul
426 | fhir::BinOp::Div
427 | fhir::BinOp::Mod => {
428 let sort = self.next_num_var();
429 self.check_expr(e1, &sort)?;
430 self.check_expr(e2, &sort)?;
431 self.sort_of_bin_op.insert(*expr, sort.clone());
432 self.check_integral(op, &sort, expr.span)?;
434
435 Ok(sort)
436 }
437 fhir::BinOp::BitAnd
438 | fhir::BinOp::BitOr
439 | fhir::BinOp::BitXor
440 | fhir::BinOp::BitShl
441 | fhir::BinOp::BitShr => {
442 let sort = rty::Sort::BitVec(self.next_bv_size_var());
443 self.check_expr(e1, &sort)?;
444 self.check_expr(e2, &sort)?;
445 Ok(sort)
446 }
447 }
448 }
449
450 fn synth_unary_op(&mut self, op: fhir::UnOp, e: &fhir::Expr<'genv>) -> Result<rty::Sort> {
451 match op {
452 fhir::UnOp::Not => {
453 self.check_expr(e, &rty::Sort::Bool)?;
454 Ok(rty::Sort::Bool)
455 }
456 fhir::UnOp::Neg => {
457 self.check_expr(e, &rty::Sort::Int)?;
458 Ok(rty::Sort::Int)
459 }
460 }
461 }
462
463 fn synth_app(
464 &mut self,
465 fsort: FuncSort,
466 args: &[fhir::Expr<'genv>],
467 span: Span,
468 ) -> Result<rty::Sort> {
469 if args.len() != fsort.inputs().len() {
470 return Err(self.emit_err(errors::ArgCountMismatch::new(
471 Some(span),
472 String::from("function"),
473 fsort.inputs().len(),
474 args.len(),
475 )));
476 }
477
478 iter::zip(args, fsort.inputs())
479 .try_for_each_exhaust(|(arg, formal)| self.check_expr(arg, formal))?;
480
481 Ok(fsort.output().clone())
482 }
483
484 fn instantiate_func_sort(
485 &mut self,
486 app_expr: &fhir::Expr<'genv>,
487 fsort: rty::PolyFuncSort,
488 ) -> rty::FuncSort {
489 let args = fsort
490 .params()
491 .map(|kind| {
492 match kind {
493 rty::SortParamKind::Sort => rty::SortArg::Sort(self.next_sort_var()),
494 rty::SortParamKind::BvSize => rty::SortArg::BvSize(self.next_bv_size_var()),
495 }
496 })
497 .collect_vec();
498 self.sort_args_of_app
499 .insert(*app_expr, List::from_slice(&args));
500 fsort.instantiate(&args)
501 }
502
503 pub(crate) fn insert_node_sort(&mut self, fhir_id: FhirId, sort: rty::Sort) {
504 self.node_sort.insert(fhir_id, sort);
505 }
506
507 pub(crate) fn sort_of_bty(&self, bty: &fhir::BaseTy) -> rty::Sort {
508 self.node_sort
509 .get(&bty.fhir_id)
510 .unwrap_or_else(|| tracked_span_bug!("no sort found for bty: `{bty:?}`"))
511 .clone()
512 }
513
514 pub(crate) fn insert_path_args(&mut self, fhir_id: FhirId, args: rty::GenericArgs) {
515 self.path_args.insert(fhir_id, args);
516 }
517
518 pub(crate) fn path_args(&self, fhir_id: FhirId) -> rty::GenericArgs {
519 self.path_args
520 .get(&fhir_id)
521 .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
522 .clone()
523 }
524
525 pub(crate) fn insert_sort_for_alias_reft(&mut self, fhir_id: FhirId, fsort: rty::FuncSort) {
526 self.sort_of_alias_reft.insert(fhir_id, fsort);
527 }
528
529 fn sort_of_alias_reft(&self, fhir_id: FhirId) -> rty::FuncSort {
530 self.sort_of_alias_reft
531 .get(&fhir_id)
532 .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
533 .clone()
534 }
535
536 fn normalize_projection_sort(
537 genv: GlobalEnv,
538 owner: FluxOwnerId,
539 sort: rty::Sort,
540 ) -> rty::Sort {
541 let infcx = genv
542 .tcx()
543 .infer_ctxt()
544 .with_next_trait_solver(true)
545 .build(TypingMode::non_body_analysis());
546 if let Some(def_id) = owner.def_id()
547 && let def_id = genv.maybe_extern_id(def_id).resolved_id()
548 && let Ok(sort) = sort.normalize_sorts(def_id, genv, &infcx)
549 {
550 sort
551 } else {
552 sort
553 }
554 }
555
556 pub(crate) fn normalize_sorts(&mut self) -> QueryResult {
561 let genv = self.genv;
562 for sort in self.node_sort.values_mut() {
563 *sort = Self::normalize_projection_sort(genv, self.owner, sort.clone());
564 }
565 for fsort in self.sort_of_alias_reft.values_mut() {
566 *fsort = genv.deep_normalize_weak_alias_sorts(fsort)?;
567 }
568 Ok(())
569 }
570}
571
572impl<'genv> InferCtxt<'genv, '_> {
573 pub(super) fn declare_param(&mut self, param: fhir::RefineParam<'genv>, sort: rty::Sort) {
574 self.params.insert(param.id, (param, sort));
575 }
576
577 fn is_coercible(&mut self, sort1: &rty::Sort, sort2: &rty::Sort, fhir_id: FhirId) -> bool {
582 if self.try_equate(sort1, sort2).is_some() {
583 return true;
584 }
585
586 let mut sort1 = sort1.clone();
587 let mut sort2 = sort2.clone();
588
589 let mut coercions = vec![];
590 if let Some((def_id, sort)) = self.is_single_field_struct(&sort1) {
591 coercions.push(rty::Coercion::Project(def_id));
592 sort1 = sort.clone();
593 }
594 if let Some((def_id, sort)) = self.is_single_field_struct(&sort2) {
595 coercions.push(rty::Coercion::Inject(def_id));
596 sort2 = sort.clone();
597 }
598 self.wfckresults.coercions_mut().insert(fhir_id, coercions);
599 self.try_equate(&sort1, &sort2).is_some()
600 }
601
602 fn is_coercible_from_func(
603 &mut self,
604 sort: &rty::Sort,
605 fhir_id: FhirId,
606 ) -> Option<rty::PolyFuncSort> {
607 if let rty::Sort::Func(fsort) = sort {
608 Some(fsort.clone())
609 } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
610 self.wfckresults
611 .coercions_mut()
612 .insert(fhir_id, vec![rty::Coercion::Inject(def_id)]);
613 Some(fsort.clone())
614 } else {
615 None
616 }
617 }
618
619 fn is_coercible_to_func(
620 &mut self,
621 sort: &rty::Sort,
622 fhir_id: FhirId,
623 ) -> Option<rty::PolyFuncSort> {
624 if let rty::Sort::Func(fsort) = sort {
625 Some(fsort.clone())
626 } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
627 self.wfckresults
628 .coercions_mut()
629 .insert(fhir_id, vec![rty::Coercion::Project(def_id)]);
630 Some(fsort.clone())
631 } else {
632 None
633 }
634 }
635
636 fn try_equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
637 let sort1 = self.resolve_vars_if_possible(sort1);
638 let sort2 = self.resolve_vars_if_possible(sort2);
639 self.try_equate_inner(&sort1, &sort2)
640 }
641
642 fn try_equate_inner(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
643 match (sort1, sort2) {
644 (rty::Sort::Infer(rty::SortVar(vid1)), rty::Sort::Infer(rty::SortVar(vid2))) => {
645 self.sort_unification_table
646 .unify_var_var(*vid1, *vid2)
647 .ok()?;
648 }
649 (rty::Sort::Infer(rty::SortVar(vid)), sort)
650 | (sort, rty::Sort::Infer(rty::SortVar(vid))) => {
651 self.sort_unification_table
652 .unify_var_value(*vid, Some(sort.clone()))
653 .ok()?;
654 }
655 (rty::Sort::Infer(rty::NumVar(vid1)), rty::Sort::Infer(rty::NumVar(vid2))) => {
656 self.num_unification_table
657 .unify_var_var(*vid1, *vid2)
658 .ok()?;
659 }
660 (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::Int)
661 | (rty::Sort::Int, rty::Sort::Infer(rty::NumVar(vid))) => {
662 self.num_unification_table
663 .unify_var_value(*vid, Some(rty::NumVarValue::Int))
664 .ok()?;
665 }
666 (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::Real)
667 | (rty::Sort::Real, rty::Sort::Infer(rty::NumVar(vid))) => {
668 self.num_unification_table
669 .unify_var_value(*vid, Some(rty::NumVarValue::Real))
670 .ok()?;
671 }
672 (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::BitVec(sz))
673 | (rty::Sort::BitVec(sz), rty::Sort::Infer(rty::NumVar(vid))) => {
674 self.num_unification_table
675 .unify_var_value(*vid, Some(rty::NumVarValue::BitVec(*sz)))
676 .ok()?;
677 }
678
679 (rty::Sort::App(ctor1, args1), rty::Sort::App(ctor2, args2)) => {
680 if ctor1 != ctor2 || args1.len() != args2.len() {
681 return None;
682 }
683 let mut args = vec![];
684 for (s1, s2) in args1.iter().zip(args2.iter()) {
685 args.push(self.try_equate_inner(s1, s2)?);
686 }
687 }
688 (rty::Sort::BitVec(size1), rty::Sort::BitVec(size2)) => {
689 self.try_equate_bv_sizes(*size1, *size2)?;
690 }
691 _ if sort1 == sort2 => {}
692 _ => return None,
693 }
694 Some(sort1.clone())
695 }
696
697 fn try_equate_bv_sizes(
698 &mut self,
699 size1: rty::BvSize,
700 size2: rty::BvSize,
701 ) -> Option<rty::BvSize> {
702 match (size1, size2) {
703 (rty::BvSize::Infer(vid1), rty::BvSize::Infer(vid2)) => {
704 self.bv_size_unification_table
705 .unify_var_var(vid1, vid2)
706 .ok()?;
707 }
708 (rty::BvSize::Infer(vid), size) | (size, rty::BvSize::Infer(vid)) => {
709 self.bv_size_unification_table
710 .unify_var_value(vid, Some(size))
711 .ok()?;
712 }
713 _ if size1 == size2 => {}
714 _ => return None,
715 }
716 Some(size1)
717 }
718
719 fn equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> rty::Sort {
720 self.try_equate(sort1, sort2)
721 .unwrap_or_else(|| bug!("failed to equate sorts: `{sort1:?}` `{sort2:?}`"))
722 }
723
724 pub(crate) fn next_sort_var(&mut self) -> rty::Sort {
725 rty::Sort::Infer(rty::SortVar(self.next_sort_vid()))
726 }
727
728 fn next_num_var(&mut self) -> rty::Sort {
729 rty::Sort::Infer(rty::NumVar(self.next_num_vid()))
730 }
731
732 pub(crate) fn next_sort_vid(&mut self) -> rty::SortVid {
733 self.sort_unification_table.new_key(None)
734 }
735
736 fn next_num_vid(&mut self) -> rty::NumVid {
737 self.num_unification_table.new_key(None)
738 }
739
740 fn next_bv_size_var(&mut self) -> rty::BvSize {
741 rty::BvSize::Infer(self.next_bv_size_vid())
742 }
743
744 fn next_bv_size_vid(&mut self) -> rty::BvSizeVid {
745 self.bv_size_unification_table.new_key(None)
746 }
747
748 fn ensure_resolved_path(&mut self, path: &fhir::PathExpr) -> Result<rty::Sort> {
749 let sort = self.synth_path(path);
750 self.fully_resolve(&sort)
751 .map_err(|_| self.emit_err(errors::CannotInferSort::new(path.span)))
752 }
753
754 fn is_single_field_struct(&mut self, sort: &rty::Sort) -> Option<(DefId, rty::Sort)> {
755 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = sort
756 && let Some(variant) = sort_def.opt_struct_variant()
757 && let [sort] = &variant.field_sorts(sort_args)[..]
758 {
759 Some((sort_def.did(), sort.clone()))
760 } else {
761 None
762 }
763 }
764
765 pub(crate) fn into_results(mut self) -> Result<WfckResults> {
766 for (node, sort) in std::mem::take(&mut self.sort_of_literal) {
769 if let rty::Sort::Infer(rty::SortInfer::NumVar(vid)) = &sort {
775 let _ = self
776 .num_unification_table
777 .unify_var_value(*vid, Some(rty::NumVarValue::Int));
778 }
779
780 let sort = self
781 .fully_resolve(&sort)
782 .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
783 self.wfckresults.node_sorts_mut().insert(node.fhir_id, sort);
784 }
785
786 for (node, sort) in std::mem::take(&mut self.sort_of_bin_op) {
788 let sort = self
789 .fully_resolve(&sort)
790 .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
791 self.wfckresults
792 .bin_op_sorts_mut()
793 .insert(node.fhir_id, sort);
794 }
795
796 let allow_uninterpreted_cast = self
797 .owner
798 .def_id()
799 .map_or_else(flux_config::allow_uninterpreted_cast, |def_id| {
800 self.genv.infer_opts(def_id).allow_uninterpreted_cast
801 });
802
803 for (node, sort_args) in std::mem::take(&mut self.sort_args_of_app) {
805 let mut res = vec![];
806 for sort_arg in &sort_args {
807 let sort_arg = match sort_arg {
808 rty::SortArg::Sort(sort) => {
809 let sort = self
810 .fully_resolve(sort)
811 .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
812 rty::SortArg::Sort(sort)
813 }
814 rty::SortArg::BvSize(rty::BvSize::Infer(vid)) => {
815 let size = self
816 .bv_size_unification_table
817 .probe_value(*vid)
818 .ok_or_else(|| {
819 self.emit_err(errors::CannotInferSort::new(node.span))
820 })?;
821 rty::SortArg::BvSize(size)
822 }
823 _ => sort_arg.clone(),
824 };
825 res.push(sort_arg);
826 }
827 if let fhir::ExprKind::App(callee, _) = node.kind
828 && matches!(callee.res, fhir::Res::GlobalFunc(fhir::SpecFuncKind::Cast))
829 {
830 let [rty::SortArg::Sort(from), rty::SortArg::Sort(to)] = &res[..] else {
831 span_bug!(node.span, "invalid sort args!")
832 };
833 if !allow_uninterpreted_cast
834 && matches!(from.cast_kind(to), rty::CastKind::Uninterpreted)
835 {
836 return Err(self.emit_err(errors::InvalidCast::new(node.span, from, to)));
837 }
838 }
839 self.wfckresults
840 .fn_app_sorts_mut()
841 .insert(node.fhir_id, res.into());
842 }
843
844 for (_, (param, sort)) in std::mem::take(&mut self.params) {
846 let sort = self
847 .fully_resolve(&sort)
848 .map_err(|_| self.emit_err(errors::SortAnnotationNeeded::new(¶m)))?;
849 self.wfckresults.param_sorts_mut().insert(param.id, sort);
850 }
851
852 Ok(self.wfckresults)
853 }
854
855 pub(crate) fn infer_mode(&self, id: fhir::ParamId) -> fhir::InferMode {
856 fhir::InferMode::from_param_kind(self.params[&id].0.kind)
857 }
858
859 #[track_caller]
860 pub(crate) fn param_sort(&self, id: fhir::ParamId) -> rty::Sort {
861 self.params
862 .get(&id)
863 .unwrap_or_else(|| bug!("no entry found for `{id:?}`"))
864 .1
865 .clone()
866 }
867
868 fn shallow_resolve(&mut self, sort: &rty::Sort) -> rty::Sort {
869 sort.fold_with(&mut ShallowResolver { infcx: self })
870 }
871
872 fn resolve_vars_if_possible(&mut self, sort: &rty::Sort) -> rty::Sort {
873 sort.fold_with(&mut OpportunisticResolver { infcx: self })
874 }
875
876 pub(crate) fn fully_resolve(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, ()> {
877 sort.try_fold_with(&mut FullResolver { infcx: self })
878 }
879}
880
881pub(crate) struct ImplicitParamInferer<'a, 'genv, 'tcx> {
892 infcx: &'a mut InferCtxt<'genv, 'tcx>,
893 errors: Errors<'genv>,
894}
895
896impl<'a, 'genv, 'tcx> ImplicitParamInferer<'a, 'genv, 'tcx> {
897 pub(crate) fn infer(
898 infcx: &'a mut InferCtxt<'genv, 'tcx>,
899 node: &fhir::OwnerNode<'genv>,
900 ) -> Result {
901 let errors = Errors::new(infcx.genv.sess());
902 let mut vis = Self { infcx, errors };
903 vis.visit_node(node);
904 vis.errors.into_result()
905 }
906
907 fn infer_implicit_params(&mut self, idx: &fhir::Expr, expected: &rty::Sort) {
908 match idx.kind {
909 fhir::ExprKind::Var(var, Some(_)) => {
910 let (_, id) = var.res.expect_param();
911 let found = self.infcx.param_sort(id);
912 self.infcx.equate(&found, expected);
913 }
914 fhir::ExprKind::Record(flds) => {
915 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
916 let sorts = sort_def.struct_variant().field_sorts(sort_args);
917 if flds.len() != sorts.len() {
918 self.errors.emit(errors::ArgCountMismatch::new(
919 Some(idx.span),
920 String::from("type"),
921 sorts.len(),
922 flds.len(),
923 ));
924 return;
925 }
926 for (f, sort) in iter::zip(flds, &sorts) {
927 self.infer_implicit_params(f, sort);
928 }
929 } else {
930 self.errors.emit(errors::ArgCountMismatch::new(
931 Some(idx.span),
932 String::from("type"),
933 1,
934 flds.len(),
935 ));
936 }
937 }
938 _ => {}
939 }
940 }
941}
942
943impl<'genv> fhir::visit::Visitor<'genv> for ImplicitParamInferer<'_, 'genv, '_> {
944 fn visit_ty(&mut self, ty: &fhir::Ty<'genv>) {
945 if let fhir::TyKind::Indexed(bty, idx) = &ty.kind {
946 let expected = self.infcx.sort_of_bty(bty);
947 self.infer_implicit_params(idx, &expected);
948 }
949 fhir::visit::walk_ty(self, ty);
950 }
951}
952
953impl InferCtxt<'_, '_> {
954 #[track_caller]
955 fn emit_sort_mismatch(
956 &mut self,
957 span: Span,
958 expected: &rty::Sort,
959 found: &rty::Sort,
960 ) -> ErrorGuaranteed {
961 let expected = self.resolve_vars_if_possible(expected);
962 let found = self.resolve_vars_if_possible(found);
963 self.emit_err(errors::SortMismatch::new(span, expected, found))
964 }
965
966 fn emit_field_not_found(&mut self, sort: &rty::Sort, field: Ident) -> ErrorGuaranteed {
967 self.emit_err(errors::FieldNotFound::new(sort.clone(), field))
968 }
969
970 #[track_caller]
971 fn emit_err<'b>(&'b self, err: impl Diagnostic<'b>) -> ErrorGuaranteed {
972 self.genv.sess().emit_err(err)
973 }
974}
975
976struct ShallowResolver<'a, 'genv, 'tcx> {
977 infcx: &'a mut InferCtxt<'genv, 'tcx>,
978}
979
980impl TypeFolder for ShallowResolver<'_, '_, '_> {
981 fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
982 match sort {
983 rty::Sort::Infer(rty::SortVar(vid)) => {
984 self.infcx
990 .sort_unification_table
991 .probe_value(*vid)
992 .map(|sort| sort.fold_with(self))
993 .unwrap_or(sort.clone())
994 }
995 rty::Sort::Infer(rty::NumVar(vid)) => {
996 self.infcx
998 .num_unification_table
999 .probe_value(*vid)
1000 .map(|val| val.to_sort().fold_with(self))
1001 .unwrap_or(sort.clone())
1002 }
1003 rty::Sort::BitVec(rty::BvSize::Infer(vid)) => {
1004 self.infcx
1005 .bv_size_unification_table
1006 .probe_value(*vid)
1007 .map(rty::Sort::BitVec)
1008 .unwrap_or(sort.clone())
1009 }
1010 _ => sort.clone(),
1011 }
1012 }
1013}
1014
1015struct OpportunisticResolver<'a, 'genv, 'tcx> {
1016 infcx: &'a mut InferCtxt<'genv, 'tcx>,
1017}
1018
1019impl TypeFolder for OpportunisticResolver<'_, '_, '_> {
1020 fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
1021 let s = self.infcx.shallow_resolve(sort);
1022 s.super_fold_with(self)
1023 }
1024}
1025
1026struct FullResolver<'a, 'genv, 'tcx> {
1027 infcx: &'a mut InferCtxt<'genv, 'tcx>,
1028}
1029
1030impl FallibleTypeFolder for FullResolver<'_, '_, '_> {
1031 type Error = ();
1032
1033 fn try_fold_sort(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, Self::Error> {
1034 let s = self.infcx.shallow_resolve(sort);
1035 match s {
1036 rty::Sort::Infer(_) | rty::Sort::BitVec(rty::BvSize::Infer(_)) => Err(()),
1037 _ => s.try_super_fold_with(self),
1038 }
1039 }
1040}
1041
1042#[derive_where(Default)]
1046struct NodeMap<'genv, T> {
1047 map: FxHashMap<FhirId, (fhir::Expr<'genv>, T)>,
1048}
1049
1050impl<'genv, T> NodeMap<'genv, T> {
1051 fn insert(&mut self, node: fhir::Expr<'genv>, data: T) {
1053 assert!(self.map.insert(node.fhir_id, (node, data)).is_none());
1054 }
1055}
1056
1057impl<'genv, T> IntoIterator for NodeMap<'genv, T> {
1058 type Item = (fhir::Expr<'genv>, T);
1059
1060 type IntoIter = std::collections::hash_map::IntoValues<FhirId, Self::Item>;
1061
1062 fn into_iter(self) -> Self::IntoIter {
1063 self.map.into_values()
1064 }
1065}