flux_macros/
primops.rs

1use std::collections::HashMap;
2
3use itertools::Itertools;
4use proc_macro2::{Span, TokenStream};
5use quote::quote;
6use syn::{
7    Ident, Lifetime, Token, braced, bracketed, parenthesized,
8    parse::{Parse, ParseStream},
9    parse_macro_input,
10    punctuated::Punctuated,
11    token,
12};
13
14macro_rules! unwrap_result {
15    ($e:expr) => {{
16        match $e {
17            Ok(e) => e,
18            Err(e) => return e.to_compile_error().into(),
19        }
20    }};
21}
22
23pub fn primop_rules(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
24    let rules = parse_macro_input!(input as Rules);
25
26    let argc = unwrap_result!(rules.check_arg_count());
27
28    let rules = rules.0.into_iter().enumerate().map(|(i, rule)| {
29        Renderer::new(i, rule)
30            .render()
31            .unwrap_or_else(|err| err.to_compile_error())
32    });
33    let args = args(argc);
34    quote! {
35        #[allow(unused_variables, non_snake_case)]
36        |#args| {
37            #(#rules)*
38            None
39        }
40    }
41    .into()
42}
43
44fn args(n: usize) -> TokenStream {
45    let args = (0..n).map(|i| {
46        let bty = mk_bty_arg(i);
47        let idx = mk_idx_arg(i);
48        quote!((#bty, #idx))
49    });
50    quote!([#(#args),*])
51}
52
53struct Rules(Vec<Rule>);
54
55impl Rules {
56    /// Check that the number of arguments is the same in all rules
57    fn check_arg_count(&self) -> syn::Result<usize> {
58        let argc = self.0.first().map(|rule| rule.args.len()).unwrap_or(0);
59        for rule in &self.0 {
60            if rule.args.len() != argc {
61                return Err(syn::Error::new(
62                    Span::call_site(),
63                    "all rules must have the same number of arguments",
64                ));
65            }
66        }
67        Ok(argc)
68    }
69}
70
71impl Parse for Rules {
72    fn parse(input: ParseStream) -> syn::Result<Self> {
73        let mut v = vec![];
74        while !input.is_empty() {
75            v.push(input.parse()?);
76        }
77        Ok(Rules(v))
78    }
79}
80
81struct Renderer {
82    lbl: Lifetime,
83    rule: Rule,
84    /// The set of metavars and the index of the inputs they match
85    metavars: HashMap<String, Vec<usize>>,
86}
87
88impl Renderer {
89    fn new(i: usize, rule: Rule) -> Self {
90        let mut metavars: HashMap<String, Vec<usize>> = HashMap::new();
91        for (i, input) in rule.args.iter().enumerate() {
92            let bty_str = input.bty.to_string();
93            if !is_primitive_type(&bty_str) {
94                metavars.entry(bty_str).or_default().push(i);
95            }
96        }
97
98        let lbl = syn::Lifetime::new(&format!("'lbl{}", i), Span::call_site());
99
100        Self { lbl, rule, metavars }
101    }
102
103    fn render(&self) -> syn::Result<TokenStream> {
104        let lbl = &self.lbl;
105        let metavar_matching = self.metavar_matching();
106        let primitive_checks = self.check_primitive_types();
107        let declare_metavars = self.declare_metavars();
108        let guards = self.guards();
109        let declare_idxs_names = self.declare_idxs_names();
110        let output_type = self.output_type()?;
111        let precondition = self.precondition();
112        Ok(quote! {
113            #lbl: {
114                #metavar_matching
115                #primitive_checks
116                #declare_metavars
117                #guards
118
119                #declare_idxs_names
120                let precondition = #precondition;
121                let v = Expr::nu();
122                let output_type = #output_type;
123                return Some(MatchedRule { precondition, output_type })
124            }
125        })
126    }
127
128    fn bty_arg_or_prim(&self, ident: &syn::Ident) -> syn::Result<TokenStream> {
129        let ident_str = ident.to_string();
130        if is_primitive_type(ident) {
131            Ok(quote!(BaseTy::from_primitive_str(#ident_str).unwrap()))
132        } else {
133            self.metavars
134                .get(&ident_str)
135                .map(|idxs| {
136                    let arg = mk_bty_arg(idxs[0]);
137                    quote!(#arg.clone())
138                })
139                .ok_or_else(|| {
140                    syn::Error::new(ident.span(), format!("cannot find metavariable `{ident_str}`"))
141                })
142        }
143    }
144
145    fn output_type(&self) -> syn::Result<TokenStream> {
146        let out = match &self.rule.output {
147            Output::Base(bty) => {
148                let bty = self.bty_arg_or_prim(bty)?;
149                quote!(#bty.to_ty())
150            }
151            Output::Indexed(bty, idx) => {
152                let bty = self.bty_arg_or_prim(bty)?;
153                quote!(rty::Ty::indexed( #bty, #idx))
154            }
155            Output::Exists(bty, pred) => {
156                let bty = self.bty_arg_or_prim(bty)?;
157                quote!(rty::Ty::exists_with_constr( #bty, #pred))
158            }
159        };
160        Ok(out)
161    }
162
163    /// Generates the code that checks that all the inputs matching the same metavariable are equal
164    fn metavar_matching(&self) -> TokenStream {
165        let lbl = &self.lbl;
166        let checks = self.metavars.values().map(|idxs| {
167            let checks = idxs.iter().tuple_windows().map(|(i, j)| {
168                let bty_arg1 = mk_bty_arg(*i);
169                let bty_arg2 = mk_bty_arg(*j);
170                quote! {
171                    if #bty_arg2 != #bty_arg1 {
172                        break #lbl;
173                    }
174                }
175            });
176            quote!(#(#checks)*)
177        });
178        quote!(#(#checks)*)
179    }
180
181    /// Generates the code that checks if an arg matching a primitive type has indeed that type
182    fn check_primitive_types(&self) -> TokenStream {
183        let lbl = &self.lbl;
184        self.rule
185            .args
186            .iter()
187            .enumerate()
188            .flat_map(|(i, arg)| {
189                let bty = &arg.bty;
190                if is_primitive_type(bty) {
191                    let bty_str = bty.to_string();
192                    let bty_arg = mk_bty_arg(i);
193                    Some(quote! {
194                        let Some(s) = #bty_arg.primitive_symbol() else {
195                            break #lbl;
196                        };
197                        if s.as_str() != #bty_str {
198                            break #lbl;
199                        }
200                    })
201                } else {
202                    None
203                }
204            })
205            .collect()
206    }
207
208    fn precondition(&self) -> TokenStream {
209        if let Some(requires) = &self.rule.requires {
210            let reason = &requires.reason;
211            let pred = &requires.pred;
212            quote!(Some(Pre { reason: #reason, pred: #pred }))
213        } else {
214            quote!(None)
215        }
216    }
217
218    /// Declare metavars as variables so they can be accessed in the guards
219    fn declare_metavars(&self) -> TokenStream {
220        self.metavars
221            .iter()
222            .map(|(var, matching_positions)| {
223                let var = syn::Ident::new(var, Span::call_site());
224                let bty_arg = mk_bty_arg(matching_positions[0]);
225                quote! {
226                    let #var = #bty_arg;
227                }
228            })
229            .collect()
230    }
231
232    fn declare_idxs_names(&self) -> TokenStream {
233        self.rule
234            .args
235            .iter()
236            .enumerate()
237            .map(|(i, arg)| {
238                let name = &arg.name;
239                let idx_arg = mk_idx_arg(i);
240                quote!(let #name = #idx_arg;)
241            })
242            .collect()
243    }
244
245    fn guards(&self) -> TokenStream {
246        self.rule
247            .guards
248            .iter()
249            .map(|guard| self.guard(guard))
250            .collect()
251    }
252
253    fn guard(&self, guard: &Guard) -> TokenStream {
254        let lbl = &self.lbl;
255        match guard {
256            Guard::If(if_, expr) => quote! {#if_ !(#expr) { break #lbl; }},
257            Guard::IfLet(let_) => quote!(#let_ else { break #lbl; };),
258            Guard::Let(let_) => quote!(#let_;),
259        }
260    }
261}
262
263struct Rule {
264    args: Punctuated<Arg, Token![,]>,
265    output: Output,
266    requires: Option<Requires>,
267    guards: Vec<Guard>,
268}
269
270impl Parse for Rule {
271    fn parse(input: ParseStream) -> syn::Result<Self> {
272        let _: Token![fn] = input.parse()?;
273        let content;
274        parenthesized!(content in input);
275        let inputs = content.parse_terminated(Arg::parse, Token![,])?;
276        let _: Token![->] = input.parse()?;
277        let output = input.parse()?;
278        let requires = if input.peek(kw::requires) { Some(input.parse()?) } else { None };
279        let guards = parse_guards(input)?;
280        Ok(Rule { args: inputs, output, requires, guards })
281    }
282}
283
284/// An arg of the form `a: T`
285struct Arg {
286    name: syn::Ident,
287    bty: syn::Ident,
288}
289
290impl Parse for Arg {
291    fn parse(input: ParseStream) -> syn::Result<Self> {
292        let name = input.parse()?;
293        let _: Token![:] = input.parse()?;
294        let bty = input.parse()?;
295        Ok(Arg { name, bty })
296    }
297}
298
299enum Output {
300    Base(syn::Ident),
301    Indexed(syn::Ident, TokenStream),
302    Exists(syn::Ident, TokenStream),
303}
304
305impl Parse for Output {
306    fn parse(input: ParseStream) -> syn::Result<Self> {
307        let bty: syn::Ident = input.parse()?;
308        if input.peek(token::Bracket) {
309            let content;
310            bracketed!(content in input);
311            Ok(Output::Indexed(bty, content.parse()?))
312        } else if input.peek(token::Brace) {
313            let content;
314            braced!(content in input);
315            let _: syn::Ident = content.parse()?;
316            let _: Token![:] = content.parse()?;
317            Ok(Output::Exists(bty, content.parse()?))
318        } else {
319            Ok(Output::Base(bty))
320        }
321    }
322}
323
324struct Requires {
325    pred: syn::Expr,
326    reason: syn::Path,
327}
328
329impl Parse for Requires {
330    fn parse(input: ParseStream) -> syn::Result<Self> {
331        let _: kw::requires = input.parse()?;
332        let pred = input.parse()?;
333        let _: Token![=>] = input.parse()?;
334        let reason = input.parse()?;
335        Ok(Requires { pred, reason })
336    }
337}
338
339fn parse_guards(input: ParseStream) -> syn::Result<Vec<Guard>> {
340    let mut guards = vec![];
341    while !input.is_empty() && (input.peek(Token![let]) || input.peek(Token![if])) {
342        guards.push(input.parse()?);
343    }
344    Ok(guards)
345}
346enum Guard {
347    If(Token![if], syn::Expr),
348    IfLet(syn::ExprLet),
349    Let(syn::ExprLet),
350}
351
352impl Parse for Guard {
353    fn parse(input: ParseStream) -> syn::Result<Self> {
354        let lookahead = input.lookahead1();
355        if lookahead.peek(Token![if]) {
356            let if_ = input.parse()?;
357            if input.peek(Token![let]) {
358                Ok(Guard::IfLet(input.parse()?))
359            } else {
360                Ok(Guard::If(if_, input.parse()?))
361            }
362        } else if lookahead.peek(Token![let]) {
363            Ok(Guard::Let(input.parse()?))
364        } else {
365            Err(lookahead.error())
366        }
367    }
368}
369
370fn mk_idx_arg(i: usize) -> Ident {
371    Ident::new(&format!("idx{}", i), Span::call_site())
372}
373
374fn mk_bty_arg(i: usize) -> Ident {
375    Ident::new(&format!("bty{}", i), Span::call_site())
376}
377
378fn is_primitive_type<T>(s: &T) -> bool
379where
380    T: PartialEq<str>,
381{
382    s == "i8"
383        || s == "i16"
384        || s == "i32"
385        || s == "i64"
386        || s == "i128"
387        || s == "u8"
388        || s == "u16"
389        || s == "u32"
390        || s == "u64"
391        || s == "u128"
392        || s == "f32"
393        || s == "f64"
394        || s == "isize"
395        || s == "usize"
396        || s == "bool"
397        || s == "char"
398        || s == "str"
399}
400
401mod kw {
402    syn::custom_keyword!(requires);
403}