flux_fhir_analysis/wf/
sortck.rs

1use std::iter;
2
3use derive_where::derive_where;
4use ena::unify::InPlaceUnificationTable;
5use flux_common::{bug, iter::IterExt, span_bug, tracked_span_bug};
6use flux_errors::{ErrorGuaranteed, Errors};
7use flux_infer::projections::NormalizeExt;
8use flux_middle::{
9    fhir::{self, FhirId, FluxOwnerId, QPathExpr, visit::Visitor as _},
10    global_env::GlobalEnv,
11    queries::QueryResult,
12    rty::{
13        self, AdtSortDef, FuncSort, List, WfckResults,
14        fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable},
15    },
16};
17use itertools::{Itertools, izip};
18use rustc_data_structures::unord::UnordMap;
19use rustc_errors::Diagnostic;
20use rustc_hash::FxHashMap;
21use rustc_hir::def::DefKind;
22use rustc_middle::ty::TypingMode;
23use rustc_span::{Span, def_id::DefId, symbol::Ident};
24
25use super::errors;
26use crate::rustc_infer::infer::TyCtxtInferExt;
27
28type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
29
30pub(super) struct InferCtxt<'genv, 'tcx> {
31    pub genv: GlobalEnv<'genv, 'tcx>,
32    pub owner: FluxOwnerId,
33    pub wfckresults: WfckResults,
34    sort_unification_table: InPlaceUnificationTable<rty::SortVid>,
35    bv_size_unification_table: InPlaceUnificationTable<rty::BvSizeVid>,
36    params: FxHashMap<fhir::ParamId, (fhir::RefineParam<'genv>, rty::Sort)>,
37    node_sort: FxHashMap<FhirId, rty::Sort>,
38    path_args: UnordMap<FhirId, rty::GenericArgs>,
39    sort_of_alias_reft: FxHashMap<FhirId, rty::FuncSort>,
40    sort_of_literal: NodeMap<'genv, rty::Sort>,
41    sort_of_bin_op: NodeMap<'genv, rty::Sort>,
42    sort_args_of_app: NodeMap<'genv, List<rty::SortArg>>,
43}
44
45pub fn prim_op_sort(op: &fhir::BinOp) -> Option<(Vec<rty::Sort>, rty::Sort)> {
46    match op {
47        fhir::BinOp::BitAnd
48        | fhir::BinOp::BitOr
49        | fhir::BinOp::BitXor
50        | fhir::BinOp::BitShl
51        | fhir::BinOp::BitShr => Some((vec![rty::Sort::Int, rty::Sort::Int], rty::Sort::Int)),
52        _ => None,
53    }
54}
55
56impl<'genv, 'tcx> InferCtxt<'genv, 'tcx> {
57    pub(super) fn new(genv: GlobalEnv<'genv, 'tcx>, owner: FluxOwnerId) -> Self {
58        // We skip 0 because that's used for sort dummy self types during conv.
59        let mut sort_unification_table = InPlaceUnificationTable::new();
60        sort_unification_table.new_key(Default::default());
61        Self {
62            genv,
63            owner,
64            wfckresults: WfckResults::new(owner),
65            sort_unification_table,
66            bv_size_unification_table: InPlaceUnificationTable::new(),
67            params: Default::default(),
68            node_sort: Default::default(),
69            path_args: Default::default(),
70            sort_of_alias_reft: Default::default(),
71            sort_of_literal: Default::default(),
72            sort_of_bin_op: Default::default(),
73            sort_args_of_app: Default::default(),
74        }
75    }
76
77    fn check_abs(
78        &mut self,
79        expr: &fhir::Expr<'genv>,
80        params: &[fhir::RefineParam],
81        body: &fhir::Expr<'genv>,
82        expected: &rty::Sort,
83    ) -> Result {
84        let Some(fsort) = self.is_coercible_from_func(expected, expr.fhir_id) else {
85            return Err(self.emit_err(errors::UnexpectedFun::new(expr.span, expected)));
86        };
87
88        let fsort = fsort.expect_mono();
89
90        if params.len() != fsort.inputs().len() {
91            return Err(self.emit_err(errors::ParamCountMismatch::new(
92                expr.span,
93                fsort.inputs().len(),
94                params.len(),
95            )));
96        }
97        iter::zip(params, fsort.inputs()).try_for_each_exhaust(|(param, expected)| {
98            let found = self.param_sort(param.id);
99            if self.try_equate(&found, expected).is_none() {
100                return Err(self.emit_sort_mismatch(param.span, expected, &found));
101            }
102            Ok(())
103        })?;
104        self.check_expr(body, fsort.output())?;
105        self.wfckresults
106            .node_sorts_mut()
107            .insert(body.fhir_id, fsort.output().clone());
108        Ok(())
109    }
110
111    fn check_field_exprs(
112        &mut self,
113        expr_span: Span,
114        sort_def: &AdtSortDef,
115        sort_args: &[rty::Sort],
116        field_exprs: &[fhir::FieldExpr<'genv>],
117        spread: Option<&fhir::Spread<'genv>>,
118        expected: &rty::Sort,
119    ) -> Result {
120        let sort_by_field_name = sort_def.struct_variant().sort_by_field_name(sort_args);
121        let mut used_fields = FxHashMap::default();
122        for expr in field_exprs {
123            // make sure that the field is actually a field
124            let Some(sort) = sort_by_field_name.get(&expr.ident.name) else {
125                return Err(self.emit_err(errors::FieldNotFound::new(expected.clone(), expr.ident)));
126            };
127            if let Some(old_field) = used_fields.insert(expr.ident.name, expr.ident) {
128                return Err(self.emit_err(errors::DuplicateFieldUsed::new(expr.ident, old_field)));
129            }
130            self.check_expr(&expr.expr, sort)?;
131        }
132        if let Some(spread) = spread {
133            // must check that the spread is of the same sort as the constructor
134            self.check_expr(&spread.expr, expected)
135        } else if sort_by_field_name.len() != used_fields.len() {
136            // emit an error because all fields are not used
137            let missing_fields = sort_by_field_name
138                .into_keys()
139                .filter(|x| !used_fields.contains_key(x))
140                .collect();
141            Err(self.emit_err(errors::ConstructorMissingFields::new(expr_span, missing_fields)))
142        } else {
143            Ok(())
144        }
145    }
146
147    fn check_constructor(
148        &mut self,
149        expr: &fhir::Expr,
150        field_exprs: &[fhir::FieldExpr<'genv>],
151        spread: Option<&fhir::Spread<'genv>>,
152        expected: &rty::Sort,
153    ) -> Result {
154        let expected = self.resolve_vars_if_possible(expected);
155        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = &expected {
156            self.wfckresults
157                .record_ctors_mut()
158                .insert(expr.fhir_id, sort_def.did());
159            self.check_field_exprs(expr.span, sort_def, sort_args, field_exprs, spread, &expected)?;
160            Ok(())
161        } else {
162            Err(self.emit_err(errors::UnexpectedConstructor::new(expr.span, &expected)))
163        }
164    }
165
166    fn check_record(
167        &mut self,
168        arg: &fhir::Expr<'genv>,
169        flds: &[fhir::Expr<'genv>],
170        expected: &rty::Sort,
171    ) -> Result {
172        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
173            let sorts = sort_def.struct_variant().field_sorts(sort_args);
174            if flds.len() != sorts.len() {
175                return Err(self.emit_err(errors::ArgCountMismatch::new(
176                    Some(arg.span),
177                    String::from("type"),
178                    sorts.len(),
179                    flds.len(),
180                )));
181            }
182            self.wfckresults
183                .record_ctors_mut()
184                .insert(arg.fhir_id, sort_def.did());
185
186            izip!(flds, &sorts)
187                .map(|(arg, expected)| self.check_expr(arg, expected))
188                .try_collect_exhaust()
189        } else {
190            Err(self.emit_err(errors::ArgCountMismatch::new(
191                Some(arg.span),
192                String::from("type"),
193                1,
194                flds.len(),
195            )))
196        }
197    }
198
199    pub(super) fn check_expr(&mut self, expr: &fhir::Expr<'genv>, expected: &rty::Sort) -> Result {
200        match expr.kind {
201            fhir::ExprKind::IfThenElse(p, e1, e2) => {
202                self.check_expr(p, &rty::Sort::Bool)?;
203                self.check_expr(e1, expected)?;
204                self.check_expr(e2, expected)?;
205                return Ok(());
206            }
207            fhir::ExprKind::Abs(params, body) => {
208                self.check_abs(expr, params, body, expected)?;
209                return Ok(());
210            }
211            fhir::ExprKind::Record(fields) => {
212                self.check_record(expr, fields, expected)?;
213                return Ok(());
214            }
215            fhir::ExprKind::Constructor(None, exprs, spread) => {
216                self.check_constructor(expr, exprs, spread, expected)?;
217                return Ok(());
218            }
219            fhir::ExprKind::Tuple(exprs) => {
220                // Check tuple against expected tuple sort
221                if let rty::Sort::Tuple(expected_sorts) = expected
222                    && exprs.len() == expected_sorts.len()
223                {
224                    for (expr, expected_sort) in iter::zip(exprs, expected_sorts) {
225                        self.check_expr(expr, expected_sort)?;
226                    }
227                    return Ok(());
228                }
229            }
230            fhir::ExprKind::Err(_) => {
231                // an error has already been reported so we can just skip
232                return Ok(());
233            }
234            fhir::ExprKind::UnaryOp(..)
235            | fhir::ExprKind::SetLiteral(..)
236            | fhir::ExprKind::BinaryOp(..)
237            | fhir::ExprKind::Dot(..)
238            | fhir::ExprKind::App(..)
239            | fhir::ExprKind::Alias(..)
240            | fhir::ExprKind::Var(..)
241            | fhir::ExprKind::Literal(..)
242            | fhir::ExprKind::BoundedQuant(..)
243            | fhir::ExprKind::Block(..)
244            | fhir::ExprKind::Constructor(..)
245            | fhir::ExprKind::PrimApp(..) => {}
246        }
247        // Fallback to synthesis and then check
248        let found = self.synth_expr(expr)?;
249        let found = self.resolve_vars_if_possible(&found);
250        let expected = self.resolve_vars_if_possible(expected);
251        if !self.is_coercible(&found, &expected, expr.fhir_id) {
252            return Err(self.emit_sort_mismatch(expr.span, &expected, &found));
253        }
254        Ok(())
255    }
256
257    pub(super) fn check_loc(&mut self, loc: &fhir::PathExpr) -> Result {
258        let found = self.synth_path(loc);
259        if found == rty::Sort::Loc {
260            Ok(())
261        } else {
262            Err(self.emit_sort_mismatch(loc.span, &rty::Sort::Loc, &found))
263        }
264    }
265
266    fn synth_lit(&mut self, lit: fhir::Lit, expr: &fhir::Expr<'genv>) -> rty::Sort {
267        match lit {
268            fhir::Lit::Int(_, Some(fhir::NumLitKind::Int)) => rty::Sort::Int,
269            fhir::Lit::Int(_, Some(fhir::NumLitKind::Real)) => rty::Sort::Real,
270            fhir::Lit::Int(_, None) => {
271                let sort = self.next_sort_var_with_cstr(rty::SortCstr::NUMERIC);
272                self.sort_of_literal.insert(*expr, sort.clone());
273                sort
274            }
275            fhir::Lit::Bool(_) => rty::Sort::Bool,
276            fhir::Lit::Str(_) => rty::Sort::Str,
277            fhir::Lit::Char(_) => rty::Sort::Char,
278        }
279    }
280
281    fn synth_prim_app(
282        &mut self,
283        op: &fhir::BinOp,
284        e1: &fhir::Expr<'genv>,
285        e2: &fhir::Expr<'genv>,
286        span: Span,
287    ) -> Result<rty::Sort> {
288        let Some((inputs, output)) = prim_op_sort(op) else {
289            return Err(self.emit_err(errors::UnsupportedPrimOp::new(span, *op)));
290        };
291        let [sort1, sort2] = &inputs[..] else {
292            return Err(self.emit_err(errors::ArgCountMismatch::new(
293                Some(span),
294                String::from("primop app"),
295                inputs.len(),
296                2,
297            )));
298        };
299        self.check_expr(e1, sort1)?;
300        self.check_expr(e2, sort2)?;
301        Ok(output)
302    }
303
304    fn synth_expr(&mut self, expr: &fhir::Expr<'genv>) -> Result<rty::Sort> {
305        match expr.kind {
306            fhir::ExprKind::Var(QPathExpr::Resolved(path, _)) => Ok(self.synth_path(&path)),
307            fhir::ExprKind::Var(QPathExpr::TypeRelative(..)) => {
308                Ok(self.synth_type_relative_path(expr))
309            }
310            fhir::ExprKind::Literal(lit) => Ok(self.synth_lit(lit, expr)),
311            fhir::ExprKind::BinaryOp(op, e1, e2) => self.synth_binary_op(expr, op, e1, e2),
312            fhir::ExprKind::PrimApp(op, e1, e2) => self.synth_prim_app(&op, e1, e2, expr.span),
313            fhir::ExprKind::UnaryOp(op, e) => self.synth_unary_op(op, e),
314            fhir::ExprKind::App(callee, args) => {
315                let sort = self.ensure_resolved_path(&callee)?;
316                let Some(poly_fsort) = self.is_coercible_to_func(&sort, callee.fhir_id) else {
317                    return Err(self.emit_err(errors::ExpectedFun::new(callee.span, &sort)));
318                };
319                let fsort = self.instantiate_func_sort(expr, poly_fsort);
320                self.synth_app(fsort, args, expr.span)
321            }
322            fhir::ExprKind::BoundedQuant(.., body) => {
323                self.check_expr(body, &rty::Sort::Bool)?;
324                Ok(rty::Sort::Bool)
325            }
326            fhir::ExprKind::Alias(_alias_reft, args) => {
327                // To check the application we only need the sort of `_alias_reft` which we collected
328                // during early conv, but should we do any extra checks on `_alias_reft`?
329                let fsort = self.sort_of_alias_reft(expr.fhir_id);
330                self.synth_app(fsort, args, expr.span)
331            }
332            fhir::ExprKind::IfThenElse(p, e1, e2) => {
333                self.check_expr(p, &rty::Sort::Bool)?;
334                let sort = self.synth_expr(e1)?;
335                self.check_expr(e2, &sort)?;
336                Ok(sort)
337            }
338            fhir::ExprKind::Dot(base, fld) => {
339                let sort = self.synth_expr(base)?;
340                let sort = self
341                    .fully_resolve(&sort)
342                    .map_err(|_| self.emit_err(errors::CannotInferSort::new(base.span)))?;
343                match &sort {
344                    rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) => {
345                        let (proj, sort) = sort_def
346                            .struct_variant()
347                            .field_by_name(sort_def.did(), sort_args, fld.name)
348                            .ok_or_else(|| self.emit_field_not_found(&sort, fld))?;
349                        self.wfckresults
350                            .field_projs_mut()
351                            .insert(expr.fhir_id, proj);
352                        Ok(sort)
353                    }
354                    rty::Sort::Tuple(sorts) => {
355                        // Parse field name as integer for tuple field access
356                        if let Ok(field_idx) = fld.name.as_str().parse::<usize>()
357                            && field_idx < sorts.len()
358                        {
359                            let proj = rty::FieldProj::Tuple {
360                                arity: sorts.len(),
361                                field: field_idx as u32,
362                            };
363                            self.wfckresults
364                                .field_projs_mut()
365                                .insert(expr.fhir_id, proj);
366                            Ok(sorts[field_idx].clone())
367                        } else {
368                            Err(self.emit_field_not_found(&sort, fld))
369                        }
370                    }
371                    rty::Sort::Bool | rty::Sort::Int | rty::Sort::Real => {
372                        Err(self.emit_err(errors::InvalidPrimitiveDotAccess::new(&sort, fld)))
373                    }
374                    _ => Err(self.emit_field_not_found(&sort, fld)),
375                }
376            }
377            fhir::ExprKind::Constructor(Some(path), field_exprs, spread) => {
378                // if we have S then we can synthesize otherwise we fail
379                // first get the sort based on the path - for example S { ... } => S
380                // and we should expect sort to be a struct or enum app
381                let path_def_id = match path.res {
382                    fhir::Res::Def(DefKind::Enum | DefKind::Struct, def_id) => def_id,
383                    _ => span_bug!(expr.span, "unexpected path in constructor"),
384                };
385                let sort_def = self
386                    .genv
387                    .adt_sort_def_of(path_def_id)
388                    .map_err(|e| self.emit_err(e))?;
389                // generate fresh inferred sorts for each param
390                let fresh_args: rty::List<_> = (0..sort_def.param_count())
391                    .map(|_| self.next_sort_var())
392                    .collect();
393                let sort = rty::Sort::App(rty::SortCtor::Adt(sort_def.clone()), fresh_args.clone());
394                // check fields & spread against fresh args
395                self.check_field_exprs(
396                    expr.span,
397                    &sort_def,
398                    &fresh_args,
399                    field_exprs,
400                    spread,
401                    &sort,
402                )?;
403                Ok(sort)
404            }
405            fhir::ExprKind::Block(decls, body) => {
406                for decl in decls {
407                    self.check_expr(&decl.init, &self.param_sort(decl.param.id))?;
408                }
409                self.synth_expr(body)
410            }
411            fhir::ExprKind::SetLiteral(elems) => {
412                let elem_sort = self.next_sort_var();
413                for elem in elems {
414                    self.check_expr(elem, &elem_sort)?;
415                }
416                Ok(rty::Sort::App(rty::SortCtor::Set, List::singleton(elem_sort)))
417            }
418            fhir::ExprKind::Tuple(exprs) => {
419                let sorts = exprs
420                    .iter()
421                    .map(|expr| self.synth_expr(expr))
422                    .try_collect()?;
423                Ok(rty::Sort::Tuple(sorts))
424            }
425            _ => Err(self.emit_err(errors::CannotInferSort::new(expr.span))),
426        }
427    }
428
429    fn synth_path(&mut self, path: &fhir::PathExpr) -> rty::Sort {
430        self.node_sort
431            .get(&path.fhir_id)
432            .unwrap_or_else(|| tracked_span_bug!("no sort found for path: `{path:?}`"))
433            .clone()
434    }
435
436    fn synth_type_relative_path(&mut self, expr: &fhir::Expr) -> rty::Sort {
437        self.node_sort
438            .get(&expr.fhir_id)
439            .unwrap_or_else(|| tracked_span_bug!("no sort found for: `{expr:?}`"))
440            .clone()
441    }
442
443    fn synth_binary_op(
444        &mut self,
445        expr: &fhir::Expr<'genv>,
446        op: fhir::BinOp,
447        e1: &fhir::Expr<'genv>,
448        e2: &fhir::Expr<'genv>,
449    ) -> Result<rty::Sort> {
450        match op {
451            fhir::BinOp::Or | fhir::BinOp::And | fhir::BinOp::Iff | fhir::BinOp::Imp => {
452                self.check_expr(e1, &rty::Sort::Bool)?;
453                self.check_expr(e2, &rty::Sort::Bool)?;
454                Ok(rty::Sort::Bool)
455            }
456            fhir::BinOp::Eq | fhir::BinOp::Ne => {
457                let sort = self.next_sort_var();
458                self.check_expr(e1, &sort)?;
459                self.check_expr(e2, &sort)?;
460                Ok(rty::Sort::Bool)
461            }
462            fhir::BinOp::Lt | fhir::BinOp::Le | fhir::BinOp::Gt | fhir::BinOp::Ge => {
463                let sort = self.next_sort_var();
464                self.check_expr(e1, &sort)?;
465                self.check_expr(e2, &sort)?;
466                self.sort_of_bin_op.insert(*expr, sort.clone());
467                Ok(rty::Sort::Bool)
468            }
469            fhir::BinOp::Add
470            | fhir::BinOp::Sub
471            | fhir::BinOp::Mul
472            | fhir::BinOp::Div
473            | fhir::BinOp::Mod
474            | fhir::BinOp::BitAnd
475            | fhir::BinOp::BitOr
476            | fhir::BinOp::BitXor
477            | fhir::BinOp::BitShl
478            | fhir::BinOp::BitShr => {
479                let sort = self.next_sort_var_with_cstr(rty::SortCstr::from_bin_op(op));
480                self.check_expr(e1, &sort)?;
481                self.check_expr(e2, &sort)?;
482                self.sort_of_bin_op.insert(*expr, sort.clone());
483                Ok(sort)
484            }
485        }
486    }
487
488    fn synth_unary_op(&mut self, op: fhir::UnOp, e: &fhir::Expr<'genv>) -> Result<rty::Sort> {
489        match op {
490            fhir::UnOp::Not => {
491                self.check_expr(e, &rty::Sort::Bool)?;
492                Ok(rty::Sort::Bool)
493            }
494            fhir::UnOp::Neg => {
495                self.check_expr(e, &rty::Sort::Int)?;
496                Ok(rty::Sort::Int)
497            }
498        }
499    }
500
501    fn synth_app(
502        &mut self,
503        fsort: FuncSort,
504        args: &[fhir::Expr<'genv>],
505        span: Span,
506    ) -> Result<rty::Sort> {
507        if args.len() != fsort.inputs().len() {
508            return Err(self.emit_err(errors::ArgCountMismatch::new(
509                Some(span),
510                String::from("function"),
511                fsort.inputs().len(),
512                args.len(),
513            )));
514        }
515
516        iter::zip(args, fsort.inputs())
517            .try_for_each_exhaust(|(arg, formal)| self.check_expr(arg, formal))?;
518
519        Ok(fsort.output().clone())
520    }
521
522    fn instantiate_func_sort(
523        &mut self,
524        app_expr: &fhir::Expr<'genv>,
525        fsort: rty::PolyFuncSort,
526    ) -> rty::FuncSort {
527        let args = fsort
528            .params()
529            .map(|kind| {
530                match kind {
531                    rty::SortParamKind::Sort => rty::SortArg::Sort(self.next_sort_var()),
532                    rty::SortParamKind::BvSize => rty::SortArg::BvSize(self.next_bv_size_var()),
533                }
534            })
535            .collect_vec();
536        self.sort_args_of_app
537            .insert(*app_expr, List::from_slice(&args));
538        fsort.instantiate(&args)
539    }
540
541    pub(crate) fn insert_node_sort(&mut self, fhir_id: FhirId, sort: rty::Sort) {
542        self.node_sort.insert(fhir_id, sort);
543    }
544
545    pub(crate) fn sort_of_bty(&self, bty: &fhir::BaseTy) -> rty::Sort {
546        self.node_sort
547            .get(&bty.fhir_id)
548            .unwrap_or_else(|| tracked_span_bug!("no sort found for bty: `{bty:?}`"))
549            .clone()
550    }
551
552    pub(crate) fn insert_path_args(&mut self, fhir_id: FhirId, args: rty::GenericArgs) {
553        self.path_args.insert(fhir_id, args);
554    }
555
556    pub(crate) fn path_args(&self, fhir_id: FhirId) -> rty::GenericArgs {
557        self.path_args
558            .get(&fhir_id)
559            .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
560            .clone()
561    }
562
563    pub(crate) fn insert_sort_for_alias_reft(&mut self, fhir_id: FhirId, fsort: rty::FuncSort) {
564        self.sort_of_alias_reft.insert(fhir_id, fsort);
565    }
566
567    fn sort_of_alias_reft(&self, fhir_id: FhirId) -> rty::FuncSort {
568        self.sort_of_alias_reft
569            .get(&fhir_id)
570            .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
571            .clone()
572    }
573
574    fn deeply_normalize_sorts<T: TypeFoldable + Clone>(
575        genv: GlobalEnv,
576        owner: FluxOwnerId,
577        t: &T,
578    ) -> QueryResult<T> {
579        let infcx = genv
580            .tcx()
581            .infer_ctxt()
582            .with_next_trait_solver(true)
583            .build(TypingMode::non_body_analysis());
584        if let Some(def_id) = owner.resolved_id() {
585            t.deeply_normalize_sorts(def_id, genv, &infcx)
586        } else {
587            Ok(t.clone())
588        }
589    }
590
591    // FIXME(nilehmann) this is a bit of a hack. We should find a more robust way to do normalization
592    // for sort checking, including normalizing projections. [RJ: normalizing_projections is done now]
593    // Maybe we can do lazy normalization. Once we do that maybe we can also stop
594    // expanding aliases in `ConvCtxt::conv_sort`.
595    pub(crate) fn normalize_sorts(&mut self) -> QueryResult {
596        for sort in self.node_sort.values_mut() {
597            *sort = Self::deeply_normalize_sorts(self.genv, self.owner, sort)?;
598        }
599        for fsort in self.sort_of_alias_reft.values_mut() {
600            *fsort = Self::deeply_normalize_sorts(self.genv, self.owner, fsort)?;
601        }
602        Ok(())
603    }
604}
605
606impl<'genv> InferCtxt<'genv, '_> {
607    pub(super) fn declare_param(&mut self, param: fhir::RefineParam<'genv>, sort: rty::Sort) {
608        self.params.insert(param.id, (param, sort));
609    }
610
611    /// Whether a value of `sort1` can be automatically coerced to a value of `sort2`. A value of an
612    /// [`rty::SortCtor::Adt`] sort with a single field of sort `s` can be coerced to a value of sort
613    /// `s` and vice versa, i.e., we can automatically project the field out of the record or inject
614    /// a value into a record.
615    fn is_coercible(&mut self, sort1: &rty::Sort, sort2: &rty::Sort, fhir_id: FhirId) -> bool {
616        if self.try_equate(sort1, sort2).is_some() {
617            return true;
618        }
619
620        let mut sort1 = sort1.clone();
621        let mut sort2 = sort2.clone();
622
623        let mut coercions = vec![];
624        if let Some((def_id, sort)) = self.is_single_field_struct(&sort1) {
625            coercions.push(rty::Coercion::Project(def_id));
626            sort1 = sort.clone();
627        }
628        if let Some((def_id, sort)) = self.is_single_field_struct(&sort2) {
629            coercions.push(rty::Coercion::Inject(def_id));
630            sort2 = sort.clone();
631        }
632        self.wfckresults.coercions_mut().insert(fhir_id, coercions);
633        self.try_equate(&sort1, &sort2).is_some()
634    }
635
636    fn is_coercible_from_func(
637        &mut self,
638        sort: &rty::Sort,
639        fhir_id: FhirId,
640    ) -> Option<rty::PolyFuncSort> {
641        if let rty::Sort::Func(fsort) = sort {
642            Some(fsort.clone())
643        } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
644            self.wfckresults
645                .coercions_mut()
646                .insert(fhir_id, vec![rty::Coercion::Inject(def_id)]);
647            Some(fsort.clone())
648        } else {
649            None
650        }
651    }
652
653    fn is_coercible_to_func(
654        &mut self,
655        sort: &rty::Sort,
656        fhir_id: FhirId,
657    ) -> Option<rty::PolyFuncSort> {
658        if let rty::Sort::Func(fsort) = sort {
659            Some(fsort.clone())
660        } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
661            self.wfckresults
662                .coercions_mut()
663                .insert(fhir_id, vec![rty::Coercion::Project(def_id)]);
664            Some(fsort.clone())
665        } else {
666            None
667        }
668    }
669
670    fn try_equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
671        let sort1 = self.resolve_vars_if_possible(sort1);
672        let sort2 = self.resolve_vars_if_possible(sort2);
673        self.try_equate_inner(&sort1, &sort2)
674    }
675
676    fn try_equate_inner(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
677        match (sort1, sort2) {
678            (rty::Sort::Infer(vid1), rty::Sort::Infer(vid2)) => {
679                self.sort_unification_table
680                    .unify_var_var(*vid1, *vid2)
681                    .ok()?;
682            }
683            (rty::Sort::Infer(vid), sort) | (sort, rty::Sort::Infer(vid)) => {
684                self.sort_unification_table
685                    .unify_var_value(*vid, rty::SortVarVal::Solved(sort.clone()))
686                    .ok()?;
687            }
688            (rty::Sort::App(ctor1, args1), rty::Sort::App(ctor2, args2)) => {
689                if ctor1 != ctor2 || args1.len() != args2.len() {
690                    return None;
691                }
692                for (s1, s2) in iter::zip(args1, args2) {
693                    self.try_equate_inner(s1, s2)?;
694                }
695            }
696            (rty::Sort::Tuple(sorts1), rty::Sort::Tuple(sorts2)) => {
697                if sorts1.len() != sorts2.len() {
698                    return None;
699                }
700                for (s1, s2) in iter::zip(sorts1, sorts2) {
701                    self.try_equate_inner(s1, s2)?;
702                }
703            }
704            (rty::Sort::BitVec(size1), rty::Sort::BitVec(size2)) => {
705                self.try_equate_bv_sizes(*size1, *size2)?;
706            }
707            _ if sort1 == sort2 => {}
708            _ => return None,
709        }
710        Some(sort1.clone())
711    }
712
713    fn try_equate_bv_sizes(
714        &mut self,
715        size1: rty::BvSize,
716        size2: rty::BvSize,
717    ) -> Option<rty::BvSize> {
718        match (size1, size2) {
719            (rty::BvSize::Infer(vid1), rty::BvSize::Infer(vid2)) => {
720                self.bv_size_unification_table
721                    .unify_var_var(vid1, vid2)
722                    .ok()?;
723            }
724            (rty::BvSize::Infer(vid), size) | (size, rty::BvSize::Infer(vid)) => {
725                self.bv_size_unification_table
726                    .unify_var_value(vid, Some(size))
727                    .ok()?;
728            }
729            _ if size1 == size2 => {}
730            _ => return None,
731        }
732        Some(size1)
733    }
734
735    fn next_sort_var(&mut self) -> rty::Sort {
736        rty::Sort::Infer(self.next_sort_vid(Default::default()))
737    }
738
739    fn next_sort_var_with_cstr(&mut self, cstr: rty::SortCstr) -> rty::Sort {
740        rty::Sort::Infer(self.next_sort_vid(rty::SortVarVal::Unsolved(cstr)))
741    }
742
743    pub(crate) fn next_sort_vid(&mut self, val: rty::SortVarVal) -> rty::SortVid {
744        self.sort_unification_table.new_key(val)
745    }
746
747    fn next_bv_size_var(&mut self) -> rty::BvSize {
748        rty::BvSize::Infer(self.next_bv_size_vid())
749    }
750
751    fn next_bv_size_vid(&mut self) -> rty::BvSizeVid {
752        self.bv_size_unification_table.new_key(None)
753    }
754
755    fn ensure_resolved_path(&mut self, path: &fhir::PathExpr) -> Result<rty::Sort> {
756        let sort = self.synth_path(path);
757        self.fully_resolve(&sort)
758            .map_err(|_| self.emit_err(errors::CannotInferSort::new(path.span)))
759    }
760
761    fn is_single_field_struct(&mut self, sort: &rty::Sort) -> Option<(DefId, rty::Sort)> {
762        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = sort
763            && let Some(variant) = sort_def.opt_struct_variant()
764            && let [sort] = &variant.field_sorts(sort_args)[..]
765        {
766            Some((sort_def.did(), sort.clone()))
767        } else {
768            None
769        }
770    }
771
772    pub(crate) fn into_results(mut self) -> Result<WfckResults> {
773        // Make sure the int literals are fully resolved. This must happen before everything else
774        // such that we properly apply the fallback for unconstrained num vars.
775        for (node, sort) in std::mem::take(&mut self.sort_of_literal) {
776            // Fallback to `int` when a num variable is unconstrained. Note that we unconditionally
777            // unify the variable. This is fine because if the variable has already been unified,
778            // the operation will fail and this won't have any effect. Also note that unifying a
779            // variable could solve variables that appear later in this for loop. This is also fine
780            // because the order doesn't matter as we are unifying everything to `int`.
781            if let rty::Sort::Infer(vid) = &sort {
782                let _ = self
783                    .sort_unification_table
784                    .unify_var_value(*vid, rty::SortVarVal::Solved(rty::Sort::Int));
785            }
786
787            let sort = self
788                .fully_resolve(&sort)
789                .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
790            self.wfckresults.node_sorts_mut().insert(node.fhir_id, sort);
791        }
792
793        // Make sure the binary operators are fully resolved
794        for (node, sort) in std::mem::take(&mut self.sort_of_bin_op) {
795            let sort = self
796                .fully_resolve(&sort)
797                .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
798            self.wfckresults
799                .bin_op_sorts_mut()
800                .insert(node.fhir_id, sort);
801        }
802
803        let allow_uninterpreted_cast =
804            self.owner
805                .as_rust()
806                .map_or_else(flux_config::allow_uninterpreted_cast, |def_id| {
807                    self.genv
808                        .infer_opts(def_id.local_id().def_id)
809                        .allow_uninterpreted_cast
810                });
811
812        // Make sure that function applications are fully resolved
813        for (node, sort_args) in std::mem::take(&mut self.sort_args_of_app) {
814            let mut res = vec![];
815            for sort_arg in &sort_args {
816                let sort_arg = match sort_arg {
817                    rty::SortArg::Sort(sort) => {
818                        let sort = self
819                            .fully_resolve(sort)
820                            .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
821                        rty::SortArg::Sort(sort)
822                    }
823                    rty::SortArg::BvSize(rty::BvSize::Infer(vid)) => {
824                        let size = self
825                            .bv_size_unification_table
826                            .probe_value(*vid)
827                            .ok_or_else(|| {
828                                self.emit_err(errors::CannotInferSort::new(node.span))
829                            })?;
830                        rty::SortArg::BvSize(size)
831                    }
832                    _ => sort_arg.clone(),
833                };
834                res.push(sort_arg);
835            }
836            if let fhir::ExprKind::App(callee, _) = node.kind
837                && matches!(callee.res, fhir::Res::GlobalFunc(fhir::SpecFuncKind::Cast))
838            {
839                let [rty::SortArg::Sort(from), rty::SortArg::Sort(to)] = &res[..] else {
840                    span_bug!(node.span, "invalid sort args!")
841                };
842                if !allow_uninterpreted_cast
843                    && matches!(from.cast_kind(to), rty::CastKind::Uninterpreted)
844                {
845                    return Err(self.emit_err(errors::InvalidCast::new(node.span, from, to)));
846                }
847            }
848            self.wfckresults
849                .fn_app_sorts_mut()
850                .insert(node.fhir_id, res.into());
851        }
852
853        // Make sure all parameters are fully resolved
854        for (_, (param, sort)) in std::mem::take(&mut self.params) {
855            let sort = self
856                .fully_resolve(&sort)
857                .map_err(|_| self.emit_err(errors::SortAnnotationNeeded::new(&param)))?;
858            self.wfckresults.param_sorts_mut().insert(param.id, sort);
859        }
860
861        Ok(self.wfckresults)
862    }
863
864    pub(crate) fn infer_mode(&self, id: fhir::ParamId) -> fhir::InferMode {
865        fhir::InferMode::from_param_kind(self.params[&id].0.kind)
866    }
867
868    #[track_caller]
869    pub(crate) fn param_sort(&self, id: fhir::ParamId) -> rty::Sort {
870        self.params
871            .get(&id)
872            .unwrap_or_else(|| bug!("no entry found for `{id:?}`"))
873            .1
874            .clone()
875    }
876
877    fn shallow_resolve(&mut self, sort: &rty::Sort) -> rty::Sort {
878        sort.fold_with(&mut ShallowResolver { infcx: self })
879    }
880
881    fn resolve_vars_if_possible(&mut self, sort: &rty::Sort) -> rty::Sort {
882        sort.fold_with(&mut OpportunisticResolver { infcx: self })
883    }
884
885    pub(crate) fn fully_resolve(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, ()> {
886        sort.try_fold_with(&mut FullResolver { infcx: self })
887    }
888}
889
890/// Before the main sort inference, we do a first traversal checking all implicitly scoped parameters
891/// declared with `@` or `#` and infer their sort based on the type they are indexing, e.g., if `n` was
892/// declared as `i32[@n]`, we infer `int` for `n`.
893///
894/// This prepass is necessary because sometimes the order in which we traverse expressions can
895/// affect what we can infer. By resolving implicit parameters first, we ensure more consistent and
896/// complete inference regardless of how expressions are later traversed.
897///
898/// It should be possible to improve sort checking (e.g., by allowing partially resolved sorts in
899/// function position) such that we don't need this anymore.
900pub(crate) struct ImplicitParamInferer<'a, 'genv, 'tcx> {
901    infcx: &'a mut InferCtxt<'genv, 'tcx>,
902    errors: Errors<'genv>,
903}
904
905impl<'a, 'genv, 'tcx> ImplicitParamInferer<'a, 'genv, 'tcx> {
906    pub(crate) fn infer(
907        infcx: &'a mut InferCtxt<'genv, 'tcx>,
908        node: &fhir::OwnerNode<'genv>,
909    ) -> Result {
910        let errors = Errors::new(infcx.genv.sess());
911        let mut vis = Self { infcx, errors };
912        vis.visit_node(node);
913        vis.errors.to_result()
914    }
915
916    fn infer_implicit_params(&mut self, idx: &fhir::Expr, expected: &rty::Sort) {
917        match idx.kind {
918            fhir::ExprKind::Var(QPathExpr::Resolved(var, Some(_))) => {
919                let (_, id) = var.res.expect_param();
920                let found = self.infcx.param_sort(id);
921                self.infcx.try_equate(&found, expected).unwrap_or_else(|| {
922                    span_bug!(idx.span, "failed to equate sorts: `{found:?}` `{expected:?}`")
923                });
924            }
925            fhir::ExprKind::Record(flds) => {
926                if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
927                    let sorts = sort_def.struct_variant().field_sorts(sort_args);
928                    if flds.len() != sorts.len() {
929                        self.errors.emit(errors::ArgCountMismatch::new(
930                            Some(idx.span),
931                            String::from("type"),
932                            sorts.len(),
933                            flds.len(),
934                        ));
935                        return;
936                    }
937                    for (f, sort) in iter::zip(flds, &sorts) {
938                        self.infer_implicit_params(f, sort);
939                    }
940                } else {
941                    self.errors.emit(errors::ArgCountMismatch::new(
942                        Some(idx.span),
943                        String::from("type"),
944                        1,
945                        flds.len(),
946                    ));
947                }
948            }
949            _ => {}
950        }
951    }
952}
953
954impl<'genv> fhir::visit::Visitor<'genv> for ImplicitParamInferer<'_, 'genv, '_> {
955    fn visit_ty(&mut self, ty: &fhir::Ty<'genv>) {
956        if let fhir::TyKind::Indexed(bty, idx) = &ty.kind {
957            let expected = self.infcx.sort_of_bty(bty);
958            self.infer_implicit_params(idx, &expected);
959        }
960        fhir::visit::walk_ty(self, ty);
961    }
962}
963
964impl InferCtxt<'_, '_> {
965    #[track_caller]
966    fn emit_sort_mismatch(
967        &mut self,
968        span: Span,
969        expected: &rty::Sort,
970        found: &rty::Sort,
971    ) -> ErrorGuaranteed {
972        let expected = self.resolve_vars_if_possible(expected);
973        let found = self.resolve_vars_if_possible(found);
974        self.emit_err(errors::SortMismatch::new(span, expected, found))
975    }
976
977    fn emit_field_not_found(&mut self, sort: &rty::Sort, field: Ident) -> ErrorGuaranteed {
978        self.emit_err(errors::FieldNotFound::new(sort.clone(), field))
979    }
980
981    #[track_caller]
982    fn emit_err<'b>(&'b self, err: impl Diagnostic<'b>) -> ErrorGuaranteed {
983        self.genv.sess().emit_err(err)
984    }
985}
986
987struct ShallowResolver<'a, 'genv, 'tcx> {
988    infcx: &'a mut InferCtxt<'genv, 'tcx>,
989}
990
991impl TypeFolder for ShallowResolver<'_, '_, '_> {
992    fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
993        match sort {
994            rty::Sort::Infer(vid) => {
995                // if `sort` is a sort variable, it can be resolved to a bit-vec variable,
996                // which can then be recursively resolved, hence the recursion. Note though that
997                // we prevent sort variables from unifying to other sort variables directly (though
998                // they may be embedded structurally), so this recursion should always be of very
999                // limited depth.
1000                self.infcx
1001                    .sort_unification_table
1002                    .probe_value(*vid)
1003                    .map_solved(|sort| sort.fold_with(self))
1004                    .solved_or(sort)
1005            }
1006            rty::Sort::BitVec(rty::BvSize::Infer(vid)) => {
1007                self.infcx
1008                    .bv_size_unification_table
1009                    .probe_value(*vid)
1010                    .map(rty::Sort::BitVec)
1011                    .unwrap_or(sort.clone())
1012            }
1013            _ => sort.clone(),
1014        }
1015    }
1016}
1017
1018struct OpportunisticResolver<'a, 'genv, 'tcx> {
1019    infcx: &'a mut InferCtxt<'genv, 'tcx>,
1020}
1021
1022impl TypeFolder for OpportunisticResolver<'_, '_, '_> {
1023    fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
1024        let s = self.infcx.shallow_resolve(sort);
1025        s.super_fold_with(self)
1026    }
1027}
1028
1029struct FullResolver<'a, 'genv, 'tcx> {
1030    infcx: &'a mut InferCtxt<'genv, 'tcx>,
1031}
1032
1033impl FallibleTypeFolder for FullResolver<'_, '_, '_> {
1034    type Error = ();
1035
1036    fn try_fold_sort(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, Self::Error> {
1037        let s = self.infcx.shallow_resolve(sort);
1038        match s {
1039            rty::Sort::Infer(_) | rty::Sort::BitVec(rty::BvSize::Infer(_)) => Err(()),
1040            _ => s.try_super_fold_with(self),
1041        }
1042    }
1043}
1044
1045/// Map to associate data to a node (i.e., an expression).
1046///
1047/// Used to record elaborated information.
1048#[derive_where(Default)]
1049struct NodeMap<'genv, T> {
1050    map: FxHashMap<FhirId, (fhir::Expr<'genv>, T)>,
1051}
1052
1053impl<'genv, T> NodeMap<'genv, T> {
1054    /// Add a `node` to the map with associated `data`
1055    fn insert(&mut self, node: fhir::Expr<'genv>, data: T) {
1056        assert!(self.map.insert(node.fhir_id, (node, data)).is_none());
1057    }
1058}
1059
1060impl<'genv, T> IntoIterator for NodeMap<'genv, T> {
1061    type Item = (fhir::Expr<'genv>, T);
1062
1063    type IntoIter = std::collections::hash_map::IntoValues<FhirId, Self::Item>;
1064
1065    fn into_iter(self) -> Self::IntoIter {
1066        self.map.into_values()
1067    }
1068}