use flux_errors::Errors;
use flux_middle::{
fhir::{self, visit::Visitor},
rty, walk_list,
};
use rustc_data_structures::snapshot_map;
use rustc_span::ErrorGuaranteed;
use super::{
errors::{InvalidParamPos, ParamNotDetermined},
sortck::InferCtxt,
};
type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
pub(crate) fn check<'genv>(infcx: &InferCtxt<'genv, '_>, node: &fhir::OwnerNode<'genv>) -> Result {
ParamUsesChecker::new(infcx).run(|ck| ck.visit_node(node))
}
struct ParamUsesChecker<'a, 'genv, 'tcx> {
infcx: &'a InferCtxt<'genv, 'tcx>,
xi: snapshot_map::SnapshotMap<fhir::ParamId, ()>,
errors: Errors<'genv>,
}
impl<'a, 'genv, 'tcx> ParamUsesChecker<'a, 'genv, 'tcx> {
fn new(infcx: &'a InferCtxt<'genv, 'tcx>) -> Self {
Self { infcx, xi: Default::default(), errors: Errors::new(infcx.genv.sess()) }
}
fn run(mut self, f: impl FnOnce(&mut Self)) -> Result {
f(&mut self);
self.errors.into_result()
}
fn check_func_params_uses(
&mut self,
expr: &fhir::Expr,
is_top_level_conj: bool,
is_top_level_var: bool,
) {
match expr.kind {
fhir::ExprKind::BinaryOp(bin_op, e1, e2) => {
let is_pred = is_top_level_conj && matches!(bin_op, fhir::BinOp::And);
self.check_func_params_uses(e1, is_pred, false);
self.check_func_params_uses(e2, is_pred, false);
}
fhir::ExprKind::UnaryOp(_, e) => self.check_func_params_uses(e, false, false),
fhir::ExprKind::App(func, args) => {
if !is_top_level_conj
&& let fhir::ExprRes::Param(_, id) = func.res
&& let fhir::InferMode::KVar = self.infcx.infer_mode(id)
{
self.errors
.emit(InvalidParamPos::new(func.span, &self.infcx.param_sort(id)));
}
for arg in args {
self.check_func_params_uses(arg, false, false);
}
}
fhir::ExprKind::Alias(_, func_args) => {
for arg in func_args {
self.check_func_params_uses(arg, false, is_top_level_var);
}
}
fhir::ExprKind::Var(var, _) => {
if let fhir::ExprRes::Param(_, id) = var.res
&& let sort @ rty::Sort::Func(_) = self.infcx.param_sort(id)
{
self.errors.emit(InvalidParamPos::new(var.span, &sort));
}
if let fhir::ExprRes::Param(_, id) = var.res
&& is_top_level_var
{
self.xi.insert(id, ());
}
}
fhir::ExprKind::IfThenElse(e1, e2, e3) => {
self.check_func_params_uses(e1, false, false);
self.check_func_params_uses(e3, false, false);
self.check_func_params_uses(e2, false, false);
}
fhir::ExprKind::Literal(_) => {}
fhir::ExprKind::Dot(var, _) => {
if let fhir::ExprRes::Param(_, id) = var.res
&& let sort @ rty::Sort::Func(_) = &self.infcx.param_sort(id)
{
self.errors.emit(InvalidParamPos::new(var.span, sort));
}
}
fhir::ExprKind::Abs(_, body) => {
self.check_func_params_uses(body, true, is_top_level_var);
}
fhir::ExprKind::Record(fields) => {
for field in fields {
self.check_func_params_uses(field, is_top_level_conj, is_top_level_var);
}
}
fhir::ExprKind::Constructor(_path, exprs, _spread) => {
for expr in exprs {
self.check_func_params_uses(&expr.expr, false, false);
}
}
}
}
fn check_params_are_value_determined(&mut self, params: &[fhir::RefineParam]) {
for param in params {
let determined = self.xi.remove(param.id);
if self.infcx.infer_mode(param.id) == fhir::InferMode::EVar && !determined {
self.errors
.emit(ParamNotDetermined::new(param.span, param.name));
}
}
}
}
impl<'genv> fhir::visit::Visitor<'genv> for ParamUsesChecker<'_, 'genv, '_> {
fn visit_node(&mut self, node: &fhir::OwnerNode<'genv>) {
if node.fn_sig().is_some() {
let snapshot = self.xi.snapshot();
fhir::visit::walk_node(self, node);
self.check_params_are_value_determined(node.generics().refinement_params);
self.xi.rollback_to(snapshot);
} else {
fhir::visit::walk_node(self, node);
}
}
fn visit_ty_alias(&mut self, ty_alias: &fhir::TyAlias<'genv>) {
fhir::visit::walk_ty_alias(self, ty_alias);
self.check_params_are_value_determined(ty_alias.index.as_slice());
}
fn visit_struct_def(&mut self, struct_def: &fhir::StructDef<'genv>) {
if let fhir::StructKind::Transparent { fields } = struct_def.kind {
walk_list!(self, visit_field_def, fields);
self.check_params_are_value_determined(struct_def.params);
}
}
fn visit_variant(&mut self, variant: &fhir::VariantDef<'genv>) {
let snapshot = self.xi.snapshot();
fhir::visit::walk_variant(self, variant);
self.check_params_are_value_determined(variant.params);
self.xi.rollback_to(snapshot);
}
fn visit_variant_ret(&mut self, ret: &fhir::VariantRet<'genv>) {
let snapshot = self.xi.snapshot();
fhir::visit::walk_variant_ret(self, ret);
self.xi.rollback_to(snapshot);
}
fn visit_fn_output(&mut self, output: &fhir::FnOutput<'genv>) {
let snapshot = self.xi.snapshot();
fhir::visit::walk_fn_output(self, output);
self.check_params_are_value_determined(output.params);
self.xi.rollback_to(snapshot);
}
fn visit_ty(&mut self, ty: &fhir::Ty<'genv>) {
match &ty.kind {
fhir::TyKind::StrgRef(_, loc, ty) => {
let (_, id) = loc.res.expect_param();
self.xi.insert(id, ());
self.visit_ty(ty);
}
fhir::TyKind::Exists(params, ty) => {
self.visit_ty(ty);
self.check_params_are_value_determined(params);
}
fhir::TyKind::Indexed(bty, expr) => {
fhir::visit::walk_bty(self, bty);
self.check_func_params_uses(expr, false, true);
}
_ => {
fhir::visit::walk_ty(self, ty);
}
}
}
fn visit_expr(&mut self, expr: &fhir::Expr) {
self.check_func_params_uses(expr, true, false);
}
fn visit_path_segment(&mut self, segment: &fhir::PathSegment<'genv>) {
let is_box = self.infcx.genv.is_box(segment.res);
for (i, arg) in segment.args.iter().enumerate() {
let snapshot = self.xi.snapshot();
self.visit_generic_arg(arg);
if !(is_box && i == 0) {
self.xi.rollback_to(snapshot);
}
}
walk_list!(self, visit_assoc_item_constraint, segment.constraints);
}
}