1use std::iter;
2
3use ena::unify::InPlaceUnificationTable;
4use flux_common::{bug, iter::IterExt, result::ResultExt, span_bug, tracked_span_bug};
5use flux_errors::{ErrorGuaranteed, Errors};
6use flux_infer::projections::NormalizeExt;
7use flux_middle::{
8 THEORY_FUNCS,
9 fhir::{self, ExprRes, FhirId, FluxOwnerId, visit::Visitor as _},
10 global_env::GlobalEnv,
11 queries::QueryResult,
12 rty::{
13 self, AdtSortDef, FuncSort, List, WfckResults,
14 fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable},
15 },
16};
17use itertools::{Itertools, izip};
18use rustc_data_structures::unord::UnordMap;
19use rustc_errors::Diagnostic;
20use rustc_hash::{FxHashMap, FxHashSet};
21use rustc_middle::ty::TypingMode;
22use rustc_span::{Span, def_id::DefId, symbol::Ident};
23
24use super::errors;
25use crate::rustc_infer::infer::TyCtxtInferExt;
26
27type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
28
29pub(super) struct InferCtxt<'genv, 'tcx> {
30 pub genv: GlobalEnv<'genv, 'tcx>,
31 pub owner: FluxOwnerId,
32 pub wfckresults: WfckResults,
33 sort_unification_table: InPlaceUnificationTable<rty::SortVid>,
34 num_unification_table: InPlaceUnificationTable<rty::NumVid>,
35 bv_size_unification_table: InPlaceUnificationTable<rty::BvSizeVid>,
36 params: FxHashMap<fhir::ParamId, (fhir::RefineParam<'genv>, rty::Sort)>,
37 sort_of_bty: FxHashMap<FhirId, rty::Sort>,
38 sort_of_literal: FxHashMap<FhirId, (rty::Sort, Span)>,
39 sort_of_bin_op: FxHashMap<FhirId, (rty::Sort, Span)>,
40 sort_args_of_app: FxHashMap<FhirId, (List<rty::SortArg>, Span)>,
41 casts: FxHashSet<FhirId>,
42 path_args: UnordMap<FhirId, rty::GenericArgs>,
43 sort_of_alias_reft: FxHashMap<FhirId, rty::FuncSort>,
44}
45
46pub fn prim_op_sort(op: &fhir::BinOp) -> Option<(Vec<rty::Sort>, rty::Sort)> {
47 match op {
48 fhir::BinOp::BitAnd | fhir::BinOp::BitOr | fhir::BinOp::BitShl | fhir::BinOp::BitShr => {
49 Some((vec![rty::Sort::Int, rty::Sort::Int], rty::Sort::Int))
50 }
51 _ => None,
52 }
53}
54
55impl<'genv, 'tcx> InferCtxt<'genv, 'tcx> {
56 pub(super) fn new(genv: GlobalEnv<'genv, 'tcx>, owner: FluxOwnerId) -> Self {
57 let mut sort_unification_table = InPlaceUnificationTable::new();
59 sort_unification_table.new_key(None);
60 Self {
61 genv,
62 owner,
63 wfckresults: WfckResults::new(owner),
64 sort_unification_table,
65 num_unification_table: InPlaceUnificationTable::new(),
66 bv_size_unification_table: InPlaceUnificationTable::new(),
67 params: Default::default(),
68 sort_of_bty: Default::default(),
69 sort_of_literal: Default::default(),
70 sort_of_bin_op: Default::default(),
71 path_args: Default::default(),
72 sort_of_alias_reft: Default::default(),
73 sort_args_of_app: Default::default(),
74 casts: Default::default(),
75 }
76 }
77
78 fn check_abs(
79 &mut self,
80 expr: &fhir::Expr,
81 params: &[fhir::RefineParam],
82 body: &fhir::Expr,
83 expected: &rty::Sort,
84 ) -> Result {
85 let Some(fsort) = self.is_coercible_from_func(expected, expr.fhir_id) else {
86 return Err(self.emit_err(errors::UnexpectedFun::new(expr.span, expected)));
87 };
88
89 let fsort = fsort.expect_mono();
90
91 if params.len() != fsort.inputs().len() {
92 return Err(self.emit_err(errors::ParamCountMismatch::new(
93 expr.span,
94 fsort.inputs().len(),
95 params.len(),
96 )));
97 }
98 iter::zip(params, fsort.inputs()).try_for_each_exhaust(|(param, expected)| {
99 let found = self.param_sort(param.id);
100 if self.try_equate(&found, expected).is_none() {
101 return Err(self.emit_sort_mismatch(param.span, expected, &found));
102 }
103 Ok(())
104 })?;
105 self.check_expr(body, fsort.output())?;
106 self.wfckresults
107 .node_sorts_mut()
108 .insert(body.fhir_id, fsort.output().clone());
109 Ok(())
110 }
111
112 fn check_field_exprs(
113 &mut self,
114 expr_span: Span,
115 sort_def: &AdtSortDef,
116 sort_args: &[rty::Sort],
117 field_exprs: &[fhir::FieldExpr],
118 spread: &Option<&fhir::Spread>,
119 expected: &rty::Sort,
120 ) -> Result {
121 let sort_by_field_name = sort_def.struct_variant().sort_by_field_name(sort_args);
122 let mut used_fields = FxHashMap::default();
123 for expr in field_exprs {
124 let Some(sort) = sort_by_field_name.get(&expr.ident.name) else {
126 return Err(self.emit_err(errors::FieldNotFound::new(expected.clone(), expr.ident)));
127 };
128 if let Some(old_field) = used_fields.insert(expr.ident.name, expr.ident) {
129 return Err(self.emit_err(errors::DuplicateFieldUsed::new(expr.ident, old_field)));
130 }
131 self.check_expr(&expr.expr, sort)?;
132 }
133 if let Some(spread) = spread {
134 self.check_expr(&spread.expr, expected)
136 } else if sort_by_field_name.len() != used_fields.len() {
137 let missing_fields = sort_by_field_name
139 .into_keys()
140 .filter(|x| !used_fields.contains_key(x))
141 .collect();
142 Err(self.emit_err(errors::ConstructorMissingFields::new(expr_span, missing_fields)))
143 } else {
144 Ok(())
145 }
146 }
147
148 fn check_constructor(
149 &mut self,
150 expr: &fhir::Expr,
151 field_exprs: &[fhir::FieldExpr],
152 spread: &Option<&fhir::Spread>,
153 expected: &rty::Sort,
154 ) -> Result {
155 let expected = self.resolve_vars_if_possible(expected);
156 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = &expected {
157 self.wfckresults
158 .record_ctors_mut()
159 .insert(expr.fhir_id, sort_def.did());
160 self.check_field_exprs(expr.span, sort_def, sort_args, field_exprs, spread, &expected)?;
161 Ok(())
162 } else {
163 Err(self.emit_err(errors::UnexpectedConstructor::new(expr.span, &expected)))
164 }
165 }
166
167 fn check_record(
168 &mut self,
169 arg: &fhir::Expr,
170 flds: &[fhir::Expr],
171 expected: &rty::Sort,
172 ) -> Result {
173 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
174 let sorts = sort_def.struct_variant().field_sorts(sort_args);
175 if flds.len() != sorts.len() {
176 return Err(self.emit_err(errors::ArgCountMismatch::new(
177 Some(arg.span),
178 String::from("type"),
179 sorts.len(),
180 flds.len(),
181 )));
182 }
183 self.wfckresults
184 .record_ctors_mut()
185 .insert(arg.fhir_id, sort_def.did());
186
187 izip!(flds, &sorts)
188 .map(|(arg, expected)| self.check_expr(arg, expected))
189 .try_collect_exhaust()
190 } else {
191 Err(self.emit_err(errors::ArgCountMismatch::new(
192 Some(arg.span),
193 String::from("type"),
194 1,
195 flds.len(),
196 )))
197 }
198 }
199
200 pub(super) fn check_expr(&mut self, expr: &fhir::Expr, expected: &rty::Sort) -> Result {
201 match &expr.kind {
202 fhir::ExprKind::IfThenElse(p, e1, e2) => {
203 self.check_expr(p, &rty::Sort::Bool)?;
204 self.check_expr(e1, expected)?;
205 self.check_expr(e2, expected)?;
206 }
207 fhir::ExprKind::Abs(params, body) => {
208 self.check_abs(expr, params, body, expected)?;
209 }
210 fhir::ExprKind::Record(fields) => self.check_record(expr, fields, expected)?,
211 fhir::ExprKind::Constructor(None, exprs, spread) => {
212 self.check_constructor(expr, exprs, spread, expected)?;
213 }
214 fhir::ExprKind::UnaryOp(..)
215 | fhir::ExprKind::BinaryOp(..)
216 | fhir::ExprKind::Dot(..)
217 | fhir::ExprKind::App(..)
218 | fhir::ExprKind::Alias(..)
219 | fhir::ExprKind::Var(..)
220 | fhir::ExprKind::Literal(..)
221 | fhir::ExprKind::BoundedQuant(..)
222 | fhir::ExprKind::Block(..)
223 | fhir::ExprKind::Constructor(..)
224 | fhir::ExprKind::PrimApp(..) => {
225 let found = self.synth_expr(expr)?;
226 let found = self.resolve_vars_if_possible(&found);
227 let expected = self.resolve_vars_if_possible(expected);
228 if !self.is_coercible(&found, &expected, expr.fhir_id) {
229 return Err(self.emit_sort_mismatch(expr.span, &expected, &found));
230 }
231 }
232 fhir::ExprKind::Err(_) => {
233 }
235 }
236 Ok(())
237 }
238
239 pub(super) fn check_loc(&mut self, loc: &fhir::PathExpr) -> Result {
240 let found = self.synth_path(loc)?;
241 if found == rty::Sort::Loc {
242 Ok(())
243 } else {
244 Err(self.emit_sort_mismatch(loc.span, &rty::Sort::Loc, &found))
245 }
246 }
247
248 fn synth_lit(&mut self, lit: fhir::Lit, expr: &fhir::Expr) -> rty::Sort {
249 match lit {
250 fhir::Lit::Int(_) => {
251 let sort = self.next_num_var();
252 self.sort_of_literal
253 .insert(expr.fhir_id, (sort.clone(), expr.span));
254 sort
255 }
256 fhir::Lit::Bool(_) => rty::Sort::Bool,
257 fhir::Lit::Real(_) => rty::Sort::Real,
258 fhir::Lit::Str(_) => rty::Sort::Str,
259 fhir::Lit::Char(_) => rty::Sort::Char,
260 }
261 }
262
263 fn synth_prim_app(
264 &mut self,
265 op: &fhir::BinOp,
266 e1: &fhir::Expr,
267 e2: &fhir::Expr,
268 span: Span,
269 ) -> Result<rty::Sort> {
270 let Some((inputs, output)) = prim_op_sort(op) else {
271 return Err(self.emit_err(errors::UnsupportedPrimOp::new(span, *op)));
272 };
273 let [sort1, sort2] = &inputs[..] else {
274 return Err(self.emit_err(errors::ArgCountMismatch::new(
275 Some(span),
276 String::from("primop app"),
277 inputs.len(),
278 2,
279 )));
280 };
281 self.check_expr(e1, sort1)?;
282 self.check_expr(e2, sort2)?;
283 Ok(output)
284 }
285
286 fn synth_expr(&mut self, expr: &fhir::Expr) -> Result<rty::Sort> {
287 match expr.kind {
288 fhir::ExprKind::Var(var, _) => self.synth_path(&var),
289 fhir::ExprKind::Literal(lit) => Ok(self.synth_lit(lit, expr)),
290 fhir::ExprKind::BinaryOp(op, e1, e2) => self.synth_binary_op(expr, op, e1, e2),
291 fhir::ExprKind::PrimApp(op, e1, e2) => self.synth_prim_app(&op, e1, e2, expr.span),
292 fhir::ExprKind::UnaryOp(op, e) => self.synth_unary_op(op, e),
293 fhir::ExprKind::App(callee, args) => {
294 let sort = self.ensure_resolved_path(&callee)?;
295 let Some(poly_fsort) = self.is_coercible_to_func(&sort, callee.fhir_id) else {
296 return Err(self.emit_err(errors::ExpectedFun::new(callee.span, &sort)));
297 };
298 let fhir_id = expr.fhir_id;
299 let span = expr.span;
300 let fsort = self.instantiate_func_sort(poly_fsort, fhir_id, span);
301 if matches!(callee.res, ExprRes::GlobalFunc(fhir::SpecFuncKind::Cast)) {
302 self.casts.insert(fhir_id);
303 }
304 self.synth_app(fsort, args, expr.span)
305 }
306 fhir::ExprKind::BoundedQuant(.., body) => {
307 self.check_expr(body, &rty::Sort::Bool)?;
308 Ok(rty::Sort::Bool)
309 }
310 fhir::ExprKind::Alias(_alias_reft, args) => {
311 let fsort = self.sort_of_alias_reft(expr.fhir_id);
314 self.synth_app(fsort, args, expr.span)
315 }
316 fhir::ExprKind::IfThenElse(p, e1, e2) => {
317 self.check_expr(p, &rty::Sort::Bool)?;
318 let sort = self.synth_expr(e1)?;
319 self.check_expr(e2, &sort)?;
320 Ok(sort)
321 }
322 fhir::ExprKind::Dot(base, fld) => {
323 let sort = self.synth_expr(base)?;
324 let sort = self
325 .fully_resolve(&sort)
326 .map_err(|_| self.emit_err(errors::CannotInferSort::new(base.span)))?;
327 match &sort {
328 rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) => {
329 let (proj, sort) = sort_def
330 .struct_variant()
331 .field_by_name(sort_def.did(), sort_args, fld.name)
332 .ok_or_else(|| self.emit_field_not_found(&sort, fld))?;
333 self.wfckresults
334 .field_projs_mut()
335 .insert(expr.fhir_id, proj);
336 Ok(sort)
337 }
338 rty::Sort::Bool | rty::Sort::Int | rty::Sort::Real => {
339 Err(self.emit_err(errors::InvalidPrimitiveDotAccess::new(&sort, fld)))
340 }
341 _ => Err(self.emit_field_not_found(&sort, fld)),
342 }
343 }
344 fhir::ExprKind::Constructor(Some(path), field_exprs, spread) => {
345 let path_def_id = match path.res {
349 ExprRes::Ctor(def_id) => def_id,
350 _ => span_bug!(expr.span, "unexpected path in constructor"),
351 };
352 let sort_def = self
353 .genv
354 .adt_sort_def_of(path_def_id)
355 .map_err(|e| self.emit_err(e))?;
356 let fresh_args: rty::List<_> = (0..sort_def.param_count())
358 .map(|_| self.next_sort_var())
359 .collect();
360 let sort = rty::Sort::App(rty::SortCtor::Adt(sort_def.clone()), fresh_args.clone());
361 self.check_field_exprs(
363 expr.span,
364 &sort_def,
365 &fresh_args,
366 field_exprs,
367 &spread,
368 &sort,
369 )?;
370 Ok(sort)
371 }
372 fhir::ExprKind::Block(decls, body) => {
373 for decl in decls {
374 self.check_expr(&decl.init, &self.param_sort(decl.param.id))?;
375 }
376 self.synth_expr(body)
377 }
378 _ => Err(self.emit_err(errors::CannotInferSort::new(expr.span))),
379 }
380 }
381
382 fn synth_path(&mut self, path: &fhir::PathExpr) -> Result<rty::Sort> {
383 match path.res {
384 ExprRes::Param(_, id) => Ok(self.param_sort(id)),
385 ExprRes::Const(def_id) => {
386 if let Some(sort) = self.genv.sort_of_def_id(def_id).emit(&self.genv)? {
387 let info = self.genv.constant_info(def_id).emit(&self.genv)?;
388 if sort != rty::Sort::Int && matches!(info, rty::ConstantInfo::Uninterpreted) {
390 Err(self.emit_err(errors::ConstantAnnotationNeeded::new(path.span)))?;
391 }
392 Ok(sort)
393 } else {
394 span_bug!(path.span, "unexpected const")
395 }
396 }
397 ExprRes::Variant(def_id) => {
398 let Some(sort) = self.genv.sort_of_def_id(def_id).emit(&self.genv)? else {
399 span_bug!(path.span, "unexpected variant {def_id:?}")
400 };
401 Ok(sort)
402 }
403 ExprRes::ConstGeneric(_) => Ok(rty::Sort::Int), ExprRes::NumConst(_) => Ok(rty::Sort::Int),
405 ExprRes::GlobalFunc(spec_func) => {
406 let fsort = match spec_func {
407 fhir::SpecFuncKind::Def(name) | fhir::SpecFuncKind::Uif(name) => {
408 self.genv.func_sort(name)
409 }
410 fhir::SpecFuncKind::Thy(itf) => THEORY_FUNCS.get(&itf).unwrap().sort.clone(),
411 fhir::SpecFuncKind::Cast => {
412 rty::PolyFuncSort::new(
413 List::from_arr([rty::SortParamKind::Sort, rty::SortParamKind::Sort]),
414 rty::FuncSort::new(
415 vec![rty::Sort::Var(rty::ParamSort::from(0_usize))],
416 rty::Sort::Var(rty::ParamSort::from(1_usize)),
417 ),
418 )
419 }
420 };
421 Ok(rty::Sort::Func(fsort))
422 }
423 ExprRes::Ctor(_) => {
424 span_bug!(path.span, "unexpected constructor in var position")
425 }
426 }
427 }
428
429 fn check_integral(&mut self, op: fhir::BinOp, sort: &rty::Sort, span: Span) -> Result {
430 if matches!(op, fhir::BinOp::Mod) {
431 let sort = self
432 .fully_resolve(sort)
433 .map_err(|_| self.emit_err(errors::CannotInferSort::new(span)))?;
434 if !matches!(sort, rty::Sort::Int | rty::Sort::BitVec(_)) {
435 span_bug!(span, "unexpected sort {sort:?} for operator {op:?}");
436 }
437 }
438 Ok(())
439 }
440
441 fn synth_binary_op(
442 &mut self,
443 expr: &fhir::Expr,
444 op: fhir::BinOp,
445 e1: &fhir::Expr,
446 e2: &fhir::Expr,
447 ) -> Result<rty::Sort> {
448 match op {
449 fhir::BinOp::Or | fhir::BinOp::And | fhir::BinOp::Iff | fhir::BinOp::Imp => {
450 self.check_expr(e1, &rty::Sort::Bool)?;
451 self.check_expr(e2, &rty::Sort::Bool)?;
452 Ok(rty::Sort::Bool)
453 }
454 fhir::BinOp::Eq | fhir::BinOp::Ne => {
455 let sort = self.next_sort_var();
456 self.check_expr(e1, &sort)?;
457 self.check_expr(e2, &sort)?;
458 Ok(rty::Sort::Bool)
459 }
460 fhir::BinOp::Lt | fhir::BinOp::Le | fhir::BinOp::Gt | fhir::BinOp::Ge => {
461 let sort = self.next_sort_var();
462 self.check_expr(e1, &sort)?;
463 self.check_expr(e2, &sort)?;
464 self.sort_of_bin_op
465 .insert(expr.fhir_id, (sort.clone(), expr.span));
466 Ok(rty::Sort::Bool)
467 }
468 fhir::BinOp::Add
469 | fhir::BinOp::Sub
470 | fhir::BinOp::Mul
471 | fhir::BinOp::Div
472 | fhir::BinOp::Mod => {
473 let sort = self.next_num_var();
474 self.check_expr(e1, &sort)?;
475 self.check_expr(e2, &sort)?;
476 self.sort_of_bin_op
477 .insert(expr.fhir_id, (sort.clone(), expr.span));
478 self.check_integral(op, &sort, expr.span)?;
480
481 Ok(sort)
482 }
483 fhir::BinOp::BitAnd
484 | fhir::BinOp::BitOr
485 | fhir::BinOp::BitShl
486 | fhir::BinOp::BitShr => {
487 let sort = rty::Sort::BitVec(self.next_bv_size_var());
488 self.check_expr(e1, &sort)?;
489 self.check_expr(e2, &sort)?;
490 Ok(sort)
491 }
492 }
493 }
494
495 fn synth_unary_op(&mut self, op: fhir::UnOp, e: &fhir::Expr) -> Result<rty::Sort> {
496 match op {
497 fhir::UnOp::Not => {
498 self.check_expr(e, &rty::Sort::Bool)?;
499 Ok(rty::Sort::Bool)
500 }
501 fhir::UnOp::Neg => {
502 self.check_expr(e, &rty::Sort::Int)?;
503 Ok(rty::Sort::Int)
504 }
505 }
506 }
507
508 fn synth_app(&mut self, fsort: FuncSort, args: &[fhir::Expr], span: Span) -> Result<rty::Sort> {
509 if args.len() != fsort.inputs().len() {
510 return Err(self.emit_err(errors::ArgCountMismatch::new(
511 Some(span),
512 String::from("function"),
513 fsort.inputs().len(),
514 args.len(),
515 )));
516 }
517
518 iter::zip(args, fsort.inputs())
519 .try_for_each_exhaust(|(arg, formal)| self.check_expr(arg, formal))?;
520
521 Ok(fsort.output().clone())
522 }
523
524 fn instantiate_func_sort(
525 &mut self,
526 fsort: rty::PolyFuncSort,
527 fhir_id: FhirId,
528 span: Span,
529 ) -> rty::FuncSort {
530 let args = fsort
531 .params()
532 .map(|kind| {
533 match kind {
534 rty::SortParamKind::Sort => rty::SortArg::Sort(self.next_sort_var()),
535 rty::SortParamKind::BvSize => rty::SortArg::BvSize(self.next_bv_size_var()),
536 }
537 })
538 .collect_vec();
539 self.sort_args_of_app
540 .insert(fhir_id, (List::from_vec(args.clone()), span));
541 fsort.instantiate(&args)
542 }
543
544 pub(crate) fn insert_sort_for_bty(&mut self, fhir_id: FhirId, sort: rty::Sort) {
545 self.sort_of_bty.insert(fhir_id, sort);
546 }
547
548 pub(crate) fn sort_of_bty(&self, fhir_id: FhirId) -> rty::Sort {
549 self.sort_of_bty
550 .get(&fhir_id)
551 .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
552 .clone()
553 }
554
555 pub(crate) fn insert_path_args(&mut self, fhir_id: FhirId, args: rty::GenericArgs) {
556 self.path_args.insert(fhir_id, args);
557 }
558
559 pub(crate) fn path_args(&self, fhir_id: FhirId) -> rty::GenericArgs {
560 self.path_args
561 .get(&fhir_id)
562 .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
563 .clone()
564 }
565
566 pub(crate) fn insert_sort_for_alias_reft(&mut self, fhir_id: FhirId, fsort: rty::FuncSort) {
567 self.sort_of_alias_reft.insert(fhir_id, fsort);
568 }
569
570 fn sort_of_alias_reft(&self, fhir_id: FhirId) -> rty::FuncSort {
571 self.sort_of_alias_reft
572 .get(&fhir_id)
573 .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
574 .clone()
575 }
576
577 fn normalize_projection_sort(
578 genv: GlobalEnv,
579 owner: FluxOwnerId,
580 sort: rty::Sort,
581 ) -> rty::Sort {
582 let infcx = genv
583 .tcx()
584 .infer_ctxt()
585 .with_next_trait_solver(true)
586 .build(TypingMode::non_body_analysis());
587 if let Some(def_id) = owner.def_id()
588 && let def_id = genv.maybe_extern_id(def_id).resolved_id()
589 && let Ok(sort) = sort.normalize_sorts(def_id, genv, &infcx)
590 {
591 sort
592 } else {
593 sort
594 }
595 }
596
597 pub(crate) fn normalize_sorts(&mut self) -> QueryResult {
602 let genv = self.genv;
603 for sort in self.sort_of_bty.values_mut() {
604 *sort = Self::normalize_projection_sort(genv, self.owner, sort.clone());
605 }
606 for fsort in self.sort_of_alias_reft.values_mut() {
607 *fsort = genv.deep_normalize_weak_alias_sorts(fsort)?;
608 }
609 Ok(())
610 }
611}
612
613impl<'genv> InferCtxt<'genv, '_> {
614 pub(super) fn declare_param(&mut self, param: fhir::RefineParam<'genv>, sort: rty::Sort) {
615 self.params.insert(param.id, (param, sort));
616 }
617
618 fn is_coercible(&mut self, sort1: &rty::Sort, sort2: &rty::Sort, fhir_id: FhirId) -> bool {
623 if self.try_equate(sort1, sort2).is_some() {
624 return true;
625 }
626
627 let mut sort1 = sort1.clone();
628 let mut sort2 = sort2.clone();
629
630 let mut coercions = vec![];
631 if let Some((def_id, sort)) = self.is_single_field_struct(&sort1) {
632 coercions.push(rty::Coercion::Project(def_id));
633 sort1 = sort.clone();
634 }
635 if let Some((def_id, sort)) = self.is_single_field_struct(&sort2) {
636 coercions.push(rty::Coercion::Inject(def_id));
637 sort2 = sort.clone();
638 }
639 self.wfckresults.coercions_mut().insert(fhir_id, coercions);
640 self.try_equate(&sort1, &sort2).is_some()
641 }
642
643 fn is_coercible_from_func(
644 &mut self,
645 sort: &rty::Sort,
646 fhir_id: FhirId,
647 ) -> Option<rty::PolyFuncSort> {
648 if let rty::Sort::Func(fsort) = sort {
649 Some(fsort.clone())
650 } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
651 self.wfckresults
652 .coercions_mut()
653 .insert(fhir_id, vec![rty::Coercion::Inject(def_id)]);
654 Some(fsort.clone())
655 } else {
656 None
657 }
658 }
659
660 fn is_coercible_to_func(
661 &mut self,
662 sort: &rty::Sort,
663 fhir_id: FhirId,
664 ) -> Option<rty::PolyFuncSort> {
665 if let rty::Sort::Func(fsort) = sort {
666 Some(fsort.clone())
667 } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
668 self.wfckresults
669 .coercions_mut()
670 .insert(fhir_id, vec![rty::Coercion::Project(def_id)]);
671 Some(fsort.clone())
672 } else {
673 None
674 }
675 }
676
677 fn try_equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
678 let sort1 = self.resolve_vars_if_possible(sort1);
679 let sort2 = self.resolve_vars_if_possible(sort2);
680 self.try_equate_inner(&sort1, &sort2)
681 }
682
683 fn try_equate_inner(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
684 match (sort1, sort2) {
685 (rty::Sort::Infer(rty::SortVar(vid1)), rty::Sort::Infer(rty::SortVar(vid2))) => {
686 self.sort_unification_table
687 .unify_var_var(*vid1, *vid2)
688 .ok()?;
689 }
690 (rty::Sort::Infer(rty::SortVar(vid)), sort)
691 | (sort, rty::Sort::Infer(rty::SortVar(vid))) => {
692 self.sort_unification_table
693 .unify_var_value(*vid, Some(sort.clone()))
694 .ok()?;
695 }
696 (rty::Sort::Infer(rty::NumVar(vid1)), rty::Sort::Infer(rty::NumVar(vid2))) => {
697 self.num_unification_table
698 .unify_var_var(*vid1, *vid2)
699 .ok()?;
700 }
701 (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::Int)
702 | (rty::Sort::Int, rty::Sort::Infer(rty::NumVar(vid))) => {
703 self.num_unification_table
704 .unify_var_value(*vid, Some(rty::NumVarValue::Int))
705 .ok()?;
706 }
707 (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::Real)
708 | (rty::Sort::Real, rty::Sort::Infer(rty::NumVar(vid))) => {
709 self.num_unification_table
710 .unify_var_value(*vid, Some(rty::NumVarValue::Real))
711 .ok()?;
712 }
713 (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::BitVec(sz))
714 | (rty::Sort::BitVec(sz), rty::Sort::Infer(rty::NumVar(vid))) => {
715 self.num_unification_table
716 .unify_var_value(*vid, Some(rty::NumVarValue::BitVec(*sz)))
717 .ok()?;
718 }
719
720 (rty::Sort::App(ctor1, args1), rty::Sort::App(ctor2, args2)) => {
721 if ctor1 != ctor2 || args1.len() != args2.len() {
722 return None;
723 }
724 let mut args = vec![];
725 for (s1, s2) in args1.iter().zip(args2.iter()) {
726 args.push(self.try_equate_inner(s1, s2)?);
727 }
728 }
729 (rty::Sort::BitVec(size1), rty::Sort::BitVec(size2)) => {
730 self.try_equate_bv_sizes(*size1, *size2)?;
731 }
732 _ if sort1 == sort2 => {}
733 _ => return None,
734 }
735 Some(sort1.clone())
736 }
737
738 fn try_equate_bv_sizes(
739 &mut self,
740 size1: rty::BvSize,
741 size2: rty::BvSize,
742 ) -> Option<rty::BvSize> {
743 match (size1, size2) {
744 (rty::BvSize::Infer(vid1), rty::BvSize::Infer(vid2)) => {
745 self.bv_size_unification_table
746 .unify_var_var(vid1, vid2)
747 .ok()?;
748 }
749 (rty::BvSize::Infer(vid), size) | (size, rty::BvSize::Infer(vid)) => {
750 self.bv_size_unification_table
751 .unify_var_value(vid, Some(size))
752 .ok()?;
753 }
754 _ if size1 == size2 => {}
755 _ => return None,
756 }
757 Some(size1)
758 }
759
760 fn equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> rty::Sort {
761 self.try_equate(sort1, sort2)
762 .unwrap_or_else(|| bug!("failed to equate sorts: `{sort1:?}` `{sort2:?}`"))
763 }
764
765 pub(crate) fn next_sort_var(&mut self) -> rty::Sort {
766 rty::Sort::Infer(rty::SortVar(self.next_sort_vid()))
767 }
768
769 fn next_num_var(&mut self) -> rty::Sort {
770 rty::Sort::Infer(rty::NumVar(self.next_num_vid()))
771 }
772
773 pub(crate) fn next_sort_vid(&mut self) -> rty::SortVid {
774 self.sort_unification_table.new_key(None)
775 }
776
777 fn next_num_vid(&mut self) -> rty::NumVid {
778 self.num_unification_table.new_key(None)
779 }
780
781 fn next_bv_size_var(&mut self) -> rty::BvSize {
782 rty::BvSize::Infer(self.next_bv_size_vid())
783 }
784
785 fn next_bv_size_vid(&mut self) -> rty::BvSizeVid {
786 self.bv_size_unification_table.new_key(None)
787 }
788
789 fn ensure_resolved_path(&mut self, path: &fhir::PathExpr) -> Result<rty::Sort> {
790 let sort = self.synth_path(path)?;
791 self.fully_resolve(&sort)
792 .map_err(|_| self.emit_err(errors::CannotInferSort::new(path.span)))
793 }
794
795 fn is_single_field_struct(&mut self, sort: &rty::Sort) -> Option<(DefId, rty::Sort)> {
796 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = sort
797 && let Some(variant) = sort_def.opt_struct_variant()
798 && let [sort] = &variant.field_sorts(sort_args)[..]
799 {
800 Some((sort_def.did(), sort.clone()))
801 } else {
802 None
803 }
804 }
805
806 pub(crate) fn into_results(mut self) -> Result<WfckResults> {
807 for (fhir_id, (sort, span)) in std::mem::take(&mut self.sort_of_literal) {
809 let sort = self
810 .fully_resolve(&sort)
811 .map_err(|_| self.emit_err(errors::CannotInferSort::new(span)))?;
812 self.wfckresults.node_sorts_mut().insert(fhir_id, sort);
813 }
814
815 for (fhir_id, (sort, span)) in std::mem::take(&mut self.sort_of_bin_op) {
817 let sort = self
818 .fully_resolve(&sort)
819 .map_err(|_| self.emit_err(errors::CannotInferSort::new(span)))?;
820 self.wfckresults.bin_op_sorts_mut().insert(fhir_id, sort);
821 }
822
823 let allow_uninterpreted_cast = self
824 .owner
825 .def_id()
826 .map_or(false, |def_id| self.genv.infer_opts(def_id).allow_uninterpreted_cast);
827
828 for (fhir_id, (sort_args, span)) in std::mem::take(&mut self.sort_args_of_app) {
830 let mut res = vec![];
831 for sort_arg in &sort_args {
832 let sort_arg = if let rty::SortArg::Sort(sort) = sort_arg {
833 let sort = self
834 .fully_resolve(sort)
835 .map_err(|_| self.emit_err(errors::CannotInferSort::new(span)))?;
836 rty::SortArg::Sort(sort)
837 } else {
838 sort_arg.clone()
839 };
840 res.push(sort_arg);
841 }
842 if self.casts.contains(&fhir_id) {
843 let [rty::SortArg::Sort(from), rty::SortArg::Sort(to)] = &res[..] else {
844 span_bug!(span, "invalid sort args!")
845 };
846 if !allow_uninterpreted_cast
847 && matches!(from.cast_kind(to), rty::CastKind::Uninterpreted)
848 {
849 return Err(self.emit_err(errors::InvalidCast::new(span, from, to)));
850 }
851 }
852 self.wfckresults
853 .fn_app_sorts_mut()
854 .insert(fhir_id, res.into());
855 }
856
857 for (_, (param, sort)) in std::mem::take(&mut self.params) {
859 let sort = self
860 .fully_resolve(&sort)
861 .map_err(|_| self.emit_err(errors::SortAnnotationNeeded::new(¶m)))?;
862 self.wfckresults
863 .node_sorts_mut()
864 .insert(param.fhir_id, sort);
865 }
866
867 Ok(self.wfckresults)
868 }
869
870 pub(crate) fn infer_mode(&self, id: fhir::ParamId) -> fhir::InferMode {
871 fhir::InferMode::from_param_kind(self.params[&id].0.kind)
872 }
873
874 #[track_caller]
875 pub(crate) fn param_sort(&self, id: fhir::ParamId) -> rty::Sort {
876 self.params
877 .get(&id)
878 .unwrap_or_else(|| bug!("no entry found for `{id:?}`"))
879 .1
880 .clone()
881 }
882
883 fn shallow_resolve(&mut self, sort: &rty::Sort) -> rty::Sort {
884 sort.fold_with(&mut ShallowResolver { infcx: self })
885 }
886
887 fn resolve_vars_if_possible(&mut self, sort: &rty::Sort) -> rty::Sort {
888 sort.fold_with(&mut OpportunisticResolver { infcx: self })
889 }
890
891 pub(crate) fn fully_resolve(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, ()> {
892 sort.try_fold_with(&mut FullResolver { infcx: self })
893 }
894}
895
896pub(crate) struct ImplicitParamInferer<'a, 'genv, 'tcx> {
907 infcx: &'a mut InferCtxt<'genv, 'tcx>,
908 errors: Errors<'genv>,
909}
910
911impl<'a, 'genv, 'tcx> ImplicitParamInferer<'a, 'genv, 'tcx> {
912 pub(crate) fn infer(
913 infcx: &'a mut InferCtxt<'genv, 'tcx>,
914 node: &fhir::OwnerNode<'genv>,
915 ) -> Result {
916 let errors = Errors::new(infcx.genv.sess());
917 let mut vis = Self { infcx, errors };
918 vis.visit_node(node);
919 vis.errors.into_result()
920 }
921
922 fn infer_implicit_params(&mut self, idx: &fhir::Expr, expected: &rty::Sort) {
923 match idx.kind {
924 fhir::ExprKind::Var(var, Some(_)) => {
925 let (_, id) = var.res.expect_param();
926 let found = self.infcx.param_sort(id);
927 self.infcx.equate(&found, expected);
928 }
929 fhir::ExprKind::Record(flds) => {
930 if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
931 let sorts = sort_def.struct_variant().field_sorts(sort_args);
932 if flds.len() != sorts.len() {
933 self.errors.emit(errors::ArgCountMismatch::new(
934 Some(idx.span),
935 String::from("type"),
936 sorts.len(),
937 flds.len(),
938 ));
939 return;
940 }
941 for (f, sort) in iter::zip(flds, &sorts) {
942 self.infer_implicit_params(f, sort);
943 }
944 } else {
945 self.errors.emit(errors::ArgCountMismatch::new(
946 Some(idx.span),
947 String::from("type"),
948 1,
949 flds.len(),
950 ));
951 }
952 }
953 _ => {}
954 }
955 }
956}
957
958impl<'genv> fhir::visit::Visitor<'genv> for ImplicitParamInferer<'_, 'genv, '_> {
959 fn visit_ty(&mut self, ty: &fhir::Ty<'genv>) {
960 if let fhir::TyKind::Indexed(bty, idx) = &ty.kind {
961 let expected = self.infcx.sort_of_bty(bty.fhir_id);
962 self.infer_implicit_params(idx, &expected);
963 }
964 fhir::visit::walk_ty(self, ty);
965 }
966}
967
968impl InferCtxt<'_, '_> {
969 #[track_caller]
970 fn emit_sort_mismatch(
971 &mut self,
972 span: Span,
973 expected: &rty::Sort,
974 found: &rty::Sort,
975 ) -> ErrorGuaranteed {
976 let expected = self.resolve_vars_if_possible(expected);
977 let found = self.resolve_vars_if_possible(found);
978 self.emit_err(errors::SortMismatch::new(span, expected, found))
979 }
980
981 fn emit_field_not_found(&mut self, sort: &rty::Sort, field: Ident) -> ErrorGuaranteed {
982 self.emit_err(errors::FieldNotFound::new(sort.clone(), field))
983 }
984
985 #[track_caller]
986 fn emit_err<'b>(&'b self, err: impl Diagnostic<'b>) -> ErrorGuaranteed {
987 self.genv.sess().emit_err(err)
988 }
989}
990
991struct ShallowResolver<'a, 'genv, 'tcx> {
992 infcx: &'a mut InferCtxt<'genv, 'tcx>,
993}
994
995impl TypeFolder for ShallowResolver<'_, '_, '_> {
996 fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
997 match sort {
998 rty::Sort::Infer(rty::SortVar(vid)) => {
999 self.infcx
1005 .sort_unification_table
1006 .probe_value(*vid)
1007 .map(|sort| sort.fold_with(self))
1008 .unwrap_or(sort.clone())
1009 }
1010 rty::Sort::Infer(rty::NumVar(vid)) => {
1011 self.infcx
1013 .num_unification_table
1014 .probe_value(*vid)
1015 .map(|val| val.to_sort().fold_with(self))
1016 .unwrap_or(sort.clone())
1017 }
1018 rty::Sort::BitVec(rty::BvSize::Infer(vid)) => {
1019 self.infcx
1020 .bv_size_unification_table
1021 .probe_value(*vid)
1022 .map(rty::Sort::BitVec)
1023 .unwrap_or(sort.clone())
1024 }
1025 _ => sort.clone(),
1026 }
1027 }
1028}
1029
1030struct OpportunisticResolver<'a, 'genv, 'tcx> {
1031 infcx: &'a mut InferCtxt<'genv, 'tcx>,
1032}
1033
1034impl TypeFolder for OpportunisticResolver<'_, '_, '_> {
1035 fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
1036 let s = self.infcx.shallow_resolve(sort);
1037 s.super_fold_with(self)
1038 }
1039}
1040
1041struct FullResolver<'a, 'genv, 'tcx> {
1042 infcx: &'a mut InferCtxt<'genv, 'tcx>,
1043}
1044
1045impl FallibleTypeFolder for FullResolver<'_, '_, '_> {
1046 type Error = ();
1047
1048 fn try_fold_sort(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, Self::Error> {
1049 let s = self.infcx.shallow_resolve(sort);
1050 match s {
1051 rty::Sort::Infer(_) | rty::Sort::BitVec(rty::BvSize::Infer(_)) => Err(()),
1052 _ => s.try_super_fold_with(self),
1053 }
1054 }
1055}