flux_infer/
lean_encoding.rs

1use std::{
2    fs,
3    io::{self, Write},
4    path::PathBuf,
5    process::{Command, Stdio},
6};
7
8use flux_common::{dbg, dbg::SpanTrace};
9use flux_config as config;
10use flux_middle::{
11    def_id::MaybeExternId,
12    global_env::GlobalEnv,
13    queries::{QueryErr, QueryResult},
14    rty::PrettyMap,
15};
16
17use crate::{
18    fixpoint_encoding::fixpoint,
19    lean_format::{self, LeanCtxt, LeanSortVar, WithLeanCtxt},
20};
21
22pub struct LeanEncoder<'genv, 'tcx> {
23    genv: GlobalEnv<'genv, 'tcx>,
24    path: PathBuf,
25    project: String,
26    defs_file_name: String,
27    pretty_var_map: PrettyMap<fixpoint::LocalVar>,
28}
29
30impl<'genv, 'tcx> LeanEncoder<'genv, 'tcx> {
31    pub fn new(
32        genv: GlobalEnv<'genv, 'tcx>,
33        pretty_var_map: PrettyMap<fixpoint::LocalVar>,
34    ) -> Self {
35        let defs_file_name = "Defs".to_string();
36        let path = genv
37            .tcx()
38            .sess
39            .opts
40            .working_dir
41            .local_path_if_available()
42            .to_path_buf()
43            .join(config::lean_dir());
44        let project = config::lean_project().to_string();
45        Self { genv, path, project, defs_file_name, pretty_var_map }
46    }
47
48    fn generate_lake_project_if_not_present(&self) -> Result<(), io::Error> {
49        if !self.path.join(&self.project).exists() {
50            Command::new("lake")
51                .arg("new")
52                .arg(&self.project)
53                .arg("lib")
54                .spawn()
55                .and_then(|mut child| child.wait())
56                .map(|_| ())
57        } else {
58            Ok(())
59        }
60    }
61
62    fn generate_sort_def_file_if_not_present(
63        &self,
64        sort: &fixpoint::SortDecl,
65    ) -> Result<(), io::Error> {
66        let pascal_project_name = Self::snake_case_to_pascal_case(&self.project);
67        let pascal_sort_file_name =
68            Self::snake_case_to_pascal_case(&format!("{}", LeanSortVar(&sort.name)));
69        let instance_path = self
70            .path
71            .join(format!("{}/{pascal_project_name}/{pascal_sort_file_name}.lean", self.project));
72        if !instance_path.exists() {
73            let mut instance_file = fs::File::create(instance_path)?;
74            writeln!(instance_file, "import {}.Lib\n", pascal_project_name)?;
75            let cx = LeanCtxt { genv: self.genv, pretty_var_map: &self.pretty_var_map };
76            writeln!(instance_file, "def {} := sorry", WithLeanCtxt { item: sort, cx: &cx })?;
77        }
78        Ok(())
79    }
80
81    fn generate_func_def_file_if_not_present(
82        &self,
83        fun: &fixpoint::ConstDecl,
84        has_opaque_sorts: bool,
85        has_data_decls: bool,
86    ) -> Result<(), io::Error> {
87        let pascal_project_name = Self::snake_case_to_pascal_case(&self.project);
88        let cx = LeanCtxt { genv: self.genv, pretty_var_map: &self.pretty_var_map };
89        let fun_name = format!("{}", WithLeanCtxt { item: &fun.name, cx: &cx });
90        let pascal_fun_file_name = Self::snake_case_to_pascal_case(&fun_name);
91        let instance_path = self
92            .path
93            .join(format!("{}/{pascal_project_name}/{pascal_fun_file_name}.lean", self.project));
94        if !instance_path.exists() {
95            let mut instance_file = fs::File::create(instance_path)?;
96            writeln!(instance_file, "import {pascal_project_name}.Lib")?;
97            if has_opaque_sorts {
98                writeln!(instance_file, "import {pascal_project_name}.OpaqueSorts")?;
99            }
100            if has_data_decls {
101                writeln!(instance_file, "import {pascal_project_name}.Structs")?;
102            }
103            writeln!(instance_file, "def {} := sorry", WithLeanCtxt { item: fun, cx: &cx })?;
104        }
105        Ok(())
106    }
107
108    fn generate_opaque_sort_files(&self, sorts: &[fixpoint::SortDecl]) -> Result<(), io::Error> {
109        if sorts.is_empty() {
110            return Ok(());
111        }
112        for sort in sorts {
113            self.generate_sort_def_file_if_not_present(sort)?;
114        }
115        let pascal_project_name = Self::snake_case_to_pascal_case(&self.project);
116        let mut opaque_sorts_file = fs::File::create(
117            self.path
118                .join(format!("{}/{pascal_project_name}/OpaqueSorts.lean", self.project)),
119        )?;
120        for sort in sorts {
121            let pascal_sort_file_name =
122                Self::snake_case_to_pascal_case(&format!("{}", LeanSortVar(&sort.name)));
123            writeln!(opaque_sorts_file, "import {pascal_project_name}.{pascal_sort_file_name}")?;
124        }
125        Ok(())
126    }
127
128    fn generate_struct_defs_file(
129        &self,
130        data_decls: &[fixpoint::DataDecl],
131        has_opaque_sorts: bool,
132    ) -> Result<(), io::Error> {
133        if data_decls.is_empty() {
134            return Ok(());
135        }
136        let pascal_project_name = Self::snake_case_to_pascal_case(&self.project);
137        let mut structs_file = fs::File::create(
138            self.path
139                .join(format!("{}/{}/Structs.lean", self.project, pascal_project_name,)),
140        )?;
141        writeln!(structs_file, "import {}.Lib", pascal_project_name)?;
142        if has_opaque_sorts {
143            writeln!(structs_file, "import {}.OpaqueSorts", pascal_project_name)?;
144        }
145        writeln!(structs_file, "-- STRUCT DECLS --")?;
146        writeln!(structs_file, "mutual")?;
147        let cx = LeanCtxt { genv: self.genv, pretty_var_map: &self.pretty_var_map };
148        for data_decl in data_decls {
149            writeln!(structs_file, "{}", WithLeanCtxt { item: data_decl, cx: &cx })?;
150        }
151        writeln!(structs_file, "end")?;
152        Ok(())
153    }
154
155    fn generate_opaque_func_files(
156        &self,
157        funs: &[fixpoint::ConstDecl],
158        has_opaque_sorts: bool,
159        has_data_decls: bool,
160    ) -> Result<(), io::Error> {
161        if funs.is_empty() {
162            return Ok(());
163        }
164        for fun in funs {
165            self.generate_func_def_file_if_not_present(fun, has_opaque_sorts, has_data_decls)?;
166        }
167        let pascal_project_name = Self::snake_case_to_pascal_case(&self.project);
168        let mut opaque_funcs_file = fs::File::create(
169            self.path
170                .join(format!("{}/{pascal_project_name}/OpaqueFuncs.lean", self.project)),
171        )?;
172        let cx = LeanCtxt { genv: self.genv, pretty_var_map: &self.pretty_var_map };
173        for fun in funs {
174            let fun_name = format!("{}", WithLeanCtxt { item: &fun.name, cx: &cx });
175            let pascal_fun_file_name = Self::snake_case_to_pascal_case(&fun_name);
176            writeln!(opaque_funcs_file, "import {pascal_project_name}.{pascal_fun_file_name}")?;
177        }
178        Ok(())
179    }
180
181    fn generate_func_defs_file(
182        &self,
183        func_defs: &[fixpoint::FunDef],
184        has_opaque_sorts: bool,
185        has_data_decls: bool,
186        has_opaque_funcs: bool,
187    ) -> Result<(), io::Error> {
188        if func_defs.is_empty() {
189            return Ok(());
190        }
191        let pascal_project_name = Self::snake_case_to_pascal_case(&self.project);
192        let defs_path = self
193            .path
194            .join(format!("{}/{pascal_project_name}/{}.lean", self.project, self.defs_file_name));
195        let mut file = fs::File::create(defs_path)?;
196
197        writeln!(file, "import {pascal_project_name}.Lib")?;
198        if has_opaque_sorts {
199            writeln!(file, "import {pascal_project_name}.OpaqueSorts")?;
200        }
201        if has_data_decls {
202            writeln!(file, "import {pascal_project_name}.Structs")?;
203        }
204        if has_opaque_funcs {
205            writeln!(file, "import {pascal_project_name}.OpaqueFuncs")?;
206        }
207        writeln!(file, "-- FUNC DECLS --")?;
208        writeln!(file, "mutual")?;
209        let cx = LeanCtxt { genv: self.genv, pretty_var_map: &self.pretty_var_map };
210        for fun_def in func_defs {
211            writeln!(file, "{}", WithLeanCtxt { item: fun_def, cx: &cx })?;
212        }
213        writeln!(file, "end")?;
214        Ok(())
215    }
216
217    fn generate_lib_file(&self) -> Result<(), io::Error> {
218        let pascal_project_name = Self::snake_case_to_pascal_case(&self.project);
219        let mut lib_file = fs::File::create(
220            self.path
221                .join(format!("{}/{pascal_project_name}/Lib.lean", self.project)),
222        )?;
223        writeln!(
224            lib_file,
225            "def BitVec_shiftLeft {{ n : Nat }} (x s : BitVec n) : BitVec n := BitVec.shiftLeft x (s.toNat)"
226        )?;
227        writeln!(
228            lib_file,
229            "def BitVec_ushiftRight {{ n : Nat }} (x s : BitVec n) : BitVec n := BitVec.ushiftRight x (s.toNat)"
230        )?;
231        writeln!(
232            lib_file,
233            "def BitVec_sshiftRight {{ n : Nat }} (x s : BitVec n) : BitVec n := BitVec.sshiftRight x (s.toNat)"
234        )?;
235        writeln!(
236            lib_file,
237            "def BitVec_uge {{ n : Nat }} (x y : BitVec n) := (BitVec.ult x y).not"
238        )?;
239        writeln!(
240            lib_file,
241            "def BitVec_sge {{ n : Nat }} (x y : BitVec n) := (BitVec.slt x y).not"
242        )?;
243        writeln!(
244            lib_file,
245            "def BitVec_ugt {{ n : Nat }} (x y : BitVec n) := (BitVec.ule x y).not"
246        )?;
247        writeln!(
248            lib_file,
249            "def BitVec_sgt {{ n : Nat }} (x y : BitVec n) := (BitVec.sle x y).not"
250        )?;
251        writeln!(
252            lib_file,
253            "def SmtMap (t0 t1 : Type) [Inhabited t0] [BEq t0] [Inhabited t1] : Type := t0 -> t1"
254        )?;
255        writeln!(
256            lib_file,
257            "def SmtMap_default {{ t0 t1: Type }} (v : t1) [Inhabited t0] [BEq t0] [Inhabited t1] : SmtMap t0 t1 := fun _ => v"
258        )?;
259        writeln!(
260            lib_file,
261            "def SmtMap_store {{ t0 t1 : Type }} [Inhabited t0] [BEq t0] [Inhabited t1] (m : SmtMap t0 t1) (k : t0) (v : t1) : SmtMap t0 t1 :=\n  fun x => if x == k then v else m x"
262        )?;
263        writeln!(
264            lib_file,
265            "def SmtMap_select {{ t0 t1 : Type }} [Inhabited t0] [BEq t0] [Inhabited t1] (m : SmtMap t0 t1) (k : t0) := m k"
266        )?;
267        Ok(())
268    }
269
270    pub fn encode_defs(
271        &self,
272        opaque_sorts: &[fixpoint::SortDecl],
273        opaque_funs: &[fixpoint::ConstDecl],
274        data_decls: &[fixpoint::DataDecl],
275        func_defs: &[fixpoint::FunDef],
276    ) -> Result<(), io::Error> {
277        self.generate_lake_project_if_not_present()?;
278        self.generate_lib_file()?;
279
280        let has_opaque_sorts = !opaque_sorts.is_empty();
281        let has_data_decls = !data_decls.is_empty();
282        let has_opaque_funcs = !opaque_funs.is_empty();
283
284        self.generate_opaque_sort_files(opaque_sorts)?;
285        self.generate_struct_defs_file(data_decls, has_opaque_sorts)?;
286        self.generate_opaque_func_files(opaque_funs, has_opaque_sorts, has_data_decls)?;
287        self.generate_func_defs_file(
288            func_defs,
289            has_opaque_sorts,
290            has_data_decls,
291            has_opaque_funcs,
292        )?;
293        Ok(())
294    }
295
296    fn theorem_path(&self, theorem_name: &str) -> PathBuf {
297        let pascal_project_name = Self::snake_case_to_pascal_case(&self.project);
298
299        self.path.join(format!(
300            "{}/{}/{}.lean",
301            self.project,
302            pascal_project_name,
303            Self::snake_case_to_pascal_case(theorem_name)
304        ))
305    }
306
307    fn generate_theorem_file(
308        &self,
309        theorem_name: &str,
310        kvars: &[fixpoint::KVarDecl],
311        cstr: &fixpoint::Constraint,
312    ) -> Result<(), io::Error> {
313        let pascal_project_name = Self::snake_case_to_pascal_case(&self.project);
314        let theorem_path = self.theorem_path(theorem_name);
315        let mut theorem_file = fs::File::create(theorem_path)?;
316        writeln!(theorem_file, "import {}.Lib", pascal_project_name)?;
317        if self
318            .path
319            .join(format!("{}/{pascal_project_name}/{}.lean", self.project, self.defs_file_name))
320            .exists()
321        {
322            writeln!(theorem_file, "import {pascal_project_name}.{}", self.defs_file_name)?;
323        }
324        if self
325            .path
326            .join(format!("{}/{pascal_project_name}/OpaqueSorts.lean", self.project))
327            .exists()
328        {
329            writeln!(theorem_file, "import {pascal_project_name}.OpaqueSorts")?;
330        }
331        if self
332            .path
333            .join(format!("{}/{pascal_project_name}/OpaqueFuncs.lean", self.project))
334            .exists()
335        {
336            writeln!(theorem_file, "import {pascal_project_name}.OpaqueFuncs")?;
337        }
338        let cx = LeanCtxt { genv: self.genv, pretty_var_map: &self.pretty_var_map };
339        writeln!(
340            theorem_file,
341            "def {} := {}",
342            theorem_name.replace(".", "_"),
343            WithLeanCtxt { item: lean_format::LeanKConstraint { kvars, constr: cstr }, cx: &cx }
344        )
345    }
346
347    fn generate_proof_file_if_not_present(
348        &self,
349        def_id: MaybeExternId,
350        theorem_name: &str,
351    ) -> Result<(), io::Error> {
352        let module_name = Self::snake_case_to_pascal_case(&self.project);
353        let proof_name = format!("{theorem_name}_proof");
354        let proof_path = self.path.join(format!(
355            "{}/{}/{}.lean",
356            self.project,
357            module_name,
358            Self::snake_case_to_pascal_case(&proof_name)
359        ));
360
361        if let Some(span) = self.genv.proven_externally(def_id.local_id()) {
362            let dst_span = SpanTrace::from_path(&proof_path, 3, 5, proof_name.len());
363            dbg::hyperlink_json!(self.genv.tcx(), span, dst_span);
364        }
365
366        if proof_path.exists() {
367            return Ok(());
368        }
369        let mut proof_file = fs::File::create(proof_path)?;
370        writeln!(proof_file, "import {module_name}.Lib")?;
371        writeln!(
372            proof_file,
373            "import {}.{}",
374            module_name,
375            Self::snake_case_to_pascal_case(theorem_name)
376        )?;
377        writeln!(proof_file, "def {proof_name} : {theorem_name} := by")?;
378        writeln!(proof_file, "  unfold {theorem_name}")?;
379        writeln!(proof_file, "  sorry")
380    }
381
382    pub fn encode_constraint(
383        &self,
384        def_id: MaybeExternId,
385        kvars: &[fixpoint::KVarDecl],
386        cstr: &fixpoint::Constraint,
387    ) -> Result<(), io::Error> {
388        self.generate_lake_project_if_not_present()?;
389        self.generate_lib_file()?;
390        let theorem_name = self
391            .genv
392            .tcx()
393            .def_path(def_id.resolved_id())
394            .to_filename_friendly_no_crate()
395            .replace("-", "_");
396
397        self.generate_theorem_file(&theorem_name, kvars, cstr)?;
398
399        self.generate_proof_file_if_not_present(def_id, &theorem_name)
400    }
401
402    fn check_proof_help(&self, theorem_name: &str) -> io::Result<()> {
403        let proof_name = format!("{theorem_name}_proof");
404        let project_path = self.path.join(&self.project);
405        let proof_path = format!(
406            "{}/{}.lean",
407            Self::snake_case_to_pascal_case(&self.project),
408            Self::snake_case_to_pascal_case(&proof_name)
409        );
410        let child = Command::new("lake")
411            .arg("--quiet")
412            .arg("lean")
413            .arg(proof_path)
414            .stdout(Stdio::piped())
415            .stderr(Stdio::piped())
416            .current_dir(project_path.as_path())
417            .spawn()?;
418        let out = child.wait_with_output()?;
419        if out.stderr.is_empty() && out.stdout.is_empty() {
420            Ok(())
421        } else {
422            let stderr = std::str::from_utf8(&out.stderr)
423                .unwrap_or("Lean exited with a non-zero return code");
424            Err(io::Error::other(stderr))
425        }
426    }
427
428    pub fn check_proof(&self, def_id: MaybeExternId) -> QueryResult<()> {
429        let theorem_name = self
430            .genv
431            .tcx()
432            .def_path(def_id.resolved_id())
433            .to_filename_friendly_no_crate()
434            .replace("-", "_");
435        self.check_proof_help(&theorem_name).map_err(|_| {
436            let msg = format!("checking proof for {theorem_name} failed");
437            let span = self.genv.tcx().def_span(def_id.resolved_id());
438            QueryErr::Emitted(
439                self.genv
440                    .sess()
441                    .dcx()
442                    .handle()
443                    .struct_span_err(span, msg)
444                    .emit(),
445            )
446        })
447    }
448
449    fn snake_case_to_pascal_case(snake: &str) -> String {
450        snake
451            .split('_')
452            .filter(|s| !s.is_empty()) // skip empty segments (handles double underscores)
453            .map(|word| {
454                let mut chars = word.chars();
455                match chars.next() {
456                    Some(first) => first.to_ascii_uppercase().to_string() + chars.as_str(),
457                    None => String::new(),
458                }
459            })
460            .collect::<String>()
461    }
462}