flux_fhir_analysis/wf/
sortck.rs

1use std::iter;
2
3use derive_where::derive_where;
4use ena::unify::InPlaceUnificationTable;
5use flux_common::{bug, iter::IterExt, span_bug, tracked_span_bug};
6use flux_errors::{ErrorGuaranteed, Errors};
7use flux_infer::projections::NormalizeExt;
8use flux_middle::{
9    fhir::{self, FhirId, FluxOwnerId, QPathExpr, visit::Visitor as _},
10    global_env::GlobalEnv,
11    queries::QueryResult,
12    rty::{
13        self, AdtSortDef, FuncSort, List, WfckResults,
14        fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable},
15    },
16};
17use itertools::{Itertools, izip};
18use rustc_data_structures::unord::UnordMap;
19use rustc_errors::Diagnostic;
20use rustc_hash::FxHashMap;
21use rustc_hir::def::DefKind;
22use rustc_middle::ty::TypingMode;
23use rustc_span::{Span, def_id::DefId, symbol::Ident};
24
25use super::errors;
26use crate::rustc_infer::infer::TyCtxtInferExt;
27
28type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
29
30pub(super) struct InferCtxt<'genv, 'tcx> {
31    pub genv: GlobalEnv<'genv, 'tcx>,
32    pub owner: FluxOwnerId,
33    pub wfckresults: WfckResults,
34    sort_unification_table: InPlaceUnificationTable<rty::SortVid>,
35    num_unification_table: InPlaceUnificationTable<rty::NumVid>,
36    bv_size_unification_table: InPlaceUnificationTable<rty::BvSizeVid>,
37    params: FxHashMap<fhir::ParamId, (fhir::RefineParam<'genv>, rty::Sort)>,
38    node_sort: FxHashMap<FhirId, rty::Sort>,
39    path_args: UnordMap<FhirId, rty::GenericArgs>,
40    sort_of_alias_reft: FxHashMap<FhirId, rty::FuncSort>,
41    sort_of_literal: NodeMap<'genv, rty::Sort>,
42    sort_of_bin_op: NodeMap<'genv, rty::Sort>,
43    sort_args_of_app: NodeMap<'genv, List<rty::SortArg>>,
44}
45
46pub fn prim_op_sort(op: &fhir::BinOp) -> Option<(Vec<rty::Sort>, rty::Sort)> {
47    match op {
48        fhir::BinOp::BitAnd
49        | fhir::BinOp::BitOr
50        | fhir::BinOp::BitXor
51        | fhir::BinOp::BitShl
52        | fhir::BinOp::BitShr => Some((vec![rty::Sort::Int, rty::Sort::Int], rty::Sort::Int)),
53        _ => None,
54    }
55}
56
57impl<'genv, 'tcx> InferCtxt<'genv, 'tcx> {
58    pub(super) fn new(genv: GlobalEnv<'genv, 'tcx>, owner: FluxOwnerId) -> Self {
59        // We skip 0 because that's used for sort dummy self types during conv.
60        let mut sort_unification_table = InPlaceUnificationTable::new();
61        sort_unification_table.new_key(None);
62        Self {
63            genv,
64            owner,
65            wfckresults: WfckResults::new(owner),
66            sort_unification_table,
67            num_unification_table: InPlaceUnificationTable::new(),
68            bv_size_unification_table: InPlaceUnificationTable::new(),
69            params: Default::default(),
70            node_sort: Default::default(),
71            path_args: Default::default(),
72            sort_of_alias_reft: Default::default(),
73            sort_of_literal: Default::default(),
74            sort_of_bin_op: Default::default(),
75            sort_args_of_app: Default::default(),
76        }
77    }
78
79    fn check_abs(
80        &mut self,
81        expr: &fhir::Expr<'genv>,
82        params: &[fhir::RefineParam],
83        body: &fhir::Expr<'genv>,
84        expected: &rty::Sort,
85    ) -> Result {
86        let Some(fsort) = self.is_coercible_from_func(expected, expr.fhir_id) else {
87            return Err(self.emit_err(errors::UnexpectedFun::new(expr.span, expected)));
88        };
89
90        let fsort = fsort.expect_mono();
91
92        if params.len() != fsort.inputs().len() {
93            return Err(self.emit_err(errors::ParamCountMismatch::new(
94                expr.span,
95                fsort.inputs().len(),
96                params.len(),
97            )));
98        }
99        iter::zip(params, fsort.inputs()).try_for_each_exhaust(|(param, expected)| {
100            let found = self.param_sort(param.id);
101            if self.try_equate(&found, expected).is_none() {
102                return Err(self.emit_sort_mismatch(param.span, expected, &found));
103            }
104            Ok(())
105        })?;
106        self.check_expr(body, fsort.output())?;
107        self.wfckresults
108            .node_sorts_mut()
109            .insert(body.fhir_id, fsort.output().clone());
110        Ok(())
111    }
112
113    fn check_field_exprs(
114        &mut self,
115        expr_span: Span,
116        sort_def: &AdtSortDef,
117        sort_args: &[rty::Sort],
118        field_exprs: &[fhir::FieldExpr<'genv>],
119        spread: &Option<&fhir::Spread<'genv>>,
120        expected: &rty::Sort,
121    ) -> Result {
122        let sort_by_field_name = sort_def.struct_variant().sort_by_field_name(sort_args);
123        let mut used_fields = FxHashMap::default();
124        for expr in field_exprs {
125            // make sure that the field is actually a field
126            let Some(sort) = sort_by_field_name.get(&expr.ident.name) else {
127                return Err(self.emit_err(errors::FieldNotFound::new(expected.clone(), expr.ident)));
128            };
129            if let Some(old_field) = used_fields.insert(expr.ident.name, expr.ident) {
130                return Err(self.emit_err(errors::DuplicateFieldUsed::new(expr.ident, old_field)));
131            }
132            self.check_expr(&expr.expr, sort)?;
133        }
134        if let Some(spread) = spread {
135            // must check that the spread is of the same sort as the constructor
136            self.check_expr(&spread.expr, expected)
137        } else if sort_by_field_name.len() != used_fields.len() {
138            // emit an error because all fields are not used
139            let missing_fields = sort_by_field_name
140                .into_keys()
141                .filter(|x| !used_fields.contains_key(x))
142                .collect();
143            Err(self.emit_err(errors::ConstructorMissingFields::new(expr_span, missing_fields)))
144        } else {
145            Ok(())
146        }
147    }
148
149    fn check_constructor(
150        &mut self,
151        expr: &fhir::Expr,
152        field_exprs: &[fhir::FieldExpr<'genv>],
153        spread: &Option<&fhir::Spread<'genv>>,
154        expected: &rty::Sort,
155    ) -> Result {
156        let expected = self.resolve_vars_if_possible(expected);
157        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = &expected {
158            self.wfckresults
159                .record_ctors_mut()
160                .insert(expr.fhir_id, sort_def.did());
161            self.check_field_exprs(expr.span, sort_def, sort_args, field_exprs, spread, &expected)?;
162            Ok(())
163        } else {
164            Err(self.emit_err(errors::UnexpectedConstructor::new(expr.span, &expected)))
165        }
166    }
167
168    fn check_record(
169        &mut self,
170        arg: &fhir::Expr<'genv>,
171        flds: &[fhir::Expr<'genv>],
172        expected: &rty::Sort,
173    ) -> Result {
174        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
175            let sorts = sort_def.struct_variant().field_sorts(sort_args);
176            if flds.len() != sorts.len() {
177                return Err(self.emit_err(errors::ArgCountMismatch::new(
178                    Some(arg.span),
179                    String::from("type"),
180                    sorts.len(),
181                    flds.len(),
182                )));
183            }
184            self.wfckresults
185                .record_ctors_mut()
186                .insert(arg.fhir_id, sort_def.did());
187
188            izip!(flds, &sorts)
189                .map(|(arg, expected)| self.check_expr(arg, expected))
190                .try_collect_exhaust()
191        } else {
192            Err(self.emit_err(errors::ArgCountMismatch::new(
193                Some(arg.span),
194                String::from("type"),
195                1,
196                flds.len(),
197            )))
198        }
199    }
200
201    pub(super) fn check_expr(&mut self, expr: &fhir::Expr<'genv>, expected: &rty::Sort) -> Result {
202        match &expr.kind {
203            fhir::ExprKind::IfThenElse(p, e1, e2) => {
204                self.check_expr(p, &rty::Sort::Bool)?;
205                self.check_expr(e1, expected)?;
206                self.check_expr(e2, expected)?;
207            }
208            fhir::ExprKind::Abs(params, body) => {
209                self.check_abs(expr, params, body, expected)?;
210            }
211            fhir::ExprKind::Record(fields) => self.check_record(expr, fields, expected)?,
212            fhir::ExprKind::Constructor(None, exprs, spread) => {
213                self.check_constructor(expr, exprs, spread, expected)?;
214            }
215            fhir::ExprKind::UnaryOp(..)
216            | fhir::ExprKind::BinaryOp(..)
217            | fhir::ExprKind::Dot(..)
218            | fhir::ExprKind::App(..)
219            | fhir::ExprKind::Alias(..)
220            | fhir::ExprKind::Var(..)
221            | fhir::ExprKind::Literal(..)
222            | fhir::ExprKind::BoundedQuant(..)
223            | fhir::ExprKind::Block(..)
224            | fhir::ExprKind::Constructor(..)
225            | fhir::ExprKind::PrimApp(..) => {
226                let found = self.synth_expr(expr)?;
227                let found = self.resolve_vars_if_possible(&found);
228                let expected = self.resolve_vars_if_possible(expected);
229                if !self.is_coercible(&found, &expected, expr.fhir_id) {
230                    return Err(self.emit_sort_mismatch(expr.span, &expected, &found));
231                }
232            }
233            fhir::ExprKind::Err(_) => {
234                // an error has already been reported so we can just skip
235            }
236        }
237        Ok(())
238    }
239
240    pub(super) fn check_loc(&mut self, loc: &fhir::PathExpr) -> Result {
241        let found = self.synth_path(loc);
242        if found == rty::Sort::Loc {
243            Ok(())
244        } else {
245            Err(self.emit_sort_mismatch(loc.span, &rty::Sort::Loc, &found))
246        }
247    }
248
249    fn synth_lit(&mut self, lit: fhir::Lit, expr: &fhir::Expr<'genv>) -> rty::Sort {
250        match lit {
251            fhir::Lit::Int(_, Some(fhir::NumLitKind::Int)) => rty::Sort::Int,
252            fhir::Lit::Int(_, Some(fhir::NumLitKind::Real)) => rty::Sort::Real,
253            fhir::Lit::Int(_, None) => {
254                let sort = self.next_num_var();
255                self.sort_of_literal.insert(*expr, sort.clone());
256                sort
257            }
258            fhir::Lit::Bool(_) => rty::Sort::Bool,
259            fhir::Lit::Str(_) => rty::Sort::Str,
260            fhir::Lit::Char(_) => rty::Sort::Char,
261        }
262    }
263
264    fn synth_prim_app(
265        &mut self,
266        op: &fhir::BinOp,
267        e1: &fhir::Expr<'genv>,
268        e2: &fhir::Expr<'genv>,
269        span: Span,
270    ) -> Result<rty::Sort> {
271        let Some((inputs, output)) = prim_op_sort(op) else {
272            return Err(self.emit_err(errors::UnsupportedPrimOp::new(span, *op)));
273        };
274        let [sort1, sort2] = &inputs[..] else {
275            return Err(self.emit_err(errors::ArgCountMismatch::new(
276                Some(span),
277                String::from("primop app"),
278                inputs.len(),
279                2,
280            )));
281        };
282        self.check_expr(e1, sort1)?;
283        self.check_expr(e2, sort2)?;
284        Ok(output)
285    }
286
287    fn synth_expr(&mut self, expr: &fhir::Expr<'genv>) -> Result<rty::Sort> {
288        match expr.kind {
289            fhir::ExprKind::Var(QPathExpr::Resolved(path, _)) => Ok(self.synth_path(&path)),
290            fhir::ExprKind::Var(QPathExpr::TypeRelative(..)) => {
291                Ok(self.synth_type_relative_path(expr))
292            }
293            fhir::ExprKind::Literal(lit) => Ok(self.synth_lit(lit, expr)),
294            fhir::ExprKind::BinaryOp(op, e1, e2) => self.synth_binary_op(expr, op, e1, e2),
295            fhir::ExprKind::PrimApp(op, e1, e2) => self.synth_prim_app(&op, e1, e2, expr.span),
296            fhir::ExprKind::UnaryOp(op, e) => self.synth_unary_op(op, e),
297            fhir::ExprKind::App(callee, args) => {
298                let sort = self.ensure_resolved_path(&callee)?;
299                let Some(poly_fsort) = self.is_coercible_to_func(&sort, callee.fhir_id) else {
300                    return Err(self.emit_err(errors::ExpectedFun::new(callee.span, &sort)));
301                };
302                let fsort = self.instantiate_func_sort(expr, poly_fsort);
303                self.synth_app(fsort, args, expr.span)
304            }
305            fhir::ExprKind::BoundedQuant(.., body) => {
306                self.check_expr(body, &rty::Sort::Bool)?;
307                Ok(rty::Sort::Bool)
308            }
309            fhir::ExprKind::Alias(_alias_reft, args) => {
310                // To check the application we only need the sort of `_alias_reft` which we collected
311                // during early conv, but should we do any extra checks on `_alias_reft`?
312                let fsort = self.sort_of_alias_reft(expr.fhir_id);
313                self.synth_app(fsort, args, expr.span)
314            }
315            fhir::ExprKind::IfThenElse(p, e1, e2) => {
316                self.check_expr(p, &rty::Sort::Bool)?;
317                let sort = self.synth_expr(e1)?;
318                self.check_expr(e2, &sort)?;
319                Ok(sort)
320            }
321            fhir::ExprKind::Dot(base, fld) => {
322                let sort = self.synth_expr(base)?;
323                let sort = self
324                    .fully_resolve(&sort)
325                    .map_err(|_| self.emit_err(errors::CannotInferSort::new(base.span)))?;
326                match &sort {
327                    rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) => {
328                        let (proj, sort) = sort_def
329                            .struct_variant()
330                            .field_by_name(sort_def.did(), sort_args, fld.name)
331                            .ok_or_else(|| self.emit_field_not_found(&sort, fld))?;
332                        self.wfckresults
333                            .field_projs_mut()
334                            .insert(expr.fhir_id, proj);
335                        Ok(sort)
336                    }
337                    rty::Sort::Bool | rty::Sort::Int | rty::Sort::Real => {
338                        Err(self.emit_err(errors::InvalidPrimitiveDotAccess::new(&sort, fld)))
339                    }
340                    _ => Err(self.emit_field_not_found(&sort, fld)),
341                }
342            }
343            fhir::ExprKind::Constructor(Some(path), field_exprs, spread) => {
344                // if we have S then we can synthesize otherwise we fail
345                // first get the sort based on the path - for example S { ... } => S
346                // and we should expect sort to be a struct or enum app
347                let path_def_id = match path.res {
348                    fhir::Res::Def(DefKind::Enum | DefKind::Struct, def_id) => def_id,
349                    _ => span_bug!(expr.span, "unexpected path in constructor"),
350                };
351                let sort_def = self
352                    .genv
353                    .adt_sort_def_of(path_def_id)
354                    .map_err(|e| self.emit_err(e))?;
355                // generate fresh inferred sorts for each param
356                let fresh_args: rty::List<_> = (0..sort_def.param_count())
357                    .map(|_| self.next_sort_var())
358                    .collect();
359                let sort = rty::Sort::App(rty::SortCtor::Adt(sort_def.clone()), fresh_args.clone());
360                // check fields & spread against fresh args
361                self.check_field_exprs(
362                    expr.span,
363                    &sort_def,
364                    &fresh_args,
365                    field_exprs,
366                    &spread,
367                    &sort,
368                )?;
369                Ok(sort)
370            }
371            fhir::ExprKind::Block(decls, body) => {
372                for decl in decls {
373                    self.check_expr(&decl.init, &self.param_sort(decl.param.id))?;
374                }
375                self.synth_expr(body)
376            }
377            _ => Err(self.emit_err(errors::CannotInferSort::new(expr.span))),
378        }
379    }
380
381    fn synth_path(&mut self, path: &fhir::PathExpr) -> rty::Sort {
382        self.node_sort
383            .get(&path.fhir_id)
384            .unwrap_or_else(|| tracked_span_bug!("no sort found for path: `{path:?}`"))
385            .clone()
386    }
387
388    fn synth_type_relative_path(&mut self, expr: &fhir::Expr) -> rty::Sort {
389        self.node_sort
390            .get(&expr.fhir_id)
391            .unwrap_or_else(|| tracked_span_bug!("no sort found for: `{expr:?}`"))
392            .clone()
393    }
394
395    fn check_integral(&mut self, op: fhir::BinOp, sort: &rty::Sort, span: Span) -> Result {
396        if matches!(op, fhir::BinOp::Mod) {
397            let sort = self
398                .fully_resolve(sort)
399                .map_err(|_| self.emit_err(errors::CannotInferSort::new(span)))?;
400            if !matches!(sort, rty::Sort::Int | rty::Sort::BitVec(_)) {
401                span_bug!(span, "unexpected sort {sort:?} for operator {op:?}");
402            }
403        }
404        Ok(())
405    }
406
407    fn synth_binary_op(
408        &mut self,
409        expr: &fhir::Expr<'genv>,
410        op: fhir::BinOp,
411        e1: &fhir::Expr<'genv>,
412        e2: &fhir::Expr<'genv>,
413    ) -> Result<rty::Sort> {
414        match op {
415            fhir::BinOp::Or | fhir::BinOp::And | fhir::BinOp::Iff | fhir::BinOp::Imp => {
416                self.check_expr(e1, &rty::Sort::Bool)?;
417                self.check_expr(e2, &rty::Sort::Bool)?;
418                Ok(rty::Sort::Bool)
419            }
420            fhir::BinOp::Eq | fhir::BinOp::Ne => {
421                let sort = self.next_sort_var();
422                self.check_expr(e1, &sort)?;
423                self.check_expr(e2, &sort)?;
424                Ok(rty::Sort::Bool)
425            }
426            fhir::BinOp::Lt | fhir::BinOp::Le | fhir::BinOp::Gt | fhir::BinOp::Ge => {
427                let sort = self.next_sort_var();
428                self.check_expr(e1, &sort)?;
429                self.check_expr(e2, &sort)?;
430                self.sort_of_bin_op.insert(*expr, sort.clone());
431                Ok(rty::Sort::Bool)
432            }
433            fhir::BinOp::Add
434            | fhir::BinOp::Sub
435            | fhir::BinOp::Mul
436            | fhir::BinOp::Div
437            | fhir::BinOp::Mod => {
438                let sort = self.next_num_var();
439                self.check_expr(e1, &sort)?;
440                self.check_expr(e2, &sort)?;
441                self.sort_of_bin_op.insert(*expr, sort.clone());
442                // check that the sort is integral for mod
443                self.check_integral(op, &sort, expr.span)?;
444
445                Ok(sort)
446            }
447            fhir::BinOp::BitAnd
448            | fhir::BinOp::BitOr
449            | fhir::BinOp::BitXor
450            | fhir::BinOp::BitShl
451            | fhir::BinOp::BitShr => {
452                let sort = rty::Sort::BitVec(self.next_bv_size_var());
453                self.check_expr(e1, &sort)?;
454                self.check_expr(e2, &sort)?;
455                Ok(sort)
456            }
457        }
458    }
459
460    fn synth_unary_op(&mut self, op: fhir::UnOp, e: &fhir::Expr<'genv>) -> Result<rty::Sort> {
461        match op {
462            fhir::UnOp::Not => {
463                self.check_expr(e, &rty::Sort::Bool)?;
464                Ok(rty::Sort::Bool)
465            }
466            fhir::UnOp::Neg => {
467                self.check_expr(e, &rty::Sort::Int)?;
468                Ok(rty::Sort::Int)
469            }
470        }
471    }
472
473    fn synth_app(
474        &mut self,
475        fsort: FuncSort,
476        args: &[fhir::Expr<'genv>],
477        span: Span,
478    ) -> Result<rty::Sort> {
479        if args.len() != fsort.inputs().len() {
480            return Err(self.emit_err(errors::ArgCountMismatch::new(
481                Some(span),
482                String::from("function"),
483                fsort.inputs().len(),
484                args.len(),
485            )));
486        }
487
488        iter::zip(args, fsort.inputs())
489            .try_for_each_exhaust(|(arg, formal)| self.check_expr(arg, formal))?;
490
491        Ok(fsort.output().clone())
492    }
493
494    fn instantiate_func_sort(
495        &mut self,
496        app_expr: &fhir::Expr<'genv>,
497        fsort: rty::PolyFuncSort,
498    ) -> rty::FuncSort {
499        let args = fsort
500            .params()
501            .map(|kind| {
502                match kind {
503                    rty::SortParamKind::Sort => rty::SortArg::Sort(self.next_sort_var()),
504                    rty::SortParamKind::BvSize => rty::SortArg::BvSize(self.next_bv_size_var()),
505                }
506            })
507            .collect_vec();
508        self.sort_args_of_app
509            .insert(*app_expr, List::from_slice(&args));
510        fsort.instantiate(&args)
511    }
512
513    pub(crate) fn insert_node_sort(&mut self, fhir_id: FhirId, sort: rty::Sort) {
514        self.node_sort.insert(fhir_id, sort);
515    }
516
517    pub(crate) fn sort_of_bty(&self, bty: &fhir::BaseTy) -> rty::Sort {
518        self.node_sort
519            .get(&bty.fhir_id)
520            .unwrap_or_else(|| tracked_span_bug!("no sort found for bty: `{bty:?}`"))
521            .clone()
522    }
523
524    pub(crate) fn insert_path_args(&mut self, fhir_id: FhirId, args: rty::GenericArgs) {
525        self.path_args.insert(fhir_id, args);
526    }
527
528    pub(crate) fn path_args(&self, fhir_id: FhirId) -> rty::GenericArgs {
529        self.path_args
530            .get(&fhir_id)
531            .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
532            .clone()
533    }
534
535    pub(crate) fn insert_sort_for_alias_reft(&mut self, fhir_id: FhirId, fsort: rty::FuncSort) {
536        self.sort_of_alias_reft.insert(fhir_id, fsort);
537    }
538
539    fn sort_of_alias_reft(&self, fhir_id: FhirId) -> rty::FuncSort {
540        self.sort_of_alias_reft
541            .get(&fhir_id)
542            .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
543            .clone()
544    }
545
546    fn normalize_projection_sort(
547        genv: GlobalEnv,
548        owner: FluxOwnerId,
549        sort: rty::Sort,
550    ) -> rty::Sort {
551        let infcx = genv
552            .tcx()
553            .infer_ctxt()
554            .with_next_trait_solver(true)
555            .build(TypingMode::non_body_analysis());
556        if let Some(def_id) = owner.def_id()
557            && let def_id = genv.maybe_extern_id(def_id).resolved_id()
558            && let Ok(sort) = sort.normalize_sorts(def_id, genv, &infcx)
559        {
560            sort
561        } else {
562            sort
563        }
564    }
565
566    // FIXME(nilehmann) this is a bit of a hack. We should find a more robust way to do normalization
567    // for sort checking, including normalizing projections. [RJ: normalizing_projections is done now]
568    // Maybe we can do lazy normalization. Once we do that maybe we can also stop
569    // expanding aliases in `ConvCtxt::conv_sort`.
570    pub(crate) fn normalize_sorts(&mut self) -> QueryResult {
571        let genv = self.genv;
572        for sort in self.node_sort.values_mut() {
573            *sort = Self::normalize_projection_sort(genv, self.owner, sort.clone());
574        }
575        for fsort in self.sort_of_alias_reft.values_mut() {
576            *fsort = genv.deep_normalize_weak_alias_sorts(fsort)?;
577        }
578        Ok(())
579    }
580}
581
582impl<'genv> InferCtxt<'genv, '_> {
583    pub(super) fn declare_param(&mut self, param: fhir::RefineParam<'genv>, sort: rty::Sort) {
584        self.params.insert(param.id, (param, sort));
585    }
586
587    /// Whether a value of `sort1` can be automatically coerced to a value of `sort2`. A value of an
588    /// [`rty::SortCtor::Adt`] sort with a single field of sort `s` can be coerced to a value of sort
589    /// `s` and vice versa, i.e., we can automatically project the field out of the record or inject
590    /// a value into a record.
591    fn is_coercible(&mut self, sort1: &rty::Sort, sort2: &rty::Sort, fhir_id: FhirId) -> bool {
592        if self.try_equate(sort1, sort2).is_some() {
593            return true;
594        }
595
596        let mut sort1 = sort1.clone();
597        let mut sort2 = sort2.clone();
598
599        let mut coercions = vec![];
600        if let Some((def_id, sort)) = self.is_single_field_struct(&sort1) {
601            coercions.push(rty::Coercion::Project(def_id));
602            sort1 = sort.clone();
603        }
604        if let Some((def_id, sort)) = self.is_single_field_struct(&sort2) {
605            coercions.push(rty::Coercion::Inject(def_id));
606            sort2 = sort.clone();
607        }
608        self.wfckresults.coercions_mut().insert(fhir_id, coercions);
609        self.try_equate(&sort1, &sort2).is_some()
610    }
611
612    fn is_coercible_from_func(
613        &mut self,
614        sort: &rty::Sort,
615        fhir_id: FhirId,
616    ) -> Option<rty::PolyFuncSort> {
617        if let rty::Sort::Func(fsort) = sort {
618            Some(fsort.clone())
619        } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
620            self.wfckresults
621                .coercions_mut()
622                .insert(fhir_id, vec![rty::Coercion::Inject(def_id)]);
623            Some(fsort.clone())
624        } else {
625            None
626        }
627    }
628
629    fn is_coercible_to_func(
630        &mut self,
631        sort: &rty::Sort,
632        fhir_id: FhirId,
633    ) -> Option<rty::PolyFuncSort> {
634        if let rty::Sort::Func(fsort) = sort {
635            Some(fsort.clone())
636        } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
637            self.wfckresults
638                .coercions_mut()
639                .insert(fhir_id, vec![rty::Coercion::Project(def_id)]);
640            Some(fsort.clone())
641        } else {
642            None
643        }
644    }
645
646    fn try_equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
647        let sort1 = self.resolve_vars_if_possible(sort1);
648        let sort2 = self.resolve_vars_if_possible(sort2);
649        self.try_equate_inner(&sort1, &sort2)
650    }
651
652    fn try_equate_inner(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
653        match (sort1, sort2) {
654            (rty::Sort::Infer(rty::SortVar(vid1)), rty::Sort::Infer(rty::SortVar(vid2))) => {
655                self.sort_unification_table
656                    .unify_var_var(*vid1, *vid2)
657                    .ok()?;
658            }
659            (rty::Sort::Infer(rty::SortVar(vid)), sort)
660            | (sort, rty::Sort::Infer(rty::SortVar(vid))) => {
661                self.sort_unification_table
662                    .unify_var_value(*vid, Some(sort.clone()))
663                    .ok()?;
664            }
665            (rty::Sort::Infer(rty::NumVar(vid1)), rty::Sort::Infer(rty::NumVar(vid2))) => {
666                self.num_unification_table
667                    .unify_var_var(*vid1, *vid2)
668                    .ok()?;
669            }
670            (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::Int)
671            | (rty::Sort::Int, rty::Sort::Infer(rty::NumVar(vid))) => {
672                self.num_unification_table
673                    .unify_var_value(*vid, Some(rty::NumVarValue::Int))
674                    .ok()?;
675            }
676            (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::Real)
677            | (rty::Sort::Real, rty::Sort::Infer(rty::NumVar(vid))) => {
678                self.num_unification_table
679                    .unify_var_value(*vid, Some(rty::NumVarValue::Real))
680                    .ok()?;
681            }
682            (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::BitVec(sz))
683            | (rty::Sort::BitVec(sz), rty::Sort::Infer(rty::NumVar(vid))) => {
684                self.num_unification_table
685                    .unify_var_value(*vid, Some(rty::NumVarValue::BitVec(*sz)))
686                    .ok()?;
687            }
688
689            (rty::Sort::App(ctor1, args1), rty::Sort::App(ctor2, args2)) => {
690                if ctor1 != ctor2 || args1.len() != args2.len() {
691                    return None;
692                }
693                let mut args = vec![];
694                for (s1, s2) in args1.iter().zip(args2.iter()) {
695                    args.push(self.try_equate_inner(s1, s2)?);
696                }
697            }
698            (rty::Sort::BitVec(size1), rty::Sort::BitVec(size2)) => {
699                self.try_equate_bv_sizes(*size1, *size2)?;
700            }
701            _ if sort1 == sort2 => {}
702            _ => return None,
703        }
704        Some(sort1.clone())
705    }
706
707    fn try_equate_bv_sizes(
708        &mut self,
709        size1: rty::BvSize,
710        size2: rty::BvSize,
711    ) -> Option<rty::BvSize> {
712        match (size1, size2) {
713            (rty::BvSize::Infer(vid1), rty::BvSize::Infer(vid2)) => {
714                self.bv_size_unification_table
715                    .unify_var_var(vid1, vid2)
716                    .ok()?;
717            }
718            (rty::BvSize::Infer(vid), size) | (size, rty::BvSize::Infer(vid)) => {
719                self.bv_size_unification_table
720                    .unify_var_value(vid, Some(size))
721                    .ok()?;
722            }
723            _ if size1 == size2 => {}
724            _ => return None,
725        }
726        Some(size1)
727    }
728
729    fn equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> rty::Sort {
730        self.try_equate(sort1, sort2)
731            .unwrap_or_else(|| bug!("failed to equate sorts: `{sort1:?}` `{sort2:?}`"))
732    }
733
734    pub(crate) fn next_sort_var(&mut self) -> rty::Sort {
735        rty::Sort::Infer(rty::SortVar(self.next_sort_vid()))
736    }
737
738    fn next_num_var(&mut self) -> rty::Sort {
739        rty::Sort::Infer(rty::NumVar(self.next_num_vid()))
740    }
741
742    pub(crate) fn next_sort_vid(&mut self) -> rty::SortVid {
743        self.sort_unification_table.new_key(None)
744    }
745
746    fn next_num_vid(&mut self) -> rty::NumVid {
747        self.num_unification_table.new_key(None)
748    }
749
750    fn next_bv_size_var(&mut self) -> rty::BvSize {
751        rty::BvSize::Infer(self.next_bv_size_vid())
752    }
753
754    fn next_bv_size_vid(&mut self) -> rty::BvSizeVid {
755        self.bv_size_unification_table.new_key(None)
756    }
757
758    fn ensure_resolved_path(&mut self, path: &fhir::PathExpr) -> Result<rty::Sort> {
759        let sort = self.synth_path(path);
760        self.fully_resolve(&sort)
761            .map_err(|_| self.emit_err(errors::CannotInferSort::new(path.span)))
762    }
763
764    fn is_single_field_struct(&mut self, sort: &rty::Sort) -> Option<(DefId, rty::Sort)> {
765        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = sort
766            && let Some(variant) = sort_def.opt_struct_variant()
767            && let [sort] = &variant.field_sorts(sort_args)[..]
768        {
769            Some((sort_def.did(), sort.clone()))
770        } else {
771            None
772        }
773    }
774
775    pub(crate) fn into_results(mut self) -> Result<WfckResults> {
776        // Make sure the int literals are fully resolved. This must happen before everything else
777        // such that we properly apply the fallback for unconstrained num vars.
778        for (node, sort) in std::mem::take(&mut self.sort_of_literal) {
779            // Fallback to `int` when a num variable is unconstrained. Note that we unconditionally
780            // unify the variable. This is fine because if the variable has already been unified,
781            // the operation will fail and this won't have any effect. Also note that unifying a
782            // variable could solve variables that appear later in this for loop. This is also fine
783            // because the order doesn't matter as we are unifying everything to `int`.
784            if let rty::Sort::Infer(rty::SortInfer::NumVar(vid)) = &sort {
785                let _ = self
786                    .num_unification_table
787                    .unify_var_value(*vid, Some(rty::NumVarValue::Int));
788            }
789
790            let sort = self
791                .fully_resolve(&sort)
792                .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
793            self.wfckresults.node_sorts_mut().insert(node.fhir_id, sort);
794        }
795
796        // Make sure the binary operators are fully resolved
797        for (node, sort) in std::mem::take(&mut self.sort_of_bin_op) {
798            let sort = self
799                .fully_resolve(&sort)
800                .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
801            self.wfckresults
802                .bin_op_sorts_mut()
803                .insert(node.fhir_id, sort);
804        }
805
806        let allow_uninterpreted_cast = self
807            .owner
808            .def_id()
809            .map_or_else(flux_config::allow_uninterpreted_cast, |def_id| {
810                self.genv.infer_opts(def_id).allow_uninterpreted_cast
811            });
812
813        // Make sure that function applications are fully resolved
814        for (node, sort_args) in std::mem::take(&mut self.sort_args_of_app) {
815            let mut res = vec![];
816            for sort_arg in &sort_args {
817                let sort_arg = match sort_arg {
818                    rty::SortArg::Sort(sort) => {
819                        let sort = self
820                            .fully_resolve(sort)
821                            .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
822                        rty::SortArg::Sort(sort)
823                    }
824                    rty::SortArg::BvSize(rty::BvSize::Infer(vid)) => {
825                        let size = self
826                            .bv_size_unification_table
827                            .probe_value(*vid)
828                            .ok_or_else(|| {
829                                self.emit_err(errors::CannotInferSort::new(node.span))
830                            })?;
831                        rty::SortArg::BvSize(size)
832                    }
833                    _ => sort_arg.clone(),
834                };
835                res.push(sort_arg);
836            }
837            if let fhir::ExprKind::App(callee, _) = node.kind
838                && matches!(callee.res, fhir::Res::GlobalFunc(fhir::SpecFuncKind::Cast))
839            {
840                let [rty::SortArg::Sort(from), rty::SortArg::Sort(to)] = &res[..] else {
841                    span_bug!(node.span, "invalid sort args!")
842                };
843                if !allow_uninterpreted_cast
844                    && matches!(from.cast_kind(to), rty::CastKind::Uninterpreted)
845                {
846                    return Err(self.emit_err(errors::InvalidCast::new(node.span, from, to)));
847                }
848            }
849            self.wfckresults
850                .fn_app_sorts_mut()
851                .insert(node.fhir_id, res.into());
852        }
853
854        // Make sure all parameters are fully resolved
855        for (_, (param, sort)) in std::mem::take(&mut self.params) {
856            let sort = self
857                .fully_resolve(&sort)
858                .map_err(|_| self.emit_err(errors::SortAnnotationNeeded::new(&param)))?;
859            self.wfckresults.param_sorts_mut().insert(param.id, sort);
860        }
861
862        Ok(self.wfckresults)
863    }
864
865    pub(crate) fn infer_mode(&self, id: fhir::ParamId) -> fhir::InferMode {
866        fhir::InferMode::from_param_kind(self.params[&id].0.kind)
867    }
868
869    #[track_caller]
870    pub(crate) fn param_sort(&self, id: fhir::ParamId) -> rty::Sort {
871        self.params
872            .get(&id)
873            .unwrap_or_else(|| bug!("no entry found for `{id:?}`"))
874            .1
875            .clone()
876    }
877
878    fn shallow_resolve(&mut self, sort: &rty::Sort) -> rty::Sort {
879        sort.fold_with(&mut ShallowResolver { infcx: self })
880    }
881
882    fn resolve_vars_if_possible(&mut self, sort: &rty::Sort) -> rty::Sort {
883        sort.fold_with(&mut OpportunisticResolver { infcx: self })
884    }
885
886    pub(crate) fn fully_resolve(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, ()> {
887        sort.try_fold_with(&mut FullResolver { infcx: self })
888    }
889}
890
891/// Before the main sort inference, we do a first traversal checking all implicitly scoped parameters
892/// declared with `@` or `#` and infer their sort based on the type they are indexing, e.g., if `n` was
893/// declared as `i32[@n]`, we infer `int` for `n`.
894///
895/// This prepass is necessary because sometimes the order in which we traverse expressions can
896/// affect what we can infer. By resolving implicit parameters first, we ensure more consistent and
897/// complete inference regardless of how expressions are later traversed.
898///
899/// It should be possible to improve sort checking (e.g., by allowing partially resolved sorts in
900/// function position) such that we don't need this anymore.
901pub(crate) struct ImplicitParamInferer<'a, 'genv, 'tcx> {
902    infcx: &'a mut InferCtxt<'genv, 'tcx>,
903    errors: Errors<'genv>,
904}
905
906impl<'a, 'genv, 'tcx> ImplicitParamInferer<'a, 'genv, 'tcx> {
907    pub(crate) fn infer(
908        infcx: &'a mut InferCtxt<'genv, 'tcx>,
909        node: &fhir::OwnerNode<'genv>,
910    ) -> Result {
911        let errors = Errors::new(infcx.genv.sess());
912        let mut vis = Self { infcx, errors };
913        vis.visit_node(node);
914        vis.errors.into_result()
915    }
916
917    fn infer_implicit_params(&mut self, idx: &fhir::Expr, expected: &rty::Sort) {
918        match idx.kind {
919            fhir::ExprKind::Var(QPathExpr::Resolved(var, Some(_))) => {
920                let (_, id) = var.res.expect_param();
921                let found = self.infcx.param_sort(id);
922                self.infcx.equate(&found, expected);
923            }
924            fhir::ExprKind::Record(flds) => {
925                if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
926                    let sorts = sort_def.struct_variant().field_sorts(sort_args);
927                    if flds.len() != sorts.len() {
928                        self.errors.emit(errors::ArgCountMismatch::new(
929                            Some(idx.span),
930                            String::from("type"),
931                            sorts.len(),
932                            flds.len(),
933                        ));
934                        return;
935                    }
936                    for (f, sort) in iter::zip(flds, &sorts) {
937                        self.infer_implicit_params(f, sort);
938                    }
939                } else {
940                    self.errors.emit(errors::ArgCountMismatch::new(
941                        Some(idx.span),
942                        String::from("type"),
943                        1,
944                        flds.len(),
945                    ));
946                }
947            }
948            _ => {}
949        }
950    }
951}
952
953impl<'genv> fhir::visit::Visitor<'genv> for ImplicitParamInferer<'_, 'genv, '_> {
954    fn visit_ty(&mut self, ty: &fhir::Ty<'genv>) {
955        if let fhir::TyKind::Indexed(bty, idx) = &ty.kind {
956            let expected = self.infcx.sort_of_bty(bty);
957            self.infer_implicit_params(idx, &expected);
958        }
959        fhir::visit::walk_ty(self, ty);
960    }
961}
962
963impl InferCtxt<'_, '_> {
964    #[track_caller]
965    fn emit_sort_mismatch(
966        &mut self,
967        span: Span,
968        expected: &rty::Sort,
969        found: &rty::Sort,
970    ) -> ErrorGuaranteed {
971        let expected = self.resolve_vars_if_possible(expected);
972        let found = self.resolve_vars_if_possible(found);
973        self.emit_err(errors::SortMismatch::new(span, expected, found))
974    }
975
976    fn emit_field_not_found(&mut self, sort: &rty::Sort, field: Ident) -> ErrorGuaranteed {
977        self.emit_err(errors::FieldNotFound::new(sort.clone(), field))
978    }
979
980    #[track_caller]
981    fn emit_err<'b>(&'b self, err: impl Diagnostic<'b>) -> ErrorGuaranteed {
982        self.genv.sess().emit_err(err)
983    }
984}
985
986struct ShallowResolver<'a, 'genv, 'tcx> {
987    infcx: &'a mut InferCtxt<'genv, 'tcx>,
988}
989
990impl TypeFolder for ShallowResolver<'_, '_, '_> {
991    fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
992        match sort {
993            rty::Sort::Infer(rty::SortVar(vid)) => {
994                // if `sort` is a sort variable, it can be resolved to an num/bit-vec variable,
995                // which can then be recursively resolved, hence the recursion. Note though that
996                // we prevent sort variables from unifying to other sort variables directly (though
997                // they may be embedded structurally), so this recursion should always be of very
998                // limited depth.
999                self.infcx
1000                    .sort_unification_table
1001                    .probe_value(*vid)
1002                    .map(|sort| sort.fold_with(self))
1003                    .unwrap_or(sort.clone())
1004            }
1005            rty::Sort::Infer(rty::NumVar(vid)) => {
1006                // same here, a num var could had been unified with a bitvector
1007                self.infcx
1008                    .num_unification_table
1009                    .probe_value(*vid)
1010                    .map(|val| val.to_sort().fold_with(self))
1011                    .unwrap_or(sort.clone())
1012            }
1013            rty::Sort::BitVec(rty::BvSize::Infer(vid)) => {
1014                self.infcx
1015                    .bv_size_unification_table
1016                    .probe_value(*vid)
1017                    .map(rty::Sort::BitVec)
1018                    .unwrap_or(sort.clone())
1019            }
1020            _ => sort.clone(),
1021        }
1022    }
1023}
1024
1025struct OpportunisticResolver<'a, 'genv, 'tcx> {
1026    infcx: &'a mut InferCtxt<'genv, 'tcx>,
1027}
1028
1029impl TypeFolder for OpportunisticResolver<'_, '_, '_> {
1030    fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
1031        let s = self.infcx.shallow_resolve(sort);
1032        s.super_fold_with(self)
1033    }
1034}
1035
1036struct FullResolver<'a, 'genv, 'tcx> {
1037    infcx: &'a mut InferCtxt<'genv, 'tcx>,
1038}
1039
1040impl FallibleTypeFolder for FullResolver<'_, '_, '_> {
1041    type Error = ();
1042
1043    fn try_fold_sort(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, Self::Error> {
1044        let s = self.infcx.shallow_resolve(sort);
1045        match s {
1046            rty::Sort::Infer(_) | rty::Sort::BitVec(rty::BvSize::Infer(_)) => Err(()),
1047            _ => s.try_super_fold_with(self),
1048        }
1049    }
1050}
1051
1052/// Map to associate data to a node (i.e., an expression).
1053///
1054/// Used to record elaborated information.
1055#[derive_where(Default)]
1056struct NodeMap<'genv, T> {
1057    map: FxHashMap<FhirId, (fhir::Expr<'genv>, T)>,
1058}
1059
1060impl<'genv, T> NodeMap<'genv, T> {
1061    /// Add a `node` to the map with associated `data`
1062    fn insert(&mut self, node: fhir::Expr<'genv>, data: T) {
1063        assert!(self.map.insert(node.fhir_id, (node, data)).is_none());
1064    }
1065}
1066
1067impl<'genv, T> IntoIterator for NodeMap<'genv, T> {
1068    type Item = (fhir::Expr<'genv>, T);
1069
1070    type IntoIter = std::collections::hash_map::IntoValues<FhirId, Self::Item>;
1071
1072    fn into_iter(self) -> Self::IntoIter {
1073        self.map.into_values()
1074    }
1075}