flux_infer/
lean_format.rs

1use core::fmt;
2use std::{fmt::Write, iter};
3
4use flux_common::{
5    bug,
6    dbg::{self, as_subscript},
7};
8use flux_middle::{
9    global_env::GlobalEnv,
10    rty::{PrettyMap, PrettyVar},
11};
12use itertools::Itertools;
13use liquid_fixpoint::{FixpointFmt, Identifier, ThyFunc};
14use rustc_data_structures::fx::FxIndexSet;
15use rustc_hir::def_id::DefId;
16
17use crate::fixpoint_encoding::{
18    ClosedSolution, KVarSolutions,
19    fixpoint::{
20        self, AdtId, BinOp, BinRel, ConstDecl, Constant, Constraint, DataDecl, DataField, DataSort,
21        Expr, FunDef, FunSort, KVarDecl, KVid, LocalVar, Pred, Sort, SortCtor, SortDecl, Var,
22    },
23};
24
25#[derive(Debug, Clone, Copy)]
26pub enum BoolMode {
27    Bool,
28    Prop,
29}
30
31pub struct LeanCtxt<'a, 'genv, 'tcx> {
32    pub genv: GlobalEnv<'genv, 'tcx>,
33    pub pretty_var_map: &'a PrettyMap<LocalVar>,
34    pub adt_map: &'a FxIndexSet<DefId>,
35    pub kvar_solutions: &'a KVarSolutions,
36    pub bool_mode: BoolMode,
37}
38
39impl<'a, 'genv, 'tcx> LeanCtxt<'a, 'genv, 'tcx> {
40    pub(crate) fn with_bool_mode(&self, bool_mode: BoolMode) -> Self {
41        LeanCtxt { bool_mode, ..*self }
42    }
43}
44
45pub struct WithLeanCtxt<'a, 'b, 'genv, 'tcx, T> {
46    pub item: T,
47    pub cx: &'a LeanCtxt<'b, 'genv, 'tcx>,
48}
49
50impl<'a, 'b, 'genv, 'tcx, T: LeanFmt> fmt::Display for WithLeanCtxt<'a, 'b, 'genv, 'tcx, T> {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        self.item.lean_fmt(f, self.cx)
53    }
54}
55
56pub trait LeanFmt {
57    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result;
58}
59
60pub struct LeanKConstraint<'a> {
61    pub theorem_name: &'a str,
62    pub kvars: &'a [KVarDecl],
63    pub constr: &'a Constraint,
64}
65
66struct LeanThyFunc<'a>(&'a ThyFunc);
67
68impl LeanFmt for SortDecl {
69    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
70        self.name.lean_fmt(f, cx)?;
71        write!(
72            f,
73            " {} : Type",
74            (0..(self.vars))
75                .map(|i| format!("(t{i} : Type) [Inhabited t{i}]"))
76                .format(" ")
77        )
78    }
79}
80
81impl LeanFmt for ConstDecl {
82    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
83        self.name.lean_fmt(f, cx)?;
84        write!(f, " : {}", WithLeanCtxt { item: &self.sort, cx })
85    }
86}
87
88// TODO(lean-localize-imports) this seems wrong, but related to lack of storing `VariantIdx` in the `DataProj`
89impl LeanFmt for DataField {
90    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
91        write!(
92            f,
93            "({} : {})",
94            self.name.display().to_string().replace("$", "_"),
95            WithLeanCtxt { item: &self.sort, cx }
96        )
97    }
98}
99
100impl LeanFmt for DataSort {
101    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> std::fmt::Result {
102        match self {
103            DataSort::User(def_id) => write!(f, "{}", def_id.name()),
104            DataSort::Tuple(n) => write!(f, "Tupleₓ{}", dbg::as_subscript(n)),
105            DataSort::Adt(adt_id) => {
106                let def_id = cx.adt_map.get_index(adt_id.as_usize()).unwrap();
107                write!(f, "{}", def_id_to_pascal_case(def_id, &cx.genv.tcx()))
108            }
109        }
110    }
111}
112
113impl LeanFmt for DataDecl {
114    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
115        if self.ctors.len() == 1 {
116            writeln!(f, "@[ext]")?;
117            write!(f, "structure ")?;
118            self.name.lean_fmt(f, cx)?;
119            writeln!(
120                f,
121                " {} where",
122                (0..self.vars)
123                    .map(|i| format!("(t{i} : Type) [Inhabited t{i}]"))
124                    .format(" ")
125            )?;
126            let ctor = &self.ctors[0];
127            if let fixpoint::Var::DataCtor(adt_id, _) = &ctor.name {
128                writeln!(
129                    f,
130                    "  mk{}{} ::",
131                    WithLeanCtxt { item: &DataSort::Adt(*adt_id), cx },
132                    as_subscript(0)
133                )?;
134                for (idx, field) in ctor.fields.iter().enumerate() {
135                    writeln!(
136                        f,
137                        "    {} : {} ",
138                        WithLeanCtxt { item: LeanField(*adt_id, idx.try_into().unwrap()), cx },
139                        WithLeanCtxt { item: &field.sort, cx }
140                    )?;
141                }
142            } else {
143                bug!("unexpected ctor {ctor:?} in datadecl");
144            };
145        } else {
146            write!(f, "inductive ")?;
147            self.name.lean_fmt(f, cx)?;
148            writeln!(
149                f,
150                " {} where",
151                (0..self.vars)
152                    .map(|i| format!("(t{i} : Type) [Inhabited t{i}]"))
153                    .format(" ")
154            )?;
155            for data_ctor in &self.ctors {
156                let fixpoint::Var::DataCtor(adt_id, variant_id) = &data_ctor.name else {
157                    bug!("unexpected ctor {data_ctor:?} in datadecl")
158                };
159                write!(f, "| ")?;
160                write!(
161                    f,
162                    " mk{}{} ",
163                    WithLeanCtxt { item: &DataSort::Adt(*adt_id), cx },
164                    as_subscript(variant_id.as_usize()),
165                )?;
166                // data_ctor.name.lean_fmt(f, cx)?;
167                for field in &data_ctor.fields {
168                    write!(f, " ")?;
169                    field.lean_fmt(f, cx)?;
170                }
171                writeln!(f)?;
172            }
173        }
174        Ok(())
175    }
176}
177
178impl<'a> fmt::Display for LeanThyFunc<'a> {
179    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180        match self.0 {
181            ThyFunc::IntToBv8 => write!(f, "BitVec.ofInt 8"),
182            ThyFunc::IntToBv32 => write!(f, "BitVec.ofInt 32"),
183            ThyFunc::IntToBv64 => write!(f, "BitVec.ofInt 64"),
184            ThyFunc::Bv8ToInt | ThyFunc::Bv32ToInt | ThyFunc::Bv64ToInt => {
185                write!(f, "BitVec.toNat")
186            }
187            ThyFunc::BvAdd => write!(f, "BitVec.add"),
188            ThyFunc::BvSub => write!(f, "BitVec.sub"),
189            ThyFunc::BvMul => write!(f, "BitVec.mul"),
190            ThyFunc::BvNeg => write!(f, "BitVec.neg"),
191            ThyFunc::BvSdiv => write!(f, "BitVec.sdiv"),
192            ThyFunc::BvSrem => write!(f, "BitVec.srem"),
193            ThyFunc::BvUdiv => write!(f, "BitVec.udiv"),
194            ThyFunc::BvAnd => write!(f, "BitVec.and"),
195            ThyFunc::BvOr => write!(f, "BitVec.or"),
196            ThyFunc::BvXor => write!(f, "BitVec.xor"),
197            ThyFunc::BvNot => write!(f, "BitVec.not"),
198            ThyFunc::BvSle => write!(f, "BitVec.sle"),
199            ThyFunc::BvSlt => write!(f, "BitVec.slt"),
200            ThyFunc::BvUle => write!(f, "BitVec.ule"),
201            ThyFunc::BvUlt => write!(f, "BitVec.ult"),
202            ThyFunc::BvAshr => write!(f, "BitVec_sshiftRight"),
203            ThyFunc::BvLshr => write!(f, "BitVec_ushiftRight"),
204            ThyFunc::BvShl => write!(f, "BitVec_shiftLeft"),
205            ThyFunc::BvSignExtend(size) => write!(f, "BitVec.signExtend {}", size),
206            ThyFunc::BvZeroExtend(size) => write!(f, "BitVec.zeroExtend {}", size),
207            ThyFunc::BvUrem => write!(f, "BitVec.umod"),
208            ThyFunc::BvSge => write!(f, "BitVec_sge"),
209            ThyFunc::BvSgt => write!(f, "BitVec_sgt"),
210            ThyFunc::BvUge => write!(f, "BitVec_uge"),
211            ThyFunc::BvUgt => write!(f, "BitVec_ugt"),
212            ThyFunc::MapDefault => write!(f, "SmtMap_default"),
213            ThyFunc::MapSelect => write!(f, "SmtMap_select"),
214            ThyFunc::MapStore => write!(f, "SmtMap_store"),
215            func => panic!("Unsupported theory function {}", func),
216        }
217    }
218}
219
220impl<T: LeanFmt> LeanFmt for &T {
221    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
222        (*self).lean_fmt(f, cx)
223    }
224}
225
226struct LeanAdt(AdtId);
227struct LeanDataProj(AdtId, u32);
228struct LeanField(AdtId, u32);
229
230impl LeanFmt for LeanField {
231    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
232        let adt_id = self.0;
233        if let Some(def_id) = cx.adt_map.get_index(adt_id.as_usize())
234            && let Ok(adt_sort_def) = cx.genv.adt_sort_def_of(def_id)
235        {
236            write!(f, "{}", adt_sort_def.struct_variant().field_names()[self.1 as usize])
237        } else {
238            write!(f, "fld{}", as_subscript(self.1 as usize))
239        }
240    }
241}
242
243impl LeanFmt for LeanAdt {
244    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
245        let def_id = cx.adt_map.get_index(self.0.as_usize()).unwrap();
246        write!(f, "{}", def_id_to_pascal_case(def_id, &cx.genv.tcx()))
247    }
248}
249
250impl LeanFmt for LeanDataProj {
251    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
252        write!(
253            f,
254            "{}.{}",
255            WithLeanCtxt { item: LeanAdt(self.0), cx },
256            WithLeanCtxt { item: LeanField(self.0, self.1), cx }
257        )
258    }
259}
260
261impl LeanFmt for Var {
262    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
263        match self {
264            Var::Global(_gvar, Some(def_id)) => {
265                let path = cx
266                    .genv
267                    .tcx()
268                    .def_path(def_id.parent())
269                    .to_filename_friendly_no_crate()
270                    .replace("-", "_");
271                if path.is_empty() {
272                    write!(f, "{}", def_id.name())
273                } else {
274                    write!(f, "{path}_{}", def_id.name())
275                }
276            }
277            Var::DataCtor(adt_id, idx) => {
278                write!(
279                    f,
280                    "{}.mk{}{}",
281                    WithLeanCtxt { item: LeanAdt(*adt_id), cx },
282                    WithLeanCtxt { item: &DataSort::Adt(*adt_id), cx },
283                    as_subscript(idx.as_usize())
284                )
285            }
286            Var::DataProj { adt_id, field } => LeanDataProj(*adt_id, *field).lean_fmt(f, cx),
287            Var::Local(local_var) => {
288                write!(f, "{}", cx.pretty_var_map.get(&PrettyVar::Local(*local_var)))
289            }
290            Var::Param(param) => {
291                write!(f, "{}", cx.pretty_var_map.get(&PrettyVar::Param(*param)))
292            }
293            _ => {
294                write!(f, "{}", self.display().to_string().replace("$", "_"))
295            }
296        }
297    }
298}
299
300impl LeanFmt for Sort {
301    fn lean_fmt(&self, f: &mut std::fmt::Formatter, cx: &LeanCtxt) -> std::fmt::Result {
302        match self {
303            Sort::Int => write!(f, "Int"),
304            Sort::Bool => {
305                match cx.bool_mode {
306                    BoolMode::Bool => write!(f, "Bool"),
307                    BoolMode::Prop => write!(f, "Prop"),
308                }
309            }
310            Sort::Real => write!(f, "Real"),
311            Sort::Str => write!(f, "String"),
312            Sort::Func(f_sort) => {
313                write!(
314                    f,
315                    "({} -> {})",
316                    WithLeanCtxt { item: &f_sort[0], cx },
317                    WithLeanCtxt { item: &f_sort[1], cx }
318                )
319            }
320            Sort::App(sort_ctor, args) => {
321                match sort_ctor {
322                    SortCtor::Data(sort) => {
323                        if args.is_empty() {
324                            sort.lean_fmt(f, cx)
325                        } else {
326                            write!(
327                                f,
328                                "({} {})",
329                                WithLeanCtxt { item: sort, cx },
330                                args.iter()
331                                    .map(|arg| { WithLeanCtxt { item: arg, cx } })
332                                    .format(" ")
333                            )
334                        }
335                    }
336                    SortCtor::Map => {
337                        write!(
338                            f,
339                            "(SmtMap {} {})",
340                            WithLeanCtxt { item: &args[0], cx },
341                            WithLeanCtxt { item: &args[1], cx }
342                        )
343                    }
344                    _ => todo!(),
345                }
346            }
347            Sort::BitVec(bv_size) => {
348                match bv_size.as_ref() {
349                    Sort::BvSize(size) => write!(f, "BitVec {}", size),
350                    s => {
351                        panic!(
352                            "encountered sort {} where bitvec size was expected",
353                            WithLeanCtxt { item: s, cx }
354                        )
355                    }
356                }
357            }
358            Sort::Abs(v, sort) => {
359                write!(
360                    f,
361                    "{{t{v} : Type}} -> [Inhabited t{v}] -> {}",
362                    WithLeanCtxt { item: sort.as_ref(), cx }
363                )
364            }
365            Sort::Var(v) => write!(f, "t{v}"),
366            s => todo!("{:?}", s),
367        }
368    }
369}
370
371impl LeanFmt for Expr {
372    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
373        match self {
374            Expr::Var(v) => v.lean_fmt(f, cx),
375            Expr::Constant(c) => {
376                match c {
377                    Constant::Numeral(n) => write!(f, "{n}",),
378                    Constant::Boolean(b) => {
379                        match cx.bool_mode {
380                            BoolMode::Bool => write!(f, "{}", if *b { "true" } else { "false" }),
381                            BoolMode::Prop => write!(f, "{}", if *b { "True" } else { "False" }),
382                        }
383                    }
384                    Constant::String(s) => write!(f, "{}", s.display()),
385                    Constant::Real(n) => write!(f, "{n}.0"),
386                    Constant::BitVec(bv, size) => write!(f, "{}#{}", bv, size),
387                }
388            }
389            Expr::BinaryOp(bin_op, args) => {
390                let bin_op_str = match bin_op {
391                    BinOp::Add => "+",
392                    BinOp::Sub => "-",
393                    BinOp::Mul => "*",
394                    BinOp::Div => "/",
395                    BinOp::Mod => "%",
396                };
397                write!(f, "(")?;
398                args[0].lean_fmt(f, cx)?;
399                write!(f, " {} ", bin_op_str)?;
400                args[1].lean_fmt(f, cx)?;
401                write!(f, ")")
402            }
403            Expr::Atom(bin_rel, args) => {
404                let bin_rel_str = match bin_rel {
405                    BinRel::Eq => "=",
406                    BinRel::Ne => "≠",
407                    BinRel::Le => "≤",
408                    BinRel::Lt => "<",
409                    BinRel::Ge => "≥",
410                    BinRel::Gt => ">",
411                };
412                write!(f, "(")?;
413                args[0].lean_fmt(f, cx)?;
414                write!(f, " {} ", bin_rel_str)?;
415                args[1].lean_fmt(f, cx)?;
416                write!(f, ")")
417            }
418            Expr::App(function, sort_args, args, out_sort) => {
419                if out_sort.is_some() {
420                    write!(f, "(")?;
421                }
422                write!(f, "(")?;
423                function.as_ref().lean_fmt(f, cx)?;
424                if let Some(sort_args) = sort_args {
425                    for (i, s_arg) in sort_args.iter().enumerate() {
426                        write!(f, " (t{i} := {})", WithLeanCtxt { item: s_arg, cx })?;
427                    }
428                }
429                for arg in args {
430                    write!(f, " ")?;
431                    arg.lean_fmt(f, cx)?;
432                }
433                write!(f, ")")?;
434                if let Some(out_sort) = out_sort {
435                    write!(f, " : (")?;
436                    let sort_cx = cx.with_bool_mode(BoolMode::Bool);
437                    out_sort.lean_fmt(f, &sort_cx)?;
438                    write!(f, "))")?;
439                }
440                Ok(())
441            }
442            Expr::And(exprs) => {
443                write!(f, "(")?;
444                for (i, expr) in exprs.iter().enumerate() {
445                    if i > 0 {
446                        match cx.bool_mode {
447                            BoolMode::Bool => write!(f, " && ")?,
448                            BoolMode::Prop => write!(f, " ∧ ")?,
449                        };
450                    }
451                    expr.lean_fmt(f, cx)?;
452                }
453                write!(f, ")")
454            }
455            Expr::Or(exprs) => {
456                write!(f, "(")?;
457                for (i, expr) in exprs.iter().enumerate() {
458                    if i > 0 {
459                        match cx.bool_mode {
460                            BoolMode::Bool => write!(f, " || ")?,
461                            BoolMode::Prop => write!(f, " ∨ ")?,
462                        };
463                    }
464                    expr.lean_fmt(f, cx)?;
465                }
466                write!(f, ")")
467            }
468            Expr::Neg(inner) => {
469                write!(f, "(-")?;
470                inner.as_ref().lean_fmt(f, cx)?;
471                write!(f, ")")
472            }
473            Expr::IfThenElse(ite) => {
474                let [condition, if_true, if_false] = ite.as_ref();
475                write!(f, "(if ")?;
476                condition.lean_fmt(f, cx)?;
477                write!(f, " then ")?;
478                if_true.lean_fmt(f, cx)?;
479                write!(f, " else ")?;
480                if_false.lean_fmt(f, cx)?;
481                write!(f, ")")
482            }
483            Expr::Not(inner) => {
484                write!(f, "(")?;
485                match cx.bool_mode {
486                    BoolMode::Bool => write!(f, "!")?,
487                    BoolMode::Prop => write!(f, "¬")?,
488                };
489                inner.as_ref().lean_fmt(f, cx)?;
490                write!(f, ")")
491            }
492            Expr::Imp(implication) => {
493                let [lhs, rhs] = implication.as_ref();
494                write!(f, "(")?;
495                lhs.lean_fmt(f, cx)?;
496                write!(f, " -> ")?;
497                rhs.lean_fmt(f, cx)?;
498                write!(f, ")")
499            }
500            Expr::Iff(equiv) => {
501                let [lhs, rhs] = equiv.as_ref();
502                write!(f, "(")?;
503                lhs.lean_fmt(f, cx)?;
504                write!(f, " <-> ")?;
505                rhs.lean_fmt(f, cx)?;
506                write!(f, ")")
507            }
508            Expr::Let(binder, exprs) => {
509                let [def, body] = exprs.as_ref();
510                write!(f, "(let ")?;
511                binder.lean_fmt(f, cx)?;
512                write!(f, " := ")?;
513                def.lean_fmt(f, cx)?;
514                write!(f, "; ")?;
515                body.lean_fmt(f, cx)?;
516                write!(f, ")")
517            }
518            Expr::ThyFunc(thy_func) => {
519                write!(f, "{}", LeanThyFunc(thy_func))
520            }
521            Expr::IsCtor(..) => {
522                todo!("not yet implemented: datatypes in lean")
523            }
524            Expr::Exists(bind, expr) => {
525                write!(f, "(∃ ")?;
526                for (var, sort) in bind {
527                    write!(f, "(")?;
528                    var.lean_fmt(f, cx)?;
529                    write!(f, " : {})", WithLeanCtxt { item: sort, cx })?;
530                }
531                write!(f, ", ")?;
532                expr.lean_fmt(f, cx)?;
533                write!(f, ")")?;
534                Ok(())
535            }
536        }
537    }
538}
539
540impl LeanFmt for FunDef {
541    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
542        let FunDef { name, sort, comment: _, body } = self;
543        write!(f, "def ")?;
544        name.lean_fmt(f, cx)?;
545        if let Some(body) = body {
546            for (arg, arg_sort) in iter::zip(&body.args, &sort.inputs) {
547                write!(f, " (")?;
548                arg.lean_fmt(f, cx)?;
549                write!(f, " : {})", WithLeanCtxt { item: arg_sort, cx })?;
550            }
551            writeln!(f, " : {} :=", WithLeanCtxt { item: &sort.output, cx })?;
552            write!(f, "  ")?;
553            body.expr.lean_fmt(f, cx)?;
554        } else {
555            write!(f, " : {} := sorry", WithLeanCtxt { item: sort, cx })?;
556        }
557        writeln!(f)
558    }
559}
560
561impl LeanFmt for FunSort {
562    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
563        for i in 0..self.params {
564            write!(f, "{{t{i} : Type}} -> [Inhabited t{i}] -> ")?;
565        }
566        if self.inputs.is_empty() {
567            write!(f, "{}", WithLeanCtxt { item: &self.output, cx })
568        } else {
569            write!(
570                f,
571                "{} -> {}",
572                self.inputs.iter().format_with(" -> ", |sort, f| {
573                    f(&format_args!("{}", WithLeanCtxt { item: sort, cx }))
574                }),
575                WithLeanCtxt { item: &self.output, cx }
576            )
577        }
578    }
579}
580
581impl LeanFmt for Pred {
582    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
583        match self {
584            Pred::Expr(expr) => expr.lean_fmt(f, cx),
585            Pred::And(preds) => {
586                write!(f, "(")?;
587                for (i, pred) in preds.iter().enumerate() {
588                    if i > 0 {
589                        write!(f, " ∧ ")?;
590                    }
591                    pred.lean_fmt(f, cx)?;
592                }
593                write!(f, ")")
594            }
595            Pred::KVar(kvid, args) => {
596                write!(f, "({}", kvid.display().to_string().replace("$", "_"))?;
597                for imp in cx
598                    .kvar_solutions
599                    .non_cut_solutions
600                    .get(kvid)
601                    .map(|sol| sol.0.clone())
602                    .unwrap_or(vec![])
603                {
604                    write!(f, " ")?;
605                    imp.0.lean_fmt(f, cx)?;
606                }
607                for arg in args {
608                    write!(f, " ")?;
609                    arg.lean_fmt(f, cx)?;
610                }
611                write!(f, ")")
612            }
613        }
614    }
615}
616
617impl LeanFmt for KVarDecl {
618    fn lean_fmt(&self, f: &mut std::fmt::Formatter, cx: &LeanCtxt) -> std::fmt::Result {
619        let implicits: Vec<_> = cx
620            .kvar_solutions
621            .non_cut_solutions
622            .get(&self.kvid)
623            .map(|solution| solution.0.clone())
624            .unwrap_or(vec![]);
625        let sorts = implicits
626            .iter()
627            .map(|(_, sort)| sort)
628            .chain(&self.sorts)
629            .enumerate()
630            .map(|(i, sort)| format!("(a{i} : {})", WithLeanCtxt { item: sort, cx }))
631            .format(" -> ");
632        write!(f, "∃ {} : {} -> Prop", self.kvid.display().to_string().replace("$", "_"), sorts)
633    }
634}
635
636impl LeanFmt for (&KVid, &ClosedSolution) {
637    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
638        let (kvid, (implicit, (explicit, inner))) = self;
639        write!(f, "def k{} ", kvid.as_usize())?;
640        for (arg, sort) in implicit.iter().chain(explicit) {
641            write!(f, "(")?;
642            arg.lean_fmt(f, cx)?;
643            write!(f, " : {}) ", WithLeanCtxt { item: sort, cx })?;
644        }
645        writeln!(f, ": Prop :=")?;
646        write!(f, "  ")?;
647        inner.lean_fmt(f, cx)?;
648        Ok(())
649    }
650}
651
652impl<'a> LeanFmt for LeanKConstraint<'a> {
653    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
654        let theorem_name = self.theorem_name.replace(".", "_");
655        let namespace = format!("{}KVarSolutions", snake_case_to_pascal_case(&theorem_name));
656        if !cx.kvar_solutions.is_empty() {
657            writeln!(f, "namespace {namespace}\n")?;
658
659            let cx = cx.with_bool_mode(BoolMode::Prop);
660
661            if !cx.kvar_solutions.cut_solutions.is_empty() {
662                writeln!(f, "-- cyclic (cut) kvars")?;
663                for kvar_solution in &cx.kvar_solutions.cut_solutions {
664                    kvar_solution.lean_fmt(f, &cx)?;
665                    writeln!(f)?;
666                }
667            }
668
669            if !cx.kvar_solutions.non_cut_solutions.is_empty() {
670                writeln!(f, "-- acyclic (non-cut) kvars")?;
671                for kvar_solution in &cx.kvar_solutions.non_cut_solutions {
672                    kvar_solution.lean_fmt(f, &cx)?;
673                    writeln!(f)?;
674                }
675            }
676            writeln!(f, "\nend {namespace}\n\n")?;
677            writeln!(f, "open {namespace}\n\n")?;
678        }
679
680        write!(f, "\n\ndef {theorem_name} := ")?;
681
682        if self.kvars.is_empty() {
683            self.constr.lean_fmt(f, cx)
684        } else {
685            write!(
686                f,
687                "{}, ",
688                self.kvars
689                    .iter()
690                    .map(|kvar| { WithLeanCtxt { item: kvar, cx } })
691                    .format(", ")
692            )?;
693            self.constr.lean_fmt(f, cx)
694        }
695    }
696}
697
698impl LeanFmt for Constraint {
699    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
700        let mut fmt_cx = ConstraintFormatter::default();
701        fmt_cx.incr();
702        fmt_cx.newline(f)?;
703        self.fmt_nested(f, cx, &mut fmt_cx)?;
704        fmt_cx.decr();
705        Ok(())
706    }
707}
708
709impl FormatNested for Constraint {
710    fn fmt_nested(
711        &self,
712        f: &mut fmt::Formatter,
713        lean_cx: &LeanCtxt,
714        fmt_cx: &mut ConstraintFormatter,
715    ) -> fmt::Result {
716        match self {
717            Constraint::ForAll(bind, inner) => {
718                let trivial_pred = bind.pred.is_trivially_true();
719                let trivial_bind = bind.name.display().to_string().starts_with("_");
720                if !trivial_bind {
721                    write!(f, "∀ (")?;
722                    bind.name.lean_fmt(f, lean_cx)?;
723                    write!(f, " : {}),", WithLeanCtxt { item: &bind.sort, cx: lean_cx })?;
724                    fmt_cx.incr();
725                    fmt_cx.newline(f)?;
726                }
727                if !trivial_pred {
728                    bind.pred.lean_fmt(f, lean_cx)?;
729                    write!(f, " ->")?;
730                    fmt_cx.incr();
731                    fmt_cx.newline(f)?;
732                }
733                inner.fmt_nested(f, lean_cx, fmt_cx)?;
734                if !trivial_pred {
735                    fmt_cx.decr();
736                }
737                if !trivial_bind {
738                    fmt_cx.decr();
739                }
740                Ok(())
741            }
742            Constraint::Conj(constraints) => {
743                let n = constraints.len();
744                for (i, constraint) in constraints.iter().enumerate() {
745                    write!(f, "(")?;
746                    constraint.fmt_nested(f, lean_cx, fmt_cx)?;
747                    write!(f, ")")?;
748                    if i < n - 1 {
749                        write!(f, " ∧")?;
750                    }
751                    fmt_cx.newline(f)?;
752                }
753                Ok(())
754            }
755            Constraint::Pred(pred, _) => pred.lean_fmt(f, lean_cx),
756        }
757    }
758}
759
760pub trait FormatNested {
761    fn fmt_nested(
762        &self,
763        f: &mut fmt::Formatter,
764        lean_cx: &LeanCtxt,
765        fmt_cx: &mut ConstraintFormatter,
766    ) -> fmt::Result;
767}
768
769#[derive(Default)]
770pub struct ConstraintFormatter {
771    level: u32,
772}
773
774impl ConstraintFormatter {
775    pub fn incr(&mut self) {
776        self.level += 1;
777    }
778
779    pub fn decr(&mut self) {
780        self.level -= 1;
781    }
782
783    pub fn newline(&self, f: &mut fmt::Formatter) -> fmt::Result {
784        f.write_char('\n')?;
785        self.padding(f)
786    }
787
788    pub fn padding(&self, f: &mut fmt::Formatter) -> fmt::Result {
789        for _ in 0..self.level {
790            f.write_str(" ")?;
791        }
792        Ok(())
793    }
794}
795
796pub fn def_id_to_pascal_case(def_id: &DefId, tcx: &rustc_middle::ty::TyCtxt) -> String {
797    let snake = tcx
798        .def_path(*def_id)
799        .to_filename_friendly_no_crate()
800        .replace("-", "_");
801    let pascal_case = snake_case_to_pascal_case(&snake);
802    let re = regex::Regex::new(r"\{impl#(\d+)\}").unwrap();
803    re.replace_all(&pascal_case, "Impl__$1__").to_string()
804}
805
806pub fn snake_case_to_pascal_case(snake: &str) -> String {
807    snake
808        .split('_')
809        .filter(|s| !s.is_empty()) // skip empty segments (handles double underscores)
810        .map(|word| {
811            let mut chars = word.chars();
812            match chars.next() {
813                Some(first) => first.to_ascii_uppercase().to_string() + chars.as_str(),
814                None => String::new(),
815            }
816        })
817        .collect::<String>()
818}