flux_refineck/
invariants.rs

1use flux_common::{iter::IterExt, result::ResultExt};
2use flux_config::InferOpts;
3use flux_errors::ErrorGuaranteed;
4use flux_infer::{
5    fixpoint_encoding::FixQueryCache,
6    infer::{ConstrReason, GlobalEnvExt, Tag},
7};
8use flux_middle::{
9    FixpointQueryKind,
10    def_id::MaybeExternId,
11    fhir,
12    global_env::GlobalEnv,
13    queries::try_query,
14    rty::{self},
15};
16use rustc_infer::infer::TyCtxtInferExt;
17use rustc_middle::ty::TypingMode;
18use rustc_span::{DUMMY_SP, Span};
19
20pub fn check_invariants(
21    genv: GlobalEnv,
22    cache: &mut FixQueryCache,
23    def_id: MaybeExternId,
24    invariants: &[fhir::Expr],
25    adt_def: &rty::AdtDef,
26) -> Result<(), ErrorGuaranteed> {
27    // FIXME(nilehmann) maybe we should record whether the invariants were generated with overflow
28    // checking enabled and only assume them in code that also overflow checking enabled.
29    // Although, enable overflow checking locally is unsound in general.
30    //
31    // The good way would be to make overflow checking a property of a type that can be turned on
32    // and off locally. Then we consider an overflow-checked `T` distinct from a non-checked one and
33    // error/warn in case of a mismatch: overflow-checked types can flow to non-checked code but not
34    // the other way around.
35    let opts = genv.infer_opts(def_id.local_id());
36    adt_def
37        .invariants()
38        .iter_identity()
39        .enumerate()
40        .try_for_each_exhaust(|(idx, invariant)| {
41            let span = invariants[idx].span;
42            check_invariant(genv, cache, def_id, adt_def, span, invariant, opts)
43        })
44}
45
46fn check_invariant(
47    genv: GlobalEnv,
48    cache: &mut FixQueryCache,
49    def_id: MaybeExternId,
50    adt_def: &rty::AdtDef,
51    span: Span,
52    invariant: &rty::Invariant,
53    opts: InferOpts,
54) -> Result<(), ErrorGuaranteed> {
55    let resolved_id = def_id.resolved_id();
56
57    let region_infercx = genv
58        .tcx()
59        .infer_ctxt()
60        .with_next_trait_solver(true)
61        .build(TypingMode::non_body_analysis());
62
63    let mut infcx_root = try_query(|| {
64        genv.infcx_root(&region_infercx, opts)
65            .identity_for_item(resolved_id)?
66            .build()
67    })
68    .emit(&genv)?;
69
70    for variant_idx in adt_def.variants().indices() {
71        let mut rcx = infcx_root.infcx(resolved_id, &region_infercx);
72
73        let variant_sig = genv
74            .variant_sig(adt_def.did(), variant_idx)
75            .emit(&genv)?
76            .expect("cannot check opaque structs")
77            .instantiate_identity()
78            .replace_bound_refts_with(|sort, _, kind| {
79                rty::Expr::fvar(rcx.define_bound_reft_var(sort, kind))
80            });
81
82        for ty in variant_sig.fields() {
83            let ty = rcx.unpack(ty);
84            rcx.assume_invariants(&ty);
85        }
86        let pred = invariant.apply(&variant_sig.idx);
87        rcx.check_pred(&pred, Tag::new(ConstrReason::Other, DUMMY_SP));
88    }
89    let answer = infcx_root
90        .execute_fixpoint_query(cache, def_id, FixpointQueryKind::Invariant)
91        .emit(&genv)?;
92
93    if answer.errors.is_empty() {
94        Ok(())
95    } else {
96        Err(genv.sess().emit_err(errors::Invalid { span }))
97    }
98}
99
100mod errors {
101    use flux_errors::E0999;
102    use flux_macros::Diagnostic;
103    use rustc_span::Span;
104
105    #[derive(Diagnostic)]
106    #[diag(refineck_invalid_invariant, code = E0999)]
107    pub struct Invalid {
108        #[primary_span]
109        pub span: Span,
110    }
111}