flux_fhir_analysis/wf/
param_usage.rs

1//! Code to check whether refinement parameters are used in allowed positions.
2//!
3//! The correct usage of a parameter depends on whether its infer mode is [evar] or [kvar].
4//! For evar mode, parameters must be used at least once as an index in a position that fully
5//! determines their value (see <https://arxiv.org/pdf/2209.13000.pdf> for details). Parameters
6//! with kvar mode (i.e., abstract refinement predicates) must only be used in function position
7//! in a top-level conjunction such that they result in a proper horn constraint after being
8//! substituted by a kvar as required by fixpoint.
9//!
10//! [evar]: `fhir::InferMode::EVar`
11//! [kvar]: `fhir::InferMode::KVar`
12
13use flux_errors::Errors;
14use flux_middle::{
15    fhir::{self, visit::Visitor},
16    rty, walk_list,
17};
18use rustc_data_structures::snapshot_map;
19use rustc_span::ErrorGuaranteed;
20
21use super::{
22    errors::{InvalidParamPos, ParamNotDetermined},
23    sortck::InferCtxt,
24};
25
26type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
27
28pub(crate) fn check<'genv>(infcx: &InferCtxt<'genv, '_>, node: &fhir::OwnerNode<'genv>) -> Result {
29    ParamUsesChecker::new(infcx).run(|ck| ck.visit_node(node))
30}
31
32struct ParamUsesChecker<'a, 'genv, 'tcx> {
33    infcx: &'a InferCtxt<'genv, 'tcx>,
34    /// Keeps track of all refinement parameters that are used as an index such that their value is fully
35    /// determined. The name xi is taken from [1], where the well-formedness judgment uses an uppercase
36    /// Xi (Ξ) for a context that is similar in purpose.
37    ///
38    /// This is basically a set of [`fhir::ParamId`] implemented with a snapshot map such that elements
39    /// can be removed in batch when there's a change in polarity.
40    ///
41    /// [1]: https://arxiv.org/pdf/2209.13000.pdf
42    xi: snapshot_map::SnapshotMap<fhir::ParamId, ()>,
43    errors: Errors<'genv>,
44}
45
46impl<'a, 'genv, 'tcx> ParamUsesChecker<'a, 'genv, 'tcx> {
47    fn new(infcx: &'a InferCtxt<'genv, 'tcx>) -> Self {
48        Self { infcx, xi: Default::default(), errors: Errors::new(infcx.genv.sess()) }
49    }
50
51    fn run(mut self, f: impl FnOnce(&mut Self)) -> Result {
52        f(&mut self);
53        self.errors.into_result()
54    }
55
56    /// Insert params that are considered to be value determined to `xi`.
57    fn insert_value_determined(&mut self, expr: &fhir::Expr) {
58        match expr.kind {
59            fhir::ExprKind::Var(fhir::QPathExpr::Resolved(path, _))
60                if let fhir::Res::Param(_, id) = path.res =>
61            {
62                self.xi.insert(id, ());
63            }
64            fhir::ExprKind::Record(fields) => {
65                for field in fields {
66                    self.insert_value_determined(field);
67                }
68            }
69            fhir::ExprKind::Constructor(_, fields, _) => {
70                for field in fields {
71                    self.insert_value_determined(&field.expr);
72                }
73            }
74            _ => {}
75        }
76    }
77
78    /// Checks that refinement parameters of function sort are used in allowed positions.
79    fn check_func_params_uses(&mut self, expr: &fhir::Expr, is_top_level_conj: bool) {
80        match expr.kind {
81            fhir::ExprKind::BinaryOp(bin_op, e1, e2) | fhir::ExprKind::PrimApp(bin_op, e1, e2) => {
82                let is_top_level_conj = is_top_level_conj && matches!(bin_op, fhir::BinOp::And);
83                self.check_func_params_uses(e1, is_top_level_conj);
84                self.check_func_params_uses(e2, is_top_level_conj);
85            }
86            fhir::ExprKind::UnaryOp(_, e) => self.check_func_params_uses(e, false),
87            fhir::ExprKind::App(func, args) => {
88                if !is_top_level_conj
89                    && let fhir::Res::Param(_, id) = func.res
90                    && let fhir::InferMode::KVar = self.infcx.infer_mode(id)
91                {
92                    self.errors
93                        .emit(InvalidParamPos::new(func.span, &self.infcx.param_sort(id)));
94                }
95                for arg in args {
96                    self.check_func_params_uses(arg, false);
97                }
98            }
99            fhir::ExprKind::Alias(_, func_args) => {
100                // TODO(nilehmann) should we check the usage inside the `AliasPred`?
101                for arg in func_args {
102                    self.check_func_params_uses(arg, false);
103                }
104            }
105            fhir::ExprKind::Var(fhir::QPathExpr::Resolved(path, _)) => {
106                if let fhir::Res::Param(_, id) = path.res
107                    && let sort @ rty::Sort::Func(_) = self.infcx.param_sort(id)
108                {
109                    self.errors.emit(InvalidParamPos::new(path.span, &sort));
110                }
111            }
112            fhir::ExprKind::Var(fhir::QPathExpr::TypeRelative(..)) => {
113                // TODO(nilehmann) should we check the usage inside the `qself`?
114            }
115            fhir::ExprKind::IfThenElse(e1, e2, e3) => {
116                self.check_func_params_uses(e1, false);
117                self.check_func_params_uses(e3, false);
118                self.check_func_params_uses(e2, false);
119            }
120            fhir::ExprKind::Literal(_) => {}
121            fhir::ExprKind::Dot(base, _) => {
122                self.check_func_params_uses(base, false);
123            }
124            fhir::ExprKind::Abs(_, body) => {
125                self.check_func_params_uses(body, true);
126            }
127            fhir::ExprKind::BoundedQuant(_, _, _, body) => {
128                self.check_func_params_uses(body, false);
129            }
130            fhir::ExprKind::Record(fields) => {
131                for field in fields {
132                    self.check_func_params_uses(field, is_top_level_conj);
133                }
134            }
135            fhir::ExprKind::Constructor(_, fields, spread) => {
136                if let Some(spread) = spread {
137                    self.check_func_params_uses(&spread.expr, false);
138                }
139                for field in fields {
140                    self.check_func_params_uses(&field.expr, false);
141                }
142            }
143            fhir::ExprKind::Block(decls, body) => {
144                for decl in decls {
145                    self.check_func_params_uses(&decl.init, false);
146                }
147                self.check_func_params_uses(body, false);
148            }
149            fhir::ExprKind::Err(_) => {
150                // an error has already been reported so we can just skip
151            }
152        }
153    }
154
155    /// Check that Hindly parameters in `params` appear in a value determined position
156    fn check_params_are_value_determined(&mut self, params: &[fhir::RefineParam]) {
157        for param in params {
158            let determined = self.xi.remove(param.id);
159            if self.infcx.infer_mode(param.id) == fhir::InferMode::EVar && !determined {
160                self.errors
161                    .emit(ParamNotDetermined::new(param.span, param.name));
162            }
163        }
164    }
165}
166
167impl<'genv> fhir::visit::Visitor<'genv> for ParamUsesChecker<'_, 'genv, '_> {
168    fn visit_node(&mut self, node: &fhir::OwnerNode<'genv>) {
169        if node.fn_sig().is_some() {
170            // Check early refinement parameters in fn-like nodes
171            let snapshot = self.xi.snapshot();
172            fhir::visit::walk_node(self, node);
173            self.check_params_are_value_determined(node.generics().refinement_params);
174            self.xi.rollback_to(snapshot);
175        } else {
176            fhir::visit::walk_node(self, node);
177        }
178    }
179
180    fn visit_ty_alias(&mut self, ty_alias: &fhir::TyAlias<'genv>) {
181        fhir::visit::walk_ty_alias(self, ty_alias);
182        self.check_params_are_value_determined(ty_alias.index.as_slice());
183    }
184
185    fn visit_struct_def(&mut self, struct_def: &fhir::StructDef<'genv>) {
186        if let fhir::StructKind::Transparent { fields } = struct_def.kind {
187            walk_list!(self, visit_field_def, fields);
188            self.check_params_are_value_determined(struct_def.params);
189        }
190    }
191
192    fn visit_variant(&mut self, variant: &fhir::VariantDef<'genv>) {
193        let snapshot = self.xi.snapshot();
194        fhir::visit::walk_variant(self, variant);
195        self.check_params_are_value_determined(variant.params);
196        self.xi.rollback_to(snapshot);
197    }
198
199    fn visit_variant_ret(&mut self, ret: &fhir::VariantRet<'genv>) {
200        let snapshot = self.xi.snapshot();
201        fhir::visit::walk_variant_ret(self, ret);
202        self.xi.rollback_to(snapshot);
203    }
204
205    fn visit_fn_output(&mut self, output: &fhir::FnOutput<'genv>) {
206        let snapshot = self.xi.snapshot();
207        fhir::visit::walk_fn_output(self, output);
208        self.check_params_are_value_determined(output.params);
209        self.xi.rollback_to(snapshot);
210    }
211
212    fn visit_ty(&mut self, ty: &fhir::Ty<'genv>) {
213        match &ty.kind {
214            fhir::TyKind::StrgRef(_, loc, ty) => {
215                let (_, id) = loc.res.expect_param();
216                self.xi.insert(id, ());
217                self.visit_ty(ty);
218            }
219            fhir::TyKind::Exists(params, ty) => {
220                self.visit_ty(ty);
221                self.check_params_are_value_determined(params);
222            }
223            fhir::TyKind::Indexed(bty, expr) => {
224                fhir::visit::walk_bty(self, bty);
225                self.insert_value_determined(expr);
226                self.check_func_params_uses(expr, false);
227            }
228            _ => fhir::visit::walk_ty(self, ty),
229        }
230    }
231
232    fn visit_expr(&mut self, expr: &fhir::Expr) {
233        self.check_func_params_uses(expr, true);
234    }
235
236    fn visit_path_segment(&mut self, segment: &fhir::PathSegment<'genv>) {
237        let is_box = self.infcx.genv.is_box(segment.res);
238
239        for (i, arg) in segment.args.iter().enumerate() {
240            let snapshot = self.xi.snapshot();
241            self.visit_generic_arg(arg);
242            if !(is_box && i == 0) {
243                self.xi.rollback_to(snapshot);
244            }
245        }
246        walk_list!(self, visit_assoc_item_constraint, segment.constraints);
247    }
248}