flux_infer/
lean_format.rs

1use core::fmt;
2use std::{fmt::Write, iter, sync::LazyLock};
3
4use flux_common::{
5    bug,
6    dbg::{self, as_subscript},
7};
8use flux_middle::{
9    def_id::FluxDefId,
10    global_env::GlobalEnv,
11    rty::{PrettyMap, PrettyVar},
12};
13use itertools::Itertools;
14use liquid_fixpoint::{FixpointFmt, Identifier, ThyFunc};
15use rustc_data_structures::fx::FxIndexSet;
16use rustc_hash::FxHashMap;
17use rustc_hir::def_id::DefId;
18
19use crate::fixpoint_encoding::{
20    ClosedSolution, InterpretedConst, KVarSolutions,
21    fixpoint::{
22        self, AdtId, BinOp, BinRel, Constant, Constraint, DataDecl, DataField, DataSort, Expr,
23        FunDef, FunSort, GlobalVar, KVarDecl, KVid, LocalVar, Pred, Sort, SortCtor, SortDecl, Var,
24    },
25};
26
27pub struct LeanCtxt<'a, 'genv, 'tcx> {
28    pub genv: GlobalEnv<'genv, 'tcx>,
29    pub pretty_var_map: &'a PrettyMap<LocalVar>,
30    pub adt_map: &'a FxIndexSet<DefId>,
31    pub opaque_adt_map: &'a [(FluxDefId, SortDecl)],
32    pub kvar_solutions: &'a KVarSolutions,
33    /// Maps the per-run `GlobalVar` index of a primop constant to its stable Lean name.
34    pub primop_var_map: &'a FxHashMap<GlobalVar, String>,
35}
36
37pub struct WithLeanCtxt<'a, 'b, 'genv, 'tcx, T> {
38    pub item: T,
39    pub cx: &'a LeanCtxt<'b, 'genv, 'tcx>,
40}
41
42impl<'a, 'b, 'genv, 'tcx, T: LeanFmt> fmt::Display for WithLeanCtxt<'a, 'b, 'genv, 'tcx, T> {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        self.item.lean_fmt(f, self.cx)
45    }
46}
47
48pub trait LeanFmt {
49    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result;
50}
51
52pub struct LeanKConstraint<'a> {
53    pub theorem_name: &'a str,
54    pub kvars: &'a [KVarDecl],
55    pub constr: &'a Constraint,
56    pub should_fail: bool,
57}
58
59struct LeanThyFunc<'a>(&'a ThyFunc);
60
61impl LeanFmt for SortDecl {
62    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
63        self.name.lean_fmt(f, cx)?;
64        write!(
65            f,
66            " {} : Type",
67            (0..(self.vars))
68                .map(|i| format!("(t{i} : Type) [Inhabited t{i}]"))
69                .format(" ")
70        )
71    }
72}
73
74impl LeanFmt for InterpretedConst {
75    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
76        write!(
77            f,
78            "abbrev {} : {} := {}",
79            WithLeanCtxt { item: &self.0.name, cx },
80            WithLeanCtxt { item: &self.0.sort, cx },
81            WithLeanCtxt { item: &self.1, cx }
82        )
83    }
84}
85
86// TODO(lean-localize-imports) this seems wrong, but related to lack of storing `VariantIdx` in the `DataProj`
87impl LeanFmt for DataField {
88    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
89        write!(
90            f,
91            "({} : {})",
92            sanitize_name(&self.name.display().to_string()),
93            WithLeanCtxt { item: &self.sort, cx }
94        )
95    }
96}
97
98impl LeanFmt for DataSort {
99    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> std::fmt::Result {
100        match self {
101            DataSort::User(opaque_id) => {
102                let (def_id, _) = cx.opaque_adt_map.get(opaque_id.as_usize()).unwrap();
103                write!(f, "{}", snake_case_to_pascal_case(def_id.name().as_str()))
104            }
105            DataSort::Tuple(n) => write!(f, "Tupleₓ{}", dbg::as_subscript(n)),
106            DataSort::Adt(adt_id) => {
107                let def_id = cx.adt_map.get_index(adt_id.as_usize()).unwrap();
108                write!(f, "{}", def_id_to_pascal_case(def_id, &cx.genv.tcx()))
109            }
110        }
111    }
112}
113
114impl LeanFmt for DataDecl {
115    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
116        if self.ctors.len() == 1 {
117            writeln!(f, "@[ext]")?;
118            write!(f, "structure ")?;
119            self.name.lean_fmt(f, cx)?;
120            writeln!(
121                f,
122                " {} where",
123                (0..self.vars)
124                    .map(|i| format!("(t{i} : Type) [Inhabited t{i}]"))
125                    .format(" ")
126            )?;
127            let ctor = &self.ctors[0];
128            if let fixpoint::Var::DataCtor(adt_id, _) = &ctor.name {
129                writeln!(
130                    f,
131                    "  mk{}{} ::",
132                    WithLeanCtxt { item: &DataSort::Adt(*adt_id), cx },
133                    as_subscript(0)
134                )?;
135                for (idx, field) in ctor.fields.iter().enumerate() {
136                    writeln!(
137                        f,
138                        "    {} : {} ",
139                        WithLeanCtxt { item: LeanField(*adt_id, idx.try_into().unwrap()), cx },
140                        WithLeanCtxt { item: &field.sort, cx }
141                    )?;
142                }
143            } else {
144                bug!("unexpected ctor {ctor:?} in datadecl");
145            };
146        } else {
147            write!(f, "inductive ")?;
148            self.name.lean_fmt(f, cx)?;
149            writeln!(
150                f,
151                " {} where",
152                (0..self.vars)
153                    .map(|i| format!("(t{i} : Type) [Inhabited t{i}]"))
154                    .format(" ")
155            )?;
156            for data_ctor in &self.ctors {
157                let fixpoint::Var::DataCtor(adt_id, variant_id) = &data_ctor.name else {
158                    bug!("unexpected ctor {data_ctor:?} in datadecl")
159                };
160                write!(f, "| ")?;
161                write!(
162                    f,
163                    " mk{}{} ",
164                    WithLeanCtxt { item: &DataSort::Adt(*adt_id), cx },
165                    as_subscript(variant_id.as_usize()),
166                )?;
167                // data_ctor.name.lean_fmt(f, cx)?;
168                for field in &data_ctor.fields {
169                    write!(f, " ")?;
170                    field.lean_fmt(f, cx)?;
171                }
172                writeln!(f)?;
173            }
174        }
175        Ok(())
176    }
177}
178
179impl<'a> fmt::Display for LeanThyFunc<'a> {
180    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181        match self.0 {
182            ThyFunc::IntToBv8 => write!(f, "BitVec.ofInt 8"),
183            ThyFunc::IntToBv32 => write!(f, "BitVec.ofInt 32"),
184            ThyFunc::IntToBv64 => write!(f, "BitVec.ofInt 64"),
185            ThyFunc::Bv8ToInt | ThyFunc::Bv32ToInt | ThyFunc::Bv64ToInt => {
186                write!(f, "BitVec.toNat")
187            }
188            ThyFunc::BvAdd => write!(f, "BitVec.add"),
189            ThyFunc::BvSub => write!(f, "BitVec.sub"),
190            ThyFunc::BvMul => write!(f, "BitVec.mul"),
191            ThyFunc::BvNeg => write!(f, "BitVec.neg"),
192            ThyFunc::BvSdiv => write!(f, "BitVec.sdiv"),
193            ThyFunc::BvSrem => write!(f, "BitVec.srem"),
194            ThyFunc::BvUdiv => write!(f, "BitVec.udiv"),
195            ThyFunc::BvAnd => write!(f, "BitVec.and"),
196            ThyFunc::BvOr => write!(f, "BitVec.or"),
197            ThyFunc::BvXor => write!(f, "BitVec.xor"),
198            ThyFunc::BvNot => write!(f, "BitVec.not"),
199            ThyFunc::BvSle => write!(f, "BitVec.sle"),
200            ThyFunc::BvSlt => write!(f, "BitVec.slt"),
201            ThyFunc::BvUle => write!(f, "BitVec.ule"),
202            ThyFunc::BvUlt => write!(f, "BitVec.ult"),
203            ThyFunc::BvAshr => write!(f, "BitVec_sshiftRight"),
204            ThyFunc::BvLshr => write!(f, "BitVec_ushiftRight"),
205            ThyFunc::BvShl => write!(f, "BitVec_shiftLeft"),
206            ThyFunc::BvSignExtend(size) => write!(f, "BitVec_signExtend {}", size),
207            ThyFunc::BvZeroExtend(size) => write!(f, "BitVec_zeroExtend {}", size),
208            ThyFunc::BvUrem => write!(f, "BitVec.umod"),
209            ThyFunc::BvSge => write!(f, "BitVec_sge"),
210            ThyFunc::BvSgt => write!(f, "BitVec_sgt"),
211            ThyFunc::BvUge => write!(f, "BitVec_uge"),
212            ThyFunc::BvUgt => write!(f, "BitVec_ugt"),
213            ThyFunc::MapDefault => write!(f, "SmtMap_default"),
214            ThyFunc::MapSelect => write!(f, "SmtMap_select"),
215            ThyFunc::MapStore => write!(f, "SmtMap_store"),
216            func => panic!("Unsupported theory function {}", func),
217        }
218    }
219}
220
221impl<T: LeanFmt> LeanFmt for &T {
222    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
223        (*self).lean_fmt(f, cx)
224    }
225}
226
227struct LeanAdt(AdtId);
228struct LeanDataProj(AdtId, u32);
229struct LeanField(AdtId, u32);
230
231impl LeanFmt for LeanField {
232    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
233        let adt_id = self.0;
234        if let Some(def_id) = cx.adt_map.get_index(adt_id.as_usize())
235            && let Ok(adt_sort_def) = cx.genv.adt_sort_def_of(def_id)
236        {
237            write!(
238                f,
239                "{}",
240                sanitize_name(
241                    adt_sort_def.struct_variant().field_names()[self.1 as usize].as_str()
242                )
243            )
244        } else {
245            write!(f, "fld{}", as_subscript(self.1 as usize))
246        }
247    }
248}
249
250impl LeanFmt for LeanAdt {
251    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
252        let def_id = cx.adt_map.get_index(self.0.as_usize()).unwrap();
253        write!(f, "{}", def_id_to_pascal_case(def_id, &cx.genv.tcx()))
254    }
255}
256
257impl LeanFmt for LeanDataProj {
258    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
259        write!(
260            f,
261            "{}.{}",
262            WithLeanCtxt { item: LeanAdt(self.0), cx },
263            WithLeanCtxt { item: LeanField(self.0, self.1), cx }
264        )
265    }
266}
267
268impl LeanFmt for Var {
269    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
270        match self {
271            Var::Global(_gvar, def_id) => {
272                let path = cx
273                    .genv
274                    .tcx()
275                    .def_path(def_id.parent())
276                    .to_filename_friendly_no_crate();
277                if path.is_empty() {
278                    write!(f, "{}", def_id.name())
279                } else {
280                    write!(f, "{}_{}", sanitize_name(&path), def_id.name())
281                }
282            }
283            Var::Const(gvar, None) => {
284                if let Some(stable) = cx.primop_var_map.get(gvar) {
285                    write!(f, "{stable}")
286                } else {
287                    write!(f, "{}", sanitize_name(&self.display().to_string()))
288                }
289            }
290            Var::Const(_, Some(did)) => {
291                let path = cx.genv.tcx().def_path(*did).to_filename_friendly_no_crate();
292                write!(f, "{}", sanitize_name(&path))
293            }
294            Var::DataCtor(adt_id, idx) => {
295                write!(
296                    f,
297                    "{}.mk{}{}",
298                    WithLeanCtxt { item: LeanAdt(*adt_id), cx },
299                    WithLeanCtxt { item: &DataSort::Adt(*adt_id), cx },
300                    as_subscript(idx.as_usize())
301                )
302            }
303            Var::DataProj { adt_id, field } => LeanDataProj(*adt_id, *field).lean_fmt(f, cx),
304            Var::Local(local_var) => {
305                write!(f, "{}", cx.pretty_var_map.get(&PrettyVar::Local(*local_var)))
306            }
307            Var::Param(param) => {
308                write!(f, "{}", cx.pretty_var_map.get(&PrettyVar::Param(*param)))
309            }
310            _ => {
311                write!(f, "{}", sanitize_name(&self.display().to_string()))
312            }
313        }
314    }
315}
316
317impl LeanFmt for Sort {
318    fn lean_fmt(&self, f: &mut std::fmt::Formatter, cx: &LeanCtxt) -> std::fmt::Result {
319        match self {
320            Sort::Int => write!(f, "Int"),
321            Sort::Bool => write!(f, "Prop"),
322            Sort::Real => write!(f, "Real"),
323            Sort::Str => write!(f, "String"),
324            Sort::Func(f_sort) => {
325                write!(
326                    f,
327                    "({} -> {})",
328                    WithLeanCtxt { item: &f_sort[0], cx },
329                    WithLeanCtxt { item: &f_sort[1], cx }
330                )
331            }
332            Sort::App(sort_ctor, args) => {
333                match sort_ctor {
334                    SortCtor::Data(sort) => {
335                        if args.is_empty() {
336                            sort.lean_fmt(f, cx)
337                        } else {
338                            write!(
339                                f,
340                                "({} {})",
341                                WithLeanCtxt { item: sort, cx },
342                                args.iter()
343                                    .map(|arg| { WithLeanCtxt { item: arg, cx } })
344                                    .format(" ")
345                            )
346                        }
347                    }
348                    SortCtor::Map => {
349                        write!(
350                            f,
351                            "(SmtMap {} {})",
352                            WithLeanCtxt { item: &args[0], cx },
353                            WithLeanCtxt { item: &args[1], cx }
354                        )
355                    }
356                    _ => todo!(),
357                }
358            }
359            Sort::BitVec(bv_size) => {
360                match bv_size.as_ref() {
361                    Sort::BvSize(size) => write!(f, "BitVec {}", size),
362                    s => {
363                        panic!(
364                            "encountered sort {} where bitvec size was expected",
365                            WithLeanCtxt { item: s, cx }
366                        )
367                    }
368                }
369            }
370            Sort::Abs(v, sort) => {
371                write!(
372                    f,
373                    "{{t{v} : Type}} -> [Inhabited t{v}] -> {}",
374                    WithLeanCtxt { item: sort.as_ref(), cx }
375                )
376            }
377            Sort::Var(v) => write!(f, "t{v}"),
378            Sort::BvSize(size) => {
379                panic!("sort BvSize({size}) should only occur as an argument to BitVec")
380            }
381        }
382    }
383}
384
385impl LeanFmt for Expr {
386    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
387        match self {
388            Expr::Var(v) => v.lean_fmt(f, cx),
389            Expr::Constant(c) => {
390                match c {
391                    Constant::Numeral(n) => write!(f, "{n}",),
392                    Constant::Boolean(b) => write!(f, "{}", if *b { "True" } else { "False" }),
393                    Constant::String(s) => write!(f, "{}", s.display()),
394                    Constant::Real(n) => write!(f, "{n}.0"),
395                    Constant::BitVec(bv, size) => write!(f, "{}#{}", bv, size),
396                }
397            }
398            Expr::BinaryOp(bin_op, args) => {
399                let bin_op_str = match bin_op {
400                    BinOp::Add => "+",
401                    BinOp::Sub => "-",
402                    BinOp::Mul => "*",
403                    BinOp::Div => "/",
404                    BinOp::Mod => "%",
405                };
406                write!(f, "(")?;
407                args[0].lean_fmt(f, cx)?;
408                write!(f, " {} ", bin_op_str)?;
409                args[1].lean_fmt(f, cx)?;
410                write!(f, ")")
411            }
412            Expr::Atom(bin_rel, args) => {
413                let bin_rel_str = match bin_rel {
414                    BinRel::Eq => "=",
415                    BinRel::Ne => "≠",
416                    BinRel::Le => "≤",
417                    BinRel::Lt => "<",
418                    BinRel::Ge => "≥",
419                    BinRel::Gt => ">",
420                };
421                write!(f, "(")?;
422                args[0].lean_fmt(f, cx)?;
423                write!(f, " {} ", bin_rel_str)?;
424                args[1].lean_fmt(f, cx)?;
425                write!(f, ")")
426            }
427            Expr::App(function, sort_args, args, out_sort) => {
428                if out_sort.is_some() {
429                    write!(f, "(")?;
430                }
431                write!(f, "(")?;
432                function.as_ref().lean_fmt(f, cx)?;
433                if let Some(sort_args) = sort_args {
434                    for (i, s_arg) in sort_args.iter().enumerate() {
435                        if matches!(s_arg, Sort::BvSize(..)) {
436                            continue;
437                        }
438                        write!(f, " (t{i} := {})", WithLeanCtxt { item: s_arg, cx })?;
439                    }
440                }
441                for arg in args {
442                    write!(f, " ")?;
443                    arg.lean_fmt(f, cx)?;
444                }
445                write!(f, ")")?;
446                if let Some(out_sort) = out_sort {
447                    write!(f, " : (")?;
448                    out_sort.lean_fmt(f, cx)?;
449                    write!(f, "))")?;
450                }
451                Ok(())
452            }
453            Expr::And(exprs) => {
454                write!(f, "(")?;
455                for (i, expr) in exprs.iter().enumerate() {
456                    if i > 0 {
457                        write!(f, " ∧ ")?;
458                    }
459                    expr.lean_fmt(f, cx)?;
460                }
461                write!(f, ")")
462            }
463            Expr::Or(exprs) => {
464                write!(f, "(")?;
465                for (i, expr) in exprs.iter().enumerate() {
466                    if i > 0 {
467                        write!(f, " ∨ ")?;
468                    }
469                    expr.lean_fmt(f, cx)?;
470                }
471                write!(f, ")")
472            }
473            Expr::Neg(inner) => {
474                write!(f, "(-")?;
475                inner.as_ref().lean_fmt(f, cx)?;
476                write!(f, ")")
477            }
478            Expr::IfThenElse(ite) => {
479                let [condition, if_true, if_false] = ite.as_ref();
480                write!(f, "(if ")?;
481                condition.lean_fmt(f, cx)?;
482                write!(f, " then ")?;
483                if_true.lean_fmt(f, cx)?;
484                write!(f, " else ")?;
485                if_false.lean_fmt(f, cx)?;
486                write!(f, ")")
487            }
488            Expr::Not(inner) => {
489                write!(f, "(")?;
490                write!(f, "¬")?;
491                inner.as_ref().lean_fmt(f, cx)?;
492                write!(f, ")")
493            }
494            Expr::Imp(implication) => {
495                let [lhs, rhs] = implication.as_ref();
496                write!(f, "(")?;
497                lhs.lean_fmt(f, cx)?;
498                write!(f, " -> ")?;
499                rhs.lean_fmt(f, cx)?;
500                write!(f, ")")
501            }
502            Expr::Iff(equiv) => {
503                let [lhs, rhs] = equiv.as_ref();
504                write!(f, "(")?;
505                lhs.lean_fmt(f, cx)?;
506                write!(f, " <-> ")?;
507                rhs.lean_fmt(f, cx)?;
508                write!(f, ")")
509            }
510            Expr::Let(binder, exprs) => {
511                let [def, body] = exprs.as_ref();
512                write!(f, "(let ")?;
513                binder.lean_fmt(f, cx)?;
514                write!(f, " := ")?;
515                def.lean_fmt(f, cx)?;
516                write!(f, "; ")?;
517                body.lean_fmt(f, cx)?;
518                write!(f, ")")
519            }
520            Expr::ThyFunc(thy_func) => {
521                write!(f, "{}", LeanThyFunc(thy_func))
522            }
523            Expr::IsCtor(..) => {
524                todo!("not yet implemented: datatypes in lean")
525            }
526            Expr::Exists(bind, expr) => {
527                write!(f, "(∃ ")?;
528                for (var, sort) in bind {
529                    write!(f, "(")?;
530                    var.lean_fmt(f, cx)?;
531                    write!(f, " : {})", WithLeanCtxt { item: sort, cx })?;
532                }
533                write!(f, ", ")?;
534                expr.lean_fmt(f, cx)?;
535                write!(f, ")")?;
536                Ok(())
537            }
538        }
539    }
540}
541
542impl LeanFmt for FunDef {
543    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
544        let FunDef { name, sort, comment: _, body } = self;
545        write!(f, "noncomputable def ")?;
546        name.lean_fmt(f, cx)?;
547        if let Some(body) = body {
548            for (arg, arg_sort) in iter::zip(&body.args, &sort.inputs) {
549                write!(f, " (")?;
550                arg.lean_fmt(f, cx)?;
551                write!(f, " : {})", WithLeanCtxt { item: arg_sort, cx })?;
552            }
553            writeln!(f, " : {} :=", WithLeanCtxt { item: &sort.output, cx })?;
554            write!(f, "  ")?;
555            body.expr.lean_fmt(f, cx)?;
556        } else {
557            write!(f, " : {} := sorry", WithLeanCtxt { item: sort, cx })?;
558        }
559        writeln!(f)
560    }
561}
562
563impl LeanFmt for FunSort {
564    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
565        for i in 0..self.params {
566            write!(f, "{{t{i} : Type}} -> [Inhabited t{i}] -> ")?;
567        }
568        if self.inputs.is_empty() {
569            write!(f, "{}", WithLeanCtxt { item: &self.output, cx })
570        } else {
571            write!(
572                f,
573                "{} -> {}",
574                self.inputs.iter().format_with(" -> ", |sort, f| {
575                    f(&format_args!("{}", WithLeanCtxt { item: sort, cx }))
576                }),
577                WithLeanCtxt { item: &self.output, cx }
578            )
579        }
580    }
581}
582
583impl LeanFmt for Pred {
584    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
585        match self {
586            Pred::Expr(expr) => expr.lean_fmt(f, cx),
587            Pred::And(preds) => {
588                write!(f, "(")?;
589                for (i, pred) in preds.iter().enumerate() {
590                    if i > 0 {
591                        write!(f, " ∧ ")?;
592                    }
593                    pred.lean_fmt(f, cx)?;
594                }
595                write!(f, ")")
596            }
597            Pred::KVar(kvid, args) => {
598                write!(f, "({}", sanitize_name(&kvid.display().to_string()))?;
599                for imp in cx
600                    .kvar_solutions
601                    .non_cut_solutions
602                    .get(kvid)
603                    .map(|sol| sol.0.clone())
604                    .unwrap_or(vec![])
605                {
606                    write!(f, " ")?;
607                    imp.0.lean_fmt(f, cx)?;
608                }
609                for arg in args {
610                    write!(f, " ")?;
611                    arg.lean_fmt(f, cx)?;
612                }
613                write!(f, ")")
614            }
615        }
616    }
617}
618
619impl LeanFmt for KVarDecl {
620    fn lean_fmt(&self, f: &mut std::fmt::Formatter, cx: &LeanCtxt) -> std::fmt::Result {
621        let implicits: Vec<_> = cx
622            .kvar_solutions
623            .non_cut_solutions
624            .get(&self.kvid)
625            .map(|solution| solution.0.clone())
626            .unwrap_or(vec![]);
627        let sorts = implicits
628            .iter()
629            .map(|(_, sort)| sort)
630            .chain(&self.sorts)
631            .enumerate()
632            .map(|(i, sort)| format!("(a{i} : {})", WithLeanCtxt { item: sort, cx }))
633            .format(" -> ");
634        write!(f, "∃ {} : {} -> Prop", sanitize_name(&self.kvid.display().to_string()), sorts)
635    }
636}
637
638impl LeanFmt for (&KVid, &ClosedSolution) {
639    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
640        let (kvid, (implicit, (explicit, inner))) = self;
641        write!(f, "def k{} ", kvid.as_usize())?;
642        for (arg, sort) in implicit.iter().chain(explicit) {
643            write!(f, "(")?;
644            arg.lean_fmt(f, cx)?;
645            write!(f, " : {}) ", WithLeanCtxt { item: sort, cx })?;
646        }
647        writeln!(f, ": Prop :=")?;
648        write!(f, "  ")?;
649        inner.lean_fmt(f, cx)?;
650        Ok(())
651    }
652}
653
654impl<'a> LeanFmt for LeanKConstraint<'a> {
655    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
656        let theorem_name = self.theorem_name.replace(".", "_");
657        let namespace = format!("{}KVarSolutions", snake_case_to_pascal_case(&theorem_name));
658        if !cx.kvar_solutions.is_empty() {
659            writeln!(f, "namespace {namespace}\n")?;
660
661            if !cx.kvar_solutions.cut_solutions.is_empty() {
662                writeln!(f, "-- cyclic (cut) kvars")?;
663                for kvar_solution in &cx.kvar_solutions.cut_solutions {
664                    kvar_solution.lean_fmt(f, cx)?;
665                    writeln!(f)?;
666                }
667            }
668
669            if !cx.kvar_solutions.non_cut_solutions.is_empty() {
670                writeln!(f, "-- acyclic (non-cut) kvars")?;
671                for kvar_solution in &cx.kvar_solutions.non_cut_solutions {
672                    kvar_solution.lean_fmt(f, cx)?;
673                    writeln!(f)?;
674                }
675            }
676            writeln!(f, "\nend {namespace}\n\n")?;
677            writeln!(f, "open {namespace}\n\n")?;
678        }
679
680        write!(f, "\n\ndef {theorem_name} := ")?;
681
682        if self.should_fail {
683            write!(f, "¬")?;
684        }
685
686        if self.kvars.is_empty() {
687            self.constr.lean_fmt(f, cx)
688        } else {
689            write!(
690                f,
691                "{}, ",
692                self.kvars
693                    .iter()
694                    .map(|kvar| { WithLeanCtxt { item: kvar, cx } })
695                    .format(", ")
696            )?;
697            self.constr.lean_fmt(f, cx)
698        }
699    }
700}
701
702impl LeanFmt for Constraint {
703    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
704        let mut fmt_cx = ConstraintFormatter::default();
705        fmt_cx.incr();
706        fmt_cx.newline(f)?;
707        self.fmt_nested(f, cx, &mut fmt_cx)?;
708        fmt_cx.decr();
709        Ok(())
710    }
711}
712
713impl FormatNested for Constraint {
714    fn fmt_nested(
715        &self,
716        f: &mut fmt::Formatter,
717        lean_cx: &LeanCtxt,
718        fmt_cx: &mut ConstraintFormatter,
719    ) -> fmt::Result {
720        match self {
721            Constraint::ForAll(bind, inner) => {
722                let trivial_pred = bind.pred.is_trivially_true();
723                let trivial_bind = bind.name.display().to_string().starts_with("_");
724                if !trivial_bind {
725                    write!(f, "∀ (")?;
726                    bind.name.lean_fmt(f, lean_cx)?;
727                    write!(f, " : {}),", WithLeanCtxt { item: &bind.sort, cx: lean_cx })?;
728                    fmt_cx.incr();
729                    fmt_cx.newline(f)?;
730                }
731                if !trivial_pred {
732                    bind.pred.lean_fmt(f, lean_cx)?;
733                    write!(f, " ->")?;
734                    fmt_cx.incr();
735                    fmt_cx.newline(f)?;
736                }
737                inner.fmt_nested(f, lean_cx, fmt_cx)?;
738                if !trivial_pred {
739                    fmt_cx.decr();
740                }
741                if !trivial_bind {
742                    fmt_cx.decr();
743                }
744                Ok(())
745            }
746            Constraint::Conj(constraints) => {
747                let n = constraints.len();
748                for (i, constraint) in constraints.iter().enumerate() {
749                    write!(f, "(")?;
750                    constraint.fmt_nested(f, lean_cx, fmt_cx)?;
751                    write!(f, ")")?;
752                    if i < n - 1 {
753                        write!(f, " ∧")?;
754                    }
755                    fmt_cx.newline(f)?;
756                }
757                Ok(())
758            }
759            Constraint::Pred(pred, _) => pred.lean_fmt(f, lean_cx),
760        }
761    }
762}
763
764pub trait FormatNested {
765    fn fmt_nested(
766        &self,
767        f: &mut fmt::Formatter,
768        lean_cx: &LeanCtxt,
769        fmt_cx: &mut ConstraintFormatter,
770    ) -> fmt::Result;
771}
772
773#[derive(Default)]
774pub struct ConstraintFormatter {
775    level: u32,
776}
777
778impl ConstraintFormatter {
779    pub fn incr(&mut self) {
780        self.level += 1;
781    }
782
783    pub fn decr(&mut self) {
784        self.level -= 1;
785    }
786
787    pub fn newline(&self, f: &mut fmt::Formatter) -> fmt::Result {
788        f.write_char('\n')?;
789        self.padding(f)
790    }
791
792    pub fn padding(&self, f: &mut fmt::Formatter) -> fmt::Result {
793        for _ in 0..self.level {
794            f.write_str(" ")?;
795        }
796        Ok(())
797    }
798}
799
800static IMPL_RE: LazyLock<regex::Regex> =
801    std::sync::LazyLock::new(|| regex::Regex::new(r"\{impl#(\d+)\}").unwrap());
802
803fn is_reserved(name: &str) -> bool {
804    matches!(name, "end" | "def")
805}
806
807fn sanitize_name(name: &str) -> String {
808    if is_reserved(name) {
809        format!("{name}_")
810    } else {
811        IMPL_RE
812            .replace_all(name, "impl_${1}")
813            .replace("-", "_")
814            .replace("$", "_")
815    }
816}
817
818pub fn def_id_to_pascal_case(def_id: &DefId, tcx: &rustc_middle::ty::TyCtxt) -> String {
819    let snake = tcx
820        .def_path(*def_id)
821        .to_filename_friendly_no_crate()
822        .replace("-", "_");
823    let pascal_case = snake_case_to_pascal_case(&snake);
824    IMPL_RE
825        .replace_all(&pascal_case, "Impl__${1}__")
826        .to_string()
827}
828
829pub fn snake_case_to_pascal_case(snake: &str) -> String {
830    snake
831        .split('_')
832        .filter(|s| !s.is_empty()) // skip empty segments (handles double underscores)
833        .map(|word| {
834            let mut chars = word.chars();
835            match chars.next() {
836                Some(first) => first.to_ascii_uppercase().to_string() + chars.as_str(),
837                None => String::new(),
838            }
839        })
840        .collect::<String>()
841}