1use std::iter;
2
3use derive_where::derive_where;
4use ena::unify::InPlaceUnificationTable;
5use flux_common::{bug, iter::IterExt, span_bug, tracked_span_bug};
6use flux_errors::{ErrorGuaranteed, Errors};
7use flux_infer::projections::NormalizeExt;
8use flux_middle::{
9 fhir::{self, FhirId, FluxOwnerId, QPathExpr, visit::Visitor as _},
10 global_env::GlobalEnv,
11 queries::QueryResult,
12 rty::{
13 self, AdtSortDef, FuncSort, List, WfckResults,
14 fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable},
15 },
16};
17use itertools::{Itertools, izip};
18use rustc_data_structures::unord::UnordMap;
19use rustc_errors::Diagnostic;
20use rustc_hash::FxHashMap;
21use rustc_hir::def::DefKind;
22use rustc_middle::ty::TypingMode;
23use rustc_span::{Span, def_id::DefId, symbol::Ident};
24
25use super::errors;
26use crate::rustc_infer::infer::TyCtxtInferExt;
27
28type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
29
30pub(super) struct InferCtxt<'genv, 'tcx> {
31 pub genv: GlobalEnv<'genv, 'tcx>,
32 pub owner: FluxOwnerId,
33 pub wfckresults: WfckResults,
34 sort_unification_table: InPlaceUnificationTable<rty::SortVid>,
35 num_unification_table: InPlaceUnificationTable<rty::NumVid>,
36 bv_size_unification_table: InPlaceUnificationTable<rty::BvSizeVid>,
37 params: FxHashMap<fhir::ParamId, (fhir::RefineParam<'genv>, rty::Sort)>,
38 node_sort: FxHashMap<FhirId, rty::Sort>,
39 path_args: UnordMap<FhirId, rty::GenericArgs>,
40 sort_of_alias_reft: FxHashMap<FhirId, rty::FuncSort>,
41 sort_of_literal: NodeMap<'genv, rty::Sort>,
42 sort_of_bin_op: NodeMap<'genv, rty::Sort>,
43 sort_args_of_app: NodeMap<'genv, List<rty::SortArg>>,
44}
45
46pub fn prim_op_sort(op: &fhir::BinOp) -> Option<(Vec<rty::Sort>, rty::Sort)> {
47 match op {
48 fhir::BinOp::BitAnd
49 | fhir::BinOp::BitOr
50 | fhir::BinOp::BitXor
51 | fhir::BinOp::BitShl
52 | fhir::BinOp::BitShr => Some((vec![rty::Sort::Int, rty::Sort::Int], rty::Sort::Int)),
53 _ => None,
54 }
55}
56
57impl<'genv, 'tcx> InferCtxt<'genv, 'tcx> {
58 pub(super) fn new(genv: GlobalEnv<'genv, 'tcx>, owner: FluxOwnerId) -> Self {
59 let mut sort_unification_table = InPlaceUnificationTable::new();
61 sort_unification_table.new_key(None);
62 Self {
63 genv,
64 owner,
65 wfckresults: WfckResults::new(owner),
66 sort_unification_table,
67 num_unification_table: InPlaceUnificationTable::new(),
68 bv_size_unification_table: InPlaceUnificationTable::new(),
69 params: Default::default(),
70 node_sort: Default::default(),
71 path_args: Default::default(),
72 sort_of_alias_reft: Default::default(),
73 sort_of_literal: Default::default(),
74 sort_of_bin_op: Default::default(),
75 sort_args_of_app: Default::default(),
76 }
77 }
78
79 fn check_abs(
80 &mut self,
81 expr: &fhir::Expr<'genv>,
82 params: &[fhir::RefineParam],
83 body: &fhir::Expr<'genv>,
84 expected: &rty::Sort,
85 ) -> Result {
86 let Some(fsort) = self.is_coercible_from_func(expected, expr.fhir_id) else {
87 return Err(self.emit_err(errors::UnexpectedFun::new(expr.span, expected)));
88 };
89
90 let fsort = fsort.expect_mono();
91
92 if params.len() != fsort.inputs().len() {
93 return Err(self.emit_err(errors::ParamCountMismatch::new(
94 expr.span,
95 fsort.inputs().len(),
96 params.len(),
97 )));
98 }
99 iter::zip(params, fsort.inputs()).try_for_each_exhaust(|(param, expected)| {
100 let found = self.param_sort(param.id);
101 if self.try_equate(&found, expected).is_none() {
102 return Err(self.emit_sort_mismatch(param.span, expected, &found));
103 }
104 Ok(())
105 })?;
106 self.check_expr(body, fsort.output())?;
107 self.wfckresults
108 .node_sorts_mut()
109 .insert(body.fhir_id, fsort.output().clone());
110 Ok(())
111 }
112
113 fn check_field_exprs(
114 &mut self,
115 expr_span: Span,
116 sort_def: &AdtSortDef,
117 sort_args: &[rty::Sort],
118 field_exprs: &[fhir::FieldExpr<'genv>],
119 spread: &Option<&fhir::Spread<'genv>>,
120 expected: &rty::Sort,
121 ) -> Result {
122 let sort_by_field_name = sort_def.struct_variant().sort_by_field_name(sort_args);
123 let mut used_fields = FxHashMap::default();
124 for expr in field_exprs {
125 let Some(sort) = sort_by_field_name.get(&expr.ident.name) else {
127 return Err(self.emit_err(errors::FieldNotFound::new(expected.clone(), expr.ident)));
128 };
129 if let Some(old_field) = used_fields.insert(expr.ident.name, expr.ident) {
130 return Err(self.emit_err(errors::DuplicateFieldUsed::new(expr.ident, old_field)));
131 }
132 self.check_expr(&expr.expr, sort)?;
133 }
134 if let Some(spread) = spread {
135 self.check_expr(&spread.expr, expected)
137 } else if sort_by_field_name.len() != used_fields.len() {
138 let missing_fields = sort_by_field_name
140 .into_keys()
141 .filter(|x| !used_fields.contains_key(x))
142 .collect();
143 Err(self.emit_err(errors::ConstructorMissingFields::new(expr_span, missing_fields)))
144 } else {
145 Ok(())
146 }
147 }
148
149 fn check_constructor(
150 &mut self,
151 expr: &fhir::Expr,
152 field_exprs: &[fhir::FieldExpr<'genv>],
153 spread: &Option<&fhir::Spread<'genv>>,
154 expected: &rty::Sort,
155 ) -> Result {
156 let expected = self.resolve_vars_if_possible(expected);
157 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = &expected {
158 self.wfckresults
159 .record_ctors_mut()
160 .insert(expr.fhir_id, sort_def.did());
161 self.check_field_exprs(expr.span, sort_def, sort_args, field_exprs, spread, &expected)?;
162 Ok(())
163 } else {
164 Err(self.emit_err(errors::UnexpectedConstructor::new(expr.span, &expected)))
165 }
166 }
167
168 fn check_record(
169 &mut self,
170 arg: &fhir::Expr<'genv>,
171 flds: &[fhir::Expr<'genv>],
172 expected: &rty::Sort,
173 ) -> Result {
174 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
175 let sorts = sort_def.struct_variant().field_sorts(sort_args);
176 if flds.len() != sorts.len() {
177 return Err(self.emit_err(errors::ArgCountMismatch::new(
178 Some(arg.span),
179 String::from("type"),
180 sorts.len(),
181 flds.len(),
182 )));
183 }
184 self.wfckresults
185 .record_ctors_mut()
186 .insert(arg.fhir_id, sort_def.did());
187
188 izip!(flds, &sorts)
189 .map(|(arg, expected)| self.check_expr(arg, expected))
190 .try_collect_exhaust()
191 } else {
192 Err(self.emit_err(errors::ArgCountMismatch::new(
193 Some(arg.span),
194 String::from("type"),
195 1,
196 flds.len(),
197 )))
198 }
199 }
200
201 pub(super) fn check_expr(&mut self, expr: &fhir::Expr<'genv>, expected: &rty::Sort) -> Result {
202 match &expr.kind {
203 fhir::ExprKind::IfThenElse(p, e1, e2) => {
204 self.check_expr(p, &rty::Sort::Bool)?;
205 self.check_expr(e1, expected)?;
206 self.check_expr(e2, expected)?;
207 }
208 fhir::ExprKind::Abs(params, body) => {
209 self.check_abs(expr, params, body, expected)?;
210 }
211 fhir::ExprKind::Record(fields) => self.check_record(expr, fields, expected)?,
212 fhir::ExprKind::Constructor(None, exprs, spread) => {
213 self.check_constructor(expr, exprs, spread, expected)?;
214 }
215 fhir::ExprKind::UnaryOp(..)
216 | fhir::ExprKind::BinaryOp(..)
217 | fhir::ExprKind::Dot(..)
218 | fhir::ExprKind::App(..)
219 | fhir::ExprKind::Alias(..)
220 | fhir::ExprKind::Var(..)
221 | fhir::ExprKind::Literal(..)
222 | fhir::ExprKind::BoundedQuant(..)
223 | fhir::ExprKind::Block(..)
224 | fhir::ExprKind::Constructor(..)
225 | fhir::ExprKind::PrimApp(..) => {
226 let found = self.synth_expr(expr)?;
227 let found = self.resolve_vars_if_possible(&found);
228 let expected = self.resolve_vars_if_possible(expected);
229 if !self.is_coercible(&found, &expected, expr.fhir_id) {
230 return Err(self.emit_sort_mismatch(expr.span, &expected, &found));
231 }
232 }
233 fhir::ExprKind::Err(_) => {
234 }
236 }
237 Ok(())
238 }
239
240 pub(super) fn check_loc(&mut self, loc: &fhir::PathExpr) -> Result {
241 let found = self.synth_path(loc);
242 if found == rty::Sort::Loc {
243 Ok(())
244 } else {
245 Err(self.emit_sort_mismatch(loc.span, &rty::Sort::Loc, &found))
246 }
247 }
248
249 fn synth_lit(&mut self, lit: fhir::Lit, expr: &fhir::Expr<'genv>) -> rty::Sort {
250 match lit {
251 fhir::Lit::Int(_, Some(fhir::NumLitKind::Int)) => rty::Sort::Int,
252 fhir::Lit::Int(_, Some(fhir::NumLitKind::Real)) => rty::Sort::Real,
253 fhir::Lit::Int(_, None) => {
254 let sort = self.next_num_var();
255 self.sort_of_literal.insert(*expr, sort.clone());
256 sort
257 }
258 fhir::Lit::Bool(_) => rty::Sort::Bool,
259 fhir::Lit::Str(_) => rty::Sort::Str,
260 fhir::Lit::Char(_) => rty::Sort::Char,
261 }
262 }
263
264 fn synth_prim_app(
265 &mut self,
266 op: &fhir::BinOp,
267 e1: &fhir::Expr<'genv>,
268 e2: &fhir::Expr<'genv>,
269 span: Span,
270 ) -> Result<rty::Sort> {
271 let Some((inputs, output)) = prim_op_sort(op) else {
272 return Err(self.emit_err(errors::UnsupportedPrimOp::new(span, *op)));
273 };
274 let [sort1, sort2] = &inputs[..] else {
275 return Err(self.emit_err(errors::ArgCountMismatch::new(
276 Some(span),
277 String::from("primop app"),
278 inputs.len(),
279 2,
280 )));
281 };
282 self.check_expr(e1, sort1)?;
283 self.check_expr(e2, sort2)?;
284 Ok(output)
285 }
286
287 fn synth_expr(&mut self, expr: &fhir::Expr<'genv>) -> Result<rty::Sort> {
288 match expr.kind {
289 fhir::ExprKind::Var(QPathExpr::Resolved(path, _)) => Ok(self.synth_path(&path)),
290 fhir::ExprKind::Var(QPathExpr::TypeRelative(..)) => {
291 Ok(self.synth_type_relative_path(expr))
292 }
293 fhir::ExprKind::Literal(lit) => Ok(self.synth_lit(lit, expr)),
294 fhir::ExprKind::BinaryOp(op, e1, e2) => self.synth_binary_op(expr, op, e1, e2),
295 fhir::ExprKind::PrimApp(op, e1, e2) => self.synth_prim_app(&op, e1, e2, expr.span),
296 fhir::ExprKind::UnaryOp(op, e) => self.synth_unary_op(op, e),
297 fhir::ExprKind::App(callee, args) => {
298 let sort = self.ensure_resolved_path(&callee)?;
299 let Some(poly_fsort) = self.is_coercible_to_func(&sort, callee.fhir_id) else {
300 return Err(self.emit_err(errors::ExpectedFun::new(callee.span, &sort)));
301 };
302 let fsort = self.instantiate_func_sort(expr, poly_fsort);
303 self.synth_app(fsort, args, expr.span)
304 }
305 fhir::ExprKind::BoundedQuant(.., body) => {
306 self.check_expr(body, &rty::Sort::Bool)?;
307 Ok(rty::Sort::Bool)
308 }
309 fhir::ExprKind::Alias(_alias_reft, args) => {
310 let fsort = self.sort_of_alias_reft(expr.fhir_id);
313 self.synth_app(fsort, args, expr.span)
314 }
315 fhir::ExprKind::IfThenElse(p, e1, e2) => {
316 self.check_expr(p, &rty::Sort::Bool)?;
317 let sort = self.synth_expr(e1)?;
318 self.check_expr(e2, &sort)?;
319 Ok(sort)
320 }
321 fhir::ExprKind::Dot(base, fld) => {
322 let sort = self.synth_expr(base)?;
323 let sort = self
324 .fully_resolve(&sort)
325 .map_err(|_| self.emit_err(errors::CannotInferSort::new(base.span)))?;
326 match &sort {
327 rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) => {
328 let (proj, sort) = sort_def
329 .struct_variant()
330 .field_by_name(sort_def.did(), sort_args, fld.name)
331 .ok_or_else(|| self.emit_field_not_found(&sort, fld))?;
332 self.wfckresults
333 .field_projs_mut()
334 .insert(expr.fhir_id, proj);
335 Ok(sort)
336 }
337 rty::Sort::Bool | rty::Sort::Int | rty::Sort::Real => {
338 Err(self.emit_err(errors::InvalidPrimitiveDotAccess::new(&sort, fld)))
339 }
340 _ => Err(self.emit_field_not_found(&sort, fld)),
341 }
342 }
343 fhir::ExprKind::Constructor(Some(path), field_exprs, spread) => {
344 let path_def_id = match path.res {
348 fhir::Res::Def(DefKind::Enum | DefKind::Struct, def_id) => def_id,
349 _ => span_bug!(expr.span, "unexpected path in constructor"),
350 };
351 let sort_def = self
352 .genv
353 .adt_sort_def_of(path_def_id)
354 .map_err(|e| self.emit_err(e))?;
355 let fresh_args: rty::List<_> = (0..sort_def.param_count())
357 .map(|_| self.next_sort_var())
358 .collect();
359 let sort = rty::Sort::App(rty::SortCtor::Adt(sort_def.clone()), fresh_args.clone());
360 self.check_field_exprs(
362 expr.span,
363 &sort_def,
364 &fresh_args,
365 field_exprs,
366 &spread,
367 &sort,
368 )?;
369 Ok(sort)
370 }
371 fhir::ExprKind::Block(decls, body) => {
372 for decl in decls {
373 self.check_expr(&decl.init, &self.param_sort(decl.param.id))?;
374 }
375 self.synth_expr(body)
376 }
377 _ => Err(self.emit_err(errors::CannotInferSort::new(expr.span))),
378 }
379 }
380
381 fn synth_path(&mut self, path: &fhir::PathExpr) -> rty::Sort {
382 self.node_sort
383 .get(&path.fhir_id)
384 .unwrap_or_else(|| tracked_span_bug!("no sort found for path: `{path:?}`"))
385 .clone()
386 }
387
388 fn synth_type_relative_path(&mut self, expr: &fhir::Expr) -> rty::Sort {
389 self.node_sort
390 .get(&expr.fhir_id)
391 .unwrap_or_else(|| tracked_span_bug!("no sort found for: `{expr:?}`"))
392 .clone()
393 }
394
395 fn check_integral(&mut self, op: fhir::BinOp, sort: &rty::Sort, span: Span) -> Result {
396 if matches!(op, fhir::BinOp::Mod) {
397 let sort = self
398 .fully_resolve(sort)
399 .map_err(|_| self.emit_err(errors::CannotInferSort::new(span)))?;
400 if !matches!(sort, rty::Sort::Int | rty::Sort::BitVec(_)) {
401 span_bug!(span, "unexpected sort {sort:?} for operator {op:?}");
402 }
403 }
404 Ok(())
405 }
406
407 fn synth_binary_op(
408 &mut self,
409 expr: &fhir::Expr<'genv>,
410 op: fhir::BinOp,
411 e1: &fhir::Expr<'genv>,
412 e2: &fhir::Expr<'genv>,
413 ) -> Result<rty::Sort> {
414 match op {
415 fhir::BinOp::Or | fhir::BinOp::And | fhir::BinOp::Iff | fhir::BinOp::Imp => {
416 self.check_expr(e1, &rty::Sort::Bool)?;
417 self.check_expr(e2, &rty::Sort::Bool)?;
418 Ok(rty::Sort::Bool)
419 }
420 fhir::BinOp::Eq | fhir::BinOp::Ne => {
421 let sort = self.next_sort_var();
422 self.check_expr(e1, &sort)?;
423 self.check_expr(e2, &sort)?;
424 Ok(rty::Sort::Bool)
425 }
426 fhir::BinOp::Lt | fhir::BinOp::Le | fhir::BinOp::Gt | fhir::BinOp::Ge => {
427 let sort = self.next_sort_var();
428 self.check_expr(e1, &sort)?;
429 self.check_expr(e2, &sort)?;
430 self.sort_of_bin_op.insert(*expr, sort.clone());
431 Ok(rty::Sort::Bool)
432 }
433 fhir::BinOp::Add
434 | fhir::BinOp::Sub
435 | fhir::BinOp::Mul
436 | fhir::BinOp::Div
437 | fhir::BinOp::Mod => {
438 let sort = self.next_num_var();
439 self.check_expr(e1, &sort)?;
440 self.check_expr(e2, &sort)?;
441 self.sort_of_bin_op.insert(*expr, sort.clone());
442 self.check_integral(op, &sort, expr.span)?;
444
445 Ok(sort)
446 }
447 fhir::BinOp::BitAnd
448 | fhir::BinOp::BitOr
449 | fhir::BinOp::BitXor
450 | fhir::BinOp::BitShl
451 | fhir::BinOp::BitShr => {
452 let sort = rty::Sort::BitVec(self.next_bv_size_var());
453 self.check_expr(e1, &sort)?;
454 self.check_expr(e2, &sort)?;
455 Ok(sort)
456 }
457 }
458 }
459
460 fn synth_unary_op(&mut self, op: fhir::UnOp, e: &fhir::Expr<'genv>) -> Result<rty::Sort> {
461 match op {
462 fhir::UnOp::Not => {
463 self.check_expr(e, &rty::Sort::Bool)?;
464 Ok(rty::Sort::Bool)
465 }
466 fhir::UnOp::Neg => {
467 self.check_expr(e, &rty::Sort::Int)?;
468 Ok(rty::Sort::Int)
469 }
470 }
471 }
472
473 fn synth_app(
474 &mut self,
475 fsort: FuncSort,
476 args: &[fhir::Expr<'genv>],
477 span: Span,
478 ) -> Result<rty::Sort> {
479 if args.len() != fsort.inputs().len() {
480 return Err(self.emit_err(errors::ArgCountMismatch::new(
481 Some(span),
482 String::from("function"),
483 fsort.inputs().len(),
484 args.len(),
485 )));
486 }
487
488 iter::zip(args, fsort.inputs())
489 .try_for_each_exhaust(|(arg, formal)| self.check_expr(arg, formal))?;
490
491 Ok(fsort.output().clone())
492 }
493
494 fn instantiate_func_sort(
495 &mut self,
496 app_expr: &fhir::Expr<'genv>,
497 fsort: rty::PolyFuncSort,
498 ) -> rty::FuncSort {
499 let args = fsort
500 .params()
501 .map(|kind| {
502 match kind {
503 rty::SortParamKind::Sort => rty::SortArg::Sort(self.next_sort_var()),
504 rty::SortParamKind::BvSize => rty::SortArg::BvSize(self.next_bv_size_var()),
505 }
506 })
507 .collect_vec();
508 self.sort_args_of_app
509 .insert(*app_expr, List::from_slice(&args));
510 fsort.instantiate(&args)
511 }
512
513 pub(crate) fn insert_node_sort(&mut self, fhir_id: FhirId, sort: rty::Sort) {
514 self.node_sort.insert(fhir_id, sort);
515 }
516
517 pub(crate) fn sort_of_bty(&self, bty: &fhir::BaseTy) -> rty::Sort {
518 self.node_sort
519 .get(&bty.fhir_id)
520 .unwrap_or_else(|| tracked_span_bug!("no sort found for bty: `{bty:?}`"))
521 .clone()
522 }
523
524 pub(crate) fn insert_path_args(&mut self, fhir_id: FhirId, args: rty::GenericArgs) {
525 self.path_args.insert(fhir_id, args);
526 }
527
528 pub(crate) fn path_args(&self, fhir_id: FhirId) -> rty::GenericArgs {
529 self.path_args
530 .get(&fhir_id)
531 .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
532 .clone()
533 }
534
535 pub(crate) fn insert_sort_for_alias_reft(&mut self, fhir_id: FhirId, fsort: rty::FuncSort) {
536 self.sort_of_alias_reft.insert(fhir_id, fsort);
537 }
538
539 fn sort_of_alias_reft(&self, fhir_id: FhirId) -> rty::FuncSort {
540 self.sort_of_alias_reft
541 .get(&fhir_id)
542 .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
543 .clone()
544 }
545
546 fn normalize_projection_sort(
547 genv: GlobalEnv,
548 owner: FluxOwnerId,
549 sort: rty::Sort,
550 ) -> rty::Sort {
551 let infcx = genv
552 .tcx()
553 .infer_ctxt()
554 .with_next_trait_solver(true)
555 .build(TypingMode::non_body_analysis());
556 if let Some(def_id) = owner.def_id()
557 && let def_id = genv.maybe_extern_id(def_id).resolved_id()
558 && let Ok(sort) = sort.normalize_sorts(def_id, genv, &infcx)
559 {
560 sort
561 } else {
562 sort
563 }
564 }
565
566 pub(crate) fn normalize_sorts(&mut self) -> QueryResult {
571 let genv = self.genv;
572 for sort in self.node_sort.values_mut() {
573 *sort = Self::normalize_projection_sort(genv, self.owner, sort.clone());
574 }
575 for fsort in self.sort_of_alias_reft.values_mut() {
576 *fsort = genv.deep_normalize_weak_alias_sorts(fsort)?;
577 }
578 Ok(())
579 }
580}
581
582impl<'genv> InferCtxt<'genv, '_> {
583 pub(super) fn declare_param(&mut self, param: fhir::RefineParam<'genv>, sort: rty::Sort) {
584 self.params.insert(param.id, (param, sort));
585 }
586
587 fn is_coercible(&mut self, sort1: &rty::Sort, sort2: &rty::Sort, fhir_id: FhirId) -> bool {
592 if self.try_equate(sort1, sort2).is_some() {
593 return true;
594 }
595
596 let mut sort1 = sort1.clone();
597 let mut sort2 = sort2.clone();
598
599 let mut coercions = vec![];
600 if let Some((def_id, sort)) = self.is_single_field_struct(&sort1) {
601 coercions.push(rty::Coercion::Project(def_id));
602 sort1 = sort.clone();
603 }
604 if let Some((def_id, sort)) = self.is_single_field_struct(&sort2) {
605 coercions.push(rty::Coercion::Inject(def_id));
606 sort2 = sort.clone();
607 }
608 self.wfckresults.coercions_mut().insert(fhir_id, coercions);
609 self.try_equate(&sort1, &sort2).is_some()
610 }
611
612 fn is_coercible_from_func(
613 &mut self,
614 sort: &rty::Sort,
615 fhir_id: FhirId,
616 ) -> Option<rty::PolyFuncSort> {
617 if let rty::Sort::Func(fsort) = sort {
618 Some(fsort.clone())
619 } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
620 self.wfckresults
621 .coercions_mut()
622 .insert(fhir_id, vec![rty::Coercion::Inject(def_id)]);
623 Some(fsort.clone())
624 } else {
625 None
626 }
627 }
628
629 fn is_coercible_to_func(
630 &mut self,
631 sort: &rty::Sort,
632 fhir_id: FhirId,
633 ) -> Option<rty::PolyFuncSort> {
634 if let rty::Sort::Func(fsort) = sort {
635 Some(fsort.clone())
636 } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
637 self.wfckresults
638 .coercions_mut()
639 .insert(fhir_id, vec![rty::Coercion::Project(def_id)]);
640 Some(fsort.clone())
641 } else {
642 None
643 }
644 }
645
646 fn try_equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
647 let sort1 = self.resolve_vars_if_possible(sort1);
648 let sort2 = self.resolve_vars_if_possible(sort2);
649 self.try_equate_inner(&sort1, &sort2)
650 }
651
652 fn try_equate_inner(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
653 match (sort1, sort2) {
654 (rty::Sort::Infer(rty::SortVar(vid1)), rty::Sort::Infer(rty::SortVar(vid2))) => {
655 self.sort_unification_table
656 .unify_var_var(*vid1, *vid2)
657 .ok()?;
658 }
659 (rty::Sort::Infer(rty::SortVar(vid)), sort)
660 | (sort, rty::Sort::Infer(rty::SortVar(vid))) => {
661 self.sort_unification_table
662 .unify_var_value(*vid, Some(sort.clone()))
663 .ok()?;
664 }
665 (rty::Sort::Infer(rty::NumVar(vid1)), rty::Sort::Infer(rty::NumVar(vid2))) => {
666 self.num_unification_table
667 .unify_var_var(*vid1, *vid2)
668 .ok()?;
669 }
670 (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::Int)
671 | (rty::Sort::Int, rty::Sort::Infer(rty::NumVar(vid))) => {
672 self.num_unification_table
673 .unify_var_value(*vid, Some(rty::NumVarValue::Int))
674 .ok()?;
675 }
676 (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::Real)
677 | (rty::Sort::Real, rty::Sort::Infer(rty::NumVar(vid))) => {
678 self.num_unification_table
679 .unify_var_value(*vid, Some(rty::NumVarValue::Real))
680 .ok()?;
681 }
682 (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::BitVec(sz))
683 | (rty::Sort::BitVec(sz), rty::Sort::Infer(rty::NumVar(vid))) => {
684 self.num_unification_table
685 .unify_var_value(*vid, Some(rty::NumVarValue::BitVec(*sz)))
686 .ok()?;
687 }
688
689 (rty::Sort::App(ctor1, args1), rty::Sort::App(ctor2, args2)) => {
690 if ctor1 != ctor2 || args1.len() != args2.len() {
691 return None;
692 }
693 let mut args = vec![];
694 for (s1, s2) in args1.iter().zip(args2.iter()) {
695 args.push(self.try_equate_inner(s1, s2)?);
696 }
697 }
698 (rty::Sort::BitVec(size1), rty::Sort::BitVec(size2)) => {
699 self.try_equate_bv_sizes(*size1, *size2)?;
700 }
701 _ if sort1 == sort2 => {}
702 _ => return None,
703 }
704 Some(sort1.clone())
705 }
706
707 fn try_equate_bv_sizes(
708 &mut self,
709 size1: rty::BvSize,
710 size2: rty::BvSize,
711 ) -> Option<rty::BvSize> {
712 match (size1, size2) {
713 (rty::BvSize::Infer(vid1), rty::BvSize::Infer(vid2)) => {
714 self.bv_size_unification_table
715 .unify_var_var(vid1, vid2)
716 .ok()?;
717 }
718 (rty::BvSize::Infer(vid), size) | (size, rty::BvSize::Infer(vid)) => {
719 self.bv_size_unification_table
720 .unify_var_value(vid, Some(size))
721 .ok()?;
722 }
723 _ if size1 == size2 => {}
724 _ => return None,
725 }
726 Some(size1)
727 }
728
729 fn equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> rty::Sort {
730 self.try_equate(sort1, sort2)
731 .unwrap_or_else(|| bug!("failed to equate sorts: `{sort1:?}` `{sort2:?}`"))
732 }
733
734 pub(crate) fn next_sort_var(&mut self) -> rty::Sort {
735 rty::Sort::Infer(rty::SortVar(self.next_sort_vid()))
736 }
737
738 fn next_num_var(&mut self) -> rty::Sort {
739 rty::Sort::Infer(rty::NumVar(self.next_num_vid()))
740 }
741
742 pub(crate) fn next_sort_vid(&mut self) -> rty::SortVid {
743 self.sort_unification_table.new_key(None)
744 }
745
746 fn next_num_vid(&mut self) -> rty::NumVid {
747 self.num_unification_table.new_key(None)
748 }
749
750 fn next_bv_size_var(&mut self) -> rty::BvSize {
751 rty::BvSize::Infer(self.next_bv_size_vid())
752 }
753
754 fn next_bv_size_vid(&mut self) -> rty::BvSizeVid {
755 self.bv_size_unification_table.new_key(None)
756 }
757
758 fn ensure_resolved_path(&mut self, path: &fhir::PathExpr) -> Result<rty::Sort> {
759 let sort = self.synth_path(path);
760 self.fully_resolve(&sort)
761 .map_err(|_| self.emit_err(errors::CannotInferSort::new(path.span)))
762 }
763
764 fn is_single_field_struct(&mut self, sort: &rty::Sort) -> Option<(DefId, rty::Sort)> {
765 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = sort
766 && let Some(variant) = sort_def.opt_struct_variant()
767 && let [sort] = &variant.field_sorts(sort_args)[..]
768 {
769 Some((sort_def.did(), sort.clone()))
770 } else {
771 None
772 }
773 }
774
775 pub(crate) fn into_results(mut self) -> Result<WfckResults> {
776 for (node, sort) in std::mem::take(&mut self.sort_of_literal) {
779 if let rty::Sort::Infer(rty::SortInfer::NumVar(vid)) = &sort {
785 let _ = self
786 .num_unification_table
787 .unify_var_value(*vid, Some(rty::NumVarValue::Int));
788 }
789
790 let sort = self
791 .fully_resolve(&sort)
792 .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
793 self.wfckresults.node_sorts_mut().insert(node.fhir_id, sort);
794 }
795
796 for (node, sort) in std::mem::take(&mut self.sort_of_bin_op) {
798 let sort = self
799 .fully_resolve(&sort)
800 .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
801 self.wfckresults
802 .bin_op_sorts_mut()
803 .insert(node.fhir_id, sort);
804 }
805
806 let allow_uninterpreted_cast = self
807 .owner
808 .def_id()
809 .map_or_else(flux_config::allow_uninterpreted_cast, |def_id| {
810 self.genv.infer_opts(def_id).allow_uninterpreted_cast
811 });
812
813 for (node, sort_args) in std::mem::take(&mut self.sort_args_of_app) {
815 let mut res = vec![];
816 for sort_arg in &sort_args {
817 let sort_arg = match sort_arg {
818 rty::SortArg::Sort(sort) => {
819 let sort = self
820 .fully_resolve(sort)
821 .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
822 rty::SortArg::Sort(sort)
823 }
824 rty::SortArg::BvSize(rty::BvSize::Infer(vid)) => {
825 let size = self
826 .bv_size_unification_table
827 .probe_value(*vid)
828 .ok_or_else(|| {
829 self.emit_err(errors::CannotInferSort::new(node.span))
830 })?;
831 rty::SortArg::BvSize(size)
832 }
833 _ => sort_arg.clone(),
834 };
835 res.push(sort_arg);
836 }
837 if let fhir::ExprKind::App(callee, _) = node.kind
838 && matches!(callee.res, fhir::Res::GlobalFunc(fhir::SpecFuncKind::Cast))
839 {
840 let [rty::SortArg::Sort(from), rty::SortArg::Sort(to)] = &res[..] else {
841 span_bug!(node.span, "invalid sort args!")
842 };
843 if !allow_uninterpreted_cast
844 && matches!(from.cast_kind(to), rty::CastKind::Uninterpreted)
845 {
846 return Err(self.emit_err(errors::InvalidCast::new(node.span, from, to)));
847 }
848 }
849 self.wfckresults
850 .fn_app_sorts_mut()
851 .insert(node.fhir_id, res.into());
852 }
853
854 for (_, (param, sort)) in std::mem::take(&mut self.params) {
856 let sort = self
857 .fully_resolve(&sort)
858 .map_err(|_| self.emit_err(errors::SortAnnotationNeeded::new(¶m)))?;
859 self.wfckresults.param_sorts_mut().insert(param.id, sort);
860 }
861
862 Ok(self.wfckresults)
863 }
864
865 pub(crate) fn infer_mode(&self, id: fhir::ParamId) -> fhir::InferMode {
866 fhir::InferMode::from_param_kind(self.params[&id].0.kind)
867 }
868
869 #[track_caller]
870 pub(crate) fn param_sort(&self, id: fhir::ParamId) -> rty::Sort {
871 self.params
872 .get(&id)
873 .unwrap_or_else(|| bug!("no entry found for `{id:?}`"))
874 .1
875 .clone()
876 }
877
878 fn shallow_resolve(&mut self, sort: &rty::Sort) -> rty::Sort {
879 sort.fold_with(&mut ShallowResolver { infcx: self })
880 }
881
882 fn resolve_vars_if_possible(&mut self, sort: &rty::Sort) -> rty::Sort {
883 sort.fold_with(&mut OpportunisticResolver { infcx: self })
884 }
885
886 pub(crate) fn fully_resolve(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, ()> {
887 sort.try_fold_with(&mut FullResolver { infcx: self })
888 }
889}
890
891pub(crate) struct ImplicitParamInferer<'a, 'genv, 'tcx> {
902 infcx: &'a mut InferCtxt<'genv, 'tcx>,
903 errors: Errors<'genv>,
904}
905
906impl<'a, 'genv, 'tcx> ImplicitParamInferer<'a, 'genv, 'tcx> {
907 pub(crate) fn infer(
908 infcx: &'a mut InferCtxt<'genv, 'tcx>,
909 node: &fhir::OwnerNode<'genv>,
910 ) -> Result {
911 let errors = Errors::new(infcx.genv.sess());
912 let mut vis = Self { infcx, errors };
913 vis.visit_node(node);
914 vis.errors.into_result()
915 }
916
917 fn infer_implicit_params(&mut self, idx: &fhir::Expr, expected: &rty::Sort) {
918 match idx.kind {
919 fhir::ExprKind::Var(QPathExpr::Resolved(var, Some(_))) => {
920 let (_, id) = var.res.expect_param();
921 let found = self.infcx.param_sort(id);
922 self.infcx.equate(&found, expected);
923 }
924 fhir::ExprKind::Record(flds) => {
925 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
926 let sorts = sort_def.struct_variant().field_sorts(sort_args);
927 if flds.len() != sorts.len() {
928 self.errors.emit(errors::ArgCountMismatch::new(
929 Some(idx.span),
930 String::from("type"),
931 sorts.len(),
932 flds.len(),
933 ));
934 return;
935 }
936 for (f, sort) in iter::zip(flds, &sorts) {
937 self.infer_implicit_params(f, sort);
938 }
939 } else {
940 self.errors.emit(errors::ArgCountMismatch::new(
941 Some(idx.span),
942 String::from("type"),
943 1,
944 flds.len(),
945 ));
946 }
947 }
948 _ => {}
949 }
950 }
951}
952
953impl<'genv> fhir::visit::Visitor<'genv> for ImplicitParamInferer<'_, 'genv, '_> {
954 fn visit_ty(&mut self, ty: &fhir::Ty<'genv>) {
955 if let fhir::TyKind::Indexed(bty, idx) = &ty.kind {
956 let expected = self.infcx.sort_of_bty(bty);
957 self.infer_implicit_params(idx, &expected);
958 }
959 fhir::visit::walk_ty(self, ty);
960 }
961}
962
963impl InferCtxt<'_, '_> {
964 #[track_caller]
965 fn emit_sort_mismatch(
966 &mut self,
967 span: Span,
968 expected: &rty::Sort,
969 found: &rty::Sort,
970 ) -> ErrorGuaranteed {
971 let expected = self.resolve_vars_if_possible(expected);
972 let found = self.resolve_vars_if_possible(found);
973 self.emit_err(errors::SortMismatch::new(span, expected, found))
974 }
975
976 fn emit_field_not_found(&mut self, sort: &rty::Sort, field: Ident) -> ErrorGuaranteed {
977 self.emit_err(errors::FieldNotFound::new(sort.clone(), field))
978 }
979
980 #[track_caller]
981 fn emit_err<'b>(&'b self, err: impl Diagnostic<'b>) -> ErrorGuaranteed {
982 self.genv.sess().emit_err(err)
983 }
984}
985
986struct ShallowResolver<'a, 'genv, 'tcx> {
987 infcx: &'a mut InferCtxt<'genv, 'tcx>,
988}
989
990impl TypeFolder for ShallowResolver<'_, '_, '_> {
991 fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
992 match sort {
993 rty::Sort::Infer(rty::SortVar(vid)) => {
994 self.infcx
1000 .sort_unification_table
1001 .probe_value(*vid)
1002 .map(|sort| sort.fold_with(self))
1003 .unwrap_or(sort.clone())
1004 }
1005 rty::Sort::Infer(rty::NumVar(vid)) => {
1006 self.infcx
1008 .num_unification_table
1009 .probe_value(*vid)
1010 .map(|val| val.to_sort().fold_with(self))
1011 .unwrap_or(sort.clone())
1012 }
1013 rty::Sort::BitVec(rty::BvSize::Infer(vid)) => {
1014 self.infcx
1015 .bv_size_unification_table
1016 .probe_value(*vid)
1017 .map(rty::Sort::BitVec)
1018 .unwrap_or(sort.clone())
1019 }
1020 _ => sort.clone(),
1021 }
1022 }
1023}
1024
1025struct OpportunisticResolver<'a, 'genv, 'tcx> {
1026 infcx: &'a mut InferCtxt<'genv, 'tcx>,
1027}
1028
1029impl TypeFolder for OpportunisticResolver<'_, '_, '_> {
1030 fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
1031 let s = self.infcx.shallow_resolve(sort);
1032 s.super_fold_with(self)
1033 }
1034}
1035
1036struct FullResolver<'a, 'genv, 'tcx> {
1037 infcx: &'a mut InferCtxt<'genv, 'tcx>,
1038}
1039
1040impl FallibleTypeFolder for FullResolver<'_, '_, '_> {
1041 type Error = ();
1042
1043 fn try_fold_sort(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, Self::Error> {
1044 let s = self.infcx.shallow_resolve(sort);
1045 match s {
1046 rty::Sort::Infer(_) | rty::Sort::BitVec(rty::BvSize::Infer(_)) => Err(()),
1047 _ => s.try_super_fold_with(self),
1048 }
1049 }
1050}
1051
1052#[derive_where(Default)]
1056struct NodeMap<'genv, T> {
1057 map: FxHashMap<FhirId, (fhir::Expr<'genv>, T)>,
1058}
1059
1060impl<'genv, T> NodeMap<'genv, T> {
1061 fn insert(&mut self, node: fhir::Expr<'genv>, data: T) {
1063 assert!(self.map.insert(node.fhir_id, (node, data)).is_none());
1064 }
1065}
1066
1067impl<'genv, T> IntoIterator for NodeMap<'genv, T> {
1068 type Item = (fhir::Expr<'genv>, T);
1069
1070 type IntoIter = std::collections::hash_map::IntoValues<FhirId, Self::Item>;
1071
1072 fn into_iter(self) -> Self::IntoIter {
1073 self.map.into_values()
1074 }
1075}