flux_infer/
lean_encoding.rs

1use std::{
2    fs,
3    io::{self, Write},
4    path::Path,
5    process::{Command, Stdio},
6};
7
8use flux_middle::{
9    def_id::MaybeExternId,
10    global_env::GlobalEnv,
11    queries::{QueryErr, QueryResult},
12};
13
14use crate::{
15    fixpoint_encoding::fixpoint,
16    lean_format::{self, LeanConstDecl, LeanSortDecl, LeanSortVar, LeanVar},
17};
18
19pub struct LeanEncoder<'genv, 'tcx, 'a> {
20    genv: GlobalEnv<'genv, 'tcx>,
21    lean_path: &'a Path,
22    project_name: String,
23    defs_file_name: String,
24}
25
26impl<'genv, 'tcx, 'a> LeanEncoder<'genv, 'tcx, 'a> {
27    pub fn new(
28        genv: GlobalEnv<'genv, 'tcx>,
29        lean_path: &'a Path,
30        project_name: String,
31        defs_file_name: String,
32    ) -> Self {
33        Self { genv, lean_path, project_name, defs_file_name }
34    }
35
36    fn generate_lake_project_if_not_present(&self) -> Result<(), io::Error> {
37        if !self.lean_path.join(self.project_name.as_str()).exists() {
38            Command::new("lake")
39                .arg("new")
40                .arg(self.project_name.as_str())
41                .arg("lib")
42                .spawn()
43                .and_then(|mut child| child.wait())
44                .map(|_| ())
45        } else {
46            Ok(())
47        }
48    }
49
50    fn generate_sorts_instance_file_if_not_present(
51        &self,
52        sorts: &[fixpoint::SortDecl],
53    ) -> Result<(), io::Error> {
54        let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
55        let instance_path = self.lean_path.join(
56            format!(
57                "{}/{}/OpaqueSortsInstance.lean",
58                self.project_name,
59                pascal_project_name.as_str(),
60            )
61            .as_str(),
62        );
63        if !instance_path.exists() {
64            let mut instance_file = fs::File::create(instance_path)?;
65            writeln!(instance_file, "import {}.Lib", pascal_project_name.as_str())?;
66            writeln!(instance_file, "import {}.OpaqueSortDefs\n", pascal_project_name.as_str())?;
67            writeln!(instance_file, "instance : FluxOpaqueSorts where")?;
68            for sort in sorts {
69                writeln!(instance_file, "  {} := sorry", LeanSortVar(&sort.name))?;
70            }
71        }
72        Ok(())
73    }
74
75    fn generate_funcs_instance_file_if_not_present(
76        &self,
77        funs: &[fixpoint::ConstDecl],
78    ) -> Result<(), io::Error> {
79        let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
80        let instance_path = self.lean_path.join(
81            format!(
82                "{}/{}/OpaqueFuncsInstance.lean",
83                self.project_name,
84                pascal_project_name.as_str(),
85            )
86            .as_str(),
87        );
88        if !instance_path.exists() {
89            let mut instance_file = fs::File::create(instance_path)?;
90            writeln!(instance_file, "import {}.Lib", pascal_project_name.as_str())?;
91            writeln!(instance_file, "import {}.OpaqueFuncDefs\n", pascal_project_name.as_str())?;
92            writeln!(instance_file, "instance : FluxOpaqueFuncs where")?;
93            for fun in funs {
94                writeln!(instance_file, "  {} := sorry", LeanVar(&fun.name, self.genv))?;
95            }
96        }
97        Ok(())
98    }
99
100    fn generate_sorts_inferred_instance_file(
101        &self,
102        sorts: &[fixpoint::SortDecl],
103    ) -> Result<(), io::Error> {
104        let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
105        let mut inferred_instance_file = fs::File::create(self.lean_path.join(format!(
106            "{}/{}/OpaqueSorts.lean",
107            self.project_name,
108            pascal_project_name.as_str()
109        )))?;
110        writeln!(
111            inferred_instance_file,
112            "import {}.OpaqueSortsInstance\n",
113            pascal_project_name.as_str()
114        )?;
115        writeln!(
116            inferred_instance_file,
117            "def fluxOpaqueSorts : FluxOpaqueSorts := inferInstance\n"
118        )?;
119        for sort in sorts {
120            writeln!(
121                inferred_instance_file,
122                "def {} := fluxOpaqueSorts.{}",
123                LeanSortVar(&sort.name),
124                LeanSortVar(&sort.name)
125            )?;
126        }
127        Ok(())
128    }
129
130    fn generate_funcs_inferred_instance_file(
131        &self,
132        funs: &[fixpoint::ConstDecl],
133    ) -> Result<(), io::Error> {
134        let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
135        let mut inferred_instance_file = fs::File::create(self.lean_path.join(format!(
136            "{}/{}/OpaqueFuncs.lean",
137            self.project_name,
138            pascal_project_name.as_str()
139        )))?;
140        writeln!(
141            inferred_instance_file,
142            "import {}.OpaqueFuncsInstance\n",
143            pascal_project_name.as_str()
144        )?;
145        writeln!(
146            inferred_instance_file,
147            "def fluxOpaqueFuncs : FluxOpaqueFuncs := inferInstance\n"
148        )?;
149        for fun in funs {
150            writeln!(
151                inferred_instance_file,
152                "def {} := fluxOpaqueFuncs.{}",
153                LeanConstDecl(fun, self.genv),
154                LeanVar(&fun.name, self.genv)
155            )?;
156        }
157        Ok(())
158    }
159
160    fn generate_sort_typeclass_files(&self, sorts: &[fixpoint::SortDecl]) -> Result<(), io::Error> {
161        let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
162        let mut opaque_sorts_file = fs::File::create(self.lean_path.join(format!(
163            "{}/{}/OpaqueSortDefs.lean",
164            self.project_name,
165            pascal_project_name.as_str(),
166        )))?;
167        writeln!(opaque_sorts_file, "import {}.Lib", pascal_project_name.as_str())?;
168        writeln!(opaque_sorts_file, "-- FLUX OPAQUE SORT DEFS --")?;
169        writeln!(opaque_sorts_file, "class FluxOpaqueSorts where")?;
170        for sort in sorts {
171            writeln!(opaque_sorts_file, "  {}", LeanSortDecl(sort, self.genv))?;
172        }
173        self.generate_sorts_instance_file_if_not_present(sorts)?;
174        self.generate_sorts_inferred_instance_file(sorts)?;
175        Ok(())
176    }
177
178    fn generate_struct_defs_file(
179        &self,
180        data_decls: &[fixpoint::DataDecl],
181        has_opaque_sorts: bool,
182    ) -> Result<(), io::Error> {
183        let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
184        let mut structs_file = fs::File::create(self.lean_path.join(format!(
185            "{}/{}/Structs.lean",
186            self.project_name,
187            pascal_project_name.as_str(),
188        )))?;
189        writeln!(structs_file, "import {}.Lib", pascal_project_name.as_str())?;
190        if has_opaque_sorts {
191            writeln!(structs_file, "import {}.OpaqueSorts", pascal_project_name.as_str())?;
192        }
193        writeln!(structs_file, "-- STRUCT DECLS --")?;
194        writeln!(structs_file, "mutual")?;
195        for data_decl in data_decls {
196            writeln!(structs_file, "{}", lean_format::LeanDataDecl(data_decl, self.genv))?;
197        }
198        writeln!(structs_file, "end")?;
199        Ok(())
200    }
201
202    fn generate_func_typeclass_files(
203        &self,
204        funs: &[fixpoint::ConstDecl],
205        has_opaque_sorts: bool,
206        has_data_decls: bool,
207    ) -> Result<(), io::Error> {
208        let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
209        let mut opaque_funcs_file = fs::File::create(self.lean_path.join(format!(
210            "{}/{}/OpaqueFuncDefs.lean",
211            self.project_name,
212            pascal_project_name.as_str(),
213        )))?;
214        writeln!(opaque_funcs_file, "import {}.Lib", pascal_project_name.as_str())?;
215        if has_opaque_sorts {
216            writeln!(opaque_funcs_file, "import {}.OpaqueSorts", pascal_project_name.as_str())?;
217        }
218        if has_data_decls {
219            writeln!(opaque_funcs_file, "import {}.Structs", pascal_project_name.as_str())?;
220        }
221        writeln!(opaque_funcs_file, "-- OPAQUE DEFS --")?;
222        writeln!(opaque_funcs_file, "class FluxOpaqueFuncs where")?;
223        for fun in funs {
224            writeln!(opaque_funcs_file, "  {}", LeanConstDecl(fun, self.genv))?;
225        }
226        self.generate_funcs_instance_file_if_not_present(funs)?;
227        self.generate_funcs_inferred_instance_file(funs)?;
228        Ok(())
229    }
230
231    fn generate_func_defs_file(
232        &self,
233        func_defs: &[fixpoint::FunDef],
234        has_opaque_sorts: bool,
235        has_data_decls: bool,
236        has_opaque_funcs: bool,
237    ) -> Result<(), io::Error> {
238        let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
239        let defs_path = self.lean_path.join(
240            format!(
241                "{}/{}/{}.lean",
242                self.project_name,
243                pascal_project_name.as_str(),
244                self.defs_file_name
245            )
246            .as_str(),
247        );
248        let mut file = fs::File::create(defs_path)?;
249
250        writeln!(file, "import {}.Lib", pascal_project_name.as_str())?;
251        if has_opaque_sorts {
252            writeln!(file, "import {}.OpaqueSorts", pascal_project_name.as_str())?;
253        }
254        if has_data_decls {
255            writeln!(file, "import {}.Structs", pascal_project_name.as_str())?;
256        }
257        if has_opaque_funcs {
258            writeln!(file, "import {}.OpaqueFuncs", pascal_project_name.as_str())?;
259        }
260        writeln!(file, "-- FUNC DECLS --")?;
261        writeln!(file, "mutual")?;
262        for fun_def in func_defs {
263            writeln!(file, "{}", lean_format::LeanFunDef(fun_def, self.genv))?;
264        }
265        writeln!(file, "end")?;
266        Ok(())
267    }
268
269    fn generate_lib_file(&self) -> Result<(), io::Error> {
270        let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
271        let mut lib_file = fs::File::create(self.lean_path.join(
272            format!("{}/{}/Lib.lean", self.project_name, pascal_project_name.as_str()).as_str(),
273        ))?;
274        writeln!(
275            lib_file,
276            "def BitVec_shiftLeft {{ n : Nat }} (x s : BitVec n) : BitVec n := BitVec.shiftLeft x (s.toNat)"
277        )?;
278        writeln!(
279            lib_file,
280            "def BitVec_ushiftRight {{ n : Nat }} (x s : BitVec n) : BitVec n := BitVec.ushiftRight x (s.toNat)"
281        )?;
282        writeln!(
283            lib_file,
284            "def BitVec_sshiftRight {{ n : Nat }} (x s : BitVec n) : BitVec n := BitVec.sshiftRight x (s.toNat)"
285        )?;
286        writeln!(
287            lib_file,
288            "def BitVec_uge {{ n : Nat }} (x y : BitVec n) := (BitVec.ult x y).not"
289        )?;
290        writeln!(
291            lib_file,
292            "def BitVec_sge {{ n : Nat }} (x y : BitVec n) := (BitVec.slt x y).not"
293        )?;
294        writeln!(
295            lib_file,
296            "def BitVec_ugt {{ n : Nat }} (x y : BitVec n) := (BitVec.ule x y).not"
297        )?;
298        writeln!(
299            lib_file,
300            "def BitVec_sgt {{ n : Nat }} (x y : BitVec n) := (BitVec.sle x y).not"
301        )?;
302        writeln!(
303            lib_file,
304            "def SmtMap (t0 t1 : Type) [Inhabited t0] [BEq t0] [Inhabited t1] : Type := t0 -> t1"
305        )?;
306        writeln!(
307            lib_file,
308            "def SmtMap_default {{ t0 t1: Type }} (v : t1) [Inhabited t0] [BEq t0] [Inhabited t1] : SmtMap t0 t1 := fun _ => v"
309        )?;
310        writeln!(
311            lib_file,
312            "def SmtMap_store {{ t0 t1 : Type }} [Inhabited t0] [BEq t0] [Inhabited t1] (m : SmtMap t0 t1) (k : t0) (v : t1) : SmtMap t0 t1 :=\n  fun x => if x == k then v else m x"
313        )?;
314        writeln!(
315            lib_file,
316            "def SmtMap_select {{ t0 t1 : Type }} [Inhabited t0] [BEq t0] [Inhabited t1] (m : SmtMap t0 t1) (k : t0) := m k"
317        )?;
318        Ok(())
319    }
320
321    pub fn encode_defs(
322        &self,
323        opaque_sorts: &[fixpoint::SortDecl],
324        opaque_funs: &[fixpoint::ConstDecl],
325        data_decls: &[fixpoint::DataDecl],
326        func_defs: &[fixpoint::FunDef],
327    ) -> Result<(), io::Error> {
328        self.generate_lake_project_if_not_present()?;
329        self.generate_lib_file()?;
330
331        let has_opaque_sorts = !opaque_sorts.is_empty();
332        let has_data_decls = !data_decls.is_empty();
333        let has_opaque_funcs = !opaque_funs.is_empty();
334
335        if has_opaque_sorts {
336            self.generate_sort_typeclass_files(opaque_sorts)?;
337        }
338        if has_data_decls {
339            self.generate_struct_defs_file(data_decls, has_opaque_sorts)?;
340        }
341        if has_opaque_funcs {
342            self.generate_func_typeclass_files(opaque_funs, has_opaque_sorts, has_data_decls)?;
343        }
344        if !func_defs.is_empty() {
345            self.generate_func_defs_file(
346                func_defs,
347                has_opaque_sorts,
348                has_data_decls,
349                has_opaque_funcs,
350            )?;
351        }
352        Ok(())
353    }
354
355    fn generate_theorem_file(
356        &self,
357        theorem_name: &str,
358        kvars: &[fixpoint::KVarDecl],
359        cstr: &fixpoint::Constraint,
360    ) -> Result<(), io::Error> {
361        let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
362        let theorem_path = self.lean_path.join(
363            format!(
364                "{}/{}/{}.lean",
365                self.project_name,
366                pascal_project_name.as_str(),
367                Self::snake_case_to_pascal_case(theorem_name)
368            )
369            .as_str(),
370        );
371        let mut theorem_file = fs::File::create(theorem_path)?;
372        writeln!(theorem_file, "import {}.Lib", pascal_project_name.as_str())?;
373        if self
374            .lean_path
375            .join(
376                format!(
377                    "{}/{}/{}.lean",
378                    self.project_name.as_str(),
379                    pascal_project_name.as_str(),
380                    self.defs_file_name.as_str()
381                )
382                .as_str(),
383            )
384            .exists()
385        {
386            writeln!(
387                theorem_file,
388                "import {}.{}",
389                pascal_project_name.as_str(),
390                self.defs_file_name.as_str()
391            )?;
392        }
393        if self
394            .lean_path
395            .join(
396                format!(
397                    "{}/{}/OpaqueSorts.lean",
398                    self.project_name.as_str(),
399                    pascal_project_name.as_str(),
400                )
401                .as_str(),
402            )
403            .exists()
404        {
405            writeln!(theorem_file, "import {}.OpaqueSorts", pascal_project_name.as_str())?;
406        }
407        if self
408            .lean_path
409            .join(
410                format!(
411                    "{}/{}/OpaqueFuncs.lean",
412                    self.project_name.as_str(),
413                    pascal_project_name.as_str(),
414                )
415                .as_str(),
416            )
417            .exists()
418        {
419            writeln!(theorem_file, "import {}.OpaqueFuncs", pascal_project_name.as_str())?;
420        }
421        writeln!(
422            theorem_file,
423            "def {} := {}",
424            theorem_name.replace(".", "_"),
425            lean_format::LeanKConstraint(kvars, cstr, self.genv)
426        )
427    }
428
429    fn generate_proof_file_if_not_present(&self, theorem_name: &str) -> Result<(), io::Error> {
430        let module_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
431        let proof_name = format!("{theorem_name}_proof");
432        let proof_path = self.lean_path.join(
433            format!(
434                "{}/{}/{}.lean",
435                self.project_name.as_str(),
436                module_name.as_str(),
437                Self::snake_case_to_pascal_case(proof_name.as_str())
438            )
439            .as_str(),
440        );
441        if proof_path.exists() {
442            return Ok(());
443        }
444        let mut proof_file = fs::File::create(proof_path)?;
445        writeln!(proof_file, "import {}.Lib", module_name.as_str())?;
446        writeln!(
447            proof_file,
448            "import {}.{}",
449            module_name.as_str(),
450            Self::snake_case_to_pascal_case(theorem_name)
451        )?;
452        writeln!(proof_file, "def {proof_name} : {theorem_name} := by")?;
453        writeln!(proof_file, "  unfold {theorem_name}")?;
454        writeln!(proof_file, "  sorry")
455    }
456
457    pub fn encode_constraint(
458        &self,
459        def_id: MaybeExternId,
460        kvars: &[fixpoint::KVarDecl],
461        cstr: &fixpoint::Constraint,
462    ) -> Result<(), io::Error> {
463        self.generate_lake_project_if_not_present()?;
464        self.generate_lib_file()?;
465        let theorem_name = self
466            .genv
467            .tcx()
468            .def_path(def_id.resolved_id())
469            .to_filename_friendly_no_crate()
470            .replace("-", "_");
471        self.generate_theorem_file(theorem_name.as_str(), kvars, cstr)?;
472        self.generate_proof_file_if_not_present(theorem_name.as_str())
473    }
474
475    fn check_proof_help(&self, theorem_name: &str) -> io::Result<()> {
476        let proof_name = format!("{theorem_name}_proof");
477        let project_path = self.lean_path.join(self.project_name.as_str());
478        let proof_path = project_path.join(format!(
479            "{}/{}.lean",
480            Self::snake_case_to_pascal_case(self.project_name.as_str()),
481            Self::snake_case_to_pascal_case(proof_name.as_str())
482        ));
483        let child = Command::new("lake")
484            .arg("--quiet")
485            .arg("--dir")
486            .arg(project_path.to_str().unwrap())
487            .arg("lean")
488            .arg(proof_path.to_str().unwrap())
489            .stdout(Stdio::piped())
490            .stderr(Stdio::piped())
491            .spawn()?;
492        let out = child.wait_with_output()?;
493        if out.stderr.is_empty() && out.stdout.is_empty() {
494            Ok(())
495        } else {
496            let stderr = std::str::from_utf8(&out.stderr)
497                .unwrap_or("Lean exited with a non-zero return code");
498            Err(io::Error::other(stderr))
499        }
500    }
501
502    pub fn check_proof(&self, def_id: MaybeExternId) -> QueryResult<()> {
503        let theorem_name = self
504            .genv
505            .tcx()
506            .def_path(def_id.resolved_id())
507            .to_filename_friendly_no_crate()
508            .replace("-", "_");
509        self.check_proof_help(theorem_name.as_str()).map_err(|_| {
510            let msg = format!("checking proof for {} failed", theorem_name.as_str());
511            let span = self.genv.tcx().def_span(def_id.resolved_id());
512            QueryErr::Emitted(
513                self.genv
514                    .sess()
515                    .dcx()
516                    .handle()
517                    .struct_span_err(span, msg)
518                    .emit(),
519            )
520        })
521    }
522
523    fn snake_case_to_pascal_case(snake: &str) -> String {
524        snake
525            .split('_')
526            .filter(|s| !s.is_empty()) // skip empty segments (handles double underscores)
527            .map(|word| {
528                let mut chars = word.chars();
529                match chars.next() {
530                    Some(first) => first.to_ascii_uppercase().to_string() + chars.as_str(),
531                    None => String::new(),
532                }
533            })
534            .collect::<String>()
535    }
536}