flux_fhir_analysis/wf/
sortck.rs

1use std::iter;
2
3use derive_where::derive_where;
4use ena::unify::InPlaceUnificationTable;
5use flux_common::{bug, iter::IterExt, span_bug, tracked_span_bug};
6use flux_errors::{ErrorGuaranteed, Errors};
7use flux_infer::projections::NormalizeExt;
8use flux_middle::{
9    fhir::{self, FhirId, FluxOwnerId, QPathExpr, visit::Visitor as _},
10    global_env::GlobalEnv,
11    queries::QueryResult,
12    rty::{
13        self, AdtSortDef, FuncSort, List, WfckResults,
14        fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable},
15    },
16};
17use itertools::{Itertools, izip};
18use rustc_data_structures::unord::UnordMap;
19use rustc_errors::Diagnostic;
20use rustc_hash::FxHashMap;
21use rustc_hir::def::DefKind;
22use rustc_middle::ty::TypingMode;
23use rustc_span::{Span, def_id::DefId, symbol::Ident};
24
25use super::errors;
26use crate::rustc_infer::infer::TyCtxtInferExt;
27
28type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
29
30pub(super) struct InferCtxt<'genv, 'tcx> {
31    pub genv: GlobalEnv<'genv, 'tcx>,
32    pub owner: FluxOwnerId,
33    pub wfckresults: WfckResults,
34    sort_unification_table: InPlaceUnificationTable<rty::SortVid>,
35    num_unification_table: InPlaceUnificationTable<rty::NumVid>,
36    bv_size_unification_table: InPlaceUnificationTable<rty::BvSizeVid>,
37    params: FxHashMap<fhir::ParamId, (fhir::RefineParam<'genv>, rty::Sort)>,
38    node_sort: FxHashMap<FhirId, rty::Sort>,
39    path_args: UnordMap<FhirId, rty::GenericArgs>,
40    sort_of_alias_reft: FxHashMap<FhirId, rty::FuncSort>,
41    sort_of_literal: NodeMap<'genv, rty::Sort>,
42    sort_of_bin_op: NodeMap<'genv, rty::Sort>,
43    sort_args_of_app: NodeMap<'genv, List<rty::SortArg>>,
44}
45
46pub fn prim_op_sort(op: &fhir::BinOp) -> Option<(Vec<rty::Sort>, rty::Sort)> {
47    match op {
48        fhir::BinOp::BitAnd
49        | fhir::BinOp::BitOr
50        | fhir::BinOp::BitXor
51        | fhir::BinOp::BitShl
52        | fhir::BinOp::BitShr => Some((vec![rty::Sort::Int, rty::Sort::Int], rty::Sort::Int)),
53        _ => None,
54    }
55}
56
57impl<'genv, 'tcx> InferCtxt<'genv, 'tcx> {
58    pub(super) fn new(genv: GlobalEnv<'genv, 'tcx>, owner: FluxOwnerId) -> Self {
59        // We skip 0 because that's used for sort dummy self types during conv.
60        let mut sort_unification_table = InPlaceUnificationTable::new();
61        sort_unification_table.new_key(None);
62        Self {
63            genv,
64            owner,
65            wfckresults: WfckResults::new(owner),
66            sort_unification_table,
67            num_unification_table: InPlaceUnificationTable::new(),
68            bv_size_unification_table: InPlaceUnificationTable::new(),
69            params: Default::default(),
70            node_sort: Default::default(),
71            path_args: Default::default(),
72            sort_of_alias_reft: Default::default(),
73            sort_of_literal: Default::default(),
74            sort_of_bin_op: Default::default(),
75            sort_args_of_app: Default::default(),
76        }
77    }
78
79    fn check_abs(
80        &mut self,
81        expr: &fhir::Expr<'genv>,
82        params: &[fhir::RefineParam],
83        body: &fhir::Expr<'genv>,
84        expected: &rty::Sort,
85    ) -> Result {
86        let Some(fsort) = self.is_coercible_from_func(expected, expr.fhir_id) else {
87            return Err(self.emit_err(errors::UnexpectedFun::new(expr.span, expected)));
88        };
89
90        let fsort = fsort.expect_mono();
91
92        if params.len() != fsort.inputs().len() {
93            return Err(self.emit_err(errors::ParamCountMismatch::new(
94                expr.span,
95                fsort.inputs().len(),
96                params.len(),
97            )));
98        }
99        iter::zip(params, fsort.inputs()).try_for_each_exhaust(|(param, expected)| {
100            let found = self.param_sort(param.id);
101            if self.try_equate(&found, expected).is_none() {
102                return Err(self.emit_sort_mismatch(param.span, expected, &found));
103            }
104            Ok(())
105        })?;
106        self.check_expr(body, fsort.output())?;
107        self.wfckresults
108            .node_sorts_mut()
109            .insert(body.fhir_id, fsort.output().clone());
110        Ok(())
111    }
112
113    fn check_field_exprs(
114        &mut self,
115        expr_span: Span,
116        sort_def: &AdtSortDef,
117        sort_args: &[rty::Sort],
118        field_exprs: &[fhir::FieldExpr<'genv>],
119        spread: &Option<&fhir::Spread<'genv>>,
120        expected: &rty::Sort,
121    ) -> Result {
122        let sort_by_field_name = sort_def.struct_variant().sort_by_field_name(sort_args);
123        let mut used_fields = FxHashMap::default();
124        for expr in field_exprs {
125            // make sure that the field is actually a field
126            let Some(sort) = sort_by_field_name.get(&expr.ident.name) else {
127                return Err(self.emit_err(errors::FieldNotFound::new(expected.clone(), expr.ident)));
128            };
129            if let Some(old_field) = used_fields.insert(expr.ident.name, expr.ident) {
130                return Err(self.emit_err(errors::DuplicateFieldUsed::new(expr.ident, old_field)));
131            }
132            self.check_expr(&expr.expr, sort)?;
133        }
134        if let Some(spread) = spread {
135            // must check that the spread is of the same sort as the constructor
136            self.check_expr(&spread.expr, expected)
137        } else if sort_by_field_name.len() != used_fields.len() {
138            // emit an error because all fields are not used
139            let missing_fields = sort_by_field_name
140                .into_keys()
141                .filter(|x| !used_fields.contains_key(x))
142                .collect();
143            Err(self.emit_err(errors::ConstructorMissingFields::new(expr_span, missing_fields)))
144        } else {
145            Ok(())
146        }
147    }
148
149    fn check_constructor(
150        &mut self,
151        expr: &fhir::Expr,
152        field_exprs: &[fhir::FieldExpr<'genv>],
153        spread: &Option<&fhir::Spread<'genv>>,
154        expected: &rty::Sort,
155    ) -> Result {
156        let expected = self.resolve_vars_if_possible(expected);
157        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = &expected {
158            self.wfckresults
159                .record_ctors_mut()
160                .insert(expr.fhir_id, sort_def.did());
161            self.check_field_exprs(expr.span, sort_def, sort_args, field_exprs, spread, &expected)?;
162            Ok(())
163        } else {
164            Err(self.emit_err(errors::UnexpectedConstructor::new(expr.span, &expected)))
165        }
166    }
167
168    fn check_record(
169        &mut self,
170        arg: &fhir::Expr<'genv>,
171        flds: &[fhir::Expr<'genv>],
172        expected: &rty::Sort,
173    ) -> Result {
174        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
175            let sorts = sort_def.struct_variant().field_sorts(sort_args);
176            if flds.len() != sorts.len() {
177                return Err(self.emit_err(errors::ArgCountMismatch::new(
178                    Some(arg.span),
179                    String::from("type"),
180                    sorts.len(),
181                    flds.len(),
182                )));
183            }
184            self.wfckresults
185                .record_ctors_mut()
186                .insert(arg.fhir_id, sort_def.did());
187
188            izip!(flds, &sorts)
189                .map(|(arg, expected)| self.check_expr(arg, expected))
190                .try_collect_exhaust()
191        } else {
192            Err(self.emit_err(errors::ArgCountMismatch::new(
193                Some(arg.span),
194                String::from("type"),
195                1,
196                flds.len(),
197            )))
198        }
199    }
200
201    pub(super) fn check_expr(&mut self, expr: &fhir::Expr<'genv>, expected: &rty::Sort) -> Result {
202        match &expr.kind {
203            fhir::ExprKind::IfThenElse(p, e1, e2) => {
204                self.check_expr(p, &rty::Sort::Bool)?;
205                self.check_expr(e1, expected)?;
206                self.check_expr(e2, expected)?;
207            }
208            fhir::ExprKind::Abs(params, body) => {
209                self.check_abs(expr, params, body, expected)?;
210            }
211            fhir::ExprKind::Record(fields) => self.check_record(expr, fields, expected)?,
212            fhir::ExprKind::Constructor(None, exprs, spread) => {
213                self.check_constructor(expr, exprs, spread, expected)?;
214            }
215            fhir::ExprKind::UnaryOp(..)
216            | fhir::ExprKind::BinaryOp(..)
217            | fhir::ExprKind::Dot(..)
218            | fhir::ExprKind::App(..)
219            | fhir::ExprKind::Alias(..)
220            | fhir::ExprKind::Var(..)
221            | fhir::ExprKind::Literal(..)
222            | fhir::ExprKind::BoundedQuant(..)
223            | fhir::ExprKind::Block(..)
224            | fhir::ExprKind::Constructor(..)
225            | fhir::ExprKind::PrimApp(..) => {
226                let found = self.synth_expr(expr)?;
227                let found = self.resolve_vars_if_possible(&found);
228                let expected = self.resolve_vars_if_possible(expected);
229                if !self.is_coercible(&found, &expected, expr.fhir_id) {
230                    return Err(self.emit_sort_mismatch(expr.span, &expected, &found));
231                }
232            }
233            fhir::ExprKind::Err(_) => {
234                // an error has already been reported so we can just skip
235            }
236        }
237        Ok(())
238    }
239
240    pub(super) fn check_loc(&mut self, loc: &fhir::PathExpr) -> Result {
241        let found = self.synth_path(loc);
242        if found == rty::Sort::Loc {
243            Ok(())
244        } else {
245            Err(self.emit_sort_mismatch(loc.span, &rty::Sort::Loc, &found))
246        }
247    }
248
249    fn synth_lit(&mut self, lit: fhir::Lit, expr: &fhir::Expr<'genv>) -> rty::Sort {
250        match lit {
251            fhir::Lit::Int(_, Some(fhir::NumLitKind::Int)) => rty::Sort::Int,
252            fhir::Lit::Int(_, Some(fhir::NumLitKind::Real)) => rty::Sort::Real,
253            fhir::Lit::Int(_, None) => {
254                let sort = self.next_num_var();
255                self.sort_of_literal.insert(*expr, sort.clone());
256                sort
257            }
258            fhir::Lit::Bool(_) => rty::Sort::Bool,
259            fhir::Lit::Str(_) => rty::Sort::Str,
260            fhir::Lit::Char(_) => rty::Sort::Char,
261        }
262    }
263
264    fn synth_prim_app(
265        &mut self,
266        op: &fhir::BinOp,
267        e1: &fhir::Expr<'genv>,
268        e2: &fhir::Expr<'genv>,
269        span: Span,
270    ) -> Result<rty::Sort> {
271        let Some((inputs, output)) = prim_op_sort(op) else {
272            return Err(self.emit_err(errors::UnsupportedPrimOp::new(span, *op)));
273        };
274        let [sort1, sort2] = &inputs[..] else {
275            return Err(self.emit_err(errors::ArgCountMismatch::new(
276                Some(span),
277                String::from("primop app"),
278                inputs.len(),
279                2,
280            )));
281        };
282        self.check_expr(e1, sort1)?;
283        self.check_expr(e2, sort2)?;
284        Ok(output)
285    }
286
287    fn synth_expr(&mut self, expr: &fhir::Expr<'genv>) -> Result<rty::Sort> {
288        match expr.kind {
289            fhir::ExprKind::Var(QPathExpr::Resolved(path, _)) => Ok(self.synth_path(&path)),
290            fhir::ExprKind::Var(QPathExpr::TypeRelative(..)) => {
291                Ok(self.synth_type_relative_path(expr))
292            }
293            fhir::ExprKind::Literal(lit) => Ok(self.synth_lit(lit, expr)),
294            fhir::ExprKind::BinaryOp(op, e1, e2) => self.synth_binary_op(expr, op, e1, e2),
295            fhir::ExprKind::PrimApp(op, e1, e2) => self.synth_prim_app(&op, e1, e2, expr.span),
296            fhir::ExprKind::UnaryOp(op, e) => self.synth_unary_op(op, e),
297            fhir::ExprKind::App(callee, args) => {
298                let sort = self.ensure_resolved_path(&callee)?;
299                let Some(poly_fsort) = self.is_coercible_to_func(&sort, callee.fhir_id) else {
300                    return Err(self.emit_err(errors::ExpectedFun::new(callee.span, &sort)));
301                };
302                let fsort = self.instantiate_func_sort(expr, poly_fsort);
303                self.synth_app(fsort, args, expr.span)
304            }
305            fhir::ExprKind::BoundedQuant(.., body) => {
306                self.check_expr(body, &rty::Sort::Bool)?;
307                Ok(rty::Sort::Bool)
308            }
309            fhir::ExprKind::Alias(_alias_reft, args) => {
310                // To check the application we only need the sort of `_alias_reft` which we collected
311                // during early conv, but should we do any extra checks on `_alias_reft`?
312                let fsort = self.sort_of_alias_reft(expr.fhir_id);
313                self.synth_app(fsort, args, expr.span)
314            }
315            fhir::ExprKind::IfThenElse(p, e1, e2) => {
316                self.check_expr(p, &rty::Sort::Bool)?;
317                let sort = self.synth_expr(e1)?;
318                self.check_expr(e2, &sort)?;
319                Ok(sort)
320            }
321            fhir::ExprKind::Dot(base, fld) => {
322                let sort = self.synth_expr(base)?;
323                let sort = self
324                    .fully_resolve(&sort)
325                    .map_err(|_| self.emit_err(errors::CannotInferSort::new(base.span)))?;
326                match &sort {
327                    rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) => {
328                        let (proj, sort) = sort_def
329                            .struct_variant()
330                            .field_by_name(sort_def.did(), sort_args, fld.name)
331                            .ok_or_else(|| self.emit_field_not_found(&sort, fld))?;
332                        self.wfckresults
333                            .field_projs_mut()
334                            .insert(expr.fhir_id, proj);
335                        Ok(sort)
336                    }
337                    rty::Sort::Bool | rty::Sort::Int | rty::Sort::Real => {
338                        Err(self.emit_err(errors::InvalidPrimitiveDotAccess::new(&sort, fld)))
339                    }
340                    _ => Err(self.emit_field_not_found(&sort, fld)),
341                }
342            }
343            fhir::ExprKind::Constructor(Some(path), field_exprs, spread) => {
344                // if we have S then we can synthesize otherwise we fail
345                // first get the sort based on the path - for example S { ... } => S
346                // and we should expect sort to be a struct or enum app
347                let path_def_id = match path.res {
348                    fhir::Res::Def(DefKind::Enum | DefKind::Struct, def_id) => def_id,
349                    _ => span_bug!(expr.span, "unexpected path in constructor"),
350                };
351                let sort_def = self
352                    .genv
353                    .adt_sort_def_of(path_def_id)
354                    .map_err(|e| self.emit_err(e))?;
355                // generate fresh inferred sorts for each param
356                let fresh_args: rty::List<_> = (0..sort_def.param_count())
357                    .map(|_| self.next_sort_var())
358                    .collect();
359                let sort = rty::Sort::App(rty::SortCtor::Adt(sort_def.clone()), fresh_args.clone());
360                // check fields & spread against fresh args
361                self.check_field_exprs(
362                    expr.span,
363                    &sort_def,
364                    &fresh_args,
365                    field_exprs,
366                    &spread,
367                    &sort,
368                )?;
369                Ok(sort)
370            }
371            fhir::ExprKind::Block(decls, body) => {
372                for decl in decls {
373                    self.check_expr(&decl.init, &self.param_sort(decl.param.id))?;
374                }
375                self.synth_expr(body)
376            }
377            _ => Err(self.emit_err(errors::CannotInferSort::new(expr.span))),
378        }
379    }
380
381    fn synth_path(&mut self, path: &fhir::PathExpr) -> rty::Sort {
382        self.node_sort
383            .get(&path.fhir_id)
384            .unwrap_or_else(|| tracked_span_bug!("no sort found for path: `{path:?}`"))
385            .clone()
386    }
387
388    fn synth_type_relative_path(&mut self, expr: &fhir::Expr) -> rty::Sort {
389        self.node_sort
390            .get(&expr.fhir_id)
391            .unwrap_or_else(|| tracked_span_bug!("no sort found for: `{expr:?}`"))
392            .clone()
393    }
394
395    fn check_integral(&mut self, op: fhir::BinOp, sort: &rty::Sort, span: Span) -> Result {
396        if matches!(op, fhir::BinOp::Mod) {
397            let sort = self
398                .fully_resolve(sort)
399                .map_err(|_| self.emit_err(errors::CannotInferSort::new(span)))?;
400            if !matches!(sort, rty::Sort::Int | rty::Sort::BitVec(_)) {
401                span_bug!(span, "unexpected sort {sort:?} for operator {op:?}");
402            }
403        }
404        Ok(())
405    }
406
407    fn synth_binary_op(
408        &mut self,
409        expr: &fhir::Expr<'genv>,
410        op: fhir::BinOp,
411        e1: &fhir::Expr<'genv>,
412        e2: &fhir::Expr<'genv>,
413    ) -> Result<rty::Sort> {
414        match op {
415            fhir::BinOp::Or | fhir::BinOp::And | fhir::BinOp::Iff | fhir::BinOp::Imp => {
416                self.check_expr(e1, &rty::Sort::Bool)?;
417                self.check_expr(e2, &rty::Sort::Bool)?;
418                Ok(rty::Sort::Bool)
419            }
420            fhir::BinOp::Eq | fhir::BinOp::Ne => {
421                let sort = self.next_sort_var();
422                self.check_expr(e1, &sort)?;
423                self.check_expr(e2, &sort)?;
424                Ok(rty::Sort::Bool)
425            }
426            fhir::BinOp::Lt | fhir::BinOp::Le | fhir::BinOp::Gt | fhir::BinOp::Ge => {
427                let sort = self.next_sort_var();
428                self.check_expr(e1, &sort)?;
429                self.check_expr(e2, &sort)?;
430                self.sort_of_bin_op.insert(*expr, sort.clone());
431                Ok(rty::Sort::Bool)
432            }
433            fhir::BinOp::Add
434            | fhir::BinOp::Sub
435            | fhir::BinOp::Mul
436            | fhir::BinOp::Div
437            | fhir::BinOp::Mod => {
438                let sort = self.next_num_var();
439                self.check_expr(e1, &sort)?;
440                self.check_expr(e2, &sort)?;
441                self.sort_of_bin_op.insert(*expr, sort.clone());
442                // check that the sort is integral for mod
443                self.check_integral(op, &sort, expr.span)?;
444
445                Ok(sort)
446            }
447            fhir::BinOp::BitAnd
448            | fhir::BinOp::BitOr
449            | fhir::BinOp::BitXor
450            | fhir::BinOp::BitShl
451            | fhir::BinOp::BitShr => {
452                let sort = rty::Sort::BitVec(self.next_bv_size_var());
453                self.check_expr(e1, &sort)?;
454                self.check_expr(e2, &sort)?;
455                Ok(sort)
456            }
457        }
458    }
459
460    fn synth_unary_op(&mut self, op: fhir::UnOp, e: &fhir::Expr<'genv>) -> Result<rty::Sort> {
461        match op {
462            fhir::UnOp::Not => {
463                self.check_expr(e, &rty::Sort::Bool)?;
464                Ok(rty::Sort::Bool)
465            }
466            fhir::UnOp::Neg => {
467                self.check_expr(e, &rty::Sort::Int)?;
468                Ok(rty::Sort::Int)
469            }
470        }
471    }
472
473    fn synth_app(
474        &mut self,
475        fsort: FuncSort,
476        args: &[fhir::Expr<'genv>],
477        span: Span,
478    ) -> Result<rty::Sort> {
479        if args.len() != fsort.inputs().len() {
480            return Err(self.emit_err(errors::ArgCountMismatch::new(
481                Some(span),
482                String::from("function"),
483                fsort.inputs().len(),
484                args.len(),
485            )));
486        }
487
488        iter::zip(args, fsort.inputs())
489            .try_for_each_exhaust(|(arg, formal)| self.check_expr(arg, formal))?;
490
491        Ok(fsort.output().clone())
492    }
493
494    fn instantiate_func_sort(
495        &mut self,
496        app_expr: &fhir::Expr<'genv>,
497        fsort: rty::PolyFuncSort,
498    ) -> rty::FuncSort {
499        let args = fsort
500            .params()
501            .map(|kind| {
502                match kind {
503                    rty::SortParamKind::Sort => rty::SortArg::Sort(self.next_sort_var()),
504                    rty::SortParamKind::BvSize => rty::SortArg::BvSize(self.next_bv_size_var()),
505                }
506            })
507            .collect_vec();
508        self.sort_args_of_app
509            .insert(*app_expr, List::from_slice(&args));
510        fsort.instantiate(&args)
511    }
512
513    pub(crate) fn insert_node_sort(&mut self, fhir_id: FhirId, sort: rty::Sort) {
514        self.node_sort.insert(fhir_id, sort);
515    }
516
517    pub(crate) fn sort_of_bty(&self, bty: &fhir::BaseTy) -> rty::Sort {
518        self.node_sort
519            .get(&bty.fhir_id)
520            .unwrap_or_else(|| tracked_span_bug!("no sort found for bty: `{bty:?}`"))
521            .clone()
522    }
523
524    pub(crate) fn insert_path_args(&mut self, fhir_id: FhirId, args: rty::GenericArgs) {
525        self.path_args.insert(fhir_id, args);
526    }
527
528    pub(crate) fn path_args(&self, fhir_id: FhirId) -> rty::GenericArgs {
529        self.path_args
530            .get(&fhir_id)
531            .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
532            .clone()
533    }
534
535    pub(crate) fn insert_sort_for_alias_reft(&mut self, fhir_id: FhirId, fsort: rty::FuncSort) {
536        self.sort_of_alias_reft.insert(fhir_id, fsort);
537    }
538
539    fn sort_of_alias_reft(&self, fhir_id: FhirId) -> rty::FuncSort {
540        self.sort_of_alias_reft
541            .get(&fhir_id)
542            .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
543            .clone()
544    }
545
546    fn deeply_normalize_sorts<T: TypeFoldable + Clone>(
547        genv: GlobalEnv,
548        owner: FluxOwnerId,
549        t: &T,
550    ) -> QueryResult<T> {
551        let infcx = genv
552            .tcx()
553            .infer_ctxt()
554            .with_next_trait_solver(true)
555            .build(TypingMode::non_body_analysis());
556        if let Some(def_id) = owner.def_id() {
557            let def_id = genv.maybe_extern_id(def_id).resolved_id();
558            t.deeply_normalize_sorts(def_id, genv, &infcx)
559        } else {
560            Ok(t.clone())
561        }
562    }
563
564    // FIXME(nilehmann) this is a bit of a hack. We should find a more robust way to do normalization
565    // for sort checking, including normalizing projections. [RJ: normalizing_projections is done now]
566    // Maybe we can do lazy normalization. Once we do that maybe we can also stop
567    // expanding aliases in `ConvCtxt::conv_sort`.
568    pub(crate) fn normalize_sorts(&mut self) -> QueryResult {
569        for sort in self.node_sort.values_mut() {
570            *sort = Self::deeply_normalize_sorts(self.genv, self.owner, sort)?;
571        }
572        for fsort in self.sort_of_alias_reft.values_mut() {
573            *fsort = Self::deeply_normalize_sorts(self.genv, self.owner, fsort)?;
574        }
575        Ok(())
576    }
577}
578
579impl<'genv> InferCtxt<'genv, '_> {
580    pub(super) fn declare_param(&mut self, param: fhir::RefineParam<'genv>, sort: rty::Sort) {
581        self.params.insert(param.id, (param, sort));
582    }
583
584    /// Whether a value of `sort1` can be automatically coerced to a value of `sort2`. A value of an
585    /// [`rty::SortCtor::Adt`] sort with a single field of sort `s` can be coerced to a value of sort
586    /// `s` and vice versa, i.e., we can automatically project the field out of the record or inject
587    /// a value into a record.
588    fn is_coercible(&mut self, sort1: &rty::Sort, sort2: &rty::Sort, fhir_id: FhirId) -> bool {
589        if self.try_equate(sort1, sort2).is_some() {
590            return true;
591        }
592
593        let mut sort1 = sort1.clone();
594        let mut sort2 = sort2.clone();
595
596        let mut coercions = vec![];
597        if let Some((def_id, sort)) = self.is_single_field_struct(&sort1) {
598            coercions.push(rty::Coercion::Project(def_id));
599            sort1 = sort.clone();
600        }
601        if let Some((def_id, sort)) = self.is_single_field_struct(&sort2) {
602            coercions.push(rty::Coercion::Inject(def_id));
603            sort2 = sort.clone();
604        }
605        self.wfckresults.coercions_mut().insert(fhir_id, coercions);
606        self.try_equate(&sort1, &sort2).is_some()
607    }
608
609    fn is_coercible_from_func(
610        &mut self,
611        sort: &rty::Sort,
612        fhir_id: FhirId,
613    ) -> Option<rty::PolyFuncSort> {
614        if let rty::Sort::Func(fsort) = sort {
615            Some(fsort.clone())
616        } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
617            self.wfckresults
618                .coercions_mut()
619                .insert(fhir_id, vec![rty::Coercion::Inject(def_id)]);
620            Some(fsort.clone())
621        } else {
622            None
623        }
624    }
625
626    fn is_coercible_to_func(
627        &mut self,
628        sort: &rty::Sort,
629        fhir_id: FhirId,
630    ) -> Option<rty::PolyFuncSort> {
631        if let rty::Sort::Func(fsort) = sort {
632            Some(fsort.clone())
633        } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
634            self.wfckresults
635                .coercions_mut()
636                .insert(fhir_id, vec![rty::Coercion::Project(def_id)]);
637            Some(fsort.clone())
638        } else {
639            None
640        }
641    }
642
643    fn try_equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
644        let sort1 = self.resolve_vars_if_possible(sort1);
645        let sort2 = self.resolve_vars_if_possible(sort2);
646        self.try_equate_inner(&sort1, &sort2)
647    }
648
649    fn try_equate_inner(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
650        match (sort1, sort2) {
651            (rty::Sort::Infer(rty::SortVar(vid1)), rty::Sort::Infer(rty::SortVar(vid2))) => {
652                self.sort_unification_table
653                    .unify_var_var(*vid1, *vid2)
654                    .ok()?;
655            }
656            (rty::Sort::Infer(rty::SortVar(vid)), sort)
657            | (sort, rty::Sort::Infer(rty::SortVar(vid))) => {
658                self.sort_unification_table
659                    .unify_var_value(*vid, Some(sort.clone()))
660                    .ok()?;
661            }
662            (rty::Sort::Infer(rty::NumVar(vid1)), rty::Sort::Infer(rty::NumVar(vid2))) => {
663                self.num_unification_table
664                    .unify_var_var(*vid1, *vid2)
665                    .ok()?;
666            }
667            (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::Int)
668            | (rty::Sort::Int, rty::Sort::Infer(rty::NumVar(vid))) => {
669                self.num_unification_table
670                    .unify_var_value(*vid, Some(rty::NumVarValue::Int))
671                    .ok()?;
672            }
673            (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::Real)
674            | (rty::Sort::Real, rty::Sort::Infer(rty::NumVar(vid))) => {
675                self.num_unification_table
676                    .unify_var_value(*vid, Some(rty::NumVarValue::Real))
677                    .ok()?;
678            }
679            (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::BitVec(sz))
680            | (rty::Sort::BitVec(sz), rty::Sort::Infer(rty::NumVar(vid))) => {
681                self.num_unification_table
682                    .unify_var_value(*vid, Some(rty::NumVarValue::BitVec(*sz)))
683                    .ok()?;
684            }
685
686            (rty::Sort::App(ctor1, args1), rty::Sort::App(ctor2, args2)) => {
687                if ctor1 != ctor2 || args1.len() != args2.len() {
688                    return None;
689                }
690                let mut args = vec![];
691                for (s1, s2) in args1.iter().zip(args2.iter()) {
692                    args.push(self.try_equate_inner(s1, s2)?);
693                }
694            }
695            (rty::Sort::BitVec(size1), rty::Sort::BitVec(size2)) => {
696                self.try_equate_bv_sizes(*size1, *size2)?;
697            }
698            _ if sort1 == sort2 => {}
699            _ => return None,
700        }
701        Some(sort1.clone())
702    }
703
704    fn try_equate_bv_sizes(
705        &mut self,
706        size1: rty::BvSize,
707        size2: rty::BvSize,
708    ) -> Option<rty::BvSize> {
709        match (size1, size2) {
710            (rty::BvSize::Infer(vid1), rty::BvSize::Infer(vid2)) => {
711                self.bv_size_unification_table
712                    .unify_var_var(vid1, vid2)
713                    .ok()?;
714            }
715            (rty::BvSize::Infer(vid), size) | (size, rty::BvSize::Infer(vid)) => {
716                self.bv_size_unification_table
717                    .unify_var_value(vid, Some(size))
718                    .ok()?;
719            }
720            _ if size1 == size2 => {}
721            _ => return None,
722        }
723        Some(size1)
724    }
725
726    fn equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> rty::Sort {
727        self.try_equate(sort1, sort2)
728            .unwrap_or_else(|| bug!("failed to equate sorts: `{sort1:?}` `{sort2:?}`"))
729    }
730
731    pub(crate) fn next_sort_var(&mut self) -> rty::Sort {
732        rty::Sort::Infer(rty::SortVar(self.next_sort_vid()))
733    }
734
735    fn next_num_var(&mut self) -> rty::Sort {
736        rty::Sort::Infer(rty::NumVar(self.next_num_vid()))
737    }
738
739    pub(crate) fn next_sort_vid(&mut self) -> rty::SortVid {
740        self.sort_unification_table.new_key(None)
741    }
742
743    fn next_num_vid(&mut self) -> rty::NumVid {
744        self.num_unification_table.new_key(None)
745    }
746
747    fn next_bv_size_var(&mut self) -> rty::BvSize {
748        rty::BvSize::Infer(self.next_bv_size_vid())
749    }
750
751    fn next_bv_size_vid(&mut self) -> rty::BvSizeVid {
752        self.bv_size_unification_table.new_key(None)
753    }
754
755    fn ensure_resolved_path(&mut self, path: &fhir::PathExpr) -> Result<rty::Sort> {
756        let sort = self.synth_path(path);
757        self.fully_resolve(&sort)
758            .map_err(|_| self.emit_err(errors::CannotInferSort::new(path.span)))
759    }
760
761    fn is_single_field_struct(&mut self, sort: &rty::Sort) -> Option<(DefId, rty::Sort)> {
762        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = sort
763            && let Some(variant) = sort_def.opt_struct_variant()
764            && let [sort] = &variant.field_sorts(sort_args)[..]
765        {
766            Some((sort_def.did(), sort.clone()))
767        } else {
768            None
769        }
770    }
771
772    pub(crate) fn into_results(mut self) -> Result<WfckResults> {
773        // Make sure the int literals are fully resolved. This must happen before everything else
774        // such that we properly apply the fallback for unconstrained num vars.
775        for (node, sort) in std::mem::take(&mut self.sort_of_literal) {
776            // Fallback to `int` when a num variable is unconstrained. Note that we unconditionally
777            // unify the variable. This is fine because if the variable has already been unified,
778            // the operation will fail and this won't have any effect. Also note that unifying a
779            // variable could solve variables that appear later in this for loop. This is also fine
780            // because the order doesn't matter as we are unifying everything to `int`.
781            if let rty::Sort::Infer(rty::SortInfer::NumVar(vid)) = &sort {
782                let _ = self
783                    .num_unification_table
784                    .unify_var_value(*vid, Some(rty::NumVarValue::Int));
785            }
786
787            let sort = self
788                .fully_resolve(&sort)
789                .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
790            self.wfckresults.node_sorts_mut().insert(node.fhir_id, sort);
791        }
792
793        // Make sure the binary operators are fully resolved
794        for (node, sort) in std::mem::take(&mut self.sort_of_bin_op) {
795            let sort = self
796                .fully_resolve(&sort)
797                .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
798            self.wfckresults
799                .bin_op_sorts_mut()
800                .insert(node.fhir_id, sort);
801        }
802
803        let allow_uninterpreted_cast = self
804            .owner
805            .def_id()
806            .map_or_else(flux_config::allow_uninterpreted_cast, |def_id| {
807                self.genv.infer_opts(def_id).allow_uninterpreted_cast
808            });
809
810        // Make sure that function applications are fully resolved
811        for (node, sort_args) in std::mem::take(&mut self.sort_args_of_app) {
812            let mut res = vec![];
813            for sort_arg in &sort_args {
814                let sort_arg = match sort_arg {
815                    rty::SortArg::Sort(sort) => {
816                        let sort = self
817                            .fully_resolve(sort)
818                            .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
819                        rty::SortArg::Sort(sort)
820                    }
821                    rty::SortArg::BvSize(rty::BvSize::Infer(vid)) => {
822                        let size = self
823                            .bv_size_unification_table
824                            .probe_value(*vid)
825                            .ok_or_else(|| {
826                                self.emit_err(errors::CannotInferSort::new(node.span))
827                            })?;
828                        rty::SortArg::BvSize(size)
829                    }
830                    _ => sort_arg.clone(),
831                };
832                res.push(sort_arg);
833            }
834            if let fhir::ExprKind::App(callee, _) = node.kind
835                && matches!(callee.res, fhir::Res::GlobalFunc(fhir::SpecFuncKind::Cast))
836            {
837                let [rty::SortArg::Sort(from), rty::SortArg::Sort(to)] = &res[..] else {
838                    span_bug!(node.span, "invalid sort args!")
839                };
840                if !allow_uninterpreted_cast
841                    && matches!(from.cast_kind(to), rty::CastKind::Uninterpreted)
842                {
843                    return Err(self.emit_err(errors::InvalidCast::new(node.span, from, to)));
844                }
845            }
846            self.wfckresults
847                .fn_app_sorts_mut()
848                .insert(node.fhir_id, res.into());
849        }
850
851        // Make sure all parameters are fully resolved
852        for (_, (param, sort)) in std::mem::take(&mut self.params) {
853            let sort = self
854                .fully_resolve(&sort)
855                .map_err(|_| self.emit_err(errors::SortAnnotationNeeded::new(&param)))?;
856            self.wfckresults.param_sorts_mut().insert(param.id, sort);
857        }
858
859        Ok(self.wfckresults)
860    }
861
862    pub(crate) fn infer_mode(&self, id: fhir::ParamId) -> fhir::InferMode {
863        fhir::InferMode::from_param_kind(self.params[&id].0.kind)
864    }
865
866    #[track_caller]
867    pub(crate) fn param_sort(&self, id: fhir::ParamId) -> rty::Sort {
868        self.params
869            .get(&id)
870            .unwrap_or_else(|| bug!("no entry found for `{id:?}`"))
871            .1
872            .clone()
873    }
874
875    fn shallow_resolve(&mut self, sort: &rty::Sort) -> rty::Sort {
876        sort.fold_with(&mut ShallowResolver { infcx: self })
877    }
878
879    fn resolve_vars_if_possible(&mut self, sort: &rty::Sort) -> rty::Sort {
880        sort.fold_with(&mut OpportunisticResolver { infcx: self })
881    }
882
883    pub(crate) fn fully_resolve(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, ()> {
884        sort.try_fold_with(&mut FullResolver { infcx: self })
885    }
886}
887
888/// Before the main sort inference, we do a first traversal checking all implicitly scoped parameters
889/// declared with `@` or `#` and infer their sort based on the type they are indexing, e.g., if `n` was
890/// declared as `i32[@n]`, we infer `int` for `n`.
891///
892/// This prepass is necessary because sometimes the order in which we traverse expressions can
893/// affect what we can infer. By resolving implicit parameters first, we ensure more consistent and
894/// complete inference regardless of how expressions are later traversed.
895///
896/// It should be possible to improve sort checking (e.g., by allowing partially resolved sorts in
897/// function position) such that we don't need this anymore.
898pub(crate) struct ImplicitParamInferer<'a, 'genv, 'tcx> {
899    infcx: &'a mut InferCtxt<'genv, 'tcx>,
900    errors: Errors<'genv>,
901}
902
903impl<'a, 'genv, 'tcx> ImplicitParamInferer<'a, 'genv, 'tcx> {
904    pub(crate) fn infer(
905        infcx: &'a mut InferCtxt<'genv, 'tcx>,
906        node: &fhir::OwnerNode<'genv>,
907    ) -> Result {
908        let errors = Errors::new(infcx.genv.sess());
909        let mut vis = Self { infcx, errors };
910        vis.visit_node(node);
911        vis.errors.to_result()
912    }
913
914    fn infer_implicit_params(&mut self, idx: &fhir::Expr, expected: &rty::Sort) {
915        match idx.kind {
916            fhir::ExprKind::Var(QPathExpr::Resolved(var, Some(_))) => {
917                let (_, id) = var.res.expect_param();
918                let found = self.infcx.param_sort(id);
919                self.infcx.equate(&found, expected);
920            }
921            fhir::ExprKind::Record(flds) => {
922                if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
923                    let sorts = sort_def.struct_variant().field_sorts(sort_args);
924                    if flds.len() != sorts.len() {
925                        self.errors.emit(errors::ArgCountMismatch::new(
926                            Some(idx.span),
927                            String::from("type"),
928                            sorts.len(),
929                            flds.len(),
930                        ));
931                        return;
932                    }
933                    for (f, sort) in iter::zip(flds, &sorts) {
934                        self.infer_implicit_params(f, sort);
935                    }
936                } else {
937                    self.errors.emit(errors::ArgCountMismatch::new(
938                        Some(idx.span),
939                        String::from("type"),
940                        1,
941                        flds.len(),
942                    ));
943                }
944            }
945            _ => {}
946        }
947    }
948}
949
950impl<'genv> fhir::visit::Visitor<'genv> for ImplicitParamInferer<'_, 'genv, '_> {
951    fn visit_ty(&mut self, ty: &fhir::Ty<'genv>) {
952        if let fhir::TyKind::Indexed(bty, idx) = &ty.kind {
953            let expected = self.infcx.sort_of_bty(bty);
954            self.infer_implicit_params(idx, &expected);
955        }
956        fhir::visit::walk_ty(self, ty);
957    }
958}
959
960impl InferCtxt<'_, '_> {
961    #[track_caller]
962    fn emit_sort_mismatch(
963        &mut self,
964        span: Span,
965        expected: &rty::Sort,
966        found: &rty::Sort,
967    ) -> ErrorGuaranteed {
968        let expected = self.resolve_vars_if_possible(expected);
969        let found = self.resolve_vars_if_possible(found);
970        self.emit_err(errors::SortMismatch::new(span, expected, found))
971    }
972
973    fn emit_field_not_found(&mut self, sort: &rty::Sort, field: Ident) -> ErrorGuaranteed {
974        self.emit_err(errors::FieldNotFound::new(sort.clone(), field))
975    }
976
977    #[track_caller]
978    fn emit_err<'b>(&'b self, err: impl Diagnostic<'b>) -> ErrorGuaranteed {
979        self.genv.sess().emit_err(err)
980    }
981}
982
983struct ShallowResolver<'a, 'genv, 'tcx> {
984    infcx: &'a mut InferCtxt<'genv, 'tcx>,
985}
986
987impl TypeFolder for ShallowResolver<'_, '_, '_> {
988    fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
989        match sort {
990            rty::Sort::Infer(rty::SortVar(vid)) => {
991                // if `sort` is a sort variable, it can be resolved to an num/bit-vec variable,
992                // which can then be recursively resolved, hence the recursion. Note though that
993                // we prevent sort variables from unifying to other sort variables directly (though
994                // they may be embedded structurally), so this recursion should always be of very
995                // limited depth.
996                self.infcx
997                    .sort_unification_table
998                    .probe_value(*vid)
999                    .map(|sort| sort.fold_with(self))
1000                    .unwrap_or(sort.clone())
1001            }
1002            rty::Sort::Infer(rty::NumVar(vid)) => {
1003                // same here, a num var could had been unified with a bitvector
1004                self.infcx
1005                    .num_unification_table
1006                    .probe_value(*vid)
1007                    .map(|val| val.to_sort().fold_with(self))
1008                    .unwrap_or(sort.clone())
1009            }
1010            rty::Sort::BitVec(rty::BvSize::Infer(vid)) => {
1011                self.infcx
1012                    .bv_size_unification_table
1013                    .probe_value(*vid)
1014                    .map(rty::Sort::BitVec)
1015                    .unwrap_or(sort.clone())
1016            }
1017            _ => sort.clone(),
1018        }
1019    }
1020}
1021
1022struct OpportunisticResolver<'a, 'genv, 'tcx> {
1023    infcx: &'a mut InferCtxt<'genv, 'tcx>,
1024}
1025
1026impl TypeFolder for OpportunisticResolver<'_, '_, '_> {
1027    fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
1028        let s = self.infcx.shallow_resolve(sort);
1029        s.super_fold_with(self)
1030    }
1031}
1032
1033struct FullResolver<'a, 'genv, 'tcx> {
1034    infcx: &'a mut InferCtxt<'genv, 'tcx>,
1035}
1036
1037impl FallibleTypeFolder for FullResolver<'_, '_, '_> {
1038    type Error = ();
1039
1040    fn try_fold_sort(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, Self::Error> {
1041        let s = self.infcx.shallow_resolve(sort);
1042        match s {
1043            rty::Sort::Infer(_) | rty::Sort::BitVec(rty::BvSize::Infer(_)) => Err(()),
1044            _ => s.try_super_fold_with(self),
1045        }
1046    }
1047}
1048
1049/// Map to associate data to a node (i.e., an expression).
1050///
1051/// Used to record elaborated information.
1052#[derive_where(Default)]
1053struct NodeMap<'genv, T> {
1054    map: FxHashMap<FhirId, (fhir::Expr<'genv>, T)>,
1055}
1056
1057impl<'genv, T> NodeMap<'genv, T> {
1058    /// Add a `node` to the map with associated `data`
1059    fn insert(&mut self, node: fhir::Expr<'genv>, data: T) {
1060        assert!(self.map.insert(node.fhir_id, (node, data)).is_none());
1061    }
1062}
1063
1064impl<'genv, T> IntoIterator for NodeMap<'genv, T> {
1065    type Item = (fhir::Expr<'genv>, T);
1066
1067    type IntoIter = std::collections::hash_map::IntoValues<FhirId, Self::Item>;
1068
1069    fn into_iter(self) -> Self::IntoIter {
1070        self.map.into_values()
1071    }
1072}