1use std::{fmt, hash::Hash, sync::LazyLock};
27
28use flux_common::tracked_span_bug;
29use flux_config::OverflowMode;
30use flux_infer::infer::ConstrReason;
31use flux_macros::primop_rules;
32use flux_middle::rty::{self, BaseTy, Expr, Sort};
33use flux_rustc_bridge::mir;
34use rty::{
35 BinOp::{BitAnd, BitOr, BitShl, BitShr, BitXor, Mod},
36 Expr as E,
37};
38use rustc_data_structures::unord::UnordMap;
39
40#[derive(Debug)]
41pub(crate) struct MatchedRule {
42 pub precondition: Option<Pre>,
43 pub output_type: rty::Ty,
44}
45
46#[derive(Debug)]
47pub(crate) struct Pre {
48 pub reason: ConstrReason,
49 pub pred: Expr,
50}
51
52pub(crate) fn match_bin_op(
53 op: mir::BinOp,
54 bty1: &BaseTy,
55 idx1: &Expr,
56 bty2: &BaseTy,
57 idx2: &Expr,
58 overflow_mode: OverflowMode,
59) -> MatchedRule {
60 let table = match overflow_mode {
61 OverflowMode::Strict => &OVERFLOW_STRICT_BIN_OPS,
62 OverflowMode::Lazy => &OVERFLOW_LAZY_BIN_OPS,
63 OverflowMode::None => &OVERFLOW_NONE_BIN_OPS,
64 OverflowMode::StrictUnder => &OVERFLOW_STRICT_UNDER_BIN_OPS,
65 };
66 table.match_inputs(&op, [(bty1.clone(), idx1.clone()), (bty2.clone(), idx2.clone())])
67}
68
69pub(crate) fn match_un_op(
70 op: mir::UnOp,
71 bty: &BaseTy,
72 idx: &Expr,
73 overflow_mode: OverflowMode,
74) -> MatchedRule {
75 let table = match overflow_mode {
76 OverflowMode::Strict => &OVERFLOW_STRICT_UN_OPS,
77 OverflowMode::None => &OVERFLOW_NONE_UN_OPS,
78 OverflowMode::Lazy | OverflowMode::StrictUnder => &OVERFLOW_LAZY_UN_OPS,
79 };
80 table.match_inputs(&op, [(bty.clone(), idx.clone())])
81}
82
83struct RuleTable<Op: Eq + Hash, const N: usize> {
84 rules: UnordMap<Op, RuleMatcher<N>>,
85}
86
87impl<Op: Eq + Hash + fmt::Debug, const N: usize> RuleTable<Op, N> {
88 fn match_inputs(&self, op: &Op, inputs: [(BaseTy, Expr); N]) -> MatchedRule {
89 (self.rules[op])(&inputs)
90 .unwrap_or_else(|| tracked_span_bug!("no primop rule for {op:?} using {inputs:?}"))
91 }
92}
93
94type RuleMatcher<const N: usize> = fn(&[(BaseTy, Expr); N]) -> Option<MatchedRule>;
95
96static OVERFLOW_NONE_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
97 use mir::BinOp::*;
98 RuleTable {
99 rules: [
100 (Add, mk_add_rules(OverflowMode::None)),
102 (Mul, mk_mul_rules(OverflowMode::None)),
103 (Sub, mk_sub_rules(OverflowMode::None)),
104 (Div, mk_div_rules()),
105 (Rem, mk_rem_rules()),
106 (BitAnd, mk_bit_and_rules()),
108 (BitOr, mk_bit_or_rules()),
109 (BitXor, mk_bit_xor_rules()),
110 (Eq, mk_eq_rules()),
112 (Ne, mk_ne_rules()),
113 (Le, mk_le_rules()),
114 (Ge, mk_ge_rules()),
115 (Lt, mk_lt_rules()),
116 (Gt, mk_gt_rules()),
117 (Shl, mk_shl_rules()),
119 (Shr, mk_shr_rules()),
120 ]
121 .into_iter()
122 .collect(),
123 }
124});
125
126static OVERFLOW_STRICT_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
127 use mir::BinOp::*;
128 RuleTable {
129 rules: [
130 (Add, mk_add_rules(OverflowMode::Strict)),
132 (Mul, mk_mul_rules(OverflowMode::Strict)),
133 (Sub, mk_sub_rules(OverflowMode::Strict)),
134 (Div, mk_div_rules()),
135 (Rem, mk_rem_rules()),
136 (BitAnd, mk_bit_and_rules()),
138 (BitOr, mk_bit_or_rules()),
139 (BitXor, mk_bit_xor_rules()),
140 (Eq, mk_eq_rules()),
142 (Ne, mk_ne_rules()),
143 (Le, mk_le_rules()),
144 (Ge, mk_ge_rules()),
145 (Lt, mk_lt_rules()),
146 (Gt, mk_gt_rules()),
147 (Shl, mk_shl_rules()),
149 (Shr, mk_shr_rules()),
150 ]
151 .into_iter()
152 .collect(),
153 }
154});
155
156static OVERFLOW_LAZY_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
157 use mir::BinOp::*;
158 RuleTable {
159 rules: [
160 (Add, mk_add_rules(OverflowMode::Lazy)),
162 (Mul, mk_mul_rules(OverflowMode::Lazy)),
163 (Sub, mk_sub_rules(OverflowMode::Lazy)),
164 (Div, mk_div_rules()),
165 (Rem, mk_rem_rules()),
166 (BitAnd, mk_bit_and_rules()),
168 (BitOr, mk_bit_or_rules()),
169 (BitXor, mk_bit_xor_rules()),
170 (Eq, mk_eq_rules()),
172 (Ne, mk_ne_rules()),
173 (Le, mk_le_rules()),
174 (Ge, mk_ge_rules()),
175 (Lt, mk_lt_rules()),
176 (Gt, mk_gt_rules()),
177 (Shl, mk_shl_rules()),
179 (Shr, mk_shr_rules()),
180 ]
181 .into_iter()
182 .collect(),
183 }
184});
185
186static OVERFLOW_STRICT_UNDER_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
187 use mir::BinOp::*;
188 RuleTable {
189 rules: [
190 (Add, mk_add_rules(OverflowMode::StrictUnder)),
192 (Mul, mk_mul_rules(OverflowMode::StrictUnder)),
193 (Sub, mk_sub_rules(OverflowMode::StrictUnder)),
194 (Div, mk_div_rules()),
195 (Rem, mk_rem_rules()),
196 (BitAnd, mk_bit_and_rules()),
198 (BitOr, mk_bit_or_rules()),
199 (BitXor, mk_bit_xor_rules()),
200 (Eq, mk_eq_rules()),
202 (Ne, mk_ne_rules()),
203 (Le, mk_le_rules()),
204 (Ge, mk_ge_rules()),
205 (Lt, mk_lt_rules()),
206 (Gt, mk_gt_rules()),
207 (Shl, mk_shl_rules()),
209 (Shr, mk_shr_rules()),
210 ]
211 .into_iter()
212 .collect(),
213 }
214});
215
216static OVERFLOW_NONE_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
217 use mir::UnOp::*;
218 RuleTable {
219 rules: [(Neg, mk_neg_rules(OverflowMode::None)), (Not, mk_not_rules())]
220 .into_iter()
221 .collect(),
222 }
223});
224
225static OVERFLOW_LAZY_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
226 use mir::UnOp::*;
227 RuleTable {
228 rules: [(Neg, mk_neg_rules(OverflowMode::Lazy)), (Not, mk_not_rules())]
229 .into_iter()
230 .collect(),
231 }
232});
233
234static OVERFLOW_STRICT_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
235 use mir::UnOp::*;
236 RuleTable {
237 rules: [(Neg, mk_neg_rules(OverflowMode::Strict)), (Not, mk_not_rules())]
238 .into_iter()
239 .collect(),
240 }
241});
242
243fn valid_int(e: impl Into<Expr>, int_ty: rty::IntTy) -> rty::Expr {
244 let e1 = e.into();
245 let e2 = e1.clone();
246 E::and(E::ge(e1, E::int_min(int_ty)), E::le(e2, E::int_max(int_ty)))
247}
248
249fn valid_uint(e: impl Into<Expr>, uint_ty: rty::UintTy) -> rty::Expr {
250 let e1 = e.into();
251 let e2 = e1.clone();
252 E::and(E::ge(e1, 0), E::le(e2, E::uint_max(uint_ty)))
253}
254
255fn mk_add_rules(overflow_mode: OverflowMode) -> RuleMatcher<2> {
257 match overflow_mode {
258 OverflowMode::Strict => {
259 primop_rules! {
260 fn(a: T, b: T) -> T[a + b]
261 requires valid_int(a + b, int_ty) => ConstrReason::Overflow
262 if let &BaseTy::Int(int_ty) = T
263
264 fn(a: T, b: T) -> T[a + b]
265 requires valid_uint(a + b, uint_ty) => ConstrReason::Overflow
266 if let &BaseTy::Uint(uint_ty) = T
267
268 fn(a: T, b: T) -> T
269 }
270 }
271
272 OverflowMode::Lazy | OverflowMode::StrictUnder => {
273 primop_rules! {
274 fn(a: T, b: T) -> T{v: E::implies(valid_int(a + b, int_ty), E::eq(v, a+b)) }
275 if let &BaseTy::Int(int_ty) = T
276
277 fn(a: T, b: T) -> T{v: E::implies(valid_uint(a + b, uint_ty), E::eq(v, a+b)) }
278 if let &BaseTy::Uint(uint_ty) = T
279
280 fn(a: T, b: T) -> T
281 }
282 }
283
284 OverflowMode::None => {
285 primop_rules! {
286 fn(a: T, b: T) -> T[a + b]
287 if T.is_integral()
288
289 fn(a: T, b: T) -> T
290 }
291 }
292 }
293}
294
295fn mk_mul_rules(overflow_mode: OverflowMode) -> RuleMatcher<2> {
297 match overflow_mode {
298 OverflowMode::Strict => {
299 primop_rules! {
300 fn(a: T, b: T) -> T[a * b]
301 requires valid_int(a * b, int_ty) => ConstrReason::Overflow
302 if let &BaseTy::Int(int_ty) = T
303
304 fn(a: T, b: T) -> T[a * b]
305 requires valid_uint(a * b, uint_ty) => ConstrReason::Overflow
306 if let &BaseTy::Uint(uint_ty) = T
307
308 fn(a: T, b: T) -> T
309 }
310 }
311
312 OverflowMode::Lazy | OverflowMode::StrictUnder => {
313 primop_rules! {
314 fn(a: T, b: T) -> T{v: E::implies(valid_int(a * b, int_ty), E::eq(v, a * b)) }
315 if let &BaseTy::Int(int_ty) = T
316
317 fn(a: T, b: T) -> T{v: E::implies(valid_uint(a * b, uint_ty), E::eq(v, a * b)) }
318 if let &BaseTy::Uint(uint_ty) = T
319
320 fn(a: T, b: T) -> T
321 }
322 }
323
324 OverflowMode::None => {
325 primop_rules!(
326 fn(a: T, b: T) -> T[a * b]
327 if T.is_integral()
328
329 fn(a: T, b: T) -> T
330 if T.is_float()
331 )
332 }
333 }
334}
335
336fn mk_sub_rules(overflow_mode: OverflowMode) -> RuleMatcher<2> {
338 match overflow_mode {
339 OverflowMode::Strict => {
340 primop_rules! {
341 fn(a: T, b: T) -> T[a - b]
342 requires valid_int(a - b, int_ty) => ConstrReason::Overflow
343 if let &BaseTy::Int(int_ty) = T
344
345 fn(a: T, b: T) -> T[a - b]
346 requires valid_uint(a - b, uint_ty) => ConstrReason::Overflow
347 if let &BaseTy::Uint(uint_ty) = T
348
349 fn(a: T, b: T) -> T
350 }
351 }
352
353 OverflowMode::StrictUnder => {
355 primop_rules! {
356 fn(a: T, b: T) -> T{v: E::implies(valid_int(a - b, int_ty), E::eq(v, a - b)) }
357 if let &BaseTy::Int(int_ty) = T
358
359 fn(a: T, b: T) -> T{v: E::implies(valid_uint(a - b, uint_ty), E::eq(v, a - b)) }
360 requires E::ge(a - b, 0) => ConstrReason::Underflow
361 if let &BaseTy::Uint(uint_ty) = T
362
363 fn(a: T, b: T) -> T
364 }
365 }
366
367 OverflowMode::Lazy => {
368 primop_rules! {
369 fn(a: T, b: T) -> T{v: E::implies(valid_int(a - b, int_ty), E::eq(v, a - b)) }
370 if let &BaseTy::Int(int_ty) = T
371
372 fn(a: T, b: T) -> T{v: E::implies(valid_uint(a - b, uint_ty), E::eq(v, a - b)) }
373 if let &BaseTy::Uint(uint_ty) = T
374
375 fn(a: T, b: T) -> T
376 }
377 }
378
379 OverflowMode::None => {
380 primop_rules! {
381 fn(a: T, b: T) -> T[a - b]
382 requires E::ge(a - b, 0) => ConstrReason::Underflow
383 if T.is_unsigned()
384
385 fn(a: T, b: T) -> T[a - b]
386 if T.is_signed()
387
388 fn(a: T, b: T) -> T
389 if T.is_float()
390 }
391 }
392 }
393}
394
395fn mk_div_rules() -> RuleMatcher<2> {
397 primop_rules! {
398 fn(a: T, b: T) -> T[a/b]
399 requires E::ne(b, 0) => ConstrReason::Div
400 if T.is_integral()
401
402 fn(a: T, b: T) -> T
403 if T.is_float()
404 }
405}
406
407fn mk_rem_rules() -> RuleMatcher<2> {
409 primop_rules! {
410 fn(a: T, b: T) -> T[E::binary_op(Mod(Sort::Int), a, b)]
411 requires E::ne(b, 0) => ConstrReason::Rem
412 if T.is_unsigned()
413
414 fn(a: T, b: T) -> T{v: E::implies(
415 E::and(E::ge(a, 0), E::ge(b, 0)),
416 E::eq(v, E::binary_op(Mod(Sort::Int), a, b))) }
417 requires E::ne(b, 0) => ConstrReason::Rem
418 if T.is_signed()
419
420 fn (a: T, b: T) -> T
421 if T.is_float()
422 }
423}
424
425fn mk_bit_and_rules() -> RuleMatcher<2> {
427 primop_rules! {
428 fn(a: T, b: T) -> { T[E::prim_val(BitAnd(Sort::Int), a, b)] | E::prim_rel(BitAnd(Sort::Int), a, b) }
429 if T.is_integral()
430
431 fn(a: bool, b: bool) -> bool[E::and(a, b)]
432 }
433}
434
435fn mk_bit_or_rules() -> RuleMatcher<2> {
437 primop_rules! {
438 fn(a: T, b: T) -> { T[E::prim_val(BitOr(Sort::Int), a, b)] | E::prim_rel(BitOr(Sort::Int), a, b) }
439 if T.is_integral()
440
441 fn(a: bool, b: bool) -> bool[E::or(a, b)]
442 }
443}
444
445fn mk_bit_xor_rules() -> RuleMatcher<2> {
447 primop_rules! {
448 fn(a: T, b: T) -> { T[E::prim_val(BitXor(Sort::Int), a, b)] | E::prim_rel(BitXor(Sort::Int), a, b) }
449 if T.is_integral()
450 }
451}
452
453fn mk_eq_rules() -> RuleMatcher<2> {
455 primop_rules! {
456 fn(a: T, b: T) -> bool[E::eq(a, b)]
457 if T.is_integral() || T.is_bool() || T.is_char() || T.is_str()
458
459 fn(a: T, b: T) -> bool
460 }
461}
462
463fn mk_ne_rules() -> RuleMatcher<2> {
465 primop_rules! {
466 fn(a: T, b: T) -> bool[E::ne(a, b)]
467 if T.is_integral() || T.is_bool()
468
469 fn(a: T, b: T) -> bool
470 }
471}
472
473fn mk_le_rules() -> RuleMatcher<2> {
475 primop_rules! {
476 fn(a: T, b: T) -> bool[E::le(a, b)]
477 if T.is_integral()
478
479 fn(a: bool, b: bool) -> bool[E::implies(a, b)]
480
481 fn(a: T, b: T) -> bool
482 }
483}
484
485fn mk_ge_rules() -> RuleMatcher<2> {
487 primop_rules! {
488 fn(a: T, b: T) -> bool[E::ge(a, b)]
489 if T.is_integral()
490
491 fn(a: bool, b: bool) -> bool[E::implies(b, a)]
492
493 fn(a: T, b: T) -> bool
494 }
495}
496
497fn mk_lt_rules() -> RuleMatcher<2> {
499 primop_rules! {
500 fn(a: T, b: T) -> bool[E::lt(a, b)]
501 if T.is_integral()
502
503 fn(a: bool, b: bool) -> bool[E::and(a.not(), b)]
504
505 fn(a: T, b: T) -> bool
506 }
507}
508
509fn mk_gt_rules() -> RuleMatcher<2> {
511 primop_rules! {
512 fn(a: T, b: T) -> bool[E::gt(a, b)]
513 if T.is_integral()
514
515 fn(a: bool, b: bool) -> bool[E::and(a, b.not())]
516
517 fn(a: T, b: T) -> bool
518 }
519}
520
521fn mk_shl_rules() -> RuleMatcher<2> {
523 primop_rules! {
524 fn(a: T, b: S) -> { T[E::prim_val(BitShl(Sort::Int), a, b)] | E::prim_rel(BitShl(Sort::Int), a, b) }
525 if T.is_integral() && S.is_integral()
526 }
527}
528
529fn mk_shr_rules() -> RuleMatcher<2> {
531 primop_rules! {
532 fn(a: T, b: S) -> { T[E::prim_val(BitShr(Sort::Int), a, b)] | E::prim_rel(BitShr(Sort::Int), a, b) }
533 if T.is_integral() && S.is_integral()
534 }
535}
536
537fn mk_neg_rules(overflow_mode: OverflowMode) -> RuleMatcher<1> {
539 match overflow_mode {
540 OverflowMode::Strict => {
541 primop_rules! {
542 fn(a: T) -> T[a.neg()]
543 requires E::ne(a, E::int_min(int_ty)) => ConstrReason::Overflow
544 if let &BaseTy::Int(int_ty) = T
545
546 fn(a: T) -> T[a.neg()]
547 if T.is_float()
548 }
549 }
550 OverflowMode::Lazy | OverflowMode::StrictUnder => {
551 primop_rules! {
552 fn(a: T) -> T{v: E::implies(E::ne(a, E::int_min(int_ty)), E::eq(v, a.neg())) }
553 if let &BaseTy::Int(int_ty) = T
554
555 fn(a: T) -> T[a.neg()]
556 if T.is_float()
557 }
558 }
559 OverflowMode::None => {
560 primop_rules! {
561 fn(a: T) -> T[a.neg()]
562 if T.is_integral()
563
564 fn(a: T) -> T
565 if T.is_float()
566 }
567 }
568 }
569}
570
571fn mk_not_rules() -> RuleMatcher<1> {
573 primop_rules! {
574 fn(a: bool) -> bool[a.not()]
575
576 fn(a: T) -> T
577 if T.is_integral()
578 }
579}