flux_macros/
primops.rs

1use std::collections::HashMap;
2
3use itertools::Itertools;
4use proc_macro2::{Span, TokenStream};
5use quote::quote;
6use syn::{
7    Ident, Lifetime, Token, braced, bracketed, parenthesized,
8    parse::{Parse, ParseStream},
9    parse_macro_input,
10    punctuated::Punctuated,
11    token,
12};
13
14macro_rules! unwrap_result {
15    ($e:expr) => {{
16        match $e {
17            Ok(e) => e,
18            Err(e) => return e.to_compile_error().into(),
19        }
20    }};
21}
22
23pub fn primop_rules(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
24    let rules = parse_macro_input!(input as Rules);
25
26    let argc = unwrap_result!(rules.check_arg_count());
27
28    let rules = rules.0.into_iter().enumerate().map(|(i, rule)| {
29        Renderer::new(i, rule)
30            .render()
31            .unwrap_or_else(|err| err.to_compile_error())
32    });
33    let args = args(argc);
34    quote! {
35        #[allow(unused_variables, non_snake_case)]
36        |#args| {
37            #(#rules)*
38            None
39        }
40    }
41    .into()
42}
43
44fn args(n: usize) -> TokenStream {
45    let args = (0..n).map(|i| {
46        let bty = mk_bty_arg(i);
47        let idx = mk_idx_arg(i);
48        quote!((#bty, #idx))
49    });
50    quote!([#(#args),*])
51}
52
53struct Rules(Vec<Rule>);
54
55impl Rules {
56    /// Check that the number of arguments is the same in all rules
57    fn check_arg_count(&self) -> syn::Result<usize> {
58        let argc = self.0.first().map(|rule| rule.args.len()).unwrap_or(0);
59        for rule in &self.0 {
60            if rule.args.len() != argc {
61                return Err(syn::Error::new(
62                    Span::call_site(),
63                    "all rules must have the same number of arguments",
64                ));
65            }
66        }
67        Ok(argc)
68    }
69}
70
71impl Parse for Rules {
72    fn parse(input: ParseStream) -> syn::Result<Self> {
73        let mut v = vec![];
74        while !input.is_empty() {
75            v.push(input.parse()?);
76        }
77        Ok(Rules(v))
78    }
79}
80
81struct Renderer {
82    lbl: Lifetime,
83    rule: Rule,
84    /// The set of metavars and the index of the inputs they match
85    metavars: HashMap<String, Vec<usize>>,
86}
87
88impl Renderer {
89    fn new(i: usize, rule: Rule) -> Self {
90        let mut metavars: HashMap<String, Vec<usize>> = HashMap::new();
91        for (i, input) in rule.args.iter().enumerate() {
92            let bty_str = input.bty.to_string();
93            if !is_primitive_type(&bty_str) {
94                metavars.entry(bty_str).or_default().push(i);
95            }
96        }
97
98        let lbl = syn::Lifetime::new(&format!("'lbl{i}"), Span::call_site());
99
100        Self { lbl, rule, metavars }
101    }
102
103    fn render(&self) -> syn::Result<TokenStream> {
104        let lbl = &self.lbl;
105        let metavar_matching = self.metavar_matching();
106        let primitive_checks = self.check_primitive_types();
107        let declare_metavars = self.declare_metavars();
108        let guards = self.guards();
109        let declare_idxs_names = self.declare_idxs_names();
110        let output_type = self.output_type()?;
111        let precondition = self.precondition();
112        Ok(quote! {
113            #lbl: {
114                #metavar_matching
115                #primitive_checks
116                #declare_metavars
117                #guards
118
119                #declare_idxs_names
120                let precondition = #precondition;
121                let v = Expr::nu();
122                let output_type = #output_type;
123                return Some(MatchedRule { precondition, output_type })
124            }
125        })
126    }
127
128    fn bty_arg_or_prim(&self, ident: &syn::Ident) -> syn::Result<TokenStream> {
129        let ident_str = ident.to_string();
130        if is_primitive_type(ident) {
131            Ok(quote!(BaseTy::from_primitive_str(#ident_str).unwrap()))
132        } else {
133            self.metavars
134                .get(&ident_str)
135                .map(|idxs| {
136                    let arg = mk_bty_arg(idxs[0]);
137                    quote!(#arg.clone())
138                })
139                .ok_or_else(|| {
140                    syn::Error::new(ident.span(), format!("cannot find metavariable `{ident_str}`"))
141                })
142        }
143    }
144
145    fn output_type(&self) -> syn::Result<TokenStream> {
146        let out = match &self.rule.output {
147            Output::Base(bty) => {
148                let bty = self.bty_arg_or_prim(bty)?;
149                quote!(#bty.to_ty())
150            }
151            Output::Indexed(bty, idx) => {
152                let bty = self.bty_arg_or_prim(bty)?;
153                quote!(rty::Ty::indexed( #bty, #idx))
154            }
155            Output::Exists(bty, pred) => {
156                let bty = self.bty_arg_or_prim(bty)?;
157                quote!(rty::Ty::exists_with_constr( #bty, #pred))
158            }
159            Output::Constr(bty, idx, pred) => {
160                let bty = self.bty_arg_or_prim(bty)?;
161                quote!(rty::Ty::constr(#pred, rty::Ty::indexed( #bty, #idx)))
162            }
163        };
164        Ok(out)
165    }
166
167    /// Generates the code that checks that all the inputs matching the same metavariable are equal
168    fn metavar_matching(&self) -> TokenStream {
169        let lbl = &self.lbl;
170        let checks = self.metavars.values().map(|idxs| {
171            let checks = idxs.iter().tuple_windows().map(|(i, j)| {
172                let bty_arg1 = mk_bty_arg(*i);
173                let bty_arg2 = mk_bty_arg(*j);
174                quote! {
175                    if #bty_arg2 != #bty_arg1 {
176                        break #lbl;
177                    }
178                }
179            });
180            quote!(#(#checks)*)
181        });
182        quote!(#(#checks)*)
183    }
184
185    /// Generates the code that checks if an arg matching a primitive type has indeed that type
186    fn check_primitive_types(&self) -> TokenStream {
187        let lbl = &self.lbl;
188        self.rule
189            .args
190            .iter()
191            .enumerate()
192            .flat_map(|(i, arg)| {
193                let bty = &arg.bty;
194                if is_primitive_type(bty) {
195                    let bty_str = bty.to_string();
196                    let bty_arg = mk_bty_arg(i);
197                    Some(quote! {
198                        let Some(s) = #bty_arg.primitive_symbol() else {
199                            break #lbl;
200                        };
201                        if s.as_str() != #bty_str {
202                            break #lbl;
203                        }
204                    })
205                } else {
206                    None
207                }
208            })
209            .collect()
210    }
211
212    fn precondition(&self) -> TokenStream {
213        if let Some(requires) = &self.rule.requires {
214            let reason = &requires.reason;
215            let pred = &requires.pred;
216            quote!(Some(Pre { reason: #reason, pred: #pred }))
217        } else {
218            quote!(None)
219        }
220    }
221
222    /// Declare metavars as variables so they can be accessed in the guards
223    fn declare_metavars(&self) -> TokenStream {
224        self.metavars
225            .iter()
226            .map(|(var, matching_positions)| {
227                let var = syn::Ident::new(var, Span::call_site());
228                let bty_arg = mk_bty_arg(matching_positions[0]);
229                quote! {
230                    let #var = #bty_arg;
231                }
232            })
233            .collect()
234    }
235
236    fn declare_idxs_names(&self) -> TokenStream {
237        self.rule
238            .args
239            .iter()
240            .enumerate()
241            .map(|(i, arg)| {
242                let name = &arg.name;
243                let idx_arg = mk_idx_arg(i);
244                quote!(let #name = #idx_arg;)
245            })
246            .collect()
247    }
248
249    fn guards(&self) -> TokenStream {
250        self.rule
251            .guards
252            .iter()
253            .map(|guard| self.guard(guard))
254            .collect()
255    }
256
257    fn guard(&self, guard: &Guard) -> TokenStream {
258        let lbl = &self.lbl;
259        match guard {
260            Guard::If(if_, expr) => quote! {#if_ !(#expr) { break #lbl; }},
261            Guard::IfLet(let_) => quote!(#let_ else { break #lbl; };),
262            Guard::Let(let_) => quote!(#let_;),
263        }
264    }
265}
266
267struct Rule {
268    args: Punctuated<Arg, Token![,]>,
269    output: Output,
270    requires: Option<Requires>,
271    guards: Vec<Guard>,
272}
273
274impl Parse for Rule {
275    fn parse(input: ParseStream) -> syn::Result<Self> {
276        let _: Token![fn] = input.parse()?;
277        let content;
278        parenthesized!(content in input);
279        let inputs = content.parse_terminated(Arg::parse, Token![,])?;
280        let _: Token![->] = input.parse()?;
281        let output = input.parse()?;
282        let requires = if input.peek(kw::requires) { Some(input.parse()?) } else { None };
283        let guards = parse_guards(input)?;
284        Ok(Rule { args: inputs, output, requires, guards })
285    }
286}
287
288/// An arg of the form `a: T`
289struct Arg {
290    name: syn::Ident,
291    bty: syn::Ident,
292}
293
294impl Parse for Arg {
295    fn parse(input: ParseStream) -> syn::Result<Self> {
296        let name = input.parse()?;
297        let _: Token![:] = input.parse()?;
298        let bty = input.parse()?;
299        Ok(Arg { name, bty })
300    }
301}
302
303enum Output {
304    Base(syn::Ident),
305    Indexed(syn::Ident, TokenStream),
306    Exists(syn::Ident, TokenStream),
307    Constr(syn::Ident, TokenStream, TokenStream),
308}
309
310impl Parse for Output {
311    fn parse(input: ParseStream) -> syn::Result<Self> {
312        if input.peek(token::Brace) {
313            let content;
314            braced!(content in input);
315            let bty = content.parse()?;
316            let idx = parse_index(&content)?;
317            let _: Token![|] = content.parse()?;
318            Ok(Output::Constr(bty, idx, content.parse()?))
319        } else {
320            let bty: syn::Ident = input.parse()?;
321            if input.peek(token::Bracket) {
322                Ok(Output::Indexed(bty, parse_index(input)?))
323            } else if input.peek(token::Brace) {
324                let content;
325                braced!(content in input);
326                let _: syn::Ident = content.parse()?;
327                let _: Token![:] = content.parse()?;
328                Ok(Output::Exists(bty, content.parse()?))
329            } else {
330                Ok(Output::Base(bty))
331            }
332        }
333    }
334}
335
336fn parse_index(input: ParseStream) -> syn::Result<TokenStream> {
337    let content;
338    bracketed!(content in input);
339    content.parse()
340}
341
342struct Requires {
343    pred: syn::Expr,
344    reason: syn::Path,
345}
346
347impl Parse for Requires {
348    fn parse(input: ParseStream) -> syn::Result<Self> {
349        let _: kw::requires = input.parse()?;
350        let pred = input.parse()?;
351        let _: Token![=>] = input.parse()?;
352        let reason = input.parse()?;
353        Ok(Requires { pred, reason })
354    }
355}
356
357fn parse_guards(input: ParseStream) -> syn::Result<Vec<Guard>> {
358    let mut guards = vec![];
359    while !input.is_empty() && (input.peek(Token![let]) || input.peek(Token![if])) {
360        guards.push(input.parse()?);
361    }
362    Ok(guards)
363}
364enum Guard {
365    If(Token![if], syn::Expr),
366    IfLet(syn::ExprLet),
367    Let(syn::ExprLet),
368}
369
370impl Parse for Guard {
371    fn parse(input: ParseStream) -> syn::Result<Self> {
372        let lookahead = input.lookahead1();
373        if lookahead.peek(Token![if]) {
374            let if_ = input.parse()?;
375            if input.peek(Token![let]) {
376                Ok(Guard::IfLet(input.parse()?))
377            } else {
378                Ok(Guard::If(if_, input.parse()?))
379            }
380        } else if lookahead.peek(Token![let]) {
381            Ok(Guard::Let(input.parse()?))
382        } else {
383            Err(lookahead.error())
384        }
385    }
386}
387
388fn mk_idx_arg(i: usize) -> Ident {
389    Ident::new(&format!("idx{i}"), Span::call_site())
390}
391
392fn mk_bty_arg(i: usize) -> Ident {
393    Ident::new(&format!("bty{i}"), Span::call_site())
394}
395
396fn is_primitive_type<T>(s: &T) -> bool
397where
398    T: PartialEq<str>,
399{
400    s == "i8"
401        || s == "i16"
402        || s == "i32"
403        || s == "i64"
404        || s == "i128"
405        || s == "u8"
406        || s == "u16"
407        || s == "u32"
408        || s == "u64"
409        || s == "u128"
410        || s == "f32"
411        || s == "f64"
412        || s == "isize"
413        || s == "usize"
414        || s == "bool"
415        || s == "char"
416        || s == "str"
417}
418
419mod kw {
420    syn::custom_keyword!(requires);
421}