flux_refineck/
primops.rs

1/// This file defines the refinement rules for primitive operations.
2/// Flux needs to define how to reason about primitive operations on different
3/// [`BaseTy`]s. This is done by defining a set of rules for each operation.
4///
5/// For example, equality checks depend on whether the `BaseTy` is treated as
6/// refineable or opaque.
7///
8/// ```
9/// // Make the rules for `a == b`.
10/// fn mk_eq_rules() -> RuleMatcher<2> {
11///     primop_rules! {
12///         // if the `BaseTy` is refineable, then we can reason about equality.
13///         // The specified types in the `if` are refineable and Flux will use
14///         // the refined postcondition (`bool[E::eq(a, b)]`) to reason about
15///         // the invariants of `==`.
16///         fn(a: T, b: T) -> bool[E::eq(a, b)]
17///         if T.is_integral() || T.is_bool() || T.is_char() || T.is_str()
18///
19///         // Otherwise, if the `BaseTy` is opaque, then we can't reason
20///         // about equality. Flux only knows that the return type is a boolean,
21///         // but the return value is unrefined.
22///         fn(a: T, b: T) -> bool
23///     }
24/// }
25/// ```
26use std::{fmt, hash::Hash, sync::LazyLock};
27
28use flux_common::tracked_span_bug;
29use flux_config::OverflowMode;
30use flux_infer::infer::ConstrReason;
31use flux_macros::primop_rules;
32use flux_middle::rty::{self, BaseTy, Expr, Sort};
33use flux_rustc_bridge::mir;
34use rty::{
35    BinOp::{BitAnd, BitOr, BitShl, BitShr, BitXor, Mod},
36    Expr as E,
37};
38use rustc_data_structures::unord::UnordMap;
39
40#[derive(Debug)]
41pub(crate) struct MatchedRule {
42    pub precondition: Option<Pre>,
43    pub output_type: rty::Ty,
44}
45
46#[derive(Debug)]
47pub(crate) struct Pre {
48    pub reason: ConstrReason,
49    pub pred: Expr,
50}
51
52pub(crate) fn match_bin_op(
53    op: mir::BinOp,
54    bty1: &BaseTy,
55    idx1: &Expr,
56    bty2: &BaseTy,
57    idx2: &Expr,
58    overflow_mode: OverflowMode,
59) -> MatchedRule {
60    let table = match overflow_mode {
61        OverflowMode::Strict => &OVERFLOW_STRICT_BIN_OPS,
62        OverflowMode::Lazy => &OVERFLOW_LAZY_BIN_OPS,
63        OverflowMode::None => &OVERFLOW_NONE_BIN_OPS,
64        OverflowMode::StrictUnder => &OVERFLOW_STRICT_UNDER_BIN_OPS,
65    };
66    table.match_inputs(&op, [(bty1.clone(), idx1.clone()), (bty2.clone(), idx2.clone())])
67}
68
69pub(crate) fn match_un_op(
70    op: mir::UnOp,
71    bty: &BaseTy,
72    idx: &Expr,
73    overflow_mode: OverflowMode,
74) -> MatchedRule {
75    let table = match overflow_mode {
76        OverflowMode::Strict => &OVERFLOW_STRICT_UN_OPS,
77        OverflowMode::None => &OVERFLOW_NONE_UN_OPS,
78        OverflowMode::Lazy | OverflowMode::StrictUnder => &OVERFLOW_LAZY_UN_OPS,
79    };
80    table.match_inputs(&op, [(bty.clone(), idx.clone())])
81}
82
83struct RuleTable<Op: Eq + Hash, const N: usize> {
84    rules: UnordMap<Op, RuleMatcher<N>>,
85}
86
87impl<Op: Eq + Hash + fmt::Debug, const N: usize> RuleTable<Op, N> {
88    fn match_inputs(&self, op: &Op, inputs: [(BaseTy, Expr); N]) -> MatchedRule {
89        (self.rules[op])(&inputs)
90            .unwrap_or_else(|| tracked_span_bug!("no primop rule for {op:?} using {inputs:?}"))
91    }
92}
93
94type RuleMatcher<const N: usize> = fn(&[(BaseTy, Expr); N]) -> Option<MatchedRule>;
95
96static OVERFLOW_NONE_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
97    use mir::BinOp::*;
98    RuleTable {
99        rules: [
100            // Arith
101            (Add, mk_add_rules(OverflowMode::None)),
102            (Mul, mk_mul_rules(OverflowMode::None)),
103            (Sub, mk_sub_rules(OverflowMode::None)),
104            (Div, mk_div_rules()),
105            (Rem, mk_rem_rules()),
106            // Bitwise
107            (BitAnd, mk_bit_and_rules()),
108            (BitOr, mk_bit_or_rules()),
109            (BitXor, mk_bit_xor_rules()),
110            // Cmp
111            (Eq, mk_eq_rules()),
112            (Ne, mk_ne_rules()),
113            (Le, mk_le_rules()),
114            (Ge, mk_ge_rules()),
115            (Lt, mk_lt_rules()),
116            (Gt, mk_gt_rules()),
117            // Shifts
118            (Shl, mk_shl_rules()),
119            (Shr, mk_shr_rules()),
120        ]
121        .into_iter()
122        .collect(),
123    }
124});
125
126static OVERFLOW_STRICT_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
127    use mir::BinOp::*;
128    RuleTable {
129        rules: [
130            // Arith
131            (Add, mk_add_rules(OverflowMode::Strict)),
132            (Mul, mk_mul_rules(OverflowMode::Strict)),
133            (Sub, mk_sub_rules(OverflowMode::Strict)),
134            (Div, mk_div_rules()),
135            (Rem, mk_rem_rules()),
136            // Bitwise
137            (BitAnd, mk_bit_and_rules()),
138            (BitOr, mk_bit_or_rules()),
139            (BitXor, mk_bit_xor_rules()),
140            // Cmp
141            (Eq, mk_eq_rules()),
142            (Ne, mk_ne_rules()),
143            (Le, mk_le_rules()),
144            (Ge, mk_ge_rules()),
145            (Lt, mk_lt_rules()),
146            (Gt, mk_gt_rules()),
147            // Shifts
148            (Shl, mk_shl_rules()),
149            (Shr, mk_shr_rules()),
150        ]
151        .into_iter()
152        .collect(),
153    }
154});
155
156static OVERFLOW_LAZY_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
157    use mir::BinOp::*;
158    RuleTable {
159        rules: [
160            // Arith
161            (Add, mk_add_rules(OverflowMode::Lazy)),
162            (Mul, mk_mul_rules(OverflowMode::Lazy)),
163            (Sub, mk_sub_rules(OverflowMode::Lazy)),
164            (Div, mk_div_rules()),
165            (Rem, mk_rem_rules()),
166            // Bitwise
167            (BitAnd, mk_bit_and_rules()),
168            (BitOr, mk_bit_or_rules()),
169            (BitXor, mk_bit_xor_rules()),
170            // Cmp
171            (Eq, mk_eq_rules()),
172            (Ne, mk_ne_rules()),
173            (Le, mk_le_rules()),
174            (Ge, mk_ge_rules()),
175            (Lt, mk_lt_rules()),
176            (Gt, mk_gt_rules()),
177            // Shifts
178            (Shl, mk_shl_rules()),
179            (Shr, mk_shr_rules()),
180        ]
181        .into_iter()
182        .collect(),
183    }
184});
185
186static OVERFLOW_STRICT_UNDER_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
187    use mir::BinOp::*;
188    RuleTable {
189        rules: [
190            // Arith
191            (Add, mk_add_rules(OverflowMode::StrictUnder)),
192            (Mul, mk_mul_rules(OverflowMode::StrictUnder)),
193            (Sub, mk_sub_rules(OverflowMode::StrictUnder)),
194            (Div, mk_div_rules()),
195            (Rem, mk_rem_rules()),
196            // Bitwise
197            (BitAnd, mk_bit_and_rules()),
198            (BitOr, mk_bit_or_rules()),
199            (BitXor, mk_bit_xor_rules()),
200            // Cmp
201            (Eq, mk_eq_rules()),
202            (Ne, mk_ne_rules()),
203            (Le, mk_le_rules()),
204            (Ge, mk_ge_rules()),
205            (Lt, mk_lt_rules()),
206            (Gt, mk_gt_rules()),
207            // Shifts
208            (Shl, mk_shl_rules()),
209            (Shr, mk_shr_rules()),
210        ]
211        .into_iter()
212        .collect(),
213    }
214});
215
216static OVERFLOW_NONE_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
217    use mir::UnOp::*;
218    RuleTable {
219        rules: [(Neg, mk_neg_rules(OverflowMode::None)), (Not, mk_not_rules())]
220            .into_iter()
221            .collect(),
222    }
223});
224
225static OVERFLOW_LAZY_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
226    use mir::UnOp::*;
227    RuleTable {
228        rules: [(Neg, mk_neg_rules(OverflowMode::Lazy)), (Not, mk_not_rules())]
229            .into_iter()
230            .collect(),
231    }
232});
233
234static OVERFLOW_STRICT_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
235    use mir::UnOp::*;
236    RuleTable {
237        rules: [(Neg, mk_neg_rules(OverflowMode::Strict)), (Not, mk_not_rules())]
238            .into_iter()
239            .collect(),
240    }
241});
242
243fn valid_int(e: impl Into<Expr>, int_ty: rty::IntTy) -> rty::Expr {
244    let e1 = e.into();
245    let e2 = e1.clone();
246    E::and(E::ge(e1, E::int_min(int_ty)), E::le(e2, E::int_max(int_ty)))
247}
248
249fn valid_uint(e: impl Into<Expr>, uint_ty: rty::UintTy) -> rty::Expr {
250    let e1 = e.into();
251    let e2 = e1.clone();
252    E::and(E::ge(e1, 0), E::le(e2, E::uint_max(uint_ty)))
253}
254
255/// `a + b`
256fn mk_add_rules(overflow_mode: OverflowMode) -> RuleMatcher<2> {
257    match overflow_mode {
258        OverflowMode::Strict => {
259            primop_rules! {
260                fn(a: T, b: T) -> T[a + b]
261                requires valid_int(a + b, int_ty) => ConstrReason::Overflow
262                if let &BaseTy::Int(int_ty) = T
263
264                fn(a: T, b: T) -> T[a + b]
265                requires valid_uint(a + b, uint_ty) => ConstrReason::Overflow
266                if let &BaseTy::Uint(uint_ty) = T
267
268                fn(a: T, b: T) -> T
269            }
270        }
271
272        OverflowMode::Lazy | OverflowMode::StrictUnder => {
273            primop_rules! {
274                fn(a: T, b: T) -> T{v: E::implies(valid_int(a + b, int_ty), E::eq(v, a+b)) }
275                if let &BaseTy::Int(int_ty) = T
276
277                fn(a: T, b: T) -> T{v: E::implies(valid_uint(a + b, uint_ty), E::eq(v, a+b)) }
278                if let &BaseTy::Uint(uint_ty) = T
279
280                fn(a: T, b: T) -> T
281            }
282        }
283
284        OverflowMode::None => {
285            primop_rules! {
286                fn(a: T, b: T) -> T[a + b]
287                if T.is_integral()
288
289                fn(a: T, b: T) -> T
290            }
291        }
292    }
293}
294
295/// `a * b`
296fn mk_mul_rules(overflow_mode: OverflowMode) -> RuleMatcher<2> {
297    match overflow_mode {
298        OverflowMode::Strict => {
299            primop_rules! {
300                fn(a: T, b: T) -> T[a * b]
301                requires valid_int(a * b, int_ty) => ConstrReason::Overflow
302                if let &BaseTy::Int(int_ty) = T
303
304                fn(a: T, b: T) -> T[a * b]
305                requires valid_uint(a * b, uint_ty) => ConstrReason::Overflow
306                if let &BaseTy::Uint(uint_ty) = T
307
308                fn(a: T, b: T) -> T
309            }
310        }
311
312        OverflowMode::Lazy | OverflowMode::StrictUnder => {
313            primop_rules! {
314                fn(a: T, b: T) -> T{v: E::implies(valid_int(a * b, int_ty), E::eq(v, a * b)) }
315                if let &BaseTy::Int(int_ty) = T
316
317                fn(a: T, b: T) -> T{v: E::implies(valid_uint(a * b, uint_ty), E::eq(v, a * b)) }
318                if let &BaseTy::Uint(uint_ty) = T
319
320                fn(a: T, b: T) -> T
321            }
322        }
323
324        OverflowMode::None => {
325            primop_rules!(
326                fn(a: T, b: T) -> T[a * b]
327                if T.is_integral()
328
329                fn(a: T, b: T) -> T
330                if T.is_float()
331            )
332        }
333    }
334}
335
336/// `a - b`
337fn mk_sub_rules(overflow_mode: OverflowMode) -> RuleMatcher<2> {
338    match overflow_mode {
339        OverflowMode::Strict => {
340            primop_rules! {
341                fn(a: T, b: T) -> T[a - b]
342                requires valid_int(a - b, int_ty) => ConstrReason::Overflow
343                if let &BaseTy::Int(int_ty) = T
344
345                fn(a: T, b: T) -> T[a - b]
346                requires valid_uint(a - b, uint_ty) => ConstrReason::Overflow
347                if let &BaseTy::Uint(uint_ty) = T
348
349                fn(a: T, b: T) -> T
350            }
351        }
352
353        // like Lazy, but we also check for underflow on unsigned subtraction
354        OverflowMode::StrictUnder => {
355            primop_rules! {
356                fn(a: T, b: T) -> T{v: E::implies(valid_int(a - b, int_ty), E::eq(v, a - b)) }
357                if let &BaseTy::Int(int_ty) = T
358
359                fn(a: T, b: T) -> T{v: E::implies(valid_uint(a - b, uint_ty), E::eq(v, a - b)) }
360                requires E::ge(a - b, 0) => ConstrReason::Underflow
361                if let &BaseTy::Uint(uint_ty) = T
362
363                fn(a: T, b: T) -> T
364            }
365        }
366
367        OverflowMode::Lazy => {
368            primop_rules! {
369                fn(a: T, b: T) -> T{v: E::implies(valid_int(a - b, int_ty), E::eq(v, a - b)) }
370                if let &BaseTy::Int(int_ty) = T
371
372                fn(a: T, b: T) -> T{v: E::implies(valid_uint(a - b, uint_ty), E::eq(v, a - b)) }
373                if let &BaseTy::Uint(uint_ty) = T
374
375                fn(a: T, b: T) -> T
376            }
377        }
378
379        OverflowMode::None => {
380            primop_rules! {
381                fn(a: T, b: T) -> T[a - b]
382                requires E::ge(a - b, 0) => ConstrReason::Underflow
383                if T.is_unsigned()
384
385                fn(a: T, b: T) -> T[a - b]
386                if T.is_signed()
387
388                fn(a: T, b: T) -> T
389                if T.is_float()
390            }
391        }
392    }
393}
394
395/// `a/b`
396fn mk_div_rules() -> RuleMatcher<2> {
397    primop_rules! {
398        fn(a: T, b: T) -> T[a/b]
399        requires E::ne(b, 0) => ConstrReason::Div
400        if T.is_integral()
401
402        fn(a: T, b: T) -> T
403        if T.is_float()
404    }
405}
406
407/// `a % b`
408fn mk_rem_rules() -> RuleMatcher<2> {
409    primop_rules! {
410        fn(a: T, b: T) -> T[E::binary_op(Mod(Sort::Int), a, b)]
411        requires E::ne(b, 0) => ConstrReason::Rem
412        if T.is_unsigned()
413
414        fn(a: T, b: T) -> T{v: E::implies(
415                                   E::and(E::ge(a, 0), E::ge(b, 0)),
416                                   E::eq(v, E::binary_op(Mod(Sort::Int), a, b))) }
417        requires E::ne(b, 0) => ConstrReason::Rem
418        if T.is_signed()
419
420        fn (a: T, b: T) -> T
421        if T.is_float()
422    }
423}
424
425/// `a & b`
426fn mk_bit_and_rules() -> RuleMatcher<2> {
427    primop_rules! {
428        fn(a: T, b: T) -> { T[E::prim_val(BitAnd(Sort::Int), a, b)] | E::prim_rel(BitAnd(Sort::Int), a, b) }
429        if T.is_integral()
430
431        fn(a: bool, b: bool) -> bool[E::and(a, b)]
432    }
433}
434
435/// `a | b`
436fn mk_bit_or_rules() -> RuleMatcher<2> {
437    primop_rules! {
438        fn(a: T, b: T) -> { T[E::prim_val(BitOr(Sort::Int), a, b)] | E::prim_rel(BitOr(Sort::Int), a, b) }
439        if T.is_integral()
440
441        fn(a: bool, b: bool) -> bool[E::or(a, b)]
442    }
443}
444
445/// `a ^ b`
446fn mk_bit_xor_rules() -> RuleMatcher<2> {
447    primop_rules! {
448        fn(a: T, b: T) -> { T[E::prim_val(BitXor(Sort::Int), a, b)] | E::prim_rel(BitXor(Sort::Int), a, b) }
449        if T.is_integral()
450    }
451}
452
453// the (a: T, b: S) case ensures the last case is "unreachable", hence the "allow"
454#[allow(unreachable_code)]
455/// `a == b`
456fn mk_eq_rules() -> RuleMatcher<2> {
457    primop_rules! {
458        fn(a: T, b: T) -> bool[E::eq(a, b)]
459        if T.is_integral() || T.is_bool() || T.is_char() || T.is_str()
460        fn(a: T, b: S) -> bool
461    }
462}
463
464#[allow(unreachable_code)]
465/// `a != b`
466fn mk_ne_rules() -> RuleMatcher<2> {
467    primop_rules! {
468        fn(a: T, b: T) -> bool[E::ne(a, b)]
469        if T.is_integral() || T.is_bool()
470
471        fn(a: T, b: S) -> bool
472    }
473}
474
475#[allow(unreachable_code)]
476/// `a <= b`
477fn mk_le_rules() -> RuleMatcher<2> {
478    primop_rules! {
479        fn(a: T, b: T) -> bool[E::le(a, b)]
480        if T.is_integral()
481
482        fn(a: bool, b: bool) -> bool[E::implies(a, b)]
483
484        fn(a: T, b: S) -> bool
485    }
486}
487
488#[allow(unreachable_code)]
489/// `a >= b`
490fn mk_ge_rules() -> RuleMatcher<2> {
491    primop_rules! {
492        fn(a: T, b: T) -> bool[E::ge(a, b)]
493        if T.is_integral()
494
495        fn(a: bool, b: bool) -> bool[E::implies(b, a)]
496
497        fn(a: T, b: S) -> bool
498    }
499}
500
501#[allow(unreachable_code)]
502/// `a < b`
503fn mk_lt_rules() -> RuleMatcher<2> {
504    primop_rules! {
505        fn(a: T, b: T) -> bool[E::lt(a, b)]
506        if T.is_integral()
507
508        fn(a: bool, b: bool) -> bool[E::and(a.not(), b)]
509
510        fn(a: T, b: S) -> bool
511    }
512}
513
514#[allow(unreachable_code)]
515/// `a > b`
516fn mk_gt_rules() -> RuleMatcher<2> {
517    primop_rules! {
518        fn(a: T, b: T) -> bool[E::gt(a, b)]
519        if T.is_integral()
520
521        fn(a: bool, b: bool) -> bool[E::and(a, b.not())]
522
523        fn(a: T, b: S) -> bool
524    }
525}
526
527/// `a << b`
528fn mk_shl_rules() -> RuleMatcher<2> {
529    primop_rules! {
530        fn(a: T, b: S) -> { T[E::prim_val(BitShl(Sort::Int), a, b)] | E::prim_rel(BitShl(Sort::Int), a, b) }
531        if T.is_integral() && S.is_integral()
532    }
533}
534
535/// `a >> b`
536fn mk_shr_rules() -> RuleMatcher<2> {
537    primop_rules! {
538        fn(a: T, b: S) -> { T[E::prim_val(BitShr(Sort::Int), a, b)] | E::prim_rel(BitShr(Sort::Int), a, b) }
539        if T.is_integral() && S.is_integral()
540    }
541}
542
543/// `-a`
544fn mk_neg_rules(overflow_mode: OverflowMode) -> RuleMatcher<1> {
545    match overflow_mode {
546        OverflowMode::Strict => {
547            primop_rules! {
548                fn(a: T) -> T[a.neg()]
549                requires E::ne(a, E::int_min(int_ty)) => ConstrReason::Overflow
550                if let &BaseTy::Int(int_ty) = T
551
552                fn(a: T) -> T[a.neg()]
553                if T.is_float()
554            }
555        }
556        OverflowMode::Lazy | OverflowMode::StrictUnder => {
557            primop_rules! {
558                fn(a: T) -> T{v: E::implies(E::ne(a, E::int_min(int_ty)), E::eq(v, a.neg())) }
559                if let &BaseTy::Int(int_ty) = T
560
561                fn(a: T) -> T[a.neg()]
562                if T.is_float()
563            }
564        }
565        OverflowMode::None => {
566            primop_rules! {
567                fn(a: T) -> T[a.neg()]
568                if T.is_integral()
569
570                fn(a: T) -> T
571                if T.is_float()
572            }
573        }
574    }
575}
576
577/// `!a`
578fn mk_not_rules() -> RuleMatcher<1> {
579    primop_rules! {
580        fn(a: bool) -> bool[a.not()]
581
582        fn(a: T) -> T
583        if T.is_integral()
584    }
585}