flux_fhir_analysis/wf/
sortck.rs

1use std::iter;
2
3use ena::unify::InPlaceUnificationTable;
4use flux_common::{bug, iter::IterExt, result::ResultExt, span_bug, tracked_span_bug};
5use flux_errors::{ErrorGuaranteed, Errors};
6use flux_infer::projections::NormalizeExt;
7use flux_middle::{
8    THEORY_FUNCS,
9    fhir::{self, ExprRes, FhirId, FluxOwnerId, SpecFuncKind, visit::Visitor as _},
10    global_env::GlobalEnv,
11    queries::QueryResult,
12    rty::{
13        self, AdtSortDef, FuncSort, WfckResults,
14        fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable},
15    },
16};
17use itertools::{Itertools, izip};
18use rustc_data_structures::unord::UnordMap;
19use rustc_errors::Diagnostic;
20use rustc_hash::FxHashMap;
21use rustc_middle::ty::TypingMode;
22use rustc_span::{Span, def_id::DefId, symbol::Ident};
23
24use super::errors;
25use crate::rustc_infer::infer::TyCtxtInferExt;
26
27type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
28
29pub(super) struct InferCtxt<'genv, 'tcx> {
30    pub genv: GlobalEnv<'genv, 'tcx>,
31    pub owner: FluxOwnerId,
32    pub wfckresults: WfckResults,
33    sort_unification_table: InPlaceUnificationTable<rty::SortVid>,
34    num_unification_table: InPlaceUnificationTable<rty::NumVid>,
35    bv_size_unification_table: InPlaceUnificationTable<rty::BvSizeVid>,
36    params: FxHashMap<fhir::ParamId, (fhir::RefineParam<'genv>, rty::Sort)>,
37    sort_of_bty: FxHashMap<FhirId, rty::Sort>,
38    sort_of_literal: FxHashMap<FhirId, (rty::Sort, Span)>,
39    sort_of_bin_op: FxHashMap<FhirId, (rty::Sort, Span)>,
40    path_args: UnordMap<FhirId, rty::GenericArgs>,
41    sort_of_alias_reft: FxHashMap<FhirId, rty::FuncSort>,
42}
43
44impl<'genv, 'tcx> InferCtxt<'genv, 'tcx> {
45    pub(super) fn new(genv: GlobalEnv<'genv, 'tcx>, owner: FluxOwnerId) -> Self {
46        // We skip 0 because that's used for sort dummy self types during conv.
47        let mut sort_unification_table = InPlaceUnificationTable::new();
48        sort_unification_table.new_key(None);
49        Self {
50            genv,
51            owner,
52            wfckresults: WfckResults::new(owner),
53            sort_unification_table,
54            num_unification_table: InPlaceUnificationTable::new(),
55            bv_size_unification_table: InPlaceUnificationTable::new(),
56            params: Default::default(),
57            sort_of_bty: Default::default(),
58            sort_of_literal: Default::default(),
59            sort_of_bin_op: Default::default(),
60            path_args: Default::default(),
61            sort_of_alias_reft: Default::default(),
62        }
63    }
64
65    fn check_abs(
66        &mut self,
67        expr: &fhir::Expr,
68        params: &[fhir::RefineParam],
69        body: &fhir::Expr,
70        expected: &rty::Sort,
71    ) -> Result {
72        let Some(fsort) = self.is_coercible_from_func(expected, expr.fhir_id) else {
73            return Err(self.emit_err(errors::UnexpectedFun::new(expr.span, expected)));
74        };
75
76        let fsort = fsort.expect_mono();
77
78        if params.len() != fsort.inputs().len() {
79            return Err(self.emit_err(errors::ParamCountMismatch::new(
80                expr.span,
81                fsort.inputs().len(),
82                params.len(),
83            )));
84        }
85        iter::zip(params, fsort.inputs()).try_for_each_exhaust(|(param, expected)| {
86            let found = self.param_sort(param.id);
87            if self.try_equate(&found, expected).is_none() {
88                return Err(self.emit_sort_mismatch(param.span, expected, &found));
89            }
90            Ok(())
91        })?;
92        self.check_expr(body, fsort.output())?;
93        self.wfckresults
94            .node_sorts_mut()
95            .insert(body.fhir_id, fsort.output().clone());
96        Ok(())
97    }
98
99    fn check_field_exprs(
100        &mut self,
101        expr_span: Span,
102        sort_def: &AdtSortDef,
103        sort_args: &[rty::Sort],
104        field_exprs: &[fhir::FieldExpr],
105        spread: &Option<&fhir::Spread>,
106        expected: &rty::Sort,
107    ) -> Result {
108        let sort_by_field_name = sort_def.struct_variant().sort_by_field_name(sort_args);
109        let mut used_fields = FxHashMap::default();
110        for expr in field_exprs {
111            // make sure that the field is actually a field
112            let Some(sort) = sort_by_field_name.get(&expr.ident.name) else {
113                return Err(self.emit_err(errors::FieldNotFound::new(expected.clone(), expr.ident)));
114            };
115            if let Some(old_field) = used_fields.insert(expr.ident.name, expr.ident) {
116                return Err(self.emit_err(errors::DuplicateFieldUsed::new(expr.ident, old_field)));
117            }
118            self.check_expr(&expr.expr, sort)?;
119        }
120        if let Some(spread) = spread {
121            // must check that the spread is of the same sort as the constructor
122            self.check_expr(&spread.expr, expected)
123        } else if sort_by_field_name.len() != used_fields.len() {
124            // emit an error because all fields are not used
125            let missing_fields = sort_by_field_name
126                .into_keys()
127                .filter(|x| !used_fields.contains_key(x))
128                .collect();
129            Err(self.emit_err(errors::ConstructorMissingFields::new(expr_span, missing_fields)))
130        } else {
131            Ok(())
132        }
133    }
134
135    fn check_constructor(
136        &mut self,
137        expr: &fhir::Expr,
138        field_exprs: &[fhir::FieldExpr],
139        spread: &Option<&fhir::Spread>,
140        expected: &rty::Sort,
141    ) -> Result {
142        let expected = self.resolve_vars_if_possible(expected);
143        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = &expected {
144            self.wfckresults
145                .record_ctors_mut()
146                .insert(expr.fhir_id, sort_def.did());
147            self.check_field_exprs(expr.span, sort_def, sort_args, field_exprs, spread, &expected)?;
148            Ok(())
149        } else {
150            Err(self.emit_err(errors::UnexpectedConstructor::new(expr.span, &expected)))
151        }
152    }
153
154    fn check_record(
155        &mut self,
156        arg: &fhir::Expr,
157        flds: &[fhir::Expr],
158        expected: &rty::Sort,
159    ) -> Result {
160        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
161            let sorts = sort_def.struct_variant().field_sorts(sort_args);
162            if flds.len() != sorts.len() {
163                return Err(self.emit_err(errors::ArgCountMismatch::new(
164                    Some(arg.span),
165                    String::from("type"),
166                    sorts.len(),
167                    flds.len(),
168                )));
169            }
170            self.wfckresults
171                .record_ctors_mut()
172                .insert(arg.fhir_id, sort_def.did());
173
174            izip!(flds, &sorts)
175                .map(|(arg, expected)| self.check_expr(arg, expected))
176                .try_collect_exhaust()
177        } else {
178            Err(self.emit_err(errors::ArgCountMismatch::new(
179                Some(arg.span),
180                String::from("type"),
181                1,
182                flds.len(),
183            )))
184        }
185    }
186
187    pub(super) fn check_expr(&mut self, expr: &fhir::Expr, expected: &rty::Sort) -> Result {
188        match &expr.kind {
189            fhir::ExprKind::IfThenElse(p, e1, e2) => {
190                self.check_expr(p, &rty::Sort::Bool)?;
191                self.check_expr(e1, expected)?;
192                self.check_expr(e2, expected)?;
193            }
194            fhir::ExprKind::Abs(params, body) => {
195                self.check_abs(expr, params, body, expected)?;
196            }
197            fhir::ExprKind::Record(fields) => self.check_record(expr, fields, expected)?,
198            fhir::ExprKind::Constructor(None, exprs, spread) => {
199                self.check_constructor(expr, exprs, spread, expected)?;
200            }
201            fhir::ExprKind::UnaryOp(..)
202            | fhir::ExprKind::BinaryOp(..)
203            | fhir::ExprKind::Dot(..)
204            | fhir::ExprKind::App(..)
205            | fhir::ExprKind::Alias(..)
206            | fhir::ExprKind::Var(..)
207            | fhir::ExprKind::Literal(..)
208            | fhir::ExprKind::BoundedQuant(..)
209            | fhir::ExprKind::Block(..)
210            | fhir::ExprKind::Constructor(..) => {
211                let found = self.synth_expr(expr)?;
212                let found = self.resolve_vars_if_possible(&found);
213                let expected = self.resolve_vars_if_possible(expected);
214                if !self.is_coercible(&found, &expected, expr.fhir_id) {
215                    return Err(self.emit_sort_mismatch(expr.span, &expected, &found));
216                }
217            }
218            fhir::ExprKind::Err(_) => {
219                // an error has already been reported so we can just skip
220            }
221        }
222        Ok(())
223    }
224
225    pub(super) fn check_loc(&mut self, loc: &fhir::PathExpr) -> Result {
226        let found = self.synth_path(loc)?;
227        if found == rty::Sort::Loc {
228            Ok(())
229        } else {
230            Err(self.emit_sort_mismatch(loc.span, &rty::Sort::Loc, &found))
231        }
232    }
233
234    fn synth_lit(&mut self, lit: fhir::Lit, expr: &fhir::Expr) -> rty::Sort {
235        match lit {
236            fhir::Lit::Int(_) => {
237                let sort = self.next_num_var();
238                self.sort_of_literal
239                    .insert(expr.fhir_id, (sort.clone(), expr.span));
240                sort
241            }
242            fhir::Lit::Bool(_) => rty::Sort::Bool,
243            fhir::Lit::Real(_) => rty::Sort::Real,
244            fhir::Lit::Str(_) => rty::Sort::Str,
245            fhir::Lit::Char(_) => rty::Sort::Char,
246        }
247    }
248
249    fn synth_expr(&mut self, expr: &fhir::Expr) -> Result<rty::Sort> {
250        match expr.kind {
251            fhir::ExprKind::Var(var, _) => self.synth_path(&var),
252            fhir::ExprKind::Literal(lit) => Ok(self.synth_lit(lit, expr)),
253            fhir::ExprKind::BinaryOp(op, e1, e2) => self.synth_binary_op(expr, op, e1, e2),
254            fhir::ExprKind::UnaryOp(op, e) => self.synth_unary_op(op, e),
255            fhir::ExprKind::App(callee, args) => {
256                let sort = self.ensure_resolved_path(&callee)?;
257                let Some(poly_fsort) = self.is_coercible_to_func(&sort, callee.fhir_id) else {
258                    return Err(self.emit_err(errors::ExpectedFun::new(callee.span, &sort)));
259                };
260                let fsort = self.instantiate_func_sort(poly_fsort);
261                self.synth_app(fsort, args, expr.span)
262            }
263            fhir::ExprKind::BoundedQuant(.., body) => {
264                self.check_expr(body, &rty::Sort::Bool)?;
265                Ok(rty::Sort::Bool)
266            }
267            fhir::ExprKind::Alias(_alias_reft, args) => {
268                // To check the application we only need the sort of `_alias_reft` which we collected
269                // during early conv, but should we do any extra checks on `_alias_reft`?
270                let fsort = self.sort_of_alias_reft(expr.fhir_id);
271                self.synth_app(fsort, args, expr.span)
272            }
273            fhir::ExprKind::IfThenElse(p, e1, e2) => {
274                self.check_expr(p, &rty::Sort::Bool)?;
275                let sort = self.synth_expr(e1)?;
276                self.check_expr(e2, &sort)?;
277                Ok(sort)
278            }
279            fhir::ExprKind::Dot(base, fld) => {
280                let sort = self.synth_expr(base)?;
281                let sort = self
282                    .fully_resolve(&sort)
283                    .map_err(|_| self.emit_err(errors::CannotInferSort::new(base.span)))?;
284                match &sort {
285                    rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) => {
286                        let (proj, sort) = sort_def
287                            .struct_variant()
288                            .field_by_name(sort_def.did(), sort_args, fld.name)
289                            .ok_or_else(|| self.emit_field_not_found(&sort, fld))?;
290                        self.wfckresults
291                            .field_projs_mut()
292                            .insert(expr.fhir_id, proj);
293                        Ok(sort)
294                    }
295                    rty::Sort::Bool | rty::Sort::Int | rty::Sort::Real => {
296                        Err(self.emit_err(errors::InvalidPrimitiveDotAccess::new(&sort, fld)))
297                    }
298                    _ => Err(self.emit_field_not_found(&sort, fld)),
299                }
300            }
301            fhir::ExprKind::Constructor(Some(path), field_exprs, spread) => {
302                // if we have S then we can synthesize otherwise we fail
303                // first get the sort based on the path - for example S { ... } => S
304                // and we should expect sort to be a struct or enum app
305                let path_def_id = match path.res {
306                    ExprRes::Ctor(def_id) => def_id,
307                    _ => span_bug!(expr.span, "unexpected path in constructor"),
308                };
309                let sort_def = self
310                    .genv
311                    .adt_sort_def_of(path_def_id)
312                    .map_err(|e| self.emit_err(e))?;
313                // generate fresh inferred sorts for each param
314                let fresh_args: rty::List<_> = (0..sort_def.param_count())
315                    .map(|_| self.next_sort_var())
316                    .collect();
317                let sort = rty::Sort::App(rty::SortCtor::Adt(sort_def.clone()), fresh_args.clone());
318                // check fields & spread against fresh args
319                self.check_field_exprs(
320                    expr.span,
321                    &sort_def,
322                    &fresh_args,
323                    field_exprs,
324                    &spread,
325                    &sort,
326                )?;
327                Ok(sort)
328            }
329            fhir::ExprKind::Block(decls, body) => {
330                for decl in decls {
331                    self.check_expr(&decl.init, &self.param_sort(decl.param.id))?;
332                }
333                self.synth_expr(body)
334            }
335            _ => Err(self.emit_err(errors::CannotInferSort::new(expr.span))),
336        }
337    }
338
339    fn synth_path(&mut self, path: &fhir::PathExpr) -> Result<rty::Sort> {
340        match path.res {
341            ExprRes::Param(_, id) => Ok(self.param_sort(id)),
342            ExprRes::Const(def_id) => {
343                if let Some(sort) = self.genv.sort_of_def_id(def_id).emit(&self.genv)? {
344                    let info = self.genv.constant_info(def_id).emit(&self.genv)?;
345                    // non-integral constant
346                    if sort != rty::Sort::Int && matches!(info, rty::ConstantInfo::Uninterpreted) {
347                        Err(self.emit_err(errors::ConstantAnnotationNeeded::new(path.span)))?;
348                    }
349                    Ok(sort)
350                } else {
351                    span_bug!(path.span, "unexpected const")
352                }
353            }
354            ExprRes::Variant(def_id) => {
355                let Some(sort) = self.genv.sort_of_def_id(def_id).emit(&self.genv)? else {
356                    span_bug!(path.span, "unexpected variant {def_id:?}")
357                };
358                Ok(sort)
359            }
360            ExprRes::ConstGeneric(_) => Ok(rty::Sort::Int), // TODO: generalize generic-const sorts
361            ExprRes::NumConst(_) => Ok(rty::Sort::Int),
362            ExprRes::GlobalFunc(SpecFuncKind::Def(name) | SpecFuncKind::Uif(name)) => {
363                Ok(rty::Sort::Func(self.genv.func_sort(name)))
364            }
365            ExprRes::GlobalFunc(SpecFuncKind::Thy(itf)) => {
366                Ok(rty::Sort::Func(THEORY_FUNCS.get(&itf).unwrap().sort.clone()))
367            }
368            ExprRes::Ctor(_) => {
369                span_bug!(path.span, "unexpected constructor in var position")
370            }
371        }
372    }
373
374    fn check_integral(&mut self, op: fhir::BinOp, sort: &rty::Sort, span: Span) -> Result {
375        if matches!(op, fhir::BinOp::Mod) {
376            let sort = self
377                .fully_resolve(sort)
378                .map_err(|_| self.emit_err(errors::CannotInferSort::new(span)))?;
379            if !matches!(sort, rty::Sort::Int | rty::Sort::BitVec(_)) {
380                span_bug!(span, "unexpected sort {sort:?} for operator {op:?}");
381            }
382        }
383        Ok(())
384    }
385
386    fn synth_binary_op(
387        &mut self,
388        expr: &fhir::Expr,
389        op: fhir::BinOp,
390        e1: &fhir::Expr,
391        e2: &fhir::Expr,
392    ) -> Result<rty::Sort> {
393        match op {
394            fhir::BinOp::Or | fhir::BinOp::And | fhir::BinOp::Iff | fhir::BinOp::Imp => {
395                self.check_expr(e1, &rty::Sort::Bool)?;
396                self.check_expr(e2, &rty::Sort::Bool)?;
397                Ok(rty::Sort::Bool)
398            }
399            fhir::BinOp::Eq | fhir::BinOp::Ne => {
400                let sort = self.next_sort_var();
401                self.check_expr(e1, &sort)?;
402                self.check_expr(e2, &sort)?;
403                Ok(rty::Sort::Bool)
404            }
405            fhir::BinOp::Lt | fhir::BinOp::Le | fhir::BinOp::Gt | fhir::BinOp::Ge => {
406                let sort = self.next_sort_var();
407                self.check_expr(e1, &sort)?;
408                self.check_expr(e2, &sort)?;
409                self.sort_of_bin_op
410                    .insert(expr.fhir_id, (sort.clone(), expr.span));
411                Ok(rty::Sort::Bool)
412            }
413            fhir::BinOp::Add
414            | fhir::BinOp::Sub
415            | fhir::BinOp::Mul
416            | fhir::BinOp::Div
417            | fhir::BinOp::Mod => {
418                let sort = self.next_num_var();
419                self.check_expr(e1, &sort)?;
420                self.check_expr(e2, &sort)?;
421                self.sort_of_bin_op
422                    .insert(expr.fhir_id, (sort.clone(), expr.span));
423                // check that the sort is integral for mod
424                self.check_integral(op, &sort, expr.span)?;
425
426                Ok(sort)
427            }
428            fhir::BinOp::BitAnd
429            | fhir::BinOp::BitOr
430            | fhir::BinOp::BitShl
431            | fhir::BinOp::BitShr => {
432                let sort = rty::Sort::BitVec(self.next_bv_size_var());
433                self.check_expr(e1, &sort)?;
434                self.check_expr(e2, &sort)?;
435                Ok(sort)
436            }
437        }
438    }
439
440    fn synth_unary_op(&mut self, op: fhir::UnOp, e: &fhir::Expr) -> Result<rty::Sort> {
441        match op {
442            fhir::UnOp::Not => {
443                self.check_expr(e, &rty::Sort::Bool)?;
444                Ok(rty::Sort::Bool)
445            }
446            fhir::UnOp::Neg => {
447                self.check_expr(e, &rty::Sort::Int)?;
448                Ok(rty::Sort::Int)
449            }
450        }
451    }
452
453    fn synth_app(&mut self, fsort: FuncSort, args: &[fhir::Expr], span: Span) -> Result<rty::Sort> {
454        if args.len() != fsort.inputs().len() {
455            return Err(self.emit_err(errors::ArgCountMismatch::new(
456                Some(span),
457                String::from("function"),
458                fsort.inputs().len(),
459                args.len(),
460            )));
461        }
462
463        iter::zip(args, fsort.inputs())
464            .try_for_each_exhaust(|(arg, formal)| self.check_expr(arg, formal))?;
465
466        Ok(fsort.output().clone())
467    }
468
469    fn instantiate_func_sort(&mut self, fsort: rty::PolyFuncSort) -> rty::FuncSort {
470        let args = fsort
471            .params()
472            .map(|kind| {
473                match kind {
474                    rty::SortParamKind::Sort => rty::SortArg::Sort(self.next_sort_var()),
475                    rty::SortParamKind::BvSize => rty::SortArg::BvSize(self.next_bv_size_var()),
476                }
477            })
478            .collect_vec();
479        fsort.instantiate(&args)
480    }
481
482    pub(crate) fn insert_sort_for_bty(&mut self, fhir_id: FhirId, sort: rty::Sort) {
483        self.sort_of_bty.insert(fhir_id, sort);
484    }
485
486    pub(crate) fn sort_of_bty(&self, fhir_id: FhirId) -> rty::Sort {
487        self.sort_of_bty
488            .get(&fhir_id)
489            .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
490            .clone()
491    }
492
493    pub(crate) fn insert_path_args(&mut self, fhir_id: FhirId, args: rty::GenericArgs) {
494        self.path_args.insert(fhir_id, args);
495    }
496
497    pub(crate) fn path_args(&self, fhir_id: FhirId) -> rty::GenericArgs {
498        self.path_args
499            .get(&fhir_id)
500            .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
501            .clone()
502    }
503
504    pub(crate) fn insert_sort_for_alias_reft(&mut self, fhir_id: FhirId, fsort: rty::FuncSort) {
505        self.sort_of_alias_reft.insert(fhir_id, fsort);
506    }
507
508    fn sort_of_alias_reft(&self, fhir_id: FhirId) -> rty::FuncSort {
509        self.sort_of_alias_reft
510            .get(&fhir_id)
511            .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
512            .clone()
513    }
514
515    fn normalize_projection_sort(
516        genv: GlobalEnv,
517        owner: FluxOwnerId,
518        sort: rty::Sort,
519    ) -> rty::Sort {
520        let infcx = genv
521            .tcx()
522            .infer_ctxt()
523            .with_next_trait_solver(true)
524            .build(TypingMode::non_body_analysis());
525        if let Some(def_id) = owner.def_id()
526            && let def_id = genv.maybe_extern_id(def_id).resolved_id()
527            && let Ok(sort) = sort.normalize_sorts(def_id, genv, &infcx)
528        {
529            sort
530        } else {
531            sort
532        }
533    }
534
535    // FIXME(nilehmann) this is a bit of a hack. We should find a more robust way to do normalization
536    // for sort checking, including normalizing projections. [RJ: normalizing_projections is done now]
537    // Maybe we can do lazy normalization. Once we do that maybe we can also stop
538    // expanding aliases in `ConvCtxt::conv_sort`.
539    pub(crate) fn normalize_sorts(&mut self) -> QueryResult {
540        let genv = self.genv;
541        for sort in self.sort_of_bty.values_mut() {
542            *sort = Self::normalize_projection_sort(genv, self.owner, sort.clone());
543        }
544        for fsort in self.sort_of_alias_reft.values_mut() {
545            *fsort = genv.deep_normalize_weak_alias_sorts(fsort)?;
546        }
547        Ok(())
548    }
549}
550
551impl<'genv> InferCtxt<'genv, '_> {
552    pub(super) fn declare_param(&mut self, param: fhir::RefineParam<'genv>, sort: rty::Sort) {
553        self.params.insert(param.id, (param, sort));
554    }
555
556    /// Whether a value of `sort1` can be automatically coerced to a value of `sort2`. A value of an
557    /// [`rty::SortCtor::Adt`] sort with a single field of sort `s` can be coerced to a value of sort
558    /// `s` and vice versa, i.e., we can automatically project the field out of the record or inject
559    /// a value into a record.
560    fn is_coercible(&mut self, sort1: &rty::Sort, sort2: &rty::Sort, fhir_id: FhirId) -> bool {
561        if self.try_equate(sort1, sort2).is_some() {
562            return true;
563        }
564
565        let mut sort1 = sort1.clone();
566        let mut sort2 = sort2.clone();
567
568        let mut coercions = vec![];
569        if let Some((def_id, sort)) = self.is_single_field_struct(&sort1) {
570            coercions.push(rty::Coercion::Project(def_id));
571            sort1 = sort.clone();
572        }
573        if let Some((def_id, sort)) = self.is_single_field_struct(&sort2) {
574            coercions.push(rty::Coercion::Inject(def_id));
575            sort2 = sort.clone();
576        }
577        self.wfckresults.coercions_mut().insert(fhir_id, coercions);
578        self.try_equate(&sort1, &sort2).is_some()
579    }
580
581    fn is_coercible_from_func(
582        &mut self,
583        sort: &rty::Sort,
584        fhir_id: FhirId,
585    ) -> Option<rty::PolyFuncSort> {
586        if let rty::Sort::Func(fsort) = sort {
587            Some(fsort.clone())
588        } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
589            self.wfckresults
590                .coercions_mut()
591                .insert(fhir_id, vec![rty::Coercion::Inject(def_id)]);
592            Some(fsort.clone())
593        } else {
594            None
595        }
596    }
597
598    fn is_coercible_to_func(
599        &mut self,
600        sort: &rty::Sort,
601        fhir_id: FhirId,
602    ) -> Option<rty::PolyFuncSort> {
603        if let rty::Sort::Func(fsort) = sort {
604            Some(fsort.clone())
605        } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
606            self.wfckresults
607                .coercions_mut()
608                .insert(fhir_id, vec![rty::Coercion::Project(def_id)]);
609            Some(fsort.clone())
610        } else {
611            None
612        }
613    }
614
615    fn try_equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
616        let sort1 = self.resolve_vars_if_possible(sort1);
617        let sort2 = self.resolve_vars_if_possible(sort2);
618        self.try_equate_inner(&sort1, &sort2)
619    }
620
621    fn try_equate_inner(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
622        match (sort1, sort2) {
623            (rty::Sort::Infer(rty::SortVar(vid1)), rty::Sort::Infer(rty::SortVar(vid2))) => {
624                self.sort_unification_table
625                    .unify_var_var(*vid1, *vid2)
626                    .ok()?;
627            }
628            (rty::Sort::Infer(rty::SortVar(vid)), sort)
629            | (sort, rty::Sort::Infer(rty::SortVar(vid))) => {
630                self.sort_unification_table
631                    .unify_var_value(*vid, Some(sort.clone()))
632                    .ok()?;
633            }
634            (rty::Sort::Infer(rty::NumVar(vid1)), rty::Sort::Infer(rty::NumVar(vid2))) => {
635                self.num_unification_table
636                    .unify_var_var(*vid1, *vid2)
637                    .ok()?;
638            }
639            (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::Int)
640            | (rty::Sort::Int, rty::Sort::Infer(rty::NumVar(vid))) => {
641                self.num_unification_table
642                    .unify_var_value(*vid, Some(rty::NumVarValue::Int))
643                    .ok()?;
644            }
645            (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::Real)
646            | (rty::Sort::Real, rty::Sort::Infer(rty::NumVar(vid))) => {
647                self.num_unification_table
648                    .unify_var_value(*vid, Some(rty::NumVarValue::Real))
649                    .ok()?;
650            }
651            (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::BitVec(sz))
652            | (rty::Sort::BitVec(sz), rty::Sort::Infer(rty::NumVar(vid))) => {
653                self.num_unification_table
654                    .unify_var_value(*vid, Some(rty::NumVarValue::BitVec(*sz)))
655                    .ok()?;
656            }
657
658            (rty::Sort::App(ctor1, args1), rty::Sort::App(ctor2, args2)) => {
659                if ctor1 != ctor2 || args1.len() != args2.len() {
660                    return None;
661                }
662                let mut args = vec![];
663                for (s1, s2) in args1.iter().zip(args2.iter()) {
664                    args.push(self.try_equate_inner(s1, s2)?);
665                }
666            }
667            (rty::Sort::BitVec(size1), rty::Sort::BitVec(size2)) => {
668                self.try_equate_bv_sizes(*size1, *size2)?;
669            }
670            _ if sort1 == sort2 => {}
671            _ => return None,
672        }
673        Some(sort1.clone())
674    }
675
676    fn try_equate_bv_sizes(
677        &mut self,
678        size1: rty::BvSize,
679        size2: rty::BvSize,
680    ) -> Option<rty::BvSize> {
681        match (size1, size2) {
682            (rty::BvSize::Infer(vid1), rty::BvSize::Infer(vid2)) => {
683                self.bv_size_unification_table
684                    .unify_var_var(vid1, vid2)
685                    .ok()?;
686            }
687            (rty::BvSize::Infer(vid), size) | (size, rty::BvSize::Infer(vid)) => {
688                self.bv_size_unification_table
689                    .unify_var_value(vid, Some(size))
690                    .ok()?;
691            }
692            _ if size1 == size2 => {}
693            _ => return None,
694        }
695        Some(size1)
696    }
697
698    fn equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> rty::Sort {
699        self.try_equate(sort1, sort2)
700            .unwrap_or_else(|| bug!("failed to equate sorts: `{sort1:?}` `{sort2:?}`"))
701    }
702
703    pub(crate) fn next_sort_var(&mut self) -> rty::Sort {
704        rty::Sort::Infer(rty::SortVar(self.next_sort_vid()))
705    }
706
707    fn next_num_var(&mut self) -> rty::Sort {
708        rty::Sort::Infer(rty::NumVar(self.next_num_vid()))
709    }
710
711    pub(crate) fn next_sort_vid(&mut self) -> rty::SortVid {
712        self.sort_unification_table.new_key(None)
713    }
714
715    fn next_num_vid(&mut self) -> rty::NumVid {
716        self.num_unification_table.new_key(None)
717    }
718
719    fn next_bv_size_var(&mut self) -> rty::BvSize {
720        rty::BvSize::Infer(self.next_bv_size_vid())
721    }
722
723    fn next_bv_size_vid(&mut self) -> rty::BvSizeVid {
724        self.bv_size_unification_table.new_key(None)
725    }
726
727    fn ensure_resolved_path(&mut self, path: &fhir::PathExpr) -> Result<rty::Sort> {
728        let sort = self.synth_path(path)?;
729        self.fully_resolve(&sort)
730            .map_err(|_| self.emit_err(errors::CannotInferSort::new(path.span)))
731    }
732
733    fn is_single_field_struct(&mut self, sort: &rty::Sort) -> Option<(DefId, rty::Sort)> {
734        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = sort
735            && let Some(variant) = sort_def.opt_struct_variant()
736            && let [sort] = &variant.field_sorts(sort_args)[..]
737        {
738            Some((sort_def.did(), sort.clone()))
739        } else {
740            None
741        }
742    }
743
744    pub(crate) fn into_results(mut self) -> Result<WfckResults> {
745        // Make sure the int literals are fully resolved
746        for (fhir_id, (sort, span)) in std::mem::take(&mut self.sort_of_literal) {
747            let sort = self
748                .fully_resolve(&sort)
749                .map_err(|_| self.emit_err(errors::CannotInferSort::new(span)))?;
750            self.wfckresults.node_sorts_mut().insert(fhir_id, sort);
751        }
752
753        // Make sure the binary operators are fully resolved
754        for (fhir_id, (sort, span)) in std::mem::take(&mut self.sort_of_bin_op) {
755            let sort = self
756                .fully_resolve(&sort)
757                .map_err(|_| self.emit_err(errors::CannotInferSort::new(span)))?;
758            self.wfckresults.bin_op_sorts_mut().insert(fhir_id, sort);
759        }
760
761        // Make sure all parameters are fully resolved
762        for (_, (param, sort)) in std::mem::take(&mut self.params) {
763            let sort = self
764                .fully_resolve(&sort)
765                .map_err(|_| self.emit_err(errors::SortAnnotationNeeded::new(&param)))?;
766            self.wfckresults
767                .node_sorts_mut()
768                .insert(param.fhir_id, sort);
769        }
770
771        Ok(self.wfckresults)
772    }
773
774    pub(crate) fn infer_mode(&self, id: fhir::ParamId) -> fhir::InferMode {
775        fhir::InferMode::from_param_kind(self.params[&id].0.kind)
776    }
777
778    #[track_caller]
779    pub(crate) fn param_sort(&self, id: fhir::ParamId) -> rty::Sort {
780        self.params
781            .get(&id)
782            .unwrap_or_else(|| bug!("no entry found for `{id:?}`"))
783            .1
784            .clone()
785    }
786
787    fn shallow_resolve(&mut self, sort: &rty::Sort) -> rty::Sort {
788        sort.fold_with(&mut ShallowResolver { infcx: self })
789    }
790
791    fn resolve_vars_if_possible(&mut self, sort: &rty::Sort) -> rty::Sort {
792        sort.fold_with(&mut OpportunisticResolver { infcx: self })
793    }
794
795    pub(crate) fn fully_resolve(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, ()> {
796        sort.try_fold_with(&mut FullResolver { infcx: self })
797    }
798}
799
800/// Before the main sort inference, we do a first traversal checking all implicitly scoped parameters
801/// declared with `@` or `#` and infer their sort based on the type they are indexing, e.g., if `n` was
802/// declared as `i32[@n]`, we infer `int` for `n`.
803///
804/// This prepass is necessary because sometimes the order in which we traverse expressions can
805/// affect what we can infer. By resolving implicit parameters first, we ensure more consistent and
806/// complete inference regardless of how expressions are later traversed.
807///
808/// It should be possible to improve sort checking (e.g., by allowing partially resolved sorts in
809/// function position) such that we don't need this anymore.
810pub(crate) struct ImplicitParamInferer<'a, 'genv, 'tcx> {
811    infcx: &'a mut InferCtxt<'genv, 'tcx>,
812    errors: Errors<'genv>,
813}
814
815impl<'a, 'genv, 'tcx> ImplicitParamInferer<'a, 'genv, 'tcx> {
816    pub(crate) fn infer(
817        infcx: &'a mut InferCtxt<'genv, 'tcx>,
818        node: &fhir::OwnerNode<'genv>,
819    ) -> Result {
820        let errors = Errors::new(infcx.genv.sess());
821        let mut vis = Self { infcx, errors };
822        vis.visit_node(node);
823        vis.errors.into_result()
824    }
825
826    fn infer_implicit_params(&mut self, idx: &fhir::Expr, expected: &rty::Sort) {
827        match idx.kind {
828            fhir::ExprKind::Var(var, Some(_)) => {
829                let (_, id) = var.res.expect_param();
830                let found = self.infcx.param_sort(id);
831                self.infcx.equate(&found, expected);
832            }
833            fhir::ExprKind::Record(flds) => {
834                if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
835                    let sorts = sort_def.struct_variant().field_sorts(sort_args);
836                    if flds.len() != sorts.len() {
837                        self.errors.emit(errors::ArgCountMismatch::new(
838                            Some(idx.span),
839                            String::from("type"),
840                            sorts.len(),
841                            flds.len(),
842                        ));
843                        return;
844                    }
845                    for (f, sort) in iter::zip(flds, &sorts) {
846                        self.infer_implicit_params(f, sort);
847                    }
848                } else {
849                    self.errors.emit(errors::ArgCountMismatch::new(
850                        Some(idx.span),
851                        String::from("type"),
852                        1,
853                        flds.len(),
854                    ));
855                }
856            }
857            _ => {}
858        }
859    }
860}
861
862impl<'genv> fhir::visit::Visitor<'genv> for ImplicitParamInferer<'_, 'genv, '_> {
863    fn visit_ty(&mut self, ty: &fhir::Ty<'genv>) {
864        if let fhir::TyKind::Indexed(bty, idx) = &ty.kind {
865            let expected = self.infcx.sort_of_bty(bty.fhir_id);
866            self.infer_implicit_params(idx, &expected);
867        }
868        fhir::visit::walk_ty(self, ty);
869    }
870}
871
872impl InferCtxt<'_, '_> {
873    #[track_caller]
874    fn emit_sort_mismatch(
875        &mut self,
876        span: Span,
877        expected: &rty::Sort,
878        found: &rty::Sort,
879    ) -> ErrorGuaranteed {
880        let expected = self.resolve_vars_if_possible(expected);
881        let found = self.resolve_vars_if_possible(found);
882        self.emit_err(errors::SortMismatch::new(span, expected, found))
883    }
884
885    fn emit_field_not_found(&mut self, sort: &rty::Sort, field: Ident) -> ErrorGuaranteed {
886        self.emit_err(errors::FieldNotFound::new(sort.clone(), field))
887    }
888
889    #[track_caller]
890    fn emit_err<'b>(&'b self, err: impl Diagnostic<'b>) -> ErrorGuaranteed {
891        self.genv.sess().emit_err(err)
892    }
893}
894
895struct ShallowResolver<'a, 'genv, 'tcx> {
896    infcx: &'a mut InferCtxt<'genv, 'tcx>,
897}
898
899impl TypeFolder for ShallowResolver<'_, '_, '_> {
900    fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
901        match sort {
902            rty::Sort::Infer(rty::SortVar(vid)) => {
903                // if `sort` is a sort variable, it can be resolved to an num/bit-vec variable,
904                // which can then be recursively resolved, hence the recursion. Note though that
905                // we prevent sort variables from unifying to other sort variables directly (though
906                // they may be embedded structurally), so this recursion should always be of very
907                // limited depth.
908                self.infcx
909                    .sort_unification_table
910                    .probe_value(*vid)
911                    .map(|sort| sort.fold_with(self))
912                    .unwrap_or(sort.clone())
913            }
914            rty::Sort::Infer(rty::NumVar(vid)) => {
915                // same here, a num var could had been unified with a bitvector
916                self.infcx
917                    .num_unification_table
918                    .probe_value(*vid)
919                    .map(|val| val.to_sort().fold_with(self))
920                    .unwrap_or(sort.clone())
921            }
922            rty::Sort::BitVec(rty::BvSize::Infer(vid)) => {
923                self.infcx
924                    .bv_size_unification_table
925                    .probe_value(*vid)
926                    .map(rty::Sort::BitVec)
927                    .unwrap_or(sort.clone())
928            }
929            _ => sort.clone(),
930        }
931    }
932}
933
934struct OpportunisticResolver<'a, 'genv, 'tcx> {
935    infcx: &'a mut InferCtxt<'genv, 'tcx>,
936}
937
938impl TypeFolder for OpportunisticResolver<'_, '_, '_> {
939    fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
940        let s = self.infcx.shallow_resolve(sort);
941        s.super_fold_with(self)
942    }
943}
944
945struct FullResolver<'a, 'genv, 'tcx> {
946    infcx: &'a mut InferCtxt<'genv, 'tcx>,
947}
948
949impl FallibleTypeFolder for FullResolver<'_, '_, '_> {
950    type Error = ();
951
952    fn try_fold_sort(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, Self::Error> {
953        let s = self.infcx.shallow_resolve(sort);
954        match s {
955            rty::Sort::Infer(_) | rty::Sort::BitVec(rty::BvSize::Infer(_)) => Err(()),
956            _ => s.try_super_fold_with(self),
957        }
958    }
959}