1use std::{hash::Hash, sync::LazyLock};
27
28use flux_common::tracked_span_bug;
29use flux_infer::infer::ConstrReason;
30use flux_macros::primop_rules;
31use flux_middle::rty::{self, BaseTy, Expr, Sort};
32use flux_rustc_bridge::mir;
33use rty::{BinOp::Mod, Expr as E};
34use rustc_data_structures::unord::UnordMap;
35
36pub(crate) struct MatchedRule {
37 pub precondition: Option<Pre>,
38 pub output_type: rty::Ty,
39}
40
41pub(crate) struct Pre {
42 pub reason: ConstrReason,
43 pub pred: Expr,
44}
45
46pub(crate) fn match_bin_op(
47 op: mir::BinOp,
48 bty1: &BaseTy,
49 idx1: &Expr,
50 bty2: &BaseTy,
51 idx2: &Expr,
52 check_overflow: bool,
53) -> MatchedRule {
54 let table = if check_overflow { &OVERFLOW_BIN_OPS } else { &DEFAULT_BIN_OPS };
55 table.match_inputs(&op, [(bty1.clone(), idx1.clone()), (bty2.clone(), idx2.clone())])
56}
57
58pub(crate) fn match_un_op(
59 op: mir::UnOp,
60 bty: &BaseTy,
61 idx: &Expr,
62 check_overflow: bool,
63) -> MatchedRule {
64 let table = if check_overflow { &OVERFLOW_UN_OPS } else { &DEFAULT_UN_OPS };
65 table.match_inputs(&op, [(bty.clone(), idx.clone())])
66}
67
68struct RuleTable<Op: Eq + Hash, const N: usize> {
69 rules: UnordMap<Op, RuleMatcher<N>>,
70}
71
72impl<Op: Eq + Hash, const N: usize> RuleTable<Op, N> {
73 fn match_inputs(&self, op: &Op, inputs: [(BaseTy, Expr); N]) -> MatchedRule {
74 (self.rules[op])(&inputs)
75 .unwrap_or_else(|| tracked_span_bug!("no primop rule for {inputs:?}"))
76 }
77}
78
79type RuleMatcher<const N: usize> = fn(&[(BaseTy, Expr); N]) -> Option<MatchedRule>;
80
81static DEFAULT_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
82 use mir::BinOp::*;
83 RuleTable {
84 rules: [
85 (Add, mk_add_rules(false)),
87 (Mul, mk_mul_rules(false)),
88 (Sub, mk_sub_rules(false)),
89 (Div, mk_div_rules()),
90 (Rem, mk_rem_rules()),
91 (BitAnd, mk_bit_and_rules()),
93 (BitOr, mk_bit_or_rules()),
94 (BitXor, mk_bit_xor_rules()),
95 (Eq, mk_eq_rules()),
97 (Ne, mk_ne_rules()),
98 (Le, mk_le_rules()),
99 (Ge, mk_ge_rules()),
100 (Lt, mk_lt_rules()),
101 (Gt, mk_gt_rules()),
102 (Shl, mk_shl_rules()),
104 (Shr, mk_shr_rules()),
105 ]
106 .into_iter()
107 .collect(),
108 }
109});
110
111static OVERFLOW_BIN_OPS: LazyLock<RuleTable<mir::BinOp, 2>> = LazyLock::new(|| {
112 use mir::BinOp::*;
113 RuleTable {
114 rules: [
115 (Add, mk_add_rules(true)),
117 (Mul, mk_mul_rules(true)),
118 (Sub, mk_sub_rules(true)),
119 (Div, mk_div_rules()),
120 (Rem, mk_rem_rules()),
121 (BitAnd, mk_bit_and_rules()),
123 (BitOr, mk_bit_or_rules()),
124 (BitXor, mk_bit_xor_rules()),
125 (Eq, mk_eq_rules()),
127 (Ne, mk_ne_rules()),
128 (Le, mk_le_rules()),
129 (Ge, mk_ge_rules()),
130 (Lt, mk_lt_rules()),
131 (Gt, mk_gt_rules()),
132 (Shl, mk_shl_rules()),
134 (Shr, mk_shr_rules()),
135 ]
136 .into_iter()
137 .collect(),
138 }
139});
140
141static DEFAULT_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
142 use mir::UnOp::*;
143 RuleTable {
144 rules: [(Neg, mk_neg_rules(false)), (Not, mk_not_rules())]
145 .into_iter()
146 .collect(),
147 }
148});
149
150static OVERFLOW_UN_OPS: LazyLock<RuleTable<mir::UnOp, 1>> = LazyLock::new(|| {
151 use mir::UnOp::*;
152 RuleTable {
153 rules: [(Neg, mk_neg_rules(true)), (Not, mk_not_rules())]
154 .into_iter()
155 .collect(),
156 }
157});
158
159fn mk_add_rules(check_overflow: bool) -> RuleMatcher<2> {
161 if check_overflow {
162 primop_rules! {
163 fn(a: T, b: T) -> T[a + b]
164 requires E::and(
165 E::ge(a + b, E::int_min(int_ty)),
166 E::le(a + b, E::int_max(int_ty)),
167 ) => ConstrReason::Overflow
168 if let &BaseTy::Int(int_ty) = T
169
170 fn(a: T, b: T) -> T[a + b]
171 requires E::le(a + b, E::uint_max(uint_ty)) => ConstrReason::Overflow
172 if let &BaseTy::Uint(uint_ty) = T
173
174 fn(a: T, b: T) -> T
175 }
176 } else {
177 primop_rules! {
178 fn(a: T, b: T) -> T[a + b]
179 if T.is_integral()
180
181 fn(a: T, b: T) -> T
182 }
183 }
184}
185
186fn mk_mul_rules(check_overflow: bool) -> RuleMatcher<2> {
188 if check_overflow {
189 primop_rules! {
190 fn(a: T, b: T) -> T[a * b]
191 requires E::and(
192 E::ge(a * b, E::int_min(int_ty)),
193 E::le(a * b, E::int_max(int_ty)),
194 ) => ConstrReason::Overflow
195 if let &BaseTy::Int(int_ty) = T
196
197 fn(a: T, b: T) -> T[a * b]
198 requires E::le(a * b, E::uint_max(uint_ty)) => ConstrReason::Overflow
199 if let &BaseTy::Uint(uint_ty) = T
200
201 fn(a: T, b: T) -> T
202 }
203 } else {
204 primop_rules!(
205 fn(a: T, b: T) -> T[a * b]
206 if T.is_integral()
207
208 fn(a: T, b: T) -> T
209 if T.is_float()
210 )
211 }
212}
213
214fn mk_sub_rules(check_overflow: bool) -> RuleMatcher<2> {
216 if check_overflow {
217 primop_rules! {
218 fn(a: T, b: T) -> T[a - b]
219 requires E::and(
220 E::ge(a - b, E::int_min(int_ty)),
221 E::le(a - b, E::int_max(int_ty)),
222 ) => ConstrReason::Overflow
223 if let &BaseTy::Int(int_ty) = T
224
225 fn(a: T, b: T) -> T[a - b]
226 requires E::and(
227 E::ge(a - b, 0),
228 E::le(a - b, E::uint_max(uint_ty)),
229 ) => ConstrReason::Overflow
230 if let &BaseTy::Uint(uint_ty) = T
231
232 fn(a: T, b: T) -> T
233 }
234 } else {
235 primop_rules! {
236 fn(a: T, b: T) -> T[a - b]
237 requires E::ge(a - b, 0) => ConstrReason::Overflow
238 if T.is_unsigned()
239
240 fn(a: T, b: T) -> T[a - b]
241 if T.is_signed()
242
243 fn(a: T, b: T) -> T
244 if T.is_float()
245 }
246 }
247}
248
249fn mk_div_rules() -> RuleMatcher<2> {
251 primop_rules! {
252 fn(a: T, b: T) -> T[a/b]
253 requires E::ne(b, 0) => ConstrReason::Div
254 if T.is_integral()
255
256 fn(a: T, b: T) -> T
257 if T.is_float()
258 }
259}
260
261fn mk_rem_rules() -> RuleMatcher<2> {
263 primop_rules! {
264 fn(a: T, b: T) -> T[E::binary_op(Mod(Sort::Int), a, b)]
265 requires E::ne(b, 0) => ConstrReason::Rem
266 if T.is_unsigned()
267
268 fn(a: T, b: T) -> T{v: E::implies(
269 E::and(E::ge(a, 0), E::ge(b, 0)),
270 E::eq(v, E::binary_op(Mod(Sort::Int), a, b))) }
271 requires E::ne(b, 0) => ConstrReason::Rem
272 if T.is_signed()
273 }
274}
275
276fn mk_bit_and_rules() -> RuleMatcher<2> {
278 primop_rules! {
279 fn(a: T, b: T) -> T
280 if T.is_integral()
281
282 fn(a: bool, b: bool) -> bool[E::and(a, b)]
283 }
284}
285
286fn mk_bit_or_rules() -> RuleMatcher<2> {
288 primop_rules! {
289 fn(a: T, b: T) -> T
290 if T.is_integral()
291
292 fn(a: bool, b: bool) -> bool[E::or(a, b)]
293 }
294}
295
296fn mk_bit_xor_rules() -> RuleMatcher<2> {
298 primop_rules! {
299 fn(a: T, b: T) -> T
300 if T.is_integral()
301 }
302}
303
304fn mk_eq_rules() -> RuleMatcher<2> {
306 primop_rules! {
307 fn(a: T, b: T) -> bool[E::eq(a, b)]
308 if T.is_integral() || T.is_bool() || T.is_char() || T.is_str()
309
310 fn(a: T, b: T) -> bool
311 }
312}
313
314fn mk_ne_rules() -> RuleMatcher<2> {
316 primop_rules! {
317 fn(a: T, b: T) -> bool[E::ne(a, b)]
318 if T.is_integral() || T.is_bool()
319
320 fn(a: T, b: T) -> bool
321 }
322}
323
324fn mk_le_rules() -> RuleMatcher<2> {
326 primop_rules! {
327 fn(a: T, b: T) -> bool[E::le(a, b)]
328 if T.is_integral()
329
330 fn(a: bool, b: bool) -> bool[E::implies(a, b)]
331
332 fn(a: T, b: T) -> bool
333 }
334}
335
336fn mk_ge_rules() -> RuleMatcher<2> {
338 primop_rules! {
339 fn(a: T, b: T) -> bool[E::ge(a, b)]
340 if T.is_integral()
341
342 fn(a: bool, b: bool) -> bool[E::implies(b, a)]
343
344 fn(a: T, b: T) -> bool
345 }
346}
347
348fn mk_lt_rules() -> RuleMatcher<2> {
350 primop_rules! {
351 fn(a: T, b: T) -> bool[E::lt(a, b)]
352 if T.is_integral()
353
354 fn(a: bool, b: bool) -> bool[E::and(a.not(), b)]
355
356 fn(a: T, b: T) -> bool
357 }
358}
359
360fn mk_gt_rules() -> RuleMatcher<2> {
362 primop_rules! {
363 fn(a: T, b: T) -> bool[E::gt(a, b)]
364 if T.is_integral()
365
366 fn(a: bool, b: bool) -> bool[E::and(a, b.not())]
367
368 fn(a: T, b: T) -> bool
369 }
370}
371
372fn mk_shl_rules() -> RuleMatcher<2> {
374 primop_rules! {
375 fn(a: T, b: S) -> T
376 if T.is_integral() && S.is_integral()
377 }
378}
379
380fn mk_shr_rules() -> RuleMatcher<2> {
382 primop_rules! {
383 fn(a: T, b: S) -> T
384 if T.is_integral() && S.is_integral()
385 }
386}
387
388fn mk_neg_rules(check_overflow: bool) -> RuleMatcher<1> {
390 if check_overflow {
391 primop_rules! {
392 fn(a: T) -> T[a.neg()]
393 requires E::ne(a, E::int_min(int_ty)) => ConstrReason::Overflow
394 if let &BaseTy::Int(int_ty) = T
395
396 fn(a: T) -> T[a.neg()]
397 if T.is_float()
398 }
399 } else {
400 primop_rules! {
401 fn(a: T) -> T[a.neg()]
402 if T.is_integral()
403
404 fn(a: T) -> T
405 if T.is_float()
406 }
407 }
408}
409
410fn mk_not_rules() -> RuleMatcher<1> {
412 primop_rules! {
413 fn(a: bool) -> bool[a.not()]
414
415 fn(a: T) -> T
416 if T.is_integral()
417 }
418}