flux_fhir_analysis/wf/
sortck.rs

1use std::iter;
2
3use ena::unify::InPlaceUnificationTable;
4use flux_common::{bug, iter::IterExt, result::ResultExt, span_bug, tracked_span_bug};
5use flux_errors::{ErrorGuaranteed, Errors};
6use flux_infer::projections::NormalizeExt;
7use flux_middle::{
8    THEORY_FUNCS,
9    fhir::{self, ExprRes, FhirId, FluxOwnerId, SpecFuncKind, visit::Visitor as _},
10    global_env::GlobalEnv,
11    queries::QueryResult,
12    rty::{
13        self, AdtSortDef, FuncSort, List, WfckResults,
14        fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable},
15    },
16};
17use itertools::{Itertools, izip};
18use rustc_data_structures::unord::UnordMap;
19use rustc_errors::Diagnostic;
20use rustc_hash::FxHashMap;
21use rustc_middle::ty::TypingMode;
22use rustc_span::{Span, def_id::DefId, symbol::Ident};
23
24use super::errors;
25use crate::rustc_infer::infer::TyCtxtInferExt;
26
27type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
28
29pub(super) struct InferCtxt<'genv, 'tcx> {
30    pub genv: GlobalEnv<'genv, 'tcx>,
31    pub owner: FluxOwnerId,
32    pub wfckresults: WfckResults,
33    sort_unification_table: InPlaceUnificationTable<rty::SortVid>,
34    num_unification_table: InPlaceUnificationTable<rty::NumVid>,
35    bv_size_unification_table: InPlaceUnificationTable<rty::BvSizeVid>,
36    params: FxHashMap<fhir::ParamId, (fhir::RefineParam<'genv>, rty::Sort)>,
37    sort_of_bty: FxHashMap<FhirId, rty::Sort>,
38    sort_of_literal: FxHashMap<FhirId, (rty::Sort, Span)>,
39    sort_of_bin_op: FxHashMap<FhirId, (rty::Sort, Span)>,
40    path_args: UnordMap<FhirId, rty::GenericArgs>,
41    sort_of_alias_reft: FxHashMap<FhirId, rty::FuncSort>,
42}
43
44pub fn prim_op_sort(op: &fhir::BinOp) -> Option<(Vec<rty::Sort>, rty::Sort)> {
45    match op {
46        fhir::BinOp::BitAnd | fhir::BinOp::BitOr | fhir::BinOp::BitShl | fhir::BinOp::BitShr => {
47            Some((vec![rty::Sort::Int, rty::Sort::Int], rty::Sort::Int))
48        }
49        _ => None,
50    }
51}
52
53impl<'genv, 'tcx> InferCtxt<'genv, 'tcx> {
54    pub(super) fn new(genv: GlobalEnv<'genv, 'tcx>, owner: FluxOwnerId) -> Self {
55        // We skip 0 because that's used for sort dummy self types during conv.
56        let mut sort_unification_table = InPlaceUnificationTable::new();
57        sort_unification_table.new_key(None);
58        Self {
59            genv,
60            owner,
61            wfckresults: WfckResults::new(owner),
62            sort_unification_table,
63            num_unification_table: InPlaceUnificationTable::new(),
64            bv_size_unification_table: InPlaceUnificationTable::new(),
65            params: Default::default(),
66            sort_of_bty: Default::default(),
67            sort_of_literal: Default::default(),
68            sort_of_bin_op: Default::default(),
69            path_args: Default::default(),
70            sort_of_alias_reft: Default::default(),
71        }
72    }
73
74    fn check_abs(
75        &mut self,
76        expr: &fhir::Expr,
77        params: &[fhir::RefineParam],
78        body: &fhir::Expr,
79        expected: &rty::Sort,
80    ) -> Result {
81        let Some(fsort) = self.is_coercible_from_func(expected, expr.fhir_id) else {
82            return Err(self.emit_err(errors::UnexpectedFun::new(expr.span, expected)));
83        };
84
85        let fsort = fsort.expect_mono();
86
87        if params.len() != fsort.inputs().len() {
88            return Err(self.emit_err(errors::ParamCountMismatch::new(
89                expr.span,
90                fsort.inputs().len(),
91                params.len(),
92            )));
93        }
94        iter::zip(params, fsort.inputs()).try_for_each_exhaust(|(param, expected)| {
95            let found = self.param_sort(param.id);
96            if self.try_equate(&found, expected).is_none() {
97                return Err(self.emit_sort_mismatch(param.span, expected, &found));
98            }
99            Ok(())
100        })?;
101        self.check_expr(body, fsort.output())?;
102        self.wfckresults
103            .node_sorts_mut()
104            .insert(body.fhir_id, fsort.output().clone());
105        Ok(())
106    }
107
108    fn check_field_exprs(
109        &mut self,
110        expr_span: Span,
111        sort_def: &AdtSortDef,
112        sort_args: &[rty::Sort],
113        field_exprs: &[fhir::FieldExpr],
114        spread: &Option<&fhir::Spread>,
115        expected: &rty::Sort,
116    ) -> Result {
117        let sort_by_field_name = sort_def.struct_variant().sort_by_field_name(sort_args);
118        let mut used_fields = FxHashMap::default();
119        for expr in field_exprs {
120            // make sure that the field is actually a field
121            let Some(sort) = sort_by_field_name.get(&expr.ident.name) else {
122                return Err(self.emit_err(errors::FieldNotFound::new(expected.clone(), expr.ident)));
123            };
124            if let Some(old_field) = used_fields.insert(expr.ident.name, expr.ident) {
125                return Err(self.emit_err(errors::DuplicateFieldUsed::new(expr.ident, old_field)));
126            }
127            self.check_expr(&expr.expr, sort)?;
128        }
129        if let Some(spread) = spread {
130            // must check that the spread is of the same sort as the constructor
131            self.check_expr(&spread.expr, expected)
132        } else if sort_by_field_name.len() != used_fields.len() {
133            // emit an error because all fields are not used
134            let missing_fields = sort_by_field_name
135                .into_keys()
136                .filter(|x| !used_fields.contains_key(x))
137                .collect();
138            Err(self.emit_err(errors::ConstructorMissingFields::new(expr_span, missing_fields)))
139        } else {
140            Ok(())
141        }
142    }
143
144    fn check_constructor(
145        &mut self,
146        expr: &fhir::Expr,
147        field_exprs: &[fhir::FieldExpr],
148        spread: &Option<&fhir::Spread>,
149        expected: &rty::Sort,
150    ) -> Result {
151        let expected = self.resolve_vars_if_possible(expected);
152        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = &expected {
153            self.wfckresults
154                .record_ctors_mut()
155                .insert(expr.fhir_id, sort_def.did());
156            self.check_field_exprs(expr.span, sort_def, sort_args, field_exprs, spread, &expected)?;
157            Ok(())
158        } else {
159            Err(self.emit_err(errors::UnexpectedConstructor::new(expr.span, &expected)))
160        }
161    }
162
163    fn check_record(
164        &mut self,
165        arg: &fhir::Expr,
166        flds: &[fhir::Expr],
167        expected: &rty::Sort,
168    ) -> Result {
169        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
170            let sorts = sort_def.struct_variant().field_sorts(sort_args);
171            if flds.len() != sorts.len() {
172                return Err(self.emit_err(errors::ArgCountMismatch::new(
173                    Some(arg.span),
174                    String::from("type"),
175                    sorts.len(),
176                    flds.len(),
177                )));
178            }
179            self.wfckresults
180                .record_ctors_mut()
181                .insert(arg.fhir_id, sort_def.did());
182
183            izip!(flds, &sorts)
184                .map(|(arg, expected)| self.check_expr(arg, expected))
185                .try_collect_exhaust()
186        } else {
187            Err(self.emit_err(errors::ArgCountMismatch::new(
188                Some(arg.span),
189                String::from("type"),
190                1,
191                flds.len(),
192            )))
193        }
194    }
195
196    pub(super) fn check_expr(&mut self, expr: &fhir::Expr, expected: &rty::Sort) -> Result {
197        match &expr.kind {
198            fhir::ExprKind::IfThenElse(p, e1, e2) => {
199                self.check_expr(p, &rty::Sort::Bool)?;
200                self.check_expr(e1, expected)?;
201                self.check_expr(e2, expected)?;
202            }
203            fhir::ExprKind::Abs(params, body) => {
204                self.check_abs(expr, params, body, expected)?;
205            }
206            fhir::ExprKind::Record(fields) => self.check_record(expr, fields, expected)?,
207            fhir::ExprKind::Constructor(None, exprs, spread) => {
208                self.check_constructor(expr, exprs, spread, expected)?;
209            }
210            fhir::ExprKind::UnaryOp(..)
211            | fhir::ExprKind::BinaryOp(..)
212            | fhir::ExprKind::Dot(..)
213            | fhir::ExprKind::App(..)
214            | fhir::ExprKind::Alias(..)
215            | fhir::ExprKind::Var(..)
216            | fhir::ExprKind::Literal(..)
217            | fhir::ExprKind::BoundedQuant(..)
218            | fhir::ExprKind::Block(..)
219            | fhir::ExprKind::Constructor(..)
220            | fhir::ExprKind::PrimApp(..) => {
221                let found = self.synth_expr(expr)?;
222                let found = self.resolve_vars_if_possible(&found);
223                let expected = self.resolve_vars_if_possible(expected);
224                if !self.is_coercible(&found, &expected, expr.fhir_id) {
225                    return Err(self.emit_sort_mismatch(expr.span, &expected, &found));
226                }
227            }
228            fhir::ExprKind::Err(_) => {
229                // an error has already been reported so we can just skip
230            }
231        }
232        Ok(())
233    }
234
235    pub(super) fn check_loc(&mut self, loc: &fhir::PathExpr) -> Result {
236        let found = self.synth_path(loc)?;
237        if found == rty::Sort::Loc {
238            Ok(())
239        } else {
240            Err(self.emit_sort_mismatch(loc.span, &rty::Sort::Loc, &found))
241        }
242    }
243
244    fn synth_lit(&mut self, lit: fhir::Lit, expr: &fhir::Expr) -> rty::Sort {
245        match lit {
246            fhir::Lit::Int(_) => {
247                let sort = self.next_num_var();
248                self.sort_of_literal
249                    .insert(expr.fhir_id, (sort.clone(), expr.span));
250                sort
251            }
252            fhir::Lit::Bool(_) => rty::Sort::Bool,
253            fhir::Lit::Real(_) => rty::Sort::Real,
254            fhir::Lit::Str(_) => rty::Sort::Str,
255            fhir::Lit::Char(_) => rty::Sort::Char,
256        }
257    }
258
259    fn synth_prim_app(
260        &mut self,
261        op: &fhir::BinOp,
262        e1: &fhir::Expr,
263        e2: &fhir::Expr,
264        span: Span,
265    ) -> Result<rty::Sort> {
266        let Some((inputs, output)) = prim_op_sort(op) else {
267            return Err(self.emit_err(errors::UnsupportedPrimOp::new(span, *op)));
268        };
269        let [sort1, sort2] = &inputs[..] else {
270            return Err(self.emit_err(errors::ArgCountMismatch::new(
271                Some(span),
272                String::from("primop app"),
273                inputs.len(),
274                2,
275            )));
276        };
277        self.check_expr(e1, sort1)?;
278        self.check_expr(e2, sort2)?;
279        Ok(output)
280    }
281
282    fn synth_expr(&mut self, expr: &fhir::Expr) -> Result<rty::Sort> {
283        match expr.kind {
284            fhir::ExprKind::Var(var, _) => self.synth_path(&var),
285            fhir::ExprKind::Literal(lit) => Ok(self.synth_lit(lit, expr)),
286            fhir::ExprKind::BinaryOp(op, e1, e2) => self.synth_binary_op(expr, op, e1, e2),
287            fhir::ExprKind::PrimApp(op, e1, e2) => self.synth_prim_app(&op, e1, e2, expr.span),
288            fhir::ExprKind::UnaryOp(op, e) => self.synth_unary_op(op, e),
289            fhir::ExprKind::App(callee, args) => {
290                let sort = self.ensure_resolved_path(&callee)?;
291                let Some(poly_fsort) = self.is_coercible_to_func(&sort, callee.fhir_id) else {
292                    return Err(self.emit_err(errors::ExpectedFun::new(callee.span, &sort)));
293                };
294                let fsort = self.instantiate_func_sort(poly_fsort);
295                self.synth_app(fsort, args, expr.span)
296            }
297            fhir::ExprKind::BoundedQuant(.., body) => {
298                self.check_expr(body, &rty::Sort::Bool)?;
299                Ok(rty::Sort::Bool)
300            }
301            fhir::ExprKind::Alias(_alias_reft, args) => {
302                // To check the application we only need the sort of `_alias_reft` which we collected
303                // during early conv, but should we do any extra checks on `_alias_reft`?
304                let fsort = self.sort_of_alias_reft(expr.fhir_id);
305                self.synth_app(fsort, args, expr.span)
306            }
307            fhir::ExprKind::IfThenElse(p, e1, e2) => {
308                self.check_expr(p, &rty::Sort::Bool)?;
309                let sort = self.synth_expr(e1)?;
310                self.check_expr(e2, &sort)?;
311                Ok(sort)
312            }
313            fhir::ExprKind::Dot(base, fld) => {
314                let sort = self.synth_expr(base)?;
315                let sort = self
316                    .fully_resolve(&sort)
317                    .map_err(|_| self.emit_err(errors::CannotInferSort::new(base.span)))?;
318                match &sort {
319                    rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) => {
320                        let (proj, sort) = sort_def
321                            .struct_variant()
322                            .field_by_name(sort_def.did(), sort_args, fld.name)
323                            .ok_or_else(|| self.emit_field_not_found(&sort, fld))?;
324                        self.wfckresults
325                            .field_projs_mut()
326                            .insert(expr.fhir_id, proj);
327                        Ok(sort)
328                    }
329                    rty::Sort::Bool | rty::Sort::Int | rty::Sort::Real => {
330                        Err(self.emit_err(errors::InvalidPrimitiveDotAccess::new(&sort, fld)))
331                    }
332                    _ => Err(self.emit_field_not_found(&sort, fld)),
333                }
334            }
335            fhir::ExprKind::Constructor(Some(path), field_exprs, spread) => {
336                // if we have S then we can synthesize otherwise we fail
337                // first get the sort based on the path - for example S { ... } => S
338                // and we should expect sort to be a struct or enum app
339                let path_def_id = match path.res {
340                    ExprRes::Ctor(def_id) => def_id,
341                    _ => span_bug!(expr.span, "unexpected path in constructor"),
342                };
343                let sort_def = self
344                    .genv
345                    .adt_sort_def_of(path_def_id)
346                    .map_err(|e| self.emit_err(e))?;
347                // generate fresh inferred sorts for each param
348                let fresh_args: rty::List<_> = (0..sort_def.param_count())
349                    .map(|_| self.next_sort_var())
350                    .collect();
351                let sort = rty::Sort::App(rty::SortCtor::Adt(sort_def.clone()), fresh_args.clone());
352                // check fields & spread against fresh args
353                self.check_field_exprs(
354                    expr.span,
355                    &sort_def,
356                    &fresh_args,
357                    field_exprs,
358                    &spread,
359                    &sort,
360                )?;
361                Ok(sort)
362            }
363            fhir::ExprKind::Block(decls, body) => {
364                for decl in decls {
365                    self.check_expr(&decl.init, &self.param_sort(decl.param.id))?;
366                }
367                self.synth_expr(body)
368            }
369            _ => Err(self.emit_err(errors::CannotInferSort::new(expr.span))),
370        }
371    }
372
373    fn synth_path(&mut self, path: &fhir::PathExpr) -> Result<rty::Sort> {
374        match path.res {
375            ExprRes::Param(_, id) => Ok(self.param_sort(id)),
376            ExprRes::Const(def_id) => {
377                if let Some(sort) = self.genv.sort_of_def_id(def_id).emit(&self.genv)? {
378                    let info = self.genv.constant_info(def_id).emit(&self.genv)?;
379                    // non-integral constant
380                    if sort != rty::Sort::Int && matches!(info, rty::ConstantInfo::Uninterpreted) {
381                        Err(self.emit_err(errors::ConstantAnnotationNeeded::new(path.span)))?;
382                    }
383                    Ok(sort)
384                } else {
385                    span_bug!(path.span, "unexpected const")
386                }
387            }
388            ExprRes::Variant(def_id) => {
389                let Some(sort) = self.genv.sort_of_def_id(def_id).emit(&self.genv)? else {
390                    span_bug!(path.span, "unexpected variant {def_id:?}")
391                };
392                Ok(sort)
393            }
394            ExprRes::ConstGeneric(_) => Ok(rty::Sort::Int), // TODO: generalize generic-const sorts
395            ExprRes::NumConst(_) => Ok(rty::Sort::Int),
396            ExprRes::GlobalFunc(SpecFuncKind::Def(name) | SpecFuncKind::Uif(name)) => {
397                Ok(rty::Sort::Func(self.genv.func_sort(name)))
398            }
399            ExprRes::GlobalFunc(SpecFuncKind::Thy(itf)) => {
400                Ok(rty::Sort::Func(THEORY_FUNCS.get(&itf).unwrap().sort.clone()))
401            }
402            ExprRes::GlobalFunc(SpecFuncKind::CharToInt) => {
403                Ok(rty::Sort::Func(rty::PolyFuncSort::new(
404                    List::empty(),
405                    rty::FuncSort::new(vec![rty::Sort::Char], rty::Sort::Int),
406                )))
407            }
408            ExprRes::Ctor(_) => {
409                span_bug!(path.span, "unexpected constructor in var position")
410            }
411        }
412    }
413
414    fn check_integral(&mut self, op: fhir::BinOp, sort: &rty::Sort, span: Span) -> Result {
415        if matches!(op, fhir::BinOp::Mod) {
416            let sort = self
417                .fully_resolve(sort)
418                .map_err(|_| self.emit_err(errors::CannotInferSort::new(span)))?;
419            if !matches!(sort, rty::Sort::Int | rty::Sort::BitVec(_)) {
420                span_bug!(span, "unexpected sort {sort:?} for operator {op:?}");
421            }
422        }
423        Ok(())
424    }
425
426    fn synth_binary_op(
427        &mut self,
428        expr: &fhir::Expr,
429        op: fhir::BinOp,
430        e1: &fhir::Expr,
431        e2: &fhir::Expr,
432    ) -> Result<rty::Sort> {
433        match op {
434            fhir::BinOp::Or | fhir::BinOp::And | fhir::BinOp::Iff | fhir::BinOp::Imp => {
435                self.check_expr(e1, &rty::Sort::Bool)?;
436                self.check_expr(e2, &rty::Sort::Bool)?;
437                Ok(rty::Sort::Bool)
438            }
439            fhir::BinOp::Eq | fhir::BinOp::Ne => {
440                let sort = self.next_sort_var();
441                self.check_expr(e1, &sort)?;
442                self.check_expr(e2, &sort)?;
443                Ok(rty::Sort::Bool)
444            }
445            fhir::BinOp::Lt | fhir::BinOp::Le | fhir::BinOp::Gt | fhir::BinOp::Ge => {
446                let sort = self.next_sort_var();
447                self.check_expr(e1, &sort)?;
448                self.check_expr(e2, &sort)?;
449                self.sort_of_bin_op
450                    .insert(expr.fhir_id, (sort.clone(), expr.span));
451                Ok(rty::Sort::Bool)
452            }
453            fhir::BinOp::Add
454            | fhir::BinOp::Sub
455            | fhir::BinOp::Mul
456            | fhir::BinOp::Div
457            | fhir::BinOp::Mod => {
458                let sort = self.next_num_var();
459                self.check_expr(e1, &sort)?;
460                self.check_expr(e2, &sort)?;
461                self.sort_of_bin_op
462                    .insert(expr.fhir_id, (sort.clone(), expr.span));
463                // check that the sort is integral for mod
464                self.check_integral(op, &sort, expr.span)?;
465
466                Ok(sort)
467            }
468            fhir::BinOp::BitAnd
469            | fhir::BinOp::BitOr
470            | fhir::BinOp::BitShl
471            | fhir::BinOp::BitShr => {
472                let sort = rty::Sort::BitVec(self.next_bv_size_var());
473                self.check_expr(e1, &sort)?;
474                self.check_expr(e2, &sort)?;
475                Ok(sort)
476            }
477        }
478    }
479
480    fn synth_unary_op(&mut self, op: fhir::UnOp, e: &fhir::Expr) -> Result<rty::Sort> {
481        match op {
482            fhir::UnOp::Not => {
483                self.check_expr(e, &rty::Sort::Bool)?;
484                Ok(rty::Sort::Bool)
485            }
486            fhir::UnOp::Neg => {
487                self.check_expr(e, &rty::Sort::Int)?;
488                Ok(rty::Sort::Int)
489            }
490        }
491    }
492
493    fn synth_app(&mut self, fsort: FuncSort, args: &[fhir::Expr], span: Span) -> Result<rty::Sort> {
494        if args.len() != fsort.inputs().len() {
495            return Err(self.emit_err(errors::ArgCountMismatch::new(
496                Some(span),
497                String::from("function"),
498                fsort.inputs().len(),
499                args.len(),
500            )));
501        }
502
503        iter::zip(args, fsort.inputs())
504            .try_for_each_exhaust(|(arg, formal)| self.check_expr(arg, formal))?;
505
506        Ok(fsort.output().clone())
507    }
508
509    fn instantiate_func_sort(&mut self, fsort: rty::PolyFuncSort) -> rty::FuncSort {
510        let args = fsort
511            .params()
512            .map(|kind| {
513                match kind {
514                    rty::SortParamKind::Sort => rty::SortArg::Sort(self.next_sort_var()),
515                    rty::SortParamKind::BvSize => rty::SortArg::BvSize(self.next_bv_size_var()),
516                }
517            })
518            .collect_vec();
519        fsort.instantiate(&args)
520    }
521
522    pub(crate) fn insert_sort_for_bty(&mut self, fhir_id: FhirId, sort: rty::Sort) {
523        self.sort_of_bty.insert(fhir_id, sort);
524    }
525
526    pub(crate) fn sort_of_bty(&self, fhir_id: FhirId) -> rty::Sort {
527        self.sort_of_bty
528            .get(&fhir_id)
529            .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
530            .clone()
531    }
532
533    pub(crate) fn insert_path_args(&mut self, fhir_id: FhirId, args: rty::GenericArgs) {
534        self.path_args.insert(fhir_id, args);
535    }
536
537    pub(crate) fn path_args(&self, fhir_id: FhirId) -> rty::GenericArgs {
538        self.path_args
539            .get(&fhir_id)
540            .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
541            .clone()
542    }
543
544    pub(crate) fn insert_sort_for_alias_reft(&mut self, fhir_id: FhirId, fsort: rty::FuncSort) {
545        self.sort_of_alias_reft.insert(fhir_id, fsort);
546    }
547
548    fn sort_of_alias_reft(&self, fhir_id: FhirId) -> rty::FuncSort {
549        self.sort_of_alias_reft
550            .get(&fhir_id)
551            .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
552            .clone()
553    }
554
555    fn normalize_projection_sort(
556        genv: GlobalEnv,
557        owner: FluxOwnerId,
558        sort: rty::Sort,
559    ) -> rty::Sort {
560        let infcx = genv
561            .tcx()
562            .infer_ctxt()
563            .with_next_trait_solver(true)
564            .build(TypingMode::non_body_analysis());
565        if let Some(def_id) = owner.def_id()
566            && let def_id = genv.maybe_extern_id(def_id).resolved_id()
567            && let Ok(sort) = sort.normalize_sorts(def_id, genv, &infcx)
568        {
569            sort
570        } else {
571            sort
572        }
573    }
574
575    // FIXME(nilehmann) this is a bit of a hack. We should find a more robust way to do normalization
576    // for sort checking, including normalizing projections. [RJ: normalizing_projections is done now]
577    // Maybe we can do lazy normalization. Once we do that maybe we can also stop
578    // expanding aliases in `ConvCtxt::conv_sort`.
579    pub(crate) fn normalize_sorts(&mut self) -> QueryResult {
580        let genv = self.genv;
581        for sort in self.sort_of_bty.values_mut() {
582            *sort = Self::normalize_projection_sort(genv, self.owner, sort.clone());
583        }
584        for fsort in self.sort_of_alias_reft.values_mut() {
585            *fsort = genv.deep_normalize_weak_alias_sorts(fsort)?;
586        }
587        Ok(())
588    }
589}
590
591impl<'genv> InferCtxt<'genv, '_> {
592    pub(super) fn declare_param(&mut self, param: fhir::RefineParam<'genv>, sort: rty::Sort) {
593        self.params.insert(param.id, (param, sort));
594    }
595
596    /// Whether a value of `sort1` can be automatically coerced to a value of `sort2`. A value of an
597    /// [`rty::SortCtor::Adt`] sort with a single field of sort `s` can be coerced to a value of sort
598    /// `s` and vice versa, i.e., we can automatically project the field out of the record or inject
599    /// a value into a record.
600    fn is_coercible(&mut self, sort1: &rty::Sort, sort2: &rty::Sort, fhir_id: FhirId) -> bool {
601        if self.try_equate(sort1, sort2).is_some() {
602            return true;
603        }
604
605        let mut sort1 = sort1.clone();
606        let mut sort2 = sort2.clone();
607
608        let mut coercions = vec![];
609        if let Some((def_id, sort)) = self.is_single_field_struct(&sort1) {
610            coercions.push(rty::Coercion::Project(def_id));
611            sort1 = sort.clone();
612        }
613        if let Some((def_id, sort)) = self.is_single_field_struct(&sort2) {
614            coercions.push(rty::Coercion::Inject(def_id));
615            sort2 = sort.clone();
616        }
617        self.wfckresults.coercions_mut().insert(fhir_id, coercions);
618        self.try_equate(&sort1, &sort2).is_some()
619    }
620
621    fn is_coercible_from_func(
622        &mut self,
623        sort: &rty::Sort,
624        fhir_id: FhirId,
625    ) -> Option<rty::PolyFuncSort> {
626        if let rty::Sort::Func(fsort) = sort {
627            Some(fsort.clone())
628        } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
629            self.wfckresults
630                .coercions_mut()
631                .insert(fhir_id, vec![rty::Coercion::Inject(def_id)]);
632            Some(fsort.clone())
633        } else {
634            None
635        }
636    }
637
638    fn is_coercible_to_func(
639        &mut self,
640        sort: &rty::Sort,
641        fhir_id: FhirId,
642    ) -> Option<rty::PolyFuncSort> {
643        if let rty::Sort::Func(fsort) = sort {
644            Some(fsort.clone())
645        } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
646            self.wfckresults
647                .coercions_mut()
648                .insert(fhir_id, vec![rty::Coercion::Project(def_id)]);
649            Some(fsort.clone())
650        } else {
651            None
652        }
653    }
654
655    fn try_equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
656        let sort1 = self.resolve_vars_if_possible(sort1);
657        let sort2 = self.resolve_vars_if_possible(sort2);
658        self.try_equate_inner(&sort1, &sort2)
659    }
660
661    fn try_equate_inner(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
662        match (sort1, sort2) {
663            (rty::Sort::Infer(rty::SortVar(vid1)), rty::Sort::Infer(rty::SortVar(vid2))) => {
664                self.sort_unification_table
665                    .unify_var_var(*vid1, *vid2)
666                    .ok()?;
667            }
668            (rty::Sort::Infer(rty::SortVar(vid)), sort)
669            | (sort, rty::Sort::Infer(rty::SortVar(vid))) => {
670                self.sort_unification_table
671                    .unify_var_value(*vid, Some(sort.clone()))
672                    .ok()?;
673            }
674            (rty::Sort::Infer(rty::NumVar(vid1)), rty::Sort::Infer(rty::NumVar(vid2))) => {
675                self.num_unification_table
676                    .unify_var_var(*vid1, *vid2)
677                    .ok()?;
678            }
679            (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::Int)
680            | (rty::Sort::Int, rty::Sort::Infer(rty::NumVar(vid))) => {
681                self.num_unification_table
682                    .unify_var_value(*vid, Some(rty::NumVarValue::Int))
683                    .ok()?;
684            }
685            (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::Real)
686            | (rty::Sort::Real, rty::Sort::Infer(rty::NumVar(vid))) => {
687                self.num_unification_table
688                    .unify_var_value(*vid, Some(rty::NumVarValue::Real))
689                    .ok()?;
690            }
691            (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::BitVec(sz))
692            | (rty::Sort::BitVec(sz), rty::Sort::Infer(rty::NumVar(vid))) => {
693                self.num_unification_table
694                    .unify_var_value(*vid, Some(rty::NumVarValue::BitVec(*sz)))
695                    .ok()?;
696            }
697
698            (rty::Sort::App(ctor1, args1), rty::Sort::App(ctor2, args2)) => {
699                if ctor1 != ctor2 || args1.len() != args2.len() {
700                    return None;
701                }
702                let mut args = vec![];
703                for (s1, s2) in args1.iter().zip(args2.iter()) {
704                    args.push(self.try_equate_inner(s1, s2)?);
705                }
706            }
707            (rty::Sort::BitVec(size1), rty::Sort::BitVec(size2)) => {
708                self.try_equate_bv_sizes(*size1, *size2)?;
709            }
710            _ if sort1 == sort2 => {}
711            _ => return None,
712        }
713        Some(sort1.clone())
714    }
715
716    fn try_equate_bv_sizes(
717        &mut self,
718        size1: rty::BvSize,
719        size2: rty::BvSize,
720    ) -> Option<rty::BvSize> {
721        match (size1, size2) {
722            (rty::BvSize::Infer(vid1), rty::BvSize::Infer(vid2)) => {
723                self.bv_size_unification_table
724                    .unify_var_var(vid1, vid2)
725                    .ok()?;
726            }
727            (rty::BvSize::Infer(vid), size) | (size, rty::BvSize::Infer(vid)) => {
728                self.bv_size_unification_table
729                    .unify_var_value(vid, Some(size))
730                    .ok()?;
731            }
732            _ if size1 == size2 => {}
733            _ => return None,
734        }
735        Some(size1)
736    }
737
738    fn equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> rty::Sort {
739        self.try_equate(sort1, sort2)
740            .unwrap_or_else(|| bug!("failed to equate sorts: `{sort1:?}` `{sort2:?}`"))
741    }
742
743    pub(crate) fn next_sort_var(&mut self) -> rty::Sort {
744        rty::Sort::Infer(rty::SortVar(self.next_sort_vid()))
745    }
746
747    fn next_num_var(&mut self) -> rty::Sort {
748        rty::Sort::Infer(rty::NumVar(self.next_num_vid()))
749    }
750
751    pub(crate) fn next_sort_vid(&mut self) -> rty::SortVid {
752        self.sort_unification_table.new_key(None)
753    }
754
755    fn next_num_vid(&mut self) -> rty::NumVid {
756        self.num_unification_table.new_key(None)
757    }
758
759    fn next_bv_size_var(&mut self) -> rty::BvSize {
760        rty::BvSize::Infer(self.next_bv_size_vid())
761    }
762
763    fn next_bv_size_vid(&mut self) -> rty::BvSizeVid {
764        self.bv_size_unification_table.new_key(None)
765    }
766
767    fn ensure_resolved_path(&mut self, path: &fhir::PathExpr) -> Result<rty::Sort> {
768        let sort = self.synth_path(path)?;
769        self.fully_resolve(&sort)
770            .map_err(|_| self.emit_err(errors::CannotInferSort::new(path.span)))
771    }
772
773    fn is_single_field_struct(&mut self, sort: &rty::Sort) -> Option<(DefId, rty::Sort)> {
774        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = sort
775            && let Some(variant) = sort_def.opt_struct_variant()
776            && let [sort] = &variant.field_sorts(sort_args)[..]
777        {
778            Some((sort_def.did(), sort.clone()))
779        } else {
780            None
781        }
782    }
783
784    pub(crate) fn into_results(mut self) -> Result<WfckResults> {
785        // Make sure the int literals are fully resolved
786        for (fhir_id, (sort, span)) in std::mem::take(&mut self.sort_of_literal) {
787            let sort = self
788                .fully_resolve(&sort)
789                .map_err(|_| self.emit_err(errors::CannotInferSort::new(span)))?;
790            self.wfckresults.node_sorts_mut().insert(fhir_id, sort);
791        }
792
793        // Make sure the binary operators are fully resolved
794        for (fhir_id, (sort, span)) in std::mem::take(&mut self.sort_of_bin_op) {
795            let sort = self
796                .fully_resolve(&sort)
797                .map_err(|_| self.emit_err(errors::CannotInferSort::new(span)))?;
798            self.wfckresults.bin_op_sorts_mut().insert(fhir_id, sort);
799        }
800
801        // Make sure all parameters are fully resolved
802        for (_, (param, sort)) in std::mem::take(&mut self.params) {
803            let sort = self
804                .fully_resolve(&sort)
805                .map_err(|_| self.emit_err(errors::SortAnnotationNeeded::new(&param)))?;
806            self.wfckresults
807                .node_sorts_mut()
808                .insert(param.fhir_id, sort);
809        }
810
811        Ok(self.wfckresults)
812    }
813
814    pub(crate) fn infer_mode(&self, id: fhir::ParamId) -> fhir::InferMode {
815        fhir::InferMode::from_param_kind(self.params[&id].0.kind)
816    }
817
818    #[track_caller]
819    pub(crate) fn param_sort(&self, id: fhir::ParamId) -> rty::Sort {
820        self.params
821            .get(&id)
822            .unwrap_or_else(|| bug!("no entry found for `{id:?}`"))
823            .1
824            .clone()
825    }
826
827    fn shallow_resolve(&mut self, sort: &rty::Sort) -> rty::Sort {
828        sort.fold_with(&mut ShallowResolver { infcx: self })
829    }
830
831    fn resolve_vars_if_possible(&mut self, sort: &rty::Sort) -> rty::Sort {
832        sort.fold_with(&mut OpportunisticResolver { infcx: self })
833    }
834
835    pub(crate) fn fully_resolve(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, ()> {
836        sort.try_fold_with(&mut FullResolver { infcx: self })
837    }
838}
839
840/// Before the main sort inference, we do a first traversal checking all implicitly scoped parameters
841/// declared with `@` or `#` and infer their sort based on the type they are indexing, e.g., if `n` was
842/// declared as `i32[@n]`, we infer `int` for `n`.
843///
844/// This prepass is necessary because sometimes the order in which we traverse expressions can
845/// affect what we can infer. By resolving implicit parameters first, we ensure more consistent and
846/// complete inference regardless of how expressions are later traversed.
847///
848/// It should be possible to improve sort checking (e.g., by allowing partially resolved sorts in
849/// function position) such that we don't need this anymore.
850pub(crate) struct ImplicitParamInferer<'a, 'genv, 'tcx> {
851    infcx: &'a mut InferCtxt<'genv, 'tcx>,
852    errors: Errors<'genv>,
853}
854
855impl<'a, 'genv, 'tcx> ImplicitParamInferer<'a, 'genv, 'tcx> {
856    pub(crate) fn infer(
857        infcx: &'a mut InferCtxt<'genv, 'tcx>,
858        node: &fhir::OwnerNode<'genv>,
859    ) -> Result {
860        let errors = Errors::new(infcx.genv.sess());
861        let mut vis = Self { infcx, errors };
862        vis.visit_node(node);
863        vis.errors.into_result()
864    }
865
866    fn infer_implicit_params(&mut self, idx: &fhir::Expr, expected: &rty::Sort) {
867        match idx.kind {
868            fhir::ExprKind::Var(var, Some(_)) => {
869                let (_, id) = var.res.expect_param();
870                let found = self.infcx.param_sort(id);
871                self.infcx.equate(&found, expected);
872            }
873            fhir::ExprKind::Record(flds) => {
874                if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
875                    let sorts = sort_def.struct_variant().field_sorts(sort_args);
876                    if flds.len() != sorts.len() {
877                        self.errors.emit(errors::ArgCountMismatch::new(
878                            Some(idx.span),
879                            String::from("type"),
880                            sorts.len(),
881                            flds.len(),
882                        ));
883                        return;
884                    }
885                    for (f, sort) in iter::zip(flds, &sorts) {
886                        self.infer_implicit_params(f, sort);
887                    }
888                } else {
889                    self.errors.emit(errors::ArgCountMismatch::new(
890                        Some(idx.span),
891                        String::from("type"),
892                        1,
893                        flds.len(),
894                    ));
895                }
896            }
897            _ => {}
898        }
899    }
900}
901
902impl<'genv> fhir::visit::Visitor<'genv> for ImplicitParamInferer<'_, 'genv, '_> {
903    fn visit_ty(&mut self, ty: &fhir::Ty<'genv>) {
904        if let fhir::TyKind::Indexed(bty, idx) = &ty.kind {
905            let expected = self.infcx.sort_of_bty(bty.fhir_id);
906            self.infer_implicit_params(idx, &expected);
907        }
908        fhir::visit::walk_ty(self, ty);
909    }
910}
911
912impl InferCtxt<'_, '_> {
913    #[track_caller]
914    fn emit_sort_mismatch(
915        &mut self,
916        span: Span,
917        expected: &rty::Sort,
918        found: &rty::Sort,
919    ) -> ErrorGuaranteed {
920        let expected = self.resolve_vars_if_possible(expected);
921        let found = self.resolve_vars_if_possible(found);
922        self.emit_err(errors::SortMismatch::new(span, expected, found))
923    }
924
925    fn emit_field_not_found(&mut self, sort: &rty::Sort, field: Ident) -> ErrorGuaranteed {
926        self.emit_err(errors::FieldNotFound::new(sort.clone(), field))
927    }
928
929    #[track_caller]
930    fn emit_err<'b>(&'b self, err: impl Diagnostic<'b>) -> ErrorGuaranteed {
931        self.genv.sess().emit_err(err)
932    }
933}
934
935struct ShallowResolver<'a, 'genv, 'tcx> {
936    infcx: &'a mut InferCtxt<'genv, 'tcx>,
937}
938
939impl TypeFolder for ShallowResolver<'_, '_, '_> {
940    fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
941        match sort {
942            rty::Sort::Infer(rty::SortVar(vid)) => {
943                // if `sort` is a sort variable, it can be resolved to an num/bit-vec variable,
944                // which can then be recursively resolved, hence the recursion. Note though that
945                // we prevent sort variables from unifying to other sort variables directly (though
946                // they may be embedded structurally), so this recursion should always be of very
947                // limited depth.
948                self.infcx
949                    .sort_unification_table
950                    .probe_value(*vid)
951                    .map(|sort| sort.fold_with(self))
952                    .unwrap_or(sort.clone())
953            }
954            rty::Sort::Infer(rty::NumVar(vid)) => {
955                // same here, a num var could had been unified with a bitvector
956                self.infcx
957                    .num_unification_table
958                    .probe_value(*vid)
959                    .map(|val| val.to_sort().fold_with(self))
960                    .unwrap_or(sort.clone())
961            }
962            rty::Sort::BitVec(rty::BvSize::Infer(vid)) => {
963                self.infcx
964                    .bv_size_unification_table
965                    .probe_value(*vid)
966                    .map(rty::Sort::BitVec)
967                    .unwrap_or(sort.clone())
968            }
969            _ => sort.clone(),
970        }
971    }
972}
973
974struct OpportunisticResolver<'a, 'genv, 'tcx> {
975    infcx: &'a mut InferCtxt<'genv, 'tcx>,
976}
977
978impl TypeFolder for OpportunisticResolver<'_, '_, '_> {
979    fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
980        let s = self.infcx.shallow_resolve(sort);
981        s.super_fold_with(self)
982    }
983}
984
985struct FullResolver<'a, 'genv, 'tcx> {
986    infcx: &'a mut InferCtxt<'genv, 'tcx>,
987}
988
989impl FallibleTypeFolder for FullResolver<'_, '_, '_> {
990    type Error = ();
991
992    fn try_fold_sort(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, Self::Error> {
993        let s = self.infcx.shallow_resolve(sort);
994        match s {
995            rty::Sort::Infer(_) | rty::Sort::BitVec(rty::BvSize::Infer(_)) => Err(()),
996            _ => s.try_super_fold_with(self),
997        }
998    }
999}