flux_refineck/
primops.rs

1/// This file defines the refinement rules for primitive operations.
2/// Flux needs to define how to reason about primitive operations on different
3/// [`BaseTy`]s. This is done by defining a set of rules for each operation.
4///
5/// For example, equality checks depend on whether the `BaseTy` is treated as
6/// refineable or opaque.
7///
8/// ```
9/// // Make the rules for `a == b`.
10/// fn mk_eq_rules() -> RuleMatcher<2> {
11///     primop_rules! {
12///         // if the `BaseTy` is refineable, then we can reason about equality.
13///         // The specified types in the `if` are refineable and Flux will use
14///         // the refined postcondition (`bool[E::eq(a, b)]`) to reason about
15///         // the invariants of `==`.
16///         fn(a: T, b: T) -> bool[E::eq(a, b)]
17///         if T.is_integral() || T.is_bool() || T.is_char() || T.is_str()
18///
19///         // Otherwise, if the `BaseTy` is opaque, then we can't reason
20///         // about equality. Flux only knows that the return type is a boolean,
21///         // but the return value is unrefined.
22///         fn(a: T, b: T) -> bool
23///     }
24/// }
25/// ```
26use std::{fmt, hash::Hash, sync::LazyLock};
27
28use flux_common::tracked_span_bug;
29use flux_infer::infer::ConstrReason;
30use flux_macros::primop_rules;
31use flux_middle::rty::{self, BaseTy, Expr, Sort};
32use flux_rustc_bridge::mir;
33use rty::{BinOp::Mod, Expr as E};
34use rustc_data_structures::unord::UnordMap;
35
36#[derive(Debug)]
37pub(crate) struct MatchedRule {
38    pub precondition: Option<Pre>,
39    pub output_type: rty::Ty,
40}
41
42#[derive(Debug)]
43pub(crate) struct Pre {
44    pub reason: ConstrReason,
45    pub pred: Expr,
46}
47
48pub(crate) fn match_bin_op(
49    op: mir::BinOp,
50    bty1: &BaseTy,
51    idx1: &Expr,
52    bty2: &BaseTy,
53    idx2: &Expr,
54    check_overflow: bool,
55) -> MatchedRule {
56    let table = if check_overflow { &OVERFLOW_BIN_OPS } else { &DEFAULT_BIN_OPS };
57    table.match_inputs(&op, [(bty1.clone(), idx1.clone()), (bty2.clone(), idx2.clone())])
58}
59
60pub(crate) fn match_un_op(
61    op: mir::UnOp,
62    bty: &BaseTy,
63    idx: &Expr,
64    check_overflow: bool,
65) -> MatchedRule {
66    let table = if check_overflow { &OVERFLOW_UN_OPS } else { &DEFAULT_UN_OPS };
67    table.match_inputs(&op, [(bty.clone(), idx.clone())])
68}
69
70struct RuleTable<Op: Eq + Hash, const N: usize> {
71    rules: UnordMap<Op, RuleMatcher<N>>,
72}
73
74impl<Op: Eq + Hash + fmt::Debug, const N: usize> RuleTable<Op, N> {
75    fn match_inputs(&self, op: &Op, inputs: [(BaseTy, Expr); N]) -> MatchedRule {
76        (self.rules[op])(&inputs)
77            .unwrap_or_else(|| tracked_span_bug!("no primop rule for {op:?} using {inputs:?}"))
78    }
79}
80
81type RuleMatcher<const N: usize> = fn(&[(BaseTy, Expr); N]) -> Option<MatchedRule>;
82
83static DEFAULT_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
84    use mir::BinOp::*;
85    RuleTable {
86        rules: [
87            // Arith
88            (Add, mk_add_rules(false)),
89            (Mul, mk_mul_rules(false)),
90            (Sub, mk_sub_rules(false)),
91            (Div, mk_div_rules()),
92            (Rem, mk_rem_rules()),
93            // Bitwise
94            (BitAnd, mk_bit_and_rules()),
95            (BitOr, mk_bit_or_rules()),
96            (BitXor, mk_bit_xor_rules()),
97            // Cmp
98            (Eq, mk_eq_rules()),
99            (Ne, mk_ne_rules()),
100            (Le, mk_le_rules()),
101            (Ge, mk_ge_rules()),
102            (Lt, mk_lt_rules()),
103            (Gt, mk_gt_rules()),
104            // Shifts
105            (Shl, mk_shl_rules()),
106            (Shr, mk_shr_rules()),
107        ]
108        .into_iter()
109        .collect(),
110    }
111});
112
113static OVERFLOW_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
114    use mir::BinOp::*;
115    RuleTable {
116        rules: [
117            // Arith
118            (Add, mk_add_rules(true)),
119            (Mul, mk_mul_rules(true)),
120            (Sub, mk_sub_rules(true)),
121            (Div, mk_div_rules()),
122            (Rem, mk_rem_rules()),
123            // Bitwise
124            (BitAnd, mk_bit_and_rules()),
125            (BitOr, mk_bit_or_rules()),
126            (BitXor, mk_bit_xor_rules()),
127            // Cmp
128            (Eq, mk_eq_rules()),
129            (Ne, mk_ne_rules()),
130            (Le, mk_le_rules()),
131            (Ge, mk_ge_rules()),
132            (Lt, mk_lt_rules()),
133            (Gt, mk_gt_rules()),
134            // Shifts
135            (Shl, mk_shl_rules()),
136            (Shr, mk_shr_rules()),
137        ]
138        .into_iter()
139        .collect(),
140    }
141});
142
143static DEFAULT_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
144    use mir::UnOp::*;
145    RuleTable {
146        rules: [(Neg, mk_neg_rules(false)), (Not, mk_not_rules())]
147            .into_iter()
148            .collect(),
149    }
150});
151
152static OVERFLOW_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
153    use mir::UnOp::*;
154    RuleTable {
155        rules: [(Neg, mk_neg_rules(true)), (Not, mk_not_rules())]
156            .into_iter()
157            .collect(),
158    }
159});
160
161/// `a + b`
162fn mk_add_rules(check_overflow: bool) -> RuleMatcher<2> {
163    if check_overflow {
164        primop_rules! {
165            fn(a: T, b: T) -> T[a + b]
166            requires E::and(
167                         E::ge(a + b, E::int_min(int_ty)),
168                         E::le(a + b, E::int_max(int_ty)),
169                     ) => ConstrReason::Overflow
170            if let &BaseTy::Int(int_ty) = T
171
172            fn(a: T, b: T) -> T[a + b]
173            requires E::le(a + b, E::uint_max(uint_ty)) => ConstrReason::Overflow
174            if let &BaseTy::Uint(uint_ty) = T
175
176            fn(a: T, b: T) -> T
177        }
178    } else {
179        primop_rules! {
180            fn(a: T, b: T) -> T[a + b]
181            if T.is_integral()
182
183            fn(a: T, b: T) -> T
184        }
185    }
186}
187
188/// `a * b`
189fn mk_mul_rules(check_overflow: bool) -> RuleMatcher<2> {
190    if check_overflow {
191        primop_rules! {
192            fn(a: T, b: T) -> T[a * b]
193            requires E::and(
194                         E::ge(a * b, E::int_min(int_ty)),
195                         E::le(a * b, E::int_max(int_ty)),
196                     ) => ConstrReason::Overflow
197            if let &BaseTy::Int(int_ty) = T
198
199            fn(a: T, b: T) -> T[a * b]
200            requires E::le(a * b, E::uint_max(uint_ty)) => ConstrReason::Overflow
201            if let &BaseTy::Uint(uint_ty) = T
202
203            fn(a: T, b: T) -> T
204        }
205    } else {
206        primop_rules!(
207            fn(a: T, b: T) -> T[a * b]
208            if T.is_integral()
209
210            fn(a: T, b: T) -> T
211            if T.is_float()
212        )
213    }
214}
215
216/// `a - b`
217fn mk_sub_rules(check_overflow: bool) -> RuleMatcher<2> {
218    if check_overflow {
219        primop_rules! {
220            fn(a: T, b: T) -> T[a - b]
221            requires E::and(
222                         E::ge(a - b, E::int_min(int_ty)),
223                         E::le(a - b, E::int_max(int_ty)),
224                     ) => ConstrReason::Overflow
225            if let &BaseTy::Int(int_ty) = T
226
227            fn(a: T, b: T) -> T[a - b]
228            requires E::and(
229                         E::ge(a - b, 0),
230                         E::le(a - b, E::uint_max(uint_ty)),
231                     ) => ConstrReason::Overflow
232            if let &BaseTy::Uint(uint_ty) = T
233
234            fn(a: T, b: T) -> T
235        }
236    } else {
237        primop_rules! {
238            fn(a: T, b: T) -> T[a - b]
239            requires E::ge(a - b, 0) => ConstrReason::Overflow
240            if T.is_unsigned()
241
242            fn(a: T, b: T) -> T[a - b]
243            if T.is_signed()
244
245            fn(a: T, b: T) -> T
246            if T.is_float()
247        }
248    }
249}
250
251/// `a/b`
252fn mk_div_rules() -> RuleMatcher<2> {
253    primop_rules! {
254        fn(a: T, b: T) -> T[a/b]
255        requires E::ne(b, 0) => ConstrReason::Div
256        if T.is_integral()
257
258        fn(a: T, b: T) -> T
259        if T.is_float()
260    }
261}
262
263/// `a % b`
264fn mk_rem_rules() -> RuleMatcher<2> {
265    primop_rules! {
266        fn(a: T, b: T) -> T[E::binary_op(Mod(Sort::Int), a, b)]
267        requires E::ne(b, 0) => ConstrReason::Rem
268        if T.is_unsigned()
269
270        fn(a: T, b: T) -> T{v: E::implies(
271                                   E::and(E::ge(a, 0), E::ge(b, 0)),
272                                   E::eq(v, E::binary_op(Mod(Sort::Int), a, b))) }
273        requires E::ne(b, 0) => ConstrReason::Rem
274        if T.is_signed()
275
276        fn (a: T, b: T) -> T
277        if T.is_float()
278    }
279}
280
281/// `a & b`
282fn mk_bit_and_rules() -> RuleMatcher<2> {
283    primop_rules! {
284        fn(a: T, b: T) -> { T[E::prim_val(rty::BinOp::BitAnd, a, b)] | E::prim_rel(rty::BinOp::BitAnd, a, b) }
285        if T.is_integral()
286
287        fn(a: bool, b: bool) -> bool[E::and(a, b)]
288    }
289}
290
291/// `a | b`
292fn mk_bit_or_rules() -> RuleMatcher<2> {
293    primop_rules! {
294        fn(a: T, b: T) -> { T[E::prim_val(rty::BinOp::BitOr, a, b)] | E::prim_rel(rty::BinOp::BitOr, a, b) }
295        if T.is_integral()
296
297        fn(a: bool, b: bool) -> bool[E::or(a, b)]
298    }
299}
300
301/// `a ^ b`
302fn mk_bit_xor_rules() -> RuleMatcher<2> {
303    primop_rules! {
304        fn(a: T, b: T) -> T
305        if T.is_integral()
306    }
307}
308
309/// `a == b`
310fn mk_eq_rules() -> RuleMatcher<2> {
311    primop_rules! {
312        fn(a: T, b: T) -> bool[E::eq(a, b)]
313        if T.is_integral() || T.is_bool() || T.is_char() || T.is_str()
314
315        fn(a: T, b: T) -> bool
316    }
317}
318
319/// `a != b`
320fn mk_ne_rules() -> RuleMatcher<2> {
321    primop_rules! {
322        fn(a: T, b: T) -> bool[E::ne(a, b)]
323        if T.is_integral() || T.is_bool()
324
325        fn(a: T, b: T) -> bool
326    }
327}
328
329/// `a <= b`
330fn mk_le_rules() -> RuleMatcher<2> {
331    primop_rules! {
332        fn(a: T, b: T) -> bool[E::le(a, b)]
333        if T.is_integral()
334
335        fn(a: bool, b: bool) -> bool[E::implies(a, b)]
336
337        fn(a: T, b: T) -> bool
338    }
339}
340
341/// `a >= b`
342fn mk_ge_rules() -> RuleMatcher<2> {
343    primop_rules! {
344        fn(a: T, b: T) -> bool[E::ge(a, b)]
345        if T.is_integral()
346
347        fn(a: bool, b: bool) -> bool[E::implies(b, a)]
348
349        fn(a: T, b: T) -> bool
350    }
351}
352
353/// `a < b`
354fn mk_lt_rules() -> RuleMatcher<2> {
355    primop_rules! {
356        fn(a: T, b: T) -> bool[E::lt(a, b)]
357        if T.is_integral()
358
359        fn(a: bool, b: bool) -> bool[E::and(a.not(), b)]
360
361        fn(a: T, b: T) -> bool
362    }
363}
364
365/// `a > b`
366fn mk_gt_rules() -> RuleMatcher<2> {
367    primop_rules! {
368        fn(a: T, b: T) -> bool[E::gt(a, b)]
369        if T.is_integral()
370
371        fn(a: bool, b: bool) -> bool[E::and(a, b.not())]
372
373        fn(a: T, b: T) -> bool
374    }
375}
376
377/// `a << b`
378fn mk_shl_rules() -> RuleMatcher<2> {
379    primop_rules! {
380        fn(a: T, b: S) -> { T[E::prim_val(rty::BinOp::BitShl, a, b)] | E::prim_rel(rty::BinOp::BitShl, a, b) }
381        if T.is_integral() && S.is_integral()
382    }
383}
384
385/// `a >> b`
386fn mk_shr_rules() -> RuleMatcher<2> {
387    primop_rules! {
388        fn(a: T, b: S) -> { T[E::prim_val(rty::BinOp::BitShr, a, b)] | E::prim_rel(rty::BinOp::BitShr, a, b) }
389        if T.is_integral() && S.is_integral()
390    }
391}
392
393/// `-a`
394fn mk_neg_rules(check_overflow: bool) -> RuleMatcher<1> {
395    if check_overflow {
396        primop_rules! {
397            fn(a: T) -> T[a.neg()]
398            requires E::ne(a, E::int_min(int_ty)) => ConstrReason::Overflow
399            if let &BaseTy::Int(int_ty) = T
400
401            fn(a: T) -> T[a.neg()]
402            if T.is_float()
403        }
404    } else {
405        primop_rules! {
406            fn(a: T) -> T[a.neg()]
407            if T.is_integral()
408
409            fn(a: T) -> T
410            if T.is_float()
411        }
412    }
413}
414
415/// `!a`
416fn mk_not_rules() -> RuleMatcher<1> {
417    primop_rules! {
418        fn(a: bool) -> bool[a.not()]
419
420        fn(a: T) -> T
421        if T.is_integral()
422    }
423}