1use std::{fmt, hash::Hash, sync::LazyLock};
27
28use flux_common::tracked_span_bug;
29use flux_config::OverflowMode;
30use flux_infer::infer::ConstrReason;
31use flux_macros::primop_rules;
32use flux_middle::rty::{self, BaseTy, Expr, Sort};
33use flux_rustc_bridge::mir;
34use rty::{BinOp::Mod, Expr as E};
35use rustc_data_structures::unord::UnordMap;
36
37#[derive(Debug)]
38pub(crate) struct MatchedRule {
39 pub precondition: Option<Pre>,
40 pub output_type: rty::Ty,
41}
42
43#[derive(Debug)]
44pub(crate) struct Pre {
45 pub reason: ConstrReason,
46 pub pred: Expr,
47}
48
49pub(crate) fn match_bin_op(
50 op: mir::BinOp,
51 bty1: &BaseTy,
52 idx1: &Expr,
53 bty2: &BaseTy,
54 idx2: &Expr,
55 overflow_mode: OverflowMode,
56) -> MatchedRule {
57 let table = match overflow_mode {
58 OverflowMode::Strict => &OVERFLOW_STRICT_BIN_OPS,
59 OverflowMode::Lazy => &OVERFLOW_LAZY_BIN_OPS,
60 OverflowMode::None => &OVERFLOW_NONE_BIN_OPS,
61 OverflowMode::StrictUnder => &OVERFLOW_STRICT_UNDER_BIN_OPS,
62 };
63 table.match_inputs(&op, [(bty1.clone(), idx1.clone()), (bty2.clone(), idx2.clone())])
64}
65
66pub(crate) fn match_un_op(
67 op: mir::UnOp,
68 bty: &BaseTy,
69 idx: &Expr,
70 overflow_mode: OverflowMode,
71) -> MatchedRule {
72 let table = match overflow_mode {
73 OverflowMode::Strict => &OVERFLOW_STRICT_UN_OPS,
74 OverflowMode::None => &OVERFLOW_NONE_UN_OPS,
75 OverflowMode::Lazy | OverflowMode::StrictUnder => &OVERFLOW_LAZY_UN_OPS,
76 };
77 table.match_inputs(&op, [(bty.clone(), idx.clone())])
78}
79
80struct RuleTable<Op: Eq + Hash, const N: usize> {
81 rules: UnordMap<Op, RuleMatcher<N>>,
82}
83
84impl<Op: Eq + Hash + fmt::Debug, const N: usize> RuleTable<Op, N> {
85 fn match_inputs(&self, op: &Op, inputs: [(BaseTy, Expr); N]) -> MatchedRule {
86 (self.rules[op])(&inputs)
87 .unwrap_or_else(|| tracked_span_bug!("no primop rule for {op:?} using {inputs:?}"))
88 }
89}
90
91type RuleMatcher<const N: usize> = fn(&[(BaseTy, Expr); N]) -> Option<MatchedRule>;
92
93static OVERFLOW_NONE_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
94 use mir::BinOp::*;
95 RuleTable {
96 rules: [
97 (Add, mk_add_rules(OverflowMode::None)),
99 (Mul, mk_mul_rules(OverflowMode::None)),
100 (Sub, mk_sub_rules(OverflowMode::None)),
101 (Div, mk_div_rules()),
102 (Rem, mk_rem_rules()),
103 (BitAnd, mk_bit_and_rules()),
105 (BitOr, mk_bit_or_rules()),
106 (BitXor, mk_bit_xor_rules()),
107 (Eq, mk_eq_rules()),
109 (Ne, mk_ne_rules()),
110 (Le, mk_le_rules()),
111 (Ge, mk_ge_rules()),
112 (Lt, mk_lt_rules()),
113 (Gt, mk_gt_rules()),
114 (Shl, mk_shl_rules()),
116 (Shr, mk_shr_rules()),
117 ]
118 .into_iter()
119 .collect(),
120 }
121});
122
123static OVERFLOW_STRICT_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
124 use mir::BinOp::*;
125 RuleTable {
126 rules: [
127 (Add, mk_add_rules(OverflowMode::Strict)),
129 (Mul, mk_mul_rules(OverflowMode::Strict)),
130 (Sub, mk_sub_rules(OverflowMode::Strict)),
131 (Div, mk_div_rules()),
132 (Rem, mk_rem_rules()),
133 (BitAnd, mk_bit_and_rules()),
135 (BitOr, mk_bit_or_rules()),
136 (BitXor, mk_bit_xor_rules()),
137 (Eq, mk_eq_rules()),
139 (Ne, mk_ne_rules()),
140 (Le, mk_le_rules()),
141 (Ge, mk_ge_rules()),
142 (Lt, mk_lt_rules()),
143 (Gt, mk_gt_rules()),
144 (Shl, mk_shl_rules()),
146 (Shr, mk_shr_rules()),
147 ]
148 .into_iter()
149 .collect(),
150 }
151});
152
153static OVERFLOW_LAZY_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
154 use mir::BinOp::*;
155 RuleTable {
156 rules: [
157 (Add, mk_add_rules(OverflowMode::Lazy)),
159 (Mul, mk_mul_rules(OverflowMode::Lazy)),
160 (Sub, mk_sub_rules(OverflowMode::Lazy)),
161 (Div, mk_div_rules()),
162 (Rem, mk_rem_rules()),
163 (BitAnd, mk_bit_and_rules()),
165 (BitOr, mk_bit_or_rules()),
166 (BitXor, mk_bit_xor_rules()),
167 (Eq, mk_eq_rules()),
169 (Ne, mk_ne_rules()),
170 (Le, mk_le_rules()),
171 (Ge, mk_ge_rules()),
172 (Lt, mk_lt_rules()),
173 (Gt, mk_gt_rules()),
174 (Shl, mk_shl_rules()),
176 (Shr, mk_shr_rules()),
177 ]
178 .into_iter()
179 .collect(),
180 }
181});
182
183static OVERFLOW_STRICT_UNDER_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
184 use mir::BinOp::*;
185 RuleTable {
186 rules: [
187 (Add, mk_add_rules(OverflowMode::StrictUnder)),
189 (Mul, mk_mul_rules(OverflowMode::StrictUnder)),
190 (Sub, mk_sub_rules(OverflowMode::StrictUnder)),
191 (Div, mk_div_rules()),
192 (Rem, mk_rem_rules()),
193 (BitAnd, mk_bit_and_rules()),
195 (BitOr, mk_bit_or_rules()),
196 (BitXor, mk_bit_xor_rules()),
197 (Eq, mk_eq_rules()),
199 (Ne, mk_ne_rules()),
200 (Le, mk_le_rules()),
201 (Ge, mk_ge_rules()),
202 (Lt, mk_lt_rules()),
203 (Gt, mk_gt_rules()),
204 (Shl, mk_shl_rules()),
206 (Shr, mk_shr_rules()),
207 ]
208 .into_iter()
209 .collect(),
210 }
211});
212
213static OVERFLOW_NONE_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
214 use mir::UnOp::*;
215 RuleTable {
216 rules: [(Neg, mk_neg_rules(OverflowMode::None)), (Not, mk_not_rules())]
217 .into_iter()
218 .collect(),
219 }
220});
221
222static OVERFLOW_LAZY_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
223 use mir::UnOp::*;
224 RuleTable {
225 rules: [(Neg, mk_neg_rules(OverflowMode::Lazy)), (Not, mk_not_rules())]
226 .into_iter()
227 .collect(),
228 }
229});
230
231static OVERFLOW_STRICT_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
232 use mir::UnOp::*;
233 RuleTable {
234 rules: [(Neg, mk_neg_rules(OverflowMode::Strict)), (Not, mk_not_rules())]
235 .into_iter()
236 .collect(),
237 }
238});
239
240fn valid_int(e: impl Into<Expr>, int_ty: rty::IntTy) -> rty::Expr {
241 let e1 = e.into();
242 let e2 = e1.clone();
243 E::and(E::ge(e1, E::int_min(int_ty)), E::le(e2, E::int_max(int_ty)))
244}
245
246fn valid_uint(e: impl Into<Expr>, uint_ty: rty::UintTy) -> rty::Expr {
247 let e1 = e.into();
248 let e2 = e1.clone();
249 E::and(E::ge(e1, 0), E::le(e2, E::uint_max(uint_ty)))
250}
251
252fn mk_add_rules(overflow_mode: OverflowMode) -> RuleMatcher<2> {
254 match overflow_mode {
255 OverflowMode::Strict => {
256 primop_rules! {
257 fn(a: T, b: T) -> T[a + b]
258 requires valid_int(a + b, int_ty) => ConstrReason::Overflow
259 if let &BaseTy::Int(int_ty) = T
260
261 fn(a: T, b: T) -> T[a + b]
262 requires valid_uint(a + b, uint_ty) => ConstrReason::Overflow
263 if let &BaseTy::Uint(uint_ty) = T
264
265 fn(a: T, b: T) -> T
266 }
267 }
268
269 OverflowMode::Lazy | OverflowMode::StrictUnder => {
270 primop_rules! {
271 fn(a: T, b: T) -> T{v: E::implies(valid_int(a + b, int_ty), E::eq(v, a+b)) }
272 if let &BaseTy::Int(int_ty) = T
273
274 fn(a: T, b: T) -> T{v: E::implies(valid_uint(a + b, uint_ty), E::eq(v, a+b)) }
275 if let &BaseTy::Uint(uint_ty) = T
276
277 fn(a: T, b: T) -> T
278 }
279 }
280
281 OverflowMode::None => {
282 primop_rules! {
283 fn(a: T, b: T) -> T[a + b]
284 if T.is_integral()
285
286 fn(a: T, b: T) -> T
287 }
288 }
289 }
290}
291
292fn mk_mul_rules(overflow_mode: OverflowMode) -> RuleMatcher<2> {
294 match overflow_mode {
295 OverflowMode::Strict => {
296 primop_rules! {
297 fn(a: T, b: T) -> T[a * b]
298 requires valid_int(a * b, int_ty) => ConstrReason::Overflow
299 if let &BaseTy::Int(int_ty) = T
300
301 fn(a: T, b: T) -> T[a * b]
302 requires valid_uint(a * b, uint_ty) => ConstrReason::Overflow
303 if let &BaseTy::Uint(uint_ty) = T
304
305 fn(a: T, b: T) -> T
306 }
307 }
308
309 OverflowMode::Lazy | OverflowMode::StrictUnder => {
310 primop_rules! {
311 fn(a: T, b: T) -> T{v: E::implies(valid_int(a * b, int_ty), E::eq(v, a * b)) }
312 if let &BaseTy::Int(int_ty) = T
313
314 fn(a: T, b: T) -> T{v: E::implies(valid_uint(a * b, uint_ty), E::eq(v, a * b)) }
315 if let &BaseTy::Uint(uint_ty) = T
316
317 fn(a: T, b: T) -> T
318 }
319 }
320
321 OverflowMode::None => {
322 primop_rules!(
323 fn(a: T, b: T) -> T[a * b]
324 if T.is_integral()
325
326 fn(a: T, b: T) -> T
327 if T.is_float()
328 )
329 }
330 }
331}
332
333fn mk_sub_rules(overflow_mode: OverflowMode) -> RuleMatcher<2> {
335 match overflow_mode {
336 OverflowMode::Strict => {
337 primop_rules! {
338 fn(a: T, b: T) -> T[a - b]
339 requires valid_int(a - b, int_ty) => ConstrReason::Overflow
340 if let &BaseTy::Int(int_ty) = T
341
342 fn(a: T, b: T) -> T[a - b]
343 requires valid_uint(a - b, uint_ty) => ConstrReason::Overflow
344 if let &BaseTy::Uint(uint_ty) = T
345
346 fn(a: T, b: T) -> T
347 }
348 }
349
350 OverflowMode::StrictUnder => {
352 primop_rules! {
353 fn(a: T, b: T) -> T{v: E::implies(valid_int(a - b, int_ty), E::eq(v, a - b)) }
354 if let &BaseTy::Int(int_ty) = T
355
356 fn(a: T, b: T) -> T{v: E::implies(valid_uint(a - b, uint_ty), E::eq(v, a - b)) }
357 requires E::ge(a - b, 0) => ConstrReason::Underflow
358 if let &BaseTy::Uint(uint_ty) = T
359
360 fn(a: T, b: T) -> T
361 }
362 }
363
364 OverflowMode::Lazy => {
365 primop_rules! {
366 fn(a: T, b: T) -> T{v: E::implies(valid_int(a - b, int_ty), E::eq(v, a - b)) }
367 if let &BaseTy::Int(int_ty) = T
368
369 fn(a: T, b: T) -> T{v: E::implies(valid_uint(a - b, uint_ty), E::eq(v, a - b)) }
370 if let &BaseTy::Uint(uint_ty) = T
371
372 fn(a: T, b: T) -> T
373 }
374 }
375
376 OverflowMode::None => {
377 primop_rules! {
378 fn(a: T, b: T) -> T[a - b]
379 requires E::ge(a - b, 0) => ConstrReason::Underflow
380 if T.is_unsigned()
381
382 fn(a: T, b: T) -> T[a - b]
383 if T.is_signed()
384
385 fn(a: T, b: T) -> T
386 if T.is_float()
387 }
388 }
389 }
390}
391
392fn mk_div_rules() -> RuleMatcher<2> {
394 primop_rules! {
395 fn(a: T, b: T) -> T[a/b]
396 requires E::ne(b, 0) => ConstrReason::Div
397 if T.is_integral()
398
399 fn(a: T, b: T) -> T
400 if T.is_float()
401 }
402}
403
404fn mk_rem_rules() -> RuleMatcher<2> {
406 primop_rules! {
407 fn(a: T, b: T) -> T[E::binary_op(Mod(Sort::Int), a, b)]
408 requires E::ne(b, 0) => ConstrReason::Rem
409 if T.is_unsigned()
410
411 fn(a: T, b: T) -> T{v: E::implies(
412 E::and(E::ge(a, 0), E::ge(b, 0)),
413 E::eq(v, E::binary_op(Mod(Sort::Int), a, b))) }
414 requires E::ne(b, 0) => ConstrReason::Rem
415 if T.is_signed()
416
417 fn (a: T, b: T) -> T
418 if T.is_float()
419 }
420}
421
422fn mk_bit_and_rules() -> RuleMatcher<2> {
424 primop_rules! {
425 fn(a: T, b: T) -> { T[E::prim_val(rty::BinOp::BitAnd, a, b)] | E::prim_rel(rty::BinOp::BitAnd, a, b) }
426 if T.is_integral()
427
428 fn(a: bool, b: bool) -> bool[E::and(a, b)]
429 }
430}
431
432fn mk_bit_or_rules() -> RuleMatcher<2> {
434 primop_rules! {
435 fn(a: T, b: T) -> { T[E::prim_val(rty::BinOp::BitOr, a, b)] | E::prim_rel(rty::BinOp::BitOr, a, b) }
436 if T.is_integral()
437
438 fn(a: bool, b: bool) -> bool[E::or(a, b)]
439 }
440}
441
442fn mk_bit_xor_rules() -> RuleMatcher<2> {
444 primop_rules! {
445 fn(a: T, b: T) -> { T[E::prim_val(rty::BinOp::BitXor, a, b)] | E::prim_rel(rty::BinOp::BitXor, a, b) }
446 if T.is_integral()
447 }
448}
449
450fn mk_eq_rules() -> RuleMatcher<2> {
452 primop_rules! {
453 fn(a: T, b: T) -> bool[E::eq(a, b)]
454 if T.is_integral() || T.is_bool() || T.is_char() || T.is_str()
455
456 fn(a: T, b: T) -> bool
457 }
458}
459
460fn mk_ne_rules() -> RuleMatcher<2> {
462 primop_rules! {
463 fn(a: T, b: T) -> bool[E::ne(a, b)]
464 if T.is_integral() || T.is_bool()
465
466 fn(a: T, b: T) -> bool
467 }
468}
469
470fn mk_le_rules() -> RuleMatcher<2> {
472 primop_rules! {
473 fn(a: T, b: T) -> bool[E::le(a, b)]
474 if T.is_integral()
475
476 fn(a: bool, b: bool) -> bool[E::implies(a, b)]
477
478 fn(a: T, b: T) -> bool
479 }
480}
481
482fn mk_ge_rules() -> RuleMatcher<2> {
484 primop_rules! {
485 fn(a: T, b: T) -> bool[E::ge(a, b)]
486 if T.is_integral()
487
488 fn(a: bool, b: bool) -> bool[E::implies(b, a)]
489
490 fn(a: T, b: T) -> bool
491 }
492}
493
494fn mk_lt_rules() -> RuleMatcher<2> {
496 primop_rules! {
497 fn(a: T, b: T) -> bool[E::lt(a, b)]
498 if T.is_integral()
499
500 fn(a: bool, b: bool) -> bool[E::and(a.not(), b)]
501
502 fn(a: T, b: T) -> bool
503 }
504}
505
506fn mk_gt_rules() -> RuleMatcher<2> {
508 primop_rules! {
509 fn(a: T, b: T) -> bool[E::gt(a, b)]
510 if T.is_integral()
511
512 fn(a: bool, b: bool) -> bool[E::and(a, b.not())]
513
514 fn(a: T, b: T) -> bool
515 }
516}
517
518fn mk_shl_rules() -> RuleMatcher<2> {
520 primop_rules! {
521 fn(a: T, b: S) -> { T[E::prim_val(rty::BinOp::BitShl, a, b)] | E::prim_rel(rty::BinOp::BitShl, a, b) }
522 if T.is_integral() && S.is_integral()
523 }
524}
525
526fn mk_shr_rules() -> RuleMatcher<2> {
528 primop_rules! {
529 fn(a: T, b: S) -> { T[E::prim_val(rty::BinOp::BitShr, a, b)] | E::prim_rel(rty::BinOp::BitShr, a, b) }
530 if T.is_integral() && S.is_integral()
531 }
532}
533
534fn mk_neg_rules(overflow_mode: OverflowMode) -> RuleMatcher<1> {
536 match overflow_mode {
537 OverflowMode::Strict => {
538 primop_rules! {
539 fn(a: T) -> T[a.neg()]
540 requires E::ne(a, E::int_min(int_ty)) => ConstrReason::Overflow
541 if let &BaseTy::Int(int_ty) = T
542
543 fn(a: T) -> T[a.neg()]
544 if T.is_float()
545 }
546 }
547 OverflowMode::Lazy | OverflowMode::StrictUnder => {
548 primop_rules! {
549 fn(a: T) -> T{v: E::implies(E::ne(a, E::int_min(int_ty)), E::eq(v, a.neg())) }
550 if let &BaseTy::Int(int_ty) = T
551
552 fn(a: T) -> T[a.neg()]
553 if T.is_float()
554 }
555 }
556 OverflowMode::None => {
557 primop_rules! {
558 fn(a: T) -> T[a.neg()]
559 if T.is_integral()
560
561 fn(a: T) -> T
562 if T.is_float()
563 }
564 }
565 }
566}
567
568fn mk_not_rules() -> RuleMatcher<1> {
570 primop_rules! {
571 fn(a: bool) -> bool[a.not()]
572
573 fn(a: T) -> T
574 if T.is_integral()
575 }
576}