flux_infer/
lean_encoding.rs

1use std::{
2    fs,
3    io::{self, Write},
4    path,
5    process::{Command, Stdio},
6};
7
8use flux_middle::{
9    def_id::MaybeExternId,
10    global_env::GlobalEnv,
11    queries::{QueryErr, QueryResult},
12};
13use itertools::Itertools;
14use liquid_fixpoint::Identifier;
15
16use crate::fixpoint_encoding::fixpoint::{BinRel, ConstDecl, Constraint, Expr, FunDef, Pred};
17
18pub(crate) struct ConstDef(pub ConstDecl, pub Option<Expr>);
19
20pub(crate) struct LeanEncoder<'genv, 'tcx> {
21    def_id: MaybeExternId,
22    genv: GlobalEnv<'genv, 'tcx>,
23    fun_defs: Vec<FunDef>,
24    constants: Vec<ConstDef>,
25    constraint: Constraint,
26}
27
28impl<'genv, 'tcx> LeanEncoder<'genv, 'tcx> {
29    pub fn new(
30        def_id: MaybeExternId,
31        genv: GlobalEnv<'genv, 'tcx>,
32        fun_defs: Vec<FunDef>,
33        constants: Vec<ConstDecl>,
34        constraint: Constraint,
35    ) -> Self {
36        let constants = Self::extract_const_defs(constants, &constraint);
37        Self { def_id, genv, fun_defs, constants, constraint }
38    }
39
40    pub fn fun_defs(&self) -> &[FunDef] {
41        &self.fun_defs
42    }
43
44    pub fn constraint(&self) -> &Constraint {
45        &self.constraint
46    }
47
48    pub fn constants(&self) -> &[ConstDef] {
49        &self.constants
50    }
51
52    pub(crate) fn theorem_name(&self) -> String {
53        self.genv
54            .tcx()
55            .def_path(self.def_id.resolved_id())
56            .to_filename_friendly_no_crate()
57            .replace("-", "_")
58    }
59
60    fn proof_name(&self) -> String {
61        format!("{}_proof", self.theorem_name()).to_string()
62    }
63
64    fn generate_lake_project_if_not_present(
65        &self,
66        lean_path: &path::Path,
67        project_name: &str,
68    ) -> Result<(), io::Error> {
69        if !lean_path.join(project_name).exists() {
70            Command::new("lake")
71                .arg("new")
72                .arg(project_name)
73                .arg("lib")
74                .spawn()
75                .and_then(|mut child| child.wait())
76                .map(|_| ())
77        } else {
78            Ok(())
79        }
80    }
81
82    fn generate_def_file(
83        &self,
84        lean_path: &path::Path,
85        project_name: &str,
86    ) -> Result<(), io::Error> {
87        self.generate_lake_project_if_not_present(lean_path, project_name)?;
88        let theorem_path = lean_path.join(
89            format!(
90                "{project_name}/{}/{}.lean",
91                Self::snake_case_to_pascal_case(project_name),
92                Self::snake_case_to_pascal_case(self.theorem_name().as_str())
93            )
94            .as_str(),
95        );
96        let mut file = fs::File::create(theorem_path)?;
97        writeln!(file, "{self}")
98    }
99
100    fn generate_proof_file_if_not_present(
101        &self,
102        lean_path: &path::Path,
103        project_name: &str,
104    ) -> Result<(), io::Error> {
105        self.generate_def_file(lean_path, project_name)?;
106        let module_name = Self::snake_case_to_pascal_case(project_name);
107        let proof_name = self.proof_name();
108        let proof_path = lean_path.join(
109            format!(
110                "{project_name}/{}/{}.lean",
111                module_name.as_str(),
112                Self::snake_case_to_pascal_case(proof_name.as_str())
113            )
114            .as_str(),
115        );
116        let theorem_name = self.theorem_name();
117        if !proof_path.exists() {
118            let mut file = std::fs::File::create(proof_path)?;
119            writeln!(
120                file,
121                "import {}.{}",
122                Self::snake_case_to_pascal_case(module_name.as_str()),
123                Self::snake_case_to_pascal_case(theorem_name.as_str())
124            )?;
125            writeln!(file, "def {proof_name} : {theorem_name} := by")?;
126            writeln!(file, "  unfold {theorem_name}")?;
127            writeln!(file, "  sorry")
128        } else {
129            Ok(())
130        }
131    }
132
133    fn check_proof_help(&self, lean_path: &path::Path, project_name: &str) -> std::io::Result<()> {
134        self.generate_proof_file_if_not_present(lean_path, project_name)?;
135        let project_path = lean_path.join(project_name);
136        let proof_path = project_path.join(format!(
137            "{}/{}.lean",
138            Self::snake_case_to_pascal_case(project_name),
139            Self::snake_case_to_pascal_case(self.proof_name().as_str())
140        ));
141        let child = Command::new("lake")
142            .arg("--dir")
143            .arg(project_path.to_str().unwrap())
144            .arg("lean")
145            .arg(proof_path.to_str().unwrap())
146            .stdout(Stdio::piped())
147            .stderr(Stdio::piped())
148            .spawn()?;
149        let out = child.wait_with_output()?;
150        if out.stderr.is_empty() && out.stdout.is_empty() {
151            Ok(())
152        } else {
153            let stderr = std::str::from_utf8(&out.stderr)
154                .unwrap_or("Lean exited with a non-zero return code");
155            Err(io::Error::other(stderr))
156        }
157    }
158
159    pub fn check_proof(&self, lean_path: &path::Path, project_name: &str) -> QueryResult<()> {
160        self.check_proof_help(lean_path, project_name).map_err(|_| {
161            let msg = format!("checking proof for {} failed", self.theorem_name());
162            let span = self.genv.tcx().def_span(self.def_id.resolved_id());
163            QueryErr::Emitted(
164                self.genv
165                    .sess()
166                    .dcx()
167                    .handle()
168                    .struct_span_err(span, msg)
169                    .emit(),
170            )
171        })
172    }
173
174    fn snake_case_to_pascal_case(snake: &str) -> String {
175        snake
176            .split('_')
177            .filter(|s| !s.is_empty()) // skip empty segments (handles double underscores)
178            .map(|word| {
179                let mut chars = word.chars();
180                match chars.next() {
181                    Some(first) => first.to_ascii_uppercase().to_string() + chars.as_str(),
182                    None => String::new(),
183                }
184            })
185            .collect::<String>()
186    }
187
188    fn extract_const_defs(const_decls: Vec<ConstDecl>, constraint: &Constraint) -> Vec<ConstDef> {
189        const_decls
190            .into_iter()
191            .map(|const_decl| {
192                let mut defs = vec![];
193                Self::extract_const_def(&const_decl, constraint, &mut defs);
194                if defs.len() <= 1 {
195                    ConstDef(const_decl, defs.pop())
196                } else {
197                    panic!("Constant {} has {} definitions", const_decl.name.display(), defs.len())
198                }
199            })
200            .collect_vec()
201    }
202
203    fn extract_const_def(const_decl: &ConstDecl, constraint: &Constraint, acc: &mut Vec<Expr>) {
204        match constraint {
205            Constraint::ForAll(bind, inner) => {
206                if let Pred::Expr(Expr::Atom(BinRel::Eq, equals)) = &bind.pred {
207                    if let Expr::Var(vl) = &equals[0]
208                        && vl.eq(&const_decl.name)
209                    {
210                        acc.push(equals[1].clone());
211                    }
212                    if let Expr::Var(vr) = &equals[1]
213                        && vr.eq(&const_decl.name)
214                    {
215                        acc.push(equals[0].clone());
216                    }
217                    Self::extract_const_def(const_decl, inner.as_ref(), acc);
218                }
219            }
220            Constraint::Conj(conjuncts) => {
221                conjuncts
222                    .iter()
223                    .for_each(|constraint| Self::extract_const_def(const_decl, constraint, acc));
224            }
225            Constraint::Pred(..) => {}
226        }
227    }
228}