1use std::iter;
2
3use ena::unify::InPlaceUnificationTable;
4use flux_common::{bug, iter::IterExt, result::ResultExt, span_bug, tracked_span_bug};
5use flux_errors::{ErrorGuaranteed, Errors};
6use flux_infer::projections::NormalizeExt;
7use flux_middle::{
8 THEORY_FUNCS,
9 fhir::{self, ExprRes, FhirId, FluxOwnerId, SpecFuncKind, visit::Visitor as _},
10 global_env::GlobalEnv,
11 queries::QueryResult,
12 rty::{
13 self, AdtSortDef, FuncSort, List, WfckResults,
14 fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable},
15 },
16};
17use itertools::{Itertools, izip};
18use rustc_data_structures::unord::UnordMap;
19use rustc_errors::Diagnostic;
20use rustc_hash::FxHashMap;
21use rustc_middle::ty::TypingMode;
22use rustc_span::{Span, def_id::DefId, symbol::Ident};
23
24use super::errors;
25use crate::rustc_infer::infer::TyCtxtInferExt;
26
27type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
28
29pub(super) struct InferCtxt<'genv, 'tcx> {
30 pub genv: GlobalEnv<'genv, 'tcx>,
31 pub owner: FluxOwnerId,
32 pub wfckresults: WfckResults,
33 sort_unification_table: InPlaceUnificationTable<rty::SortVid>,
34 num_unification_table: InPlaceUnificationTable<rty::NumVid>,
35 bv_size_unification_table: InPlaceUnificationTable<rty::BvSizeVid>,
36 params: FxHashMap<fhir::ParamId, (fhir::RefineParam<'genv>, rty::Sort)>,
37 sort_of_bty: FxHashMap<FhirId, rty::Sort>,
38 sort_of_literal: FxHashMap<FhirId, (rty::Sort, Span)>,
39 sort_of_bin_op: FxHashMap<FhirId, (rty::Sort, Span)>,
40 path_args: UnordMap<FhirId, rty::GenericArgs>,
41 sort_of_alias_reft: FxHashMap<FhirId, rty::FuncSort>,
42}
43
44pub fn prim_op_sort(op: &fhir::BinOp) -> Option<(Vec<rty::Sort>, rty::Sort)> {
45 match op {
46 fhir::BinOp::BitAnd | fhir::BinOp::BitOr | fhir::BinOp::BitShl | fhir::BinOp::BitShr => {
47 Some((vec![rty::Sort::Int, rty::Sort::Int], rty::Sort::Int))
48 }
49 _ => None,
50 }
51}
52
53impl<'genv, 'tcx> InferCtxt<'genv, 'tcx> {
54 pub(super) fn new(genv: GlobalEnv<'genv, 'tcx>, owner: FluxOwnerId) -> Self {
55 let mut sort_unification_table = InPlaceUnificationTable::new();
57 sort_unification_table.new_key(None);
58 Self {
59 genv,
60 owner,
61 wfckresults: WfckResults::new(owner),
62 sort_unification_table,
63 num_unification_table: InPlaceUnificationTable::new(),
64 bv_size_unification_table: InPlaceUnificationTable::new(),
65 params: Default::default(),
66 sort_of_bty: Default::default(),
67 sort_of_literal: Default::default(),
68 sort_of_bin_op: Default::default(),
69 path_args: Default::default(),
70 sort_of_alias_reft: Default::default(),
71 }
72 }
73
74 fn check_abs(
75 &mut self,
76 expr: &fhir::Expr,
77 params: &[fhir::RefineParam],
78 body: &fhir::Expr,
79 expected: &rty::Sort,
80 ) -> Result {
81 let Some(fsort) = self.is_coercible_from_func(expected, expr.fhir_id) else {
82 return Err(self.emit_err(errors::UnexpectedFun::new(expr.span, expected)));
83 };
84
85 let fsort = fsort.expect_mono();
86
87 if params.len() != fsort.inputs().len() {
88 return Err(self.emit_err(errors::ParamCountMismatch::new(
89 expr.span,
90 fsort.inputs().len(),
91 params.len(),
92 )));
93 }
94 iter::zip(params, fsort.inputs()).try_for_each_exhaust(|(param, expected)| {
95 let found = self.param_sort(param.id);
96 if self.try_equate(&found, expected).is_none() {
97 return Err(self.emit_sort_mismatch(param.span, expected, &found));
98 }
99 Ok(())
100 })?;
101 self.check_expr(body, fsort.output())?;
102 self.wfckresults
103 .node_sorts_mut()
104 .insert(body.fhir_id, fsort.output().clone());
105 Ok(())
106 }
107
108 fn check_field_exprs(
109 &mut self,
110 expr_span: Span,
111 sort_def: &AdtSortDef,
112 sort_args: &[rty::Sort],
113 field_exprs: &[fhir::FieldExpr],
114 spread: &Option<&fhir::Spread>,
115 expected: &rty::Sort,
116 ) -> Result {
117 let sort_by_field_name = sort_def.struct_variant().sort_by_field_name(sort_args);
118 let mut used_fields = FxHashMap::default();
119 for expr in field_exprs {
120 let Some(sort) = sort_by_field_name.get(&expr.ident.name) else {
122 return Err(self.emit_err(errors::FieldNotFound::new(expected.clone(), expr.ident)));
123 };
124 if let Some(old_field) = used_fields.insert(expr.ident.name, expr.ident) {
125 return Err(self.emit_err(errors::DuplicateFieldUsed::new(expr.ident, old_field)));
126 }
127 self.check_expr(&expr.expr, sort)?;
128 }
129 if let Some(spread) = spread {
130 self.check_expr(&spread.expr, expected)
132 } else if sort_by_field_name.len() != used_fields.len() {
133 let missing_fields = sort_by_field_name
135 .into_keys()
136 .filter(|x| !used_fields.contains_key(x))
137 .collect();
138 Err(self.emit_err(errors::ConstructorMissingFields::new(expr_span, missing_fields)))
139 } else {
140 Ok(())
141 }
142 }
143
144 fn check_constructor(
145 &mut self,
146 expr: &fhir::Expr,
147 field_exprs: &[fhir::FieldExpr],
148 spread: &Option<&fhir::Spread>,
149 expected: &rty::Sort,
150 ) -> Result {
151 let expected = self.resolve_vars_if_possible(expected);
152 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = &expected {
153 self.wfckresults
154 .record_ctors_mut()
155 .insert(expr.fhir_id, sort_def.did());
156 self.check_field_exprs(expr.span, sort_def, sort_args, field_exprs, spread, &expected)?;
157 Ok(())
158 } else {
159 Err(self.emit_err(errors::UnexpectedConstructor::new(expr.span, &expected)))
160 }
161 }
162
163 fn check_record(
164 &mut self,
165 arg: &fhir::Expr,
166 flds: &[fhir::Expr],
167 expected: &rty::Sort,
168 ) -> Result {
169 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
170 let sorts = sort_def.struct_variant().field_sorts(sort_args);
171 if flds.len() != sorts.len() {
172 return Err(self.emit_err(errors::ArgCountMismatch::new(
173 Some(arg.span),
174 String::from("type"),
175 sorts.len(),
176 flds.len(),
177 )));
178 }
179 self.wfckresults
180 .record_ctors_mut()
181 .insert(arg.fhir_id, sort_def.did());
182
183 izip!(flds, &sorts)
184 .map(|(arg, expected)| self.check_expr(arg, expected))
185 .try_collect_exhaust()
186 } else {
187 Err(self.emit_err(errors::ArgCountMismatch::new(
188 Some(arg.span),
189 String::from("type"),
190 1,
191 flds.len(),
192 )))
193 }
194 }
195
196 pub(super) fn check_expr(&mut self, expr: &fhir::Expr, expected: &rty::Sort) -> Result {
197 match &expr.kind {
198 fhir::ExprKind::IfThenElse(p, e1, e2) => {
199 self.check_expr(p, &rty::Sort::Bool)?;
200 self.check_expr(e1, expected)?;
201 self.check_expr(e2, expected)?;
202 }
203 fhir::ExprKind::Abs(params, body) => {
204 self.check_abs(expr, params, body, expected)?;
205 }
206 fhir::ExprKind::Record(fields) => self.check_record(expr, fields, expected)?,
207 fhir::ExprKind::Constructor(None, exprs, spread) => {
208 self.check_constructor(expr, exprs, spread, expected)?;
209 }
210 fhir::ExprKind::UnaryOp(..)
211 | fhir::ExprKind::BinaryOp(..)
212 | fhir::ExprKind::Dot(..)
213 | fhir::ExprKind::App(..)
214 | fhir::ExprKind::Alias(..)
215 | fhir::ExprKind::Var(..)
216 | fhir::ExprKind::Literal(..)
217 | fhir::ExprKind::BoundedQuant(..)
218 | fhir::ExprKind::Block(..)
219 | fhir::ExprKind::Constructor(..)
220 | fhir::ExprKind::PrimApp(..) => {
221 let found = self.synth_expr(expr)?;
222 let found = self.resolve_vars_if_possible(&found);
223 let expected = self.resolve_vars_if_possible(expected);
224 if !self.is_coercible(&found, &expected, expr.fhir_id) {
225 return Err(self.emit_sort_mismatch(expr.span, &expected, &found));
226 }
227 }
228 fhir::ExprKind::Err(_) => {
229 }
231 }
232 Ok(())
233 }
234
235 pub(super) fn check_loc(&mut self, loc: &fhir::PathExpr) -> Result {
236 let found = self.synth_path(loc)?;
237 if found == rty::Sort::Loc {
238 Ok(())
239 } else {
240 Err(self.emit_sort_mismatch(loc.span, &rty::Sort::Loc, &found))
241 }
242 }
243
244 fn synth_lit(&mut self, lit: fhir::Lit, expr: &fhir::Expr) -> rty::Sort {
245 match lit {
246 fhir::Lit::Int(_) => {
247 let sort = self.next_num_var();
248 self.sort_of_literal
249 .insert(expr.fhir_id, (sort.clone(), expr.span));
250 sort
251 }
252 fhir::Lit::Bool(_) => rty::Sort::Bool,
253 fhir::Lit::Real(_) => rty::Sort::Real,
254 fhir::Lit::Str(_) => rty::Sort::Str,
255 fhir::Lit::Char(_) => rty::Sort::Char,
256 }
257 }
258
259 fn synth_prim_app(
260 &mut self,
261 op: &fhir::BinOp,
262 e1: &fhir::Expr,
263 e2: &fhir::Expr,
264 span: Span,
265 ) -> Result<rty::Sort> {
266 let Some((inputs, output)) = prim_op_sort(op) else {
267 return Err(self.emit_err(errors::UnsupportedPrimOp::new(span, *op)));
268 };
269 let [sort1, sort2] = &inputs[..] else {
270 return Err(self.emit_err(errors::ArgCountMismatch::new(
271 Some(span),
272 String::from("primop app"),
273 inputs.len(),
274 2,
275 )));
276 };
277 self.check_expr(e1, sort1)?;
278 self.check_expr(e2, sort2)?;
279 Ok(output)
280 }
281
282 fn synth_expr(&mut self, expr: &fhir::Expr) -> Result<rty::Sort> {
283 match expr.kind {
284 fhir::ExprKind::Var(var, _) => self.synth_path(&var),
285 fhir::ExprKind::Literal(lit) => Ok(self.synth_lit(lit, expr)),
286 fhir::ExprKind::BinaryOp(op, e1, e2) => self.synth_binary_op(expr, op, e1, e2),
287 fhir::ExprKind::PrimApp(op, e1, e2) => self.synth_prim_app(&op, e1, e2, expr.span),
288 fhir::ExprKind::UnaryOp(op, e) => self.synth_unary_op(op, e),
289 fhir::ExprKind::App(callee, args) => {
290 let sort = self.ensure_resolved_path(&callee)?;
291 let Some(poly_fsort) = self.is_coercible_to_func(&sort, callee.fhir_id) else {
292 return Err(self.emit_err(errors::ExpectedFun::new(callee.span, &sort)));
293 };
294 let fsort = self.instantiate_func_sort(poly_fsort);
295 self.synth_app(fsort, args, expr.span)
296 }
297 fhir::ExprKind::BoundedQuant(.., body) => {
298 self.check_expr(body, &rty::Sort::Bool)?;
299 Ok(rty::Sort::Bool)
300 }
301 fhir::ExprKind::Alias(_alias_reft, args) => {
302 let fsort = self.sort_of_alias_reft(expr.fhir_id);
305 self.synth_app(fsort, args, expr.span)
306 }
307 fhir::ExprKind::IfThenElse(p, e1, e2) => {
308 self.check_expr(p, &rty::Sort::Bool)?;
309 let sort = self.synth_expr(e1)?;
310 self.check_expr(e2, &sort)?;
311 Ok(sort)
312 }
313 fhir::ExprKind::Dot(base, fld) => {
314 let sort = self.synth_expr(base)?;
315 let sort = self
316 .fully_resolve(&sort)
317 .map_err(|_| self.emit_err(errors::CannotInferSort::new(base.span)))?;
318 match &sort {
319 rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) => {
320 let (proj, sort) = sort_def
321 .struct_variant()
322 .field_by_name(sort_def.did(), sort_args, fld.name)
323 .ok_or_else(|| self.emit_field_not_found(&sort, fld))?;
324 self.wfckresults
325 .field_projs_mut()
326 .insert(expr.fhir_id, proj);
327 Ok(sort)
328 }
329 rty::Sort::Bool | rty::Sort::Int | rty::Sort::Real => {
330 Err(self.emit_err(errors::InvalidPrimitiveDotAccess::new(&sort, fld)))
331 }
332 _ => Err(self.emit_field_not_found(&sort, fld)),
333 }
334 }
335 fhir::ExprKind::Constructor(Some(path), field_exprs, spread) => {
336 let path_def_id = match path.res {
340 ExprRes::Ctor(def_id) => def_id,
341 _ => span_bug!(expr.span, "unexpected path in constructor"),
342 };
343 let sort_def = self
344 .genv
345 .adt_sort_def_of(path_def_id)
346 .map_err(|e| self.emit_err(e))?;
347 let fresh_args: rty::List<_> = (0..sort_def.param_count())
349 .map(|_| self.next_sort_var())
350 .collect();
351 let sort = rty::Sort::App(rty::SortCtor::Adt(sort_def.clone()), fresh_args.clone());
352 self.check_field_exprs(
354 expr.span,
355 &sort_def,
356 &fresh_args,
357 field_exprs,
358 &spread,
359 &sort,
360 )?;
361 Ok(sort)
362 }
363 fhir::ExprKind::Block(decls, body) => {
364 for decl in decls {
365 self.check_expr(&decl.init, &self.param_sort(decl.param.id))?;
366 }
367 self.synth_expr(body)
368 }
369 _ => Err(self.emit_err(errors::CannotInferSort::new(expr.span))),
370 }
371 }
372
373 fn synth_path(&mut self, path: &fhir::PathExpr) -> Result<rty::Sort> {
374 match path.res {
375 ExprRes::Param(_, id) => Ok(self.param_sort(id)),
376 ExprRes::Const(def_id) => {
377 if let Some(sort) = self.genv.sort_of_def_id(def_id).emit(&self.genv)? {
378 let info = self.genv.constant_info(def_id).emit(&self.genv)?;
379 if sort != rty::Sort::Int && matches!(info, rty::ConstantInfo::Uninterpreted) {
381 Err(self.emit_err(errors::ConstantAnnotationNeeded::new(path.span)))?;
382 }
383 Ok(sort)
384 } else {
385 span_bug!(path.span, "unexpected const")
386 }
387 }
388 ExprRes::Variant(def_id) => {
389 let Some(sort) = self.genv.sort_of_def_id(def_id).emit(&self.genv)? else {
390 span_bug!(path.span, "unexpected variant {def_id:?}")
391 };
392 Ok(sort)
393 }
394 ExprRes::ConstGeneric(_) => Ok(rty::Sort::Int), ExprRes::NumConst(_) => Ok(rty::Sort::Int),
396 ExprRes::GlobalFunc(SpecFuncKind::Def(name) | SpecFuncKind::Uif(name)) => {
397 Ok(rty::Sort::Func(self.genv.func_sort(name)))
398 }
399 ExprRes::GlobalFunc(SpecFuncKind::Thy(itf)) => {
400 Ok(rty::Sort::Func(THEORY_FUNCS.get(&itf).unwrap().sort.clone()))
401 }
402 ExprRes::GlobalFunc(SpecFuncKind::CharToInt) => {
403 Ok(rty::Sort::Func(rty::PolyFuncSort::new(
404 List::empty(),
405 rty::FuncSort::new(vec![rty::Sort::Char], rty::Sort::Int),
406 )))
407 }
408 ExprRes::Ctor(_) => {
409 span_bug!(path.span, "unexpected constructor in var position")
410 }
411 }
412 }
413
414 fn check_integral(&mut self, op: fhir::BinOp, sort: &rty::Sort, span: Span) -> Result {
415 if matches!(op, fhir::BinOp::Mod) {
416 let sort = self
417 .fully_resolve(sort)
418 .map_err(|_| self.emit_err(errors::CannotInferSort::new(span)))?;
419 if !matches!(sort, rty::Sort::Int | rty::Sort::BitVec(_)) {
420 span_bug!(span, "unexpected sort {sort:?} for operator {op:?}");
421 }
422 }
423 Ok(())
424 }
425
426 fn synth_binary_op(
427 &mut self,
428 expr: &fhir::Expr,
429 op: fhir::BinOp,
430 e1: &fhir::Expr,
431 e2: &fhir::Expr,
432 ) -> Result<rty::Sort> {
433 match op {
434 fhir::BinOp::Or | fhir::BinOp::And | fhir::BinOp::Iff | fhir::BinOp::Imp => {
435 self.check_expr(e1, &rty::Sort::Bool)?;
436 self.check_expr(e2, &rty::Sort::Bool)?;
437 Ok(rty::Sort::Bool)
438 }
439 fhir::BinOp::Eq | fhir::BinOp::Ne => {
440 let sort = self.next_sort_var();
441 self.check_expr(e1, &sort)?;
442 self.check_expr(e2, &sort)?;
443 Ok(rty::Sort::Bool)
444 }
445 fhir::BinOp::Lt | fhir::BinOp::Le | fhir::BinOp::Gt | fhir::BinOp::Ge => {
446 let sort = self.next_sort_var();
447 self.check_expr(e1, &sort)?;
448 self.check_expr(e2, &sort)?;
449 self.sort_of_bin_op
450 .insert(expr.fhir_id, (sort.clone(), expr.span));
451 Ok(rty::Sort::Bool)
452 }
453 fhir::BinOp::Add
454 | fhir::BinOp::Sub
455 | fhir::BinOp::Mul
456 | fhir::BinOp::Div
457 | fhir::BinOp::Mod => {
458 let sort = self.next_num_var();
459 self.check_expr(e1, &sort)?;
460 self.check_expr(e2, &sort)?;
461 self.sort_of_bin_op
462 .insert(expr.fhir_id, (sort.clone(), expr.span));
463 self.check_integral(op, &sort, expr.span)?;
465
466 Ok(sort)
467 }
468 fhir::BinOp::BitAnd
469 | fhir::BinOp::BitOr
470 | fhir::BinOp::BitShl
471 | fhir::BinOp::BitShr => {
472 let sort = rty::Sort::BitVec(self.next_bv_size_var());
473 self.check_expr(e1, &sort)?;
474 self.check_expr(e2, &sort)?;
475 Ok(sort)
476 }
477 }
478 }
479
480 fn synth_unary_op(&mut self, op: fhir::UnOp, e: &fhir::Expr) -> Result<rty::Sort> {
481 match op {
482 fhir::UnOp::Not => {
483 self.check_expr(e, &rty::Sort::Bool)?;
484 Ok(rty::Sort::Bool)
485 }
486 fhir::UnOp::Neg => {
487 self.check_expr(e, &rty::Sort::Int)?;
488 Ok(rty::Sort::Int)
489 }
490 }
491 }
492
493 fn synth_app(&mut self, fsort: FuncSort, args: &[fhir::Expr], span: Span) -> Result<rty::Sort> {
494 if args.len() != fsort.inputs().len() {
495 return Err(self.emit_err(errors::ArgCountMismatch::new(
496 Some(span),
497 String::from("function"),
498 fsort.inputs().len(),
499 args.len(),
500 )));
501 }
502
503 iter::zip(args, fsort.inputs())
504 .try_for_each_exhaust(|(arg, formal)| self.check_expr(arg, formal))?;
505
506 Ok(fsort.output().clone())
507 }
508
509 fn instantiate_func_sort(&mut self, fsort: rty::PolyFuncSort) -> rty::FuncSort {
510 let args = fsort
511 .params()
512 .map(|kind| {
513 match kind {
514 rty::SortParamKind::Sort => rty::SortArg::Sort(self.next_sort_var()),
515 rty::SortParamKind::BvSize => rty::SortArg::BvSize(self.next_bv_size_var()),
516 }
517 })
518 .collect_vec();
519 fsort.instantiate(&args)
520 }
521
522 pub(crate) fn insert_sort_for_bty(&mut self, fhir_id: FhirId, sort: rty::Sort) {
523 self.sort_of_bty.insert(fhir_id, sort);
524 }
525
526 pub(crate) fn sort_of_bty(&self, fhir_id: FhirId) -> rty::Sort {
527 self.sort_of_bty
528 .get(&fhir_id)
529 .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
530 .clone()
531 }
532
533 pub(crate) fn insert_path_args(&mut self, fhir_id: FhirId, args: rty::GenericArgs) {
534 self.path_args.insert(fhir_id, args);
535 }
536
537 pub(crate) fn path_args(&self, fhir_id: FhirId) -> rty::GenericArgs {
538 self.path_args
539 .get(&fhir_id)
540 .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
541 .clone()
542 }
543
544 pub(crate) fn insert_sort_for_alias_reft(&mut self, fhir_id: FhirId, fsort: rty::FuncSort) {
545 self.sort_of_alias_reft.insert(fhir_id, fsort);
546 }
547
548 fn sort_of_alias_reft(&self, fhir_id: FhirId) -> rty::FuncSort {
549 self.sort_of_alias_reft
550 .get(&fhir_id)
551 .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
552 .clone()
553 }
554
555 fn normalize_projection_sort(
556 genv: GlobalEnv,
557 owner: FluxOwnerId,
558 sort: rty::Sort,
559 ) -> rty::Sort {
560 let infcx = genv
561 .tcx()
562 .infer_ctxt()
563 .with_next_trait_solver(true)
564 .build(TypingMode::non_body_analysis());
565 if let Some(def_id) = owner.def_id()
566 && let def_id = genv.maybe_extern_id(def_id).resolved_id()
567 && let Ok(sort) = sort.normalize_sorts(def_id, genv, &infcx)
568 {
569 sort
570 } else {
571 sort
572 }
573 }
574
575 pub(crate) fn normalize_sorts(&mut self) -> QueryResult {
580 let genv = self.genv;
581 for sort in self.sort_of_bty.values_mut() {
582 *sort = Self::normalize_projection_sort(genv, self.owner, sort.clone());
583 }
584 for fsort in self.sort_of_alias_reft.values_mut() {
585 *fsort = genv.deep_normalize_weak_alias_sorts(fsort)?;
586 }
587 Ok(())
588 }
589}
590
591impl<'genv> InferCtxt<'genv, '_> {
592 pub(super) fn declare_param(&mut self, param: fhir::RefineParam<'genv>, sort: rty::Sort) {
593 self.params.insert(param.id, (param, sort));
594 }
595
596 fn is_coercible(&mut self, sort1: &rty::Sort, sort2: &rty::Sort, fhir_id: FhirId) -> bool {
601 if self.try_equate(sort1, sort2).is_some() {
602 return true;
603 }
604
605 let mut sort1 = sort1.clone();
606 let mut sort2 = sort2.clone();
607
608 let mut coercions = vec![];
609 if let Some((def_id, sort)) = self.is_single_field_struct(&sort1) {
610 coercions.push(rty::Coercion::Project(def_id));
611 sort1 = sort.clone();
612 }
613 if let Some((def_id, sort)) = self.is_single_field_struct(&sort2) {
614 coercions.push(rty::Coercion::Inject(def_id));
615 sort2 = sort.clone();
616 }
617 self.wfckresults.coercions_mut().insert(fhir_id, coercions);
618 self.try_equate(&sort1, &sort2).is_some()
619 }
620
621 fn is_coercible_from_func(
622 &mut self,
623 sort: &rty::Sort,
624 fhir_id: FhirId,
625 ) -> Option<rty::PolyFuncSort> {
626 if let rty::Sort::Func(fsort) = sort {
627 Some(fsort.clone())
628 } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
629 self.wfckresults
630 .coercions_mut()
631 .insert(fhir_id, vec![rty::Coercion::Inject(def_id)]);
632 Some(fsort.clone())
633 } else {
634 None
635 }
636 }
637
638 fn is_coercible_to_func(
639 &mut self,
640 sort: &rty::Sort,
641 fhir_id: FhirId,
642 ) -> Option<rty::PolyFuncSort> {
643 if let rty::Sort::Func(fsort) = sort {
644 Some(fsort.clone())
645 } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
646 self.wfckresults
647 .coercions_mut()
648 .insert(fhir_id, vec![rty::Coercion::Project(def_id)]);
649 Some(fsort.clone())
650 } else {
651 None
652 }
653 }
654
655 fn try_equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
656 let sort1 = self.resolve_vars_if_possible(sort1);
657 let sort2 = self.resolve_vars_if_possible(sort2);
658 self.try_equate_inner(&sort1, &sort2)
659 }
660
661 fn try_equate_inner(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
662 match (sort1, sort2) {
663 (rty::Sort::Infer(rty::SortVar(vid1)), rty::Sort::Infer(rty::SortVar(vid2))) => {
664 self.sort_unification_table
665 .unify_var_var(*vid1, *vid2)
666 .ok()?;
667 }
668 (rty::Sort::Infer(rty::SortVar(vid)), sort)
669 | (sort, rty::Sort::Infer(rty::SortVar(vid))) => {
670 self.sort_unification_table
671 .unify_var_value(*vid, Some(sort.clone()))
672 .ok()?;
673 }
674 (rty::Sort::Infer(rty::NumVar(vid1)), rty::Sort::Infer(rty::NumVar(vid2))) => {
675 self.num_unification_table
676 .unify_var_var(*vid1, *vid2)
677 .ok()?;
678 }
679 (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::Int)
680 | (rty::Sort::Int, rty::Sort::Infer(rty::NumVar(vid))) => {
681 self.num_unification_table
682 .unify_var_value(*vid, Some(rty::NumVarValue::Int))
683 .ok()?;
684 }
685 (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::Real)
686 | (rty::Sort::Real, rty::Sort::Infer(rty::NumVar(vid))) => {
687 self.num_unification_table
688 .unify_var_value(*vid, Some(rty::NumVarValue::Real))
689 .ok()?;
690 }
691 (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::BitVec(sz))
692 | (rty::Sort::BitVec(sz), rty::Sort::Infer(rty::NumVar(vid))) => {
693 self.num_unification_table
694 .unify_var_value(*vid, Some(rty::NumVarValue::BitVec(*sz)))
695 .ok()?;
696 }
697
698 (rty::Sort::App(ctor1, args1), rty::Sort::App(ctor2, args2)) => {
699 if ctor1 != ctor2 || args1.len() != args2.len() {
700 return None;
701 }
702 let mut args = vec![];
703 for (s1, s2) in args1.iter().zip(args2.iter()) {
704 args.push(self.try_equate_inner(s1, s2)?);
705 }
706 }
707 (rty::Sort::BitVec(size1), rty::Sort::BitVec(size2)) => {
708 self.try_equate_bv_sizes(*size1, *size2)?;
709 }
710 _ if sort1 == sort2 => {}
711 _ => return None,
712 }
713 Some(sort1.clone())
714 }
715
716 fn try_equate_bv_sizes(
717 &mut self,
718 size1: rty::BvSize,
719 size2: rty::BvSize,
720 ) -> Option<rty::BvSize> {
721 match (size1, size2) {
722 (rty::BvSize::Infer(vid1), rty::BvSize::Infer(vid2)) => {
723 self.bv_size_unification_table
724 .unify_var_var(vid1, vid2)
725 .ok()?;
726 }
727 (rty::BvSize::Infer(vid), size) | (size, rty::BvSize::Infer(vid)) => {
728 self.bv_size_unification_table
729 .unify_var_value(vid, Some(size))
730 .ok()?;
731 }
732 _ if size1 == size2 => {}
733 _ => return None,
734 }
735 Some(size1)
736 }
737
738 fn equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> rty::Sort {
739 self.try_equate(sort1, sort2)
740 .unwrap_or_else(|| bug!("failed to equate sorts: `{sort1:?}` `{sort2:?}`"))
741 }
742
743 pub(crate) fn next_sort_var(&mut self) -> rty::Sort {
744 rty::Sort::Infer(rty::SortVar(self.next_sort_vid()))
745 }
746
747 fn next_num_var(&mut self) -> rty::Sort {
748 rty::Sort::Infer(rty::NumVar(self.next_num_vid()))
749 }
750
751 pub(crate) fn next_sort_vid(&mut self) -> rty::SortVid {
752 self.sort_unification_table.new_key(None)
753 }
754
755 fn next_num_vid(&mut self) -> rty::NumVid {
756 self.num_unification_table.new_key(None)
757 }
758
759 fn next_bv_size_var(&mut self) -> rty::BvSize {
760 rty::BvSize::Infer(self.next_bv_size_vid())
761 }
762
763 fn next_bv_size_vid(&mut self) -> rty::BvSizeVid {
764 self.bv_size_unification_table.new_key(None)
765 }
766
767 fn ensure_resolved_path(&mut self, path: &fhir::PathExpr) -> Result<rty::Sort> {
768 let sort = self.synth_path(path)?;
769 self.fully_resolve(&sort)
770 .map_err(|_| self.emit_err(errors::CannotInferSort::new(path.span)))
771 }
772
773 fn is_single_field_struct(&mut self, sort: &rty::Sort) -> Option<(DefId, rty::Sort)> {
774 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = sort
775 && let Some(variant) = sort_def.opt_struct_variant()
776 && let [sort] = &variant.field_sorts(sort_args)[..]
777 {
778 Some((sort_def.did(), sort.clone()))
779 } else {
780 None
781 }
782 }
783
784 pub(crate) fn into_results(mut self) -> Result<WfckResults> {
785 for (fhir_id, (sort, span)) in std::mem::take(&mut self.sort_of_literal) {
787 let sort = self
788 .fully_resolve(&sort)
789 .map_err(|_| self.emit_err(errors::CannotInferSort::new(span)))?;
790 self.wfckresults.node_sorts_mut().insert(fhir_id, sort);
791 }
792
793 for (fhir_id, (sort, span)) in std::mem::take(&mut self.sort_of_bin_op) {
795 let sort = self
796 .fully_resolve(&sort)
797 .map_err(|_| self.emit_err(errors::CannotInferSort::new(span)))?;
798 self.wfckresults.bin_op_sorts_mut().insert(fhir_id, sort);
799 }
800
801 for (_, (param, sort)) in std::mem::take(&mut self.params) {
803 let sort = self
804 .fully_resolve(&sort)
805 .map_err(|_| self.emit_err(errors::SortAnnotationNeeded::new(¶m)))?;
806 self.wfckresults
807 .node_sorts_mut()
808 .insert(param.fhir_id, sort);
809 }
810
811 Ok(self.wfckresults)
812 }
813
814 pub(crate) fn infer_mode(&self, id: fhir::ParamId) -> fhir::InferMode {
815 fhir::InferMode::from_param_kind(self.params[&id].0.kind)
816 }
817
818 #[track_caller]
819 pub(crate) fn param_sort(&self, id: fhir::ParamId) -> rty::Sort {
820 self.params
821 .get(&id)
822 .unwrap_or_else(|| bug!("no entry found for `{id:?}`"))
823 .1
824 .clone()
825 }
826
827 fn shallow_resolve(&mut self, sort: &rty::Sort) -> rty::Sort {
828 sort.fold_with(&mut ShallowResolver { infcx: self })
829 }
830
831 fn resolve_vars_if_possible(&mut self, sort: &rty::Sort) -> rty::Sort {
832 sort.fold_with(&mut OpportunisticResolver { infcx: self })
833 }
834
835 pub(crate) fn fully_resolve(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, ()> {
836 sort.try_fold_with(&mut FullResolver { infcx: self })
837 }
838}
839
840pub(crate) struct ImplicitParamInferer<'a, 'genv, 'tcx> {
851 infcx: &'a mut InferCtxt<'genv, 'tcx>,
852 errors: Errors<'genv>,
853}
854
855impl<'a, 'genv, 'tcx> ImplicitParamInferer<'a, 'genv, 'tcx> {
856 pub(crate) fn infer(
857 infcx: &'a mut InferCtxt<'genv, 'tcx>,
858 node: &fhir::OwnerNode<'genv>,
859 ) -> Result {
860 let errors = Errors::new(infcx.genv.sess());
861 let mut vis = Self { infcx, errors };
862 vis.visit_node(node);
863 vis.errors.into_result()
864 }
865
866 fn infer_implicit_params(&mut self, idx: &fhir::Expr, expected: &rty::Sort) {
867 match idx.kind {
868 fhir::ExprKind::Var(var, Some(_)) => {
869 let (_, id) = var.res.expect_param();
870 let found = self.infcx.param_sort(id);
871 self.infcx.equate(&found, expected);
872 }
873 fhir::ExprKind::Record(flds) => {
874 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
875 let sorts = sort_def.struct_variant().field_sorts(sort_args);
876 if flds.len() != sorts.len() {
877 self.errors.emit(errors::ArgCountMismatch::new(
878 Some(idx.span),
879 String::from("type"),
880 sorts.len(),
881 flds.len(),
882 ));
883 return;
884 }
885 for (f, sort) in iter::zip(flds, &sorts) {
886 self.infer_implicit_params(f, sort);
887 }
888 } else {
889 self.errors.emit(errors::ArgCountMismatch::new(
890 Some(idx.span),
891 String::from("type"),
892 1,
893 flds.len(),
894 ));
895 }
896 }
897 _ => {}
898 }
899 }
900}
901
902impl<'genv> fhir::visit::Visitor<'genv> for ImplicitParamInferer<'_, 'genv, '_> {
903 fn visit_ty(&mut self, ty: &fhir::Ty<'genv>) {
904 if let fhir::TyKind::Indexed(bty, idx) = &ty.kind {
905 let expected = self.infcx.sort_of_bty(bty.fhir_id);
906 self.infer_implicit_params(idx, &expected);
907 }
908 fhir::visit::walk_ty(self, ty);
909 }
910}
911
912impl InferCtxt<'_, '_> {
913 #[track_caller]
914 fn emit_sort_mismatch(
915 &mut self,
916 span: Span,
917 expected: &rty::Sort,
918 found: &rty::Sort,
919 ) -> ErrorGuaranteed {
920 let expected = self.resolve_vars_if_possible(expected);
921 let found = self.resolve_vars_if_possible(found);
922 self.emit_err(errors::SortMismatch::new(span, expected, found))
923 }
924
925 fn emit_field_not_found(&mut self, sort: &rty::Sort, field: Ident) -> ErrorGuaranteed {
926 self.emit_err(errors::FieldNotFound::new(sort.clone(), field))
927 }
928
929 #[track_caller]
930 fn emit_err<'b>(&'b self, err: impl Diagnostic<'b>) -> ErrorGuaranteed {
931 self.genv.sess().emit_err(err)
932 }
933}
934
935struct ShallowResolver<'a, 'genv, 'tcx> {
936 infcx: &'a mut InferCtxt<'genv, 'tcx>,
937}
938
939impl TypeFolder for ShallowResolver<'_, '_, '_> {
940 fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
941 match sort {
942 rty::Sort::Infer(rty::SortVar(vid)) => {
943 self.infcx
949 .sort_unification_table
950 .probe_value(*vid)
951 .map(|sort| sort.fold_with(self))
952 .unwrap_or(sort.clone())
953 }
954 rty::Sort::Infer(rty::NumVar(vid)) => {
955 self.infcx
957 .num_unification_table
958 .probe_value(*vid)
959 .map(|val| val.to_sort().fold_with(self))
960 .unwrap_or(sort.clone())
961 }
962 rty::Sort::BitVec(rty::BvSize::Infer(vid)) => {
963 self.infcx
964 .bv_size_unification_table
965 .probe_value(*vid)
966 .map(rty::Sort::BitVec)
967 .unwrap_or(sort.clone())
968 }
969 _ => sort.clone(),
970 }
971 }
972}
973
974struct OpportunisticResolver<'a, 'genv, 'tcx> {
975 infcx: &'a mut InferCtxt<'genv, 'tcx>,
976}
977
978impl TypeFolder for OpportunisticResolver<'_, '_, '_> {
979 fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
980 let s = self.infcx.shallow_resolve(sort);
981 s.super_fold_with(self)
982 }
983}
984
985struct FullResolver<'a, 'genv, 'tcx> {
986 infcx: &'a mut InferCtxt<'genv, 'tcx>,
987}
988
989impl FallibleTypeFolder for FullResolver<'_, '_, '_> {
990 type Error = ();
991
992 fn try_fold_sort(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, Self::Error> {
993 let s = self.infcx.shallow_resolve(sort);
994 match s {
995 rty::Sort::Infer(_) | rty::Sort::BitVec(rty::BvSize::Infer(_)) => Err(()),
996 _ => s.try_super_fold_with(self),
997 }
998 }
999}