1use std::iter;
2
3use ena::unify::InPlaceUnificationTable;
4use flux_common::{bug, iter::IterExt, result::ResultExt, span_bug, tracked_span_bug};
5use flux_errors::{ErrorGuaranteed, Errors};
6use flux_infer::projections::NormalizeExt;
7use flux_middle::{
8 THEORY_FUNCS,
9 fhir::{self, ExprRes, FhirId, FluxOwnerId, SpecFuncKind, visit::Visitor as _},
10 global_env::GlobalEnv,
11 queries::QueryResult,
12 rty::{
13 self, AdtSortDef, FuncSort, WfckResults,
14 fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable},
15 },
16};
17use itertools::{Itertools, izip};
18use rustc_data_structures::unord::UnordMap;
19use rustc_errors::Diagnostic;
20use rustc_hash::FxHashMap;
21use rustc_middle::ty::TypingMode;
22use rustc_span::{Span, def_id::DefId, symbol::Ident};
23
24use super::errors;
25use crate::rustc_infer::infer::TyCtxtInferExt;
26
27type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
28
29pub(super) struct InferCtxt<'genv, 'tcx> {
30 pub genv: GlobalEnv<'genv, 'tcx>,
31 pub owner: FluxOwnerId,
32 pub wfckresults: WfckResults,
33 sort_unification_table: InPlaceUnificationTable<rty::SortVid>,
34 num_unification_table: InPlaceUnificationTable<rty::NumVid>,
35 bv_size_unification_table: InPlaceUnificationTable<rty::BvSizeVid>,
36 params: FxHashMap<fhir::ParamId, (fhir::RefineParam<'genv>, rty::Sort)>,
37 sort_of_bty: FxHashMap<FhirId, rty::Sort>,
38 sort_of_literal: FxHashMap<FhirId, (rty::Sort, Span)>,
39 sort_of_bin_op: FxHashMap<FhirId, (rty::Sort, Span)>,
40 path_args: UnordMap<FhirId, rty::GenericArgs>,
41 sort_of_alias_reft: FxHashMap<FhirId, rty::FuncSort>,
42}
43
44impl<'genv, 'tcx> InferCtxt<'genv, 'tcx> {
45 pub(super) fn new(genv: GlobalEnv<'genv, 'tcx>, owner: FluxOwnerId) -> Self {
46 let mut sort_unification_table = InPlaceUnificationTable::new();
48 sort_unification_table.new_key(None);
49 Self {
50 genv,
51 owner,
52 wfckresults: WfckResults::new(owner),
53 sort_unification_table,
54 num_unification_table: InPlaceUnificationTable::new(),
55 bv_size_unification_table: InPlaceUnificationTable::new(),
56 params: Default::default(),
57 sort_of_bty: Default::default(),
58 sort_of_literal: Default::default(),
59 sort_of_bin_op: Default::default(),
60 path_args: Default::default(),
61 sort_of_alias_reft: Default::default(),
62 }
63 }
64
65 fn check_abs(
66 &mut self,
67 expr: &fhir::Expr,
68 params: &[fhir::RefineParam],
69 body: &fhir::Expr,
70 expected: &rty::Sort,
71 ) -> Result {
72 let Some(fsort) = self.is_coercible_from_func(expected, expr.fhir_id) else {
73 return Err(self.emit_err(errors::UnexpectedFun::new(expr.span, expected)));
74 };
75
76 let fsort = fsort.expect_mono();
77
78 if params.len() != fsort.inputs().len() {
79 return Err(self.emit_err(errors::ParamCountMismatch::new(
80 expr.span,
81 fsort.inputs().len(),
82 params.len(),
83 )));
84 }
85 iter::zip(params, fsort.inputs()).try_for_each_exhaust(|(param, expected)| {
86 let found = self.param_sort(param.id);
87 if self.try_equate(&found, expected).is_none() {
88 return Err(self.emit_sort_mismatch(param.span, expected, &found));
89 }
90 Ok(())
91 })?;
92 self.check_expr(body, fsort.output())?;
93 self.wfckresults
94 .node_sorts_mut()
95 .insert(body.fhir_id, fsort.output().clone());
96 Ok(())
97 }
98
99 fn check_field_exprs(
100 &mut self,
101 expr_span: Span,
102 sort_def: &AdtSortDef,
103 sort_args: &[rty::Sort],
104 field_exprs: &[fhir::FieldExpr],
105 spread: &Option<&fhir::Spread>,
106 expected: &rty::Sort,
107 ) -> Result {
108 let sort_by_field_name = sort_def.struct_variant().sort_by_field_name(sort_args);
109 let mut used_fields = FxHashMap::default();
110 for expr in field_exprs {
111 let Some(sort) = sort_by_field_name.get(&expr.ident.name) else {
113 return Err(self.emit_err(errors::FieldNotFound::new(expected.clone(), expr.ident)));
114 };
115 if let Some(old_field) = used_fields.insert(expr.ident.name, expr.ident) {
116 return Err(self.emit_err(errors::DuplicateFieldUsed::new(expr.ident, old_field)));
117 }
118 self.check_expr(&expr.expr, sort)?;
119 }
120 if let Some(spread) = spread {
121 self.check_expr(&spread.expr, expected)
123 } else if sort_by_field_name.len() != used_fields.len() {
124 let missing_fields = sort_by_field_name
126 .into_keys()
127 .filter(|x| !used_fields.contains_key(x))
128 .collect();
129 Err(self.emit_err(errors::ConstructorMissingFields::new(expr_span, missing_fields)))
130 } else {
131 Ok(())
132 }
133 }
134
135 fn check_constructor(
136 &mut self,
137 expr: &fhir::Expr,
138 field_exprs: &[fhir::FieldExpr],
139 spread: &Option<&fhir::Spread>,
140 expected: &rty::Sort,
141 ) -> Result {
142 let expected = self.resolve_vars_if_possible(expected);
143 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = &expected {
144 self.wfckresults
145 .record_ctors_mut()
146 .insert(expr.fhir_id, sort_def.did());
147 self.check_field_exprs(expr.span, sort_def, sort_args, field_exprs, spread, &expected)?;
148 Ok(())
149 } else {
150 Err(self.emit_err(errors::UnexpectedConstructor::new(expr.span, &expected)))
151 }
152 }
153
154 fn check_record(
155 &mut self,
156 arg: &fhir::Expr,
157 flds: &[fhir::Expr],
158 expected: &rty::Sort,
159 ) -> Result {
160 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
161 let sorts = sort_def.struct_variant().field_sorts(sort_args);
162 if flds.len() != sorts.len() {
163 return Err(self.emit_err(errors::ArgCountMismatch::new(
164 Some(arg.span),
165 String::from("type"),
166 sorts.len(),
167 flds.len(),
168 )));
169 }
170 self.wfckresults
171 .record_ctors_mut()
172 .insert(arg.fhir_id, sort_def.did());
173
174 izip!(flds, &sorts)
175 .map(|(arg, expected)| self.check_expr(arg, expected))
176 .try_collect_exhaust()
177 } else {
178 Err(self.emit_err(errors::ArgCountMismatch::new(
179 Some(arg.span),
180 String::from("type"),
181 1,
182 flds.len(),
183 )))
184 }
185 }
186
187 pub(super) fn check_expr(&mut self, expr: &fhir::Expr, expected: &rty::Sort) -> Result {
188 match &expr.kind {
189 fhir::ExprKind::IfThenElse(p, e1, e2) => {
190 self.check_expr(p, &rty::Sort::Bool)?;
191 self.check_expr(e1, expected)?;
192 self.check_expr(e2, expected)?;
193 }
194 fhir::ExprKind::Abs(params, body) => {
195 self.check_abs(expr, params, body, expected)?;
196 }
197 fhir::ExprKind::Record(fields) => self.check_record(expr, fields, expected)?,
198 fhir::ExprKind::Constructor(None, exprs, spread) => {
199 self.check_constructor(expr, exprs, spread, expected)?;
200 }
201 fhir::ExprKind::UnaryOp(..)
202 | fhir::ExprKind::BinaryOp(..)
203 | fhir::ExprKind::Dot(..)
204 | fhir::ExprKind::App(..)
205 | fhir::ExprKind::Alias(..)
206 | fhir::ExprKind::Var(..)
207 | fhir::ExprKind::Literal(..)
208 | fhir::ExprKind::BoundedQuant(..)
209 | fhir::ExprKind::Block(..)
210 | fhir::ExprKind::Constructor(..) => {
211 let found = self.synth_expr(expr)?;
212 let found = self.resolve_vars_if_possible(&found);
213 let expected = self.resolve_vars_if_possible(expected);
214 if !self.is_coercible(&found, &expected, expr.fhir_id) {
215 return Err(self.emit_sort_mismatch(expr.span, &expected, &found));
216 }
217 }
218 fhir::ExprKind::Err(_) => {
219 }
221 }
222 Ok(())
223 }
224
225 pub(super) fn check_loc(&mut self, loc: &fhir::PathExpr) -> Result {
226 let found = self.synth_path(loc)?;
227 if found == rty::Sort::Loc {
228 Ok(())
229 } else {
230 Err(self.emit_sort_mismatch(loc.span, &rty::Sort::Loc, &found))
231 }
232 }
233
234 fn synth_lit(&mut self, lit: fhir::Lit, expr: &fhir::Expr) -> rty::Sort {
235 match lit {
236 fhir::Lit::Int(_) => {
237 let sort = self.next_num_var();
238 self.sort_of_literal
239 .insert(expr.fhir_id, (sort.clone(), expr.span));
240 sort
241 }
242 fhir::Lit::Bool(_) => rty::Sort::Bool,
243 fhir::Lit::Real(_) => rty::Sort::Real,
244 fhir::Lit::Str(_) => rty::Sort::Str,
245 fhir::Lit::Char(_) => rty::Sort::Char,
246 }
247 }
248
249 fn synth_expr(&mut self, expr: &fhir::Expr) -> Result<rty::Sort> {
250 match expr.kind {
251 fhir::ExprKind::Var(var, _) => self.synth_path(&var),
252 fhir::ExprKind::Literal(lit) => Ok(self.synth_lit(lit, expr)),
253 fhir::ExprKind::BinaryOp(op, e1, e2) => self.synth_binary_op(expr, op, e1, e2),
254 fhir::ExprKind::UnaryOp(op, e) => self.synth_unary_op(op, e),
255 fhir::ExprKind::App(callee, args) => {
256 let sort = self.ensure_resolved_path(&callee)?;
257 let Some(poly_fsort) = self.is_coercible_to_func(&sort, callee.fhir_id) else {
258 return Err(self.emit_err(errors::ExpectedFun::new(callee.span, &sort)));
259 };
260 let fsort = self.instantiate_func_sort(poly_fsort);
261 self.synth_app(fsort, args, expr.span)
262 }
263 fhir::ExprKind::BoundedQuant(.., body) => {
264 self.check_expr(body, &rty::Sort::Bool)?;
265 Ok(rty::Sort::Bool)
266 }
267 fhir::ExprKind::Alias(_alias_reft, args) => {
268 let fsort = self.sort_of_alias_reft(expr.fhir_id);
271 self.synth_app(fsort, args, expr.span)
272 }
273 fhir::ExprKind::IfThenElse(p, e1, e2) => {
274 self.check_expr(p, &rty::Sort::Bool)?;
275 let sort = self.synth_expr(e1)?;
276 self.check_expr(e2, &sort)?;
277 Ok(sort)
278 }
279 fhir::ExprKind::Dot(base, fld) => {
280 let sort = self.synth_expr(base)?;
281 let sort = self
282 .fully_resolve(&sort)
283 .map_err(|_| self.emit_err(errors::CannotInferSort::new(base.span)))?;
284 match &sort {
285 rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) => {
286 let (proj, sort) = sort_def
287 .struct_variant()
288 .field_by_name(sort_def.did(), sort_args, fld.name)
289 .ok_or_else(|| self.emit_field_not_found(&sort, fld))?;
290 self.wfckresults
291 .field_projs_mut()
292 .insert(expr.fhir_id, proj);
293 Ok(sort)
294 }
295 rty::Sort::Bool | rty::Sort::Int | rty::Sort::Real => {
296 Err(self.emit_err(errors::InvalidPrimitiveDotAccess::new(&sort, fld)))
297 }
298 _ => Err(self.emit_field_not_found(&sort, fld)),
299 }
300 }
301 fhir::ExprKind::Constructor(Some(path), field_exprs, spread) => {
302 let path_def_id = match path.res {
306 ExprRes::Ctor(def_id) => def_id,
307 _ => span_bug!(expr.span, "unexpected path in constructor"),
308 };
309 let sort_def = self
310 .genv
311 .adt_sort_def_of(path_def_id)
312 .map_err(|e| self.emit_err(e))?;
313 let fresh_args: rty::List<_> = (0..sort_def.param_count())
315 .map(|_| self.next_sort_var())
316 .collect();
317 let sort = rty::Sort::App(rty::SortCtor::Adt(sort_def.clone()), fresh_args.clone());
318 self.check_field_exprs(
320 expr.span,
321 &sort_def,
322 &fresh_args,
323 field_exprs,
324 &spread,
325 &sort,
326 )?;
327 Ok(sort)
328 }
329 fhir::ExprKind::Block(decls, body) => {
330 for decl in decls {
331 self.check_expr(&decl.init, &self.param_sort(decl.param.id))?;
332 }
333 self.synth_expr(body)
334 }
335 _ => Err(self.emit_err(errors::CannotInferSort::new(expr.span))),
336 }
337 }
338
339 fn synth_path(&mut self, path: &fhir::PathExpr) -> Result<rty::Sort> {
340 match path.res {
341 ExprRes::Param(_, id) => Ok(self.param_sort(id)),
342 ExprRes::Const(def_id) => {
343 if let Some(sort) = self.genv.sort_of_def_id(def_id).emit(&self.genv)? {
344 let info = self.genv.constant_info(def_id).emit(&self.genv)?;
345 if sort != rty::Sort::Int && matches!(info, rty::ConstantInfo::Uninterpreted) {
347 Err(self.emit_err(errors::ConstantAnnotationNeeded::new(path.span)))?;
348 }
349 Ok(sort)
350 } else {
351 span_bug!(path.span, "unexpected const")
352 }
353 }
354 ExprRes::Variant(def_id) => {
355 let Some(sort) = self.genv.sort_of_def_id(def_id).emit(&self.genv)? else {
356 span_bug!(path.span, "unexpected variant {def_id:?}")
357 };
358 Ok(sort)
359 }
360 ExprRes::ConstGeneric(_) => Ok(rty::Sort::Int), ExprRes::NumConst(_) => Ok(rty::Sort::Int),
362 ExprRes::GlobalFunc(SpecFuncKind::Def(name) | SpecFuncKind::Uif(name)) => {
363 Ok(rty::Sort::Func(self.genv.func_sort(name)))
364 }
365 ExprRes::GlobalFunc(SpecFuncKind::Thy(itf)) => {
366 Ok(rty::Sort::Func(THEORY_FUNCS.get(&itf).unwrap().sort.clone()))
367 }
368 ExprRes::Ctor(_) => {
369 span_bug!(path.span, "unexpected constructor in var position")
370 }
371 }
372 }
373
374 fn check_integral(&mut self, op: fhir::BinOp, sort: &rty::Sort, span: Span) -> Result {
375 if matches!(op, fhir::BinOp::Mod) {
376 let sort = self
377 .fully_resolve(sort)
378 .map_err(|_| self.emit_err(errors::CannotInferSort::new(span)))?;
379 if !matches!(sort, rty::Sort::Int | rty::Sort::BitVec(_)) {
380 span_bug!(span, "unexpected sort {sort:?} for operator {op:?}");
381 }
382 }
383 Ok(())
384 }
385
386 fn synth_binary_op(
387 &mut self,
388 expr: &fhir::Expr,
389 op: fhir::BinOp,
390 e1: &fhir::Expr,
391 e2: &fhir::Expr,
392 ) -> Result<rty::Sort> {
393 match op {
394 fhir::BinOp::Or | fhir::BinOp::And | fhir::BinOp::Iff | fhir::BinOp::Imp => {
395 self.check_expr(e1, &rty::Sort::Bool)?;
396 self.check_expr(e2, &rty::Sort::Bool)?;
397 Ok(rty::Sort::Bool)
398 }
399 fhir::BinOp::Eq | fhir::BinOp::Ne => {
400 let sort = self.next_sort_var();
401 self.check_expr(e1, &sort)?;
402 self.check_expr(e2, &sort)?;
403 Ok(rty::Sort::Bool)
404 }
405 fhir::BinOp::Lt | fhir::BinOp::Le | fhir::BinOp::Gt | fhir::BinOp::Ge => {
406 let sort = self.next_sort_var();
407 self.check_expr(e1, &sort)?;
408 self.check_expr(e2, &sort)?;
409 self.sort_of_bin_op
410 .insert(expr.fhir_id, (sort.clone(), expr.span));
411 Ok(rty::Sort::Bool)
412 }
413 fhir::BinOp::Add
414 | fhir::BinOp::Sub
415 | fhir::BinOp::Mul
416 | fhir::BinOp::Div
417 | fhir::BinOp::Mod => {
418 let sort = self.next_num_var();
419 self.check_expr(e1, &sort)?;
420 self.check_expr(e2, &sort)?;
421 self.sort_of_bin_op
422 .insert(expr.fhir_id, (sort.clone(), expr.span));
423 self.check_integral(op, &sort, expr.span)?;
425
426 Ok(sort)
427 }
428 fhir::BinOp::BitAnd
429 | fhir::BinOp::BitOr
430 | fhir::BinOp::BitShl
431 | fhir::BinOp::BitShr => {
432 let sort = rty::Sort::BitVec(self.next_bv_size_var());
433 self.check_expr(e1, &sort)?;
434 self.check_expr(e2, &sort)?;
435 Ok(sort)
436 }
437 }
438 }
439
440 fn synth_unary_op(&mut self, op: fhir::UnOp, e: &fhir::Expr) -> Result<rty::Sort> {
441 match op {
442 fhir::UnOp::Not => {
443 self.check_expr(e, &rty::Sort::Bool)?;
444 Ok(rty::Sort::Bool)
445 }
446 fhir::UnOp::Neg => {
447 self.check_expr(e, &rty::Sort::Int)?;
448 Ok(rty::Sort::Int)
449 }
450 }
451 }
452
453 fn synth_app(&mut self, fsort: FuncSort, args: &[fhir::Expr], span: Span) -> Result<rty::Sort> {
454 if args.len() != fsort.inputs().len() {
455 return Err(self.emit_err(errors::ArgCountMismatch::new(
456 Some(span),
457 String::from("function"),
458 fsort.inputs().len(),
459 args.len(),
460 )));
461 }
462
463 iter::zip(args, fsort.inputs())
464 .try_for_each_exhaust(|(arg, formal)| self.check_expr(arg, formal))?;
465
466 Ok(fsort.output().clone())
467 }
468
469 fn instantiate_func_sort(&mut self, fsort: rty::PolyFuncSort) -> rty::FuncSort {
470 let args = fsort
471 .params()
472 .map(|kind| {
473 match kind {
474 rty::SortParamKind::Sort => rty::SortArg::Sort(self.next_sort_var()),
475 rty::SortParamKind::BvSize => rty::SortArg::BvSize(self.next_bv_size_var()),
476 }
477 })
478 .collect_vec();
479 fsort.instantiate(&args)
480 }
481
482 pub(crate) fn insert_sort_for_bty(&mut self, fhir_id: FhirId, sort: rty::Sort) {
483 self.sort_of_bty.insert(fhir_id, sort);
484 }
485
486 pub(crate) fn sort_of_bty(&self, fhir_id: FhirId) -> rty::Sort {
487 self.sort_of_bty
488 .get(&fhir_id)
489 .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
490 .clone()
491 }
492
493 pub(crate) fn insert_path_args(&mut self, fhir_id: FhirId, args: rty::GenericArgs) {
494 self.path_args.insert(fhir_id, args);
495 }
496
497 pub(crate) fn path_args(&self, fhir_id: FhirId) -> rty::GenericArgs {
498 self.path_args
499 .get(&fhir_id)
500 .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
501 .clone()
502 }
503
504 pub(crate) fn insert_sort_for_alias_reft(&mut self, fhir_id: FhirId, fsort: rty::FuncSort) {
505 self.sort_of_alias_reft.insert(fhir_id, fsort);
506 }
507
508 fn sort_of_alias_reft(&self, fhir_id: FhirId) -> rty::FuncSort {
509 self.sort_of_alias_reft
510 .get(&fhir_id)
511 .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
512 .clone()
513 }
514
515 fn normalize_projection_sort(
516 genv: GlobalEnv,
517 owner: FluxOwnerId,
518 sort: rty::Sort,
519 ) -> rty::Sort {
520 let infcx = genv
521 .tcx()
522 .infer_ctxt()
523 .with_next_trait_solver(true)
524 .build(TypingMode::non_body_analysis());
525 if let Some(def_id) = owner.def_id()
526 && let def_id = genv.maybe_extern_id(def_id).resolved_id()
527 && let Ok(sort) = sort.normalize_sorts(def_id, genv, &infcx)
528 {
529 sort
530 } else {
531 sort
532 }
533 }
534
535 pub(crate) fn normalize_sorts(&mut self) -> QueryResult {
540 let genv = self.genv;
541 for sort in self.sort_of_bty.values_mut() {
542 *sort = Self::normalize_projection_sort(genv, self.owner, sort.clone());
543 }
544 for fsort in self.sort_of_alias_reft.values_mut() {
545 *fsort = genv.deep_normalize_weak_alias_sorts(fsort)?;
546 }
547 Ok(())
548 }
549}
550
551impl<'genv> InferCtxt<'genv, '_> {
552 pub(super) fn declare_param(&mut self, param: fhir::RefineParam<'genv>, sort: rty::Sort) {
553 self.params.insert(param.id, (param, sort));
554 }
555
556 fn is_coercible(&mut self, sort1: &rty::Sort, sort2: &rty::Sort, fhir_id: FhirId) -> bool {
561 if self.try_equate(sort1, sort2).is_some() {
562 return true;
563 }
564
565 let mut sort1 = sort1.clone();
566 let mut sort2 = sort2.clone();
567
568 let mut coercions = vec![];
569 if let Some((def_id, sort)) = self.is_single_field_struct(&sort1) {
570 coercions.push(rty::Coercion::Project(def_id));
571 sort1 = sort.clone();
572 }
573 if let Some((def_id, sort)) = self.is_single_field_struct(&sort2) {
574 coercions.push(rty::Coercion::Inject(def_id));
575 sort2 = sort.clone();
576 }
577 self.wfckresults.coercions_mut().insert(fhir_id, coercions);
578 self.try_equate(&sort1, &sort2).is_some()
579 }
580
581 fn is_coercible_from_func(
582 &mut self,
583 sort: &rty::Sort,
584 fhir_id: FhirId,
585 ) -> Option<rty::PolyFuncSort> {
586 if let rty::Sort::Func(fsort) = sort {
587 Some(fsort.clone())
588 } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
589 self.wfckresults
590 .coercions_mut()
591 .insert(fhir_id, vec![rty::Coercion::Inject(def_id)]);
592 Some(fsort.clone())
593 } else {
594 None
595 }
596 }
597
598 fn is_coercible_to_func(
599 &mut self,
600 sort: &rty::Sort,
601 fhir_id: FhirId,
602 ) -> Option<rty::PolyFuncSort> {
603 if let rty::Sort::Func(fsort) = sort {
604 Some(fsort.clone())
605 } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
606 self.wfckresults
607 .coercions_mut()
608 .insert(fhir_id, vec![rty::Coercion::Project(def_id)]);
609 Some(fsort.clone())
610 } else {
611 None
612 }
613 }
614
615 fn try_equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
616 let sort1 = self.resolve_vars_if_possible(sort1);
617 let sort2 = self.resolve_vars_if_possible(sort2);
618 self.try_equate_inner(&sort1, &sort2)
619 }
620
621 fn try_equate_inner(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
622 match (sort1, sort2) {
623 (rty::Sort::Infer(rty::SortVar(vid1)), rty::Sort::Infer(rty::SortVar(vid2))) => {
624 self.sort_unification_table
625 .unify_var_var(*vid1, *vid2)
626 .ok()?;
627 }
628 (rty::Sort::Infer(rty::SortVar(vid)), sort)
629 | (sort, rty::Sort::Infer(rty::SortVar(vid))) => {
630 self.sort_unification_table
631 .unify_var_value(*vid, Some(sort.clone()))
632 .ok()?;
633 }
634 (rty::Sort::Infer(rty::NumVar(vid1)), rty::Sort::Infer(rty::NumVar(vid2))) => {
635 self.num_unification_table
636 .unify_var_var(*vid1, *vid2)
637 .ok()?;
638 }
639 (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::Int)
640 | (rty::Sort::Int, rty::Sort::Infer(rty::NumVar(vid))) => {
641 self.num_unification_table
642 .unify_var_value(*vid, Some(rty::NumVarValue::Int))
643 .ok()?;
644 }
645 (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::Real)
646 | (rty::Sort::Real, rty::Sort::Infer(rty::NumVar(vid))) => {
647 self.num_unification_table
648 .unify_var_value(*vid, Some(rty::NumVarValue::Real))
649 .ok()?;
650 }
651 (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::BitVec(sz))
652 | (rty::Sort::BitVec(sz), rty::Sort::Infer(rty::NumVar(vid))) => {
653 self.num_unification_table
654 .unify_var_value(*vid, Some(rty::NumVarValue::BitVec(*sz)))
655 .ok()?;
656 }
657
658 (rty::Sort::App(ctor1, args1), rty::Sort::App(ctor2, args2)) => {
659 if ctor1 != ctor2 || args1.len() != args2.len() {
660 return None;
661 }
662 let mut args = vec![];
663 for (s1, s2) in args1.iter().zip(args2.iter()) {
664 args.push(self.try_equate_inner(s1, s2)?);
665 }
666 }
667 (rty::Sort::BitVec(size1), rty::Sort::BitVec(size2)) => {
668 self.try_equate_bv_sizes(*size1, *size2)?;
669 }
670 _ if sort1 == sort2 => {}
671 _ => return None,
672 }
673 Some(sort1.clone())
674 }
675
676 fn try_equate_bv_sizes(
677 &mut self,
678 size1: rty::BvSize,
679 size2: rty::BvSize,
680 ) -> Option<rty::BvSize> {
681 match (size1, size2) {
682 (rty::BvSize::Infer(vid1), rty::BvSize::Infer(vid2)) => {
683 self.bv_size_unification_table
684 .unify_var_var(vid1, vid2)
685 .ok()?;
686 }
687 (rty::BvSize::Infer(vid), size) | (size, rty::BvSize::Infer(vid)) => {
688 self.bv_size_unification_table
689 .unify_var_value(vid, Some(size))
690 .ok()?;
691 }
692 _ if size1 == size2 => {}
693 _ => return None,
694 }
695 Some(size1)
696 }
697
698 fn equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> rty::Sort {
699 self.try_equate(sort1, sort2)
700 .unwrap_or_else(|| bug!("failed to equate sorts: `{sort1:?}` `{sort2:?}`"))
701 }
702
703 pub(crate) fn next_sort_var(&mut self) -> rty::Sort {
704 rty::Sort::Infer(rty::SortVar(self.next_sort_vid()))
705 }
706
707 fn next_num_var(&mut self) -> rty::Sort {
708 rty::Sort::Infer(rty::NumVar(self.next_num_vid()))
709 }
710
711 pub(crate) fn next_sort_vid(&mut self) -> rty::SortVid {
712 self.sort_unification_table.new_key(None)
713 }
714
715 fn next_num_vid(&mut self) -> rty::NumVid {
716 self.num_unification_table.new_key(None)
717 }
718
719 fn next_bv_size_var(&mut self) -> rty::BvSize {
720 rty::BvSize::Infer(self.next_bv_size_vid())
721 }
722
723 fn next_bv_size_vid(&mut self) -> rty::BvSizeVid {
724 self.bv_size_unification_table.new_key(None)
725 }
726
727 fn ensure_resolved_path(&mut self, path: &fhir::PathExpr) -> Result<rty::Sort> {
728 let sort = self.synth_path(path)?;
729 self.fully_resolve(&sort)
730 .map_err(|_| self.emit_err(errors::CannotInferSort::new(path.span)))
731 }
732
733 fn is_single_field_struct(&mut self, sort: &rty::Sort) -> Option<(DefId, rty::Sort)> {
734 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = sort
735 && let Some(variant) = sort_def.opt_struct_variant()
736 && let [sort] = &variant.field_sorts(sort_args)[..]
737 {
738 Some((sort_def.did(), sort.clone()))
739 } else {
740 None
741 }
742 }
743
744 pub(crate) fn into_results(mut self) -> Result<WfckResults> {
745 for (fhir_id, (sort, span)) in std::mem::take(&mut self.sort_of_literal) {
747 let sort = self
748 .fully_resolve(&sort)
749 .map_err(|_| self.emit_err(errors::CannotInferSort::new(span)))?;
750 self.wfckresults.node_sorts_mut().insert(fhir_id, sort);
751 }
752
753 for (fhir_id, (sort, span)) in std::mem::take(&mut self.sort_of_bin_op) {
755 let sort = self
756 .fully_resolve(&sort)
757 .map_err(|_| self.emit_err(errors::CannotInferSort::new(span)))?;
758 self.wfckresults.bin_op_sorts_mut().insert(fhir_id, sort);
759 }
760
761 for (_, (param, sort)) in std::mem::take(&mut self.params) {
763 let sort = self
764 .fully_resolve(&sort)
765 .map_err(|_| self.emit_err(errors::SortAnnotationNeeded::new(¶m)))?;
766 self.wfckresults
767 .node_sorts_mut()
768 .insert(param.fhir_id, sort);
769 }
770
771 Ok(self.wfckresults)
772 }
773
774 pub(crate) fn infer_mode(&self, id: fhir::ParamId) -> fhir::InferMode {
775 fhir::InferMode::from_param_kind(self.params[&id].0.kind)
776 }
777
778 #[track_caller]
779 pub(crate) fn param_sort(&self, id: fhir::ParamId) -> rty::Sort {
780 self.params
781 .get(&id)
782 .unwrap_or_else(|| bug!("no entry found for `{id:?}`"))
783 .1
784 .clone()
785 }
786
787 fn shallow_resolve(&mut self, sort: &rty::Sort) -> rty::Sort {
788 sort.fold_with(&mut ShallowResolver { infcx: self })
789 }
790
791 fn resolve_vars_if_possible(&mut self, sort: &rty::Sort) -> rty::Sort {
792 sort.fold_with(&mut OpportunisticResolver { infcx: self })
793 }
794
795 pub(crate) fn fully_resolve(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, ()> {
796 sort.try_fold_with(&mut FullResolver { infcx: self })
797 }
798}
799
800pub(crate) struct ImplicitParamInferer<'a, 'genv, 'tcx> {
811 infcx: &'a mut InferCtxt<'genv, 'tcx>,
812 errors: Errors<'genv>,
813}
814
815impl<'a, 'genv, 'tcx> ImplicitParamInferer<'a, 'genv, 'tcx> {
816 pub(crate) fn infer(
817 infcx: &'a mut InferCtxt<'genv, 'tcx>,
818 node: &fhir::OwnerNode<'genv>,
819 ) -> Result {
820 let errors = Errors::new(infcx.genv.sess());
821 let mut vis = Self { infcx, errors };
822 vis.visit_node(node);
823 vis.errors.into_result()
824 }
825
826 fn infer_implicit_params(&mut self, idx: &fhir::Expr, expected: &rty::Sort) {
827 match idx.kind {
828 fhir::ExprKind::Var(var, Some(_)) => {
829 let (_, id) = var.res.expect_param();
830 let found = self.infcx.param_sort(id);
831 self.infcx.equate(&found, expected);
832 }
833 fhir::ExprKind::Record(flds) => {
834 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
835 let sorts = sort_def.struct_variant().field_sorts(sort_args);
836 if flds.len() != sorts.len() {
837 self.errors.emit(errors::ArgCountMismatch::new(
838 Some(idx.span),
839 String::from("type"),
840 sorts.len(),
841 flds.len(),
842 ));
843 return;
844 }
845 for (f, sort) in iter::zip(flds, &sorts) {
846 self.infer_implicit_params(f, sort);
847 }
848 } else {
849 self.errors.emit(errors::ArgCountMismatch::new(
850 Some(idx.span),
851 String::from("type"),
852 1,
853 flds.len(),
854 ));
855 }
856 }
857 _ => {}
858 }
859 }
860}
861
862impl<'genv> fhir::visit::Visitor<'genv> for ImplicitParamInferer<'_, 'genv, '_> {
863 fn visit_ty(&mut self, ty: &fhir::Ty<'genv>) {
864 if let fhir::TyKind::Indexed(bty, idx) = &ty.kind {
865 let expected = self.infcx.sort_of_bty(bty.fhir_id);
866 self.infer_implicit_params(idx, &expected);
867 }
868 fhir::visit::walk_ty(self, ty);
869 }
870}
871
872impl InferCtxt<'_, '_> {
873 #[track_caller]
874 fn emit_sort_mismatch(
875 &mut self,
876 span: Span,
877 expected: &rty::Sort,
878 found: &rty::Sort,
879 ) -> ErrorGuaranteed {
880 let expected = self.resolve_vars_if_possible(expected);
881 let found = self.resolve_vars_if_possible(found);
882 self.emit_err(errors::SortMismatch::new(span, expected, found))
883 }
884
885 fn emit_field_not_found(&mut self, sort: &rty::Sort, field: Ident) -> ErrorGuaranteed {
886 self.emit_err(errors::FieldNotFound::new(sort.clone(), field))
887 }
888
889 #[track_caller]
890 fn emit_err<'b>(&'b self, err: impl Diagnostic<'b>) -> ErrorGuaranteed {
891 self.genv.sess().emit_err(err)
892 }
893}
894
895struct ShallowResolver<'a, 'genv, 'tcx> {
896 infcx: &'a mut InferCtxt<'genv, 'tcx>,
897}
898
899impl TypeFolder for ShallowResolver<'_, '_, '_> {
900 fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
901 match sort {
902 rty::Sort::Infer(rty::SortVar(vid)) => {
903 self.infcx
909 .sort_unification_table
910 .probe_value(*vid)
911 .map(|sort| sort.fold_with(self))
912 .unwrap_or(sort.clone())
913 }
914 rty::Sort::Infer(rty::NumVar(vid)) => {
915 self.infcx
917 .num_unification_table
918 .probe_value(*vid)
919 .map(|val| val.to_sort().fold_with(self))
920 .unwrap_or(sort.clone())
921 }
922 rty::Sort::BitVec(rty::BvSize::Infer(vid)) => {
923 self.infcx
924 .bv_size_unification_table
925 .probe_value(*vid)
926 .map(rty::Sort::BitVec)
927 .unwrap_or(sort.clone())
928 }
929 _ => sort.clone(),
930 }
931 }
932}
933
934struct OpportunisticResolver<'a, 'genv, 'tcx> {
935 infcx: &'a mut InferCtxt<'genv, 'tcx>,
936}
937
938impl TypeFolder for OpportunisticResolver<'_, '_, '_> {
939 fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
940 let s = self.infcx.shallow_resolve(sort);
941 s.super_fold_with(self)
942 }
943}
944
945struct FullResolver<'a, 'genv, 'tcx> {
946 infcx: &'a mut InferCtxt<'genv, 'tcx>,
947}
948
949impl FallibleTypeFolder for FullResolver<'_, '_, '_> {
950 type Error = ();
951
952 fn try_fold_sort(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, Self::Error> {
953 let s = self.infcx.shallow_resolve(sort);
954 match s {
955 rty::Sort::Infer(_) | rty::Sort::BitVec(rty::BvSize::Infer(_)) => Err(()),
956 _ => s.try_super_fold_with(self),
957 }
958 }
959}