flux_fhir_analysis/wf/
sortck.rs

1use std::iter;
2
3use derive_where::derive_where;
4use ena::unify::InPlaceUnificationTable;
5use flux_common::{bug, iter::IterExt, span_bug, tracked_span_bug};
6use flux_errors::{ErrorGuaranteed, Errors};
7use flux_infer::projections::NormalizeExt;
8use flux_middle::{
9    fhir::{self, FhirId, FluxOwnerId, visit::Visitor as _},
10    global_env::GlobalEnv,
11    queries::QueryResult,
12    rty::{
13        self, AdtSortDef, FuncSort, List, WfckResults,
14        fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable},
15    },
16};
17use itertools::{Itertools, izip};
18use rustc_data_structures::unord::UnordMap;
19use rustc_errors::Diagnostic;
20use rustc_hash::FxHashMap;
21use rustc_hir::def::DefKind;
22use rustc_middle::ty::TypingMode;
23use rustc_span::{Span, def_id::DefId, symbol::Ident};
24
25use super::errors;
26use crate::rustc_infer::infer::TyCtxtInferExt;
27
28type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
29
30pub(super) struct InferCtxt<'genv, 'tcx> {
31    pub genv: GlobalEnv<'genv, 'tcx>,
32    pub owner: FluxOwnerId,
33    pub wfckresults: WfckResults,
34    sort_unification_table: InPlaceUnificationTable<rty::SortVid>,
35    num_unification_table: InPlaceUnificationTable<rty::NumVid>,
36    bv_size_unification_table: InPlaceUnificationTable<rty::BvSizeVid>,
37    params: FxHashMap<fhir::ParamId, (fhir::RefineParam<'genv>, rty::Sort)>,
38    node_sort: FxHashMap<FhirId, rty::Sort>,
39    path_args: UnordMap<FhirId, rty::GenericArgs>,
40    sort_of_alias_reft: FxHashMap<FhirId, rty::FuncSort>,
41    sort_of_literal: NodeMap<'genv, rty::Sort>,
42    sort_of_bin_op: NodeMap<'genv, rty::Sort>,
43    sort_args_of_app: NodeMap<'genv, List<rty::SortArg>>,
44}
45
46pub fn prim_op_sort(op: &fhir::BinOp) -> Option<(Vec<rty::Sort>, rty::Sort)> {
47    match op {
48        fhir::BinOp::BitAnd
49        | fhir::BinOp::BitOr
50        | fhir::BinOp::BitXor
51        | fhir::BinOp::BitShl
52        | fhir::BinOp::BitShr => Some((vec![rty::Sort::Int, rty::Sort::Int], rty::Sort::Int)),
53        _ => None,
54    }
55}
56
57impl<'genv, 'tcx> InferCtxt<'genv, 'tcx> {
58    pub(super) fn new(genv: GlobalEnv<'genv, 'tcx>, owner: FluxOwnerId) -> Self {
59        // We skip 0 because that's used for sort dummy self types during conv.
60        let mut sort_unification_table = InPlaceUnificationTable::new();
61        sort_unification_table.new_key(None);
62        Self {
63            genv,
64            owner,
65            wfckresults: WfckResults::new(owner),
66            sort_unification_table,
67            num_unification_table: InPlaceUnificationTable::new(),
68            bv_size_unification_table: InPlaceUnificationTable::new(),
69            params: Default::default(),
70            node_sort: Default::default(),
71            path_args: Default::default(),
72            sort_of_alias_reft: Default::default(),
73            sort_of_literal: Default::default(),
74            sort_of_bin_op: Default::default(),
75            sort_args_of_app: Default::default(),
76        }
77    }
78
79    fn check_abs(
80        &mut self,
81        expr: &fhir::Expr<'genv>,
82        params: &[fhir::RefineParam],
83        body: &fhir::Expr<'genv>,
84        expected: &rty::Sort,
85    ) -> Result {
86        let Some(fsort) = self.is_coercible_from_func(expected, expr.fhir_id) else {
87            return Err(self.emit_err(errors::UnexpectedFun::new(expr.span, expected)));
88        };
89
90        let fsort = fsort.expect_mono();
91
92        if params.len() != fsort.inputs().len() {
93            return Err(self.emit_err(errors::ParamCountMismatch::new(
94                expr.span,
95                fsort.inputs().len(),
96                params.len(),
97            )));
98        }
99        iter::zip(params, fsort.inputs()).try_for_each_exhaust(|(param, expected)| {
100            let found = self.param_sort(param.id);
101            if self.try_equate(&found, expected).is_none() {
102                return Err(self.emit_sort_mismatch(param.span, expected, &found));
103            }
104            Ok(())
105        })?;
106        self.check_expr(body, fsort.output())?;
107        self.wfckresults
108            .node_sorts_mut()
109            .insert(body.fhir_id, fsort.output().clone());
110        Ok(())
111    }
112
113    fn check_field_exprs(
114        &mut self,
115        expr_span: Span,
116        sort_def: &AdtSortDef,
117        sort_args: &[rty::Sort],
118        field_exprs: &[fhir::FieldExpr<'genv>],
119        spread: &Option<&fhir::Spread<'genv>>,
120        expected: &rty::Sort,
121    ) -> Result {
122        let sort_by_field_name = sort_def.struct_variant().sort_by_field_name(sort_args);
123        let mut used_fields = FxHashMap::default();
124        for expr in field_exprs {
125            // make sure that the field is actually a field
126            let Some(sort) = sort_by_field_name.get(&expr.ident.name) else {
127                return Err(self.emit_err(errors::FieldNotFound::new(expected.clone(), expr.ident)));
128            };
129            if let Some(old_field) = used_fields.insert(expr.ident.name, expr.ident) {
130                return Err(self.emit_err(errors::DuplicateFieldUsed::new(expr.ident, old_field)));
131            }
132            self.check_expr(&expr.expr, sort)?;
133        }
134        if let Some(spread) = spread {
135            // must check that the spread is of the same sort as the constructor
136            self.check_expr(&spread.expr, expected)
137        } else if sort_by_field_name.len() != used_fields.len() {
138            // emit an error because all fields are not used
139            let missing_fields = sort_by_field_name
140                .into_keys()
141                .filter(|x| !used_fields.contains_key(x))
142                .collect();
143            Err(self.emit_err(errors::ConstructorMissingFields::new(expr_span, missing_fields)))
144        } else {
145            Ok(())
146        }
147    }
148
149    fn check_constructor(
150        &mut self,
151        expr: &fhir::Expr,
152        field_exprs: &[fhir::FieldExpr<'genv>],
153        spread: &Option<&fhir::Spread<'genv>>,
154        expected: &rty::Sort,
155    ) -> Result {
156        let expected = self.resolve_vars_if_possible(expected);
157        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = &expected {
158            self.wfckresults
159                .record_ctors_mut()
160                .insert(expr.fhir_id, sort_def.did());
161            self.check_field_exprs(expr.span, sort_def, sort_args, field_exprs, spread, &expected)?;
162            Ok(())
163        } else {
164            Err(self.emit_err(errors::UnexpectedConstructor::new(expr.span, &expected)))
165        }
166    }
167
168    fn check_record(
169        &mut self,
170        arg: &fhir::Expr<'genv>,
171        flds: &[fhir::Expr<'genv>],
172        expected: &rty::Sort,
173    ) -> Result {
174        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
175            let sorts = sort_def.struct_variant().field_sorts(sort_args);
176            if flds.len() != sorts.len() {
177                return Err(self.emit_err(errors::ArgCountMismatch::new(
178                    Some(arg.span),
179                    String::from("type"),
180                    sorts.len(),
181                    flds.len(),
182                )));
183            }
184            self.wfckresults
185                .record_ctors_mut()
186                .insert(arg.fhir_id, sort_def.did());
187
188            izip!(flds, &sorts)
189                .map(|(arg, expected)| self.check_expr(arg, expected))
190                .try_collect_exhaust()
191        } else {
192            Err(self.emit_err(errors::ArgCountMismatch::new(
193                Some(arg.span),
194                String::from("type"),
195                1,
196                flds.len(),
197            )))
198        }
199    }
200
201    pub(super) fn check_expr(&mut self, expr: &fhir::Expr<'genv>, expected: &rty::Sort) -> Result {
202        match &expr.kind {
203            fhir::ExprKind::IfThenElse(p, e1, e2) => {
204                self.check_expr(p, &rty::Sort::Bool)?;
205                self.check_expr(e1, expected)?;
206                self.check_expr(e2, expected)?;
207            }
208            fhir::ExprKind::Abs(params, body) => {
209                self.check_abs(expr, params, body, expected)?;
210            }
211            fhir::ExprKind::Record(fields) => self.check_record(expr, fields, expected)?,
212            fhir::ExprKind::Constructor(None, exprs, spread) => {
213                self.check_constructor(expr, exprs, spread, expected)?;
214            }
215            fhir::ExprKind::UnaryOp(..)
216            | fhir::ExprKind::BinaryOp(..)
217            | fhir::ExprKind::Dot(..)
218            | fhir::ExprKind::App(..)
219            | fhir::ExprKind::Alias(..)
220            | fhir::ExprKind::Var(..)
221            | fhir::ExprKind::Literal(..)
222            | fhir::ExprKind::BoundedQuant(..)
223            | fhir::ExprKind::Block(..)
224            | fhir::ExprKind::Constructor(..)
225            | fhir::ExprKind::PrimApp(..) => {
226                let found = self.synth_expr(expr)?;
227                let found = self.resolve_vars_if_possible(&found);
228                let expected = self.resolve_vars_if_possible(expected);
229                if !self.is_coercible(&found, &expected, expr.fhir_id) {
230                    return Err(self.emit_sort_mismatch(expr.span, &expected, &found));
231                }
232            }
233            fhir::ExprKind::Err(_) => {
234                // an error has already been reported so we can just skip
235            }
236        }
237        Ok(())
238    }
239
240    pub(super) fn check_loc(&mut self, loc: &fhir::PathExpr) -> Result {
241        let found = self.synth_path(loc);
242        if found == rty::Sort::Loc {
243            Ok(())
244        } else {
245            Err(self.emit_sort_mismatch(loc.span, &rty::Sort::Loc, &found))
246        }
247    }
248
249    fn synth_lit(&mut self, lit: fhir::Lit, expr: &fhir::Expr<'genv>) -> rty::Sort {
250        match lit {
251            fhir::Lit::Int(_, Some(fhir::NumLitKind::Int)) => rty::Sort::Int,
252            fhir::Lit::Int(_, Some(fhir::NumLitKind::Real)) => rty::Sort::Real,
253            fhir::Lit::Int(_, None) => {
254                let sort = self.next_num_var();
255                self.sort_of_literal.insert(*expr, sort.clone());
256                sort
257            }
258            fhir::Lit::Bool(_) => rty::Sort::Bool,
259            fhir::Lit::Str(_) => rty::Sort::Str,
260            fhir::Lit::Char(_) => rty::Sort::Char,
261        }
262    }
263
264    fn synth_prim_app(
265        &mut self,
266        op: &fhir::BinOp,
267        e1: &fhir::Expr<'genv>,
268        e2: &fhir::Expr<'genv>,
269        span: Span,
270    ) -> Result<rty::Sort> {
271        let Some((inputs, output)) = prim_op_sort(op) else {
272            return Err(self.emit_err(errors::UnsupportedPrimOp::new(span, *op)));
273        };
274        let [sort1, sort2] = &inputs[..] else {
275            return Err(self.emit_err(errors::ArgCountMismatch::new(
276                Some(span),
277                String::from("primop app"),
278                inputs.len(),
279                2,
280            )));
281        };
282        self.check_expr(e1, sort1)?;
283        self.check_expr(e2, sort2)?;
284        Ok(output)
285    }
286
287    fn synth_expr(&mut self, expr: &fhir::Expr<'genv>) -> Result<rty::Sort> {
288        match expr.kind {
289            fhir::ExprKind::Var(var, _) => Ok(self.synth_path(&var)),
290            fhir::ExprKind::Literal(lit) => Ok(self.synth_lit(lit, expr)),
291            fhir::ExprKind::BinaryOp(op, e1, e2) => self.synth_binary_op(expr, op, e1, e2),
292            fhir::ExprKind::PrimApp(op, e1, e2) => self.synth_prim_app(&op, e1, e2, expr.span),
293            fhir::ExprKind::UnaryOp(op, e) => self.synth_unary_op(op, e),
294            fhir::ExprKind::App(callee, args) => {
295                let sort = self.ensure_resolved_path(&callee)?;
296                let Some(poly_fsort) = self.is_coercible_to_func(&sort, callee.fhir_id) else {
297                    return Err(self.emit_err(errors::ExpectedFun::new(callee.span, &sort)));
298                };
299                let fsort = self.instantiate_func_sort(expr, poly_fsort);
300                self.synth_app(fsort, args, expr.span)
301            }
302            fhir::ExprKind::BoundedQuant(.., body) => {
303                self.check_expr(body, &rty::Sort::Bool)?;
304                Ok(rty::Sort::Bool)
305            }
306            fhir::ExprKind::Alias(_alias_reft, args) => {
307                // To check the application we only need the sort of `_alias_reft` which we collected
308                // during early conv, but should we do any extra checks on `_alias_reft`?
309                let fsort = self.sort_of_alias_reft(expr.fhir_id);
310                self.synth_app(fsort, args, expr.span)
311            }
312            fhir::ExprKind::IfThenElse(p, e1, e2) => {
313                self.check_expr(p, &rty::Sort::Bool)?;
314                let sort = self.synth_expr(e1)?;
315                self.check_expr(e2, &sort)?;
316                Ok(sort)
317            }
318            fhir::ExprKind::Dot(base, fld) => {
319                let sort = self.synth_expr(base)?;
320                let sort = self
321                    .fully_resolve(&sort)
322                    .map_err(|_| self.emit_err(errors::CannotInferSort::new(base.span)))?;
323                match &sort {
324                    rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) => {
325                        let (proj, sort) = sort_def
326                            .struct_variant()
327                            .field_by_name(sort_def.did(), sort_args, fld.name)
328                            .ok_or_else(|| self.emit_field_not_found(&sort, fld))?;
329                        self.wfckresults
330                            .field_projs_mut()
331                            .insert(expr.fhir_id, proj);
332                        Ok(sort)
333                    }
334                    rty::Sort::Bool | rty::Sort::Int | rty::Sort::Real => {
335                        Err(self.emit_err(errors::InvalidPrimitiveDotAccess::new(&sort, fld)))
336                    }
337                    _ => Err(self.emit_field_not_found(&sort, fld)),
338                }
339            }
340            fhir::ExprKind::Constructor(Some(path), field_exprs, spread) => {
341                // if we have S then we can synthesize otherwise we fail
342                // first get the sort based on the path - for example S { ... } => S
343                // and we should expect sort to be a struct or enum app
344                let path_def_id = match path.res {
345                    fhir::Res::Def(DefKind::Enum | DefKind::Struct, def_id) => def_id,
346                    _ => span_bug!(expr.span, "unexpected path in constructor"),
347                };
348                let sort_def = self
349                    .genv
350                    .adt_sort_def_of(path_def_id)
351                    .map_err(|e| self.emit_err(e))?;
352                // generate fresh inferred sorts for each param
353                let fresh_args: rty::List<_> = (0..sort_def.param_count())
354                    .map(|_| self.next_sort_var())
355                    .collect();
356                let sort = rty::Sort::App(rty::SortCtor::Adt(sort_def.clone()), fresh_args.clone());
357                // check fields & spread against fresh args
358                self.check_field_exprs(
359                    expr.span,
360                    &sort_def,
361                    &fresh_args,
362                    field_exprs,
363                    &spread,
364                    &sort,
365                )?;
366                Ok(sort)
367            }
368            fhir::ExprKind::Block(decls, body) => {
369                for decl in decls {
370                    self.check_expr(&decl.init, &self.param_sort(decl.param.id))?;
371                }
372                self.synth_expr(body)
373            }
374            _ => Err(self.emit_err(errors::CannotInferSort::new(expr.span))),
375        }
376    }
377
378    fn synth_path(&mut self, path: &fhir::PathExpr) -> rty::Sort {
379        self.node_sort
380            .get(&path.fhir_id)
381            .unwrap_or_else(|| tracked_span_bug!("no sort found for path: `{path:?}`"))
382            .clone()
383    }
384
385    fn check_integral(&mut self, op: fhir::BinOp, sort: &rty::Sort, span: Span) -> Result {
386        if matches!(op, fhir::BinOp::Mod) {
387            let sort = self
388                .fully_resolve(sort)
389                .map_err(|_| self.emit_err(errors::CannotInferSort::new(span)))?;
390            if !matches!(sort, rty::Sort::Int | rty::Sort::BitVec(_)) {
391                span_bug!(span, "unexpected sort {sort:?} for operator {op:?}");
392            }
393        }
394        Ok(())
395    }
396
397    fn synth_binary_op(
398        &mut self,
399        expr: &fhir::Expr<'genv>,
400        op: fhir::BinOp,
401        e1: &fhir::Expr<'genv>,
402        e2: &fhir::Expr<'genv>,
403    ) -> Result<rty::Sort> {
404        match op {
405            fhir::BinOp::Or | fhir::BinOp::And | fhir::BinOp::Iff | fhir::BinOp::Imp => {
406                self.check_expr(e1, &rty::Sort::Bool)?;
407                self.check_expr(e2, &rty::Sort::Bool)?;
408                Ok(rty::Sort::Bool)
409            }
410            fhir::BinOp::Eq | fhir::BinOp::Ne => {
411                let sort = self.next_sort_var();
412                self.check_expr(e1, &sort)?;
413                self.check_expr(e2, &sort)?;
414                Ok(rty::Sort::Bool)
415            }
416            fhir::BinOp::Lt | fhir::BinOp::Le | fhir::BinOp::Gt | fhir::BinOp::Ge => {
417                let sort = self.next_sort_var();
418                self.check_expr(e1, &sort)?;
419                self.check_expr(e2, &sort)?;
420                self.sort_of_bin_op.insert(*expr, sort.clone());
421                Ok(rty::Sort::Bool)
422            }
423            fhir::BinOp::Add
424            | fhir::BinOp::Sub
425            | fhir::BinOp::Mul
426            | fhir::BinOp::Div
427            | fhir::BinOp::Mod => {
428                let sort = self.next_num_var();
429                self.check_expr(e1, &sort)?;
430                self.check_expr(e2, &sort)?;
431                self.sort_of_bin_op.insert(*expr, sort.clone());
432                // check that the sort is integral for mod
433                self.check_integral(op, &sort, expr.span)?;
434
435                Ok(sort)
436            }
437            fhir::BinOp::BitAnd
438            | fhir::BinOp::BitOr
439            | fhir::BinOp::BitXor
440            | fhir::BinOp::BitShl
441            | fhir::BinOp::BitShr => {
442                let sort = rty::Sort::BitVec(self.next_bv_size_var());
443                self.check_expr(e1, &sort)?;
444                self.check_expr(e2, &sort)?;
445                Ok(sort)
446            }
447        }
448    }
449
450    fn synth_unary_op(&mut self, op: fhir::UnOp, e: &fhir::Expr<'genv>) -> Result<rty::Sort> {
451        match op {
452            fhir::UnOp::Not => {
453                self.check_expr(e, &rty::Sort::Bool)?;
454                Ok(rty::Sort::Bool)
455            }
456            fhir::UnOp::Neg => {
457                self.check_expr(e, &rty::Sort::Int)?;
458                Ok(rty::Sort::Int)
459            }
460        }
461    }
462
463    fn synth_app(
464        &mut self,
465        fsort: FuncSort,
466        args: &[fhir::Expr<'genv>],
467        span: Span,
468    ) -> Result<rty::Sort> {
469        if args.len() != fsort.inputs().len() {
470            return Err(self.emit_err(errors::ArgCountMismatch::new(
471                Some(span),
472                String::from("function"),
473                fsort.inputs().len(),
474                args.len(),
475            )));
476        }
477
478        iter::zip(args, fsort.inputs())
479            .try_for_each_exhaust(|(arg, formal)| self.check_expr(arg, formal))?;
480
481        Ok(fsort.output().clone())
482    }
483
484    fn instantiate_func_sort(
485        &mut self,
486        app_expr: &fhir::Expr<'genv>,
487        fsort: rty::PolyFuncSort,
488    ) -> rty::FuncSort {
489        let args = fsort
490            .params()
491            .map(|kind| {
492                match kind {
493                    rty::SortParamKind::Sort => rty::SortArg::Sort(self.next_sort_var()),
494                    rty::SortParamKind::BvSize => rty::SortArg::BvSize(self.next_bv_size_var()),
495                }
496            })
497            .collect_vec();
498        self.sort_args_of_app
499            .insert(*app_expr, List::from_slice(&args));
500        fsort.instantiate(&args)
501    }
502
503    pub(crate) fn insert_node_sort(&mut self, fhir_id: FhirId, sort: rty::Sort) {
504        self.node_sort.insert(fhir_id, sort);
505    }
506
507    pub(crate) fn sort_of_bty(&self, bty: &fhir::BaseTy) -> rty::Sort {
508        self.node_sort
509            .get(&bty.fhir_id)
510            .unwrap_or_else(|| tracked_span_bug!("no sort found for bty: `{bty:?}`"))
511            .clone()
512    }
513
514    pub(crate) fn insert_path_args(&mut self, fhir_id: FhirId, args: rty::GenericArgs) {
515        self.path_args.insert(fhir_id, args);
516    }
517
518    pub(crate) fn path_args(&self, fhir_id: FhirId) -> rty::GenericArgs {
519        self.path_args
520            .get(&fhir_id)
521            .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
522            .clone()
523    }
524
525    pub(crate) fn insert_sort_for_alias_reft(&mut self, fhir_id: FhirId, fsort: rty::FuncSort) {
526        self.sort_of_alias_reft.insert(fhir_id, fsort);
527    }
528
529    fn sort_of_alias_reft(&self, fhir_id: FhirId) -> rty::FuncSort {
530        self.sort_of_alias_reft
531            .get(&fhir_id)
532            .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
533            .clone()
534    }
535
536    fn normalize_projection_sort(
537        genv: GlobalEnv,
538        owner: FluxOwnerId,
539        sort: rty::Sort,
540    ) -> rty::Sort {
541        let infcx = genv
542            .tcx()
543            .infer_ctxt()
544            .with_next_trait_solver(true)
545            .build(TypingMode::non_body_analysis());
546        if let Some(def_id) = owner.def_id()
547            && let def_id = genv.maybe_extern_id(def_id).resolved_id()
548            && let Ok(sort) = sort.normalize_sorts(def_id, genv, &infcx)
549        {
550            sort
551        } else {
552            sort
553        }
554    }
555
556    // FIXME(nilehmann) this is a bit of a hack. We should find a more robust way to do normalization
557    // for sort checking, including normalizing projections. [RJ: normalizing_projections is done now]
558    // Maybe we can do lazy normalization. Once we do that maybe we can also stop
559    // expanding aliases in `ConvCtxt::conv_sort`.
560    pub(crate) fn normalize_sorts(&mut self) -> QueryResult {
561        let genv = self.genv;
562        for sort in self.node_sort.values_mut() {
563            *sort = Self::normalize_projection_sort(genv, self.owner, sort.clone());
564        }
565        for fsort in self.sort_of_alias_reft.values_mut() {
566            *fsort = genv.deep_normalize_weak_alias_sorts(fsort)?;
567        }
568        Ok(())
569    }
570}
571
572impl<'genv> InferCtxt<'genv, '_> {
573    pub(super) fn declare_param(&mut self, param: fhir::RefineParam<'genv>, sort: rty::Sort) {
574        self.params.insert(param.id, (param, sort));
575    }
576
577    /// Whether a value of `sort1` can be automatically coerced to a value of `sort2`. A value of an
578    /// [`rty::SortCtor::Adt`] sort with a single field of sort `s` can be coerced to a value of sort
579    /// `s` and vice versa, i.e., we can automatically project the field out of the record or inject
580    /// a value into a record.
581    fn is_coercible(&mut self, sort1: &rty::Sort, sort2: &rty::Sort, fhir_id: FhirId) -> bool {
582        if self.try_equate(sort1, sort2).is_some() {
583            return true;
584        }
585
586        let mut sort1 = sort1.clone();
587        let mut sort2 = sort2.clone();
588
589        let mut coercions = vec![];
590        if let Some((def_id, sort)) = self.is_single_field_struct(&sort1) {
591            coercions.push(rty::Coercion::Project(def_id));
592            sort1 = sort.clone();
593        }
594        if let Some((def_id, sort)) = self.is_single_field_struct(&sort2) {
595            coercions.push(rty::Coercion::Inject(def_id));
596            sort2 = sort.clone();
597        }
598        self.wfckresults.coercions_mut().insert(fhir_id, coercions);
599        self.try_equate(&sort1, &sort2).is_some()
600    }
601
602    fn is_coercible_from_func(
603        &mut self,
604        sort: &rty::Sort,
605        fhir_id: FhirId,
606    ) -> Option<rty::PolyFuncSort> {
607        if let rty::Sort::Func(fsort) = sort {
608            Some(fsort.clone())
609        } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
610            self.wfckresults
611                .coercions_mut()
612                .insert(fhir_id, vec![rty::Coercion::Inject(def_id)]);
613            Some(fsort.clone())
614        } else {
615            None
616        }
617    }
618
619    fn is_coercible_to_func(
620        &mut self,
621        sort: &rty::Sort,
622        fhir_id: FhirId,
623    ) -> Option<rty::PolyFuncSort> {
624        if let rty::Sort::Func(fsort) = sort {
625            Some(fsort.clone())
626        } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
627            self.wfckresults
628                .coercions_mut()
629                .insert(fhir_id, vec![rty::Coercion::Project(def_id)]);
630            Some(fsort.clone())
631        } else {
632            None
633        }
634    }
635
636    fn try_equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
637        let sort1 = self.resolve_vars_if_possible(sort1);
638        let sort2 = self.resolve_vars_if_possible(sort2);
639        self.try_equate_inner(&sort1, &sort2)
640    }
641
642    fn try_equate_inner(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
643        match (sort1, sort2) {
644            (rty::Sort::Infer(rty::SortVar(vid1)), rty::Sort::Infer(rty::SortVar(vid2))) => {
645                self.sort_unification_table
646                    .unify_var_var(*vid1, *vid2)
647                    .ok()?;
648            }
649            (rty::Sort::Infer(rty::SortVar(vid)), sort)
650            | (sort, rty::Sort::Infer(rty::SortVar(vid))) => {
651                self.sort_unification_table
652                    .unify_var_value(*vid, Some(sort.clone()))
653                    .ok()?;
654            }
655            (rty::Sort::Infer(rty::NumVar(vid1)), rty::Sort::Infer(rty::NumVar(vid2))) => {
656                self.num_unification_table
657                    .unify_var_var(*vid1, *vid2)
658                    .ok()?;
659            }
660            (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::Int)
661            | (rty::Sort::Int, rty::Sort::Infer(rty::NumVar(vid))) => {
662                self.num_unification_table
663                    .unify_var_value(*vid, Some(rty::NumVarValue::Int))
664                    .ok()?;
665            }
666            (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::Real)
667            | (rty::Sort::Real, rty::Sort::Infer(rty::NumVar(vid))) => {
668                self.num_unification_table
669                    .unify_var_value(*vid, Some(rty::NumVarValue::Real))
670                    .ok()?;
671            }
672            (rty::Sort::Infer(rty::NumVar(vid)), rty::Sort::BitVec(sz))
673            | (rty::Sort::BitVec(sz), rty::Sort::Infer(rty::NumVar(vid))) => {
674                self.num_unification_table
675                    .unify_var_value(*vid, Some(rty::NumVarValue::BitVec(*sz)))
676                    .ok()?;
677            }
678
679            (rty::Sort::App(ctor1, args1), rty::Sort::App(ctor2, args2)) => {
680                if ctor1 != ctor2 || args1.len() != args2.len() {
681                    return None;
682                }
683                let mut args = vec![];
684                for (s1, s2) in args1.iter().zip(args2.iter()) {
685                    args.push(self.try_equate_inner(s1, s2)?);
686                }
687            }
688            (rty::Sort::BitVec(size1), rty::Sort::BitVec(size2)) => {
689                self.try_equate_bv_sizes(*size1, *size2)?;
690            }
691            _ if sort1 == sort2 => {}
692            _ => return None,
693        }
694        Some(sort1.clone())
695    }
696
697    fn try_equate_bv_sizes(
698        &mut self,
699        size1: rty::BvSize,
700        size2: rty::BvSize,
701    ) -> Option<rty::BvSize> {
702        match (size1, size2) {
703            (rty::BvSize::Infer(vid1), rty::BvSize::Infer(vid2)) => {
704                self.bv_size_unification_table
705                    .unify_var_var(vid1, vid2)
706                    .ok()?;
707            }
708            (rty::BvSize::Infer(vid), size) | (size, rty::BvSize::Infer(vid)) => {
709                self.bv_size_unification_table
710                    .unify_var_value(vid, Some(size))
711                    .ok()?;
712            }
713            _ if size1 == size2 => {}
714            _ => return None,
715        }
716        Some(size1)
717    }
718
719    fn equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> rty::Sort {
720        self.try_equate(sort1, sort2)
721            .unwrap_or_else(|| bug!("failed to equate sorts: `{sort1:?}` `{sort2:?}`"))
722    }
723
724    pub(crate) fn next_sort_var(&mut self) -> rty::Sort {
725        rty::Sort::Infer(rty::SortVar(self.next_sort_vid()))
726    }
727
728    fn next_num_var(&mut self) -> rty::Sort {
729        rty::Sort::Infer(rty::NumVar(self.next_num_vid()))
730    }
731
732    pub(crate) fn next_sort_vid(&mut self) -> rty::SortVid {
733        self.sort_unification_table.new_key(None)
734    }
735
736    fn next_num_vid(&mut self) -> rty::NumVid {
737        self.num_unification_table.new_key(None)
738    }
739
740    fn next_bv_size_var(&mut self) -> rty::BvSize {
741        rty::BvSize::Infer(self.next_bv_size_vid())
742    }
743
744    fn next_bv_size_vid(&mut self) -> rty::BvSizeVid {
745        self.bv_size_unification_table.new_key(None)
746    }
747
748    fn ensure_resolved_path(&mut self, path: &fhir::PathExpr) -> Result<rty::Sort> {
749        let sort = self.synth_path(path);
750        self.fully_resolve(&sort)
751            .map_err(|_| self.emit_err(errors::CannotInferSort::new(path.span)))
752    }
753
754    fn is_single_field_struct(&mut self, sort: &rty::Sort) -> Option<(DefId, rty::Sort)> {
755        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = sort
756            && let Some(variant) = sort_def.opt_struct_variant()
757            && let [sort] = &variant.field_sorts(sort_args)[..]
758        {
759            Some((sort_def.did(), sort.clone()))
760        } else {
761            None
762        }
763    }
764
765    pub(crate) fn into_results(mut self) -> Result<WfckResults> {
766        // Make sure the int literals are fully resolved. This must happen before everything else
767        // such that we properly apply the fallback for unconstrained num vars.
768        for (node, sort) in std::mem::take(&mut self.sort_of_literal) {
769            // Fallback to `int` when a num variable is unconstrained. Note that we unconditionally
770            // unify the variable. This is fine because if the variable has already been unified,
771            // the operation will fail and this won't have any effect. Also note that unifying a
772            // variable could solve variables that appear later in this for loop. This is also fine
773            // because the order doesn't matter as we are unifying everything to `int`.
774            if let rty::Sort::Infer(rty::SortInfer::NumVar(vid)) = &sort {
775                let _ = self
776                    .num_unification_table
777                    .unify_var_value(*vid, Some(rty::NumVarValue::Int));
778            }
779
780            let sort = self
781                .fully_resolve(&sort)
782                .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
783            self.wfckresults.node_sorts_mut().insert(node.fhir_id, sort);
784        }
785
786        // Make sure the binary operators are fully resolved
787        for (node, sort) in std::mem::take(&mut self.sort_of_bin_op) {
788            let sort = self
789                .fully_resolve(&sort)
790                .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
791            self.wfckresults
792                .bin_op_sorts_mut()
793                .insert(node.fhir_id, sort);
794        }
795
796        let allow_uninterpreted_cast = self
797            .owner
798            .def_id()
799            .map_or_else(flux_config::allow_uninterpreted_cast, |def_id| {
800                self.genv.infer_opts(def_id).allow_uninterpreted_cast
801            });
802
803        // Make sure that function applications are fully resolved
804        for (node, sort_args) in std::mem::take(&mut self.sort_args_of_app) {
805            let mut res = vec![];
806            for sort_arg in &sort_args {
807                let sort_arg = match sort_arg {
808                    rty::SortArg::Sort(sort) => {
809                        let sort = self
810                            .fully_resolve(sort)
811                            .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
812                        rty::SortArg::Sort(sort)
813                    }
814                    rty::SortArg::BvSize(rty::BvSize::Infer(vid)) => {
815                        let size = self
816                            .bv_size_unification_table
817                            .probe_value(*vid)
818                            .ok_or_else(|| {
819                                self.emit_err(errors::CannotInferSort::new(node.span))
820                            })?;
821                        rty::SortArg::BvSize(size)
822                    }
823                    _ => sort_arg.clone(),
824                };
825                res.push(sort_arg);
826            }
827            if let fhir::ExprKind::App(callee, _) = node.kind
828                && matches!(callee.res, fhir::Res::GlobalFunc(fhir::SpecFuncKind::Cast))
829            {
830                let [rty::SortArg::Sort(from), rty::SortArg::Sort(to)] = &res[..] else {
831                    span_bug!(node.span, "invalid sort args!")
832                };
833                if !allow_uninterpreted_cast
834                    && matches!(from.cast_kind(to), rty::CastKind::Uninterpreted)
835                {
836                    return Err(self.emit_err(errors::InvalidCast::new(node.span, from, to)));
837                }
838            }
839            self.wfckresults
840                .fn_app_sorts_mut()
841                .insert(node.fhir_id, res.into());
842        }
843
844        // Make sure all parameters are fully resolved
845        for (_, (param, sort)) in std::mem::take(&mut self.params) {
846            let sort = self
847                .fully_resolve(&sort)
848                .map_err(|_| self.emit_err(errors::SortAnnotationNeeded::new(&param)))?;
849            self.wfckresults.param_sorts_mut().insert(param.id, sort);
850        }
851
852        Ok(self.wfckresults)
853    }
854
855    pub(crate) fn infer_mode(&self, id: fhir::ParamId) -> fhir::InferMode {
856        fhir::InferMode::from_param_kind(self.params[&id].0.kind)
857    }
858
859    #[track_caller]
860    pub(crate) fn param_sort(&self, id: fhir::ParamId) -> rty::Sort {
861        self.params
862            .get(&id)
863            .unwrap_or_else(|| bug!("no entry found for `{id:?}`"))
864            .1
865            .clone()
866    }
867
868    fn shallow_resolve(&mut self, sort: &rty::Sort) -> rty::Sort {
869        sort.fold_with(&mut ShallowResolver { infcx: self })
870    }
871
872    fn resolve_vars_if_possible(&mut self, sort: &rty::Sort) -> rty::Sort {
873        sort.fold_with(&mut OpportunisticResolver { infcx: self })
874    }
875
876    pub(crate) fn fully_resolve(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, ()> {
877        sort.try_fold_with(&mut FullResolver { infcx: self })
878    }
879}
880
881/// Before the main sort inference, we do a first traversal checking all implicitly scoped parameters
882/// declared with `@` or `#` and infer their sort based on the type they are indexing, e.g., if `n` was
883/// declared as `i32[@n]`, we infer `int` for `n`.
884///
885/// This prepass is necessary because sometimes the order in which we traverse expressions can
886/// affect what we can infer. By resolving implicit parameters first, we ensure more consistent and
887/// complete inference regardless of how expressions are later traversed.
888///
889/// It should be possible to improve sort checking (e.g., by allowing partially resolved sorts in
890/// function position) such that we don't need this anymore.
891pub(crate) struct ImplicitParamInferer<'a, 'genv, 'tcx> {
892    infcx: &'a mut InferCtxt<'genv, 'tcx>,
893    errors: Errors<'genv>,
894}
895
896impl<'a, 'genv, 'tcx> ImplicitParamInferer<'a, 'genv, 'tcx> {
897    pub(crate) fn infer(
898        infcx: &'a mut InferCtxt<'genv, 'tcx>,
899        node: &fhir::OwnerNode<'genv>,
900    ) -> Result {
901        let errors = Errors::new(infcx.genv.sess());
902        let mut vis = Self { infcx, errors };
903        vis.visit_node(node);
904        vis.errors.into_result()
905    }
906
907    fn infer_implicit_params(&mut self, idx: &fhir::Expr, expected: &rty::Sort) {
908        match idx.kind {
909            fhir::ExprKind::Var(var, Some(_)) => {
910                let (_, id) = var.res.expect_param();
911                let found = self.infcx.param_sort(id);
912                self.infcx.equate(&found, expected);
913            }
914            fhir::ExprKind::Record(flds) => {
915                if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = expected {
916                    let sorts = sort_def.struct_variant().field_sorts(sort_args);
917                    if flds.len() != sorts.len() {
918                        self.errors.emit(errors::ArgCountMismatch::new(
919                            Some(idx.span),
920                            String::from("type"),
921                            sorts.len(),
922                            flds.len(),
923                        ));
924                        return;
925                    }
926                    for (f, sort) in iter::zip(flds, &sorts) {
927                        self.infer_implicit_params(f, sort);
928                    }
929                } else {
930                    self.errors.emit(errors::ArgCountMismatch::new(
931                        Some(idx.span),
932                        String::from("type"),
933                        1,
934                        flds.len(),
935                    ));
936                }
937            }
938            _ => {}
939        }
940    }
941}
942
943impl<'genv> fhir::visit::Visitor<'genv> for ImplicitParamInferer<'_, 'genv, '_> {
944    fn visit_ty(&mut self, ty: &fhir::Ty<'genv>) {
945        if let fhir::TyKind::Indexed(bty, idx) = &ty.kind {
946            let expected = self.infcx.sort_of_bty(bty);
947            self.infer_implicit_params(idx, &expected);
948        }
949        fhir::visit::walk_ty(self, ty);
950    }
951}
952
953impl InferCtxt<'_, '_> {
954    #[track_caller]
955    fn emit_sort_mismatch(
956        &mut self,
957        span: Span,
958        expected: &rty::Sort,
959        found: &rty::Sort,
960    ) -> ErrorGuaranteed {
961        let expected = self.resolve_vars_if_possible(expected);
962        let found = self.resolve_vars_if_possible(found);
963        self.emit_err(errors::SortMismatch::new(span, expected, found))
964    }
965
966    fn emit_field_not_found(&mut self, sort: &rty::Sort, field: Ident) -> ErrorGuaranteed {
967        self.emit_err(errors::FieldNotFound::new(sort.clone(), field))
968    }
969
970    #[track_caller]
971    fn emit_err<'b>(&'b self, err: impl Diagnostic<'b>) -> ErrorGuaranteed {
972        self.genv.sess().emit_err(err)
973    }
974}
975
976struct ShallowResolver<'a, 'genv, 'tcx> {
977    infcx: &'a mut InferCtxt<'genv, 'tcx>,
978}
979
980impl TypeFolder for ShallowResolver<'_, '_, '_> {
981    fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
982        match sort {
983            rty::Sort::Infer(rty::SortVar(vid)) => {
984                // if `sort` is a sort variable, it can be resolved to an num/bit-vec variable,
985                // which can then be recursively resolved, hence the recursion. Note though that
986                // we prevent sort variables from unifying to other sort variables directly (though
987                // they may be embedded structurally), so this recursion should always be of very
988                // limited depth.
989                self.infcx
990                    .sort_unification_table
991                    .probe_value(*vid)
992                    .map(|sort| sort.fold_with(self))
993                    .unwrap_or(sort.clone())
994            }
995            rty::Sort::Infer(rty::NumVar(vid)) => {
996                // same here, a num var could had been unified with a bitvector
997                self.infcx
998                    .num_unification_table
999                    .probe_value(*vid)
1000                    .map(|val| val.to_sort().fold_with(self))
1001                    .unwrap_or(sort.clone())
1002            }
1003            rty::Sort::BitVec(rty::BvSize::Infer(vid)) => {
1004                self.infcx
1005                    .bv_size_unification_table
1006                    .probe_value(*vid)
1007                    .map(rty::Sort::BitVec)
1008                    .unwrap_or(sort.clone())
1009            }
1010            _ => sort.clone(),
1011        }
1012    }
1013}
1014
1015struct OpportunisticResolver<'a, 'genv, 'tcx> {
1016    infcx: &'a mut InferCtxt<'genv, 'tcx>,
1017}
1018
1019impl TypeFolder for OpportunisticResolver<'_, '_, '_> {
1020    fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
1021        let s = self.infcx.shallow_resolve(sort);
1022        s.super_fold_with(self)
1023    }
1024}
1025
1026struct FullResolver<'a, 'genv, 'tcx> {
1027    infcx: &'a mut InferCtxt<'genv, 'tcx>,
1028}
1029
1030impl FallibleTypeFolder for FullResolver<'_, '_, '_> {
1031    type Error = ();
1032
1033    fn try_fold_sort(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, Self::Error> {
1034        let s = self.infcx.shallow_resolve(sort);
1035        match s {
1036            rty::Sort::Infer(_) | rty::Sort::BitVec(rty::BvSize::Infer(_)) => Err(()),
1037            _ => s.try_super_fold_with(self),
1038        }
1039    }
1040}
1041
1042/// Map to associate data to a node (i.e., an expression).
1043///
1044/// Used to record elaborated information.
1045#[derive_where(Default)]
1046struct NodeMap<'genv, T> {
1047    map: FxHashMap<FhirId, (fhir::Expr<'genv>, T)>,
1048}
1049
1050impl<'genv, T> NodeMap<'genv, T> {
1051    /// Add a `node` to the map with associated `data`
1052    fn insert(&mut self, node: fhir::Expr<'genv>, data: T) {
1053        assert!(self.map.insert(node.fhir_id, (node, data)).is_none());
1054    }
1055}
1056
1057impl<'genv, T> IntoIterator for NodeMap<'genv, T> {
1058    type Item = (fhir::Expr<'genv>, T);
1059
1060    type IntoIter = std::collections::hash_map::IntoValues<FhirId, Self::Item>;
1061
1062    fn into_iter(self) -> Self::IntoIter {
1063        self.map.into_values()
1064    }
1065}