flux_refineck/
primops.rs

1/// This file defines the refinement rules for primitive operations.
2/// Flux needs to define how to reason about primitive operations on different
3/// [`BaseTy`]s. This is done by defining a set of rules for each operation.
4///
5/// For example, equality checks depend on whether the `BaseTy` is treated as
6/// refineable or opaque.
7///
8/// ```
9/// // Make the rules for `a == b`.
10/// fn mk_eq_rules() -> RuleMatcher<2> {
11///     primop_rules! {
12///         // if the `BaseTy` is refineable, then we can reason about equality.
13///         // The specified types in the `if` are refineable and Flux will use
14///         // the refined postcondition (`bool[E::eq(a, b)]`) to reason about
15///         // the invariants of `==`.
16///         fn(a: T, b: T) -> bool[E::eq(a, b)]
17///         if T.is_integral() || T.is_bool() || T.is_char() || T.is_str()
18///
19///         // Otherwise, if the `BaseTy` is opaque, then we can't reason
20///         // about equality. Flux only knows that the return type is a boolean,
21///         // but the return value is unrefined.
22///         fn(a: T, b: T) -> bool
23///     }
24/// }
25/// ```
26use std::{hash::Hash, sync::LazyLock};
27
28use flux_common::tracked_span_bug;
29use flux_infer::infer::ConstrReason;
30use flux_macros::primop_rules;
31use flux_middle::rty::{self, BaseTy, Expr, Sort};
32use flux_rustc_bridge::mir;
33use rty::{BinOp::Mod, Expr as E};
34use rustc_data_structures::unord::UnordMap;
35
36pub(crate) struct MatchedRule {
37    pub precondition: Option<Pre>,
38    pub output_type: rty::Ty,
39}
40
41pub(crate) struct Pre {
42    pub reason: ConstrReason,
43    pub pred: Expr,
44}
45
46pub(crate) fn match_bin_op(
47    op: mir::BinOp,
48    bty1: &BaseTy,
49    idx1: &Expr,
50    bty2: &BaseTy,
51    idx2: &Expr,
52    check_overflow: bool,
53) -> MatchedRule {
54    let table = if check_overflow { &OVERFLOW_BIN_OPS } else { &DEFAULT_BIN_OPS };
55    table.match_inputs(&op, [(bty1.clone(), idx1.clone()), (bty2.clone(), idx2.clone())])
56}
57
58pub(crate) fn match_un_op(
59    op: mir::UnOp,
60    bty: &BaseTy,
61    idx: &Expr,
62    check_overflow: bool,
63) -> MatchedRule {
64    let table = if check_overflow { &OVERFLOW_UN_OPS } else { &DEFAULT_UN_OPS };
65    table.match_inputs(&op, [(bty.clone(), idx.clone())])
66}
67
68struct RuleTable<Op: Eq + Hash, const N: usize> {
69    rules: UnordMap<Op, RuleMatcher<N>>,
70}
71
72impl<Op: Eq + Hash, const N: usize> RuleTable<Op, N> {
73    fn match_inputs(&self, op: &Op, inputs: [(BaseTy, Expr); N]) -> MatchedRule {
74        (self.rules[op])(&inputs)
75            .unwrap_or_else(|| tracked_span_bug!("no primop rule for {inputs:?}"))
76    }
77}
78
79type RuleMatcher<const N: usize> = fn(&[(BaseTy, Expr); N]) -> Option<MatchedRule>;
80
81static DEFAULT_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
82    use mir::BinOp::*;
83    RuleTable {
84        rules: [
85            // Arith
86            (Add, mk_add_rules(false)),
87            (Mul, mk_mul_rules(false)),
88            (Sub, mk_sub_rules(false)),
89            (Div, mk_div_rules()),
90            (Rem, mk_rem_rules()),
91            // Bitwise
92            (BitAnd, mk_bit_and_rules()),
93            (BitOr, mk_bit_or_rules()),
94            (BitXor, mk_bit_xor_rules()),
95            // Cmp
96            (Eq, mk_eq_rules()),
97            (Ne, mk_ne_rules()),
98            (Le, mk_le_rules()),
99            (Ge, mk_ge_rules()),
100            (Lt, mk_lt_rules()),
101            (Gt, mk_gt_rules()),
102            // Shifts
103            (Shl, mk_shl_rules()),
104            (Shr, mk_shr_rules()),
105        ]
106        .into_iter()
107        .collect(),
108    }
109});
110
111static OVERFLOW_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
112    use mir::BinOp::*;
113    RuleTable {
114        rules: [
115            // Arith
116            (Add, mk_add_rules(true)),
117            (Mul, mk_mul_rules(true)),
118            (Sub, mk_sub_rules(true)),
119            (Div, mk_div_rules()),
120            (Rem, mk_rem_rules()),
121            // Bitwise
122            (BitAnd, mk_bit_and_rules()),
123            (BitOr, mk_bit_or_rules()),
124            (BitXor, mk_bit_xor_rules()),
125            // Cmp
126            (Eq, mk_eq_rules()),
127            (Ne, mk_ne_rules()),
128            (Le, mk_le_rules()),
129            (Ge, mk_ge_rules()),
130            (Lt, mk_lt_rules()),
131            (Gt, mk_gt_rules()),
132            // Shifts
133            (Shl, mk_shl_rules()),
134            (Shr, mk_shr_rules()),
135        ]
136        .into_iter()
137        .collect(),
138    }
139});
140
141static DEFAULT_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
142    use mir::UnOp::*;
143    RuleTable {
144        rules: [(Neg, mk_neg_rules(false)), (Not, mk_not_rules())]
145            .into_iter()
146            .collect(),
147    }
148});
149
150static OVERFLOW_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
151    use mir::UnOp::*;
152    RuleTable {
153        rules: [(Neg, mk_neg_rules(true)), (Not, mk_not_rules())]
154            .into_iter()
155            .collect(),
156    }
157});
158
159/// `a + b`
160fn mk_add_rules(check_overflow: bool) -> RuleMatcher<2> {
161    if check_overflow {
162        primop_rules! {
163            fn(a: T, b: T) -> T[a + b]
164            requires E::and(
165                         E::ge(a + b, E::int_min(int_ty)),
166                         E::le(a + b, E::int_max(int_ty)),
167                     ) => ConstrReason::Overflow
168            if let &BaseTy::Int(int_ty) = T
169
170            fn(a: T, b: T) -> T[a + b]
171            requires E::le(a + b, E::uint_max(uint_ty)) => ConstrReason::Overflow
172            if let &BaseTy::Uint(uint_ty) = T
173
174            fn(a: T, b: T) -> T
175        }
176    } else {
177        primop_rules! {
178            fn(a: T, b: T) -> T[a + b]
179            if T.is_integral()
180
181            fn(a: T, b: T) -> T
182        }
183    }
184}
185
186/// `a * b`
187fn mk_mul_rules(check_overflow: bool) -> RuleMatcher<2> {
188    if check_overflow {
189        primop_rules! {
190            fn(a: T, b: T) -> T[a * b]
191            requires E::and(
192                         E::ge(a * b, E::int_min(int_ty)),
193                         E::le(a * b, E::int_max(int_ty)),
194                     ) => ConstrReason::Overflow
195            if let &BaseTy::Int(int_ty) = T
196
197            fn(a: T, b: T) -> T[a * b]
198            requires E::le(a * b, E::uint_max(uint_ty)) => ConstrReason::Overflow
199            if let &BaseTy::Uint(uint_ty) = T
200
201            fn(a: T, b: T) -> T
202        }
203    } else {
204        primop_rules!(
205            fn(a: T, b: T) -> T[a * b]
206            if T.is_integral()
207
208            fn(a: T, b: T) -> T
209            if T.is_float()
210        )
211    }
212}
213
214/// `a - b`
215fn mk_sub_rules(check_overflow: bool) -> RuleMatcher<2> {
216    if check_overflow {
217        primop_rules! {
218            fn(a: T, b: T) -> T[a - b]
219            requires E::and(
220                         E::ge(a - b, E::int_min(int_ty)),
221                         E::le(a - b, E::int_max(int_ty)),
222                     ) => ConstrReason::Overflow
223            if let &BaseTy::Int(int_ty) = T
224
225            fn(a: T, b: T) -> T[a - b]
226            requires E::and(
227                         E::ge(a - b, 0),
228                         E::le(a - b, E::uint_max(uint_ty)),
229                     ) => ConstrReason::Overflow
230            if let &BaseTy::Uint(uint_ty) = T
231
232            fn(a: T, b: T) -> T
233        }
234    } else {
235        primop_rules! {
236            fn(a: T, b: T) -> T[a - b]
237            requires E::ge(a - b, 0) => ConstrReason::Overflow
238            if T.is_unsigned()
239
240            fn(a: T, b: T) -> T[a - b]
241            if T.is_signed()
242
243            fn(a: T, b: T) -> T
244            if T.is_float()
245        }
246    }
247}
248
249/// `a/b`
250fn mk_div_rules() -> RuleMatcher<2> {
251    primop_rules! {
252        fn(a: T, b: T) -> T[a/b]
253        requires E::ne(b, 0) => ConstrReason::Div
254        if T.is_integral()
255
256        fn(a: T, b: T) -> T
257        if T.is_float()
258    }
259}
260
261/// `a % b`
262fn mk_rem_rules() -> RuleMatcher<2> {
263    primop_rules! {
264        fn(a: T, b: T) -> T[E::binary_op(Mod(Sort::Int), a, b)]
265        requires E::ne(b, 0) => ConstrReason::Rem
266        if T.is_unsigned()
267
268        fn(a: T, b: T) -> T{v: E::implies(
269                                   E::and(E::ge(a, 0), E::ge(b, 0)),
270                                   E::eq(v, E::binary_op(Mod(Sort::Int), a, b))) }
271        requires E::ne(b, 0) => ConstrReason::Rem
272        if T.is_signed()
273    }
274}
275
276/// `a & b`
277fn mk_bit_and_rules() -> RuleMatcher<2> {
278    primop_rules! {
279        fn(a: T, b: T) -> T
280        if T.is_integral()
281
282        fn(a: bool, b: bool) -> bool[E::and(a, b)]
283    }
284}
285
286/// `a | b`
287fn mk_bit_or_rules() -> RuleMatcher<2> {
288    primop_rules! {
289        fn(a: T, b: T) -> T
290        if T.is_integral()
291
292        fn(a: bool, b: bool) -> bool[E::or(a, b)]
293    }
294}
295
296/// `a ^ b`
297fn mk_bit_xor_rules() -> RuleMatcher<2> {
298    primop_rules! {
299        fn(a: T, b: T) -> T
300        if T.is_integral()
301    }
302}
303
304/// `a == b`
305fn mk_eq_rules() -> RuleMatcher<2> {
306    primop_rules! {
307        fn(a: T, b: T) -> bool[E::eq(a, b)]
308        if T.is_integral() || T.is_bool() || T.is_char() || T.is_str()
309
310        fn(a: T, b: T) -> bool
311    }
312}
313
314/// `a != b`
315fn mk_ne_rules() -> RuleMatcher<2> {
316    primop_rules! {
317        fn(a: T, b: T) -> bool[E::ne(a, b)]
318        if T.is_integral() || T.is_bool()
319
320        fn(a: T, b: T) -> bool
321    }
322}
323
324/// `a <= b`
325fn mk_le_rules() -> RuleMatcher<2> {
326    primop_rules! {
327        fn(a: T, b: T) -> bool[E::le(a, b)]
328        if T.is_integral()
329
330        fn(a: bool, b: bool) -> bool[E::implies(a, b)]
331
332        fn(a: T, b: T) -> bool
333    }
334}
335
336/// `a >= b`
337fn mk_ge_rules() -> RuleMatcher<2> {
338    primop_rules! {
339        fn(a: T, b: T) -> bool[E::ge(a, b)]
340        if T.is_integral()
341
342        fn(a: bool, b: bool) -> bool[E::implies(b, a)]
343
344        fn(a: T, b: T) -> bool
345    }
346}
347
348/// `a < b`
349fn mk_lt_rules() -> RuleMatcher<2> {
350    primop_rules! {
351        fn(a: T, b: T) -> bool[E::lt(a, b)]
352        if T.is_integral()
353
354        fn(a: bool, b: bool) -> bool[E::and(a.not(), b)]
355
356        fn(a: T, b: T) -> bool
357    }
358}
359
360/// `a > b`
361fn mk_gt_rules() -> RuleMatcher<2> {
362    primop_rules! {
363        fn(a: T, b: T) -> bool[E::gt(a, b)]
364        if T.is_integral()
365
366        fn(a: bool, b: bool) -> bool[E::and(a, b.not())]
367
368        fn(a: T, b: T) -> bool
369    }
370}
371
372/// `a << b`
373fn mk_shl_rules() -> RuleMatcher<2> {
374    primop_rules! {
375        fn(a: T, b: S) -> T
376        if T.is_integral() && S.is_integral()
377    }
378}
379
380/// `a >> b`
381fn mk_shr_rules() -> RuleMatcher<2> {
382    primop_rules! {
383        fn(a: T, b: S) -> T
384        if T.is_integral() && S.is_integral()
385    }
386}
387
388/// `-a`
389fn mk_neg_rules(check_overflow: bool) -> RuleMatcher<1> {
390    if check_overflow {
391        primop_rules! {
392            fn(a: T) -> T[a.neg()]
393            requires E::ne(a, E::int_min(int_ty)) => ConstrReason::Overflow
394            if let &BaseTy::Int(int_ty) = T
395
396            fn(a: T) -> T[a.neg()]
397            if T.is_float()
398        }
399    } else {
400        primop_rules! {
401            fn(a: T) -> T[a.neg()]
402            if T.is_integral()
403
404            fn(a: T) -> T
405            if T.is_float()
406        }
407    }
408}
409
410/// `!a`
411fn mk_not_rules() -> RuleMatcher<1> {
412    primop_rules! {
413        fn(a: bool) -> bool[a.not()]
414
415        fn(a: T) -> T
416        if T.is_integral()
417    }
418}