flux_infer/
lean_format.rs

1use core::fmt;
2use std::fmt::Write;
3
4use flux_middle::{
5    global_env::GlobalEnv,
6    rty::{PrettyMap, PrettyVar},
7};
8use itertools::Itertools;
9use liquid_fixpoint::{FixpointFmt, Identifier, ThyFunc};
10
11use crate::fixpoint_encoding::fixpoint::{
12    BinOp, BinRel, ConstDecl, Constant, Constraint, DataDecl, DataField, DataSort, Expr, FunDef,
13    KVarDecl, LocalVar, Pred, Sort, SortCtor, SortDecl, Var,
14};
15
16pub struct LeanCtxt<'a, 'genv, 'tcx> {
17    pub genv: GlobalEnv<'genv, 'tcx>,
18    pub pretty_var_map: &'a PrettyMap<LocalVar>,
19}
20
21pub struct WithLeanCtxt<'a, 'b, 'genv, 'tcx, T> {
22    pub item: T,
23    pub cx: &'a LeanCtxt<'b, 'genv, 'tcx>,
24}
25
26impl<'a, 'b, 'genv, 'tcx, T: LeanFmt> fmt::Display for WithLeanCtxt<'a, 'b, 'genv, 'tcx, T> {
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        self.item.lean_fmt(f, self.cx)
29    }
30}
31
32pub trait LeanFmt {
33    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result;
34}
35
36struct LeanSort<'a>(&'a Sort);
37pub struct LeanKConstraint<'a> {
38    pub kvars: &'a [KVarDecl],
39    pub constr: &'a Constraint,
40}
41
42pub struct LeanSortVar<'a>(pub &'a DataSort);
43struct LeanKVarDecl<'a>(&'a KVarDecl);
44struct LeanThyFunc<'a>(&'a ThyFunc);
45
46impl LeanFmt for &SortDecl {
47    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
48        (*self).lean_fmt(f, cx)
49    }
50}
51
52impl LeanFmt for SortDecl {
53    fn lean_fmt(&self, f: &mut fmt::Formatter, _cx: &LeanCtxt) -> fmt::Result {
54        write!(
55            f,
56            "{} {} : Type",
57            LeanSortVar(&self.name),
58            (0..(self.vars))
59                .map(|i| format!("(t{i} : Type) [Inhabited t{i}]"))
60                .format(" ")
61        )
62    }
63}
64
65impl LeanFmt for &ConstDecl {
66    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
67        (*self).lean_fmt(f, cx)
68    }
69}
70
71impl LeanFmt for ConstDecl {
72    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
73        self.name.lean_fmt(f, cx)?;
74        write!(f, " : {}", LeanSort(&self.sort))
75    }
76}
77
78impl LeanFmt for &DataField {
79    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
80        (*self).lean_fmt(f, cx)
81    }
82}
83
84impl LeanFmt for DataField {
85    fn lean_fmt(&self, f: &mut fmt::Formatter, _cx: &LeanCtxt) -> fmt::Result {
86        write!(
87            f,
88            "({} : {})",
89            self.name.display().to_string().replace("$", "_"),
90            LeanSort(&self.sort)
91        )
92    }
93}
94
95impl<'a> fmt::Display for LeanSortVar<'a> {
96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97        match self.0 {
98            DataSort::User(def_id) => write!(f, "{}", def_id.name()),
99            _ => write!(f, "{}", self.0.display().to_string().replace("$", "_")),
100        }
101    }
102}
103
104impl LeanFmt for &DataDecl {
105    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
106        (*self).lean_fmt(f, cx)
107    }
108}
109
110impl LeanFmt for DataDecl {
111    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
112        if self.ctors.len() == 1 {
113            writeln!(f, "@[ext]")?;
114            writeln!(
115                f,
116                "structure {} {} where",
117                LeanSortVar(&self.name),
118                (0..self.vars)
119                    .map(|i| format!("(t{i} : Type) [Inhabited t{i}]"))
120                    .format(" ")
121            )?;
122            writeln!(f, "  {}::", self.ctors[0].name.display().to_string().replace("$", "_"),)?;
123            for field in &self.ctors[0].fields {
124                write!(f, "  ")?;
125                field.lean_fmt(f, cx)?;
126                writeln!(f)?;
127            }
128        } else {
129            writeln!(
130                f,
131                "inductive {} {} where",
132                LeanSortVar(&self.name),
133                (0..self.vars)
134                    .map(|i| format!("(t{i} : Type) [Inhabited t{i}]"))
135                    .format(" ")
136            )?;
137            for data_ctor in &self.ctors {
138                write!(f, "| ")?;
139                data_ctor.name.lean_fmt(f, cx)?;
140                for field in &data_ctor.fields {
141                    write!(f, " ")?;
142                    field.lean_fmt(f, cx)?;
143                }
144                writeln!(f)?;
145            }
146        }
147        Ok(())
148    }
149}
150
151impl<'a> fmt::Display for LeanThyFunc<'a> {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        match self.0 {
154            ThyFunc::IntToBv8 => write!(f, "BitVec.ofInt 8"),
155            ThyFunc::IntToBv32 => write!(f, "BitVec.ofInt 32"),
156            ThyFunc::IntToBv64 => write!(f, "BitVec.ofInt 64"),
157            ThyFunc::Bv8ToInt | ThyFunc::Bv32ToInt | ThyFunc::Bv64ToInt => {
158                write!(f, "BitVec.toNat")
159            }
160            ThyFunc::BvAdd => write!(f, "BitVec.add"),
161            ThyFunc::BvSub => write!(f, "BitVec.sub"),
162            ThyFunc::BvMul => write!(f, "BitVec.mul"),
163            ThyFunc::BvNeg => write!(f, "BitVec.neg"),
164            ThyFunc::BvSdiv => write!(f, "BitVec.sdiv"),
165            ThyFunc::BvSrem => write!(f, "BitVec.srem"),
166            ThyFunc::BvUdiv => write!(f, "BitVec.udiv"),
167            ThyFunc::BvAnd => write!(f, "BitVec.and"),
168            ThyFunc::BvOr => write!(f, "BitVec.or"),
169            ThyFunc::BvXor => write!(f, "BitVec.xor"),
170            ThyFunc::BvNot => write!(f, "BitVec.not"),
171            ThyFunc::BvSle => write!(f, "BitVec.sle"),
172            ThyFunc::BvSlt => write!(f, "BitVec.slt"),
173            ThyFunc::BvUle => write!(f, "BitVec.ule"),
174            ThyFunc::BvUlt => write!(f, "BitVec.ult"),
175            ThyFunc::BvAshr => write!(f, "BitVec_sshiftRight"),
176            ThyFunc::BvLshr => write!(f, "BitVec_ushiftRight"),
177            ThyFunc::BvShl => write!(f, "BitVec_shiftLeft"),
178            ThyFunc::BvSignExtend(size) => write!(f, "BitVec.signExtend {}", size),
179            ThyFunc::BvZeroExtend(size) => write!(f, "BitVec.zeroExtend {}", size),
180            ThyFunc::BvUrem => write!(f, "BitVec.umod"),
181            ThyFunc::BvSge => write!(f, "BitVec_sge"),
182            ThyFunc::BvSgt => write!(f, "BitVec_sgt"),
183            ThyFunc::BvUge => write!(f, "BitVec_uge"),
184            ThyFunc::BvUgt => write!(f, "BitVec_ugt"),
185            ThyFunc::MapDefault => write!(f, "SmtMap_default"),
186            ThyFunc::MapSelect => write!(f, "SmtMap_select"),
187            ThyFunc::MapStore => write!(f, "SmtMap_store"),
188            func => panic!("Unsupported theory function {}", func),
189        }
190    }
191}
192
193impl LeanFmt for &Var {
194    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
195        (*self).lean_fmt(f, cx)
196    }
197}
198
199impl LeanFmt for Var {
200    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
201        match self {
202            Var::Global(_gvar, Some(def_id)) => {
203                let path = cx
204                    .genv
205                    .tcx()
206                    .def_path(def_id.parent())
207                    .to_filename_friendly_no_crate()
208                    .replace("-", "_");
209                if path.is_empty() {
210                    write!(f, "{}", def_id.name())
211                } else {
212                    write!(f, "{path}_{}", def_id.name())
213                }
214            }
215            Var::DataCtor(adt_id, _) | Var::DataProj { adt_id, field: _ } => {
216                write!(
217                    f,
218                    "{}.{}",
219                    LeanSortVar(&DataSort::Adt(*adt_id)),
220                    self.display().to_string().replace("$", "_")
221                )
222            }
223            Var::Local(local_var) => {
224                write!(f, "{}", cx.pretty_var_map.get(&PrettyVar::Local(*local_var)))
225            }
226            Var::Param(param) => {
227                write!(f, "{}", cx.pretty_var_map.get(&PrettyVar::Param(*param)))
228            }
229            _ => {
230                write!(f, "{}", self.display().to_string().replace("$", "_"))
231            }
232        }
233    }
234}
235
236impl<'a> fmt::Display for LeanSort<'a> {
237    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238        match self.0 {
239            Sort::Int => write!(f, "Int"),
240            Sort::Bool => write!(f, "Prop"),
241            Sort::Real => write!(f, "Real"),
242            Sort::Str => write!(f, "String"),
243            Sort::Func(f_sort) => {
244                write!(f, "({} -> {})", LeanSort(&f_sort[0]), LeanSort(&f_sort[1]))
245            }
246            Sort::App(sort_ctor, args) => {
247                match sort_ctor {
248                    SortCtor::Data(sort) => {
249                        if args.is_empty() {
250                            write!(f, "{}", LeanSortVar(sort))
251                        } else {
252                            write!(
253                                f,
254                                "({} {})",
255                                LeanSortVar(sort),
256                                args.iter().map(LeanSort).format(" ")
257                            )
258                        }
259                    }
260                    SortCtor::Map => {
261                        write!(f, "(SmtMap {} {})", LeanSort(&args[0]), LeanSort(&args[1]))
262                    }
263                    _ => todo!(),
264                }
265            }
266            Sort::BitVec(bv_size) => {
267                match bv_size.as_ref() {
268                    Sort::BvSize(size) => write!(f, "BitVec {}", size),
269                    s => panic!("encountered sort {} where bitvec size was expected", LeanSort(s)),
270                }
271            }
272            Sort::Abs(v, sort) => {
273                write!(f, "{{t{v} : Type}} -> [Inhabited t{v}] -> {}", LeanSort(sort.as_ref()))
274            }
275            Sort::Var(v) => write!(f, "t{v}"),
276            s => todo!("{:?}", s),
277        }
278    }
279}
280
281impl LeanFmt for &Expr {
282    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
283        (*self).lean_fmt(f, cx)
284    }
285}
286
287impl LeanFmt for Expr {
288    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
289        match self {
290            Expr::Var(v) => v.lean_fmt(f, cx),
291            Expr::Constant(c) => {
292                match c {
293                    Constant::Numeral(n) => write!(f, "{n}",),
294                    Constant::Boolean(b) => write!(f, "{}", if *b { "True" } else { "False" }),
295                    Constant::String(s) => write!(f, "{}", s.display()),
296                    Constant::Real(n) => write!(f, "{n}.0"),
297                    Constant::BitVec(bv, size) => write!(f, "{}#{}", bv, size),
298                }
299            }
300            Expr::BinaryOp(bin_op, args) => {
301                let bin_op_str = match bin_op {
302                    BinOp::Add => "+",
303                    BinOp::Sub => "-",
304                    BinOp::Mul => "*",
305                    BinOp::Div => "/",
306                    BinOp::Mod => "%",
307                };
308                write!(f, "(")?;
309                args[0].lean_fmt(f, cx)?;
310                write!(f, " {} ", bin_op_str)?;
311                args[1].lean_fmt(f, cx)?;
312                write!(f, ")")
313            }
314            Expr::Atom(bin_rel, args) => {
315                let bin_rel_str = match bin_rel {
316                    BinRel::Eq => "=",
317                    BinRel::Ne => "≠",
318                    BinRel::Le => "≤",
319                    BinRel::Lt => "<",
320                    BinRel::Ge => "≥",
321                    BinRel::Gt => ">",
322                };
323                write!(f, "(")?;
324                args[0].lean_fmt(f, cx)?;
325                write!(f, " {} ", bin_rel_str)?;
326                args[1].lean_fmt(f, cx)?;
327                write!(f, ")")
328            }
329            Expr::App(function, sort_args, args) => {
330                write!(f, "(")?;
331                function.as_ref().lean_fmt(f, cx)?;
332                if let Some(sort_args) = sort_args {
333                    for (i, s_arg) in sort_args.iter().enumerate() {
334                        write!(f, " (t{i} := {})", LeanSort(s_arg))?;
335                    }
336                }
337                for arg in args {
338                    write!(f, " ")?;
339                    arg.lean_fmt(f, cx)?;
340                }
341                write!(f, ")")
342            }
343            Expr::And(exprs) => {
344                write!(f, "(")?;
345                for (i, expr) in exprs.iter().enumerate() {
346                    if i > 0 {
347                        write!(f, " && ")?;
348                    }
349                    expr.lean_fmt(f, cx)?;
350                }
351                write!(f, ")")
352            }
353            Expr::Or(exprs) => {
354                write!(f, "(")?;
355                for (i, expr) in exprs.iter().enumerate() {
356                    if i > 0 {
357                        write!(f, " || ")?;
358                    }
359                    expr.lean_fmt(f, cx)?;
360                }
361                write!(f, ")")
362            }
363            Expr::Neg(inner) => {
364                write!(f, "(-")?;
365                inner.as_ref().lean_fmt(f, cx)?;
366                write!(f, ")")
367            }
368            Expr::IfThenElse(ite) => {
369                let [condition, if_true, if_false] = ite.as_ref();
370                write!(f, "(if ")?;
371                condition.lean_fmt(f, cx)?;
372                write!(f, " then ")?;
373                if_true.lean_fmt(f, cx)?;
374                write!(f, " else ")?;
375                if_false.lean_fmt(f, cx)?;
376                write!(f, ")")
377            }
378            Expr::Not(inner) => {
379                write!(f, "(¬")?;
380                inner.as_ref().lean_fmt(f, cx)?;
381                write!(f, ")")
382            }
383            Expr::Imp(implication) => {
384                let [lhs, rhs] = implication.as_ref();
385                write!(f, "(")?;
386                lhs.lean_fmt(f, cx)?;
387                write!(f, " -> ")?;
388                rhs.lean_fmt(f, cx)?;
389                write!(f, ")")
390            }
391            Expr::Iff(equiv) => {
392                let [lhs, rhs] = equiv.as_ref();
393                write!(f, "(")?;
394                lhs.lean_fmt(f, cx)?;
395                write!(f, " <-> ")?;
396                rhs.lean_fmt(f, cx)?;
397                write!(f, ")")
398            }
399            Expr::Let(binder, exprs) => {
400                let [def, body] = exprs.as_ref();
401                write!(f, "(let ")?;
402                binder.lean_fmt(f, cx)?;
403                write!(f, " := ")?;
404                def.lean_fmt(f, cx)?;
405                write!(f, "; ")?;
406                body.lean_fmt(f, cx)?;
407                write!(f, ")")
408            }
409            Expr::ThyFunc(thy_func) => {
410                write!(f, "{}", LeanThyFunc(thy_func))
411            }
412            Expr::IsCtor(..) => {
413                todo!("not yet implemented: datatypes in lean")
414            }
415            Expr::Exists(..) => {
416                todo!("not yet implemented: exists in lean")
417            }
418            Expr::BoundVar(_) => {
419                unreachable!("bound vars should only be present in fixpoint output")
420            }
421        }
422    }
423}
424
425impl LeanFmt for &FunDef {
426    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
427        (*self).lean_fmt(f, cx)
428    }
429}
430
431impl LeanFmt for FunDef {
432    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
433        let FunDef { name, args, out, comment: _, body } = self;
434        write!(f, "def ")?;
435        name.lean_fmt(f, cx)?;
436        for (arg, arg_sort) in args {
437            write!(f, " (")?;
438            arg.lean_fmt(f, cx)?;
439            write!(f, " : {})", LeanSort(arg_sort))?;
440        }
441        writeln!(f, " : {} :=", LeanSort(out))?;
442        write!(f, "  ")?;
443        body.lean_fmt(f, cx)?;
444        writeln!(f)
445    }
446}
447
448impl LeanFmt for &Pred {
449    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
450        (*self).lean_fmt(f, cx)
451    }
452}
453
454impl LeanFmt for Pred {
455    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
456        match self {
457            Pred::Expr(expr) => expr.lean_fmt(f, cx),
458            Pred::And(preds) => {
459                write!(f, "(")?;
460                for (i, pred) in preds.iter().enumerate() {
461                    if i > 0 {
462                        write!(f, " ∧ ")?;
463                    }
464                    pred.lean_fmt(f, cx)?;
465                }
466                write!(f, ")")
467            }
468            Pred::KVar(kvid, args) => {
469                write!(f, "({}", kvid.display().to_string().replace("$", "_"))?;
470                for arg in args {
471                    write!(f, " ")?;
472                    arg.lean_fmt(f, cx)?;
473                }
474                write!(f, ")")
475            }
476        }
477    }
478}
479
480impl<'a> fmt::Display for LeanKVarDecl<'a> {
481    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
482        let sorts = self
483            .0
484            .sorts
485            .iter()
486            .enumerate()
487            .map(|(i, sort)| format!("(a{i} : {})", LeanSort(sort)))
488            .format(" -> ");
489        write!(f, "∃ {} : {} -> Prop", self.0.kvid.display().to_string().replace("$", "_"), sorts)
490    }
491}
492
493impl<'a> LeanFmt for LeanKConstraint<'a> {
494    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
495        if self.kvars.is_empty() {
496            self.constr.lean_fmt(f, cx)
497        } else {
498            write!(f, "{}, ", self.kvars.iter().map(LeanKVarDecl).format(", "))?;
499            self.constr.lean_fmt(f, cx)
500        }
501    }
502}
503
504impl LeanFmt for &Constraint {
505    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
506        (*self).lean_fmt(f, cx)
507    }
508}
509
510impl LeanFmt for Constraint {
511    fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
512        let mut fmt_cx = ConstraintFormatter::default();
513        fmt_cx.incr();
514        fmt_cx.newline(f)?;
515        self.fmt_nested(f, cx, &mut fmt_cx)?;
516        fmt_cx.decr();
517        Ok(())
518    }
519}
520
521impl FormatNested for Constraint {
522    fn fmt_nested(
523        &self,
524        f: &mut fmt::Formatter,
525        lean_cx: &LeanCtxt,
526        fmt_cx: &mut ConstraintFormatter,
527    ) -> fmt::Result {
528        match self {
529            Constraint::ForAll(bind, inner) => {
530                let trivial_pred = bind.pred.is_trivially_true();
531                let trivial_bind = bind.name.display().to_string().starts_with("_");
532                if !trivial_bind {
533                    write!(f, "∀ (")?;
534                    bind.name.lean_fmt(f, lean_cx)?;
535                    write!(f, " : {}),", LeanSort(&bind.sort))?;
536                    fmt_cx.incr();
537                    fmt_cx.newline(f)?;
538                }
539                if !trivial_pred {
540                    bind.pred.lean_fmt(f, lean_cx)?;
541                    write!(f, " ->")?;
542                    fmt_cx.incr();
543                    fmt_cx.newline(f)?;
544                }
545                inner.fmt_nested(f, lean_cx, fmt_cx)?;
546                if !trivial_pred {
547                    fmt_cx.decr();
548                }
549                if !trivial_bind {
550                    fmt_cx.decr();
551                }
552                Ok(())
553            }
554            Constraint::Conj(constraints) => {
555                let n = constraints.len();
556                for (i, constraint) in constraints.iter().enumerate() {
557                    constraint.fmt_nested(f, lean_cx, fmt_cx)?;
558                    if i < n - 1 {
559                        write!(f, " ∧")?;
560                    }
561                    fmt_cx.newline(f)?;
562                }
563                Ok(())
564            }
565            Constraint::Pred(pred, _) => pred.lean_fmt(f, lean_cx),
566        }
567    }
568}
569
570pub trait FormatNested {
571    fn fmt_nested(
572        &self,
573        f: &mut fmt::Formatter,
574        lean_cx: &LeanCtxt,
575        fmt_cx: &mut ConstraintFormatter,
576    ) -> fmt::Result;
577}
578
579#[derive(Default)]
580pub struct ConstraintFormatter {
581    level: u32,
582}
583
584impl ConstraintFormatter {
585    pub fn incr(&mut self) {
586        self.level += 1;
587    }
588
589    pub fn decr(&mut self) {
590        self.level -= 1;
591    }
592
593    pub fn newline(&self, f: &mut fmt::Formatter) -> fmt::Result {
594        f.write_char('\n')?;
595        self.padding(f)
596    }
597
598    pub fn padding(&self, f: &mut fmt::Formatter) -> fmt::Result {
599        for _ in 0..self.level {
600            f.write_str(" ")?;
601        }
602        Ok(())
603    }
604}