flux_fhir_analysis/wf/
sortck.rs

1use std::iter;
2
3use ena::unify::InPlaceUnificationTable;
4use flux_common::{bug, iter::IterExt, result::ResultExt, span_bug, tracked_span_bug};
5use flux_errors::{ErrorGuaranteed, Errors};
6use flux_infer::projections::NormalizeExt;
7use flux_middle::{
8    THEORY_FUNCS,
9    fhir::{self, ExprRes, FhirId, FluxOwnerId, visit::Visitor as _},
10    global_env::GlobalEnv,
11    queries::QueryResult,
12    rty::{
13        self, AdtSortDef, FuncSort, List, WfckResults,
14        fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable},
15    },
16};
17use itertools::{Itertools, izip};
18use rustc_data_structures::unord::UnordMap;
19use rustc_errors::Diagnostic;
20use rustc_hash::{FxHashMap, FxHashSet};
21use rustc_middle::ty::TypingMode;
22use rustc_span::{Span, def_id::DefId, symbol::Ident};
23
24use super::errors;
25use crate::rustc_infer::infer::TyCtxtInferExt;
26
27type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
28
29pub(super) struct InferCtxt<'genv, 'tcx> {
30    pub genv: GlobalEnv<'genv, 'tcx>,
31    pub owner: FluxOwnerId,
32    pub wfckresults: WfckResults,
33    sort_unification_table: InPlaceUnificationTable<rty::SortVid>,
34    num_unification_table: InPlaceUnificationTable<rty::NumVid>,
35    bv_size_unification_table: InPlaceUnificationTable<rty::BvSizeVid>,
36    params: FxHashMap<fhir::ParamId, (fhir::RefineParam<'genv>, rty::Sort)>,
37    sort_of_bty: FxHashMap<FhirId, rty::Sort>,
38    sort_of_literal: FxHashMap<FhirId, (rty::Sort, Span)>,
39    sort_of_bin_op: FxHashMap<FhirId, (rty::Sort, Span)>,
40    sort_args_of_app: FxHashMap<FhirId, (List<rty::SortArg>, Span)>,
41    casts: FxHashSet<FhirId>,
42    path_args: UnordMap<FhirId, rty::GenericArgs>,
43    sort_of_alias_reft: FxHashMap<FhirId, rty::FuncSort>,
44}
45
46pub fn prim_op_sort(op: &fhir::BinOp) -> Option<(Vec<rty::Sort>, rty::Sort)> {
47    match op {
48        fhir::BinOp::BitAnd | fhir::BinOp::BitOr | fhir::BinOp::BitShl | fhir::BinOp::BitShr => {
49            Some((vec![rty::Sort::Int, rty::Sort::Int], rty::Sort::Int))
50        }
51        _ => None,
52    }
53}
54
55impl<'genv, 'tcx> InferCtxt<'genv, 'tcx> {
56    pub(super) fn new(genv: GlobalEnv<'genv, 'tcx>, owner: FluxOwnerId) -> Self {
57        // We skip 0 because that's used for sort dummy self types during conv.
58        let mut sort_unification_table = InPlaceUnificationTable::new();
59        sort_unification_table.new_key(None);
60        Self {
61            genv,
62            owner,
63            wfckresults: WfckResults::new(owner),
64            sort_unification_table,
65            num_unification_table: InPlaceUnificationTable::new(),
66            bv_size_unification_table: InPlaceUnificationTable::new(),
67            params: Default::default(),
68            sort_of_bty: Default::default(),
69            sort_of_literal: Default::default(),
70            sort_of_bin_op: Default::default(),
71            path_args: Default::default(),
72            sort_of_alias_reft: Default::default(),
73            sort_args_of_app: Default::default(),
74            casts: Default::default(),
75        }
76    }
77
78    fn check_abs(
79        &mut self,
80        expr: &fhir::Expr,
81        params: &[fhir::RefineParam],
82        body: &fhir::Expr,
83        expected: &rty::Sort,
84    ) -> Result {
85        let Some(fsort) = self.is_coercible_from_func(expected, expr.fhir_id) else {
86            return Err(self.emit_err(errors::UnexpectedFun::new(expr.span, expected)));
87        };
88
89        let fsort = fsort.expect_mono();
90
91        if params.len() != fsort.inputs().len() {
92            return Err(self.emit_err(errors::ParamCountMismatch::new(
93                expr.span,
94                fsort.inputs().len(),
95                params.len(),
96            )));
97        }
98        iter::zip(params, fsort.inputs()).try_for_each_exhaust(|(param, expected)| {
99            let found = self.param_sort(param.id);
100            if self.try_equate(&found, expected).is_none() {
101                return Err(self.emit_sort_mismatch(param.span, expected, &found));
102            }
103            Ok(())
104        })?;
105        self.check_expr(body, fsort.output())?;
106        self.wfckresults
107            .node_sorts_mut()
108            .insert(body.fhir_id, fsort.output().clone());
109        Ok(())
110    }
111
112    fn check_field_exprs(
113        &mut self,
114        expr_span: Span,
115        sort_def: &AdtSortDef,
116        sort_args: &[rty::Sort],
117        field_exprs: &[fhir::FieldExpr],
118        spread: &Option<&fhir::Spread>,
119        expected: &rty::Sort,
120    ) -> Result {
121        let sort_by_field_name = sort_def.struct_variant().sort_by_field_name(sort_args);
122        let mut used_fields = FxHashMap::default();
123        for expr in field_exprs {
124            // make sure that the field is actually a field
125            let Some(sort) = sort_by_field_name.get(&expr.ident.name) else {
126                return Err(self.emit_err(errors::FieldNotFound::new(expected.clone(), expr.ident)));
127            };
128            if let Some(old_field) = used_fields.insert(expr.ident.name, expr.ident) {
129                return Err(self.emit_err(errors::DuplicateFieldUsed::new(expr.ident, old_field)));
130            }
131            self.check_expr(&expr.expr, sort)?;
132        }
133        if let Some(spread) = spread {
134            // must check that the spread is of the same sort as the constructor
135            self.check_expr(&spread.expr, expected)
136        } else if sort_by_field_name.len() != used_fields.len() {
137            // emit an error because all fields are not used
138            let missing_fields = sort_by_field_name
139                .into_keys()
140                .filter(|x| !used_fields.contains_key(x))
141                .collect();
142            Err(self.emit_err(errors::ConstructorMissingFields::new(expr_span, missing_fields)))
143        } else {
144            Ok(())
145        }
146    }
147
148    fn check_constructor(
149        &mut self,
150        expr: &fhir::Expr,
151        field_exprs: &[fhir::FieldExpr],
152        spread: &Option<&fhir::Spread>,
153        expected: &rty::Sort,
154    ) -> Result {
155        let expected = self.resolve_vars_if_possible(expected);
156        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = &expected {
157            self.wfckresults
158                .record_ctors_mut()
159                .insert(expr.fhir_id, sort_def.did());
160            self.check_field_exprs(expr.span, sort_def, sort_args, field_exprs, spread, &expected)?;
161            Ok(())
162        } else {
163            Err(self.emit_err(errors::UnexpectedConstructor::new(expr.span, &expected)))
164        }
165    }
166
167    fn check_record(
168        &mut self,
169        arg: &fhir::Expr,
170        flds: &[fhir::Expr],
171        expected: &rty::Sort,
172    ) -> Result {
173        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
174            let sorts = sort_def.struct_variant().field_sorts(sort_args);
175            if flds.len() != sorts.len() {
176                return Err(self.emit_err(errors::ArgCountMismatch::new(
177                    Some(arg.span),
178                    String::from("type"),
179                    sorts.len(),
180                    flds.len(),
181                )));
182            }
183            self.wfckresults
184                .record_ctors_mut()
185                .insert(arg.fhir_id, sort_def.did());
186
187            izip!(flds, &sorts)
188                .map(|(arg, expected)| self.check_expr(arg, expected))
189                .try_collect_exhaust()
190        } else {
191            Err(self.emit_err(errors::ArgCountMismatch::new(
192                Some(arg.span),
193                String::from("type"),
194                1,
195                flds.len(),
196            )))
197        }
198    }
199
200    pub(super) fn check_expr(&mut self, expr: &fhir::Expr, expected: &rty::Sort) -> Result {
201        match &expr.kind {
202            fhir::ExprKind::IfThenElse(p, e1, e2) => {
203                self.check_expr(p, &rty::Sort::Bool)?;
204                self.check_expr(e1, expected)?;
205                self.check_expr(e2, expected)?;
206            }
207            fhir::ExprKind::Abs(params, body) => {
208                self.check_abs(expr, params, body, expected)?;
209            }
210            fhir::ExprKind::Record(fields) => self.check_record(expr, fields, expected)?,
211            fhir::ExprKind::Constructor(None, exprs, spread) => {
212                self.check_constructor(expr, exprs, spread, expected)?;
213            }
214            fhir::ExprKind::UnaryOp(..)
215            | fhir::ExprKind::BinaryOp(..)
216            | fhir::ExprKind::Dot(..)
217            | fhir::ExprKind::App(..)
218            | fhir::ExprKind::Alias(..)
219            | fhir::ExprKind::Var(..)
220            | fhir::ExprKind::Literal(..)
221            | fhir::ExprKind::BoundedQuant(..)
222            | fhir::ExprKind::Block(..)
223            | fhir::ExprKind::Constructor(..)
224            | fhir::ExprKind::PrimApp(..) => {
225                let found = self.synth_expr(expr)?;
226                let found = self.resolve_vars_if_possible(&found);
227                let expected = self.resolve_vars_if_possible(expected);
228                if !self.is_coercible(&found, &expected, expr.fhir_id) {
229                    return Err(self.emit_sort_mismatch(expr.span, &expected, &found));
230                }
231            }
232            fhir::ExprKind::Err(_) => {
233                // an error has already been reported so we can just skip
234            }
235        }
236        Ok(())
237    }
238
239    pub(super) fn check_loc(&mut self, loc: &fhir::PathExpr) -> Result {
240        let found = self.synth_path(loc)?;
241        if found == rty::Sort::Loc {
242            Ok(())
243        } else {
244            Err(self.emit_sort_mismatch(loc.span, &rty::Sort::Loc, &found))
245        }
246    }
247
248    fn synth_lit(&mut self, lit: fhir::Lit, expr: &fhir::Expr) -> rty::Sort {
249        match lit {
250            fhir::Lit::Int(_) => {
251                let sort = self.next_num_var();
252                self.sort_of_literal
253                    .insert(expr.fhir_id, (sort.clone(), expr.span));
254                sort
255            }
256            fhir::Lit::Bool(_) => rty::Sort::Bool,
257            fhir::Lit::Real(_) => rty::Sort::Real,
258            fhir::Lit::Str(_) => rty::Sort::Str,
259            fhir::Lit::Char(_) => rty::Sort::Char,
260        }
261    }
262
263    fn synth_prim_app(
264        &mut self,
265        op: &fhir::BinOp,
266        e1: &fhir::Expr,
267        e2: &fhir::Expr,
268        span: Span,
269    ) -> Result<rty::Sort> {
270        let Some((inputs, output)) = prim_op_sort(op) else {
271            return Err(self.emit_err(errors::UnsupportedPrimOp::new(span, *op)));
272        };
273        let [sort1, sort2] = &inputs[..] else {
274            return Err(self.emit_err(errors::ArgCountMismatch::new(
275                Some(span),
276                String::from("primop app"),
277                inputs.len(),
278                2,
279            )));
280        };
281        self.check_expr(e1, sort1)?;
282        self.check_expr(e2, sort2)?;
283        Ok(output)
284    }
285
286    fn synth_expr(&mut self, expr: &fhir::Expr) -> Result<rty::Sort> {
287        match expr.kind {
288            fhir::ExprKind::Var(var, _) => self.synth_path(&var),
289            fhir::ExprKind::Literal(lit) => Ok(self.synth_lit(lit, expr)),
290            fhir::ExprKind::BinaryOp(op, e1, e2) => self.synth_binary_op(expr, op, e1, e2),
291            fhir::ExprKind::PrimApp(op, e1, e2) => self.synth_prim_app(&op, e1, e2, expr.span),
292            fhir::ExprKind::UnaryOp(op, e) => self.synth_unary_op(op, e),
293            fhir::ExprKind::App(callee, args) => {
294                let sort = self.ensure_resolved_path(&callee)?;
295                let Some(poly_fsort) = self.is_coercible_to_func(&sort, callee.fhir_id) else {
296                    return Err(self.emit_err(errors::ExpectedFun::new(callee.span, &sort)));
297                };
298                let fhir_id = expr.fhir_id;
299                let span = expr.span;
300                let fsort = self.instantiate_func_sort(poly_fsort, fhir_id, span);
301                if matches!(callee.res, ExprRes::GlobalFunc(fhir::SpecFuncKind::Cast)) {
302                    self.casts.insert(fhir_id);
303                }
304                self.synth_app(fsort, args, expr.span)
305            }
306            fhir::ExprKind::BoundedQuant(.., body) => {
307                self.check_expr(body, &rty::Sort::Bool)?;
308                Ok(rty::Sort::Bool)
309            }
310            fhir::ExprKind::Alias(_alias_reft, args) => {
311                // To check the application we only need the sort of `_alias_reft` which we collected
312                // during early conv, but should we do any extra checks on `_alias_reft`?
313                let fsort = self.sort_of_alias_reft(expr.fhir_id);
314                self.synth_app(fsort, args, expr.span)
315            }
316            fhir::ExprKind::IfThenElse(p, e1, e2) => {
317                self.check_expr(p, &rty::Sort::Bool)?;
318                let sort = self.synth_expr(e1)?;
319                self.check_expr(e2, &sort)?;
320                Ok(sort)
321            }
322            fhir::ExprKind::Dot(base, fld) => {
323                let sort = self.synth_expr(base)?;
324                let sort = self
325                    .fully_resolve(&sort)
326                    .map_err(|_| self.emit_err(errors::CannotInferSort::new(base.span)))?;
327                match &sort {
328                    rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) => {
329                        let (proj, sort) = sort_def
330                            .struct_variant()
331                            .field_by_name(sort_def.did(), sort_args, fld.name)
332                            .ok_or_else(|| self.emit_field_not_found(&sort, fld))?;
333                        self.wfckresults
334                            .field_projs_mut()
335                            .insert(expr.fhir_id, proj);
336                        Ok(sort)
337                    }
338                    rty::Sort::Bool | rty::Sort::Int | rty::Sort::Real => {
339                        Err(self.emit_err(errors::InvalidPrimitiveDotAccess::new(&sort, fld)))
340                    }
341                    _ => Err(self.emit_field_not_found(&sort, fld)),
342                }
343            }
344            fhir::ExprKind::Constructor(Some(path), field_exprs, spread) => {
345                // if we have S then we can synthesize otherwise we fail
346                // first get the sort based on the path - for example S { ... } => S
347                // and we should expect sort to be a struct or enum app
348                let path_def_id = match path.res {
349                    ExprRes::Ctor(def_id) => def_id,
350                    _ => span_bug!(expr.span, "unexpected path in constructor"),
351                };
352                let sort_def = self
353                    .genv
354                    .adt_sort_def_of(path_def_id)
355                    .map_err(|e| self.emit_err(e))?;
356                // generate fresh inferred sorts for each param
357                let fresh_args: rty::List<_> = (0..sort_def.param_count())
358                    .map(|_| self.next_sort_var())
359                    .collect();
360                let sort = rty::Sort::App(rty::SortCtor::Adt(sort_def.clone()), fresh_args.clone());
361                // check fields & spread against fresh args
362                self.check_field_exprs(
363                    expr.span,
364                    &sort_def,
365                    &fresh_args,
366                    field_exprs,
367                    &spread,
368                    &sort,
369                )?;
370                Ok(sort)
371            }
372            fhir::ExprKind::Block(decls, body) => {
373                for decl in decls {
374                    self.check_expr(&decl.init, &self.param_sort(decl.param.id))?;
375                }
376                self.synth_expr(body)
377            }
378            _ => Err(self.emit_err(errors::CannotInferSort::new(expr.span))),
379        }
380    }
381
382    fn synth_path(&mut self, path: &fhir::PathExpr) -> Result<rty::Sort> {
383        match path.res {
384            ExprRes::Param(_, id) => Ok(self.param_sort(id)),
385            ExprRes::Const(def_id) => {
386                if let Some(sort) = self.genv.sort_of_def_id(def_id).emit(&self.genv)? {
387                    let info = self.genv.constant_info(def_id).emit(&self.genv)?;
388                    // non-integral constant
389                    if sort != rty::Sort::Int && matches!(info, rty::ConstantInfo::Uninterpreted) {
390                        Err(self.emit_err(errors::ConstantAnnotationNeeded::new(path.span)))?;
391                    }
392                    Ok(sort)
393                } else {
394                    span_bug!(path.span, "unexpected const")
395                }
396            }
397            ExprRes::Variant(def_id) => {
398                let Some(sort) = self.genv.sort_of_def_id(def_id).emit(&self.genv)? else {
399                    span_bug!(path.span, "unexpected variant {def_id:?}")
400                };
401                Ok(sort)
402            }
403            ExprRes::ConstGeneric(_) => Ok(rty::Sort::Int), // TODO: generalize generic-const sorts
404            ExprRes::NumConst(_) => Ok(rty::Sort::Int),
405            ExprRes::GlobalFunc(spec_func) => {
406                let fsort = match spec_func {
407                    fhir::SpecFuncKind::Def(name) | fhir::SpecFuncKind::Uif(name) => {
408                        self.genv.func_sort(name)
409                    }
410                    fhir::SpecFuncKind::Thy(itf) => THEORY_FUNCS.get(&itf).unwrap().sort.clone(),
411                    fhir::SpecFuncKind::Cast => {
412                        rty::PolyFuncSort::new(
413                            List::from_arr([rty::SortParamKind::Sort, rty::SortParamKind::Sort]),
414                            rty::FuncSort::new(
415                                vec![rty::Sort::Var(rty::ParamSort::from(0_usize))],
416                                rty::Sort::Var(rty::ParamSort::from(1_usize)),
417                            ),
418                        )
419                    }
420                };
421                Ok(rty::Sort::Func(fsort))
422            }
423            ExprRes::Ctor(_) => {
424                span_bug!(path.span, "unexpected constructor in var position")
425            }
426        }
427    }
428
429    fn check_integral(&mut self, op: fhir::BinOp, sort: &rty::Sort, span: Span) -> Result {
430        if matches!(op, fhir::BinOp::Mod) {
431            let sort = self
432                .fully_resolve(sort)
433                .map_err(|_| self.emit_err(errors::CannotInferSort::new(span)))?;
434            if !matches!(sort, rty::Sort::Int | rty::Sort::BitVec(_)) {
435                span_bug!(span, "unexpected sort {sort:?} for operator {op:?}");
436            }
437        }
438        Ok(())
439    }
440
441    fn synth_binary_op(
442        &mut self,
443        expr: &fhir::Expr,
444        op: fhir::BinOp,
445        e1: &fhir::Expr,
446        e2: &fhir::Expr,
447    ) -> Result<rty::Sort> {
448        match op {
449            fhir::BinOp::Or | fhir::BinOp::And | fhir::BinOp::Iff | fhir::BinOp::Imp => {
450                self.check_expr(e1, &rty::Sort::Bool)?;
451                self.check_expr(e2, &rty::Sort::Bool)?;
452                Ok(rty::Sort::Bool)
453            }
454            fhir::BinOp::Eq | fhir::BinOp::Ne => {
455                let sort = self.next_sort_var();
456                self.check_expr(e1, &sort)?;
457                self.check_expr(e2, &sort)?;
458                Ok(rty::Sort::Bool)
459            }
460            fhir::BinOp::Lt | fhir::BinOp::Le | fhir::BinOp::Gt | fhir::BinOp::Ge => {
461                let sort = self.next_sort_var();
462                self.check_expr(e1, &sort)?;
463                self.check_expr(e2, &sort)?;
464                self.sort_of_bin_op
465                    .insert(expr.fhir_id, (sort.clone(), expr.span));
466                Ok(rty::Sort::Bool)
467            }
468            fhir::BinOp::Add
469            | fhir::BinOp::Sub
470            | fhir::BinOp::Mul
471            | fhir::BinOp::Div
472            | fhir::BinOp::Mod => {
473                let sort = self.next_num_var();
474                self.check_expr(e1, &sort)?;
475                self.check_expr(e2, &sort)?;
476                self.sort_of_bin_op
477                    .insert(expr.fhir_id, (sort.clone(), expr.span));
478                // check that the sort is integral for mod
479                self.check_integral(op, &sort, expr.span)?;
480
481                Ok(sort)
482            }
483            fhir::BinOp::BitAnd
484            | fhir::BinOp::BitOr
485            | fhir::BinOp::BitShl
486            | fhir::BinOp::BitShr => {
487                let sort = rty::Sort::BitVec(self.next_bv_size_var());
488                self.check_expr(e1, &sort)?;
489                self.check_expr(e2, &sort)?;
490                Ok(sort)
491            }
492        }
493    }
494
495    fn synth_unary_op(&mut self, op: fhir::UnOp, e: &fhir::Expr) -> Result<rty::Sort> {
496        match op {
497            fhir::UnOp::Not => {
498                self.check_expr(e, &rty::Sort::Bool)?;
499                Ok(rty::Sort::Bool)
500            }
501            fhir::UnOp::Neg => {
502                self.check_expr(e, &rty::Sort::Int)?;
503                Ok(rty::Sort::Int)
504            }
505        }
506    }
507
508    fn synth_app(&mut self, fsort: FuncSort, args: &[fhir::Expr], span: Span) -> Result<rty::Sort> {
509        if args.len() != fsort.inputs().len() {
510            return Err(self.emit_err(errors::ArgCountMismatch::new(
511                Some(span),
512                String::from("function"),
513                fsort.inputs().len(),
514                args.len(),
515            )));
516        }
517
518        iter::zip(args, fsort.inputs())
519            .try_for_each_exhaust(|(arg, formal)| self.check_expr(arg, formal))?;
520
521        Ok(fsort.output().clone())
522    }
523
524    fn instantiate_func_sort(
525        &mut self,
526        fsort: rty::PolyFuncSort,
527        fhir_id: FhirId,
528        span: Span,
529    ) -> rty::FuncSort {
530        let args = fsort
531            .params()
532            .map(|kind| {
533                match kind {
534                    rty::SortParamKind::Sort => rty::SortArg::Sort(self.next_sort_var()),
535                    rty::SortParamKind::BvSize => rty::SortArg::BvSize(self.next_bv_size_var()),
536                }
537            })
538            .collect_vec();
539        self.sort_args_of_app
540            .insert(fhir_id, (List::from_vec(args.clone()), span));
541        fsort.instantiate(&args)
542    }
543
544    pub(crate) fn insert_sort_for_bty(&mut self, fhir_id: FhirId, sort: rty::Sort) {
545        self.sort_of_bty.insert(fhir_id, sort);
546    }
547
548    pub(crate) fn sort_of_bty(&self, fhir_id: FhirId) -> rty::Sort {
549        self.sort_of_bty
550            .get(&fhir_id)
551            .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
552            .clone()
553    }
554
555    pub(crate) fn insert_path_args(&mut self, fhir_id: FhirId, args: rty::GenericArgs) {
556        self.path_args.insert(fhir_id, args);
557    }
558
559    pub(crate) fn path_args(&self, fhir_id: FhirId) -> rty::GenericArgs {
560        self.path_args
561            .get(&fhir_id)
562            .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
563            .clone()
564    }
565
566    pub(crate) fn insert_sort_for_alias_reft(&mut self, fhir_id: FhirId, fsort: rty::FuncSort) {
567        self.sort_of_alias_reft.insert(fhir_id, fsort);
568    }
569
570    fn sort_of_alias_reft(&self, fhir_id: FhirId) -> rty::FuncSort {
571        self.sort_of_alias_reft
572            .get(&fhir_id)
573            .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
574            .clone()
575    }
576
577    fn normalize_projection_sort(
578        genv: GlobalEnv,
579        owner: FluxOwnerId,
580        sort: rty::Sort,
581    ) -> rty::Sort {
582        let infcx = genv
583            .tcx()
584            .infer_ctxt()
585            .with_next_trait_solver(true)
586            .build(TypingMode::non_body_analysis());
587        if let Some(def_id) = owner.def_id()
588            && let def_id = genv.maybe_extern_id(def_id).resolved_id()
589            && let Ok(sort) = sort.normalize_sorts(def_id, genv, &infcx)
590        {
591            sort
592        } else {
593            sort
594        }
595    }
596
597    // FIXME(nilehmann) this is a bit of a hack. We should find a more robust way to do normalization
598    // for sort checking, including normalizing projections. [RJ: normalizing_projections is done now]
599    // Maybe we can do lazy normalization. Once we do that maybe we can also stop
600    // expanding aliases in `ConvCtxt::conv_sort`.
601    pub(crate) fn normalize_sorts(&mut self) -> QueryResult {
602        let genv = self.genv;
603        for sort in self.sort_of_bty.values_mut() {
604            *sort = Self::normalize_projection_sort(genv, self.owner, sort.clone());
605        }
606        for fsort in self.sort_of_alias_reft.values_mut() {
607            *fsort = genv.deep_normalize_weak_alias_sorts(fsort)?;
608        }
609        Ok(())
610    }
611}
612
613impl<'genv> InferCtxt<'genv, '_> {
614    pub(super) fn declare_param(&mut self, param: fhir::RefineParam<'genv>, sort: rty::Sort) {
615        self.params.insert(param.id, (param, sort));
616    }
617
618    /// Whether a value of `sort1` can be automatically coerced to a value of `sort2`. A value of an
619    /// [`rty::SortCtor::Adt`] sort with a single field of sort `s` can be coerced to a value of sort
620    /// `s` and vice versa, i.e., we can automatically project the field out of the record or inject
621    /// a value into a record.
622    fn is_coercible(&mut self, sort1: &rty::Sort, sort2: &rty::Sort, fhir_id: FhirId) -> bool {
623        if self.try_equate(sort1, sort2).is_some() {
624            return true;
625        }
626
627        let mut sort1 = sort1.clone();
628        let mut sort2 = sort2.clone();
629
630        let mut coercions = vec![];
631        if let Some((def_id, sort)) = self.is_single_field_struct(&sort1) {
632            coercions.push(rty::Coercion::Project(def_id));
633            sort1 = sort.clone();
634        }
635        if let Some((def_id, sort)) = self.is_single_field_struct(&sort2) {
636            coercions.push(rty::Coercion::Inject(def_id));
637            sort2 = sort.clone();
638        }
639        self.wfckresults.coercions_mut().insert(fhir_id, coercions);
640        self.try_equate(&sort1, &sort2).is_some()
641    }
642
643    fn is_coercible_from_func(
644        &mut self,
645        sort: &rty::Sort,
646        fhir_id: FhirId,
647    ) -> Option<rty::PolyFuncSort> {
648        if let rty::Sort::Func(fsort) = sort {
649            Some(fsort.clone())
650        } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
651            self.wfckresults
652                .coercions_mut()
653                .insert(fhir_id, vec![rty::Coercion::Inject(def_id)]);
654            Some(fsort.clone())
655        } else {
656            None
657        }
658    }
659
660    fn is_coercible_to_func(
661        &mut self,
662        sort: &rty::Sort,
663        fhir_id: FhirId,
664    ) -> Option<rty::PolyFuncSort> {
665        if let rty::Sort::Func(fsort) = sort {
666            Some(fsort.clone())
667        } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
668            self.wfckresults
669                .coercions_mut()
670                .insert(fhir_id, vec![rty::Coercion::Project(def_id)]);
671            Some(fsort.clone())
672        } else {
673            None
674        }
675    }
676
677    fn try_equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
678        let sort1 = self.resolve_vars_if_possible(sort1);
679        let sort2 = self.resolve_vars_if_possible(sort2);
680        self.try_equate_inner(&sort1, &sort2)
681    }
682
683    fn try_equate_inner(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
684        match (sort1, sort2) {
685            (rty::Sort::Infer(rty::SortVar(vid1)), rty::Sort::Infer(rty::SortVar(vid2))) => {
686                self.sort_unification_table
687                    .unify_var_var(*vid1, *vid2)
688                    .ok()?;
689            }
690            (rty::Sort::Infer(rty::SortVar(vid)), sort)
691            | (sort, rty::Sort::Infer(rty::SortVar(vid))) => {
692                self.sort_unification_table
693                    .unify_var_value(*vid, Some(sort.clone()))
694                    .ok()?;
695            }
696            (rty::Sort::Infer(rty::NumVar(vid1)), rty::Sort::Infer(rty::NumVar(vid2))) => {
697                self.num_unification_table
698                    .unify_var_var(*vid1, *vid2)
699                    .ok()?;
700            }
701            (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::Int)
702            | (rty::Sort::Int, rty::Sort::Infer(rty::NumVar(vid))) => {
703                self.num_unification_table
704                    .unify_var_value(*vid, Some(rty::NumVarValue::Int))
705                    .ok()?;
706            }
707            (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::Real)
708            | (rty::Sort::Real, rty::Sort::Infer(rty::NumVar(vid))) => {
709                self.num_unification_table
710                    .unify_var_value(*vid, Some(rty::NumVarValue::Real))
711                    .ok()?;
712            }
713            (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::BitVec(sz))
714            | (rty::Sort::BitVec(sz), rty::Sort::Infer(rty::NumVar(vid))) => {
715                self.num_unification_table
716                    .unify_var_value(*vid, Some(rty::NumVarValue::BitVec(*sz)))
717                    .ok()?;
718            }
719
720            (rty::Sort::App(ctor1, args1), rty::Sort::App(ctor2, args2)) => {
721                if ctor1 != ctor2 || args1.len() != args2.len() {
722                    return None;
723                }
724                let mut args = vec![];
725                for (s1, s2) in args1.iter().zip(args2.iter()) {
726                    args.push(self.try_equate_inner(s1, s2)?);
727                }
728            }
729            (rty::Sort::BitVec(size1), rty::Sort::BitVec(size2)) => {
730                self.try_equate_bv_sizes(*size1, *size2)?;
731            }
732            _ if sort1 == sort2 => {}
733            _ => return None,
734        }
735        Some(sort1.clone())
736    }
737
738    fn try_equate_bv_sizes(
739        &mut self,
740        size1: rty::BvSize,
741        size2: rty::BvSize,
742    ) -> Option<rty::BvSize> {
743        match (size1, size2) {
744            (rty::BvSize::Infer(vid1), rty::BvSize::Infer(vid2)) => {
745                self.bv_size_unification_table
746                    .unify_var_var(vid1, vid2)
747                    .ok()?;
748            }
749            (rty::BvSize::Infer(vid), size) | (size, rty::BvSize::Infer(vid)) => {
750                self.bv_size_unification_table
751                    .unify_var_value(vid, Some(size))
752                    .ok()?;
753            }
754            _ if size1 == size2 => {}
755            _ => return None,
756        }
757        Some(size1)
758    }
759
760    fn equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> rty::Sort {
761        self.try_equate(sort1, sort2)
762            .unwrap_or_else(|| bug!("failed to equate sorts: `{sort1:?}` `{sort2:?}`"))
763    }
764
765    pub(crate) fn next_sort_var(&mut self) -> rty::Sort {
766        rty::Sort::Infer(rty::SortVar(self.next_sort_vid()))
767    }
768
769    fn next_num_var(&mut self) -> rty::Sort {
770        rty::Sort::Infer(rty::NumVar(self.next_num_vid()))
771    }
772
773    pub(crate) fn next_sort_vid(&mut self) -> rty::SortVid {
774        self.sort_unification_table.new_key(None)
775    }
776
777    fn next_num_vid(&mut self) -> rty::NumVid {
778        self.num_unification_table.new_key(None)
779    }
780
781    fn next_bv_size_var(&mut self) -> rty::BvSize {
782        rty::BvSize::Infer(self.next_bv_size_vid())
783    }
784
785    fn next_bv_size_vid(&mut self) -> rty::BvSizeVid {
786        self.bv_size_unification_table.new_key(None)
787    }
788
789    fn ensure_resolved_path(&mut self, path: &fhir::PathExpr) -> Result<rty::Sort> {
790        let sort = self.synth_path(path)?;
791        self.fully_resolve(&sort)
792            .map_err(|_| self.emit_err(errors::CannotInferSort::new(path.span)))
793    }
794
795    fn is_single_field_struct(&mut self, sort: &rty::Sort) -> Option<(DefId, rty::Sort)> {
796        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = sort
797            && let Some(variant) = sort_def.opt_struct_variant()
798            && let [sort] = &variant.field_sorts(sort_args)[..]
799        {
800            Some((sort_def.did(), sort.clone()))
801        } else {
802            None
803        }
804    }
805
806    pub(crate) fn into_results(mut self) -> Result<WfckResults> {
807        // Make sure the int literals are fully resolved
808        for (fhir_id, (sort, span)) in std::mem::take(&mut self.sort_of_literal) {
809            let sort = self
810                .fully_resolve(&sort)
811                .map_err(|_| self.emit_err(errors::CannotInferSort::new(span)))?;
812            self.wfckresults.node_sorts_mut().insert(fhir_id, sort);
813        }
814
815        // Make sure the binary operators are fully resolved
816        for (fhir_id, (sort, span)) in std::mem::take(&mut self.sort_of_bin_op) {
817            let sort = self
818                .fully_resolve(&sort)
819                .map_err(|_| self.emit_err(errors::CannotInferSort::new(span)))?;
820            self.wfckresults.bin_op_sorts_mut().insert(fhir_id, sort);
821        }
822
823        let allow_uninterpreted_cast = self
824            .owner
825            .def_id()
826            .map_or(false, |def_id| self.genv.infer_opts(def_id).allow_uninterpreted_cast);
827
828        // Make sure that function applications are fully resolved
829        for (fhir_id, (sort_args, span)) in std::mem::take(&mut self.sort_args_of_app) {
830            let mut res = vec![];
831            for sort_arg in &sort_args {
832                let sort_arg = if let rty::SortArg::Sort(sort) = sort_arg {
833                    let sort = self
834                        .fully_resolve(sort)
835                        .map_err(|_| self.emit_err(errors::CannotInferSort::new(span)))?;
836                    rty::SortArg::Sort(sort)
837                } else {
838                    sort_arg.clone()
839                };
840                res.push(sort_arg);
841            }
842            if self.casts.contains(&fhir_id) {
843                let [rty::SortArg::Sort(from), rty::SortArg::Sort(to)] = &res[..] else {
844                    span_bug!(span, "invalid sort args!")
845                };
846                if !allow_uninterpreted_cast
847                    && matches!(from.cast_kind(to), rty::CastKind::Uninterpreted)
848                {
849                    return Err(self.emit_err(errors::InvalidCast::new(span, from, to)));
850                }
851            }
852            self.wfckresults
853                .fn_app_sorts_mut()
854                .insert(fhir_id, res.into());
855        }
856
857        // Make sure all parameters are fully resolved
858        for (_, (param, sort)) in std::mem::take(&mut self.params) {
859            let sort = self
860                .fully_resolve(&sort)
861                .map_err(|_| self.emit_err(errors::SortAnnotationNeeded::new(&param)))?;
862            self.wfckresults
863                .node_sorts_mut()
864                .insert(param.fhir_id, sort);
865        }
866
867        Ok(self.wfckresults)
868    }
869
870    pub(crate) fn infer_mode(&self, id: fhir::ParamId) -> fhir::InferMode {
871        fhir::InferMode::from_param_kind(self.params[&id].0.kind)
872    }
873
874    #[track_caller]
875    pub(crate) fn param_sort(&self, id: fhir::ParamId) -> rty::Sort {
876        self.params
877            .get(&id)
878            .unwrap_or_else(|| bug!("no entry found for `{id:?}`"))
879            .1
880            .clone()
881    }
882
883    fn shallow_resolve(&mut self, sort: &rty::Sort) -> rty::Sort {
884        sort.fold_with(&mut ShallowResolver { infcx: self })
885    }
886
887    fn resolve_vars_if_possible(&mut self, sort: &rty::Sort) -> rty::Sort {
888        sort.fold_with(&mut OpportunisticResolver { infcx: self })
889    }
890
891    pub(crate) fn fully_resolve(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, ()> {
892        sort.try_fold_with(&mut FullResolver { infcx: self })
893    }
894}
895
896/// Before the main sort inference, we do a first traversal checking all implicitly scoped parameters
897/// declared with `@` or `#` and infer their sort based on the type they are indexing, e.g., if `n` was
898/// declared as `i32[@n]`, we infer `int` for `n`.
899///
900/// This prepass is necessary because sometimes the order in which we traverse expressions can
901/// affect what we can infer. By resolving implicit parameters first, we ensure more consistent and
902/// complete inference regardless of how expressions are later traversed.
903///
904/// It should be possible to improve sort checking (e.g., by allowing partially resolved sorts in
905/// function position) such that we don't need this anymore.
906pub(crate) struct ImplicitParamInferer<'a, 'genv, 'tcx> {
907    infcx: &'a mut InferCtxt<'genv, 'tcx>,
908    errors: Errors<'genv>,
909}
910
911impl<'a, 'genv, 'tcx> ImplicitParamInferer<'a, 'genv, 'tcx> {
912    pub(crate) fn infer(
913        infcx: &'a mut InferCtxt<'genv, 'tcx>,
914        node: &fhir::OwnerNode<'genv>,
915    ) -> Result {
916        let errors = Errors::new(infcx.genv.sess());
917        let mut vis = Self { infcx, errors };
918        vis.visit_node(node);
919        vis.errors.into_result()
920    }
921
922    fn infer_implicit_params(&mut self, idx: &fhir::Expr, expected: &rty::Sort) {
923        match idx.kind {
924            fhir::ExprKind::Var(var, Some(_)) => {
925                let (_, id) = var.res.expect_param();
926                let found = self.infcx.param_sort(id);
927                self.infcx.equate(&found, expected);
928            }
929            fhir::ExprKind::Record(flds) => {
930                if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
931                    let sorts = sort_def.struct_variant().field_sorts(sort_args);
932                    if flds.len() != sorts.len() {
933                        self.errors.emit(errors::ArgCountMismatch::new(
934                            Some(idx.span),
935                            String::from("type"),
936                            sorts.len(),
937                            flds.len(),
938                        ));
939                        return;
940                    }
941                    for (f, sort) in iter::zip(flds, &sorts) {
942                        self.infer_implicit_params(f, sort);
943                    }
944                } else {
945                    self.errors.emit(errors::ArgCountMismatch::new(
946                        Some(idx.span),
947                        String::from("type"),
948                        1,
949                        flds.len(),
950                    ));
951                }
952            }
953            _ => {}
954        }
955    }
956}
957
958impl<'genv> fhir::visit::Visitor<'genv> for ImplicitParamInferer<'_, 'genv, '_> {
959    fn visit_ty(&mut self, ty: &fhir::Ty<'genv>) {
960        if let fhir::TyKind::Indexed(bty, idx) = &ty.kind {
961            let expected = self.infcx.sort_of_bty(bty.fhir_id);
962            self.infer_implicit_params(idx, &expected);
963        }
964        fhir::visit::walk_ty(self, ty);
965    }
966}
967
968impl InferCtxt<'_, '_> {
969    #[track_caller]
970    fn emit_sort_mismatch(
971        &mut self,
972        span: Span,
973        expected: &rty::Sort,
974        found: &rty::Sort,
975    ) -> ErrorGuaranteed {
976        let expected = self.resolve_vars_if_possible(expected);
977        let found = self.resolve_vars_if_possible(found);
978        self.emit_err(errors::SortMismatch::new(span, expected, found))
979    }
980
981    fn emit_field_not_found(&mut self, sort: &rty::Sort, field: Ident) -> ErrorGuaranteed {
982        self.emit_err(errors::FieldNotFound::new(sort.clone(), field))
983    }
984
985    #[track_caller]
986    fn emit_err<'b>(&'b self, err: impl Diagnostic<'b>) -> ErrorGuaranteed {
987        self.genv.sess().emit_err(err)
988    }
989}
990
991struct ShallowResolver<'a, 'genv, 'tcx> {
992    infcx: &'a mut InferCtxt<'genv, 'tcx>,
993}
994
995impl TypeFolder for ShallowResolver<'_, '_, '_> {
996    fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
997        match sort {
998            rty::Sort::Infer(rty::SortVar(vid)) => {
999                // if `sort` is a sort variable, it can be resolved to an num/bit-vec variable,
1000                // which can then be recursively resolved, hence the recursion. Note though that
1001                // we prevent sort variables from unifying to other sort variables directly (though
1002                // they may be embedded structurally), so this recursion should always be of very
1003                // limited depth.
1004                self.infcx
1005                    .sort_unification_table
1006                    .probe_value(*vid)
1007                    .map(|sort| sort.fold_with(self))
1008                    .unwrap_or(sort.clone())
1009            }
1010            rty::Sort::Infer(rty::NumVar(vid)) => {
1011                // same here, a num var could had been unified with a bitvector
1012                self.infcx
1013                    .num_unification_table
1014                    .probe_value(*vid)
1015                    .map(|val| val.to_sort().fold_with(self))
1016                    .unwrap_or(sort.clone())
1017            }
1018            rty::Sort::BitVec(rty::BvSize::Infer(vid)) => {
1019                self.infcx
1020                    .bv_size_unification_table
1021                    .probe_value(*vid)
1022                    .map(rty::Sort::BitVec)
1023                    .unwrap_or(sort.clone())
1024            }
1025            _ => sort.clone(),
1026        }
1027    }
1028}
1029
1030struct OpportunisticResolver<'a, 'genv, 'tcx> {
1031    infcx: &'a mut InferCtxt<'genv, 'tcx>,
1032}
1033
1034impl TypeFolder for OpportunisticResolver<'_, '_, '_> {
1035    fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
1036        let s = self.infcx.shallow_resolve(sort);
1037        s.super_fold_with(self)
1038    }
1039}
1040
1041struct FullResolver<'a, 'genv, 'tcx> {
1042    infcx: &'a mut InferCtxt<'genv, 'tcx>,
1043}
1044
1045impl FallibleTypeFolder for FullResolver<'_, '_, '_> {
1046    type Error = ();
1047
1048    fn try_fold_sort(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, Self::Error> {
1049        let s = self.infcx.shallow_resolve(sort);
1050        match s {
1051            rty::Sort::Infer(_) | rty::Sort::BitVec(rty::BvSize::Infer(_)) => Err(()),
1052            _ => s.try_super_fold_with(self),
1053        }
1054    }
1055}