1use std::{
2 fs,
3 io::{self, Write},
4 path::Path,
5 process::{Command, Stdio},
6};
7
8use flux_middle::{
9 def_id::MaybeExternId,
10 global_env::GlobalEnv,
11 queries::{QueryErr, QueryResult},
12};
13
14use crate::{
15 fixpoint_encoding::fixpoint,
16 lean_format::{self, LeanConstDecl, LeanSortDecl, LeanSortVar, LeanVar},
17};
18
19pub struct LeanEncoder<'genv, 'tcx, 'a> {
20 genv: GlobalEnv<'genv, 'tcx>,
21 lean_path: &'a Path,
22 project_name: String,
23 defs_file_name: String,
24}
25
26impl<'genv, 'tcx, 'a> LeanEncoder<'genv, 'tcx, 'a> {
27 pub fn new(
28 genv: GlobalEnv<'genv, 'tcx>,
29 lean_path: &'a Path,
30 project_name: String,
31 defs_file_name: String,
32 ) -> Self {
33 Self { genv, lean_path, project_name, defs_file_name }
34 }
35
36 fn generate_lake_project_if_not_present(&self) -> Result<(), io::Error> {
37 if !self.lean_path.join(self.project_name.as_str()).exists() {
38 Command::new("lake")
39 .arg("new")
40 .arg(self.project_name.as_str())
41 .arg("lib")
42 .spawn()
43 .and_then(|mut child| child.wait())
44 .map(|_| ())
45 } else {
46 Ok(())
47 }
48 }
49
50 fn generate_sorts_instance_file_if_not_present(
51 &self,
52 sorts: &[fixpoint::SortDecl],
53 ) -> Result<(), io::Error> {
54 let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
55 let instance_path = self.lean_path.join(
56 format!(
57 "{}/{}/OpaqueSortsInstance.lean",
58 self.project_name,
59 pascal_project_name.as_str(),
60 )
61 .as_str(),
62 );
63 if !instance_path.exists() {
64 let mut instance_file = fs::File::create(instance_path)?;
65 writeln!(instance_file, "import {}.Lib", pascal_project_name.as_str())?;
66 writeln!(instance_file, "import {}.OpaqueSortDefs\n", pascal_project_name.as_str())?;
67 writeln!(instance_file, "instance : FluxOpaqueSorts where")?;
68 for sort in sorts {
69 writeln!(instance_file, " {} := sorry", LeanSortVar(&sort.name))?;
70 }
71 }
72 Ok(())
73 }
74
75 fn generate_funcs_instance_file_if_not_present(
76 &self,
77 funs: &[fixpoint::ConstDecl],
78 ) -> Result<(), io::Error> {
79 let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
80 let instance_path = self.lean_path.join(
81 format!(
82 "{}/{}/OpaqueFuncsInstance.lean",
83 self.project_name,
84 pascal_project_name.as_str(),
85 )
86 .as_str(),
87 );
88 if !instance_path.exists() {
89 let mut instance_file = fs::File::create(instance_path)?;
90 writeln!(instance_file, "import {}.Lib", pascal_project_name.as_str())?;
91 writeln!(instance_file, "import {}.OpaqueFuncDefs\n", pascal_project_name.as_str())?;
92 writeln!(instance_file, "instance : FluxOpaqueFuncs where")?;
93 for fun in funs {
94 writeln!(instance_file, " {} := sorry", LeanVar(&fun.name, self.genv))?;
95 }
96 }
97 Ok(())
98 }
99
100 fn generate_sorts_inferred_instance_file(
101 &self,
102 sorts: &[fixpoint::SortDecl],
103 ) -> Result<(), io::Error> {
104 let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
105 let mut inferred_instance_file = fs::File::create(self.lean_path.join(format!(
106 "{}/{}/OpaqueSorts.lean",
107 self.project_name,
108 pascal_project_name.as_str()
109 )))?;
110 writeln!(
111 inferred_instance_file,
112 "import {}.OpaqueSortsInstance\n",
113 pascal_project_name.as_str()
114 )?;
115 writeln!(
116 inferred_instance_file,
117 "def fluxOpaqueSorts : FluxOpaqueSorts := inferInstance\n"
118 )?;
119 for sort in sorts {
120 writeln!(
121 inferred_instance_file,
122 "def {} := fluxOpaqueSorts.{}",
123 LeanSortVar(&sort.name),
124 LeanSortVar(&sort.name)
125 )?;
126 }
127 Ok(())
128 }
129
130 fn generate_funcs_inferred_instance_file(
131 &self,
132 funs: &[fixpoint::ConstDecl],
133 ) -> Result<(), io::Error> {
134 let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
135 let mut inferred_instance_file = fs::File::create(self.lean_path.join(format!(
136 "{}/{}/OpaqueFuncs.lean",
137 self.project_name,
138 pascal_project_name.as_str()
139 )))?;
140 writeln!(
141 inferred_instance_file,
142 "import {}.OpaqueFuncsInstance\n",
143 pascal_project_name.as_str()
144 )?;
145 writeln!(
146 inferred_instance_file,
147 "def fluxOpaqueFuncs : FluxOpaqueFuncs := inferInstance\n"
148 )?;
149 for fun in funs {
150 writeln!(
151 inferred_instance_file,
152 "def {} := fluxOpaqueFuncs.{}",
153 LeanConstDecl(fun, self.genv),
154 LeanVar(&fun.name, self.genv)
155 )?;
156 }
157 Ok(())
158 }
159
160 fn generate_sort_typeclass_files(&self, sorts: &[fixpoint::SortDecl]) -> Result<(), io::Error> {
161 let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
162 let mut opaque_sorts_file = fs::File::create(self.lean_path.join(format!(
163 "{}/{}/OpaqueSortDefs.lean",
164 self.project_name,
165 pascal_project_name.as_str(),
166 )))?;
167 writeln!(opaque_sorts_file, "import {}.Lib", pascal_project_name.as_str())?;
168 writeln!(opaque_sorts_file, "-- FLUX OPAQUE SORT DEFS --")?;
169 writeln!(opaque_sorts_file, "class FluxOpaqueSorts where")?;
170 for sort in sorts {
171 writeln!(opaque_sorts_file, " {}", LeanSortDecl(sort, self.genv))?;
172 }
173 self.generate_sorts_instance_file_if_not_present(sorts)?;
174 self.generate_sorts_inferred_instance_file(sorts)?;
175 Ok(())
176 }
177
178 fn generate_struct_defs_file(
179 &self,
180 data_decls: &[fixpoint::DataDecl],
181 has_opaque_sorts: bool,
182 ) -> Result<(), io::Error> {
183 let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
184 let mut structs_file = fs::File::create(self.lean_path.join(format!(
185 "{}/{}/Structs.lean",
186 self.project_name,
187 pascal_project_name.as_str(),
188 )))?;
189 writeln!(structs_file, "import {}.Lib", pascal_project_name.as_str())?;
190 if has_opaque_sorts {
191 writeln!(structs_file, "import {}.OpaqueSorts", pascal_project_name.as_str())?;
192 }
193 writeln!(structs_file, "-- STRUCT DECLS --")?;
194 writeln!(structs_file, "mutual")?;
195 for data_decl in data_decls {
196 writeln!(structs_file, "{}", lean_format::LeanDataDecl(data_decl, self.genv))?;
197 }
198 writeln!(structs_file, "end")?;
199 Ok(())
200 }
201
202 fn generate_func_typeclass_files(
203 &self,
204 funs: &[fixpoint::ConstDecl],
205 has_opaque_sorts: bool,
206 has_data_decls: bool,
207 ) -> Result<(), io::Error> {
208 let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
209 let mut opaque_funcs_file = fs::File::create(self.lean_path.join(format!(
210 "{}/{}/OpaqueFuncDefs.lean",
211 self.project_name,
212 pascal_project_name.as_str(),
213 )))?;
214 writeln!(opaque_funcs_file, "import {}.Lib", pascal_project_name.as_str())?;
215 if has_opaque_sorts {
216 writeln!(opaque_funcs_file, "import {}.OpaqueSorts", pascal_project_name.as_str())?;
217 }
218 if has_data_decls {
219 writeln!(opaque_funcs_file, "import {}.Structs", pascal_project_name.as_str())?;
220 }
221 writeln!(opaque_funcs_file, "-- OPAQUE DEFS --")?;
222 writeln!(opaque_funcs_file, "class FluxOpaqueFuncs where")?;
223 for fun in funs {
224 writeln!(opaque_funcs_file, " {}", LeanConstDecl(fun, self.genv))?;
225 }
226 self.generate_funcs_instance_file_if_not_present(funs)?;
227 self.generate_funcs_inferred_instance_file(funs)?;
228 Ok(())
229 }
230
231 fn generate_func_defs_file(
232 &self,
233 func_defs: &[fixpoint::FunDef],
234 has_opaque_sorts: bool,
235 has_data_decls: bool,
236 has_opaque_funcs: bool,
237 ) -> Result<(), io::Error> {
238 let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
239 let defs_path = self.lean_path.join(
240 format!(
241 "{}/{}/{}.lean",
242 self.project_name,
243 pascal_project_name.as_str(),
244 self.defs_file_name
245 )
246 .as_str(),
247 );
248 let mut file = fs::File::create(defs_path)?;
249
250 writeln!(file, "import {}.Lib", pascal_project_name.as_str())?;
251 if has_opaque_sorts {
252 writeln!(file, "import {}.OpaqueSorts", pascal_project_name.as_str())?;
253 }
254 if has_data_decls {
255 writeln!(file, "import {}.Structs", pascal_project_name.as_str())?;
256 }
257 if has_opaque_funcs {
258 writeln!(file, "import {}.OpaqueFuncs", pascal_project_name.as_str())?;
259 }
260 writeln!(file, "-- FUNC DECLS --")?;
261 writeln!(file, "mutual")?;
262 for fun_def in func_defs {
263 writeln!(file, "{}", lean_format::LeanFunDef(fun_def, self.genv))?;
264 }
265 writeln!(file, "end")?;
266 Ok(())
267 }
268
269 fn generate_lib_file(&self) -> Result<(), io::Error> {
270 let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
271 let mut lib_file = fs::File::create(self.lean_path.join(
272 format!("{}/{}/Lib.lean", self.project_name, pascal_project_name.as_str()).as_str(),
273 ))?;
274 writeln!(
275 lib_file,
276 "def BitVec_shiftLeft {{ n : Nat }} (x s : BitVec n) : BitVec n := BitVec.shiftLeft x (s.toNat)"
277 )?;
278 writeln!(
279 lib_file,
280 "def BitVec_ushiftRight {{ n : Nat }} (x s : BitVec n) : BitVec n := BitVec.ushiftRight x (s.toNat)"
281 )?;
282 writeln!(
283 lib_file,
284 "def BitVec_sshiftRight {{ n : Nat }} (x s : BitVec n) : BitVec n := BitVec.sshiftRight x (s.toNat)"
285 )?;
286 writeln!(
287 lib_file,
288 "def BitVec_uge {{ n : Nat }} (x y : BitVec n) := (BitVec.ult x y).not"
289 )?;
290 writeln!(
291 lib_file,
292 "def BitVec_sge {{ n : Nat }} (x y : BitVec n) := (BitVec.slt x y).not"
293 )?;
294 writeln!(
295 lib_file,
296 "def BitVec_ugt {{ n : Nat }} (x y : BitVec n) := (BitVec.ule x y).not"
297 )?;
298 writeln!(
299 lib_file,
300 "def BitVec_sgt {{ n : Nat }} (x y : BitVec n) := (BitVec.sle x y).not"
301 )?;
302 writeln!(
303 lib_file,
304 "def SmtMap (t0 t1 : Type) [Inhabited t0] [BEq t0] [Inhabited t1] : Type := t0 -> t1"
305 )?;
306 writeln!(
307 lib_file,
308 "def SmtMap_default {{ t0 t1: Type }} (v : t1) [Inhabited t0] [BEq t0] [Inhabited t1] : SmtMap t0 t1 := fun _ => v"
309 )?;
310 writeln!(
311 lib_file,
312 "def SmtMap_store {{ t0 t1 : Type }} [Inhabited t0] [BEq t0] [Inhabited t1] (m : SmtMap t0 t1) (k : t0) (v : t1) : SmtMap t0 t1 :=\n fun x => if x == k then v else m x"
313 )?;
314 writeln!(
315 lib_file,
316 "def SmtMap_select {{ t0 t1 : Type }} [Inhabited t0] [BEq t0] [Inhabited t1] (m : SmtMap t0 t1) (k : t0) := m k"
317 )?;
318 Ok(())
319 }
320
321 pub fn encode_defs(
322 &self,
323 opaque_sorts: &[fixpoint::SortDecl],
324 opaque_funs: &[fixpoint::ConstDecl],
325 data_decls: &[fixpoint::DataDecl],
326 func_defs: &[fixpoint::FunDef],
327 ) -> Result<(), io::Error> {
328 self.generate_lake_project_if_not_present()?;
329 self.generate_lib_file()?;
330
331 let has_opaque_sorts = !opaque_sorts.is_empty();
332 let has_data_decls = !data_decls.is_empty();
333 let has_opaque_funcs = !opaque_funs.is_empty();
334
335 if has_opaque_sorts {
336 self.generate_sort_typeclass_files(opaque_sorts)?;
337 }
338 if has_data_decls {
339 self.generate_struct_defs_file(data_decls, has_opaque_sorts)?;
340 }
341 if has_opaque_funcs {
342 self.generate_func_typeclass_files(opaque_funs, has_opaque_sorts, has_data_decls)?;
343 }
344 if !func_defs.is_empty() {
345 self.generate_func_defs_file(
346 func_defs,
347 has_opaque_sorts,
348 has_data_decls,
349 has_opaque_funcs,
350 )?;
351 }
352 Ok(())
353 }
354
355 fn generate_theorem_file(
356 &self,
357 theorem_name: &str,
358 kvars: &[fixpoint::KVarDecl],
359 cstr: &fixpoint::Constraint,
360 ) -> Result<(), io::Error> {
361 let pascal_project_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
362 let theorem_path = self.lean_path.join(
363 format!(
364 "{}/{}/{}.lean",
365 self.project_name,
366 pascal_project_name.as_str(),
367 Self::snake_case_to_pascal_case(theorem_name)
368 )
369 .as_str(),
370 );
371 let mut theorem_file = fs::File::create(theorem_path)?;
372 writeln!(theorem_file, "import {}.Lib", pascal_project_name.as_str())?;
373 if self
374 .lean_path
375 .join(
376 format!(
377 "{}/{}/{}.lean",
378 self.project_name.as_str(),
379 pascal_project_name.as_str(),
380 self.defs_file_name.as_str()
381 )
382 .as_str(),
383 )
384 .exists()
385 {
386 writeln!(
387 theorem_file,
388 "import {}.{}",
389 pascal_project_name.as_str(),
390 self.defs_file_name.as_str()
391 )?;
392 }
393 if self
394 .lean_path
395 .join(
396 format!(
397 "{}/{}/OpaqueSorts.lean",
398 self.project_name.as_str(),
399 pascal_project_name.as_str(),
400 )
401 .as_str(),
402 )
403 .exists()
404 {
405 writeln!(theorem_file, "import {}.OpaqueSorts", pascal_project_name.as_str())?;
406 }
407 if self
408 .lean_path
409 .join(
410 format!(
411 "{}/{}/OpaqueFuncs.lean",
412 self.project_name.as_str(),
413 pascal_project_name.as_str(),
414 )
415 .as_str(),
416 )
417 .exists()
418 {
419 writeln!(theorem_file, "import {}.OpaqueFuncs", pascal_project_name.as_str())?;
420 }
421 writeln!(
422 theorem_file,
423 "def {} := {}",
424 theorem_name.replace(".", "_"),
425 lean_format::LeanKConstraint(kvars, cstr, self.genv)
426 )
427 }
428
429 fn generate_proof_file_if_not_present(&self, theorem_name: &str) -> Result<(), io::Error> {
430 let module_name = Self::snake_case_to_pascal_case(self.project_name.as_str());
431 let proof_name = format!("{theorem_name}_proof");
432 let proof_path = self.lean_path.join(
433 format!(
434 "{}/{}/{}.lean",
435 self.project_name.as_str(),
436 module_name.as_str(),
437 Self::snake_case_to_pascal_case(proof_name.as_str())
438 )
439 .as_str(),
440 );
441 if proof_path.exists() {
442 return Ok(());
443 }
444 let mut proof_file = fs::File::create(proof_path)?;
445 writeln!(proof_file, "import {}.Lib", module_name.as_str())?;
446 writeln!(
447 proof_file,
448 "import {}.{}",
449 module_name.as_str(),
450 Self::snake_case_to_pascal_case(theorem_name)
451 )?;
452 writeln!(proof_file, "def {proof_name} : {theorem_name} := by")?;
453 writeln!(proof_file, " unfold {theorem_name}")?;
454 writeln!(proof_file, " sorry")
455 }
456
457 pub fn encode_constraint(
458 &self,
459 def_id: MaybeExternId,
460 kvars: &[fixpoint::KVarDecl],
461 cstr: &fixpoint::Constraint,
462 ) -> Result<(), io::Error> {
463 self.generate_lake_project_if_not_present()?;
464 self.generate_lib_file()?;
465 let theorem_name = self
466 .genv
467 .tcx()
468 .def_path(def_id.resolved_id())
469 .to_filename_friendly_no_crate()
470 .replace("-", "_");
471 self.generate_theorem_file(theorem_name.as_str(), kvars, cstr)?;
472 self.generate_proof_file_if_not_present(theorem_name.as_str())
473 }
474
475 fn check_proof_help(&self, theorem_name: &str) -> io::Result<()> {
476 let proof_name = format!("{theorem_name}_proof");
477 let project_path = self.lean_path.join(self.project_name.as_str());
478 let proof_path = project_path.join(format!(
479 "{}/{}.lean",
480 Self::snake_case_to_pascal_case(self.project_name.as_str()),
481 Self::snake_case_to_pascal_case(proof_name.as_str())
482 ));
483 let child = Command::new("lake")
484 .arg("--quiet")
485 .arg("--dir")
486 .arg(project_path.to_str().unwrap())
487 .arg("lean")
488 .arg(proof_path.to_str().unwrap())
489 .stdout(Stdio::piped())
490 .stderr(Stdio::piped())
491 .spawn()?;
492 let out = child.wait_with_output()?;
493 if out.stderr.is_empty() && out.stdout.is_empty() {
494 Ok(())
495 } else {
496 let stderr = std::str::from_utf8(&out.stderr)
497 .unwrap_or("Lean exited with a non-zero return code");
498 Err(io::Error::other(stderr))
499 }
500 }
501
502 pub fn check_proof(&self, def_id: MaybeExternId) -> QueryResult<()> {
503 let theorem_name = self
504 .genv
505 .tcx()
506 .def_path(def_id.resolved_id())
507 .to_filename_friendly_no_crate()
508 .replace("-", "_");
509 self.check_proof_help(theorem_name.as_str()).map_err(|_| {
510 let msg = format!("checking proof for {} failed", theorem_name.as_str());
511 let span = self.genv.tcx().def_span(def_id.resolved_id());
512 QueryErr::Emitted(
513 self.genv
514 .sess()
515 .dcx()
516 .handle()
517 .struct_span_err(span, msg)
518 .emit(),
519 )
520 })
521 }
522
523 fn snake_case_to_pascal_case(snake: &str) -> String {
524 snake
525 .split('_')
526 .filter(|s| !s.is_empty()) .map(|word| {
528 let mut chars = word.chars();
529 match chars.next() {
530 Some(first) => first.to_ascii_uppercase().to_string() + chars.as_str(),
531 None => String::new(),
532 }
533 })
534 .collect::<String>()
535 }
536}