1use std::collections::HashMap;
2
3use itertools::Itertools;
4use proc_macro2::{Span, TokenStream};
5use quote::quote;
6use syn::{
7 Ident, Lifetime, Token, braced, bracketed, parenthesized,
8 parse::{Parse, ParseStream},
9 parse_macro_input,
10 punctuated::Punctuated,
11 token,
12};
13
14macro_rules! unwrap_result {
15 ($e:expr) => {{
16 match $e {
17 Ok(e) => e,
18 Err(e) => return e.to_compile_error().into(),
19 }
20 }};
21}
22
23pub fn primop_rules(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
24 let rules = parse_macro_input!(input as Rules);
25
26 let argc = unwrap_result!(rules.check_arg_count());
27
28 let rules = rules.0.into_iter().enumerate().map(|(i, rule)| {
29 Renderer::new(i, rule)
30 .render()
31 .unwrap_or_else(|err| err.to_compile_error())
32 });
33 let args = args(argc);
34 quote! {
35 #[allow(unused_variables, non_snake_case)]
36 |#args| {
37 #(#rules)*
38 None
39 }
40 }
41 .into()
42}
43
44fn args(n: usize) -> TokenStream {
45 let args = (0..n).map(|i| {
46 let bty = mk_bty_arg(i);
47 let idx = mk_idx_arg(i);
48 quote!((#bty, #idx))
49 });
50 quote!([#(#args),*])
51}
52
53struct Rules(Vec<Rule>);
54
55impl Rules {
56 fn check_arg_count(&self) -> syn::Result<usize> {
58 let argc = self.0.first().map(|rule| rule.args.len()).unwrap_or(0);
59 for rule in &self.0 {
60 if rule.args.len() != argc {
61 return Err(syn::Error::new(
62 Span::call_site(),
63 "all rules must have the same number of arguments",
64 ));
65 }
66 }
67 Ok(argc)
68 }
69}
70
71impl Parse for Rules {
72 fn parse(input: ParseStream) -> syn::Result<Self> {
73 let mut v = vec![];
74 while !input.is_empty() {
75 v.push(input.parse()?);
76 }
77 Ok(Rules(v))
78 }
79}
80
81struct Renderer {
82 lbl: Lifetime,
83 rule: Rule,
84 metavars: HashMap<String, Vec<usize>>,
86}
87
88impl Renderer {
89 fn new(i: usize, rule: Rule) -> Self {
90 let mut metavars: HashMap<String, Vec<usize>> = HashMap::new();
91 for (i, input) in rule.args.iter().enumerate() {
92 let bty_str = input.bty.to_string();
93 if !is_primitive_type(&bty_str) {
94 metavars.entry(bty_str).or_default().push(i);
95 }
96 }
97
98 let lbl = syn::Lifetime::new(&format!("'lbl{}", i), Span::call_site());
99
100 Self { lbl, rule, metavars }
101 }
102
103 fn render(&self) -> syn::Result<TokenStream> {
104 let lbl = &self.lbl;
105 let metavar_matching = self.metavar_matching();
106 let primitive_checks = self.check_primitive_types();
107 let declare_metavars = self.declare_metavars();
108 let guards = self.guards();
109 let declare_idxs_names = self.declare_idxs_names();
110 let output_type = self.output_type()?;
111 let precondition = self.precondition();
112 Ok(quote! {
113 #lbl: {
114 #metavar_matching
115 #primitive_checks
116 #declare_metavars
117 #guards
118
119 #declare_idxs_names
120 let precondition = #precondition;
121 let v = Expr::nu();
122 let output_type = #output_type;
123 return Some(MatchedRule { precondition, output_type })
124 }
125 })
126 }
127
128 fn bty_arg_or_prim(&self, ident: &syn::Ident) -> syn::Result<TokenStream> {
129 let ident_str = ident.to_string();
130 if is_primitive_type(ident) {
131 Ok(quote!(BaseTy::from_primitive_str(#ident_str).unwrap()))
132 } else {
133 self.metavars
134 .get(&ident_str)
135 .map(|idxs| {
136 let arg = mk_bty_arg(idxs[0]);
137 quote!(#arg.clone())
138 })
139 .ok_or_else(|| {
140 syn::Error::new(ident.span(), format!("cannot find metavariable `{ident_str}`"))
141 })
142 }
143 }
144
145 fn output_type(&self) -> syn::Result<TokenStream> {
146 let out = match &self.rule.output {
147 Output::Base(bty) => {
148 let bty = self.bty_arg_or_prim(bty)?;
149 quote!(#bty.to_ty())
150 }
151 Output::Indexed(bty, idx) => {
152 let bty = self.bty_arg_or_prim(bty)?;
153 quote!(rty::Ty::indexed( #bty, #idx))
154 }
155 Output::Exists(bty, pred) => {
156 let bty = self.bty_arg_or_prim(bty)?;
157 quote!(rty::Ty::exists_with_constr( #bty, #pred))
158 }
159 };
160 Ok(out)
161 }
162
163 fn metavar_matching(&self) -> TokenStream {
165 let lbl = &self.lbl;
166 let checks = self.metavars.values().map(|idxs| {
167 let checks = idxs.iter().tuple_windows().map(|(i, j)| {
168 let bty_arg1 = mk_bty_arg(*i);
169 let bty_arg2 = mk_bty_arg(*j);
170 quote! {
171 if #bty_arg2 != #bty_arg1 {
172 break #lbl;
173 }
174 }
175 });
176 quote!(#(#checks)*)
177 });
178 quote!(#(#checks)*)
179 }
180
181 fn check_primitive_types(&self) -> TokenStream {
183 let lbl = &self.lbl;
184 self.rule
185 .args
186 .iter()
187 .enumerate()
188 .flat_map(|(i, arg)| {
189 let bty = &arg.bty;
190 if is_primitive_type(bty) {
191 let bty_str = bty.to_string();
192 let bty_arg = mk_bty_arg(i);
193 Some(quote! {
194 let Some(s) = #bty_arg.primitive_symbol() else {
195 break #lbl;
196 };
197 if s.as_str() != #bty_str {
198 break #lbl;
199 }
200 })
201 } else {
202 None
203 }
204 })
205 .collect()
206 }
207
208 fn precondition(&self) -> TokenStream {
209 if let Some(requires) = &self.rule.requires {
210 let reason = &requires.reason;
211 let pred = &requires.pred;
212 quote!(Some(Pre { reason: #reason, pred: #pred }))
213 } else {
214 quote!(None)
215 }
216 }
217
218 fn declare_metavars(&self) -> TokenStream {
220 self.metavars
221 .iter()
222 .map(|(var, matching_positions)| {
223 let var = syn::Ident::new(var, Span::call_site());
224 let bty_arg = mk_bty_arg(matching_positions[0]);
225 quote! {
226 let #var = #bty_arg;
227 }
228 })
229 .collect()
230 }
231
232 fn declare_idxs_names(&self) -> TokenStream {
233 self.rule
234 .args
235 .iter()
236 .enumerate()
237 .map(|(i, arg)| {
238 let name = &arg.name;
239 let idx_arg = mk_idx_arg(i);
240 quote!(let #name = #idx_arg;)
241 })
242 .collect()
243 }
244
245 fn guards(&self) -> TokenStream {
246 self.rule
247 .guards
248 .iter()
249 .map(|guard| self.guard(guard))
250 .collect()
251 }
252
253 fn guard(&self, guard: &Guard) -> TokenStream {
254 let lbl = &self.lbl;
255 match guard {
256 Guard::If(if_, expr) => quote! {#if_ !(#expr) { break #lbl; }},
257 Guard::IfLet(let_) => quote!(#let_ else { break #lbl; };),
258 Guard::Let(let_) => quote!(#let_;),
259 }
260 }
261}
262
263struct Rule {
264 args: Punctuated<Arg, Token![,]>,
265 output: Output,
266 requires: Option<Requires>,
267 guards: Vec<Guard>,
268}
269
270impl Parse for Rule {
271 fn parse(input: ParseStream) -> syn::Result<Self> {
272 let _: Token![fn] = input.parse()?;
273 let content;
274 parenthesized!(content in input);
275 let inputs = content.parse_terminated(Arg::parse, Token![,])?;
276 let _: Token![->] = input.parse()?;
277 let output = input.parse()?;
278 let requires = if input.peek(kw::requires) { Some(input.parse()?) } else { None };
279 let guards = parse_guards(input)?;
280 Ok(Rule { args: inputs, output, requires, guards })
281 }
282}
283
284struct Arg {
286 name: syn::Ident,
287 bty: syn::Ident,
288}
289
290impl Parse for Arg {
291 fn parse(input: ParseStream) -> syn::Result<Self> {
292 let name = input.parse()?;
293 let _: Token![:] = input.parse()?;
294 let bty = input.parse()?;
295 Ok(Arg { name, bty })
296 }
297}
298
299enum Output {
300 Base(syn::Ident),
301 Indexed(syn::Ident, TokenStream),
302 Exists(syn::Ident, TokenStream),
303}
304
305impl Parse for Output {
306 fn parse(input: ParseStream) -> syn::Result<Self> {
307 let bty: syn::Ident = input.parse()?;
308 if input.peek(token::Bracket) {
309 let content;
310 bracketed!(content in input);
311 Ok(Output::Indexed(bty, content.parse()?))
312 } else if input.peek(token::Brace) {
313 let content;
314 braced!(content in input);
315 let _: syn::Ident = content.parse()?;
316 let _: Token![:] = content.parse()?;
317 Ok(Output::Exists(bty, content.parse()?))
318 } else {
319 Ok(Output::Base(bty))
320 }
321 }
322}
323
324struct Requires {
325 pred: syn::Expr,
326 reason: syn::Path,
327}
328
329impl Parse for Requires {
330 fn parse(input: ParseStream) -> syn::Result<Self> {
331 let _: kw::requires = input.parse()?;
332 let pred = input.parse()?;
333 let _: Token![=>] = input.parse()?;
334 let reason = input.parse()?;
335 Ok(Requires { pred, reason })
336 }
337}
338
339fn parse_guards(input: ParseStream) -> syn::Result<Vec<Guard>> {
340 let mut guards = vec![];
341 while !input.is_empty() && (input.peek(Token![let]) || input.peek(Token![if])) {
342 guards.push(input.parse()?);
343 }
344 Ok(guards)
345}
346enum Guard {
347 If(Token![if], syn::Expr),
348 IfLet(syn::ExprLet),
349 Let(syn::ExprLet),
350}
351
352impl Parse for Guard {
353 fn parse(input: ParseStream) -> syn::Result<Self> {
354 let lookahead = input.lookahead1();
355 if lookahead.peek(Token![if]) {
356 let if_ = input.parse()?;
357 if input.peek(Token![let]) {
358 Ok(Guard::IfLet(input.parse()?))
359 } else {
360 Ok(Guard::If(if_, input.parse()?))
361 }
362 } else if lookahead.peek(Token![let]) {
363 Ok(Guard::Let(input.parse()?))
364 } else {
365 Err(lookahead.error())
366 }
367 }
368}
369
370fn mk_idx_arg(i: usize) -> Ident {
371 Ident::new(&format!("idx{}", i), Span::call_site())
372}
373
374fn mk_bty_arg(i: usize) -> Ident {
375 Ident::new(&format!("bty{}", i), Span::call_site())
376}
377
378fn is_primitive_type<T>(s: &T) -> bool
379where
380 T: PartialEq<str>,
381{
382 s == "i8"
383 || s == "i16"
384 || s == "i32"
385 || s == "i64"
386 || s == "i128"
387 || s == "u8"
388 || s == "u16"
389 || s == "u32"
390 || s == "u64"
391 || s == "u128"
392 || s == "f32"
393 || s == "f64"
394 || s == "isize"
395 || s == "usize"
396 || s == "bool"
397 || s == "char"
398 || s == "str"
399}
400
401mod kw {
402 syn::custom_keyword!(requires);
403}