flux_refineck/
primops.rs

1/// This file defines the refinement rules for primitive operations.
2/// Flux needs to define how to reason about primitive operations on different
3/// [`BaseTy`]s. This is done by defining a set of rules for each operation.
4///
5/// For example, equality checks depend on whether the `BaseTy` is treated as
6/// refineable or opaque.
7///
8/// ```
9/// // Make the rules for `a == b`.
10/// fn mk_eq_rules() -> RuleMatcher<2> {
11///     primop_rules! {
12///         // if the `BaseTy` is refineable, then we can reason about equality.
13///         // The specified types in the `if` are refineable and Flux will use
14///         // the refined postcondition (`bool[E::eq(a, b)]`) to reason about
15///         // the invariants of `==`.
16///         fn(a: T, b: T) -> bool[E::eq(a, b)]
17///         if T.is_integral() || T.is_bool() || T.is_char() || T.is_str()
18///
19///         // Otherwise, if the `BaseTy` is opaque, then we can't reason
20///         // about equality. Flux only knows that the return type is a boolean,
21///         // but the return value is unrefined.
22///         fn(a: T, b: T) -> bool
23///     }
24/// }
25/// ```
26use std::{fmt, hash::Hash, sync::LazyLock};
27
28use flux_common::tracked_span_bug;
29use flux_config::OverflowMode;
30use flux_infer::infer::ConstrReason;
31use flux_macros::primop_rules;
32use flux_middle::rty::{self, BaseTy, Expr, Sort};
33use flux_rustc_bridge::mir;
34use rty::{BinOp::Mod, Expr as E};
35use rustc_data_structures::unord::UnordMap;
36
37#[derive(Debug)]
38pub(crate) struct MatchedRule {
39    pub precondition: Option<Pre>,
40    pub output_type: rty::Ty,
41}
42
43#[derive(Debug)]
44pub(crate) struct Pre {
45    pub reason: ConstrReason,
46    pub pred: Expr,
47}
48
49pub(crate) fn match_bin_op(
50    op: mir::BinOp,
51    bty1: &BaseTy,
52    idx1: &Expr,
53    bty2: &BaseTy,
54    idx2: &Expr,
55    overflow_mode: OverflowMode,
56) -> MatchedRule {
57    let table = match overflow_mode {
58        OverflowMode::Strict => &OVERFLOW_STRICT_BIN_OPS,
59        OverflowMode::Lazy => &OVERFLOW_LAZY_BIN_OPS,
60        OverflowMode::None => &OVERFLOW_NONE_BIN_OPS,
61        OverflowMode::StrictUnder => &OVERFLOW_STRICT_UNDER_BIN_OPS,
62    };
63    table.match_inputs(&op, [(bty1.clone(), idx1.clone()), (bty2.clone(), idx2.clone())])
64}
65
66pub(crate) fn match_un_op(
67    op: mir::UnOp,
68    bty: &BaseTy,
69    idx: &Expr,
70    overflow_mode: OverflowMode,
71) -> MatchedRule {
72    let table = match overflow_mode {
73        OverflowMode::Strict => &OVERFLOW_STRICT_UN_OPS,
74        OverflowMode::None => &OVERFLOW_NONE_UN_OPS,
75        OverflowMode::Lazy | OverflowMode::StrictUnder => &OVERFLOW_LAZY_UN_OPS,
76    };
77    table.match_inputs(&op, [(bty.clone(), idx.clone())])
78}
79
80struct RuleTable<Op: Eq + Hash, const N: usize> {
81    rules: UnordMap<Op, RuleMatcher<N>>,
82}
83
84impl<Op: Eq + Hash + fmt::Debug, const N: usize> RuleTable<Op, N> {
85    fn match_inputs(&self, op: &Op, inputs: [(BaseTy, Expr); N]) -> MatchedRule {
86        (self.rules[op])(&inputs)
87            .unwrap_or_else(|| tracked_span_bug!("no primop rule for {op:?} using {inputs:?}"))
88    }
89}
90
91type RuleMatcher<const N: usize> = fn(&[(BaseTy, Expr); N]) -> Option<MatchedRule>;
92
93static OVERFLOW_NONE_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
94    use mir::BinOp::*;
95    RuleTable {
96        rules: [
97            // Arith
98            (Add, mk_add_rules(OverflowMode::None)),
99            (Mul, mk_mul_rules(OverflowMode::None)),
100            (Sub, mk_sub_rules(OverflowMode::None)),
101            (Div, mk_div_rules()),
102            (Rem, mk_rem_rules()),
103            // Bitwise
104            (BitAnd, mk_bit_and_rules()),
105            (BitOr, mk_bit_or_rules()),
106            (BitXor, mk_bit_xor_rules()),
107            // Cmp
108            (Eq, mk_eq_rules()),
109            (Ne, mk_ne_rules()),
110            (Le, mk_le_rules()),
111            (Ge, mk_ge_rules()),
112            (Lt, mk_lt_rules()),
113            (Gt, mk_gt_rules()),
114            // Shifts
115            (Shl, mk_shl_rules()),
116            (Shr, mk_shr_rules()),
117        ]
118        .into_iter()
119        .collect(),
120    }
121});
122
123static OVERFLOW_STRICT_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
124    use mir::BinOp::*;
125    RuleTable {
126        rules: [
127            // Arith
128            (Add, mk_add_rules(OverflowMode::Strict)),
129            (Mul, mk_mul_rules(OverflowMode::Strict)),
130            (Sub, mk_sub_rules(OverflowMode::Strict)),
131            (Div, mk_div_rules()),
132            (Rem, mk_rem_rules()),
133            // Bitwise
134            (BitAnd, mk_bit_and_rules()),
135            (BitOr, mk_bit_or_rules()),
136            (BitXor, mk_bit_xor_rules()),
137            // Cmp
138            (Eq, mk_eq_rules()),
139            (Ne, mk_ne_rules()),
140            (Le, mk_le_rules()),
141            (Ge, mk_ge_rules()),
142            (Lt, mk_lt_rules()),
143            (Gt, mk_gt_rules()),
144            // Shifts
145            (Shl, mk_shl_rules()),
146            (Shr, mk_shr_rules()),
147        ]
148        .into_iter()
149        .collect(),
150    }
151});
152
153static OVERFLOW_LAZY_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
154    use mir::BinOp::*;
155    RuleTable {
156        rules: [
157            // Arith
158            (Add, mk_add_rules(OverflowMode::Lazy)),
159            (Mul, mk_mul_rules(OverflowMode::Lazy)),
160            (Sub, mk_sub_rules(OverflowMode::Lazy)),
161            (Div, mk_div_rules()),
162            (Rem, mk_rem_rules()),
163            // Bitwise
164            (BitAnd, mk_bit_and_rules()),
165            (BitOr, mk_bit_or_rules()),
166            (BitXor, mk_bit_xor_rules()),
167            // Cmp
168            (Eq, mk_eq_rules()),
169            (Ne, mk_ne_rules()),
170            (Le, mk_le_rules()),
171            (Ge, mk_ge_rules()),
172            (Lt, mk_lt_rules()),
173            (Gt, mk_gt_rules()),
174            // Shifts
175            (Shl, mk_shl_rules()),
176            (Shr, mk_shr_rules()),
177        ]
178        .into_iter()
179        .collect(),
180    }
181});
182
183static OVERFLOW_STRICT_UNDER_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
184    use mir::BinOp::*;
185    RuleTable {
186        rules: [
187            // Arith
188            (Add, mk_add_rules(OverflowMode::StrictUnder)),
189            (Mul, mk_mul_rules(OverflowMode::StrictUnder)),
190            (Sub, mk_sub_rules(OverflowMode::StrictUnder)),
191            (Div, mk_div_rules()),
192            (Rem, mk_rem_rules()),
193            // Bitwise
194            (BitAnd, mk_bit_and_rules()),
195            (BitOr, mk_bit_or_rules()),
196            (BitXor, mk_bit_xor_rules()),
197            // Cmp
198            (Eq, mk_eq_rules()),
199            (Ne, mk_ne_rules()),
200            (Le, mk_le_rules()),
201            (Ge, mk_ge_rules()),
202            (Lt, mk_lt_rules()),
203            (Gt, mk_gt_rules()),
204            // Shifts
205            (Shl, mk_shl_rules()),
206            (Shr, mk_shr_rules()),
207        ]
208        .into_iter()
209        .collect(),
210    }
211});
212
213static OVERFLOW_NONE_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
214    use mir::UnOp::*;
215    RuleTable {
216        rules: [(Neg, mk_neg_rules(OverflowMode::None)), (Not, mk_not_rules())]
217            .into_iter()
218            .collect(),
219    }
220});
221
222static OVERFLOW_LAZY_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
223    use mir::UnOp::*;
224    RuleTable {
225        rules: [(Neg, mk_neg_rules(OverflowMode::Lazy)), (Not, mk_not_rules())]
226            .into_iter()
227            .collect(),
228    }
229});
230
231static OVERFLOW_STRICT_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
232    use mir::UnOp::*;
233    RuleTable {
234        rules: [(Neg, mk_neg_rules(OverflowMode::Strict)), (Not, mk_not_rules())]
235            .into_iter()
236            .collect(),
237    }
238});
239
240fn valid_int(e: impl Into<Expr>, int_ty: rty::IntTy) -> rty::Expr {
241    let e1 = e.into();
242    let e2 = e1.clone();
243    E::and(E::ge(e1, E::int_min(int_ty)), E::le(e2, E::int_max(int_ty)))
244}
245
246fn valid_uint(e: impl Into<Expr>, uint_ty: rty::UintTy) -> rty::Expr {
247    let e1 = e.into();
248    let e2 = e1.clone();
249    E::and(E::ge(e1, 0), E::le(e2, E::uint_max(uint_ty)))
250}
251
252/// `a + b`
253fn mk_add_rules(overflow_mode: OverflowMode) -> RuleMatcher<2> {
254    match overflow_mode {
255        OverflowMode::Strict => {
256            primop_rules! {
257                fn(a: T, b: T) -> T[a + b]
258                requires valid_int(a + b, int_ty) => ConstrReason::Overflow
259                if let &BaseTy::Int(int_ty) = T
260
261                fn(a: T, b: T) -> T[a + b]
262                requires valid_uint(a + b, uint_ty) => ConstrReason::Overflow
263                if let &BaseTy::Uint(uint_ty) = T
264
265                fn(a: T, b: T) -> T
266            }
267        }
268
269        OverflowMode::Lazy | OverflowMode::StrictUnder => {
270            primop_rules! {
271                fn(a: T, b: T) -> T{v: E::implies(valid_int(a + b, int_ty), E::eq(v, a+b)) }
272                if let &BaseTy::Int(int_ty) = T
273
274                fn(a: T, b: T) -> T{v: E::implies(valid_uint(a + b, uint_ty), E::eq(v, a+b)) }
275                if let &BaseTy::Uint(uint_ty) = T
276
277                fn(a: T, b: T) -> T
278            }
279        }
280
281        OverflowMode::None => {
282            primop_rules! {
283                fn(a: T, b: T) -> T[a + b]
284                if T.is_integral()
285
286                fn(a: T, b: T) -> T
287            }
288        }
289    }
290}
291
292/// `a * b`
293fn mk_mul_rules(overflow_mode: OverflowMode) -> RuleMatcher<2> {
294    match overflow_mode {
295        OverflowMode::Strict => {
296            primop_rules! {
297                fn(a: T, b: T) -> T[a * b]
298                requires valid_int(a * b, int_ty) => ConstrReason::Overflow
299                if let &BaseTy::Int(int_ty) = T
300
301                fn(a: T, b: T) -> T[a * b]
302                requires valid_uint(a * b, uint_ty) => ConstrReason::Overflow
303                if let &BaseTy::Uint(uint_ty) = T
304
305                fn(a: T, b: T) -> T
306            }
307        }
308
309        OverflowMode::Lazy | OverflowMode::StrictUnder => {
310            primop_rules! {
311                fn(a: T, b: T) -> T{v: E::implies(valid_int(a * b, int_ty), E::eq(v, a * b)) }
312                if let &BaseTy::Int(int_ty) = T
313
314                fn(a: T, b: T) -> T{v: E::implies(valid_uint(a * b, uint_ty), E::eq(v, a * b)) }
315                if let &BaseTy::Uint(uint_ty) = T
316
317                fn(a: T, b: T) -> T
318            }
319        }
320
321        OverflowMode::None => {
322            primop_rules!(
323                fn(a: T, b: T) -> T[a * b]
324                if T.is_integral()
325
326                fn(a: T, b: T) -> T
327                if T.is_float()
328            )
329        }
330    }
331}
332
333/// `a - b`
334fn mk_sub_rules(overflow_mode: OverflowMode) -> RuleMatcher<2> {
335    match overflow_mode {
336        OverflowMode::Strict => {
337            primop_rules! {
338                fn(a: T, b: T) -> T[a - b]
339                requires valid_int(a - b, int_ty) => ConstrReason::Overflow
340                if let &BaseTy::Int(int_ty) = T
341
342                fn(a: T, b: T) -> T[a - b]
343                requires valid_uint(a - b, uint_ty) => ConstrReason::Overflow
344                if let &BaseTy::Uint(uint_ty) = T
345
346                fn(a: T, b: T) -> T
347            }
348        }
349
350        // like Lazy, but we also check for underflow on unsigned subtraction
351        OverflowMode::StrictUnder => {
352            primop_rules! {
353                fn(a: T, b: T) -> T{v: E::implies(valid_int(a - b, int_ty), E::eq(v, a - b)) }
354                if let &BaseTy::Int(int_ty) = T
355
356                fn(a: T, b: T) -> T{v: E::implies(valid_uint(a - b, uint_ty), E::eq(v, a - b)) }
357                requires E::ge(a - b, 0) => ConstrReason::Underflow
358                if let &BaseTy::Uint(uint_ty) = T
359
360                fn(a: T, b: T) -> T
361            }
362        }
363
364        OverflowMode::Lazy => {
365            primop_rules! {
366                fn(a: T, b: T) -> T{v: E::implies(valid_int(a - b, int_ty), E::eq(v, a - b)) }
367                if let &BaseTy::Int(int_ty) = T
368
369                fn(a: T, b: T) -> T{v: E::implies(valid_uint(a - b, uint_ty), E::eq(v, a - b)) }
370                if let &BaseTy::Uint(uint_ty) = T
371
372                fn(a: T, b: T) -> T
373            }
374        }
375
376        OverflowMode::None => {
377            primop_rules! {
378                fn(a: T, b: T) -> T[a - b]
379                requires E::ge(a - b, 0) => ConstrReason::Underflow
380                if T.is_unsigned()
381
382                fn(a: T, b: T) -> T[a - b]
383                if T.is_signed()
384
385                fn(a: T, b: T) -> T
386                if T.is_float()
387            }
388        }
389    }
390}
391
392/// `a/b`
393fn mk_div_rules() -> RuleMatcher<2> {
394    primop_rules! {
395        fn(a: T, b: T) -> T[a/b]
396        requires E::ne(b, 0) => ConstrReason::Div
397        if T.is_integral()
398
399        fn(a: T, b: T) -> T
400        if T.is_float()
401    }
402}
403
404/// `a % b`
405fn mk_rem_rules() -> RuleMatcher<2> {
406    primop_rules! {
407        fn(a: T, b: T) -> T[E::binary_op(Mod(Sort::Int), a, b)]
408        requires E::ne(b, 0) => ConstrReason::Rem
409        if T.is_unsigned()
410
411        fn(a: T, b: T) -> T{v: E::implies(
412                                   E::and(E::ge(a, 0), E::ge(b, 0)),
413                                   E::eq(v, E::binary_op(Mod(Sort::Int), a, b))) }
414        requires E::ne(b, 0) => ConstrReason::Rem
415        if T.is_signed()
416
417        fn (a: T, b: T) -> T
418        if T.is_float()
419    }
420}
421
422/// `a & b`
423fn mk_bit_and_rules() -> RuleMatcher<2> {
424    primop_rules! {
425        fn(a: T, b: T) -> { T[E::prim_val(rty::BinOp::BitAnd, a, b)] | E::prim_rel(rty::BinOp::BitAnd, a, b) }
426        if T.is_integral()
427
428        fn(a: bool, b: bool) -> bool[E::and(a, b)]
429    }
430}
431
432/// `a | b`
433fn mk_bit_or_rules() -> RuleMatcher<2> {
434    primop_rules! {
435        fn(a: T, b: T) -> { T[E::prim_val(rty::BinOp::BitOr, a, b)] | E::prim_rel(rty::BinOp::BitOr, a, b) }
436        if T.is_integral()
437
438        fn(a: bool, b: bool) -> bool[E::or(a, b)]
439    }
440}
441
442/// `a ^ b`
443fn mk_bit_xor_rules() -> RuleMatcher<2> {
444    primop_rules! {
445        fn(a: T, b: T) -> { T[E::prim_val(rty::BinOp::BitXor, a, b)] | E::prim_rel(rty::BinOp::BitXor, a, b) }
446        if T.is_integral()
447    }
448}
449
450/// `a == b`
451fn mk_eq_rules() -> RuleMatcher<2> {
452    primop_rules! {
453        fn(a: T, b: T) -> bool[E::eq(a, b)]
454        if T.is_integral() || T.is_bool() || T.is_char() || T.is_str()
455
456        fn(a: T, b: T) -> bool
457    }
458}
459
460/// `a != b`
461fn mk_ne_rules() -> RuleMatcher<2> {
462    primop_rules! {
463        fn(a: T, b: T) -> bool[E::ne(a, b)]
464        if T.is_integral() || T.is_bool()
465
466        fn(a: T, b: T) -> bool
467    }
468}
469
470/// `a <= b`
471fn mk_le_rules() -> RuleMatcher<2> {
472    primop_rules! {
473        fn(a: T, b: T) -> bool[E::le(a, b)]
474        if T.is_integral()
475
476        fn(a: bool, b: bool) -> bool[E::implies(a, b)]
477
478        fn(a: T, b: T) -> bool
479    }
480}
481
482/// `a >= b`
483fn mk_ge_rules() -> RuleMatcher<2> {
484    primop_rules! {
485        fn(a: T, b: T) -> bool[E::ge(a, b)]
486        if T.is_integral()
487
488        fn(a: bool, b: bool) -> bool[E::implies(b, a)]
489
490        fn(a: T, b: T) -> bool
491    }
492}
493
494/// `a < b`
495fn mk_lt_rules() -> RuleMatcher<2> {
496    primop_rules! {
497        fn(a: T, b: T) -> bool[E::lt(a, b)]
498        if T.is_integral()
499
500        fn(a: bool, b: bool) -> bool[E::and(a.not(), b)]
501
502        fn(a: T, b: T) -> bool
503    }
504}
505
506/// `a > b`
507fn mk_gt_rules() -> RuleMatcher<2> {
508    primop_rules! {
509        fn(a: T, b: T) -> bool[E::gt(a, b)]
510        if T.is_integral()
511
512        fn(a: bool, b: bool) -> bool[E::and(a, b.not())]
513
514        fn(a: T, b: T) -> bool
515    }
516}
517
518/// `a << b`
519fn mk_shl_rules() -> RuleMatcher<2> {
520    primop_rules! {
521        fn(a: T, b: S) -> { T[E::prim_val(rty::BinOp::BitShl, a, b)] | E::prim_rel(rty::BinOp::BitShl, a, b) }
522        if T.is_integral() && S.is_integral()
523    }
524}
525
526/// `a >> b`
527fn mk_shr_rules() -> RuleMatcher<2> {
528    primop_rules! {
529        fn(a: T, b: S) -> { T[E::prim_val(rty::BinOp::BitShr, a, b)] | E::prim_rel(rty::BinOp::BitShr, a, b) }
530        if T.is_integral() && S.is_integral()
531    }
532}
533
534/// `-a`
535fn mk_neg_rules(overflow_mode: OverflowMode) -> RuleMatcher<1> {
536    match overflow_mode {
537        OverflowMode::Strict => {
538            primop_rules! {
539                fn(a: T) -> T[a.neg()]
540                requires E::ne(a, E::int_min(int_ty)) => ConstrReason::Overflow
541                if let &BaseTy::Int(int_ty) = T
542
543                fn(a: T) -> T[a.neg()]
544                if T.is_float()
545            }
546        }
547        OverflowMode::Lazy | OverflowMode::StrictUnder => {
548            primop_rules! {
549                fn(a: T) -> T{v: E::implies(E::ne(a, E::int_min(int_ty)), E::eq(v, a.neg())) }
550                if let &BaseTy::Int(int_ty) = T
551
552                fn(a: T) -> T[a.neg()]
553                if T.is_float()
554            }
555        }
556        OverflowMode::None => {
557            primop_rules! {
558                fn(a: T) -> T[a.neg()]
559                if T.is_integral()
560
561                fn(a: T) -> T
562                if T.is_float()
563            }
564        }
565    }
566}
567
568/// `!a`
569fn mk_not_rules() -> RuleMatcher<1> {
570    primop_rules! {
571        fn(a: bool) -> bool[a.not()]
572
573        fn(a: T) -> T
574        if T.is_integral()
575    }
576}