1use std::{
2    fs,
3    io::{self, Write},
4    path::Path,
5    process::{Command, Stdio},
6};
7
8use flux_middle::{
9    def_id::MaybeExternId,
10    global_env::GlobalEnv,
11    queries::{QueryErr, QueryResult},
12};
13
14use crate::{
15    fixpoint_encoding::fixpoint,
16    lean_format::{self, LeanConstDecl, LeanSortDecl, LeanSortVar, LeanVar},
17};
18
19pub struct LeanEncoder<'genv, 'tcx, 'a> {
20    genv: GlobalEnv<'genv, 'tcx>,
21    lean_path: &'a Path,
22    project_name: String,
23    defs_file_name: String,
24}
25
26impl<'genv, 'tcx, 'a> LeanEncoder<'genv, 'tcx, 'a> {
27    pub fn new(
28        genv: GlobalEnv<'genv, 'tcx>,
29        lean_path: &'a Path,
30        project_name: String,
31        defs_file_name: String,
32    ) -> Self {
33        Self { genv, lean_path, project_name, defs_file_name }
34    }
35
36    fn generate_lake_project_if_not_present(&self) -> Result<(), io::Error> {
37        if !self.lean_path.join(self.project_name.as_str()).exists() {
38            Command::new("lake")
39                .arg("new")
40                .arg(self.project_name.as_str())
41                .arg("lib")
42                .spawn()
43                .and_then(|mut child| child.wait())
44                .map(|_| ())
45        } else {
46            Ok(())
47        }
48    }
49
50    fn generate_instance_file_if_not_present(
51        &self,
52        sorts: &[fixpoint::SortDecl],
53        funs: &[fixpoint::ConstDecl],
54    ) -> Result<(), io::Error> {
55        let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
56        let instance_path = self.lean_path.join(
57            format!("{}/{}/Instance.lean", self.project_name, pascal_project_name.as_str(),)
58                .as_str(),
59        );
60        if !instance_path.exists() {
61            let mut instance_file = fs::File::create(instance_path)?;
62            writeln!(instance_file, "import {}.Lib", pascal_project_name.as_str())?;
63            writeln!(instance_file, "import {}.OpaqueFluxDefs\n", pascal_project_name.as_str())?;
64            writeln!(instance_file, "instance : FluxDefs where")?;
65            for sort in sorts {
66                writeln!(instance_file, "  {} := sorry", LeanSortVar(&sort.name))?;
67            }
68            for fun in funs {
69                writeln!(instance_file, "  {} := sorry", LeanVar(&fun.name, self.genv))?;
70            }
71        }
72        Ok(())
73    }
74
75    fn generate_inferred_instance_file(
76        &self,
77        sorts: &[fixpoint::SortDecl],
78        funs: &[fixpoint::ConstDecl],
79    ) -> Result<(), io::Error> {
80        let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
81        let mut inferred_instance_file = fs::File::create(self.lean_path.join(format!(
82            "{}/{}/InferredInstance.lean",
83            self.project_name,
84            pascal_project_name.as_str()
85        )))?;
86        writeln!(inferred_instance_file, "import {}.Instance\n", pascal_project_name.as_str())?;
87        writeln!(inferred_instance_file, "def fluxDefsInstance : FluxDefs := inferInstance\n")?;
88        for sort in sorts {
89            writeln!(
90                inferred_instance_file,
91                "def {} := fluxDefsInstance.{}",
92                LeanSortVar(&sort.name),
93                LeanSortVar(&sort.name)
94            )?;
95        }
96        for fun in funs {
97            writeln!(
98                inferred_instance_file,
99                "def {} := fluxDefsInstance.{}",
100                LeanConstDecl(fun, self.genv),
101                LeanVar(&fun.name, self.genv)
102            )?;
103        }
104        Ok(())
105    }
106
107    fn generate_typeclass_file(
108        &self,
109        sorts: &[fixpoint::SortDecl],
110        funs: &[fixpoint::ConstDecl],
111        data_decls: &[fixpoint::DataDecl],
112    ) -> Result<(), io::Error> {
113        let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
114        let mut opaque_defs_file = fs::File::create(self.lean_path.join(format!(
115            "{}/{}/OpaqueFluxDefs.lean",
116            self.project_name,
117            pascal_project_name.as_str(),
118        )))?;
119        writeln!(opaque_defs_file, "import {}.Lib", pascal_project_name.as_str())?;
120        if !data_decls.is_empty() {
121            writeln!(opaque_defs_file, "-- STRUCT DECLS --")?;
122            writeln!(opaque_defs_file, "mutual")?;
123            for data_decl in data_decls {
124                writeln!(opaque_defs_file, "{}", lean_format::LeanDataDecl(data_decl, self.genv))?;
125            }
126            writeln!(opaque_defs_file, "end")?;
127        }
128        writeln!(opaque_defs_file, "-- OPAQUE DEFS --")?;
129        writeln!(opaque_defs_file, "class FluxDefs where")?;
130        for sort in sorts {
131            writeln!(opaque_defs_file, "  {}", LeanSortDecl(sort, self.genv))?;
132        }
133        for fun in funs {
134            writeln!(opaque_defs_file, "  {}", LeanConstDecl(fun, self.genv))?;
135        }
136        self.generate_instance_file_if_not_present(sorts, funs)?;
137        self.generate_inferred_instance_file(sorts, funs)?;
138        Ok(())
139    }
140
141    fn generate_defs_file(
142        &self,
143        func_defs: &[fixpoint::FunDef],
144        has_opaques: bool,
145    ) -> Result<(), io::Error> {
146        let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
147        let defs_path = self.lean_path.join(
148            format!(
149                "{}/{}/{}.lean",
150                self.project_name,
151                pascal_project_name.as_str(),
152                self.defs_file_name
153            )
154            .as_str(),
155        );
156        let mut file = fs::File::create(defs_path)?;
157
158        writeln!(file, "import {}.Lib", pascal_project_name.as_str())?;
159        if has_opaques {
160            writeln!(file, "import {}.InferredInstance", pascal_project_name.as_str())?;
161        }
162        if !func_defs.is_empty() {
163            writeln!(file, "-- FUNC DECLS --")?;
164            writeln!(file, "mutual")?;
165            for fun_def in func_defs {
166                writeln!(file, "{}", lean_format::LeanFunDef(fun_def, self.genv))?;
167            }
168            writeln!(file, "end")?;
169        }
170        Ok(())
171    }
172
173    fn generate_lib_file(&self) -> Result<(), io::Error> {
174        let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
175        let mut lib_file = fs::File::create(self.lean_path.join(
176            format!("{}/{}/Lib.lean", self.project_name, pascal_project_name.as_str()).as_str(),
177        ))?;
178        writeln!(
179            lib_file,
180            "def BitVec_shiftLeft {{ n : Nat }} (x s : BitVec n) : BitVec n := BitVec.shiftLeft x (s.toNat)"
181        )?;
182        writeln!(
183            lib_file,
184            "def BitVec_ushiftRight {{ n : Nat }} (x s : BitVec n) : BitVec n := BitVec.ushiftRight x (s.toNat)"
185        )?;
186        writeln!(
187            lib_file,
188            "def BitVec_sshiftRight {{ n : Nat }} (x s : BitVec n) : BitVec n := BitVec.sshiftRight x (s.toNat)"
189        )?;
190        writeln!(
191            lib_file,
192            "def BitVec_uge {{ n : Nat }} (x y : BitVec n) := (BitVec.ult x y).not"
193        )?;
194        writeln!(
195            lib_file,
196            "def BitVec_sge {{ n : Nat }} (x y : BitVec n) := (BitVec.slt x y).not"
197        )?;
198        writeln!(
199            lib_file,
200            "def BitVec_ugt {{ n : Nat }} (x y : BitVec n) := (BitVec.ule x y).not"
201        )?;
202        writeln!(
203            lib_file,
204            "def BitVec_sgt {{ n : Nat }} (x y : BitVec n) := (BitVec.sle x y).not"
205        )?;
206        writeln!(
207            lib_file,
208            "def SmtMap (Key Val : Type) [Inhabited Key] [BEq Key] [Inhabited Val] : Type := Key -> Val"
209        )?;
210        writeln!(
211            lib_file,
212            "def SmtMap_default {{ Key Val: Type }} (v : Val) [Inhabited Key] [BEq Key] [Inhabited Val] : SmtMap Key Val := fun _ => v"
213        )?;
214        writeln!(
215            lib_file,
216            "def SmtMap_store {{ Key Val : Type }} [Inhabited Key] [BEq Key] [Inhabited Val] (m : SmtMap Key Val) (k : Key) (v : Val) : SmtMap Key Val :=\n  fun x => if x == k then v else m x"
217        )?;
218        writeln!(
219            lib_file,
220            "def SmtMap_select {{ Key Val : Type }} [Inhabited Key] [BEq Key] [Inhabited Val] (m : SmtMap Key Val) (k : Key) := m k"
221        )?;
222        Ok(())
223    }
224
225    pub fn encode_defs(
226        &self,
227        opaque_sorts: &[fixpoint::SortDecl],
228        opaque_funs: &[fixpoint::ConstDecl],
229        data_decls: &[fixpoint::DataDecl],
230        func_defs: &[fixpoint::FunDef],
231    ) -> Result<(), io::Error> {
232        self.generate_lake_project_if_not_present()?;
233        self.generate_lib_file()?;
234        let has_opaques =
235            !opaque_sorts.is_empty() || !opaque_funs.is_empty() || !data_decls.is_empty();
236        if has_opaques {
237            self.generate_typeclass_file(opaque_sorts, opaque_funs, data_decls)?;
238        }
239        if !func_defs.is_empty() {
240            self.generate_defs_file(func_defs, has_opaques)?;
241        }
242        Ok(())
243    }
244
245    fn generate_theorem_file(
246        &self,
247        theorem_name: &str,
248        kvars: &[fixpoint::KVarDecl],
249        cstr: &fixpoint::Constraint,
250    ) -> Result<(), io::Error> {
251        let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
252        let theorem_path = self.lean_path.join(
253            format!(
254                "{}/{}/{}.lean",
255                self.project_name,
256                pascal_project_name.as_str(),
257                Self::snake_case_to_pascal_case(theorem_name)
258            )
259            .as_str(),
260        );
261        let mut theorem_file = fs::File::create(theorem_path)?;
262        writeln!(theorem_file, "import {}.Lib", pascal_project_name.as_str())?;
263        if self
264            .lean_path
265            .join(
266                format!(
267                    "{}/{}/{}.lean",
268                    self.project_name.as_str(),
269                    pascal_project_name.as_str(),
270                    self.defs_file_name.as_str()
271                )
272                .as_str(),
273            )
274            .exists()
275        {
276            writeln!(
277                theorem_file,
278                "import {}.{}",
279                pascal_project_name.as_str(),
280                self.defs_file_name.as_str()
281            )?;
282        }
283        if self
284            .lean_path
285            .join(
286                format!(
287                    "{}/{}/InferredInstance.lean",
288                    self.project_name.as_str(),
289                    pascal_project_name.as_str(),
290                )
291                .as_str(),
292            )
293            .exists()
294        {
295            writeln!(theorem_file, "import {}.InferredInstance", pascal_project_name.as_str())?;
296        }
297        writeln!(
298            theorem_file,
299            "def {} := {}",
300            theorem_name.replace(".", "_"),
301            lean_format::LeanKConstraint(kvars, cstr, self.genv)
302        )
303    }
304
305    fn generate_proof_file_if_not_present(&self, theorem_name: &str) -> Result<(), io::Error> {
306        let module_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
307        let proof_name = format!("{theorem_name}_proof");
308        let proof_path = self.lean_path.join(
309            format!(
310                "{}/{}/{}.lean",
311                self.project_name.as_str(),
312                module_name.as_str(),
313                Self::snake_case_to_pascal_case(proof_name.as_str())
314            )
315            .as_str(),
316        );
317        if proof_path.exists() {
318            return Ok(());
319        }
320        let mut proof_file = fs::File::create(proof_path)?;
321        writeln!(proof_file, "import {}.Lib", module_name.as_str())?;
322        writeln!(
323            proof_file,
324            "import {}.{}",
325            module_name.as_str(),
326            Self::snake_case_to_pascal_case(theorem_name)
327        )?;
328        writeln!(proof_file, "def {proof_name} : {theorem_name} := by")?;
329        writeln!(proof_file, "  unfold {theorem_name}")?;
330        writeln!(proof_file, "  sorry")
331    }
332
333    pub fn encode_constraint(
334        &self,
335        def_id: MaybeExternId,
336        kvars: &[fixpoint::KVarDecl],
337        cstr: &fixpoint::Constraint,
338    ) -> Result<(), io::Error> {
339        self.generate_lake_project_if_not_present()?;
340        self.generate_lib_file()?;
341        let theorem_name = self
342            .genv
343            .tcx()
344            .def_path(def_id.resolved_id())
345            .to_filename_friendly_no_crate()
346            .replace("-", "_");
347        self.generate_theorem_file(theorem_name.as_str(), kvars, cstr)?;
348        self.generate_proof_file_if_not_present(theorem_name.as_str())
349    }
350
351    fn check_proof_help(&self, theorem_name: &str) -> io::Result<()> {
352        let proof_name = format!("{theorem_name}_proof");
353        let project_path = self.lean_path.join(self.project_name.as_str());
354        let proof_path = project_path.join(format!(
355            "{}/{}.lean",
356            Self::snake_case_to_pascal_case(self.project_name.as_str()),
357            Self::snake_case_to_pascal_case(proof_name.as_str())
358        ));
359        let child = Command::new("lake")
360            .arg("--dir")
361            .arg(project_path.to_str().unwrap())
362            .arg("lean")
363            .arg(proof_path.to_str().unwrap())
364            .stdout(Stdio::piped())
365            .stderr(Stdio::piped())
366            .spawn()?;
367        let out = child.wait_with_output()?;
368        if out.stderr.is_empty() && out.stdout.is_empty() {
369            Ok(())
370        } else {
371            let stderr = std::str::from_utf8(&out.stderr)
372                .unwrap_or("Lean exited with a non-zero return code");
373            Err(io::Error::other(stderr))
374        }
375    }
376
377    pub fn check_proof(&self, def_id: MaybeExternId) -> QueryResult<()> {
378        let theorem_name = self
379            .genv
380            .tcx()
381            .def_path(def_id.resolved_id())
382            .to_filename_friendly_no_crate()
383            .replace("-", "_");
384        self.check_proof_help(theorem_name.as_str()).map_err(|_| {
385            let msg = format!("checking proof for {} failed", theorem_name.as_str());
386            let span = self.genv.tcx().def_span(def_id.resolved_id());
387            QueryErr::Emitted(
388                self.genv
389                    .sess()
390                    .dcx()
391                    .handle()
392                    .struct_span_err(span, msg)
393                    .emit(),
394            )
395        })
396    }
397
398    fn snake_case_to_pascal_case(snake: &str) -> String {
399        snake
400            .split('_')
401            .filter(|s| !s.is_empty()) .map(|word| {
403                let mut chars = word.chars();
404                match chars.next() {
405                    Some(first) => first.to_ascii_uppercase().to_string() + chars.as_str(),
406                    None => String::new(),
407                }
408            })
409            .collect::<String>()
410    }
411}