flux_fhir_analysis/wf/
sortck.rs

1use std::iter;
2
3use derive_where::derive_where;
4use ena::unify::InPlaceUnificationTable;
5use flux_common::{bug, iter::IterExt, span_bug, tracked_span_bug};
6use flux_errors::{ErrorGuaranteed, Errors};
7use flux_infer::projections::NormalizeExt;
8use flux_middle::{
9    fhir::{self, FhirId, FluxOwnerId, QPathExpr, visit::Visitor as _},
10    global_env::GlobalEnv,
11    queries::QueryResult,
12    rty::{
13        self, AdtSortDef, FuncSort, List, WfckResults,
14        fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable},
15    },
16};
17use itertools::{Itertools, izip};
18use rustc_data_structures::unord::UnordMap;
19use rustc_errors::Diagnostic;
20use rustc_hash::FxHashMap;
21use rustc_hir::def::DefKind;
22use rustc_middle::ty::TypingMode;
23use rustc_span::{Span, def_id::DefId, symbol::Ident};
24
25use super::errors;
26use crate::rustc_infer::infer::TyCtxtInferExt;
27
28type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
29
30pub(super) struct InferCtxt<'genv, 'tcx> {
31    pub genv: GlobalEnv<'genv, 'tcx>,
32    pub owner: FluxOwnerId,
33    pub wfckresults: WfckResults,
34    sort_unification_table: InPlaceUnificationTable<rty::SortVid>,
35    bv_size_unification_table: InPlaceUnificationTable<rty::BvSizeVid>,
36    params: FxHashMap<fhir::ParamId, (fhir::RefineParam<'genv>, rty::Sort)>,
37    node_sort: FxHashMap<FhirId, rty::Sort>,
38    path_args: UnordMap<FhirId, rty::GenericArgs>,
39    sort_of_alias_reft: FxHashMap<FhirId, rty::FuncSort>,
40    sort_of_literal: NodeMap<'genv, rty::Sort>,
41    sort_of_bin_op: NodeMap<'genv, rty::Sort>,
42    sort_args_of_app: NodeMap<'genv, List<rty::SortArg>>,
43}
44
45pub fn prim_op_sort(op: &fhir::BinOp) -> Option<(Vec<rty::Sort>, rty::Sort)> {
46    match op {
47        fhir::BinOp::BitAnd
48        | fhir::BinOp::BitOr
49        | fhir::BinOp::BitXor
50        | fhir::BinOp::BitShl
51        | fhir::BinOp::BitShr => Some((vec![rty::Sort::Int, rty::Sort::Int], rty::Sort::Int)),
52        _ => None,
53    }
54}
55
56impl<'genv, 'tcx> InferCtxt<'genv, 'tcx> {
57    pub(super) fn new(genv: GlobalEnv<'genv, 'tcx>, owner: FluxOwnerId) -> Self {
58        // We skip 0 because that's used for sort dummy self types during conv.
59        let mut sort_unification_table = InPlaceUnificationTable::new();
60        sort_unification_table.new_key(Default::default());
61        Self {
62            genv,
63            owner,
64            wfckresults: WfckResults::new(owner),
65            sort_unification_table,
66            bv_size_unification_table: InPlaceUnificationTable::new(),
67            params: Default::default(),
68            node_sort: Default::default(),
69            path_args: Default::default(),
70            sort_of_alias_reft: Default::default(),
71            sort_of_literal: Default::default(),
72            sort_of_bin_op: Default::default(),
73            sort_args_of_app: Default::default(),
74        }
75    }
76
77    fn check_abs(
78        &mut self,
79        expr: &fhir::Expr<'genv>,
80        params: &[fhir::RefineParam],
81        body: &fhir::Expr<'genv>,
82        expected: &rty::Sort,
83    ) -> Result {
84        let Some(fsort) = self.is_coercible_from_func(expected, expr.fhir_id) else {
85            return Err(self.emit_err(errors::UnexpectedFun::new(expr.span, expected)));
86        };
87
88        let fsort = fsort.expect_mono();
89
90        if params.len() != fsort.inputs().len() {
91            return Err(self.emit_err(errors::ParamCountMismatch::new(
92                expr.span,
93                fsort.inputs().len(),
94                params.len(),
95            )));
96        }
97        iter::zip(params, fsort.inputs()).try_for_each_exhaust(|(param, expected)| {
98            let found = self.param_sort(param.id);
99            if self.try_equate(&found, expected).is_none() {
100                return Err(self.emit_sort_mismatch(param.span, expected, &found));
101            }
102            Ok(())
103        })?;
104        self.check_expr(body, fsort.output())?;
105        self.wfckresults
106            .node_sorts_mut()
107            .insert(body.fhir_id, fsort.output().clone());
108        Ok(())
109    }
110
111    fn check_field_exprs(
112        &mut self,
113        expr_span: Span,
114        sort_def: &AdtSortDef,
115        sort_args: &[rty::Sort],
116        field_exprs: &[fhir::FieldExpr<'genv>],
117        spread: &Option<&fhir::Spread<'genv>>,
118        expected: &rty::Sort,
119    ) -> Result {
120        let sort_by_field_name = sort_def.struct_variant().sort_by_field_name(sort_args);
121        let mut used_fields = FxHashMap::default();
122        for expr in field_exprs {
123            // make sure that the field is actually a field
124            let Some(sort) = sort_by_field_name.get(&expr.ident.name) else {
125                return Err(self.emit_err(errors::FieldNotFound::new(expected.clone(), expr.ident)));
126            };
127            if let Some(old_field) = used_fields.insert(expr.ident.name, expr.ident) {
128                return Err(self.emit_err(errors::DuplicateFieldUsed::new(expr.ident, old_field)));
129            }
130            self.check_expr(&expr.expr, sort)?;
131        }
132        if let Some(spread) = spread {
133            // must check that the spread is of the same sort as the constructor
134            self.check_expr(&spread.expr, expected)
135        } else if sort_by_field_name.len() != used_fields.len() {
136            // emit an error because all fields are not used
137            let missing_fields = sort_by_field_name
138                .into_keys()
139                .filter(|x| !used_fields.contains_key(x))
140                .collect();
141            Err(self.emit_err(errors::ConstructorMissingFields::new(expr_span, missing_fields)))
142        } else {
143            Ok(())
144        }
145    }
146
147    fn check_constructor(
148        &mut self,
149        expr: &fhir::Expr,
150        field_exprs: &[fhir::FieldExpr<'genv>],
151        spread: &Option<&fhir::Spread<'genv>>,
152        expected: &rty::Sort,
153    ) -> Result {
154        let expected = self.resolve_vars_if_possible(expected);
155        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = &expected {
156            self.wfckresults
157                .record_ctors_mut()
158                .insert(expr.fhir_id, sort_def.did());
159            self.check_field_exprs(expr.span, sort_def, sort_args, field_exprs, spread, &expected)?;
160            Ok(())
161        } else {
162            Err(self.emit_err(errors::UnexpectedConstructor::new(expr.span, &expected)))
163        }
164    }
165
166    fn check_record(
167        &mut self,
168        arg: &fhir::Expr<'genv>,
169        flds: &[fhir::Expr<'genv>],
170        expected: &rty::Sort,
171    ) -> Result {
172        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
173            let sorts = sort_def.struct_variant().field_sorts(sort_args);
174            if flds.len() != sorts.len() {
175                return Err(self.emit_err(errors::ArgCountMismatch::new(
176                    Some(arg.span),
177                    String::from("type"),
178                    sorts.len(),
179                    flds.len(),
180                )));
181            }
182            self.wfckresults
183                .record_ctors_mut()
184                .insert(arg.fhir_id, sort_def.did());
185
186            izip!(flds, &sorts)
187                .map(|(arg, expected)| self.check_expr(arg, expected))
188                .try_collect_exhaust()
189        } else {
190            Err(self.emit_err(errors::ArgCountMismatch::new(
191                Some(arg.span),
192                String::from("type"),
193                1,
194                flds.len(),
195            )))
196        }
197    }
198
199    pub(super) fn check_expr(&mut self, expr: &fhir::Expr<'genv>, expected: &rty::Sort) -> Result {
200        match &expr.kind {
201            fhir::ExprKind::IfThenElse(p, e1, e2) => {
202                self.check_expr(p, &rty::Sort::Bool)?;
203                self.check_expr(e1, expected)?;
204                self.check_expr(e2, expected)?;
205            }
206            fhir::ExprKind::Abs(params, body) => {
207                self.check_abs(expr, params, body, expected)?;
208            }
209            fhir::ExprKind::Record(fields) => self.check_record(expr, fields, expected)?,
210            fhir::ExprKind::Constructor(None, exprs, spread) => {
211                self.check_constructor(expr, exprs, spread, expected)?;
212            }
213            fhir::ExprKind::UnaryOp(..)
214            | fhir::ExprKind::SetLiteral(..)
215            | fhir::ExprKind::BinaryOp(..)
216            | fhir::ExprKind::Dot(..)
217            | fhir::ExprKind::App(..)
218            | fhir::ExprKind::Alias(..)
219            | fhir::ExprKind::Var(..)
220            | fhir::ExprKind::Literal(..)
221            | fhir::ExprKind::BoundedQuant(..)
222            | fhir::ExprKind::Block(..)
223            | fhir::ExprKind::Constructor(..)
224            | fhir::ExprKind::PrimApp(..) => {
225                let found = self.synth_expr(expr)?;
226                let found = self.resolve_vars_if_possible(&found);
227                let expected = self.resolve_vars_if_possible(expected);
228                if !self.is_coercible(&found, &expected, expr.fhir_id) {
229                    return Err(self.emit_sort_mismatch(expr.span, &expected, &found));
230                }
231            }
232            fhir::ExprKind::Err(_) => {
233                // an error has already been reported so we can just skip
234            }
235        }
236        Ok(())
237    }
238
239    pub(super) fn check_loc(&mut self, loc: &fhir::PathExpr) -> Result {
240        let found = self.synth_path(loc);
241        if found == rty::Sort::Loc {
242            Ok(())
243        } else {
244            Err(self.emit_sort_mismatch(loc.span, &rty::Sort::Loc, &found))
245        }
246    }
247
248    fn synth_lit(&mut self, lit: fhir::Lit, expr: &fhir::Expr<'genv>) -> rty::Sort {
249        match lit {
250            fhir::Lit::Int(_, Some(fhir::NumLitKind::Int)) => rty::Sort::Int,
251            fhir::Lit::Int(_, Some(fhir::NumLitKind::Real)) => rty::Sort::Real,
252            fhir::Lit::Int(_, None) => {
253                let sort = self.next_sort_var_with_cstr(rty::SortCstr::NUMERIC);
254                self.sort_of_literal.insert(*expr, sort.clone());
255                sort
256            }
257            fhir::Lit::Bool(_) => rty::Sort::Bool,
258            fhir::Lit::Str(_) => rty::Sort::Str,
259            fhir::Lit::Char(_) => rty::Sort::Char,
260        }
261    }
262
263    fn synth_prim_app(
264        &mut self,
265        op: &fhir::BinOp,
266        e1: &fhir::Expr<'genv>,
267        e2: &fhir::Expr<'genv>,
268        span: Span,
269    ) -> Result<rty::Sort> {
270        let Some((inputs, output)) = prim_op_sort(op) else {
271            return Err(self.emit_err(errors::UnsupportedPrimOp::new(span, *op)));
272        };
273        let [sort1, sort2] = &inputs[..] else {
274            return Err(self.emit_err(errors::ArgCountMismatch::new(
275                Some(span),
276                String::from("primop app"),
277                inputs.len(),
278                2,
279            )));
280        };
281        self.check_expr(e1, sort1)?;
282        self.check_expr(e2, sort2)?;
283        Ok(output)
284    }
285
286    fn synth_expr(&mut self, expr: &fhir::Expr<'genv>) -> Result<rty::Sort> {
287        match expr.kind {
288            fhir::ExprKind::Var(QPathExpr::Resolved(path, _)) => Ok(self.synth_path(&path)),
289            fhir::ExprKind::Var(QPathExpr::TypeRelative(..)) => {
290                Ok(self.synth_type_relative_path(expr))
291            }
292            fhir::ExprKind::Literal(lit) => Ok(self.synth_lit(lit, expr)),
293            fhir::ExprKind::BinaryOp(op, e1, e2) => self.synth_binary_op(expr, op, e1, e2),
294            fhir::ExprKind::PrimApp(op, e1, e2) => self.synth_prim_app(&op, e1, e2, expr.span),
295            fhir::ExprKind::UnaryOp(op, e) => self.synth_unary_op(op, e),
296            fhir::ExprKind::App(callee, args) => {
297                let sort = self.ensure_resolved_path(&callee)?;
298                let Some(poly_fsort) = self.is_coercible_to_func(&sort, callee.fhir_id) else {
299                    return Err(self.emit_err(errors::ExpectedFun::new(callee.span, &sort)));
300                };
301                let fsort = self.instantiate_func_sort(expr, poly_fsort);
302                self.synth_app(fsort, args, expr.span)
303            }
304            fhir::ExprKind::BoundedQuant(.., body) => {
305                self.check_expr(body, &rty::Sort::Bool)?;
306                Ok(rty::Sort::Bool)
307            }
308            fhir::ExprKind::Alias(_alias_reft, args) => {
309                // To check the application we only need the sort of `_alias_reft` which we collected
310                // during early conv, but should we do any extra checks on `_alias_reft`?
311                let fsort = self.sort_of_alias_reft(expr.fhir_id);
312                self.synth_app(fsort, args, expr.span)
313            }
314            fhir::ExprKind::IfThenElse(p, e1, e2) => {
315                self.check_expr(p, &rty::Sort::Bool)?;
316                let sort = self.synth_expr(e1)?;
317                self.check_expr(e2, &sort)?;
318                Ok(sort)
319            }
320            fhir::ExprKind::Dot(base, fld) => {
321                let sort = self.synth_expr(base)?;
322                let sort = self
323                    .fully_resolve(&sort)
324                    .map_err(|_| self.emit_err(errors::CannotInferSort::new(base.span)))?;
325                match &sort {
326                    rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) => {
327                        let (proj, sort) = sort_def
328                            .struct_variant()
329                            .field_by_name(sort_def.did(), sort_args, fld.name)
330                            .ok_or_else(|| self.emit_field_not_found(&sort, fld))?;
331                        self.wfckresults
332                            .field_projs_mut()
333                            .insert(expr.fhir_id, proj);
334                        Ok(sort)
335                    }
336                    rty::Sort::Bool | rty::Sort::Int | rty::Sort::Real => {
337                        Err(self.emit_err(errors::InvalidPrimitiveDotAccess::new(&sort, fld)))
338                    }
339                    _ => Err(self.emit_field_not_found(&sort, fld)),
340                }
341            }
342            fhir::ExprKind::Constructor(Some(path), field_exprs, spread) => {
343                // if we have S then we can synthesize otherwise we fail
344                // first get the sort based on the path - for example S { ... } => S
345                // and we should expect sort to be a struct or enum app
346                let path_def_id = match path.res {
347                    fhir::Res::Def(DefKind::Enum | DefKind::Struct, def_id) => def_id,
348                    _ => span_bug!(expr.span, "unexpected path in constructor"),
349                };
350                let sort_def = self
351                    .genv
352                    .adt_sort_def_of(path_def_id)
353                    .map_err(|e| self.emit_err(e))?;
354                // generate fresh inferred sorts for each param
355                let fresh_args: rty::List<_> = (0..sort_def.param_count())
356                    .map(|_| self.next_sort_var())
357                    .collect();
358                let sort = rty::Sort::App(rty::SortCtor::Adt(sort_def.clone()), fresh_args.clone());
359                // check fields & spread against fresh args
360                self.check_field_exprs(
361                    expr.span,
362                    &sort_def,
363                    &fresh_args,
364                    field_exprs,
365                    &spread,
366                    &sort,
367                )?;
368                Ok(sort)
369            }
370            fhir::ExprKind::Block(decls, body) => {
371                for decl in decls {
372                    self.check_expr(&decl.init, &self.param_sort(decl.param.id))?;
373                }
374                self.synth_expr(body)
375            }
376            fhir::ExprKind::SetLiteral(elems) => {
377                let elem_sort = self.next_sort_var();
378                for elem in elems {
379                    self.check_expr(elem, &elem_sort)?;
380                }
381                Ok(rty::Sort::App(rty::SortCtor::Set, List::singleton(elem_sort)))
382            }
383            _ => Err(self.emit_err(errors::CannotInferSort::new(expr.span))),
384        }
385    }
386
387    fn synth_path(&mut self, path: &fhir::PathExpr) -> rty::Sort {
388        self.node_sort
389            .get(&path.fhir_id)
390            .unwrap_or_else(|| tracked_span_bug!("no sort found for path: `{path:?}`"))
391            .clone()
392    }
393
394    fn synth_type_relative_path(&mut self, expr: &fhir::Expr) -> rty::Sort {
395        self.node_sort
396            .get(&expr.fhir_id)
397            .unwrap_or_else(|| tracked_span_bug!("no sort found for: `{expr:?}`"))
398            .clone()
399    }
400
401    fn synth_binary_op(
402        &mut self,
403        expr: &fhir::Expr<'genv>,
404        op: fhir::BinOp,
405        e1: &fhir::Expr<'genv>,
406        e2: &fhir::Expr<'genv>,
407    ) -> Result<rty::Sort> {
408        match op {
409            fhir::BinOp::Or | fhir::BinOp::And | fhir::BinOp::Iff | fhir::BinOp::Imp => {
410                self.check_expr(e1, &rty::Sort::Bool)?;
411                self.check_expr(e2, &rty::Sort::Bool)?;
412                Ok(rty::Sort::Bool)
413            }
414            fhir::BinOp::Eq | fhir::BinOp::Ne => {
415                let sort = self.next_sort_var();
416                self.check_expr(e1, &sort)?;
417                self.check_expr(e2, &sort)?;
418                Ok(rty::Sort::Bool)
419            }
420            fhir::BinOp::Lt | fhir::BinOp::Le | fhir::BinOp::Gt | fhir::BinOp::Ge => {
421                let sort = self.next_sort_var();
422                self.check_expr(e1, &sort)?;
423                self.check_expr(e2, &sort)?;
424                self.sort_of_bin_op.insert(*expr, sort.clone());
425                Ok(rty::Sort::Bool)
426            }
427            fhir::BinOp::Add
428            | fhir::BinOp::Sub
429            | fhir::BinOp::Mul
430            | fhir::BinOp::Div
431            | fhir::BinOp::Mod
432            | fhir::BinOp::BitAnd
433            | fhir::BinOp::BitOr
434            | fhir::BinOp::BitXor
435            | fhir::BinOp::BitShl
436            | fhir::BinOp::BitShr => {
437                let sort = self.next_sort_var_with_cstr(rty::SortCstr::from_bin_op(op));
438                self.check_expr(e1, &sort)?;
439                self.check_expr(e2, &sort)?;
440                self.sort_of_bin_op.insert(*expr, sort.clone());
441                Ok(sort)
442            }
443        }
444    }
445
446    fn synth_unary_op(&mut self, op: fhir::UnOp, e: &fhir::Expr<'genv>) -> Result<rty::Sort> {
447        match op {
448            fhir::UnOp::Not => {
449                self.check_expr(e, &rty::Sort::Bool)?;
450                Ok(rty::Sort::Bool)
451            }
452            fhir::UnOp::Neg => {
453                self.check_expr(e, &rty::Sort::Int)?;
454                Ok(rty::Sort::Int)
455            }
456        }
457    }
458
459    fn synth_app(
460        &mut self,
461        fsort: FuncSort,
462        args: &[fhir::Expr<'genv>],
463        span: Span,
464    ) -> Result<rty::Sort> {
465        if args.len() != fsort.inputs().len() {
466            return Err(self.emit_err(errors::ArgCountMismatch::new(
467                Some(span),
468                String::from("function"),
469                fsort.inputs().len(),
470                args.len(),
471            )));
472        }
473
474        iter::zip(args, fsort.inputs())
475            .try_for_each_exhaust(|(arg, formal)| self.check_expr(arg, formal))?;
476
477        Ok(fsort.output().clone())
478    }
479
480    fn instantiate_func_sort(
481        &mut self,
482        app_expr: &fhir::Expr<'genv>,
483        fsort: rty::PolyFuncSort,
484    ) -> rty::FuncSort {
485        let args = fsort
486            .params()
487            .map(|kind| {
488                match kind {
489                    rty::SortParamKind::Sort => rty::SortArg::Sort(self.next_sort_var()),
490                    rty::SortParamKind::BvSize => rty::SortArg::BvSize(self.next_bv_size_var()),
491                }
492            })
493            .collect_vec();
494        self.sort_args_of_app
495            .insert(*app_expr, List::from_slice(&args));
496        fsort.instantiate(&args)
497    }
498
499    pub(crate) fn insert_node_sort(&mut self, fhir_id: FhirId, sort: rty::Sort) {
500        self.node_sort.insert(fhir_id, sort);
501    }
502
503    pub(crate) fn sort_of_bty(&self, bty: &fhir::BaseTy) -> rty::Sort {
504        self.node_sort
505            .get(&bty.fhir_id)
506            .unwrap_or_else(|| tracked_span_bug!("no sort found for bty: `{bty:?}`"))
507            .clone()
508    }
509
510    pub(crate) fn insert_path_args(&mut self, fhir_id: FhirId, args: rty::GenericArgs) {
511        self.path_args.insert(fhir_id, args);
512    }
513
514    pub(crate) fn path_args(&self, fhir_id: FhirId) -> rty::GenericArgs {
515        self.path_args
516            .get(&fhir_id)
517            .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
518            .clone()
519    }
520
521    pub(crate) fn insert_sort_for_alias_reft(&mut self, fhir_id: FhirId, fsort: rty::FuncSort) {
522        self.sort_of_alias_reft.insert(fhir_id, fsort);
523    }
524
525    fn sort_of_alias_reft(&self, fhir_id: FhirId) -> rty::FuncSort {
526        self.sort_of_alias_reft
527            .get(&fhir_id)
528            .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
529            .clone()
530    }
531
532    fn deeply_normalize_sorts<T: TypeFoldable + Clone>(
533        genv: GlobalEnv,
534        owner: FluxOwnerId,
535        t: &T,
536    ) -> QueryResult<T> {
537        let infcx = genv
538            .tcx()
539            .infer_ctxt()
540            .with_next_trait_solver(true)
541            .build(TypingMode::non_body_analysis());
542        if let Some(def_id) = owner.def_id() {
543            let def_id = genv.maybe_extern_id(def_id).resolved_id();
544            t.deeply_normalize_sorts(def_id, genv, &infcx)
545        } else {
546            Ok(t.clone())
547        }
548    }
549
550    // FIXME(nilehmann) this is a bit of a hack. We should find a more robust way to do normalization
551    // for sort checking, including normalizing projections. [RJ: normalizing_projections is done now]
552    // Maybe we can do lazy normalization. Once we do that maybe we can also stop
553    // expanding aliases in `ConvCtxt::conv_sort`.
554    pub(crate) fn normalize_sorts(&mut self) -> QueryResult {
555        for sort in self.node_sort.values_mut() {
556            *sort = Self::deeply_normalize_sorts(self.genv, self.owner, sort)?;
557        }
558        for fsort in self.sort_of_alias_reft.values_mut() {
559            *fsort = Self::deeply_normalize_sorts(self.genv, self.owner, fsort)?;
560        }
561        Ok(())
562    }
563}
564
565impl<'genv> InferCtxt<'genv, '_> {
566    pub(super) fn declare_param(&mut self, param: fhir::RefineParam<'genv>, sort: rty::Sort) {
567        self.params.insert(param.id, (param, sort));
568    }
569
570    /// Whether a value of `sort1` can be automatically coerced to a value of `sort2`. A value of an
571    /// [`rty::SortCtor::Adt`] sort with a single field of sort `s` can be coerced to a value of sort
572    /// `s` and vice versa, i.e., we can automatically project the field out of the record or inject
573    /// a value into a record.
574    fn is_coercible(&mut self, sort1: &rty::Sort, sort2: &rty::Sort, fhir_id: FhirId) -> bool {
575        if self.try_equate(sort1, sort2).is_some() {
576            return true;
577        }
578
579        let mut sort1 = sort1.clone();
580        let mut sort2 = sort2.clone();
581
582        let mut coercions = vec![];
583        if let Some((def_id, sort)) = self.is_single_field_struct(&sort1) {
584            coercions.push(rty::Coercion::Project(def_id));
585            sort1 = sort.clone();
586        }
587        if let Some((def_id, sort)) = self.is_single_field_struct(&sort2) {
588            coercions.push(rty::Coercion::Inject(def_id));
589            sort2 = sort.clone();
590        }
591        self.wfckresults.coercions_mut().insert(fhir_id, coercions);
592        self.try_equate(&sort1, &sort2).is_some()
593    }
594
595    fn is_coercible_from_func(
596        &mut self,
597        sort: &rty::Sort,
598        fhir_id: FhirId,
599    ) -> Option<rty::PolyFuncSort> {
600        if let rty::Sort::Func(fsort) = sort {
601            Some(fsort.clone())
602        } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
603            self.wfckresults
604                .coercions_mut()
605                .insert(fhir_id, vec![rty::Coercion::Inject(def_id)]);
606            Some(fsort.clone())
607        } else {
608            None
609        }
610    }
611
612    fn is_coercible_to_func(
613        &mut self,
614        sort: &rty::Sort,
615        fhir_id: FhirId,
616    ) -> Option<rty::PolyFuncSort> {
617        if let rty::Sort::Func(fsort) = sort {
618            Some(fsort.clone())
619        } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
620            self.wfckresults
621                .coercions_mut()
622                .insert(fhir_id, vec![rty::Coercion::Project(def_id)]);
623            Some(fsort.clone())
624        } else {
625            None
626        }
627    }
628
629    fn try_equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
630        let sort1 = self.resolve_vars_if_possible(sort1);
631        let sort2 = self.resolve_vars_if_possible(sort2);
632        self.try_equate_inner(&sort1, &sort2)
633    }
634
635    fn try_equate_inner(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
636        match (sort1, sort2) {
637            (rty::Sort::Infer(vid1), rty::Sort::Infer(vid2)) => {
638                self.sort_unification_table
639                    .unify_var_var(*vid1, *vid2)
640                    .ok()?;
641            }
642            (rty::Sort::Infer(vid), sort) | (sort, rty::Sort::Infer(vid)) => {
643                self.sort_unification_table
644                    .unify_var_value(*vid, rty::SortVarVal::Solved(sort.clone()))
645                    .ok()?;
646            }
647            (rty::Sort::App(ctor1, args1), rty::Sort::App(ctor2, args2)) => {
648                if ctor1 != ctor2 || args1.len() != args2.len() {
649                    return None;
650                }
651                let mut args = vec![];
652                for (s1, s2) in args1.iter().zip(args2.iter()) {
653                    args.push(self.try_equate_inner(s1, s2)?);
654                }
655            }
656            (rty::Sort::BitVec(size1), rty::Sort::BitVec(size2)) => {
657                self.try_equate_bv_sizes(*size1, *size2)?;
658            }
659            _ if sort1 == sort2 => {}
660            _ => return None,
661        }
662        Some(sort1.clone())
663    }
664
665    fn try_equate_bv_sizes(
666        &mut self,
667        size1: rty::BvSize,
668        size2: rty::BvSize,
669    ) -> Option<rty::BvSize> {
670        match (size1, size2) {
671            (rty::BvSize::Infer(vid1), rty::BvSize::Infer(vid2)) => {
672                self.bv_size_unification_table
673                    .unify_var_var(vid1, vid2)
674                    .ok()?;
675            }
676            (rty::BvSize::Infer(vid), size) | (size, rty::BvSize::Infer(vid)) => {
677                self.bv_size_unification_table
678                    .unify_var_value(vid, Some(size))
679                    .ok()?;
680            }
681            _ if size1 == size2 => {}
682            _ => return None,
683        }
684        Some(size1)
685    }
686
687    fn next_sort_var(&mut self) -> rty::Sort {
688        rty::Sort::Infer(self.next_sort_vid(Default::default()))
689    }
690
691    fn next_sort_var_with_cstr(&mut self, cstr: rty::SortCstr) -> rty::Sort {
692        rty::Sort::Infer(self.next_sort_vid(rty::SortVarVal::Unsolved(cstr)))
693    }
694
695    pub(crate) fn next_sort_vid(&mut self, val: rty::SortVarVal) -> rty::SortVid {
696        self.sort_unification_table.new_key(val)
697    }
698
699    fn next_bv_size_var(&mut self) -> rty::BvSize {
700        rty::BvSize::Infer(self.next_bv_size_vid())
701    }
702
703    fn next_bv_size_vid(&mut self) -> rty::BvSizeVid {
704        self.bv_size_unification_table.new_key(None)
705    }
706
707    fn ensure_resolved_path(&mut self, path: &fhir::PathExpr) -> Result<rty::Sort> {
708        let sort = self.synth_path(path);
709        self.fully_resolve(&sort)
710            .map_err(|_| self.emit_err(errors::CannotInferSort::new(path.span)))
711    }
712
713    fn is_single_field_struct(&mut self, sort: &rty::Sort) -> Option<(DefId, rty::Sort)> {
714        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = sort
715            && let Some(variant) = sort_def.opt_struct_variant()
716            && let [sort] = &variant.field_sorts(sort_args)[..]
717        {
718            Some((sort_def.did(), sort.clone()))
719        } else {
720            None
721        }
722    }
723
724    pub(crate) fn into_results(mut self) -> Result<WfckResults> {
725        // Make sure the int literals are fully resolved. This must happen before everything else
726        // such that we properly apply the fallback for unconstrained num vars.
727        for (node, sort) in std::mem::take(&mut self.sort_of_literal) {
728            // Fallback to `int` when a num variable is unconstrained. Note that we unconditionally
729            // unify the variable. This is fine because if the variable has already been unified,
730            // the operation will fail and this won't have any effect. Also note that unifying a
731            // variable could solve variables that appear later in this for loop. This is also fine
732            // because the order doesn't matter as we are unifying everything to `int`.
733            if let rty::Sort::Infer(vid) = &sort {
734                let _ = self
735                    .sort_unification_table
736                    .unify_var_value(*vid, rty::SortVarVal::Solved(rty::Sort::Int));
737            }
738
739            let sort = self
740                .fully_resolve(&sort)
741                .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
742            self.wfckresults.node_sorts_mut().insert(node.fhir_id, sort);
743        }
744
745        // Make sure the binary operators are fully resolved
746        for (node, sort) in std::mem::take(&mut self.sort_of_bin_op) {
747            let sort = self
748                .fully_resolve(&sort)
749                .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
750            self.wfckresults
751                .bin_op_sorts_mut()
752                .insert(node.fhir_id, sort);
753        }
754
755        let allow_uninterpreted_cast = self
756            .owner
757            .def_id()
758            .map_or_else(flux_config::allow_uninterpreted_cast, |def_id| {
759                self.genv.infer_opts(def_id).allow_uninterpreted_cast
760            });
761
762        // Make sure that function applications are fully resolved
763        for (node, sort_args) in std::mem::take(&mut self.sort_args_of_app) {
764            let mut res = vec![];
765            for sort_arg in &sort_args {
766                let sort_arg = match sort_arg {
767                    rty::SortArg::Sort(sort) => {
768                        let sort = self
769                            .fully_resolve(sort)
770                            .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
771                        rty::SortArg::Sort(sort)
772                    }
773                    rty::SortArg::BvSize(rty::BvSize::Infer(vid)) => {
774                        let size = self
775                            .bv_size_unification_table
776                            .probe_value(*vid)
777                            .ok_or_else(|| {
778                                self.emit_err(errors::CannotInferSort::new(node.span))
779                            })?;
780                        rty::SortArg::BvSize(size)
781                    }
782                    _ => sort_arg.clone(),
783                };
784                res.push(sort_arg);
785            }
786            if let fhir::ExprKind::App(callee, _) = node.kind
787                && matches!(callee.res, fhir::Res::GlobalFunc(fhir::SpecFuncKind::Cast))
788            {
789                let [rty::SortArg::Sort(from), rty::SortArg::Sort(to)] = &res[..] else {
790                    span_bug!(node.span, "invalid sort args!")
791                };
792                if !allow_uninterpreted_cast
793                    && matches!(from.cast_kind(to), rty::CastKind::Uninterpreted)
794                {
795                    return Err(self.emit_err(errors::InvalidCast::new(node.span, from, to)));
796                }
797            }
798            self.wfckresults
799                .fn_app_sorts_mut()
800                .insert(node.fhir_id, res.into());
801        }
802
803        // Make sure all parameters are fully resolved
804        for (_, (param, sort)) in std::mem::take(&mut self.params) {
805            let sort = self
806                .fully_resolve(&sort)
807                .map_err(|_| self.emit_err(errors::SortAnnotationNeeded::new(&param)))?;
808            self.wfckresults.param_sorts_mut().insert(param.id, sort);
809        }
810
811        Ok(self.wfckresults)
812    }
813
814    pub(crate) fn infer_mode(&self, id: fhir::ParamId) -> fhir::InferMode {
815        fhir::InferMode::from_param_kind(self.params[&id].0.kind)
816    }
817
818    #[track_caller]
819    pub(crate) fn param_sort(&self, id: fhir::ParamId) -> rty::Sort {
820        self.params
821            .get(&id)
822            .unwrap_or_else(|| bug!("no entry found for `{id:?}`"))
823            .1
824            .clone()
825    }
826
827    fn shallow_resolve(&mut self, sort: &rty::Sort) -> rty::Sort {
828        sort.fold_with(&mut ShallowResolver { infcx: self })
829    }
830
831    fn resolve_vars_if_possible(&mut self, sort: &rty::Sort) -> rty::Sort {
832        sort.fold_with(&mut OpportunisticResolver { infcx: self })
833    }
834
835    pub(crate) fn fully_resolve(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, ()> {
836        sort.try_fold_with(&mut FullResolver { infcx: self })
837    }
838}
839
840/// Before the main sort inference, we do a first traversal checking all implicitly scoped parameters
841/// declared with `@` or `#` and infer their sort based on the type they are indexing, e.g., if `n` was
842/// declared as `i32[@n]`, we infer `int` for `n`.
843///
844/// This prepass is necessary because sometimes the order in which we traverse expressions can
845/// affect what we can infer. By resolving implicit parameters first, we ensure more consistent and
846/// complete inference regardless of how expressions are later traversed.
847///
848/// It should be possible to improve sort checking (e.g., by allowing partially resolved sorts in
849/// function position) such that we don't need this anymore.
850pub(crate) struct ImplicitParamInferer<'a, 'genv, 'tcx> {
851    infcx: &'a mut InferCtxt<'genv, 'tcx>,
852    errors: Errors<'genv>,
853}
854
855impl<'a, 'genv, 'tcx> ImplicitParamInferer<'a, 'genv, 'tcx> {
856    pub(crate) fn infer(
857        infcx: &'a mut InferCtxt<'genv, 'tcx>,
858        node: &fhir::OwnerNode<'genv>,
859    ) -> Result {
860        let errors = Errors::new(infcx.genv.sess());
861        let mut vis = Self { infcx, errors };
862        vis.visit_node(node);
863        vis.errors.to_result()
864    }
865
866    fn infer_implicit_params(&mut self, idx: &fhir::Expr, expected: &rty::Sort) {
867        match idx.kind {
868            fhir::ExprKind::Var(QPathExpr::Resolved(var, Some(_))) => {
869                let (_, id) = var.res.expect_param();
870                let found = self.infcx.param_sort(id);
871                self.infcx.try_equate(&found, expected).unwrap_or_else(|| {
872                    span_bug!(idx.span, "failed to equate sorts: `{found:?}` `{expected:?}`")
873                });
874            }
875            fhir::ExprKind::Record(flds) => {
876                if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
877                    let sorts = sort_def.struct_variant().field_sorts(sort_args);
878                    if flds.len() != sorts.len() {
879                        self.errors.emit(errors::ArgCountMismatch::new(
880                            Some(idx.span),
881                            String::from("type"),
882                            sorts.len(),
883                            flds.len(),
884                        ));
885                        return;
886                    }
887                    for (f, sort) in iter::zip(flds, &sorts) {
888                        self.infer_implicit_params(f, sort);
889                    }
890                } else {
891                    self.errors.emit(errors::ArgCountMismatch::new(
892                        Some(idx.span),
893                        String::from("type"),
894                        1,
895                        flds.len(),
896                    ));
897                }
898            }
899            _ => {}
900        }
901    }
902}
903
904impl<'genv> fhir::visit::Visitor<'genv> for ImplicitParamInferer<'_, 'genv, '_> {
905    fn visit_ty(&mut self, ty: &fhir::Ty<'genv>) {
906        if let fhir::TyKind::Indexed(bty, idx) = &ty.kind {
907            let expected = self.infcx.sort_of_bty(bty);
908            self.infer_implicit_params(idx, &expected);
909        }
910        fhir::visit::walk_ty(self, ty);
911    }
912}
913
914impl InferCtxt<'_, '_> {
915    #[track_caller]
916    fn emit_sort_mismatch(
917        &mut self,
918        span: Span,
919        expected: &rty::Sort,
920        found: &rty::Sort,
921    ) -> ErrorGuaranteed {
922        let expected = self.resolve_vars_if_possible(expected);
923        let found = self.resolve_vars_if_possible(found);
924        self.emit_err(errors::SortMismatch::new(span, expected, found))
925    }
926
927    fn emit_field_not_found(&mut self, sort: &rty::Sort, field: Ident) -> ErrorGuaranteed {
928        self.emit_err(errors::FieldNotFound::new(sort.clone(), field))
929    }
930
931    #[track_caller]
932    fn emit_err<'b>(&'b self, err: impl Diagnostic<'b>) -> ErrorGuaranteed {
933        self.genv.sess().emit_err(err)
934    }
935}
936
937struct ShallowResolver<'a, 'genv, 'tcx> {
938    infcx: &'a mut InferCtxt<'genv, 'tcx>,
939}
940
941impl TypeFolder for ShallowResolver<'_, '_, '_> {
942    fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
943        match sort {
944            rty::Sort::Infer(vid) => {
945                // if `sort` is a sort variable, it can be resolved to a bit-vec variable,
946                // which can then be recursively resolved, hence the recursion. Note though that
947                // we prevent sort variables from unifying to other sort variables directly (though
948                // they may be embedded structurally), so this recursion should always be of very
949                // limited depth.
950                self.infcx
951                    .sort_unification_table
952                    .probe_value(*vid)
953                    .map_solved(|sort| sort.fold_with(self))
954                    .solved_or(sort)
955            }
956            rty::Sort::BitVec(rty::BvSize::Infer(vid)) => {
957                self.infcx
958                    .bv_size_unification_table
959                    .probe_value(*vid)
960                    .map(rty::Sort::BitVec)
961                    .unwrap_or(sort.clone())
962            }
963            _ => sort.clone(),
964        }
965    }
966}
967
968struct OpportunisticResolver<'a, 'genv, 'tcx> {
969    infcx: &'a mut InferCtxt<'genv, 'tcx>,
970}
971
972impl TypeFolder for OpportunisticResolver<'_, '_, '_> {
973    fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
974        let s = self.infcx.shallow_resolve(sort);
975        s.super_fold_with(self)
976    }
977}
978
979struct FullResolver<'a, 'genv, 'tcx> {
980    infcx: &'a mut InferCtxt<'genv, 'tcx>,
981}
982
983impl FallibleTypeFolder for FullResolver<'_, '_, '_> {
984    type Error = ();
985
986    fn try_fold_sort(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, Self::Error> {
987        let s = self.infcx.shallow_resolve(sort);
988        match s {
989            rty::Sort::Infer(_) | rty::Sort::BitVec(rty::BvSize::Infer(_)) => Err(()),
990            _ => s.try_super_fold_with(self),
991        }
992    }
993}
994
995/// Map to associate data to a node (i.e., an expression).
996///
997/// Used to record elaborated information.
998#[derive_where(Default)]
999struct NodeMap<'genv, T> {
1000    map: FxHashMap<FhirId, (fhir::Expr<'genv>, T)>,
1001}
1002
1003impl<'genv, T> NodeMap<'genv, T> {
1004    /// Add a `node` to the map with associated `data`
1005    fn insert(&mut self, node: fhir::Expr<'genv>, data: T) {
1006        assert!(self.map.insert(node.fhir_id, (node, data)).is_none());
1007    }
1008}
1009
1010impl<'genv, T> IntoIterator for NodeMap<'genv, T> {
1011    type Item = (fhir::Expr<'genv>, T);
1012
1013    type IntoIter = std::collections::hash_map::IntoValues<FhirId, Self::Item>;
1014
1015    fn into_iter(self) -> Self::IntoIter {
1016        self.map.into_values()
1017    }
1018}