1use std::{fmt, hash::Hash, sync::LazyLock};
27
28use flux_common::tracked_span_bug;
29use flux_infer::infer::ConstrReason;
30use flux_macros::primop_rules;
31use flux_middle::rty::{self, BaseTy, Expr, Sort};
32use flux_rustc_bridge::mir;
33use rty::{BinOp::Mod, Expr as E};
34use rustc_data_structures::unord::UnordMap;
35
36#[derive(Debug)]
37pub(crate) struct MatchedRule {
38 pub precondition: Option<Pre>,
39 pub output_type: rty::Ty,
40}
41
42#[derive(Debug)]
43pub(crate) struct Pre {
44 pub reason: ConstrReason,
45 pub pred: Expr,
46}
47
48pub(crate) fn match_bin_op(
49 op: mir::BinOp,
50 bty1: &BaseTy,
51 idx1: &Expr,
52 bty2: &BaseTy,
53 idx2: &Expr,
54 check_overflow: bool,
55) -> MatchedRule {
56 let table = if check_overflow { &OVERFLOW_BIN_OPS } else { &DEFAULT_BIN_OPS };
57 table.match_inputs(&op, [(bty1.clone(), idx1.clone()), (bty2.clone(), idx2.clone())])
58}
59
60pub(crate) fn match_un_op(
61 op: mir::UnOp,
62 bty: &BaseTy,
63 idx: &Expr,
64 check_overflow: bool,
65) -> MatchedRule {
66 let table = if check_overflow { &OVERFLOW_UN_OPS } else { &DEFAULT_UN_OPS };
67 table.match_inputs(&op, [(bty.clone(), idx.clone())])
68}
69
70struct RuleTable<Op: Eq + Hash, const N: usize> {
71 rules: UnordMap<Op, RuleMatcher<N>>,
72}
73
74impl<Op: Eq + Hash + fmt::Debug, const N: usize> RuleTable<Op, N> {
75 fn match_inputs(&self, op: &Op, inputs: [(BaseTy, Expr); N]) -> MatchedRule {
76 (self.rules[op])(&inputs)
77 .unwrap_or_else(|| tracked_span_bug!("no primop rule for {op:?} using {inputs:?}"))
78 }
79}
80
81type RuleMatcher<const N: usize> = fn(&[(BaseTy, Expr); N]) -> Option<MatchedRule>;
82
83static DEFAULT_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
84 use mir::BinOp::*;
85 RuleTable {
86 rules: [
87 (Add, mk_add_rules(false)),
89 (Mul, mk_mul_rules(false)),
90 (Sub, mk_sub_rules(false)),
91 (Div, mk_div_rules()),
92 (Rem, mk_rem_rules()),
93 (BitAnd, mk_bit_and_rules()),
95 (BitOr, mk_bit_or_rules()),
96 (BitXor, mk_bit_xor_rules()),
97 (Eq, mk_eq_rules()),
99 (Ne, mk_ne_rules()),
100 (Le, mk_le_rules()),
101 (Ge, mk_ge_rules()),
102 (Lt, mk_lt_rules()),
103 (Gt, mk_gt_rules()),
104 (Shl, mk_shl_rules()),
106 (Shr, mk_shr_rules()),
107 ]
108 .into_iter()
109 .collect(),
110 }
111});
112
113static OVERFLOW_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
114 use mir::BinOp::*;
115 RuleTable {
116 rules: [
117 (Add, mk_add_rules(true)),
119 (Mul, mk_mul_rules(true)),
120 (Sub, mk_sub_rules(true)),
121 (Div, mk_div_rules()),
122 (Rem, mk_rem_rules()),
123 (BitAnd, mk_bit_and_rules()),
125 (BitOr, mk_bit_or_rules()),
126 (BitXor, mk_bit_xor_rules()),
127 (Eq, mk_eq_rules()),
129 (Ne, mk_ne_rules()),
130 (Le, mk_le_rules()),
131 (Ge, mk_ge_rules()),
132 (Lt, mk_lt_rules()),
133 (Gt, mk_gt_rules()),
134 (Shl, mk_shl_rules()),
136 (Shr, mk_shr_rules()),
137 ]
138 .into_iter()
139 .collect(),
140 }
141});
142
143static DEFAULT_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
144 use mir::UnOp::*;
145 RuleTable {
146 rules: [(Neg, mk_neg_rules(false)), (Not, mk_not_rules())]
147 .into_iter()
148 .collect(),
149 }
150});
151
152static OVERFLOW_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
153 use mir::UnOp::*;
154 RuleTable {
155 rules: [(Neg, mk_neg_rules(true)), (Not, mk_not_rules())]
156 .into_iter()
157 .collect(),
158 }
159});
160
161fn mk_add_rules(check_overflow: bool) -> RuleMatcher<2> {
163 if check_overflow {
164 primop_rules! {
165 fn(a: T, b: T) -> T[a + b]
166 requires E::and(
167 E::ge(a + b, E::int_min(int_ty)),
168 E::le(a + b, E::int_max(int_ty)),
169 ) => ConstrReason::Overflow
170 if let &BaseTy::Int(int_ty) = T
171
172 fn(a: T, b: T) -> T[a + b]
173 requires E::le(a + b, E::uint_max(uint_ty)) => ConstrReason::Overflow
174 if let &BaseTy::Uint(uint_ty) = T
175
176 fn(a: T, b: T) -> T
177 }
178 } else {
179 primop_rules! {
180 fn(a: T, b: T) -> T[a + b]
181 if T.is_integral()
182
183 fn(a: T, b: T) -> T
184 }
185 }
186}
187
188fn mk_mul_rules(check_overflow: bool) -> RuleMatcher<2> {
190 if check_overflow {
191 primop_rules! {
192 fn(a: T, b: T) -> T[a * b]
193 requires E::and(
194 E::ge(a * b, E::int_min(int_ty)),
195 E::le(a * b, E::int_max(int_ty)),
196 ) => ConstrReason::Overflow
197 if let &BaseTy::Int(int_ty) = T
198
199 fn(a: T, b: T) -> T[a * b]
200 requires E::le(a * b, E::uint_max(uint_ty)) => ConstrReason::Overflow
201 if let &BaseTy::Uint(uint_ty) = T
202
203 fn(a: T, b: T) -> T
204 }
205 } else {
206 primop_rules!(
207 fn(a: T, b: T) -> T[a * b]
208 if T.is_integral()
209
210 fn(a: T, b: T) -> T
211 if T.is_float()
212 )
213 }
214}
215
216fn mk_sub_rules(check_overflow: bool) -> RuleMatcher<2> {
218 if check_overflow {
219 primop_rules! {
220 fn(a: T, b: T) -> T[a - b]
221 requires E::and(
222 E::ge(a - b, E::int_min(int_ty)),
223 E::le(a - b, E::int_max(int_ty)),
224 ) => ConstrReason::Overflow
225 if let &BaseTy::Int(int_ty) = T
226
227 fn(a: T, b: T) -> T[a - b]
228 requires E::and(
229 E::ge(a - b, 0),
230 E::le(a - b, E::uint_max(uint_ty)),
231 ) => ConstrReason::Overflow
232 if let &BaseTy::Uint(uint_ty) = T
233
234 fn(a: T, b: T) -> T
235 }
236 } else {
237 primop_rules! {
238 fn(a: T, b: T) -> T[a - b]
239 requires E::ge(a - b, 0) => ConstrReason::Overflow
240 if T.is_unsigned()
241
242 fn(a: T, b: T) -> T[a - b]
243 if T.is_signed()
244
245 fn(a: T, b: T) -> T
246 if T.is_float()
247 }
248 }
249}
250
251fn mk_div_rules() -> RuleMatcher<2> {
253 primop_rules! {
254 fn(a: T, b: T) -> T[a/b]
255 requires E::ne(b, 0) => ConstrReason::Div
256 if T.is_integral()
257
258 fn(a: T, b: T) -> T
259 if T.is_float()
260 }
261}
262
263fn mk_rem_rules() -> RuleMatcher<2> {
265 primop_rules! {
266 fn(a: T, b: T) -> T[E::binary_op(Mod(Sort::Int), a, b)]
267 requires E::ne(b, 0) => ConstrReason::Rem
268 if T.is_unsigned()
269
270 fn(a: T, b: T) -> T{v: E::implies(
271 E::and(E::ge(a, 0), E::ge(b, 0)),
272 E::eq(v, E::binary_op(Mod(Sort::Int), a, b))) }
273 requires E::ne(b, 0) => ConstrReason::Rem
274 if T.is_signed()
275
276 fn (a: T, b: T) -> T
277 if T.is_float()
278 }
279}
280
281fn mk_bit_and_rules() -> RuleMatcher<2> {
283 primop_rules! {
284 fn(a: T, b: T) -> { T[E::prim_val(rty::BinOp::BitAnd, a, b)] | E::prim_rel(rty::BinOp::BitAnd, a, b) }
285 if T.is_integral()
286
287 fn(a: bool, b: bool) -> bool[E::and(a, b)]
288 }
289}
290
291fn mk_bit_or_rules() -> RuleMatcher<2> {
293 primop_rules! {
294 fn(a: T, b: T) -> { T[E::prim_val(rty::BinOp::BitOr, a, b)] | E::prim_rel(rty::BinOp::BitOr, a, b) }
295 if T.is_integral()
296
297 fn(a: bool, b: bool) -> bool[E::or(a, b)]
298 }
299}
300
301fn mk_bit_xor_rules() -> RuleMatcher<2> {
303 primop_rules! {
304 fn(a: T, b: T) -> T
305 if T.is_integral()
306 }
307}
308
309fn mk_eq_rules() -> RuleMatcher<2> {
311 primop_rules! {
312 fn(a: T, b: T) -> bool[E::eq(a, b)]
313 if T.is_integral() || T.is_bool() || T.is_char() || T.is_str()
314
315 fn(a: T, b: T) -> bool
316 }
317}
318
319fn mk_ne_rules() -> RuleMatcher<2> {
321 primop_rules! {
322 fn(a: T, b: T) -> bool[E::ne(a, b)]
323 if T.is_integral() || T.is_bool()
324
325 fn(a: T, b: T) -> bool
326 }
327}
328
329fn mk_le_rules() -> RuleMatcher<2> {
331 primop_rules! {
332 fn(a: T, b: T) -> bool[E::le(a, b)]
333 if T.is_integral()
334
335 fn(a: bool, b: bool) -> bool[E::implies(a, b)]
336
337 fn(a: T, b: T) -> bool
338 }
339}
340
341fn mk_ge_rules() -> RuleMatcher<2> {
343 primop_rules! {
344 fn(a: T, b: T) -> bool[E::ge(a, b)]
345 if T.is_integral()
346
347 fn(a: bool, b: bool) -> bool[E::implies(b, a)]
348
349 fn(a: T, b: T) -> bool
350 }
351}
352
353fn mk_lt_rules() -> RuleMatcher<2> {
355 primop_rules! {
356 fn(a: T, b: T) -> bool[E::lt(a, b)]
357 if T.is_integral()
358
359 fn(a: bool, b: bool) -> bool[E::and(a.not(), b)]
360
361 fn(a: T, b: T) -> bool
362 }
363}
364
365fn mk_gt_rules() -> RuleMatcher<2> {
367 primop_rules! {
368 fn(a: T, b: T) -> bool[E::gt(a, b)]
369 if T.is_integral()
370
371 fn(a: bool, b: bool) -> bool[E::and(a, b.not())]
372
373 fn(a: T, b: T) -> bool
374 }
375}
376
377fn mk_shl_rules() -> RuleMatcher<2> {
379 primop_rules! {
380 fn(a: T, b: S) -> { T[E::prim_val(rty::BinOp::BitShl, a, b)] | E::prim_rel(rty::BinOp::BitShl, a, b) }
381 if T.is_integral() && S.is_integral()
382 }
383}
384
385fn mk_shr_rules() -> RuleMatcher<2> {
387 primop_rules! {
388 fn(a: T, b: S) -> { T[E::prim_val(rty::BinOp::BitShr, a, b)] | E::prim_rel(rty::BinOp::BitShr, a, b) }
389 if T.is_integral() && S.is_integral()
390 }
391}
392
393fn mk_neg_rules(check_overflow: bool) -> RuleMatcher<1> {
395 if check_overflow {
396 primop_rules! {
397 fn(a: T) -> T[a.neg()]
398 requires E::ne(a, E::int_min(int_ty)) => ConstrReason::Overflow
399 if let &BaseTy::Int(int_ty) = T
400
401 fn(a: T) -> T[a.neg()]
402 if T.is_float()
403 }
404 } else {
405 primop_rules! {
406 fn(a: T) -> T[a.neg()]
407 if T.is_integral()
408
409 fn(a: T) -> T
410 if T.is_float()
411 }
412 }
413}
414
415fn mk_not_rules() -> RuleMatcher<1> {
417 primop_rules! {
418 fn(a: bool) -> bool[a.not()]
419
420 fn(a: T) -> T
421 if T.is_integral()
422 }
423}