1use std::collections::HashMap;
2
3use itertools::Itertools;
4use proc_macro2::{Span, TokenStream};
5use quote::quote;
6use syn::{
7 Ident, Lifetime, Token, braced, bracketed, parenthesized,
8 parse::{Parse, ParseStream},
9 parse_macro_input,
10 punctuated::Punctuated,
11 token,
12};
13
14macro_rules! unwrap_result {
15 ($e:expr) => {{
16 match $e {
17 Ok(e) => e,
18 Err(e) => return e.to_compile_error().into(),
19 }
20 }};
21}
22
23pub fn primop_rules(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
24 let rules = parse_macro_input!(input as Rules);
25
26 let argc = unwrap_result!(rules.check_arg_count());
27
28 let rules = rules.0.into_iter().enumerate().map(|(i, rule)| {
29 Renderer::new(i, rule)
30 .render()
31 .unwrap_or_else(|err| err.to_compile_error())
32 });
33 let args = args(argc);
34 quote! {
35 #[allow(unused_variables, non_snake_case)]
36 |#args| {
37 #(#rules)*
38 None
39 }
40 }
41 .into()
42}
43
44fn args(n: usize) -> TokenStream {
45 let args = (0..n).map(|i| {
46 let bty = mk_bty_arg(i);
47 let idx = mk_idx_arg(i);
48 quote!((#bty, #idx))
49 });
50 quote!([#(#args),*])
51}
52
53struct Rules(Vec<Rule>);
54
55impl Rules {
56 fn check_arg_count(&self) -> syn::Result<usize> {
58 let argc = self.0.first().map(|rule| rule.args.len()).unwrap_or(0);
59 for rule in &self.0 {
60 if rule.args.len() != argc {
61 return Err(syn::Error::new(
62 Span::call_site(),
63 "all rules must have the same number of arguments",
64 ));
65 }
66 }
67 Ok(argc)
68 }
69}
70
71impl Parse for Rules {
72 fn parse(input: ParseStream) -> syn::Result<Self> {
73 let mut v = vec![];
74 while !input.is_empty() {
75 v.push(input.parse()?);
76 }
77 Ok(Rules(v))
78 }
79}
80
81struct Renderer {
82 lbl: Lifetime,
83 rule: Rule,
84 metavars: HashMap<String, Vec<usize>>,
86}
87
88impl Renderer {
89 fn new(i: usize, rule: Rule) -> Self {
90 let mut metavars: HashMap<String, Vec<usize>> = HashMap::new();
91 for (i, input) in rule.args.iter().enumerate() {
92 let bty_str = input.bty.to_string();
93 if !is_primitive_type(&bty_str) {
94 metavars.entry(bty_str).or_default().push(i);
95 }
96 }
97
98 let lbl = syn::Lifetime::new(&format!("'lbl{i}"), Span::call_site());
99
100 Self { lbl, rule, metavars }
101 }
102
103 fn render(&self) -> syn::Result<TokenStream> {
104 let lbl = &self.lbl;
105 let metavar_matching = self.metavar_matching();
106 let primitive_checks = self.check_primitive_types();
107 let declare_metavars = self.declare_metavars();
108 let guards = self.guards();
109 let declare_idxs_names = self.declare_idxs_names();
110 let output_type = self.output_type()?;
111 let precondition = self.precondition();
112 Ok(quote! {
113 #lbl: {
114 #metavar_matching
115 #primitive_checks
116 #declare_metavars
117 #guards
118
119 #declare_idxs_names
120 let precondition = #precondition;
121 let v = Expr::nu();
122 let output_type = #output_type;
123 return Some(MatchedRule { precondition, output_type })
124 }
125 })
126 }
127
128 fn bty_arg_or_prim(&self, ident: &syn::Ident) -> syn::Result<TokenStream> {
129 let ident_str = ident.to_string();
130 if is_primitive_type(ident) {
131 Ok(quote!(BaseTy::from_primitive_str(#ident_str).unwrap()))
132 } else {
133 self.metavars
134 .get(&ident_str)
135 .map(|idxs| {
136 let arg = mk_bty_arg(idxs[0]);
137 quote!(#arg.clone())
138 })
139 .ok_or_else(|| {
140 syn::Error::new(ident.span(), format!("cannot find metavariable `{ident_str}`"))
141 })
142 }
143 }
144
145 fn output_type(&self) -> syn::Result<TokenStream> {
146 let out = match &self.rule.output {
147 Output::Base(bty) => {
148 let bty = self.bty_arg_or_prim(bty)?;
149 quote!(#bty.to_ty())
150 }
151 Output::Indexed(bty, idx) => {
152 let bty = self.bty_arg_or_prim(bty)?;
153 quote!(rty::Ty::indexed( #bty, #idx))
154 }
155 Output::Exists(bty, pred) => {
156 let bty = self.bty_arg_or_prim(bty)?;
157 quote!(rty::Ty::exists_with_constr( #bty, #pred))
158 }
159 Output::Constr(bty, idx, pred) => {
160 let bty = self.bty_arg_or_prim(bty)?;
161 quote!(rty::Ty::constr(#pred, rty::Ty::indexed( #bty, #idx)))
162 }
163 };
164 Ok(out)
165 }
166
167 fn metavar_matching(&self) -> TokenStream {
169 let lbl = &self.lbl;
170 let checks = self.metavars.values().map(|idxs| {
171 let checks = idxs.iter().tuple_windows().map(|(i, j)| {
172 let bty_arg1 = mk_bty_arg(*i);
173 let bty_arg2 = mk_bty_arg(*j);
174 quote! {
175 if #bty_arg2 != #bty_arg1 {
176 break #lbl;
177 }
178 }
179 });
180 quote!(#(#checks)*)
181 });
182 quote!(#(#checks)*)
183 }
184
185 fn check_primitive_types(&self) -> TokenStream {
187 let lbl = &self.lbl;
188 self.rule
189 .args
190 .iter()
191 .enumerate()
192 .flat_map(|(i, arg)| {
193 let bty = &arg.bty;
194 if is_primitive_type(bty) {
195 let bty_str = bty.to_string();
196 let bty_arg = mk_bty_arg(i);
197 Some(quote! {
198 let Some(s) = #bty_arg.primitive_symbol() else {
199 break #lbl;
200 };
201 if s.as_str() != #bty_str {
202 break #lbl;
203 }
204 })
205 } else {
206 None
207 }
208 })
209 .collect()
210 }
211
212 fn precondition(&self) -> TokenStream {
213 if let Some(requires) = &self.rule.requires {
214 let reason = &requires.reason;
215 let pred = &requires.pred;
216 quote!(Some(Pre { reason: #reason, pred: #pred }))
217 } else {
218 quote!(None)
219 }
220 }
221
222 fn declare_metavars(&self) -> TokenStream {
224 self.metavars
225 .iter()
226 .map(|(var, matching_positions)| {
227 let var = syn::Ident::new(var, Span::call_site());
228 let bty_arg = mk_bty_arg(matching_positions[0]);
229 quote! {
230 let #var = #bty_arg;
231 }
232 })
233 .collect()
234 }
235
236 fn declare_idxs_names(&self) -> TokenStream {
237 self.rule
238 .args
239 .iter()
240 .enumerate()
241 .map(|(i, arg)| {
242 let name = &arg.name;
243 let idx_arg = mk_idx_arg(i);
244 quote!(let #name = #idx_arg;)
245 })
246 .collect()
247 }
248
249 fn guards(&self) -> TokenStream {
250 self.rule
251 .guards
252 .iter()
253 .map(|guard| self.guard(guard))
254 .collect()
255 }
256
257 fn guard(&self, guard: &Guard) -> TokenStream {
258 let lbl = &self.lbl;
259 match guard {
260 Guard::If(if_, expr) => quote! {#if_ !(#expr) { break #lbl; }},
261 Guard::IfLet(let_) => quote!(#let_ else { break #lbl; };),
262 Guard::Let(let_) => quote!(#let_;),
263 }
264 }
265}
266
267struct Rule {
268 args: Punctuated<Arg, Token![,]>,
269 output: Output,
270 requires: Option<Requires>,
271 guards: Vec<Guard>,
272}
273
274impl Parse for Rule {
275 fn parse(input: ParseStream) -> syn::Result<Self> {
276 let _: Token![fn] = input.parse()?;
277 let content;
278 parenthesized!(content in input);
279 let inputs = content.parse_terminated(Arg::parse, Token![,])?;
280 let _: Token![->] = input.parse()?;
281 let output = input.parse()?;
282 let requires = if input.peek(kw::requires) { Some(input.parse()?) } else { None };
283 let guards = parse_guards(input)?;
284 Ok(Rule { args: inputs, output, requires, guards })
285 }
286}
287
288struct Arg {
290 name: syn::Ident,
291 bty: syn::Ident,
292}
293
294impl Parse for Arg {
295 fn parse(input: ParseStream) -> syn::Result<Self> {
296 let name = input.parse()?;
297 let _: Token![:] = input.parse()?;
298 let bty = input.parse()?;
299 Ok(Arg { name, bty })
300 }
301}
302
303enum Output {
304 Base(syn::Ident),
305 Indexed(syn::Ident, TokenStream),
306 Exists(syn::Ident, TokenStream),
307 Constr(syn::Ident, TokenStream, TokenStream),
308}
309
310impl Parse for Output {
311 fn parse(input: ParseStream) -> syn::Result<Self> {
312 if input.peek(token::Brace) {
313 let content;
314 braced!(content in input);
315 let bty = content.parse()?;
316 let idx = parse_index(&content)?;
317 let _: Token![|] = content.parse()?;
318 Ok(Output::Constr(bty, idx, content.parse()?))
319 } else {
320 let bty: syn::Ident = input.parse()?;
321 if input.peek(token::Bracket) {
322 Ok(Output::Indexed(bty, parse_index(input)?))
323 } else if input.peek(token::Brace) {
324 let content;
325 braced!(content in input);
326 let _: syn::Ident = content.parse()?;
327 let _: Token![:] = content.parse()?;
328 Ok(Output::Exists(bty, content.parse()?))
329 } else {
330 Ok(Output::Base(bty))
331 }
332 }
333 }
334}
335
336fn parse_index(input: ParseStream) -> syn::Result<TokenStream> {
337 let content;
338 bracketed!(content in input);
339 content.parse()
340}
341
342struct Requires {
343 pred: syn::Expr,
344 reason: syn::Path,
345}
346
347impl Parse for Requires {
348 fn parse(input: ParseStream) -> syn::Result<Self> {
349 let _: kw::requires = input.parse()?;
350 let pred = input.parse()?;
351 let _: Token![=>] = input.parse()?;
352 let reason = input.parse()?;
353 Ok(Requires { pred, reason })
354 }
355}
356
357fn parse_guards(input: ParseStream) -> syn::Result<Vec<Guard>> {
358 let mut guards = vec![];
359 while !input.is_empty() && (input.peek(Token![let]) || input.peek(Token![if])) {
360 guards.push(input.parse()?);
361 }
362 Ok(guards)
363}
364enum Guard {
365 If(Token![if], syn::Expr),
366 IfLet(syn::ExprLet),
367 Let(syn::ExprLet),
368}
369
370impl Parse for Guard {
371 fn parse(input: ParseStream) -> syn::Result<Self> {
372 let lookahead = input.lookahead1();
373 if lookahead.peek(Token![if]) {
374 let if_ = input.parse()?;
375 if input.peek(Token![let]) {
376 Ok(Guard::IfLet(input.parse()?))
377 } else {
378 Ok(Guard::If(if_, input.parse()?))
379 }
380 } else if lookahead.peek(Token![let]) {
381 Ok(Guard::Let(input.parse()?))
382 } else {
383 Err(lookahead.error())
384 }
385 }
386}
387
388fn mk_idx_arg(i: usize) -> Ident {
389 Ident::new(&format!("idx{i}"), Span::call_site())
390}
391
392fn mk_bty_arg(i: usize) -> Ident {
393 Ident::new(&format!("bty{i}"), Span::call_site())
394}
395
396fn is_primitive_type<T>(s: &T) -> bool
397where
398 T: PartialEq<str>,
399{
400 s == "i8"
401 || s == "i16"
402 || s == "i32"
403 || s == "i64"
404 || s == "i128"
405 || s == "u8"
406 || s == "u16"
407 || s == "u32"
408 || s == "u64"
409 || s == "u128"
410 || s == "f32"
411 || s == "f64"
412 || s == "isize"
413 || s == "usize"
414 || s == "bool"
415 || s == "char"
416 || s == "str"
417}
418
419mod kw {
420 syn::custom_keyword!(requires);
421}