1use std::{fmt, hash::Hash, sync::LazyLock};
27
28use flux_common::tracked_span_bug;
29use flux_config::OverflowMode;
30use flux_infer::infer::ConstrReason;
31use flux_macros::primop_rules;
32use flux_middle::rty::{self, BaseTy, Expr, Sort};
33use flux_rustc_bridge::mir;
34use rty::{
35 BinOp::{BitAnd, BitOr, BitShl, BitShr, BitXor, Mod},
36 Expr as E,
37};
38use rustc_data_structures::unord::UnordMap;
39
40#[derive(Debug)]
41pub(crate) struct MatchedRule {
42 pub precondition: Option<Pre>,
43 pub output_type: rty::Ty,
44}
45
46#[derive(Debug)]
47pub(crate) struct Pre {
48 pub reason: ConstrReason,
49 pub pred: Expr,
50}
51
52pub(crate) fn match_bin_op(
53 op: mir::BinOp,
54 bty1: &BaseTy,
55 idx1: &Expr,
56 bty2: &BaseTy,
57 idx2: &Expr,
58 overflow_mode: OverflowMode,
59) -> MatchedRule {
60 let table = match overflow_mode {
61 OverflowMode::Strict => &OVERFLOW_STRICT_BIN_OPS,
62 OverflowMode::Lazy => &OVERFLOW_LAZY_BIN_OPS,
63 OverflowMode::None => &OVERFLOW_NONE_BIN_OPS,
64 OverflowMode::StrictUnder => &OVERFLOW_STRICT_UNDER_BIN_OPS,
65 };
66 table.match_inputs(&op, [(bty1.clone(), idx1.clone()), (bty2.clone(), idx2.clone())])
67}
68
69pub(crate) fn match_un_op(
70 op: mir::UnOp,
71 bty: &BaseTy,
72 idx: &Expr,
73 overflow_mode: OverflowMode,
74) -> MatchedRule {
75 let table = match overflow_mode {
76 OverflowMode::Strict => &OVERFLOW_STRICT_UN_OPS,
77 OverflowMode::None => &OVERFLOW_NONE_UN_OPS,
78 OverflowMode::Lazy | OverflowMode::StrictUnder => &OVERFLOW_LAZY_UN_OPS,
79 };
80 table.match_inputs(&op, [(bty.clone(), idx.clone())])
81}
82
83struct RuleTable<Op: Eq + Hash, const N: usize> {
84 rules: UnordMap<Op, RuleMatcher<N>>,
85}
86
87impl<Op: Eq + Hash + fmt::Debug, const N: usize> RuleTable<Op, N> {
88 fn match_inputs(&self, op: &Op, inputs: [(BaseTy, Expr); N]) -> MatchedRule {
89 (self.rules[op])(&inputs)
90 .unwrap_or_else(|| tracked_span_bug!("no primop rule for {op:?} using {inputs:?}"))
91 }
92}
93
94type RuleMatcher<const N: usize> = fn(&[(BaseTy, Expr); N]) -> Option<MatchedRule>;
95
96static OVERFLOW_NONE_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
97 use mir::BinOp::*;
98 RuleTable {
99 rules: [
100 (Add, mk_add_rules(OverflowMode::None)),
102 (Mul, mk_mul_rules(OverflowMode::None)),
103 (Sub, mk_sub_rules(OverflowMode::None)),
104 (Div, mk_div_rules()),
105 (Rem, mk_rem_rules()),
106 (BitAnd, mk_bit_and_rules()),
108 (BitOr, mk_bit_or_rules()),
109 (BitXor, mk_bit_xor_rules()),
110 (Eq, mk_eq_rules()),
112 (Ne, mk_ne_rules()),
113 (Le, mk_le_rules()),
114 (Ge, mk_ge_rules()),
115 (Lt, mk_lt_rules()),
116 (Gt, mk_gt_rules()),
117 (Shl, mk_shl_rules()),
119 (Shr, mk_shr_rules()),
120 ]
121 .into_iter()
122 .collect(),
123 }
124});
125
126static OVERFLOW_STRICT_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
127 use mir::BinOp::*;
128 RuleTable {
129 rules: [
130 (Add, mk_add_rules(OverflowMode::Strict)),
132 (Mul, mk_mul_rules(OverflowMode::Strict)),
133 (Sub, mk_sub_rules(OverflowMode::Strict)),
134 (Div, mk_div_rules()),
135 (Rem, mk_rem_rules()),
136 (BitAnd, mk_bit_and_rules()),
138 (BitOr, mk_bit_or_rules()),
139 (BitXor, mk_bit_xor_rules()),
140 (Eq, mk_eq_rules()),
142 (Ne, mk_ne_rules()),
143 (Le, mk_le_rules()),
144 (Ge, mk_ge_rules()),
145 (Lt, mk_lt_rules()),
146 (Gt, mk_gt_rules()),
147 (Shl, mk_shl_rules()),
149 (Shr, mk_shr_rules()),
150 ]
151 .into_iter()
152 .collect(),
153 }
154});
155
156static OVERFLOW_LAZY_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
157 use mir::BinOp::*;
158 RuleTable {
159 rules: [
160 (Add, mk_add_rules(OverflowMode::Lazy)),
162 (Mul, mk_mul_rules(OverflowMode::Lazy)),
163 (Sub, mk_sub_rules(OverflowMode::Lazy)),
164 (Div, mk_div_rules()),
165 (Rem, mk_rem_rules()),
166 (BitAnd, mk_bit_and_rules()),
168 (BitOr, mk_bit_or_rules()),
169 (BitXor, mk_bit_xor_rules()),
170 (Eq, mk_eq_rules()),
172 (Ne, mk_ne_rules()),
173 (Le, mk_le_rules()),
174 (Ge, mk_ge_rules()),
175 (Lt, mk_lt_rules()),
176 (Gt, mk_gt_rules()),
177 (Shl, mk_shl_rules()),
179 (Shr, mk_shr_rules()),
180 ]
181 .into_iter()
182 .collect(),
183 }
184});
185
186static OVERFLOW_STRICT_UNDER_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
187 use mir::BinOp::*;
188 RuleTable {
189 rules: [
190 (Add, mk_add_rules(OverflowMode::StrictUnder)),
192 (Mul, mk_mul_rules(OverflowMode::StrictUnder)),
193 (Sub, mk_sub_rules(OverflowMode::StrictUnder)),
194 (Div, mk_div_rules()),
195 (Rem, mk_rem_rules()),
196 (BitAnd, mk_bit_and_rules()),
198 (BitOr, mk_bit_or_rules()),
199 (BitXor, mk_bit_xor_rules()),
200 (Eq, mk_eq_rules()),
202 (Ne, mk_ne_rules()),
203 (Le, mk_le_rules()),
204 (Ge, mk_ge_rules()),
205 (Lt, mk_lt_rules()),
206 (Gt, mk_gt_rules()),
207 (Shl, mk_shl_rules()),
209 (Shr, mk_shr_rules()),
210 ]
211 .into_iter()
212 .collect(),
213 }
214});
215
216static OVERFLOW_NONE_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
217 use mir::UnOp::*;
218 RuleTable {
219 rules: [(Neg, mk_neg_rules(OverflowMode::None)), (Not, mk_not_rules())]
220 .into_iter()
221 .collect(),
222 }
223});
224
225static OVERFLOW_LAZY_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
226 use mir::UnOp::*;
227 RuleTable {
228 rules: [(Neg, mk_neg_rules(OverflowMode::Lazy)), (Not, mk_not_rules())]
229 .into_iter()
230 .collect(),
231 }
232});
233
234static OVERFLOW_STRICT_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
235 use mir::UnOp::*;
236 RuleTable {
237 rules: [(Neg, mk_neg_rules(OverflowMode::Strict)), (Not, mk_not_rules())]
238 .into_iter()
239 .collect(),
240 }
241});
242
243fn valid_int(e: impl Into<Expr>, int_ty: rty::IntTy) -> rty::Expr {
244 let e1 = e.into();
245 let e2 = e1.clone();
246 E::and(E::ge(e1, E::int_min(int_ty)), E::le(e2, E::int_max(int_ty)))
247}
248
249fn valid_uint(e: impl Into<Expr>, uint_ty: rty::UintTy) -> rty::Expr {
250 let e1 = e.into();
251 let e2 = e1.clone();
252 E::and(E::ge(e1, 0), E::le(e2, E::uint_max(uint_ty)))
253}
254
255fn mk_add_rules(overflow_mode: OverflowMode) -> RuleMatcher<2> {
257 match overflow_mode {
258 OverflowMode::Strict => {
259 primop_rules! {
260 fn(a: T, b: T) -> T[a + b]
261 requires valid_int(a + b, int_ty) => ConstrReason::Overflow
262 if let &BaseTy::Int(int_ty) = T
263
264 fn(a: T, b: T) -> T[a + b]
265 requires valid_uint(a + b, uint_ty) => ConstrReason::Overflow
266 if let &BaseTy::Uint(uint_ty) = T
267
268 fn(a: T, b: T) -> T
269 }
270 }
271
272 OverflowMode::Lazy | OverflowMode::StrictUnder => {
273 primop_rules! {
274 fn(a: T, b: T) -> T{v: E::implies(valid_int(a + b, int_ty), E::eq(v, a+b)) }
275 if let &BaseTy::Int(int_ty) = T
276
277 fn(a: T, b: T) -> T{v: E::implies(valid_uint(a + b, uint_ty), E::eq(v, a+b)) }
278 if let &BaseTy::Uint(uint_ty) = T
279
280 fn(a: T, b: T) -> T
281 }
282 }
283
284 OverflowMode::None => {
285 primop_rules! {
286 fn(a: T, b: T) -> T[a + b]
287 if T.is_integral()
288
289 fn(a: T, b: T) -> T
290 }
291 }
292 }
293}
294
295fn mk_mul_rules(overflow_mode: OverflowMode) -> RuleMatcher<2> {
297 match overflow_mode {
298 OverflowMode::Strict => {
299 primop_rules! {
300 fn(a: T, b: T) -> T[a * b]
301 requires valid_int(a * b, int_ty) => ConstrReason::Overflow
302 if let &BaseTy::Int(int_ty) = T
303
304 fn(a: T, b: T) -> T[a * b]
305 requires valid_uint(a * b, uint_ty) => ConstrReason::Overflow
306 if let &BaseTy::Uint(uint_ty) = T
307
308 fn(a: T, b: T) -> T
309 }
310 }
311
312 OverflowMode::Lazy | OverflowMode::StrictUnder => {
313 primop_rules! {
314 fn(a: T, b: T) -> T{v: E::implies(valid_int(a * b, int_ty), E::eq(v, a * b)) }
315 if let &BaseTy::Int(int_ty) = T
316
317 fn(a: T, b: T) -> T{v: E::implies(valid_uint(a * b, uint_ty), E::eq(v, a * b)) }
318 if let &BaseTy::Uint(uint_ty) = T
319
320 fn(a: T, b: T) -> T
321 }
322 }
323
324 OverflowMode::None => {
325 primop_rules!(
326 fn(a: T, b: T) -> T[a * b]
327 if T.is_integral()
328
329 fn(a: T, b: T) -> T
330 if T.is_float()
331 )
332 }
333 }
334}
335
336fn mk_sub_rules(overflow_mode: OverflowMode) -> RuleMatcher<2> {
338 match overflow_mode {
339 OverflowMode::Strict => {
340 primop_rules! {
341 fn(a: T, b: T) -> T[a - b]
342 requires valid_int(a - b, int_ty) => ConstrReason::Overflow
343 if let &BaseTy::Int(int_ty) = T
344
345 fn(a: T, b: T) -> T[a - b]
346 requires valid_uint(a - b, uint_ty) => ConstrReason::Overflow
347 if let &BaseTy::Uint(uint_ty) = T
348
349 fn(a: T, b: T) -> T
350 }
351 }
352
353 OverflowMode::StrictUnder => {
355 primop_rules! {
356 fn(a: T, b: T) -> T{v: E::implies(valid_int(a - b, int_ty), E::eq(v, a - b)) }
357 if let &BaseTy::Int(int_ty) = T
358
359 fn(a: T, b: T) -> T{v: E::implies(valid_uint(a - b, uint_ty), E::eq(v, a - b)) }
360 requires E::ge(a - b, 0) => ConstrReason::Underflow
361 if let &BaseTy::Uint(uint_ty) = T
362
363 fn(a: T, b: T) -> T
364 }
365 }
366
367 OverflowMode::Lazy => {
368 primop_rules! {
369 fn(a: T, b: T) -> T{v: E::implies(valid_int(a - b, int_ty), E::eq(v, a - b)) }
370 if let &BaseTy::Int(int_ty) = T
371
372 fn(a: T, b: T) -> T{v: E::implies(valid_uint(a - b, uint_ty), E::eq(v, a - b)) }
373 if let &BaseTy::Uint(uint_ty) = T
374
375 fn(a: T, b: T) -> T
376 }
377 }
378
379 OverflowMode::None => {
380 primop_rules! {
381 fn(a: T, b: T) -> T[a - b]
382 requires E::ge(a - b, 0) => ConstrReason::Underflow
383 if T.is_unsigned()
384
385 fn(a: T, b: T) -> T[a - b]
386 if T.is_signed()
387
388 fn(a: T, b: T) -> T
389 if T.is_float()
390 }
391 }
392 }
393}
394
395fn mk_div_rules() -> RuleMatcher<2> {
397 primop_rules! {
398 fn(a: T, b: T) -> T[a/b]
399 requires E::ne(b, 0) => ConstrReason::Div
400 if T.is_integral()
401
402 fn(a: T, b: T) -> T
403 if T.is_float()
404 }
405}
406
407fn mk_rem_rules() -> RuleMatcher<2> {
409 primop_rules! {
410 fn(a: T, b: T) -> T[E::binary_op(Mod(Sort::Int), a, b)]
411 requires E::ne(b, 0) => ConstrReason::Rem
412 if T.is_unsigned()
413
414 fn(a: T, b: T) -> T{v: E::implies(
415 E::and(E::ge(a, 0), E::ge(b, 0)),
416 E::eq(v, E::binary_op(Mod(Sort::Int), a, b))) }
417 requires E::ne(b, 0) => ConstrReason::Rem
418 if T.is_signed()
419
420 fn (a: T, b: T) -> T
421 if T.is_float()
422 }
423}
424
425fn mk_bit_and_rules() -> RuleMatcher<2> {
427 primop_rules! {
428 fn(a: T, b: T) -> { T[E::prim_val(BitAnd(Sort::Int), a, b)] | E::prim_rel(BitAnd(Sort::Int), a, b) }
429 if T.is_integral()
430
431 fn(a: bool, b: bool) -> bool[E::and(a, b)]
432 }
433}
434
435fn mk_bit_or_rules() -> RuleMatcher<2> {
437 primop_rules! {
438 fn(a: T, b: T) -> { T[E::prim_val(BitOr(Sort::Int), a, b)] | E::prim_rel(BitOr(Sort::Int), a, b) }
439 if T.is_integral()
440
441 fn(a: bool, b: bool) -> bool[E::or(a, b)]
442 }
443}
444
445fn mk_bit_xor_rules() -> RuleMatcher<2> {
447 primop_rules! {
448 fn(a: T, b: T) -> { T[E::prim_val(BitXor(Sort::Int), a, b)] | E::prim_rel(BitXor(Sort::Int), a, b) }
449 if T.is_integral()
450 }
451}
452
453#[allow(unreachable_code)]
455fn mk_eq_rules() -> RuleMatcher<2> {
457 primop_rules! {
458 fn(a: T, b: T) -> bool[E::eq(a, b)]
459 if T.is_integral() || T.is_bool() || T.is_char() || T.is_str()
460 fn(a: T, b: S) -> bool
461 }
462}
463
464#[allow(unreachable_code)]
465fn mk_ne_rules() -> RuleMatcher<2> {
467 primop_rules! {
468 fn(a: T, b: T) -> bool[E::ne(a, b)]
469 if T.is_integral() || T.is_bool()
470
471 fn(a: T, b: S) -> bool
472 }
473}
474
475#[allow(unreachable_code)]
476fn mk_le_rules() -> RuleMatcher<2> {
478 primop_rules! {
479 fn(a: T, b: T) -> bool[E::le(a, b)]
480 if T.is_integral()
481
482 fn(a: bool, b: bool) -> bool[E::implies(a, b)]
483
484 fn(a: T, b: S) -> bool
485 }
486}
487
488#[allow(unreachable_code)]
489fn mk_ge_rules() -> RuleMatcher<2> {
491 primop_rules! {
492 fn(a: T, b: T) -> bool[E::ge(a, b)]
493 if T.is_integral()
494
495 fn(a: bool, b: bool) -> bool[E::implies(b, a)]
496
497 fn(a: T, b: S) -> bool
498 }
499}
500
501#[allow(unreachable_code)]
502fn mk_lt_rules() -> RuleMatcher<2> {
504 primop_rules! {
505 fn(a: T, b: T) -> bool[E::lt(a, b)]
506 if T.is_integral()
507
508 fn(a: bool, b: bool) -> bool[E::and(a.not(), b)]
509
510 fn(a: T, b: S) -> bool
511 }
512}
513
514#[allow(unreachable_code)]
515fn mk_gt_rules() -> RuleMatcher<2> {
517 primop_rules! {
518 fn(a: T, b: T) -> bool[E::gt(a, b)]
519 if T.is_integral()
520
521 fn(a: bool, b: bool) -> bool[E::and(a, b.not())]
522
523 fn(a: T, b: S) -> bool
524 }
525}
526
527fn mk_shl_rules() -> RuleMatcher<2> {
529 primop_rules! {
530 fn(a: T, b: S) -> { T[E::prim_val(BitShl(Sort::Int), a, b)] | E::prim_rel(BitShl(Sort::Int), a, b) }
531 if T.is_integral() && S.is_integral()
532 }
533}
534
535fn mk_shr_rules() -> RuleMatcher<2> {
537 primop_rules! {
538 fn(a: T, b: S) -> { T[E::prim_val(BitShr(Sort::Int), a, b)] | E::prim_rel(BitShr(Sort::Int), a, b) }
539 if T.is_integral() && S.is_integral()
540 }
541}
542
543fn mk_neg_rules(overflow_mode: OverflowMode) -> RuleMatcher<1> {
545 match overflow_mode {
546 OverflowMode::Strict => {
547 primop_rules! {
548 fn(a: T) -> T[a.neg()]
549 requires E::ne(a, E::int_min(int_ty)) => ConstrReason::Overflow
550 if let &BaseTy::Int(int_ty) = T
551
552 fn(a: T) -> T[a.neg()]
553 if T.is_float()
554 }
555 }
556 OverflowMode::Lazy | OverflowMode::StrictUnder => {
557 primop_rules! {
558 fn(a: T) -> T{v: E::implies(E::ne(a, E::int_min(int_ty)), E::eq(v, a.neg())) }
559 if let &BaseTy::Int(int_ty) = T
560
561 fn(a: T) -> T[a.neg()]
562 if T.is_float()
563 }
564 }
565 OverflowMode::None => {
566 primop_rules! {
567 fn(a: T) -> T[a.neg()]
568 if T.is_integral()
569
570 fn(a: T) -> T
571 if T.is_float()
572 }
573 }
574 }
575}
576
577fn mk_not_rules() -> RuleMatcher<1> {
579 primop_rules! {
580 fn(a: bool) -> bool[a.not()]
581
582 fn(a: T) -> T
583 if T.is_integral()
584 }
585}