flux_fhir_analysis/wf/
sortck.rs

1use std::iter;
2
3use derive_where::derive_where;
4use ena::unify::InPlaceUnificationTable;
5use flux_common::{bug, iter::IterExt, span_bug, tracked_span_bug};
6use flux_errors::{ErrorGuaranteed, Errors};
7use flux_infer::projections::NormalizeExt;
8use flux_middle::{
9    fhir::{self, FhirId, FluxOwnerId, QPathExpr, visit::Visitor as _},
10    global_env::GlobalEnv,
11    queries::QueryResult,
12    rty::{
13        self, AdtSortDef, FuncSort, List, RecordCtor, WfckResults,
14        fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable},
15    },
16};
17use itertools::{Itertools, izip};
18use rustc_data_structures::unord::UnordMap;
19use rustc_errors::Diagnostic;
20use rustc_hash::FxHashMap;
21use rustc_hir::def::DefKind;
22use rustc_middle::ty::TypingMode;
23use rustc_span::{Span, def_id::DefId, symbol::Ident};
24
25use super::errors;
26use crate::rustc_infer::infer::TyCtxtInferExt;
27
28type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
29
30pub(super) struct InferCtxt<'genv, 'tcx> {
31    pub genv: GlobalEnv<'genv, 'tcx>,
32    pub owner: FluxOwnerId,
33    pub wfckresults: WfckResults,
34    sort_unification_table: InPlaceUnificationTable<rty::SortVid>,
35    bv_size_unification_table: InPlaceUnificationTable<rty::BvSizeVid>,
36    params: FxHashMap<fhir::ParamId, (fhir::RefineParam<'genv>, rty::Sort)>,
37    node_sort: FxHashMap<FhirId, rty::Sort>,
38    path_args: UnordMap<FhirId, rty::GenericArgs>,
39    sort_of_alias_reft: FxHashMap<FhirId, rty::FuncSort>,
40    sort_of_literal: NodeMap<'genv, rty::Sort>,
41    sort_of_bin_op: NodeMap<'genv, rty::Sort>,
42    sort_args_of_app: NodeMap<'genv, List<rty::SortArg>>,
43}
44
45pub fn prim_op_sort(op: &fhir::BinOp) -> Option<(Vec<rty::Sort>, rty::Sort)> {
46    match op {
47        fhir::BinOp::BitAnd
48        | fhir::BinOp::BitOr
49        | fhir::BinOp::BitXor
50        | fhir::BinOp::BitShl
51        | fhir::BinOp::BitShr => Some((vec![rty::Sort::Int, rty::Sort::Int], rty::Sort::Int)),
52        _ => None,
53    }
54}
55
56impl<'genv, 'tcx> InferCtxt<'genv, 'tcx> {
57    pub(super) fn new(genv: GlobalEnv<'genv, 'tcx>, owner: FluxOwnerId) -> Self {
58        // We skip 0 because that's used for sort dummy self types during conv.
59        let mut sort_unification_table = InPlaceUnificationTable::new();
60        sort_unification_table.new_key(Default::default());
61        Self {
62            genv,
63            owner,
64            wfckresults: WfckResults::new(owner),
65            sort_unification_table,
66            bv_size_unification_table: InPlaceUnificationTable::new(),
67            params: Default::default(),
68            node_sort: Default::default(),
69            path_args: Default::default(),
70            sort_of_alias_reft: Default::default(),
71            sort_of_literal: Default::default(),
72            sort_of_bin_op: Default::default(),
73            sort_args_of_app: Default::default(),
74        }
75    }
76
77    fn check_abs(
78        &mut self,
79        expr: &fhir::Expr<'genv>,
80        params: &[fhir::RefineParam],
81        body: &fhir::Expr<'genv>,
82        expected: &rty::Sort,
83    ) -> Result {
84        let Some(fsort) = self.is_coercible_from_func(expected, expr.fhir_id) else {
85            return Err(self.emit_err(errors::UnexpectedFun::new(expr.span, expected)));
86        };
87
88        let fsort = fsort.expect_mono();
89
90        if params.len() != fsort.inputs().len() {
91            return Err(self.emit_err(errors::ParamCountMismatch::new(
92                expr.span,
93                fsort.inputs().len(),
94                params.len(),
95            )));
96        }
97        iter::zip(params, fsort.inputs()).try_for_each_exhaust(|(param, expected)| {
98            let found = self.param_sort(param.id);
99            if self.try_equate(&found, expected).is_none() {
100                return Err(self.emit_sort_mismatch(param.span, expected, &found));
101            }
102            Ok(())
103        })?;
104        self.check_expr(body, fsort.output())?;
105        self.wfckresults
106            .node_sorts_mut()
107            .insert(body.fhir_id, fsort.output().clone());
108        Ok(())
109    }
110
111    fn check_field_exprs(
112        &mut self,
113        expr_span: Span,
114        sort_def: &AdtSortDef,
115        sort_args: &[rty::Sort],
116        field_exprs: &[fhir::FieldExpr<'genv>],
117        spread: Option<&fhir::Spread<'genv>>,
118        expected: &rty::Sort,
119    ) -> Result {
120        let sort_by_field_name = sort_def.struct_variant().sort_by_field_name(sort_args);
121        let mut used_fields = FxHashMap::default();
122        for expr in field_exprs {
123            // make sure that the field is actually a field
124            let Some(sort) = sort_by_field_name.get(&expr.ident.name) else {
125                return Err(self.emit_err(errors::FieldNotFound::new(expected.clone(), expr.ident)));
126            };
127            if let Some(old_field) = used_fields.insert(expr.ident.name, expr.ident) {
128                return Err(self.emit_err(errors::DuplicateFieldUsed::new(expr.ident, old_field)));
129            }
130            self.check_expr(&expr.expr, sort)?;
131        }
132        if let Some(spread) = spread {
133            // must check that the spread is of the same sort as the constructor
134            self.check_expr(&spread.expr, expected)
135        } else if sort_by_field_name.len() != used_fields.len() {
136            // emit an error because all fields are not used
137            let missing_fields = sort_by_field_name
138                .into_keys()
139                .filter(|x| !used_fields.contains_key(x))
140                .collect();
141            Err(self.emit_err(errors::ConstructorMissingFields::new(expr_span, missing_fields)))
142        } else {
143            Ok(())
144        }
145    }
146
147    fn check_constructor(
148        &mut self,
149        expr: &fhir::Expr,
150        field_exprs: &[fhir::FieldExpr<'genv>],
151        spread: Option<&fhir::Spread<'genv>>,
152        expected: &rty::Sort,
153    ) -> Result {
154        let expected = self.resolve_vars_if_possible(expected);
155        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = &expected {
156            self.wfckresults
157                .record_ctors_mut()
158                .insert(expr.fhir_id, RecordCtor::Struct(sort_def.did()));
159            self.check_field_exprs(expr.span, sort_def, sort_args, field_exprs, spread, &expected)?;
160            Ok(())
161        } else {
162            Err(self.emit_err(errors::UnexpectedConstructor::new(expr.span, &expected)))
163        }
164    }
165
166    fn check_record(
167        &mut self,
168        arg: &fhir::Expr<'genv>,
169        flds: &[fhir::Expr<'genv>],
170        expected: &rty::Sort,
171    ) -> Result {
172        if let Some(sorts) = expected.field_sorts() {
173            if flds.len() != sorts.len() {
174                return Err(self.emit_err(errors::ArgCountMismatch::new(
175                    Some(arg.span),
176                    String::from("type"),
177                    sorts.len(),
178                    flds.len(),
179                )));
180            }
181            if let rty::Sort::App(rty::SortCtor::Adt(sort_def), _) = expected {
182                self.wfckresults
183                    .record_ctors_mut()
184                    .insert(arg.fhir_id, RecordCtor::Struct(sort_def.did()));
185            } else if matches!(expected, rty::Sort::RawPtr) {
186                self.wfckresults
187                    .record_ctors_mut()
188                    .insert(arg.fhir_id, RecordCtor::RawPtr);
189            }
190
191            izip!(flds, &sorts)
192                .map(|(arg, expected)| self.check_expr(arg, expected))
193                .try_collect_exhaust()
194        } else {
195            Err(self.emit_err(errors::ArgCountMismatch::new(
196                Some(arg.span),
197                String::from("type"),
198                1,
199                flds.len(),
200            )))
201        }
202    }
203
204    pub(super) fn check_expr(&mut self, expr: &fhir::Expr<'genv>, expected: &rty::Sort) -> Result {
205        match expr.kind {
206            fhir::ExprKind::IfThenElse(p, e1, e2) => {
207                self.check_expr(p, &rty::Sort::Bool)?;
208                self.check_expr(e1, expected)?;
209                self.check_expr(e2, expected)?;
210                return Ok(());
211            }
212            fhir::ExprKind::Abs(params, body) => {
213                self.check_abs(expr, params, body, expected)?;
214                return Ok(());
215            }
216            fhir::ExprKind::Record(fields) => {
217                self.check_record(expr, fields, expected)?;
218                return Ok(());
219            }
220            fhir::ExprKind::Constructor(None, exprs, spread) => {
221                self.check_constructor(expr, exprs, spread, expected)?;
222                return Ok(());
223            }
224            fhir::ExprKind::Tuple(exprs) => {
225                // Check tuple against expected tuple sort
226                if let rty::Sort::Tuple(expected_sorts) = expected
227                    && exprs.len() == expected_sorts.len()
228                {
229                    for (expr, expected_sort) in iter::zip(exprs, expected_sorts) {
230                        self.check_expr(expr, expected_sort)?;
231                    }
232                    return Ok(());
233                }
234            }
235            fhir::ExprKind::Err(_) => {
236                // an error has already been reported so we can just skip
237                return Ok(());
238            }
239            fhir::ExprKind::UnaryOp(..)
240            | fhir::ExprKind::SetLiteral(..)
241            | fhir::ExprKind::BinaryOp(..)
242            | fhir::ExprKind::Dot(..)
243            | fhir::ExprKind::App(..)
244            | fhir::ExprKind::Alias(..)
245            | fhir::ExprKind::Var(..)
246            | fhir::ExprKind::Literal(..)
247            | fhir::ExprKind::Quant(..)
248            | fhir::ExprKind::Block(..)
249            | fhir::ExprKind::Constructor(..)
250            | fhir::ExprKind::PrimApp(..) => {}
251        }
252        // Fallback to synthesis and then check
253        let found = self.synth_expr(expr)?;
254        let found = self.resolve_vars_if_possible(&found);
255        let expected = self.resolve_vars_if_possible(expected);
256        if !self.is_coercible(&found, &expected, expr.fhir_id) {
257            return Err(self.emit_sort_mismatch(expr.span, &expected, &found));
258        }
259        Ok(())
260    }
261
262    pub(super) fn check_loc(&mut self, loc: &fhir::PathExpr) -> Result {
263        let found = self.synth_path(loc);
264        if found == rty::Sort::Loc {
265            Ok(())
266        } else {
267            Err(self.emit_sort_mismatch(loc.span, &rty::Sort::Loc, &found))
268        }
269    }
270
271    fn synth_lit(&mut self, lit: fhir::Lit, expr: &fhir::Expr<'genv>) -> rty::Sort {
272        match lit {
273            fhir::Lit::Int(_) => {
274                let sort = self.next_sort_var_with_cstr(rty::SortCstr::NUMERIC);
275                self.sort_of_literal.insert(*expr, sort.clone());
276                sort
277            }
278            fhir::Lit::Real(_) => rty::Sort::Real,
279            fhir::Lit::Bool(_) => rty::Sort::Bool,
280            fhir::Lit::Str(_) => rty::Sort::Str,
281            fhir::Lit::Char(_) => rty::Sort::Char,
282        }
283    }
284
285    fn synth_prim_app(
286        &mut self,
287        op: &fhir::BinOp,
288        e1: &fhir::Expr<'genv>,
289        e2: &fhir::Expr<'genv>,
290        span: Span,
291    ) -> Result<rty::Sort> {
292        let Some((inputs, output)) = prim_op_sort(op) else {
293            return Err(self.emit_err(errors::UnsupportedPrimOp::new(span, *op)));
294        };
295        let [sort1, sort2] = &inputs[..] else {
296            return Err(self.emit_err(errors::ArgCountMismatch::new(
297                Some(span),
298                String::from("primop app"),
299                inputs.len(),
300                2,
301            )));
302        };
303        self.check_expr(e1, sort1)?;
304        self.check_expr(e2, sort2)?;
305        Ok(output)
306    }
307
308    fn synth_expr(&mut self, expr: &fhir::Expr<'genv>) -> Result<rty::Sort> {
309        match expr.kind {
310            fhir::ExprKind::Var(QPathExpr::Resolved(path, _)) => Ok(self.synth_path(&path)),
311            fhir::ExprKind::Var(QPathExpr::TypeRelative(..)) => {
312                Ok(self.synth_type_relative_path(expr))
313            }
314            fhir::ExprKind::Literal(lit) => Ok(self.synth_lit(lit, expr)),
315            fhir::ExprKind::BinaryOp(op, e1, e2) => self.synth_binary_op(expr, op, e1, e2),
316            fhir::ExprKind::PrimApp(op, e1, e2) => self.synth_prim_app(&op, e1, e2, expr.span),
317            fhir::ExprKind::UnaryOp(op, e) => self.synth_unary_op(op, e),
318            fhir::ExprKind::App(callee, args) => {
319                let sort = self.ensure_resolved_path(&callee)?;
320                let Some(poly_fsort) = self.is_coercible_to_func(&sort, callee.fhir_id) else {
321                    return Err(self.emit_err(errors::ExpectedFun::new(callee.span, &sort)));
322                };
323                let fsort = self.instantiate_func_sort(expr, poly_fsort);
324                self.synth_app(fsort, args, expr.span)
325            }
326            fhir::ExprKind::Quant(_, param, dom, body) => {
327                let sort = self.param_sort(param.id);
328                if let fhir::QuantDom::Bounded { .. } = dom
329                    && self.try_equate(&sort, &rty::Sort::Int).is_none()
330                {
331                    return Err(self.emit_err(errors::IllSortedQuantifier::new(param.span, sort)));
332                }
333                self.check_expr(body, &rty::Sort::Bool)?;
334                Ok(rty::Sort::Bool)
335            }
336            fhir::ExprKind::Alias(_alias_reft, args) => {
337                // To check the application we only need the sort of `_alias_reft` which we collected
338                // during early conv, but should we do any extra checks on `_alias_reft`?
339                let fsort = self.sort_of_alias_reft(expr.fhir_id);
340                self.synth_app(fsort, args, expr.span)
341            }
342            fhir::ExprKind::IfThenElse(p, e1, e2) => {
343                self.check_expr(p, &rty::Sort::Bool)?;
344                let sort = self.synth_expr(e1)?;
345                self.check_expr(e2, &sort)?;
346                Ok(sort)
347            }
348            fhir::ExprKind::Dot(base, fld) => {
349                let sort = self.synth_expr(base)?;
350                let sort = self
351                    .fully_resolve(&sort)
352                    .map_err(|_| self.emit_err(errors::CannotInferSort::new(base.span)))?;
353                match &sort {
354                    rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) => {
355                        let (proj, sort) = sort_def
356                            .struct_variant()
357                            .field_by_name(sort_def.did(), sort_args, fld.name)
358                            .ok_or_else(|| self.emit_field_not_found(&sort, fld))?;
359                        self.wfckresults
360                            .field_projs_mut()
361                            .insert(expr.fhir_id, proj);
362                        Ok(sort)
363                    }
364                    rty::Sort::Tuple(sorts) => {
365                        // Parse field name as integer for tuple field access
366                        if let Ok(field_idx) = fld.name.as_str().parse::<usize>()
367                            && field_idx < sorts.len()
368                        {
369                            let proj = rty::FieldProj::Tuple {
370                                arity: sorts.len(),
371                                field: field_idx as u32,
372                            };
373                            self.wfckresults
374                                .field_projs_mut()
375                                .insert(expr.fhir_id, proj);
376                            Ok(sorts[field_idx].clone())
377                        } else {
378                            Err(self.emit_field_not_found(&sort, fld))
379                        }
380                    }
381                    rty::Sort::RawPtr => {
382                        let field = rty::RawPtrField::from_name(fld.name)
383                            .ok_or_else(|| self.emit_field_not_found(&sort, fld))?;
384                        self.wfckresults
385                            .field_projs_mut()
386                            .insert(expr.fhir_id, rty::FieldProj::RawPtr { field });
387                        Ok(rty::Sort::Int)
388                    }
389                    rty::Sort::Bool | rty::Sort::Int | rty::Sort::Real => {
390                        Err(self.emit_err(errors::InvalidPrimitiveDotAccess::new(&sort, fld)))
391                    }
392                    _ => Err(self.emit_field_not_found(&sort, fld)),
393                }
394            }
395            fhir::ExprKind::Constructor(Some(path), field_exprs, spread) => {
396                // if we have S then we can synthesize otherwise we fail
397                // first get the sort based on the path - for example S { ... } => S
398                // and we should expect sort to be a struct or enum app
399                let path_def_id = match path.res {
400                    fhir::Res::Def(DefKind::Enum | DefKind::Struct, def_id) => def_id,
401                    _ => span_bug!(expr.span, "unexpected path in constructor"),
402                };
403                let sort_def = self
404                    .genv
405                    .adt_sort_def_of(path_def_id)
406                    .map_err(|e| self.emit_err(e))?;
407                // generate fresh inferred sorts for each param
408                let fresh_args: rty::List<_> = (0..sort_def.param_count())
409                    .map(|_| self.next_sort_var())
410                    .collect();
411                let sort = rty::Sort::App(rty::SortCtor::Adt(sort_def.clone()), fresh_args.clone());
412                // check fields & spread against fresh args
413                self.check_field_exprs(
414                    expr.span,
415                    &sort_def,
416                    &fresh_args,
417                    field_exprs,
418                    spread,
419                    &sort,
420                )?;
421                Ok(sort)
422            }
423            fhir::ExprKind::Block(decls, body) => {
424                for decl in decls {
425                    self.check_expr(&decl.init, &self.param_sort(decl.param.id))?;
426                }
427                self.synth_expr(body)
428            }
429            fhir::ExprKind::SetLiteral(elems) => {
430                let elem_sort = self.next_sort_var();
431                for elem in elems {
432                    self.check_expr(elem, &elem_sort)?;
433                }
434                Ok(rty::Sort::App(rty::SortCtor::Set, List::singleton(elem_sort)))
435            }
436            fhir::ExprKind::Tuple(exprs) => {
437                let sorts = exprs
438                    .iter()
439                    .map(|expr| self.synth_expr(expr))
440                    .try_collect()?;
441                Ok(rty::Sort::Tuple(sorts))
442            }
443            _ => Err(self.emit_err(errors::CannotInferSort::new(expr.span))),
444        }
445    }
446
447    fn synth_path(&mut self, path: &fhir::PathExpr) -> rty::Sort {
448        self.node_sort
449            .get(&path.fhir_id)
450            .unwrap_or_else(|| tracked_span_bug!("no sort found for path: `{path:?}`"))
451            .clone()
452    }
453
454    fn synth_type_relative_path(&mut self, expr: &fhir::Expr) -> rty::Sort {
455        self.node_sort
456            .get(&expr.fhir_id)
457            .unwrap_or_else(|| tracked_span_bug!("no sort found for: `{expr:?}`"))
458            .clone()
459    }
460
461    fn synth_binary_op(
462        &mut self,
463        expr: &fhir::Expr<'genv>,
464        op: fhir::BinOp,
465        e1: &fhir::Expr<'genv>,
466        e2: &fhir::Expr<'genv>,
467    ) -> Result<rty::Sort> {
468        match op {
469            fhir::BinOp::Or | fhir::BinOp::And | fhir::BinOp::Iff | fhir::BinOp::Imp => {
470                self.check_expr(e1, &rty::Sort::Bool)?;
471                self.check_expr(e2, &rty::Sort::Bool)?;
472                Ok(rty::Sort::Bool)
473            }
474            fhir::BinOp::Eq | fhir::BinOp::Ne => {
475                let sort = self.next_sort_var();
476                self.check_expr(e1, &sort)?;
477                self.check_expr(e2, &sort)?;
478                Ok(rty::Sort::Bool)
479            }
480            fhir::BinOp::Lt | fhir::BinOp::Le | fhir::BinOp::Gt | fhir::BinOp::Ge => {
481                let sort = self.next_sort_var();
482                self.check_expr(e1, &sort)?;
483                self.check_expr(e2, &sort)?;
484                self.sort_of_bin_op.insert(*expr, sort.clone());
485                Ok(rty::Sort::Bool)
486            }
487            fhir::BinOp::Add
488            | fhir::BinOp::Sub
489            | fhir::BinOp::Mul
490            | fhir::BinOp::Div
491            | fhir::BinOp::Mod
492            | fhir::BinOp::BitAnd
493            | fhir::BinOp::BitOr
494            | fhir::BinOp::BitXor
495            | fhir::BinOp::BitShl
496            | fhir::BinOp::BitShr => {
497                let sort = self.next_sort_var_with_cstr(rty::SortCstr::from_bin_op(op));
498                self.check_expr(e1, &sort)?;
499                self.check_expr(e2, &sort)?;
500                self.sort_of_bin_op.insert(*expr, sort.clone());
501                Ok(sort)
502            }
503        }
504    }
505
506    fn synth_unary_op(&mut self, op: fhir::UnOp, e: &fhir::Expr<'genv>) -> Result<rty::Sort> {
507        match op {
508            fhir::UnOp::Not => {
509                self.check_expr(e, &rty::Sort::Bool)?;
510                Ok(rty::Sort::Bool)
511            }
512            fhir::UnOp::Neg => {
513                self.check_expr(e, &rty::Sort::Int)?;
514                Ok(rty::Sort::Int)
515            }
516        }
517    }
518
519    fn synth_app(
520        &mut self,
521        fsort: FuncSort,
522        args: &[fhir::Expr<'genv>],
523        span: Span,
524    ) -> Result<rty::Sort> {
525        if args.len() != fsort.inputs().len() {
526            return Err(self.emit_err(errors::ArgCountMismatch::new(
527                Some(span),
528                String::from("function"),
529                fsort.inputs().len(),
530                args.len(),
531            )));
532        }
533
534        iter::zip(args, fsort.inputs())
535            .try_for_each_exhaust(|(arg, formal)| self.check_expr(arg, formal))?;
536
537        Ok(fsort.output().clone())
538    }
539
540    fn instantiate_func_sort(
541        &mut self,
542        app_expr: &fhir::Expr<'genv>,
543        fsort: rty::PolyFuncSort,
544    ) -> rty::FuncSort {
545        let args = fsort
546            .params()
547            .map(|kind| {
548                match kind {
549                    rty::SortParamKind::Sort => rty::SortArg::Sort(self.next_sort_var()),
550                    rty::SortParamKind::BvSize => rty::SortArg::BvSize(self.next_bv_size_var()),
551                }
552            })
553            .collect_vec();
554        self.sort_args_of_app
555            .insert(*app_expr, List::from_slice(&args));
556        fsort.instantiate(&args)
557    }
558
559    pub(crate) fn insert_node_sort(&mut self, fhir_id: FhirId, sort: rty::Sort) {
560        self.node_sort.insert(fhir_id, sort);
561    }
562
563    pub(crate) fn sort_of_bty(&self, bty: &fhir::BaseTy) -> rty::Sort {
564        self.node_sort
565            .get(&bty.fhir_id)
566            .unwrap_or_else(|| tracked_span_bug!("no sort found for bty: `{bty:?}`"))
567            .clone()
568    }
569
570    pub(crate) fn insert_path_args(&mut self, fhir_id: FhirId, args: rty::GenericArgs) {
571        self.path_args.insert(fhir_id, args);
572    }
573
574    pub(crate) fn path_args(&self, fhir_id: FhirId) -> rty::GenericArgs {
575        self.path_args
576            .get(&fhir_id)
577            .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
578            .clone()
579    }
580
581    pub(crate) fn insert_sort_for_alias_reft(&mut self, fhir_id: FhirId, fsort: rty::FuncSort) {
582        self.sort_of_alias_reft.insert(fhir_id, fsort);
583    }
584
585    fn sort_of_alias_reft(&self, fhir_id: FhirId) -> rty::FuncSort {
586        self.sort_of_alias_reft
587            .get(&fhir_id)
588            .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`"))
589            .clone()
590    }
591
592    fn deeply_normalize_sorts<T: TypeFoldable + Clone>(
593        genv: GlobalEnv,
594        owner: FluxOwnerId,
595        t: &T,
596    ) -> QueryResult<T> {
597        let infcx = genv
598            .tcx()
599            .infer_ctxt()
600            .with_next_trait_solver(true)
601            .build(TypingMode::non_body_analysis());
602        if let Some(def_id) = owner.resolved_id() {
603            t.deeply_normalize_sorts(def_id, genv, &infcx)
604        } else {
605            Ok(t.clone())
606        }
607    }
608
609    // FIXME(nilehmann) this is a bit of a hack. We should find a more robust way to do normalization
610    // for sort checking, including normalizing projections. [RJ: normalizing_projections is done now]
611    // Maybe we can do lazy normalization. Once we do that maybe we can also stop
612    // expanding aliases in `ConvCtxt::conv_sort`.
613    pub(crate) fn normalize_sorts(&mut self) -> QueryResult {
614        for sort in self.node_sort.values_mut() {
615            *sort = Self::deeply_normalize_sorts(self.genv, self.owner, sort)?;
616        }
617        for fsort in self.sort_of_alias_reft.values_mut() {
618            *fsort = Self::deeply_normalize_sorts(self.genv, self.owner, fsort)?;
619        }
620        Ok(())
621    }
622}
623
624impl<'genv> InferCtxt<'genv, '_> {
625    pub(super) fn declare_param(&mut self, param: fhir::RefineParam<'genv>, sort: rty::Sort) {
626        self.params.insert(param.id, (param, sort));
627    }
628
629    /// Whether a value of `sort1` can be automatically coerced to a value of `sort2`. A value of an
630    /// [`rty::SortCtor::Adt`] sort with a single field of sort `s` can be coerced to a value of sort
631    /// `s` and vice versa, i.e., we can automatically project the field out of the record or inject
632    /// a value into a record.
633    fn is_coercible(&mut self, sort1: &rty::Sort, sort2: &rty::Sort, fhir_id: FhirId) -> bool {
634        if self.try_equate(sort1, sort2).is_some() {
635            return true;
636        }
637
638        let mut sort1 = sort1.clone();
639        let mut sort2 = sort2.clone();
640
641        let mut coercions = vec![];
642        if let Some((def_id, sort)) = self.is_single_field_struct(&sort1) {
643            coercions.push(rty::Coercion::Project(def_id));
644            sort1 = sort.clone();
645        }
646        if let Some((def_id, sort)) = self.is_single_field_struct(&sort2) {
647            coercions.push(rty::Coercion::Inject(def_id));
648            sort2 = sort.clone();
649        }
650        self.wfckresults.coercions_mut().insert(fhir_id, coercions);
651        self.try_equate(&sort1, &sort2).is_some()
652    }
653
654    fn is_coercible_from_func(
655        &mut self,
656        sort: &rty::Sort,
657        fhir_id: FhirId,
658    ) -> Option<rty::PolyFuncSort> {
659        if let rty::Sort::Func(fsort) = sort {
660            Some(fsort.clone())
661        } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
662            self.wfckresults
663                .coercions_mut()
664                .insert(fhir_id, vec![rty::Coercion::Inject(def_id)]);
665            Some(fsort.clone())
666        } else {
667            None
668        }
669    }
670
671    fn is_coercible_to_func(
672        &mut self,
673        sort: &rty::Sort,
674        fhir_id: FhirId,
675    ) -> Option<rty::PolyFuncSort> {
676        if let rty::Sort::Func(fsort) = sort {
677            Some(fsort.clone())
678        } else if let Some((def_id, rty::Sort::Func(fsort))) = self.is_single_field_struct(sort) {
679            self.wfckresults
680                .coercions_mut()
681                .insert(fhir_id, vec![rty::Coercion::Project(def_id)]);
682            Some(fsort.clone())
683        } else {
684            None
685        }
686    }
687
688    fn try_equate(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
689        let sort1 = self.resolve_vars_if_possible(sort1);
690        let sort2 = self.resolve_vars_if_possible(sort2);
691        self.try_equate_inner(&sort1, &sort2)
692    }
693
694    fn try_equate_inner(&mut self, sort1: &rty::Sort, sort2: &rty::Sort) -> Option<rty::Sort> {
695        match (sort1, sort2) {
696            (rty::Sort::Infer(vid1), rty::Sort::Infer(vid2)) => {
697                self.sort_unification_table
698                    .unify_var_var(*vid1, *vid2)
699                    .ok()?;
700            }
701            (rty::Sort::Infer(vid), sort) | (sort, rty::Sort::Infer(vid)) => {
702                self.sort_unification_table
703                    .unify_var_value(*vid, rty::SortVarVal::Solved(sort.clone()))
704                    .ok()?;
705            }
706            (rty::Sort::App(ctor1, args1), rty::Sort::App(ctor2, args2)) => {
707                if ctor1 != ctor2 || args1.len() != args2.len() {
708                    return None;
709                }
710                for (s1, s2) in iter::zip(args1, args2) {
711                    self.try_equate_inner(s1, s2)?;
712                }
713            }
714            (rty::Sort::Tuple(sorts1), rty::Sort::Tuple(sorts2)) => {
715                if sorts1.len() != sorts2.len() {
716                    return None;
717                }
718                for (s1, s2) in iter::zip(sorts1, sorts2) {
719                    self.try_equate_inner(s1, s2)?;
720                }
721            }
722            (rty::Sort::BitVec(size1), rty::Sort::BitVec(size2)) => {
723                self.try_equate_bv_sizes(*size1, *size2)?;
724            }
725            _ if sort1 == sort2 => {}
726            _ => return None,
727        }
728        Some(sort1.clone())
729    }
730
731    fn try_equate_bv_sizes(
732        &mut self,
733        size1: rty::BvSize,
734        size2: rty::BvSize,
735    ) -> Option<rty::BvSize> {
736        match (size1, size2) {
737            (rty::BvSize::Infer(vid1), rty::BvSize::Infer(vid2)) => {
738                self.bv_size_unification_table
739                    .unify_var_var(vid1, vid2)
740                    .ok()?;
741            }
742            (rty::BvSize::Infer(vid), size) | (size, rty::BvSize::Infer(vid)) => {
743                self.bv_size_unification_table
744                    .unify_var_value(vid, Some(size))
745                    .ok()?;
746            }
747            _ if size1 == size2 => {}
748            _ => return None,
749        }
750        Some(size1)
751    }
752
753    fn next_sort_var(&mut self) -> rty::Sort {
754        rty::Sort::Infer(self.next_sort_vid(Default::default()))
755    }
756
757    fn next_sort_var_with_cstr(&mut self, cstr: rty::SortCstr) -> rty::Sort {
758        rty::Sort::Infer(self.next_sort_vid(rty::SortVarVal::Unsolved(cstr)))
759    }
760
761    pub(crate) fn next_sort_vid(&mut self, val: rty::SortVarVal) -> rty::SortVid {
762        self.sort_unification_table.new_key(val)
763    }
764
765    fn next_bv_size_var(&mut self) -> rty::BvSize {
766        rty::BvSize::Infer(self.next_bv_size_vid())
767    }
768
769    fn next_bv_size_vid(&mut self) -> rty::BvSizeVid {
770        self.bv_size_unification_table.new_key(None)
771    }
772
773    fn ensure_resolved_path(&mut self, path: &fhir::PathExpr) -> Result<rty::Sort> {
774        let sort = self.synth_path(path);
775        self.fully_resolve(&sort)
776            .map_err(|_| self.emit_err(errors::CannotInferSort::new(path.span)))
777    }
778
779    fn is_single_field_struct(&mut self, sort: &rty::Sort) -> Option<(DefId, rty::Sort)> {
780        if let rty::Sort::App(rty::SortCtor::Adt(sort_def), sort_args) = sort
781            && let Some(variant) = sort_def.opt_struct_variant()
782            && let [sort] = &variant.field_sorts(sort_args)[..]
783        {
784            Some((sort_def.did(), sort.clone()))
785        } else {
786            None
787        }
788    }
789
790    pub(crate) fn into_results(mut self) -> Result<WfckResults> {
791        // Make sure the int literals are fully resolved. This must happen before everything else
792        // such that we properly apply the fallback for unconstrained num vars.
793        for (node, sort) in std::mem::take(&mut self.sort_of_literal) {
794            // Fallback to `int` when a num variable is unconstrained. Note that we unconditionally
795            // unify the variable. This is fine because if the variable has already been unified,
796            // the operation will fail and this won't have any effect. Also note that unifying a
797            // variable could solve variables that appear later in this for loop. This is also fine
798            // because the order doesn't matter as we are unifying everything to `int`.
799            if let rty::Sort::Infer(vid) = &sort {
800                let _ = self
801                    .sort_unification_table
802                    .unify_var_value(*vid, rty::SortVarVal::Solved(rty::Sort::Int));
803            }
804
805            let sort = self
806                .fully_resolve(&sort)
807                .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
808            self.wfckresults.node_sorts_mut().insert(node.fhir_id, sort);
809        }
810
811        // Make sure the binary operators are fully resolved
812        for (node, sort) in std::mem::take(&mut self.sort_of_bin_op) {
813            let sort = self
814                .fully_resolve(&sort)
815                .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
816            self.wfckresults
817                .bin_op_sorts_mut()
818                .insert(node.fhir_id, sort);
819        }
820
821        let allow_uninterpreted_cast =
822            self.owner
823                .as_rust()
824                .map_or_else(flux_config::allow_uninterpreted_cast, |def_id| {
825                    self.genv
826                        .infer_opts(def_id.local_id().def_id)
827                        .allow_uninterpreted_cast
828                });
829
830        // Make sure that function applications are fully resolved
831        for (node, sort_args) in std::mem::take(&mut self.sort_args_of_app) {
832            let mut res = vec![];
833            for sort_arg in &sort_args {
834                let sort_arg = match sort_arg {
835                    rty::SortArg::Sort(sort) => {
836                        let sort = self
837                            .fully_resolve(sort)
838                            .map_err(|_| self.emit_err(errors::CannotInferSort::new(node.span)))?;
839                        rty::SortArg::Sort(sort)
840                    }
841                    rty::SortArg::BvSize(rty::BvSize::Infer(vid)) => {
842                        let size = self
843                            .bv_size_unification_table
844                            .probe_value(*vid)
845                            .ok_or_else(|| {
846                                self.emit_err(errors::CannotInferSort::new(node.span))
847                            })?;
848                        rty::SortArg::BvSize(size)
849                    }
850                    _ => sort_arg.clone(),
851                };
852                res.push(sort_arg);
853            }
854            if let fhir::ExprKind::App(callee, _) = node.kind
855                && matches!(callee.res, fhir::Res::GlobalFunc(fhir::SpecFuncKind::Cast))
856            {
857                let [rty::SortArg::Sort(from), rty::SortArg::Sort(to)] = &res[..] else {
858                    span_bug!(node.span, "invalid sort args!")
859                };
860                if !allow_uninterpreted_cast
861                    && matches!(from.cast_kind(to), rty::CastKind::Uninterpreted)
862                {
863                    return Err(self.emit_err(errors::InvalidCast::new(node.span, from, to)));
864                }
865            }
866            self.wfckresults
867                .fn_app_sorts_mut()
868                .insert(node.fhir_id, res.into());
869        }
870
871        // Make sure all parameters are fully resolved
872        for (_, (param, sort)) in std::mem::take(&mut self.params) {
873            let sort = self
874                .fully_resolve(&sort)
875                .map_err(|_| self.emit_err(errors::SortAnnotationNeeded::new(&param)))?;
876            self.wfckresults.param_sorts_mut().insert(param.id, sort);
877        }
878
879        Ok(self.wfckresults)
880    }
881
882    pub(crate) fn infer_mode(&self, id: fhir::ParamId) -> fhir::InferMode {
883        fhir::InferMode::from_param_kind(self.params[&id].0.kind)
884    }
885
886    #[track_caller]
887    pub(crate) fn param_sort(&self, id: fhir::ParamId) -> rty::Sort {
888        self.params
889            .get(&id)
890            .unwrap_or_else(|| bug!("no entry found for `{id:?}`"))
891            .1
892            .clone()
893    }
894
895    fn shallow_resolve(&mut self, sort: &rty::Sort) -> rty::Sort {
896        sort.fold_with(&mut ShallowResolver { infcx: self })
897    }
898
899    fn resolve_vars_if_possible(&mut self, sort: &rty::Sort) -> rty::Sort {
900        sort.fold_with(&mut OpportunisticResolver { infcx: self })
901    }
902
903    pub(crate) fn fully_resolve(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, ()> {
904        sort.try_fold_with(&mut FullResolver { infcx: self })
905    }
906}
907
908/// Before the main sort inference, we do a first traversal checking all implicitly scoped parameters
909/// declared with `@` or `#` and infer their sort based on the type they are indexing, e.g., if `n` was
910/// declared as `i32[@n]`, we infer `int` for `n`.
911///
912/// This prepass is necessary because sometimes the order in which we traverse expressions can
913/// affect what we can infer. By resolving implicit parameters first, we ensure more consistent and
914/// complete inference regardless of how expressions are later traversed.
915///
916/// It should be possible to improve sort checking (e.g., by allowing partially resolved sorts in
917/// function position) such that we don't need this anymore.
918pub(crate) struct ImplicitParamInferer<'a, 'genv, 'tcx> {
919    infcx: &'a mut InferCtxt<'genv, 'tcx>,
920    errors: Errors<'genv>,
921}
922
923impl<'a, 'genv, 'tcx> ImplicitParamInferer<'a, 'genv, 'tcx> {
924    pub(crate) fn infer(
925        infcx: &'a mut InferCtxt<'genv, 'tcx>,
926        node: &fhir::OwnerNode<'genv>,
927    ) -> Result {
928        let errors = Errors::new(infcx.genv.sess());
929        let mut vis = Self { infcx, errors };
930        vis.visit_node(node);
931        vis.errors.to_result()
932    }
933
934    fn infer_implicit_params(&mut self, idx: &fhir::Expr, expected: &rty::Sort) {
935        match idx.kind {
936            fhir::ExprKind::Var(QPathExpr::Resolved(var, Some(_))) => {
937                let (_, id) = var.res.expect_param();
938                let found = self.infcx.param_sort(id);
939                self.infcx.try_equate(&found, expected).unwrap_or_else(|| {
940                    span_bug!(idx.span, "failed to equate sorts: `{found:?}` `{expected:?}`")
941                });
942            }
943            fhir::ExprKind::Record(flds) => {
944                if let Some(sorts) = expected.field_sorts() {
945                    if flds.len() != sorts.len() {
946                        self.errors.emit(errors::ArgCountMismatch::new(
947                            Some(idx.span),
948                            String::from("type"),
949                            sorts.len(),
950                            flds.len(),
951                        ));
952                        return;
953                    }
954                    for (f, sort) in iter::zip(flds, &sorts) {
955                        self.infer_implicit_params(f, sort);
956                    }
957                } else {
958                    self.errors.emit(errors::ArgCountMismatch::new(
959                        Some(idx.span),
960                        String::from("type"),
961                        1,
962                        flds.len(),
963                    ));
964                }
965            }
966            _ => {}
967        }
968    }
969}
970
971impl<'genv> fhir::visit::Visitor<'genv> for ImplicitParamInferer<'_, 'genv, '_> {
972    fn visit_ty(&mut self, ty: &fhir::Ty<'genv>) {
973        if let fhir::TyKind::Indexed(bty, idx) = &ty.kind {
974            let expected = self.infcx.sort_of_bty(bty);
975            self.infer_implicit_params(idx, &expected);
976        }
977        fhir::visit::walk_ty(self, ty);
978    }
979}
980
981impl InferCtxt<'_, '_> {
982    #[track_caller]
983    fn emit_sort_mismatch(
984        &mut self,
985        span: Span,
986        expected: &rty::Sort,
987        found: &rty::Sort,
988    ) -> ErrorGuaranteed {
989        let expected = self.resolve_vars_if_possible(expected);
990        let found = self.resolve_vars_if_possible(found);
991        self.emit_err(errors::SortMismatch::new(span, expected, found))
992    }
993
994    fn emit_field_not_found(&mut self, sort: &rty::Sort, field: Ident) -> ErrorGuaranteed {
995        self.emit_err(errors::FieldNotFound::new(sort.clone(), field))
996    }
997
998    #[track_caller]
999    fn emit_err<'b>(&'b self, err: impl Diagnostic<'b>) -> ErrorGuaranteed {
1000        self.genv.sess().emit_err(err)
1001    }
1002}
1003
1004struct ShallowResolver<'a, 'genv, 'tcx> {
1005    infcx: &'a mut InferCtxt<'genv, 'tcx>,
1006}
1007
1008impl TypeFolder for ShallowResolver<'_, '_, '_> {
1009    fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
1010        match sort {
1011            rty::Sort::Infer(vid) => {
1012                // if `sort` is a sort variable, it can be resolved to a bit-vec variable,
1013                // which can then be recursively resolved, hence the recursion. Note though that
1014                // we prevent sort variables from unifying to other sort variables directly (though
1015                // they may be embedded structurally), so this recursion should always be of very
1016                // limited depth.
1017                self.infcx
1018                    .sort_unification_table
1019                    .probe_value(*vid)
1020                    .map_solved(|sort| sort.fold_with(self))
1021                    .solved_or(sort)
1022            }
1023            rty::Sort::BitVec(rty::BvSize::Infer(vid)) => {
1024                self.infcx
1025                    .bv_size_unification_table
1026                    .probe_value(*vid)
1027                    .map(rty::Sort::BitVec)
1028                    .unwrap_or(sort.clone())
1029            }
1030            _ => sort.clone(),
1031        }
1032    }
1033}
1034
1035struct OpportunisticResolver<'a, 'genv, 'tcx> {
1036    infcx: &'a mut InferCtxt<'genv, 'tcx>,
1037}
1038
1039impl TypeFolder for OpportunisticResolver<'_, '_, '_> {
1040    fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
1041        let s = self.infcx.shallow_resolve(sort);
1042        s.super_fold_with(self)
1043    }
1044}
1045
1046struct FullResolver<'a, 'genv, 'tcx> {
1047    infcx: &'a mut InferCtxt<'genv, 'tcx>,
1048}
1049
1050impl FallibleTypeFolder for FullResolver<'_, '_, '_> {
1051    type Error = ();
1052
1053    fn try_fold_sort(&mut self, sort: &rty::Sort) -> std::result::Result<rty::Sort, Self::Error> {
1054        let s = self.infcx.shallow_resolve(sort);
1055        match s {
1056            rty::Sort::Infer(_) | rty::Sort::BitVec(rty::BvSize::Infer(_)) => Err(()),
1057            _ => s.try_super_fold_with(self),
1058        }
1059    }
1060}
1061
1062/// Map to associate data to a node (i.e., an expression).
1063///
1064/// Used to record elaborated information.
1065#[derive_where(Default)]
1066struct NodeMap<'genv, T> {
1067    map: FxHashMap<FhirId, (fhir::Expr<'genv>, T)>,
1068}
1069
1070impl<'genv, T> NodeMap<'genv, T> {
1071    /// Add a `node` to the map with associated `data`
1072    fn insert(&mut self, node: fhir::Expr<'genv>, data: T) {
1073        assert!(self.map.insert(node.fhir_id, (node, data)).is_none());
1074    }
1075}
1076
1077impl<'genv, T> IntoIterator for NodeMap<'genv, T> {
1078    type Item = (fhir::Expr<'genv>, T);
1079
1080    type IntoIter = std::collections::hash_map::IntoValues<FhirId, Self::Item>;
1081
1082    fn into_iter(self) -> Self::IntoIter {
1083        self.map.into_values()
1084    }
1085}