flux_fhir_analysis/wf/
mod.rs

1//! Checks type well-formedness
2//!
3//! Well-formedness checking assumes names are correctly bound which is guaranteed after desugaring.
4
5mod errors;
6mod param_usage;
7mod sortck;
8
9use flux_common::result::{ErrorCollector, ResultExt as _};
10use flux_errors::Errors;
11use flux_middle::{
12    def_id::MaybeExternId,
13    fhir::{self, FhirId, FluxOwnerId, visit::Visitor},
14    global_env::GlobalEnv,
15    queries::QueryResult,
16    rty::{self, WfckResults},
17};
18use rustc_errors::ErrorGuaranteed;
19use rustc_hash::FxHashSet;
20use rustc_hir::{
21    OwnerId,
22    def::DefKind,
23    def_id::{CrateNum, DefId, DefIndex},
24};
25
26use self::sortck::{ImplicitParamInferer, InferCtxt};
27use crate::{
28    conv::{ConvPhase, WfckResultsProvider},
29    wf::sortck::prim_op_sort,
30};
31
32type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
33
34pub(crate) fn check_flux_item<'genv>(
35    genv: GlobalEnv<'genv, '_>,
36    item: fhir::FluxItem<'genv>,
37) -> Result<WfckResults> {
38    let owner = FluxOwnerId::Flux(item.def_id());
39    let mut infcx = InferCtxt::new(genv, owner);
40
41    Wf::with(&mut infcx, |wf| {
42        wf.declare_params_for_flux_item(item)?;
43        wf.check_flux_item(item);
44        Ok(())
45    })?;
46    infcx.into_results()
47}
48
49pub(crate) fn check_constant_expr(
50    genv: GlobalEnv,
51    owner: OwnerId,
52    expr: &fhir::Expr,
53    sort: &rty::Sort,
54) -> Result<WfckResults> {
55    let mut infcx = InferCtxt::new(genv, FluxOwnerId::Rust(owner));
56    infcx.check_expr(expr, sort)?;
57    infcx.into_results()
58}
59
60pub(crate) fn check_invariants<'genv>(
61    genv: GlobalEnv<'genv, '_>,
62    adt_def_id: MaybeExternId<OwnerId>,
63    params: &[fhir::RefineParam<'genv>],
64    invariants: &[fhir::Expr<'genv>],
65) -> Result<WfckResults> {
66    let owner = FluxOwnerId::Rust(adt_def_id.local_id());
67    let mut infcx = InferCtxt::new(genv, owner);
68    Wf::with(&mut infcx, |wf| {
69        wf.declare_params_for_invariants(params, invariants)?;
70
71        // Run first conv phase to gather sorts for associated refinements.
72        // This must run after declaring parameters because conversion expects
73        // the parameter sorts to be present in wfckresults.
74        wf.as_conv_ctxt()
75            .conv_invariants(adt_def_id.map(|it| it.def_id), params, invariants)
76            .emit(&wf.errors)?;
77
78        for invariant in invariants {
79            wf.infcx
80                .check_expr(invariant, &rty::Sort::Bool)
81                .collect_err(&mut wf.errors);
82        }
83        Ok(())
84    })?;
85    infcx.into_results()
86}
87
88pub(crate) fn check_node<'genv>(
89    genv: GlobalEnv<'genv, '_>,
90    node: &fhir::OwnerNode<'genv>,
91) -> Result<WfckResults> {
92    let mut infcx = InferCtxt::new(genv, node.owner_id().local_id().into());
93    Wf::with(&mut infcx, |wf| {
94        wf.init_infcx(node)
95            .map_err(|err| err.at(genv.tcx().def_span(node.owner_id().local_id())))
96            .emit(&genv)?;
97
98        ImplicitParamInferer::infer(wf.infcx, node)?;
99
100        wf.check_node(node);
101        Ok(())
102    })?;
103
104    param_usage::check(&infcx, node)?;
105
106    infcx.into_results()
107}
108
109struct Wf<'a, 'genv, 'tcx> {
110    infcx: &'a mut InferCtxt<'genv, 'tcx>,
111    errors: Errors<'genv>,
112    next_type_index: u32,
113    next_region_index: u32,
114    next_const_index: u32,
115}
116
117impl<'a, 'genv, 'tcx> Wf<'a, 'genv, 'tcx> {
118    fn with(infcx: &'a mut InferCtxt<'genv, 'tcx>, f: impl FnOnce(&mut Self) -> Result) -> Result {
119        let errors = Errors::new(infcx.genv.sess());
120        let mut wf = Self {
121            infcx,
122            errors,
123            // We start sorts and types from 1 to skip the trait object dummy self type.
124            // See [`rty::Ty::trait_object_dummy_self`]
125            next_type_index: 1,
126            next_region_index: 0,
127            next_const_index: 0,
128        };
129        f(&mut wf)?;
130        wf.errors.into_result()
131    }
132
133    fn check_flux_item(&mut self, item: fhir::FluxItem<'genv>) {
134        self.visit_flux_item(&item);
135    }
136
137    fn check_node(&mut self, node: &fhir::OwnerNode<'genv>) {
138        self.visit_node(node);
139    }
140
141    // We special-case primop applications to declare their parameters because their
142    // parameters are implicit from the underlying primop and must not be declared explicitly.
143    fn declare_params_for_prim_prop(&mut self, prim_prop: &fhir::PrimProp<'genv>) -> Result {
144        let Some((sorts, _)) = prim_op_sort(&prim_prop.op) else {
145            return Err(self
146                .errors
147                .emit(errors::UnsupportedPrimOp::new(prim_prop.span, prim_prop.op)));
148        };
149        if prim_prop.args.len() != sorts.len() {
150            return Err(self.errors.emit(errors::ArgCountMismatch::new(
151                Some(prim_prop.span),
152                String::from("primop"),
153                sorts.len(),
154                prim_prop.args.len(),
155            )));
156        }
157        for (arg, sort) in prim_prop.args.iter().zip(sorts) {
158            self.infcx.declare_param(*arg, sort);
159        }
160        Ok(())
161    }
162
163    /// Recursively traverse `item` and declare all refinement parameters
164    fn declare_params_for_flux_item(&mut self, item: fhir::FluxItem<'genv>) -> Result {
165        if let fhir::FluxItem::PrimProp(prim_prop) = item {
166            self.declare_params_for_prim_prop(prim_prop)
167        } else {
168            visit_refine_params(|vis| vis.visit_flux_item(&item), |param| self.declare_param(param))
169        }
170    }
171
172    /// Recursively traverse `node` and declare all refinement parameters
173    fn declare_params_for_node(&mut self, node: &fhir::OwnerNode<'genv>) -> Result {
174        visit_refine_params(|vis| vis.visit_node(node), |param| self.declare_param(param))
175    }
176
177    /// Recursively traverse `invariants` and declare all refinement parameters
178    fn declare_params_for_invariants(
179        &mut self,
180        params: &[fhir::RefineParam<'genv>],
181        invariants: &[fhir::Expr<'genv>],
182    ) -> Result {
183        for param in params {
184            self.declare_param(param)?;
185        }
186        for expr in invariants {
187            visit_refine_params(|vis| vis.visit_expr(expr), |param| self.declare_param(param))?;
188        }
189        Ok(())
190    }
191
192    fn declare_param(&mut self, param: &fhir::RefineParam<'genv>) -> Result {
193        let sort = self
194            .as_conv_ctxt()
195            .conv_sort(&param.sort)
196            .emit(&self.genv())?;
197        self.infcx.declare_param(*param, sort);
198        Ok(())
199    }
200
201    /// To check for well-formedness we need to know the sort of base types. For example, to check if
202    /// the type `i32[e]` is well formed, we need to know that the sort of `i32` is `int` so we can
203    /// check the expression `e` against it. Computing the sort from a base type is subtle and hard
204    /// to do in `fhir` so we must do it in `rty`. However, to convert from `fhir` to `rty` we need
205    /// elaborated information from sort checking which we do in `fhir`.
206    ///
207    /// To break this circularity, we do conversion in two phases. In the first phase, we do conversion
208    /// without elaborated information. This results in types in `rty` with incorrect refinements but
209    /// with the right *shape* to compute their sorts. We use these sorts for sort checking and then do
210    /// conversion again with the elaborated information.
211    ///
212    /// This function initializes the [inference context] by running the first phase of conversion and
213    /// collecting the sort of all base types.
214    ///
215    /// [inference context]: InferCtxt
216    fn init_infcx(&mut self, node: &fhir::OwnerNode<'genv>) -> QueryResult {
217        let def_id = node.owner_id().map(|id| id.def_id);
218        self.declare_params_for_node(node)?;
219        let cx = self.as_conv_ctxt();
220        match node {
221            fhir::OwnerNode::Item(item) => {
222                match &item.kind {
223                    fhir::ItemKind::Enum(enum_def) => {
224                        cx.conv_enum_variants(def_id, enum_def)?;
225                        cx.conv_generic_predicates(def_id, &item.generics)?;
226                    }
227                    fhir::ItemKind::Struct(struct_def) => {
228                        cx.conv_struct_variant(def_id, struct_def)?;
229                        cx.conv_generic_predicates(def_id, &item.generics)?;
230                    }
231                    fhir::ItemKind::TyAlias(ty_alias) => {
232                        cx.conv_type_alias(def_id, ty_alias)?;
233                        cx.conv_generic_predicates(def_id, &item.generics)?;
234                    }
235                    fhir::ItemKind::Trait(trait_) => {
236                        for assoc_reft in trait_.assoc_refinements {
237                            if let Some(body) = assoc_reft.body {
238                                cx.conv_assoc_reft_body(
239                                    assoc_reft.params,
240                                    &body,
241                                    &assoc_reft.output,
242                                )?;
243                            }
244                        }
245                        cx.conv_generic_predicates(def_id, &item.generics)?;
246                    }
247                    fhir::ItemKind::Impl(impl_) => {
248                        for assoc_reft in impl_.assoc_refinements {
249                            cx.conv_assoc_reft_body(
250                                assoc_reft.params,
251                                &assoc_reft.body,
252                                &assoc_reft.output,
253                            )?;
254                        }
255                        cx.conv_generic_predicates(def_id, &item.generics)?;
256                    }
257                    fhir::ItemKind::Fn(fn_sig) => {
258                        cx.conv_fn_sig(def_id, fn_sig)?;
259                        cx.conv_generic_predicates(def_id, &item.generics)?;
260                    }
261                    fhir::ItemKind::Const(_) => {}
262                }
263            }
264            fhir::OwnerNode::TraitItem(trait_item) => {
265                match trait_item.kind {
266                    fhir::TraitItemKind::Fn(fn_sig) => {
267                        cx.conv_fn_sig(def_id, &fn_sig)?;
268                        cx.conv_generic_predicates(def_id, &trait_item.generics)?;
269                    }
270                    fhir::TraitItemKind::Type => {}
271                    fhir::TraitItemKind::Const => {}
272                }
273            }
274            fhir::OwnerNode::ImplItem(impl_item) => {
275                match impl_item.kind {
276                    fhir::ImplItemKind::Fn(fn_sig) => {
277                        cx.conv_fn_sig(def_id, &fn_sig)?;
278                        cx.conv_generic_predicates(def_id, &impl_item.generics)?;
279                    }
280                    fhir::ImplItemKind::Type => {}
281                    fhir::ImplItemKind::Const => {}
282                }
283            }
284            fhir::OwnerNode::ForeignItem(impl_item) => {
285                match impl_item.kind {
286                    fhir::ForeignItemKind::Fn(fn_sig, generics) => {
287                        cx.conv_fn_sig(def_id, &fn_sig)?;
288                        cx.conv_generic_predicates(def_id, generics)?;
289                    }
290                }
291            }
292        }
293        self.infcx.normalize_sorts()
294    }
295
296    fn check_output_locs(&mut self, fn_decl: &fhir::FnDecl) {
297        let mut output_locs = FxHashSet::default();
298        for ens in fn_decl.output.ensures {
299            if let fhir::Ensures::Type(loc, ..) = ens
300                && let (_, id) = loc.res.expect_param()
301                && !output_locs.insert(id)
302            {
303                self.errors.emit(errors::DuplicatedEnsures::new(loc));
304            }
305        }
306
307        for ty in fn_decl.inputs {
308            if let fhir::TyKind::StrgRef(_, loc, _) = ty.kind
309                && let (_, id) = loc.res.expect_param()
310                && !output_locs.contains(&id)
311            {
312                self.errors.emit(errors::MissingEnsures::new(loc));
313            }
314        }
315    }
316}
317
318impl<'genv> fhir::visit::Visitor<'genv> for Wf<'_, 'genv, '_> {
319    fn visit_qualifier(&mut self, qual: &fhir::Qualifier<'genv>) {
320        self.infcx
321            .check_expr(&qual.expr, &rty::Sort::Bool)
322            .collect_err(&mut self.errors);
323    }
324
325    fn visit_prim_prop(&mut self, prim_prop: &fhir::PrimProp<'genv>) {
326        let Some((sorts, _)) = prim_op_sort(&prim_prop.op) else {
327            self.errors
328                .emit(errors::UnsupportedPrimOp::new(prim_prop.span, prim_prop.op));
329            return;
330        };
331
332        if prim_prop.args.len() != sorts.len() {
333            self.errors.emit(errors::ArgCountMismatch::new(
334                Some(prim_prop.span),
335                String::from("primop"),
336                sorts.len(),
337                prim_prop.args.len(),
338            ));
339            return;
340        }
341        self.infcx
342            .check_expr(&prim_prop.body, &rty::Sort::Bool)
343            .collect_err(&mut self.errors);
344    }
345
346    fn visit_func(&mut self, func: &fhir::SpecFunc<'genv>) {
347        if let Some(body) = &func.body {
348            let Ok(output) = self.as_conv_ctxt().conv_sort(&func.sort).emit(&self.errors) else {
349                return;
350            };
351            self.infcx
352                .check_expr(body, &output)
353                .collect_err(&mut self.errors);
354        }
355    }
356
357    fn visit_impl_assoc_reft(&mut self, assoc_reft: &fhir::ImplAssocReft) {
358        let Ok(output) = self
359            .as_conv_ctxt()
360            .conv_sort(&assoc_reft.output)
361            .emit(&self.errors)
362        else {
363            return;
364        };
365        self.infcx
366            .check_expr(&assoc_reft.body, &output)
367            .collect_err(&mut self.errors);
368    }
369
370    fn visit_trait_assoc_reft(&mut self, assoc_reft: &fhir::TraitAssocReft) {
371        if let Some(body) = &assoc_reft.body {
372            let Ok(output) = self
373                .as_conv_ctxt()
374                .conv_sort(&assoc_reft.output)
375                .emit(&self.errors)
376            else {
377                return;
378            };
379            self.infcx
380                .check_expr(body, &output)
381                .collect_err(&mut self.errors);
382        }
383    }
384
385    fn visit_variant_ret(&mut self, ret: &fhir::VariantRet) {
386        let genv = self.infcx.genv;
387        let enum_id = ret.enum_id;
388        let Ok(adt_sort_def) = genv.adt_sort_def_of(enum_id).emit(&self.errors) else { return };
389        if adt_sort_def.is_reflected() {
390            return;
391        }
392        let Ok(args) = rty::GenericArg::identity_for_item(genv, enum_id).emit(&self.errors) else {
393            return;
394        };
395        let expected = adt_sort_def.to_sort(&args);
396        self.infcx
397            .check_expr(&ret.idx, &expected)
398            .collect_err(&mut self.errors);
399    }
400
401    fn visit_fn_decl(&mut self, decl: &fhir::FnDecl<'genv>) {
402        fhir::visit::walk_fn_decl(self, decl);
403        self.check_output_locs(decl);
404    }
405
406    fn visit_requires(&mut self, requires: &fhir::Requires<'genv>) {
407        self.infcx
408            .check_expr(&requires.pred, &rty::Sort::Bool)
409            .collect_err(&mut self.errors);
410    }
411
412    fn visit_ensures(&mut self, ensures: &fhir::Ensures<'genv>) {
413        match ensures {
414            fhir::Ensures::Type(loc, ty) => {
415                self.infcx.check_loc(loc).collect_err(&mut self.errors);
416                self.visit_ty(ty);
417            }
418            fhir::Ensures::Pred(pred) => {
419                self.infcx
420                    .check_expr(pred, &rty::Sort::Bool)
421                    .collect_err(&mut self.errors);
422            }
423        }
424    }
425
426    fn visit_ty(&mut self, ty: &fhir::Ty<'genv>) {
427        match &ty.kind {
428            fhir::TyKind::Indexed(bty, idx) => {
429                let expected = self.infcx.sort_of_bty(bty.fhir_id);
430                self.infcx
431                    .check_expr(idx, &expected)
432                    .collect_err(&mut self.errors);
433                self.visit_bty(bty);
434            }
435            fhir::TyKind::StrgRef(_, loc, ty) => {
436                self.infcx.check_loc(loc).collect_err(&mut self.errors);
437                self.visit_ty(ty);
438            }
439            fhir::TyKind::Constr(pred, ty) => {
440                self.visit_ty(ty);
441                self.infcx
442                    .check_expr(pred, &rty::Sort::Bool)
443                    .collect_err(&mut self.errors);
444            }
445            _ => fhir::visit::walk_ty(self, ty),
446        }
447    }
448
449    fn visit_path(&mut self, path: &fhir::Path<'genv>) {
450        let genv = self.genv();
451        if let fhir::Res::Def(DefKind::TyAlias, def_id) = path.res {
452            let Ok(generics) = genv.refinement_generics_of(def_id).emit(&self.errors) else {
453                return;
454            };
455
456            let args = self.infcx.path_args(path.fhir_id);
457            for (i, expr) in path.refine.iter().enumerate() {
458                let Ok(param) = generics.param_at(i, genv).emit(&self.errors) else { return };
459                let param = param.instantiate(genv.tcx(), &args, &[]);
460                self.infcx
461                    .check_expr(expr, &param.sort)
462                    .collect_err(&mut self.errors);
463            }
464        };
465        fhir::visit::walk_path(self, path);
466    }
467}
468
469struct RefineParamVisitor<F> {
470    f: F,
471    err: Option<ErrorGuaranteed>,
472}
473
474impl<'v, F> fhir::visit::Visitor<'v> for RefineParamVisitor<F>
475where
476    F: FnMut(&fhir::RefineParam<'v>) -> Result,
477{
478    fn visit_refine_param(&mut self, param: &fhir::RefineParam<'v>) {
479        (self.f)(param).collect_err(&mut self.err);
480    }
481}
482
483fn visit_refine_params<'a, F>(visit: impl FnOnce(&mut RefineParamVisitor<F>), f: F) -> Result
484where
485    F: FnMut(&fhir::RefineParam<'a>) -> Result,
486{
487    let mut visitor = RefineParamVisitor { f, err: None };
488    visit(&mut visitor);
489    visitor.err.into_result()
490}
491
492impl<'genv, 'tcx> ConvPhase<'genv, 'tcx> for Wf<'_, 'genv, 'tcx> {
493    /// We don't expand type aliases before sort checking because we need every base type in `fhir`
494    /// to match a type in `rty`.
495    const EXPAND_TYPE_ALIASES: bool = false;
496    const HAS_ELABORATED_INFORMATION: bool = false;
497
498    type Results = InferCtxt<'genv, 'tcx>;
499
500    fn genv(&self) -> GlobalEnv<'genv, 'tcx> {
501        self.infcx.genv
502    }
503
504    fn owner(&self) -> FluxOwnerId {
505        self.infcx.wfckresults.owner
506    }
507
508    fn next_sort_vid(&mut self) -> rty::SortVid {
509        self.infcx.next_sort_vid()
510    }
511
512    fn next_type_vid(&mut self) -> rty::TyVid {
513        self.next_type_index = self.next_type_index.checked_add(1).unwrap();
514        rty::TyVid::from_u32(self.next_type_index - 1)
515    }
516
517    fn next_region_vid(&mut self) -> rty::RegionVid {
518        self.next_region_index = self.next_region_index.checked_add(1).unwrap();
519        rty::RegionVid::from_u32(self.next_region_index - 1)
520    }
521
522    fn next_const_vid(&mut self) -> rty::ConstVid {
523        self.next_const_index = self.next_const_index.checked_add(1).unwrap();
524        rty::ConstVid::from_u32(self.next_const_index - 1)
525    }
526
527    fn results(&self) -> &Self::Results {
528        self.infcx
529    }
530
531    fn insert_bty_sort(&mut self, fhir_id: FhirId, sort: rty::Sort) {
532        self.infcx.insert_sort_for_bty(fhir_id, sort);
533    }
534
535    fn insert_path_args(&mut self, fhir_id: FhirId, args: rty::GenericArgs) {
536        self.infcx.insert_path_args(fhir_id, args);
537    }
538
539    fn insert_alias_reft_sort(&mut self, fhir_id: FhirId, fsort: rty::FuncSort) {
540        self.infcx.insert_sort_for_alias_reft(fhir_id, fsort);
541    }
542}
543
544/// The purpose of doing conversion before sort checking is to collect the sorts of base types.
545/// Thus, what we return here mostly doesn't matter because the refinements on a type should not
546/// affect its sort. The one exception is the sort we generate for refinement parameters.
547///
548/// For instance, consider the following definition where we refine a struct with a polymorphic set:
549/// ```ignore
550/// #[flux::refined_by(elems: Set<T>)]
551/// struct RSet<T> { ... }
552/// ```
553/// Now, consider the type `RSet<i32{v: v >= 0}>`. This type desugars to `RSet<λv:σ. {i32[v] | v >= 0}>`
554/// where the sort `σ` needs to be inferred. The type `RSet<λv:σ. {i32[v] | v >= 0}>` has sort
555/// `RSet<σ>` where `RSet` is the sort-level representation of the `RSet` type. Thus, it is important
556/// that the inference variable we generate for `σ` is the same we use for sort checking.
557impl WfckResultsProvider for InferCtxt<'_, '_> {
558    fn bin_op_sort(&self, _: FhirId) -> rty::Sort {
559        rty::Sort::Err
560    }
561
562    fn coercions_for(&self, _: FhirId) -> &[rty::Coercion] {
563        &[]
564    }
565
566    fn field_proj(&self, _: FhirId) -> rty::FieldProj {
567        rty::FieldProj::Tuple { arity: 0, field: 0 }
568    }
569
570    fn record_ctor(&self, _: FhirId) -> DefId {
571        DefId { index: DefIndex::from_u32(0), krate: CrateNum::from_u32(0) }
572    }
573
574    fn param_sort(&self, param: &fhir::RefineParam) -> rty::Sort {
575        self.param_sort(param.id)
576    }
577
578    fn node_sort(&self, _: FhirId) -> rty::Sort {
579        rty::Sort::Err
580    }
581}