flux_fhir_analysis/wf/
mod.rs

1//! Checks type well-formedness
2//!
3//! Well-formedness checking assumes names are correctly bound which is guaranteed after desugaring.
4
5mod errors;
6mod param_usage;
7mod sortck;
8
9use flux_common::result::{ErrorCollector, ResultExt as _};
10use flux_errors::Errors;
11use flux_middle::{
12    def_id::MaybeExternId,
13    fhir::{self, FhirId, FluxOwnerId, visit::Visitor},
14    global_env::GlobalEnv,
15    queries::QueryResult,
16    rty::{self, WfckResults},
17};
18use rustc_errors::ErrorGuaranteed;
19use rustc_hash::FxHashSet;
20use rustc_hir::{
21    OwnerId,
22    def::DefKind,
23    def_id::{CrateNum, DefId, DefIndex},
24};
25
26use self::sortck::{ImplicitParamInferer, InferCtxt};
27use crate::{
28    conv::{ConvPhase, WfckResultsProvider},
29    wf::sortck::prim_op_sort,
30};
31
32type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
33
34pub(crate) fn check_flux_item<'genv>(
35    genv: GlobalEnv<'genv, '_>,
36    item: fhir::FluxItem<'genv>,
37) -> Result<WfckResults> {
38    let owner = FluxOwnerId::Flux(item.def_id());
39    let mut infcx = InferCtxt::new(genv, owner);
40
41    Wf::with(&mut infcx, |wf| {
42        wf.init_infcx_for_flux_item(item).emit(&genv)?;
43        wf.check_flux_item(item);
44        Ok(())
45    })?;
46    infcx.into_results()
47}
48
49pub(crate) fn check_constant_expr<'genv>(
50    genv: GlobalEnv<'genv, '_>,
51    owner: MaybeExternId<OwnerId>,
52    expr: &fhir::Expr<'genv>,
53    sort: &rty::Sort,
54) -> Result<WfckResults> {
55    let mut infcx = InferCtxt::new(genv, FluxOwnerId::Rust(owner));
56    Wf::with(&mut infcx, |wf| {
57        wf.declare_params_in_expr(expr)?;
58        wf.as_conv_ctxt()
59            .conv_constant_expr(expr)
60            .emit(&wf.errors)?;
61        wf.check_expr(expr, sort);
62        Ok(())
63    })?;
64    infcx.into_results()
65}
66
67pub(crate) fn check_invariants<'genv>(
68    genv: GlobalEnv<'genv, '_>,
69    adt_def_id: MaybeExternId<OwnerId>,
70    params: &[fhir::RefineParam<'genv>],
71    invariants: &[fhir::Expr<'genv>],
72) -> Result<WfckResults> {
73    let owner = FluxOwnerId::Rust(adt_def_id);
74    let mut infcx = InferCtxt::new(genv, owner);
75    Wf::with(&mut infcx, |wf| {
76        wf.declare_params_for_invariants(params, invariants)?;
77
78        // Run first conv phase to gather sorts for associated refinements.
79        // This must run after declaring parameters because conversion expects
80        // the parameter sorts to be present in wfckresults.
81        wf.as_conv_ctxt()
82            .conv_invariants(adt_def_id.map(|it| it.def_id), params, invariants)
83            .emit(&wf.errors)?;
84
85        for invariant in invariants {
86            wf.check_expr(invariant, &rty::Sort::Bool);
87        }
88        Ok(())
89    })?;
90    infcx.into_results()
91}
92
93pub(crate) fn check_node<'genv>(
94    genv: GlobalEnv<'genv, '_>,
95    node: &fhir::OwnerNode<'genv>,
96) -> Result<WfckResults> {
97    let mut infcx = InferCtxt::new(genv, node.owner_id().into());
98    Wf::with(&mut infcx, |wf| {
99        wf.init_infcx_for_node(node)
100            .map_err(|err| err.at(genv.tcx().def_span(node.owner_id().local_id())))
101            .emit(&genv)?;
102
103        ImplicitParamInferer::infer(wf.infcx, node)?;
104
105        wf.check_node(node);
106        Ok(())
107    })?;
108
109    param_usage::check(&infcx, node)?;
110
111    infcx.into_results()
112}
113
114struct Wf<'a, 'genv, 'tcx> {
115    infcx: &'a mut InferCtxt<'genv, 'tcx>,
116    errors: Errors<'genv>,
117    next_type_index: u32,
118    next_region_index: u32,
119    next_const_index: u32,
120}
121
122impl<'a, 'genv, 'tcx> Wf<'a, 'genv, 'tcx> {
123    fn with(infcx: &'a mut InferCtxt<'genv, 'tcx>, f: impl FnOnce(&mut Self) -> Result) -> Result {
124        let errors = Errors::new(infcx.genv.sess());
125        let mut wf = Self {
126            infcx,
127            errors,
128            // We start sorts and types from 1 to skip the trait object dummy self type.
129            // See [`rty::Ty::trait_object_dummy_self`]
130            next_type_index: 1,
131            next_region_index: 0,
132            next_const_index: 0,
133        };
134        f(&mut wf)?;
135        wf.errors.into_result()
136    }
137
138    fn check_flux_item(&mut self, item: fhir::FluxItem<'genv>) {
139        self.visit_flux_item(&item);
140    }
141
142    fn check_node(&mut self, node: &fhir::OwnerNode<'genv>) {
143        self.visit_node(node);
144    }
145
146    fn check_expr(&mut self, expr: &fhir::Expr<'genv>, sort: &rty::Sort) {
147        self.infcx
148            .check_expr(expr, sort)
149            .collect_err(&mut self.errors);
150    }
151
152    // We special-case primop applications to declare their parameters because their
153    // parameters are implicit from the underlying primop and must not be declared explicitly.
154    fn declare_params_for_primop_prop(&mut self, primop_prop: &fhir::PrimOpProp<'genv>) -> Result {
155        let Some((sorts, _)) = prim_op_sort(&primop_prop.op) else {
156            return Err(self
157                .errors
158                .emit(errors::UnsupportedPrimOp::new(primop_prop.span, primop_prop.op)));
159        };
160        if primop_prop.args.len() != sorts.len() {
161            return Err(self.errors.emit(errors::ArgCountMismatch::new(
162                Some(primop_prop.span),
163                String::from("primop"),
164                sorts.len(),
165                primop_prop.args.len(),
166            )));
167        }
168        for (arg, sort) in primop_prop.args.iter().zip(sorts) {
169            self.infcx.declare_param(*arg, sort);
170        }
171        visit_refine_params(
172            |vis| vis.visit_expr(&primop_prop.body),
173            |param| self.declare_param(param),
174        )
175    }
176
177    /// Recursively traverse `item` and declare all refinement parameters
178    fn declare_params_for_flux_item(&mut self, item: fhir::FluxItem<'genv>) -> Result {
179        if let fhir::FluxItem::PrimOpProp(primop_prop) = item {
180            self.declare_params_for_primop_prop(primop_prop)
181        } else {
182            visit_refine_params(|vis| vis.visit_flux_item(&item), |param| self.declare_param(param))
183        }
184    }
185
186    /// Recursively traverse `node` and declare all refinement parameters
187    fn declare_params_for_node(&mut self, node: &fhir::OwnerNode<'genv>) -> Result {
188        visit_refine_params(|vis| vis.visit_node(node), |param| self.declare_param(param))
189    }
190
191    /// Recursively traverse `invariants` and declare all refinement parameters
192    fn declare_params_for_invariants(
193        &mut self,
194        params: &[fhir::RefineParam<'genv>],
195        invariants: &[fhir::Expr<'genv>],
196    ) -> Result {
197        for param in params {
198            self.declare_param(param)?;
199        }
200        for expr in invariants {
201            self.declare_params_in_expr(expr)?;
202        }
203        Ok(())
204    }
205
206    fn declare_params_in_expr(&mut self, expr: &fhir::Expr<'genv>) -> Result {
207        visit_refine_params(|vis| vis.visit_expr(expr), |param| self.declare_param(param))
208    }
209
210    fn declare_param(&mut self, param: &fhir::RefineParam<'genv>) -> Result {
211        let sort = self
212            .as_conv_ctxt()
213            .conv_sort(&param.sort)
214            .emit(&self.genv())?;
215        self.infcx.declare_param(*param, sort);
216        Ok(())
217    }
218
219    /// To check for well-formedness we need to synthesize sorts for some nodes which is hard to
220    /// compute in `fhir`. For example, to check if the type `i32[e]` is well formed, we need to
221    /// know that the sort of `i32` is `int` so we can check the expression `e` against it. Computing
222    /// the sort from a base type is subtle and hard to do in `fhir` so we must do it in `rty`.
223    /// However, to convert from `fhir` to `rty` we need elaborated information from sort checking
224    /// which we do in `fhir`.
225    ///
226    /// To break this circularity, we do conversion in two phases. In the first phase, we do conversion
227    /// without elaborated information. This results in types in `rty` with incorrect refinements but
228    /// with the right *shape* to compute their sorts. We use these sorts for sort checking and then do
229    /// conversion again with the elaborated information.
230    ///
231    /// This function initializes the [inference context] by running the first phase of conversion and
232    /// collecting the sorts of some nodes that are hard to compute in `fhir`.
233    ///
234    /// [inference context]: InferCtxt
235    fn init_infcx_for_node(&mut self, node: &fhir::OwnerNode<'genv>) -> QueryResult {
236        let def_id = node.owner_id().map(|id| id.def_id);
237        self.declare_params_for_node(node)?;
238        let cx = self.as_conv_ctxt();
239        match node {
240            fhir::OwnerNode::Item(item) => {
241                match &item.kind {
242                    fhir::ItemKind::Enum(enum_def) => {
243                        cx.conv_enum_variants(def_id, enum_def)?;
244                        cx.conv_generic_predicates(def_id, &item.generics)?;
245                    }
246                    fhir::ItemKind::Struct(struct_def) => {
247                        cx.conv_struct_variant(def_id, struct_def)?;
248                        cx.conv_generic_predicates(def_id, &item.generics)?;
249                    }
250                    fhir::ItemKind::TyAlias(ty_alias) => {
251                        cx.conv_type_alias(def_id, ty_alias)?;
252                        cx.conv_generic_predicates(def_id, &item.generics)?;
253                    }
254                    fhir::ItemKind::Trait(trait_) => {
255                        for assoc_reft in trait_.assoc_refinements {
256                            if let Some(body) = assoc_reft.body {
257                                cx.conv_assoc_reft_body(
258                                    assoc_reft.params,
259                                    &body,
260                                    &assoc_reft.output,
261                                )?;
262                            }
263                        }
264                        cx.conv_generic_predicates(def_id, &item.generics)?;
265                    }
266                    fhir::ItemKind::Impl(impl_) => {
267                        for assoc_reft in impl_.assoc_refinements {
268                            cx.conv_assoc_reft_body(
269                                assoc_reft.params,
270                                &assoc_reft.body,
271                                &assoc_reft.output,
272                            )?;
273                        }
274                        cx.conv_generic_predicates(def_id, &item.generics)?;
275                    }
276                    fhir::ItemKind::Fn(fn_sig) => {
277                        cx.conv_fn_sig(def_id, fn_sig)?;
278                        cx.conv_generic_predicates(def_id, &item.generics)?;
279                    }
280                    fhir::ItemKind::Static(ty) => {
281                        if let Some(ty) = ty {
282                            cx.conv_static_ty(ty)?;
283                        }
284                    }
285                    fhir::ItemKind::Const(_) => {}
286                }
287            }
288            fhir::OwnerNode::TraitItem(trait_item) => {
289                match trait_item.kind {
290                    fhir::TraitItemKind::Fn(fn_sig) => {
291                        cx.conv_fn_sig(def_id, &fn_sig)?;
292                        cx.conv_generic_predicates(def_id, &trait_item.generics)?;
293                    }
294                    fhir::TraitItemKind::Type => {}
295                    fhir::TraitItemKind::Const => {}
296                }
297            }
298            fhir::OwnerNode::ImplItem(impl_item) => {
299                match impl_item.kind {
300                    fhir::ImplItemKind::Fn(fn_sig) => {
301                        cx.conv_fn_sig(def_id, &fn_sig)?;
302                        cx.conv_generic_predicates(def_id, &impl_item.generics)?;
303                    }
304                    fhir::ImplItemKind::Type => {}
305                    fhir::ImplItemKind::Const => {}
306                }
307            }
308            fhir::OwnerNode::ForeignItem(impl_item) => {
309                match impl_item.kind {
310                    fhir::ForeignItemKind::Fn(fn_sig, generics) => {
311                        cx.conv_fn_sig(def_id, &fn_sig)?;
312                        cx.conv_generic_predicates(def_id, generics)?;
313                    }
314                    fhir::ForeignItemKind::Static(_, _, _, _) => {
315                        // TODO: conv_ty if we want refinements on extern statics?
316                    }
317                }
318            }
319        }
320        self.infcx.normalize_sorts()
321    }
322
323    fn init_infcx_for_flux_item(&mut self, item: fhir::FluxItem<'genv>) -> QueryResult {
324        self.declare_params_for_flux_item(item)?;
325        let cx = self.as_conv_ctxt();
326        match item {
327            fhir::FluxItem::Qualifier(qualifier) => {
328                cx.conv_qualifier(qualifier)?;
329            }
330            fhir::FluxItem::Func(spec_func) => {
331                cx.conv_defn(spec_func)?;
332            }
333            fhir::FluxItem::PrimOpProp(prim_op_prop) => {
334                cx.conv_primop_prop(prim_op_prop)?;
335            }
336            fhir::FluxItem::SortDecl(_sort_decl) => {}
337        }
338        Ok(())
339    }
340
341    fn check_output_locs(&mut self, fn_decl: &fhir::FnDecl) {
342        let mut output_locs = FxHashSet::default();
343        for ens in fn_decl.output.ensures {
344            if let fhir::Ensures::Type(loc, ..) = ens
345                && let (_, id) = loc.res.expect_param()
346                && !output_locs.insert(id)
347            {
348                self.errors.emit(errors::DuplicatedEnsures::new(loc));
349            }
350        }
351
352        for ty in fn_decl.inputs {
353            if let fhir::TyKind::StrgRef(_, loc, _) = ty.kind
354                && let (_, id) = loc.res.expect_param()
355                && !output_locs.contains(&id)
356            {
357                self.errors.emit(errors::MissingEnsures::new(loc));
358            }
359        }
360    }
361}
362
363impl<'genv> fhir::visit::Visitor<'genv> for Wf<'_, 'genv, '_> {
364    fn visit_qualifier(&mut self, qual: &fhir::Qualifier<'genv>) {
365        self.check_expr(&qual.expr, &rty::Sort::Bool);
366    }
367
368    fn visit_primop_prop(&mut self, primop_prop: &fhir::PrimOpProp<'genv>) {
369        let Some((sorts, _)) = prim_op_sort(&primop_prop.op) else {
370            self.errors
371                .emit(errors::UnsupportedPrimOp::new(primop_prop.span, primop_prop.op));
372            return;
373        };
374
375        if primop_prop.args.len() != sorts.len() {
376            self.errors.emit(errors::ArgCountMismatch::new(
377                Some(primop_prop.span),
378                String::from("primop"),
379                sorts.len(),
380                primop_prop.args.len(),
381            ));
382            return;
383        }
384        self.check_expr(&primop_prop.body, &rty::Sort::Bool);
385    }
386
387    fn visit_func(&mut self, func: &fhir::SpecFunc<'genv>) {
388        if let Some(body) = &func.body {
389            let Ok(output) = self.as_conv_ctxt().conv_sort(&func.sort).emit(&self.errors) else {
390                return;
391            };
392            self.check_expr(body, &output);
393        }
394    }
395
396    fn visit_impl_assoc_reft(&mut self, assoc_reft: &fhir::ImplAssocReft<'genv>) {
397        let Ok(output) = self
398            .as_conv_ctxt()
399            .conv_sort(&assoc_reft.output)
400            .emit(&self.errors)
401        else {
402            return;
403        };
404        self.check_expr(&assoc_reft.body, &output);
405    }
406
407    fn visit_trait_assoc_reft(&mut self, assoc_reft: &fhir::TraitAssocReft<'genv>) {
408        if let Some(body) = &assoc_reft.body {
409            let Ok(output) = self
410                .as_conv_ctxt()
411                .conv_sort(&assoc_reft.output)
412                .emit(&self.errors)
413            else {
414                return;
415            };
416            self.check_expr(body, &output);
417        }
418    }
419
420    fn visit_variant_ret(&mut self, ret: &fhir::VariantRet<'genv>) {
421        let genv = self.infcx.genv;
422        let enum_id = ret.enum_id;
423        let Ok(adt_sort_def) = genv.adt_sort_def_of(enum_id).emit(&self.errors) else { return };
424        if adt_sort_def.is_reflected() {
425            return;
426        }
427        let Ok(args) = rty::GenericArg::identity_for_item(genv, enum_id).emit(&self.errors) else {
428            return;
429        };
430        let expected = adt_sort_def.to_sort(&args);
431        self.check_expr(&ret.idx, &expected);
432    }
433
434    fn visit_fn_decl(&mut self, decl: &fhir::FnDecl<'genv>) {
435        fhir::visit::walk_fn_decl(self, decl);
436        self.check_output_locs(decl);
437    }
438
439    fn visit_requires(&mut self, requires: &fhir::Requires<'genv>) {
440        self.check_expr(&requires.pred, &rty::Sort::Bool);
441    }
442
443    fn visit_ensures(&mut self, ensures: &fhir::Ensures<'genv>) {
444        match ensures {
445            fhir::Ensures::Type(loc, ty) => {
446                self.infcx.check_loc(loc).collect_err(&mut self.errors);
447                self.visit_ty(ty);
448            }
449            fhir::Ensures::Pred(pred) => {
450                self.check_expr(pred, &rty::Sort::Bool);
451            }
452        }
453    }
454
455    fn visit_ty(&mut self, ty: &fhir::Ty<'genv>) {
456        match &ty.kind {
457            fhir::TyKind::Indexed(bty, idx) => {
458                let expected = self.infcx.sort_of_bty(bty);
459                self.check_expr(idx, &expected);
460                self.visit_bty(bty);
461            }
462            fhir::TyKind::StrgRef(_, loc, ty) => {
463                self.infcx.check_loc(loc).collect_err(&mut self.errors);
464                self.visit_ty(ty);
465            }
466            fhir::TyKind::Constr(pred, ty) => {
467                self.visit_ty(ty);
468                self.check_expr(pred, &rty::Sort::Bool);
469            }
470            _ => fhir::visit::walk_ty(self, ty),
471        }
472    }
473
474    fn visit_path(&mut self, path: &fhir::Path<'genv>) {
475        let genv = self.genv();
476        if let fhir::Res::Def(DefKind::TyAlias, def_id) = path.res {
477            let Ok(generics) = genv.refinement_generics_of(def_id).emit(&self.errors) else {
478                return;
479            };
480
481            let args = self.infcx.path_args(path.fhir_id);
482            for (i, expr) in path.refine.iter().enumerate() {
483                let Ok(param) = generics.param_at(i, genv).emit(&self.errors) else { return };
484                let param = param.instantiate(genv.tcx(), &args, &[]);
485                self.check_expr(expr, &param.sort);
486            }
487        };
488        fhir::visit::walk_path(self, path);
489    }
490}
491
492struct RefineParamVisitor<F> {
493    f: F,
494    err: Option<ErrorGuaranteed>,
495}
496
497impl<'v, F> fhir::visit::Visitor<'v> for RefineParamVisitor<F>
498where
499    F: FnMut(&fhir::RefineParam<'v>) -> Result,
500{
501    fn visit_refine_param(&mut self, param: &fhir::RefineParam<'v>) {
502        (self.f)(param).collect_err(&mut self.err);
503    }
504}
505
506fn visit_refine_params<'a, F>(visit: impl FnOnce(&mut RefineParamVisitor<F>), f: F) -> Result
507where
508    F: FnMut(&fhir::RefineParam<'a>) -> Result,
509{
510    let mut visitor = RefineParamVisitor { f, err: None };
511    visit(&mut visitor);
512    visitor.err.into_result()
513}
514
515impl<'genv, 'tcx> ConvPhase<'genv, 'tcx> for Wf<'_, 'genv, 'tcx> {
516    /// We don't expand type aliases before sort checking because we need every base type in `fhir`
517    /// to match a type in `rty`.
518    const EXPAND_TYPE_ALIASES: bool = false;
519    const HAS_ELABORATED_INFORMATION: bool = false;
520
521    type Results = InferCtxt<'genv, 'tcx>;
522
523    fn genv(&self) -> GlobalEnv<'genv, 'tcx> {
524        self.infcx.genv
525    }
526
527    fn owner(&self) -> FluxOwnerId {
528        self.infcx.wfckresults.owner
529    }
530
531    fn next_sort_vid(&mut self) -> rty::SortVid {
532        self.infcx.next_sort_vid(Default::default())
533    }
534
535    fn next_type_vid(&mut self) -> rty::TyVid {
536        self.next_type_index = self.next_type_index.checked_add(1).unwrap();
537        rty::TyVid::from_u32(self.next_type_index - 1)
538    }
539
540    fn next_region_vid(&mut self) -> rty::RegionVid {
541        self.next_region_index = self.next_region_index.checked_add(1).unwrap();
542        rty::RegionVid::from_u32(self.next_region_index - 1)
543    }
544
545    fn next_const_vid(&mut self) -> rty::ConstVid {
546        self.next_const_index = self.next_const_index.checked_add(1).unwrap();
547        rty::ConstVid::from_u32(self.next_const_index - 1)
548    }
549
550    fn results(&self) -> &Self::Results {
551        self.infcx
552    }
553
554    fn insert_node_sort(&mut self, fhir_id: FhirId, sort: rty::Sort) {
555        self.infcx.insert_node_sort(fhir_id, sort);
556    }
557
558    fn insert_path_args(&mut self, fhir_id: FhirId, args: rty::GenericArgs) {
559        self.infcx.insert_path_args(fhir_id, args);
560    }
561
562    fn insert_alias_reft_sort(&mut self, fhir_id: FhirId, fsort: rty::FuncSort) {
563        self.infcx.insert_sort_for_alias_reft(fhir_id, fsort);
564    }
565}
566
567/// The purpose of doing conversion before sort checking is to collect the sorts of base types.
568/// Thus, what we return here mostly doesn't matter because the refinements on a type should not
569/// affect its sort. The one exception is the sort we generate for refinement parameters.
570///
571/// For instance, consider the following definition where we refine a struct with a polymorphic set:
572/// ```ignore
573/// #[flux::refined_by(elems: Set<T>)]
574/// struct RSet<T> { ... }
575/// ```
576/// Now, consider the type `RSet<i32{v: v >= 0}>`. This type desugars to `RSet<λv:σ. {i32[v] | v >= 0}>`
577/// where the sort `σ` needs to be inferred. The type `RSet<λv:σ. {i32[v] | v >= 0}>` has sort
578/// `RSet<σ>` where `RSet` is the sort-level representation of the `RSet` type. Thus, it is important
579/// that the inference variable we generate for `σ` is the same we use for sort checking.
580impl WfckResultsProvider for InferCtxt<'_, '_> {
581    fn bin_op_sort(&self, _: FhirId) -> rty::Sort {
582        rty::Sort::Err
583    }
584
585    fn coercions_for(&self, _: FhirId) -> &[rty::Coercion] {
586        &[]
587    }
588
589    fn field_proj(&self, _: FhirId) -> rty::FieldProj {
590        rty::FieldProj::Tuple { arity: 0, field: 0 }
591    }
592
593    fn record_ctor(&self, _: FhirId) -> DefId {
594        DefId { index: DefIndex::from_u32(0), krate: CrateNum::from_u32(0) }
595    }
596
597    fn param_sort(&self, param_id: fhir::ParamId) -> rty::Sort {
598        self.param_sort(param_id)
599    }
600
601    fn node_sort(&self, _: FhirId) -> rty::Sort {
602        rty::Sort::Err
603    }
604
605    fn node_sort_args(&self, _: FhirId) -> rty::List<rty::SortArg> {
606        rty::List::empty()
607    }
608}