1use std::iter;
2
3use derive_where::derive_where;
4use ena::unify::InPlaceUnificationTable;
5use flux_common::{bug, iter::IterExt, span_bug, tracked_span_bug};
6use flux_errors::{ErrorGuaranteed, Errors};
7use flux_infer::projections::NormalizeExt;
8use flux_middle::{
9 fhir::{self, FhirId, FluxOwnerId, QPathExpr, visit::Visitor as _},
10 global_env::GlobalEnv,
11 queries::QueryResult,
12 rty::{
13 self, AdtSortDef, FuncSort, List, WfckResults,
14 fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable},
15 },
16};
17use itertools::{Itertools, izip};
18use rustc_data_structures::unord::UnordMap;
19use rustc_errors::Diagnostic;
20use rustc_hash::FxHashMap;
21use rustc_hir::def::DefKind;
22use rustc_middle::ty::TypingMode;
23use rustc_span::{Span, def_id::DefId, symbol::Ident};
24
25use super::errors;
26use crate::rustc_infer::infer::TyCtxtInferExt;
27
28type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
29
30pub(super) struct InferCtxt<'genv, 'tcx> {
31 pub genv: GlobalEnv<'genv, 'tcx>,
32 pub owner: FluxOwnerId,
33 pub wfckresults: WfckResults,
34 sort_unification_table: InPlaceUnificationTable<rty::SortVid>,
35 bv_size_unification_table: InPlaceUnificationTable<rty::BvSizeVid>,
36 params: FxHashMap<fhir::ParamId, (fhir::RefineParam<'genv>, rty::Sort)>,
37 node_sort: FxHashMap<FhirId, rty::Sort>,
38 path_args: UnordMap<FhirId, rty::GenericArgs>,
39 sort_of_alias_reft: FxHashMap<FhirId, rty::FuncSort>,
40 sort_of_literal: NodeMap<'genv, rty::Sort>,
41 sort_of_bin_op: NodeMap<'genv, rty::Sort>,
42 sort_args_of_app: NodeMap<'genv, List<rty::SortArg>>,
43}
44
45pub fn prim_op_sort(op: &fhir::BinOp) -> Option<(Vec<rty::Sort>, rty::Sort)> {
46 match op {
47 fhir::BinOp::BitAnd
48 | fhir::BinOp::BitOr
49 | fhir::BinOp::BitXor
50 | fhir::BinOp::BitShl
51 | fhir::BinOp::BitShr => Some((vec![rty::Sort::Int, rty::Sort::Int], rty::Sort::Int)),
52 _ => None,
53 }
54}
55
56impl<'genv, 'tcx> InferCtxt<'genv, 'tcx> {
57 pub(super) fn new(genv: GlobalEnv<'genv, 'tcx>, owner: FluxOwnerId) -> Self {
58 let mut sort_unification_table = InPlaceUnificationTable::new();
60 sort_unification_table.new_key(Default::default());
61 Self {
62 genv,
63 owner,
64 wfckresults: WfckResults::new(owner),
65 sort_unification_table,
66 bv_size_unification_table: InPlaceUnificationTable::new(),
67 params: Default::default(),
68 node_sort: Default::default(),
69 path_args: Default::default(),
70 sort_of_alias_reft: Default::default(),
71 sort_of_literal: Default::default(),
72 sort_of_bin_op: Default::default(),
73 sort_args_of_app: Default::default(),
74 }
75 }
76
77 fn check_abs(
78 &mut self,
79 expr: &fhir::Expr<'genv>,
80 params: &[fhir::RefineParam],
81 body: &fhir::Expr<'genv>,
82 expected: &rty::Sort,
83 ) -> Result {
84 let Some(fsort) = self.is_coercible_from_func(expected, expr.fhir_id) else {
85 return Err(self.emit_err(errors::UnexpectedFun::new(expr.span, expected)));
86 };
87
88 let fsort = fsort.expect_mono();
89
90 if params.len() != fsort.inputs().len() {
91 return Err(self.emit_err(errors::ParamCountMismatch::new(
92 expr.span,
93 fsort.inputs().len(),
94 params.len(),
95 )));
96 }
97 iter::zip(params, fsort.inputs()).try_for_each_exhaust(|(param, expected)| {
98 let found = self.param_sort(param.id);
99 if self.try_equate(&found, expected).is_none() {
100 return Err(self.emit_sort_mismatch(param.span, expected, &found));
101 }
102 Ok(())
103 })?;
104 self.check_expr(body, fsort.output())?;
105 self.wfckresults
106 .node_sorts_mut()
107 .insert(body.fhir_id, fsort.output().clone());
108 Ok(())
109 }
110
111 fn check_field_exprs(
112 &mut self,
113 expr_span: Span,
114 sort_def: &AdtSortDef,
115 sort_args: &[rty::Sort],
116 field_exprs: &[fhir::FieldExpr<'genv>],
117 spread: &Option<&fhir::Spread<'genv>>,
118 expected: &rty::Sort,
119 ) -> Result {
120 let sort_by_field_name = sort_def.struct_variant().sort_by_field_name(sort_args);
121 let mut used_fields = FxHashMap::default();
122 for expr in field_exprs {
123 let Some(sort) = sort_by_field_name.get(&expr.ident.name) else {
125 return Err(self.emit_err(errors::FieldNotFound::new(expected.clone(), expr.ident)));
126 };
127 if let Some(old_field) = used_fields.insert(expr.ident.name, expr.ident) {
128 return Err(self.emit_err(errors::DuplicateFieldUsed::new(expr.ident, old_field)));
129 }
130 self.check_expr(&expr.expr, sort)?;
131 }
132 if let Some(spread) = spread {
133 self.check_expr(&spread.expr, expected)
135 } else if sort_by_field_name.len() != used_fields.len() {
136 let missing_fields = sort_by_field_name
138 .into_keys()
139 .filter(|x| !used_fields.contains_key(x))
140 .collect();
141 Err(self.emit_err(errors::ConstructorMissingFields::new(expr_span, missing_fields)))
142 } else {
143 Ok(())
144 }
145 }
146
147 fn check_constructor(
148 &mut self,
149 expr: &fhir::Expr,
150 field_exprs: &[fhir::FieldExpr<'genv>],
151 spread: &Option<&fhir::Spread<'genv>>,
152 expected: &rty::Sort,
153 ) -> Result {
154 let expected = self.resolve_vars_if_possible(expected);
155 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = &expected {
156 self.wfckresults
157 .record_ctors_mut()
158 .insert(expr.fhir_id, sort_def.did());
159 self.check_field_exprs(expr.span, sort_def, sort_args, field_exprs, spread, &expected)?;
160 Ok(())
161 } else {
162 Err(self.emit_err(errors::UnexpectedConstructor::new(expr.span, &expected)))
163 }
164 }
165
166 fn check_record(
167 &mut self,
168 arg: &fhir::Expr<'genv>,
169 flds: &[fhir::Expr<'genv>],
170 expected: &rty::Sort,
171 ) -> Result {
172 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
173 let sorts = sort_def.struct_variant().field_sorts(sort_args);
174 if flds.len() != sorts.len() {
175 return Err(self.emit_err(errors::ArgCountMismatch::new(
176 Some(arg.span),
177 String::from("type"),
178 sorts.len(),
179 flds.len(),
180 )));
181 }
182 self.wfckresults
183 .record_ctors_mut()
184 .insert(arg.fhir_id, sort_def.did());
185
186 izip!(flds, &sorts)
187 .map(|(arg, expected)| self.check_expr(arg, expected))
188 .try_collect_exhaust()
189 } else {
190 Err(self.emit_err(errors::ArgCountMismatch::new(
191 Some(arg.span),
192 String::from("type"),
193 1,
194 flds.len(),
195 )))
196 }
197 }
198
199 pub(super) fn check_expr(&mut self, expr: &fhir::Expr<'genv>, expected: &rty::Sort) -> Result {
200 match &expr.kind {
201 fhir::ExprKind::IfThenElse(p, e1, e2) => {
202 self.check_expr(p, &rty::Sort::Bool)?;
203 self.check_expr(e1, expected)?;
204 self.check_expr(e2, expected)?;
205 }
206 fhir::ExprKind::Abs(params, body) => {
207 self.check_abs(expr, params, body, expected)?;
208 }
209 fhir::ExprKind::Record(fields) => self.check_record(expr, fields, expected)?,
210 fhir::ExprKind::Constructor(None, exprs, spread) => {
211 self.check_constructor(expr, exprs, spread, expected)?;
212 }
213 fhir::ExprKind::UnaryOp(..)
214 | fhir::ExprKind::SetLiteral(..)
215 | fhir::ExprKind::BinaryOp(..)
216 | fhir::ExprKind::Dot(..)
217 | fhir::ExprKind::App(..)
218 | fhir::ExprKind::Alias(..)
219 | fhir::ExprKind::Var(..)
220 | fhir::ExprKind::Literal(..)
221 | fhir::ExprKind::BoundedQuant(..)
222 | fhir::ExprKind::Block(..)
223 | fhir::ExprKind::Constructor(..)
224 | fhir::ExprKind::PrimApp(..) => {
225 let found = self.synth_expr(expr)?;
226 let found = self.resolve_vars_if_possible(&found);
227 let expected = self.resolve_vars_if_possible(expected);
228 if !self.is_coercible(&found, &expected, expr.fhir_id) {
229 return Err(self.emit_sort_mismatch(expr.span, &expected, &found));
230 }
231 }
232 fhir::ExprKind::Err(_) => {
233 }
235 }
236 Ok(())
237 }
238
239 pub(super) fn check_loc(&mut self, loc: &fhir::PathExpr) -> Result {
240 let found = self.synth_path(loc);
241 if found == rty::Sort::Loc {
242 Ok(())
243 } else {
244 Err(self.emit_sort_mismatch(loc.span, &rty::Sort::Loc, &found))
245 }
246 }
247
248 fn synth_lit(&mut self, lit: fhir::Lit, expr: &fhir::Expr<'genv>) -> rty::Sort {
249 match lit {
250 fhir::Lit::Int(_, Some(fhir::NumLitKind::Int)) => rty::Sort::Int,
251 fhir::Lit::Int(_, Some(fhir::NumLitKind::Real)) => rty::Sort::Real,
252 fhir::Lit::Int(_, None) => {
253 let sort = self.next_sort_var_with_cstr(rty::SortCstr::NUMERIC);
254 self.sort_of_literal.insert(*expr, sort.clone());
255 sort
256 }
257 fhir::Lit::Bool(_) => rty::Sort::Bool,
258 fhir::Lit::Str(_) => rty::Sort::Str,
259 fhir::Lit::Char(_) => rty::Sort::Char,
260 }
261 }
262
263 fn synth_prim_app(
264 &mut self,
265 op: &fhir::BinOp,
266 e1: &fhir::Expr<'genv>,
267 e2: &fhir::Expr<'genv>,
268 span: Span,
269 ) -> Result<rty::Sort> {
270 let Some((inputs, output)) = prim_op_sort(op) else {
271 return Err(self.emit_err(errors::UnsupportedPrimOp::new(span, *op)));
272 };
273 let [sort1, sort2] = &inputs[..] else {
274 return Err(self.emit_err(errors::ArgCountMismatch::new(
275 Some(span),
276 String::from("primop app"),
277 inputs.len(),
278 2,
279 )));
280 };
281 self.check_expr(e1, sort1)?;
282 self.check_expr(e2, sort2)?;
283 Ok(output)
284 }
285
286 fn synth_expr(&mut self, expr: &fhir::Expr<'genv>) -> Result<rty::Sort> {
287 match expr.kind {
288 fhir::ExprKind::Var(QPathExpr::Resolved(path, _)) => Ok(self.synth_path(&path)),
289 fhir::ExprKind::Var(QPathExpr::TypeRelative(..)) => {
290 Ok(self.synth_type_relative_path(expr))
291 }
292 fhir::ExprKind::Literal(lit) => Ok(self.synth_lit(lit, expr)),
293 fhir::ExprKind::BinaryOp(op, e1, e2) => self.synth_binary_op(expr, op, e1, e2),
294 fhir::ExprKind::PrimApp(op, e1, e2) => self.synth_prim_app(&op, e1, e2, expr.span),
295 fhir::ExprKind::UnaryOp(op, e) => self.synth_unary_op(op, e),
296 fhir::ExprKind::App(callee, args) => {
297 let sort = self.ensure_resolved_path(&callee)?;
298 let Some(poly_fsort) = self.is_coercible_to_func(&sort, callee.fhir_id) else {
299 return Err(self.emit_err(errors::ExpectedFun::new(callee.span, &sort)));
300 };
301 let fsort = self.instantiate_func_sort(expr, poly_fsort);
302 self.synth_app(fsort, args, expr.span)
303 }
304 fhir::ExprKind::BoundedQuant(.., body) => {
305 self.check_expr(body, &rty::Sort::Bool)?;
306 Ok(rty::Sort::Bool)
307 }
308 fhir::ExprKind::Alias(_alias_reft, args) => {
309 let fsort = self.sort_of_alias_reft(expr.fhir_id);
312 self.synth_app(fsort, args, expr.span)
313 }
314 fhir::ExprKind::IfThenElse(p, e1, e2) => {
315 self.check_expr(p, &rty::Sort::Bool)?;
316 let sort = self.synth_expr(e1)?;
317 self.check_expr(e2, &sort)?;
318 Ok(sort)
319 }
320 fhir::ExprKind::Dot(base, fld) => {
321 let sort = self.synth_expr(base)?;
322 let sort = self
323 .fully_resolve(&sort)
324 .map_err(|_| self.emit_err(errors::CannotInferSort::new(base.span)))?;
325 match &sort {
326 rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) => {
327 let (proj, sort) = sort_def
328 .struct_variant()
329 .field_by_name(sort_def.did(), sort_args, fld.name)
330 .ok_or_else(|| self.emit_field_not_found(&sort, fld))?;
331 self.wfckresults
332 .field_projs_mut()
333 .insert(expr.fhir_id, proj);
334 Ok(sort)
335 }
336 rty::Sort::Bool | rty::Sort::Int | rty::Sort::Real => {
337 Err(self.emit_err(errors::InvalidPrimitiveDotAccess::new(&sort, fld)))
338 }
339 _ => Err(self.emit_field_not_found(&sort, fld)),
340 }
341 }
342 fhir::ExprKind::Constructor(Some(path), field_exprs, spread) => {
343 let path_def_id = match path.res {
347 fhir::Res::Def(DefKind::Enum | DefKind::Struct, def_id) => def_id,
348 _ => span_bug!(expr.span, "unexpected path in constructor"),
349 };
350 let sort_def = self
351 .genv
352 .adt_sort_def_of(path_def_id)
353 .map_err(|e| self.emit_err(e))?;
354 let fresh_args: rty::List<_> = (0..sort_def.param_count())
356 .map(|_| self.next_sort_var())
357 .collect();
358 let sort = rty::Sort::App(rty::SortCtor::Adt(sort_def.clone()), fresh_args.clone());
359 self.check_field_exprs(
361 expr.span,
362 &sort_def,
363 &fresh_args,
364 field_exprs,
365 &spread,
366 &sort,
367 )?;
368 Ok(sort)
369 }
370 fhir::ExprKind::Block(decls, body) => {
371 for decl in decls {
372 self.check_expr(&decl.init, &self.param_sort(decl.param.id))?;
373 }
374 self.synth_expr(body)
375 }
376 fhir::ExprKind::SetLiteral(elems) => {
377 let elem_sort = self.next_sort_var();
378 for elem in elems {
379 self.check_expr(elem, &elem_sort)?;
380 }
381 Ok(rty::Sort::App(rty::SortCtor::Set, List::singleton(elem_sort)))
382 }
383 _ => Err(self.emit_err(errors::CannotInferSort::new(expr.span))),
384 }
385 }
386
387 fn synth_path(&mut self, path: &fhir::PathExpr) -> rty::Sort {
388 self.node_sort
389 .get(&path.fhir_id)
390 .unwrap_or_else(|| tracked_span_bug!("no sort found for path: `{path:?}`"))
391 .clone()
392 }
393
394 fn synth_type_relative_path(&mut self, expr: &fhir::Expr) -> rty::Sort {
395 self.node_sort
396 .get(&expr.fhir_id)
397 .unwrap_or_else(|| tracked_span_bug!("no sort found for: `{expr:?}`"))
398 .clone()
399 }
400
401 fn synth_binary_op(
402 &mut self,
403 expr: &fhir::Expr<'genv>,
404 op: fhir::BinOp,
405 e1: &fhir::Expr<'genv>,
406 e2: &fhir::Expr<'genv>,
407 ) -> Result<rty::Sort> {
408 match op {
409 fhir::BinOp::Or | fhir::BinOp::And | fhir::BinOp::Iff | fhir::BinOp::Imp => {
410 self.check_expr(e1, &rty::Sort::Bool)?;
411 self.check_expr(e2, &rty::Sort::Bool)?;
412 Ok(rty::Sort::Bool)
413 }
414 fhir::BinOp::Eq | fhir::BinOp::Ne => {
415 let sort = self.next_sort_var();
416 self.check_expr(e1, &sort)?;
417 self.check_expr(e2, &sort)?;
418 Ok(rty::Sort::Bool)
419 }
420 fhir::BinOp::Lt | fhir::BinOp::Le | fhir::BinOp::Gt | fhir::BinOp::Ge => {
421 let sort = self.next_sort_var();
422 self.check_expr(e1, &sort)?;
423 self.check_expr(e2, &sort)?;
424 self.sort_of_bin_op.insert(*expr, sort.clone());
425 Ok(rty::Sort::Bool)
426 }
427 fhir::BinOp::Add
428 | fhir::BinOp::Sub
429 | fhir::BinOp::Mul
430 | fhir::BinOp::Div
431 | fhir::BinOp::Mod
432 | fhir::BinOp::BitAnd
433 | fhir::BinOp::BitOr
434 | fhir::BinOp::BitXor
435 | fhir::BinOp::BitShl
436 | fhir::BinOp::BitShr => {
437 let sort = self.next_sort_var_with_cstr(rty::SortCstr::from_bin_op(op));
438 self.check_expr(e1, &sort)?;
439 self.check_expr(e2, &sort)?;
440 self.sort_of_bin_op.insert(*expr, sort.clone());
441 Ok(sort)
442 }
443 }
444 }
445
446 fn synth_unary_op(&mut self, op: fhir::UnOp, e: &fhir::Expr<'genv>) -> Result<rty::Sort> {
447 match op {
448 fhir::UnOp::Not => {
449 self.check_expr(e, &rty::Sort::Bool)?;
450 Ok(rty::Sort::Bool)
451 }
452 fhir::UnOp::Neg => {
453 self.check_expr(e, &rty::Sort::Int)?;
454 Ok(rty::Sort::Int)
455 }
456 }
457 }
458
459 fn synth_app(
460 &mut self,
461 fsort: FuncSort,
462 args: &[fhir::Expr<'genv>],
463 span: Span,
464 ) -> Result<rty::Sort> {
465 if args.len() != fsort.inputs().len() {
466 return Err(self.emit_err(errors::ArgCountMismatch::new(
467 Some(span),
468 String::from("function"),
469 fsort.inputs().len(),
470 args.len(),
471 )));
472 }
473
474 iter::zip(args, fsort.inputs())
475 .try_for_each_exhaust(|(arg, formal)| self.check_expr(arg, formal))?;
476
477 Ok(fsort.output().clone())
478 }
479
480 fn instantiate_func_sort(
481 &mut self,
482 app_expr: &fhir::Expr<'genv>,
483 fsort: rty::PolyFuncSort,
484 ) -> rty::FuncSort {
485 let args = fsort
486 .params()
487 .map(|kind| {
488 match kind {
489 rty::SortParamKind::Sort => rty::SortArg::Sort(self.next_sort_var()),
490 rty::SortParamKind::BvSize => rty::SortArg::BvSize(self.next_bv_size_var()),
491 }
492 })
493 .collect_vec();
494 self.sort_args_of_app
495 .insert(*app_expr, List::from_slice(&args));
496 fsort.instantiate(&args)
497 }
498
499 pub(crate) fn insert_node_sort(&mut self, fhir_id: FhirId, sort: rty::Sort) {
500 self.node_sort.insert(fhir_id, sort);
501 }
502
503 pub(crate) fn sort_of_bty(&self, bty: &fhir::BaseTy) -> rty::Sort {
504 self.node_sort
505 .get(&bty.fhir_id)
506 .unwrap_or_else(|| tracked_span_bug!("no sort found for bty: `{bty:?}`"))
507 .clone()
508 }
509
510 pub(crate) fn insert_path_args(&mut self, fhir_id: FhirId, args: rty::GenericArgs) {
511 self.path_args.insert(fhir_id, args);
512 }
513
514 pub(crate) fn path_args(&self, fhir_id: FhirId) -> rty::GenericArgs {
515 self.path_args
516 .get(&fhir_id)
517 .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
518 .clone()
519 }
520
521 pub(crate) fn insert_sort_for_alias_reft(&mut self, fhir_id: FhirId, fsort: rty::FuncSort) {
522 self.sort_of_alias_reft.insert(fhir_id, fsort);
523 }
524
525 fn sort_of_alias_reft(&self, fhir_id: FhirId) -> rty::FuncSort {
526 self.sort_of_alias_reft
527 .get(&fhir_id)
528 .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
529 .clone()
530 }
531
532 fn deeply_normalize_sorts<T: TypeFoldable + Clone>(
533 genv: GlobalEnv,
534 owner: FluxOwnerId,
535 t: &T,
536 ) -> QueryResult<T> {
537 let infcx = genv
538 .tcx()
539 .infer_ctxt()
540 .with_next_trait_solver(true)
541 .build(TypingMode::non_body_analysis());
542 if let Some(def_id) = owner.def_id() {
543 let def_id = genv.maybe_extern_id(def_id).resolved_id();
544 t.deeply_normalize_sorts(def_id, genv, &infcx)
545 } else {
546 Ok(t.clone())
547 }
548 }
549
550 pub(crate) fn normalize_sorts(&mut self) -> QueryResult {
555 for sort in self.node_sort.values_mut() {
556 *sort = Self::deeply_normalize_sorts(self.genv, self.owner, sort)?;
557 }
558 for fsort in self.sort_of_alias_reft.values_mut() {
559 *fsort = Self::deeply_normalize_sorts(self.genv, self.owner, fsort)?;
560 }
561 Ok(())
562 }
563}
564
565impl<'genv> InferCtxt<'genv, '_> {
566 pub(super) fn declare_param(&mut self, param: fhir::RefineParam<'genv>, sort: rty::Sort) {
567 self.params.insert(param.id, (param, sort));
568 }
569
570 fn is_coercible(&mut self, sort1: &rty::Sort, sort2: &rty::Sort, fhir_id: FhirId) -> bool {
575 if self.try_equate(sort1, sort2).is_some() {
576 return true;
577 }
578
579 let mut sort1 = sort1.clone();
580 let mut sort2 = sort2.clone();
581
582 let mut coercions = vec![];
583 if let Some((def_id, sort)) = self.is_single_field_struct(&sort1) {
584 coercions.push(rty::Coercion::Project(def_id));
585 sort1 = sort.clone();
586 }
587 if let Some((def_id, sort)) = self.is_single_field_struct(&sort2) {
588 coercions.push(rty::Coercion::Inject(def_id));
589 sort2 = sort.clone();
590 }
591 self.wfckresults.coercions_mut().insert(fhir_id, coercions);
592 self.try_equate(&sort1, &sort2).is_some()
593 }
594
595 fn is_coercible_from_func(
596 &mut self,
597 sort: &rty::Sort,
598 fhir_id: FhirId,
599 ) -> Option<rty::PolyFuncSort> {
600 if let rty::Sort::Func(fsort) = sort {
601 Some(fsort.clone())
602 } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
603 self.wfckresults
604 .coercions_mut()
605 .insert(fhir_id, vec![rty::Coercion::Inject(def_id)]);
606 Some(fsort.clone())
607 } else {
608 None
609 }
610 }
611
612 fn is_coercible_to_func(
613 &mut self,
614 sort: &rty::Sort,
615 fhir_id: FhirId,
616 ) -> Option<rty::PolyFuncSort> {
617 if let rty::Sort::Func(fsort) = sort {
618 Some(fsort.clone())
619 } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
620 self.wfckresults
621 .coercions_mut()
622 .insert(fhir_id, vec![rty::Coercion::Project(def_id)]);
623 Some(fsort.clone())
624 } else {
625 None
626 }
627 }
628
629 fn try_equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
630 let sort1 = self.resolve_vars_if_possible(sort1);
631 let sort2 = self.resolve_vars_if_possible(sort2);
632 self.try_equate_inner(&sort1, &sort2)
633 }
634
635 fn try_equate_inner(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
636 match (sort1, sort2) {
637 (rty::Sort::Infer(vid1), rty::Sort::Infer(vid2)) => {
638 self.sort_unification_table
639 .unify_var_var(*vid1, *vid2)
640 .ok()?;
641 }
642 (rty::Sort::Infer(vid), sort) | (sort, rty::Sort::Infer(vid)) => {
643 self.sort_unification_table
644 .unify_var_value(*vid, rty::SortVarVal::Solved(sort.clone()))
645 .ok()?;
646 }
647 (rty::Sort::App(ctor1, args1), rty::Sort::App(ctor2, args2)) => {
648 if ctor1 != ctor2 || args1.len() != args2.len() {
649 return None;
650 }
651 let mut args = vec![];
652 for (s1, s2) in args1.iter().zip(args2.iter()) {
653 args.push(self.try_equate_inner(s1, s2)?);
654 }
655 }
656 (rty::Sort::BitVec(size1), rty::Sort::BitVec(size2)) => {
657 self.try_equate_bv_sizes(*size1, *size2)?;
658 }
659 _ if sort1 == sort2 => {}
660 _ => return None,
661 }
662 Some(sort1.clone())
663 }
664
665 fn try_equate_bv_sizes(
666 &mut self,
667 size1: rty::BvSize,
668 size2: rty::BvSize,
669 ) -> Option<rty::BvSize> {
670 match (size1, size2) {
671 (rty::BvSize::Infer(vid1), rty::BvSize::Infer(vid2)) => {
672 self.bv_size_unification_table
673 .unify_var_var(vid1, vid2)
674 .ok()?;
675 }
676 (rty::BvSize::Infer(vid), size) | (size, rty::BvSize::Infer(vid)) => {
677 self.bv_size_unification_table
678 .unify_var_value(vid, Some(size))
679 .ok()?;
680 }
681 _ if size1 == size2 => {}
682 _ => return None,
683 }
684 Some(size1)
685 }
686
687 fn next_sort_var(&mut self) -> rty::Sort {
688 rty::Sort::Infer(self.next_sort_vid(Default::default()))
689 }
690
691 fn next_sort_var_with_cstr(&mut self, cstr: rty::SortCstr) -> rty::Sort {
692 rty::Sort::Infer(self.next_sort_vid(rty::SortVarVal::Unsolved(cstr)))
693 }
694
695 pub(crate) fn next_sort_vid(&mut self, val: rty::SortVarVal) -> rty::SortVid {
696 self.sort_unification_table.new_key(val)
697 }
698
699 fn next_bv_size_var(&mut self) -> rty::BvSize {
700 rty::BvSize::Infer(self.next_bv_size_vid())
701 }
702
703 fn next_bv_size_vid(&mut self) -> rty::BvSizeVid {
704 self.bv_size_unification_table.new_key(None)
705 }
706
707 fn ensure_resolved_path(&mut self, path: &fhir::PathExpr) -> Result<rty::Sort> {
708 let sort = self.synth_path(path);
709 self.fully_resolve(&sort)
710 .map_err(|_| self.emit_err(errors::CannotInferSort::new(path.span)))
711 }
712
713 fn is_single_field_struct(&mut self, sort: &rty::Sort) -> Option<(DefId, rty::Sort)> {
714 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = sort
715 && let Some(variant) = sort_def.opt_struct_variant()
716 && let [sort] = &variant.field_sorts(sort_args)[..]
717 {
718 Some((sort_def.did(), sort.clone()))
719 } else {
720 None
721 }
722 }
723
724 pub(crate) fn into_results(mut self) -> Result<WfckResults> {
725 for (node, sort) in std::mem::take(&mut self.sort_of_literal) {
728 if let rty::Sort::Infer(vid) = &sort {
734 let _ = self
735 .sort_unification_table
736 .unify_var_value(*vid, rty::SortVarVal::Solved(rty::Sort::Int));
737 }
738
739 let sort = self
740 .fully_resolve(&sort)
741 .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
742 self.wfckresults.node_sorts_mut().insert(node.fhir_id, sort);
743 }
744
745 for (node, sort) in std::mem::take(&mut self.sort_of_bin_op) {
747 let sort = self
748 .fully_resolve(&sort)
749 .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
750 self.wfckresults
751 .bin_op_sorts_mut()
752 .insert(node.fhir_id, sort);
753 }
754
755 let allow_uninterpreted_cast = self
756 .owner
757 .def_id()
758 .map_or_else(flux_config::allow_uninterpreted_cast, |def_id| {
759 self.genv.infer_opts(def_id).allow_uninterpreted_cast
760 });
761
762 for (node, sort_args) in std::mem::take(&mut self.sort_args_of_app) {
764 let mut res = vec![];
765 for sort_arg in &sort_args {
766 let sort_arg = match sort_arg {
767 rty::SortArg::Sort(sort) => {
768 let sort = self
769 .fully_resolve(sort)
770 .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
771 rty::SortArg::Sort(sort)
772 }
773 rty::SortArg::BvSize(rty::BvSize::Infer(vid)) => {
774 let size = self
775 .bv_size_unification_table
776 .probe_value(*vid)
777 .ok_or_else(|| {
778 self.emit_err(errors::CannotInferSort::new(node.span))
779 })?;
780 rty::SortArg::BvSize(size)
781 }
782 _ => sort_arg.clone(),
783 };
784 res.push(sort_arg);
785 }
786 if let fhir::ExprKind::App(callee, _) = node.kind
787 && matches!(callee.res, fhir::Res::GlobalFunc(fhir::SpecFuncKind::Cast))
788 {
789 let [rty::SortArg::Sort(from), rty::SortArg::Sort(to)] = &res[..] else {
790 span_bug!(node.span, "invalid sort args!")
791 };
792 if !allow_uninterpreted_cast
793 && matches!(from.cast_kind(to), rty::CastKind::Uninterpreted)
794 {
795 return Err(self.emit_err(errors::InvalidCast::new(node.span, from, to)));
796 }
797 }
798 self.wfckresults
799 .fn_app_sorts_mut()
800 .insert(node.fhir_id, res.into());
801 }
802
803 for (_, (param, sort)) in std::mem::take(&mut self.params) {
805 let sort = self
806 .fully_resolve(&sort)
807 .map_err(|_| self.emit_err(errors::SortAnnotationNeeded::new(¶m)))?;
808 self.wfckresults.param_sorts_mut().insert(param.id, sort);
809 }
810
811 Ok(self.wfckresults)
812 }
813
814 pub(crate) fn infer_mode(&self, id: fhir::ParamId) -> fhir::InferMode {
815 fhir::InferMode::from_param_kind(self.params[&id].0.kind)
816 }
817
818 #[track_caller]
819 pub(crate) fn param_sort(&self, id: fhir::ParamId) -> rty::Sort {
820 self.params
821 .get(&id)
822 .unwrap_or_else(|| bug!("no entry found for `{id:?}`"))
823 .1
824 .clone()
825 }
826
827 fn shallow_resolve(&mut self, sort: &rty::Sort) -> rty::Sort {
828 sort.fold_with(&mut ShallowResolver { infcx: self })
829 }
830
831 fn resolve_vars_if_possible(&mut self, sort: &rty::Sort) -> rty::Sort {
832 sort.fold_with(&mut OpportunisticResolver { infcx: self })
833 }
834
835 pub(crate) fn fully_resolve(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, ()> {
836 sort.try_fold_with(&mut FullResolver { infcx: self })
837 }
838}
839
840pub(crate) struct ImplicitParamInferer<'a, 'genv, 'tcx> {
851 infcx: &'a mut InferCtxt<'genv, 'tcx>,
852 errors: Errors<'genv>,
853}
854
855impl<'a, 'genv, 'tcx> ImplicitParamInferer<'a, 'genv, 'tcx> {
856 pub(crate) fn infer(
857 infcx: &'a mut InferCtxt<'genv, 'tcx>,
858 node: &fhir::OwnerNode<'genv>,
859 ) -> Result {
860 let errors = Errors::new(infcx.genv.sess());
861 let mut vis = Self { infcx, errors };
862 vis.visit_node(node);
863 vis.errors.to_result()
864 }
865
866 fn infer_implicit_params(&mut self, idx: &fhir::Expr, expected: &rty::Sort) {
867 match idx.kind {
868 fhir::ExprKind::Var(QPathExpr::Resolved(var, Some(_))) => {
869 let (_, id) = var.res.expect_param();
870 let found = self.infcx.param_sort(id);
871 self.infcx.try_equate(&found, expected).unwrap_or_else(|| {
872 span_bug!(idx.span, "failed to equate sorts: `{found:?}` `{expected:?}`")
873 });
874 }
875 fhir::ExprKind::Record(flds) => {
876 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
877 let sorts = sort_def.struct_variant().field_sorts(sort_args);
878 if flds.len() != sorts.len() {
879 self.errors.emit(errors::ArgCountMismatch::new(
880 Some(idx.span),
881 String::from("type"),
882 sorts.len(),
883 flds.len(),
884 ));
885 return;
886 }
887 for (f, sort) in iter::zip(flds, &sorts) {
888 self.infer_implicit_params(f, sort);
889 }
890 } else {
891 self.errors.emit(errors::ArgCountMismatch::new(
892 Some(idx.span),
893 String::from("type"),
894 1,
895 flds.len(),
896 ));
897 }
898 }
899 _ => {}
900 }
901 }
902}
903
904impl<'genv> fhir::visit::Visitor<'genv> for ImplicitParamInferer<'_, 'genv, '_> {
905 fn visit_ty(&mut self, ty: &fhir::Ty<'genv>) {
906 if let fhir::TyKind::Indexed(bty, idx) = &ty.kind {
907 let expected = self.infcx.sort_of_bty(bty);
908 self.infer_implicit_params(idx, &expected);
909 }
910 fhir::visit::walk_ty(self, ty);
911 }
912}
913
914impl InferCtxt<'_, '_> {
915 #[track_caller]
916 fn emit_sort_mismatch(
917 &mut self,
918 span: Span,
919 expected: &rty::Sort,
920 found: &rty::Sort,
921 ) -> ErrorGuaranteed {
922 let expected = self.resolve_vars_if_possible(expected);
923 let found = self.resolve_vars_if_possible(found);
924 self.emit_err(errors::SortMismatch::new(span, expected, found))
925 }
926
927 fn emit_field_not_found(&mut self, sort: &rty::Sort, field: Ident) -> ErrorGuaranteed {
928 self.emit_err(errors::FieldNotFound::new(sort.clone(), field))
929 }
930
931 #[track_caller]
932 fn emit_err<'b>(&'b self, err: impl Diagnostic<'b>) -> ErrorGuaranteed {
933 self.genv.sess().emit_err(err)
934 }
935}
936
937struct ShallowResolver<'a, 'genv, 'tcx> {
938 infcx: &'a mut InferCtxt<'genv, 'tcx>,
939}
940
941impl TypeFolder for ShallowResolver<'_, '_, '_> {
942 fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
943 match sort {
944 rty::Sort::Infer(vid) => {
945 self.infcx
951 .sort_unification_table
952 .probe_value(*vid)
953 .map_solved(|sort| sort.fold_with(self))
954 .solved_or(sort)
955 }
956 rty::Sort::BitVec(rty::BvSize::Infer(vid)) => {
957 self.infcx
958 .bv_size_unification_table
959 .probe_value(*vid)
960 .map(rty::Sort::BitVec)
961 .unwrap_or(sort.clone())
962 }
963 _ => sort.clone(),
964 }
965 }
966}
967
968struct OpportunisticResolver<'a, 'genv, 'tcx> {
969 infcx: &'a mut InferCtxt<'genv, 'tcx>,
970}
971
972impl TypeFolder for OpportunisticResolver<'_, '_, '_> {
973 fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
974 let s = self.infcx.shallow_resolve(sort);
975 s.super_fold_with(self)
976 }
977}
978
979struct FullResolver<'a, 'genv, 'tcx> {
980 infcx: &'a mut InferCtxt<'genv, 'tcx>,
981}
982
983impl FallibleTypeFolder for FullResolver<'_, '_, '_> {
984 type Error = ();
985
986 fn try_fold_sort(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, Self::Error> {
987 let s = self.infcx.shallow_resolve(sort);
988 match s {
989 rty::Sort::Infer(_) | rty::Sort::BitVec(rty::BvSize::Infer(_)) => Err(()),
990 _ => s.try_super_fold_with(self),
991 }
992 }
993}
994
995#[derive_where(Default)]
999struct NodeMap<'genv, T> {
1000 map: FxHashMap<FhirId, (fhir::Expr<'genv>, T)>,
1001}
1002
1003impl<'genv, T> NodeMap<'genv, T> {
1004 fn insert(&mut self, node: fhir::Expr<'genv>, data: T) {
1006 assert!(self.map.insert(node.fhir_id, (node, data)).is_none());
1007 }
1008}
1009
1010impl<'genv, T> IntoIterator for NodeMap<'genv, T> {
1011 type Item = (fhir::Expr<'genv>, T);
1012
1013 type IntoIter = std::collections::hash_map::IntoValues<FhirId, Self::Item>;
1014
1015 fn into_iter(self) -> Self::IntoIter {
1016 self.map.into_values()
1017 }
1018}