flux_refineck/
primops.rs

1/// This file defines the refinement rules for primitive operations.
2/// Flux needs to define how to reason about primitive operations on different
3/// [`BaseTy`]s. This is done by defining a set of rules for each operation.
4///
5/// For example, equality checks depend on whether the `BaseTy` is treated as
6/// refineable or opaque.
7///
8/// ```
9/// // Make the rules for `a == b`.
10/// fn mk_eq_rules() -> RuleMatcher<2> {
11///     primop_rules! {
12///         // if the `BaseTy` is refineable, then we can reason about equality.
13///         // The specified types in the `if` are refineable and Flux will use
14///         // the refined postcondition (`bool[E::eq(a, b)]`) to reason about
15///         // the invariants of `==`.
16///         fn(a: T, b: T) -> bool[E::eq(a, b)]
17///         if T.is_integral() || T.is_bool() || T.is_char() || T.is_str()
18///
19///         // Otherwise, if the `BaseTy` is opaque, then we can't reason
20///         // about equality. Flux only knows that the return type is a boolean,
21///         // but the return value is unrefined.
22///         fn(a: T, b: T) -> bool
23///     }
24/// }
25/// ```
26use std::{fmt, hash::Hash, sync::LazyLock};
27
28use flux_common::tracked_span_bug;
29use flux_config::OverflowMode;
30use flux_infer::infer::ConstrReason;
31use flux_macros::primop_rules;
32use flux_middle::rty::{self, BaseTy, Expr, Sort};
33use flux_rustc_bridge::mir;
34use rty::{
35    BinOp::{BitAnd, BitOr, BitShl, BitShr, BitXor, Mod},
36    Expr as E,
37};
38use rustc_data_structures::unord::UnordMap;
39
40#[derive(Debug)]
41pub(crate) struct MatchedRule {
42    pub precondition: Option<Pre>,
43    pub output_type: rty::Ty,
44}
45
46#[derive(Debug)]
47pub(crate) struct Pre {
48    pub reason: ConstrReason,
49    pub pred: Expr,
50}
51
52pub(crate) fn match_bin_op(
53    op: mir::BinOp,
54    bty1: &BaseTy,
55    idx1: &Expr,
56    bty2: &BaseTy,
57    idx2: &Expr,
58    overflow_mode: OverflowMode,
59) -> MatchedRule {
60    let table = match overflow_mode {
61        OverflowMode::Strict => &OVERFLOW_STRICT_BIN_OPS,
62        OverflowMode::Lazy => &OVERFLOW_LAZY_BIN_OPS,
63        OverflowMode::None => &OVERFLOW_NONE_BIN_OPS,
64        OverflowMode::StrictUnder => &OVERFLOW_STRICT_UNDER_BIN_OPS,
65    };
66    table.match_inputs(&op, [(bty1.clone(), idx1.clone()), (bty2.clone(), idx2.clone())])
67}
68
69pub(crate) fn match_un_op(
70    op: mir::UnOp,
71    bty: &BaseTy,
72    idx: &Expr,
73    overflow_mode: OverflowMode,
74) -> MatchedRule {
75    let table = match overflow_mode {
76        OverflowMode::Strict => &OVERFLOW_STRICT_UN_OPS,
77        OverflowMode::None => &OVERFLOW_NONE_UN_OPS,
78        OverflowMode::Lazy | OverflowMode::StrictUnder => &OVERFLOW_LAZY_UN_OPS,
79    };
80    table.match_inputs(&op, [(bty.clone(), idx.clone())])
81}
82
83struct RuleTable<Op: Eq + Hash, const N: usize> {
84    rules: UnordMap<Op, RuleMatcher<N>>,
85}
86
87impl<Op: Eq + Hash + fmt::Debug, const N: usize> RuleTable<Op, N> {
88    fn match_inputs(&self, op: &Op, inputs: [(BaseTy, Expr); N]) -> MatchedRule {
89        (self.rules[op])(&inputs)
90            .unwrap_or_else(|| tracked_span_bug!("no primop rule for {op:?} using {inputs:?}"))
91    }
92}
93
94type RuleMatcher<const N: usize> = fn(&[(BaseTy, Expr); N]) -> Option<MatchedRule>;
95
96static OVERFLOW_NONE_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
97    use mir::BinOp::*;
98    RuleTable {
99        rules: [
100            // Arith
101            (Add, mk_add_rules(OverflowMode::None)),
102            (Mul, mk_mul_rules(OverflowMode::None)),
103            (Sub, mk_sub_rules(OverflowMode::None)),
104            (Div, mk_div_rules()),
105            (Rem, mk_rem_rules()),
106            // Bitwise
107            (BitAnd, mk_bit_and_rules()),
108            (BitOr, mk_bit_or_rules()),
109            (BitXor, mk_bit_xor_rules()),
110            // Cmp
111            (Eq, mk_eq_rules()),
112            (Ne, mk_ne_rules()),
113            (Le, mk_le_rules()),
114            (Ge, mk_ge_rules()),
115            (Lt, mk_lt_rules()),
116            (Gt, mk_gt_rules()),
117            // Shifts
118            (Shl, mk_shl_rules()),
119            (Shr, mk_shr_rules()),
120        ]
121        .into_iter()
122        .collect(),
123    }
124});
125
126static OVERFLOW_STRICT_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
127    use mir::BinOp::*;
128    RuleTable {
129        rules: [
130            // Arith
131            (Add, mk_add_rules(OverflowMode::Strict)),
132            (Mul, mk_mul_rules(OverflowMode::Strict)),
133            (Sub, mk_sub_rules(OverflowMode::Strict)),
134            (Div, mk_div_rules()),
135            (Rem, mk_rem_rules()),
136            // Bitwise
137            (BitAnd, mk_bit_and_rules()),
138            (BitOr, mk_bit_or_rules()),
139            (BitXor, mk_bit_xor_rules()),
140            // Cmp
141            (Eq, mk_eq_rules()),
142            (Ne, mk_ne_rules()),
143            (Le, mk_le_rules()),
144            (Ge, mk_ge_rules()),
145            (Lt, mk_lt_rules()),
146            (Gt, mk_gt_rules()),
147            // Shifts
148            (Shl, mk_shl_rules()),
149            (Shr, mk_shr_rules()),
150        ]
151        .into_iter()
152        .collect(),
153    }
154});
155
156static OVERFLOW_LAZY_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
157    use mir::BinOp::*;
158    RuleTable {
159        rules: [
160            // Arith
161            (Add, mk_add_rules(OverflowMode::Lazy)),
162            (Mul, mk_mul_rules(OverflowMode::Lazy)),
163            (Sub, mk_sub_rules(OverflowMode::Lazy)),
164            (Div, mk_div_rules()),
165            (Rem, mk_rem_rules()),
166            // Bitwise
167            (BitAnd, mk_bit_and_rules()),
168            (BitOr, mk_bit_or_rules()),
169            (BitXor, mk_bit_xor_rules()),
170            // Cmp
171            (Eq, mk_eq_rules()),
172            (Ne, mk_ne_rules()),
173            (Le, mk_le_rules()),
174            (Ge, mk_ge_rules()),
175            (Lt, mk_lt_rules()),
176            (Gt, mk_gt_rules()),
177            // Shifts
178            (Shl, mk_shl_rules()),
179            (Shr, mk_shr_rules()),
180        ]
181        .into_iter()
182        .collect(),
183    }
184});
185
186static OVERFLOW_STRICT_UNDER_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
187    use mir::BinOp::*;
188    RuleTable {
189        rules: [
190            // Arith
191            (Add, mk_add_rules(OverflowMode::StrictUnder)),
192            (Mul, mk_mul_rules(OverflowMode::StrictUnder)),
193            (Sub, mk_sub_rules(OverflowMode::StrictUnder)),
194            (Div, mk_div_rules()),
195            (Rem, mk_rem_rules()),
196            // Bitwise
197            (BitAnd, mk_bit_and_rules()),
198            (BitOr, mk_bit_or_rules()),
199            (BitXor, mk_bit_xor_rules()),
200            // Cmp
201            (Eq, mk_eq_rules()),
202            (Ne, mk_ne_rules()),
203            (Le, mk_le_rules()),
204            (Ge, mk_ge_rules()),
205            (Lt, mk_lt_rules()),
206            (Gt, mk_gt_rules()),
207            // Shifts
208            (Shl, mk_shl_rules()),
209            (Shr, mk_shr_rules()),
210        ]
211        .into_iter()
212        .collect(),
213    }
214});
215
216static OVERFLOW_NONE_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
217    use mir::UnOp::*;
218    RuleTable {
219        rules: [(Neg, mk_neg_rules(OverflowMode::None)), (Not, mk_not_rules())]
220            .into_iter()
221            .collect(),
222    }
223});
224
225static OVERFLOW_LAZY_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
226    use mir::UnOp::*;
227    RuleTable {
228        rules: [(Neg, mk_neg_rules(OverflowMode::Lazy)), (Not, mk_not_rules())]
229            .into_iter()
230            .collect(),
231    }
232});
233
234static OVERFLOW_STRICT_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
235    use mir::UnOp::*;
236    RuleTable {
237        rules: [(Neg, mk_neg_rules(OverflowMode::Strict)), (Not, mk_not_rules())]
238            .into_iter()
239            .collect(),
240    }
241});
242
243fn valid_int(e: impl Into<Expr>, int_ty: rty::IntTy) -> rty::Expr {
244    let e1 = e.into();
245    let e2 = e1.clone();
246    E::and(E::ge(e1, E::int_min(int_ty)), E::le(e2, E::int_max(int_ty)))
247}
248
249fn valid_uint(e: impl Into<Expr>, uint_ty: rty::UintTy) -> rty::Expr {
250    let e1 = e.into();
251    let e2 = e1.clone();
252    E::and(E::ge(e1, 0), E::le(e2, E::uint_max(uint_ty)))
253}
254
255/// `a + b`
256fn mk_add_rules(overflow_mode: OverflowMode) -> RuleMatcher<2> {
257    match overflow_mode {
258        OverflowMode::Strict => {
259            primop_rules! {
260                fn(a: T, b: T) -> T[a + b]
261                requires valid_int(a + b, int_ty) => ConstrReason::Overflow
262                if let &BaseTy::Int(int_ty) = T
263
264                fn(a: T, b: T) -> T[a + b]
265                requires valid_uint(a + b, uint_ty) => ConstrReason::Overflow
266                if let &BaseTy::Uint(uint_ty) = T
267
268                fn(a: T, b: T) -> T
269            }
270        }
271
272        OverflowMode::Lazy | OverflowMode::StrictUnder => {
273            primop_rules! {
274                fn(a: T, b: T) -> T{v: E::implies(valid_int(a + b, int_ty), E::eq(v, a+b)) }
275                if let &BaseTy::Int(int_ty) = T
276
277                fn(a: T, b: T) -> T{v: E::implies(valid_uint(a + b, uint_ty), E::eq(v, a+b)) }
278                if let &BaseTy::Uint(uint_ty) = T
279
280                fn(a: T, b: T) -> T
281            }
282        }
283
284        OverflowMode::None => {
285            primop_rules! {
286                fn(a: T, b: T) -> T[a + b]
287                if T.is_integral()
288
289                fn(a: T, b: T) -> T
290            }
291        }
292    }
293}
294
295/// `a * b`
296fn mk_mul_rules(overflow_mode: OverflowMode) -> RuleMatcher<2> {
297    match overflow_mode {
298        OverflowMode::Strict => {
299            primop_rules! {
300                fn(a: T, b: T) -> T[a * b]
301                requires valid_int(a * b, int_ty) => ConstrReason::Overflow
302                if let &BaseTy::Int(int_ty) = T
303
304                fn(a: T, b: T) -> T[a * b]
305                requires valid_uint(a * b, uint_ty) => ConstrReason::Overflow
306                if let &BaseTy::Uint(uint_ty) = T
307
308                fn(a: T, b: T) -> T
309            }
310        }
311
312        OverflowMode::Lazy | OverflowMode::StrictUnder => {
313            primop_rules! {
314                fn(a: T, b: T) -> T{v: E::implies(valid_int(a * b, int_ty), E::eq(v, a * b)) }
315                if let &BaseTy::Int(int_ty) = T
316
317                fn(a: T, b: T) -> T{v: E::implies(valid_uint(a * b, uint_ty), E::eq(v, a * b)) }
318                if let &BaseTy::Uint(uint_ty) = T
319
320                fn(a: T, b: T) -> T
321            }
322        }
323
324        OverflowMode::None => {
325            primop_rules!(
326                fn(a: T, b: T) -> T[a * b]
327                if T.is_integral()
328
329                fn(a: T, b: T) -> T
330                if T.is_float()
331            )
332        }
333    }
334}
335
336/// `a - b`
337fn mk_sub_rules(overflow_mode: OverflowMode) -> RuleMatcher<2> {
338    match overflow_mode {
339        OverflowMode::Strict => {
340            primop_rules! {
341                fn(a: T, b: T) -> T[a - b]
342                requires valid_int(a - b, int_ty) => ConstrReason::Overflow
343                if let &BaseTy::Int(int_ty) = T
344
345                fn(a: T, b: T) -> T[a - b]
346                requires valid_uint(a - b, uint_ty) => ConstrReason::Overflow
347                if let &BaseTy::Uint(uint_ty) = T
348
349                fn(a: T, b: T) -> T
350            }
351        }
352
353        // like Lazy, but we also check for underflow on unsigned subtraction
354        OverflowMode::StrictUnder => {
355            primop_rules! {
356                fn(a: T, b: T) -> T{v: E::implies(valid_int(a - b, int_ty), E::eq(v, a - b)) }
357                if let &BaseTy::Int(int_ty) = T
358
359                fn(a: T, b: T) -> T{v: E::implies(valid_uint(a - b, uint_ty), E::eq(v, a - b)) }
360                requires E::ge(a - b, 0) => ConstrReason::Underflow
361                if let &BaseTy::Uint(uint_ty) = T
362
363                fn(a: T, b: T) -> T
364            }
365        }
366
367        OverflowMode::Lazy => {
368            primop_rules! {
369                fn(a: T, b: T) -> T{v: E::implies(valid_int(a - b, int_ty), E::eq(v, a - b)) }
370                if let &BaseTy::Int(int_ty) = T
371
372                fn(a: T, b: T) -> T{v: E::implies(valid_uint(a - b, uint_ty), E::eq(v, a - b)) }
373                if let &BaseTy::Uint(uint_ty) = T
374
375                fn(a: T, b: T) -> T
376            }
377        }
378
379        OverflowMode::None => {
380            primop_rules! {
381                fn(a: T, b: T) -> T[a - b]
382                requires E::ge(a - b, 0) => ConstrReason::Underflow
383                if T.is_unsigned()
384
385                fn(a: T, b: T) -> T[a - b]
386                if T.is_signed()
387
388                fn(a: T, b: T) -> T
389                if T.is_float()
390            }
391        }
392    }
393}
394
395/// `a/b`
396fn mk_div_rules() -> RuleMatcher<2> {
397    primop_rules! {
398        fn(a: T, b: T) -> T[a/b]
399        requires E::ne(b, 0) => ConstrReason::Div
400        if T.is_integral()
401
402        fn(a: T, b: T) -> T
403        if T.is_float()
404    }
405}
406
407/// `a % b`
408fn mk_rem_rules() -> RuleMatcher<2> {
409    primop_rules! {
410        fn(a: T, b: T) -> T[E::binary_op(Mod(Sort::Int), a, b)]
411        requires E::ne(b, 0) => ConstrReason::Rem
412        if T.is_unsigned()
413
414        fn(a: T, b: T) -> T{v: E::implies(
415                                   E::and(E::ge(a, 0), E::ge(b, 0)),
416                                   E::eq(v, E::binary_op(Mod(Sort::Int), a, b))) }
417        requires E::ne(b, 0) => ConstrReason::Rem
418        if T.is_signed()
419
420        fn (a: T, b: T) -> T
421        if T.is_float()
422    }
423}
424
425/// `a & b`
426fn mk_bit_and_rules() -> RuleMatcher<2> {
427    primop_rules! {
428        fn(a: T, b: T) -> { T[E::prim_val(BitAnd(Sort::Int), a, b)] | E::prim_rel(BitAnd(Sort::Int), a, b) }
429        if T.is_integral()
430
431        fn(a: bool, b: bool) -> bool[E::and(a, b)]
432    }
433}
434
435/// `a | b`
436fn mk_bit_or_rules() -> RuleMatcher<2> {
437    primop_rules! {
438        fn(a: T, b: T) -> { T[E::prim_val(BitOr(Sort::Int), a, b)] | E::prim_rel(BitOr(Sort::Int), a, b) }
439        if T.is_integral()
440
441        fn(a: bool, b: bool) -> bool[E::or(a, b)]
442    }
443}
444
445/// `a ^ b`
446fn mk_bit_xor_rules() -> RuleMatcher<2> {
447    primop_rules! {
448        fn(a: T, b: T) -> { T[E::prim_val(BitXor(Sort::Int), a, b)] | E::prim_rel(BitXor(Sort::Int), a, b) }
449        if T.is_integral()
450    }
451}
452
453/// `a == b`
454fn mk_eq_rules() -> RuleMatcher<2> {
455    primop_rules! {
456        fn(a: T, b: T) -> bool[E::eq(a, b)]
457        if T.is_integral() || T.is_bool() || T.is_char() || T.is_str()
458
459        fn(a: T, b: T) -> bool
460    }
461}
462
463/// `a != b`
464fn mk_ne_rules() -> RuleMatcher<2> {
465    primop_rules! {
466        fn(a: T, b: T) -> bool[E::ne(a, b)]
467        if T.is_integral() || T.is_bool()
468
469        fn(a: T, b: T) -> bool
470    }
471}
472
473/// `a <= b`
474fn mk_le_rules() -> RuleMatcher<2> {
475    primop_rules! {
476        fn(a: T, b: T) -> bool[E::le(a, b)]
477        if T.is_integral()
478
479        fn(a: bool, b: bool) -> bool[E::implies(a, b)]
480
481        fn(a: T, b: T) -> bool
482    }
483}
484
485/// `a >= b`
486fn mk_ge_rules() -> RuleMatcher<2> {
487    primop_rules! {
488        fn(a: T, b: T) -> bool[E::ge(a, b)]
489        if T.is_integral()
490
491        fn(a: bool, b: bool) -> bool[E::implies(b, a)]
492
493        fn(a: T, b: T) -> bool
494    }
495}
496
497/// `a < b`
498fn mk_lt_rules() -> RuleMatcher<2> {
499    primop_rules! {
500        fn(a: T, b: T) -> bool[E::lt(a, b)]
501        if T.is_integral()
502
503        fn(a: bool, b: bool) -> bool[E::and(a.not(), b)]
504
505        fn(a: T, b: T) -> bool
506    }
507}
508
509/// `a > b`
510fn mk_gt_rules() -> RuleMatcher<2> {
511    primop_rules! {
512        fn(a: T, b: T) -> bool[E::gt(a, b)]
513        if T.is_integral()
514
515        fn(a: bool, b: bool) -> bool[E::and(a, b.not())]
516
517        fn(a: T, b: T) -> bool
518    }
519}
520
521/// `a << b`
522fn mk_shl_rules() -> RuleMatcher<2> {
523    primop_rules! {
524        fn(a: T, b: S) -> { T[E::prim_val(BitShl(Sort::Int), a, b)] | E::prim_rel(BitShl(Sort::Int), a, b) }
525        if T.is_integral() && S.is_integral()
526    }
527}
528
529/// `a >> b`
530fn mk_shr_rules() -> RuleMatcher<2> {
531    primop_rules! {
532        fn(a: T, b: S) -> { T[E::prim_val(BitShr(Sort::Int), a, b)] | E::prim_rel(BitShr(Sort::Int), a, b) }
533        if T.is_integral() && S.is_integral()
534    }
535}
536
537/// `-a`
538fn mk_neg_rules(overflow_mode: OverflowMode) -> RuleMatcher<1> {
539    match overflow_mode {
540        OverflowMode::Strict => {
541            primop_rules! {
542                fn(a: T) -> T[a.neg()]
543                requires E::ne(a, E::int_min(int_ty)) => ConstrReason::Overflow
544                if let &BaseTy::Int(int_ty) = T
545
546                fn(a: T) -> T[a.neg()]
547                if T.is_float()
548            }
549        }
550        OverflowMode::Lazy | OverflowMode::StrictUnder => {
551            primop_rules! {
552                fn(a: T) -> T{v: E::implies(E::ne(a, E::int_min(int_ty)), E::eq(v, a.neg())) }
553                if let &BaseTy::Int(int_ty) = T
554
555                fn(a: T) -> T[a.neg()]
556                if T.is_float()
557            }
558        }
559        OverflowMode::None => {
560            primop_rules! {
561                fn(a: T) -> T[a.neg()]
562                if T.is_integral()
563
564                fn(a: T) -> T
565                if T.is_float()
566            }
567        }
568    }
569}
570
571/// `!a`
572fn mk_not_rules() -> RuleMatcher<1> {
573    primop_rules! {
574        fn(a: bool) -> bool[a.not()]
575
576        fn(a: T) -> T
577        if T.is_integral()
578    }
579}