1use std::{
2 fs,
3 io::{self, Write},
4 path::PathBuf,
5 process::{Command, Stdio},
6};
7
8use flux_common::{dbg, dbg::SpanTrace};
9use flux_config as config;
10use flux_middle::{
11 def_id::MaybeExternId,
12 global_env::GlobalEnv,
13 queries::{QueryErr, QueryResult},
14 rty::PrettyMap,
15};
16
17use crate::{
18 fixpoint_encoding::fixpoint,
19 lean_format::{self, LeanCtxt, LeanSortVar, WithLeanCtxt},
20};
21
22pub struct LeanEncoder<'genv, 'tcx> {
23 genv: GlobalEnv<'genv, 'tcx>,
24 path: PathBuf,
25 project: String,
26 defs_file_name: String,
27 pretty_var_map: PrettyMap<fixpoint::LocalVar>,
28}
29
30impl<'genv, 'tcx> LeanEncoder<'genv, 'tcx> {
31 pub fn new(
32 genv: GlobalEnv<'genv, 'tcx>,
33 pretty_var_map: PrettyMap<fixpoint::LocalVar>,
34 ) -> Self {
35 let defs_file_name = "Defs".to_string();
36 let path = genv
37 .tcx()
38 .sess
39 .opts
40 .working_dir
41 .local_path_if_available()
42 .to_path_buf()
43 .join(config::lean_dir());
44 let project = config::lean_project().to_string();
45 Self { genv, path, project, defs_file_name, pretty_var_map }
46 }
47
48 fn generate_lake_project_if_not_present(&self) -> Result<(), io::Error> {
49 if !self.path.join(&self.project).exists() {
50 Command::new("lake")
51 .arg("new")
52 .arg(&self.project)
53 .arg("lib")
54 .spawn()
55 .and_then(|mut child| child.wait())
56 .map(|_| ())
57 } else {
58 Ok(())
59 }
60 }
61
62 fn generate_sort_def_file_if_not_present(
63 &self,
64 sort: &fixpoint::SortDecl,
65 ) -> Result<(), io::Error> {
66 let pascal_project_name = Self::snake_case_to_pascal_case(&self.project);
67 let pascal_sort_file_name =
68 Self::snake_case_to_pascal_case(&format!("{}", LeanSortVar(&sort.name)));
69 let instance_path = self
70 .path
71 .join(format!("{}/{pascal_project_name}/{pascal_sort_file_name}.lean", self.project));
72 if !instance_path.exists() {
73 let mut instance_file = fs::File::create(instance_path)?;
74 writeln!(instance_file, "import {}.Lib\n", pascal_project_name)?;
75 let cx = LeanCtxt { genv: self.genv, pretty_var_map: &self.pretty_var_map };
76 writeln!(instance_file, "def {} := sorry", WithLeanCtxt { item: sort, cx: &cx })?;
77 }
78 Ok(())
79 }
80
81 fn generate_func_def_file_if_not_present(
82 &self,
83 fun: &fixpoint::ConstDecl,
84 has_opaque_sorts: bool,
85 has_data_decls: bool,
86 ) -> Result<(), io::Error> {
87 let pascal_project_name = Self::snake_case_to_pascal_case(&self.project);
88 let cx = LeanCtxt { genv: self.genv, pretty_var_map: &self.pretty_var_map };
89 let fun_name = format!("{}", WithLeanCtxt { item: &fun.name, cx: &cx });
90 let pascal_fun_file_name = Self::snake_case_to_pascal_case(&fun_name);
91 let instance_path = self
92 .path
93 .join(format!("{}/{pascal_project_name}/{pascal_fun_file_name}.lean", self.project));
94 if !instance_path.exists() {
95 let mut instance_file = fs::File::create(instance_path)?;
96 writeln!(instance_file, "import {pascal_project_name}.Lib")?;
97 if has_opaque_sorts {
98 writeln!(instance_file, "import {pascal_project_name}.OpaqueSorts")?;
99 }
100 if has_data_decls {
101 writeln!(instance_file, "import {pascal_project_name}.Structs")?;
102 }
103 writeln!(instance_file, "def {} := sorry", WithLeanCtxt { item: fun, cx: &cx })?;
104 }
105 Ok(())
106 }
107
108 fn generate_opaque_sort_files(&self, sorts: &[fixpoint::SortDecl]) -> Result<(), io::Error> {
109 if sorts.is_empty() {
110 return Ok(());
111 }
112 for sort in sorts {
113 self.generate_sort_def_file_if_not_present(sort)?;
114 }
115 let pascal_project_name = Self::snake_case_to_pascal_case(&self.project);
116 let mut opaque_sorts_file = fs::File::create(
117 self.path
118 .join(format!("{}/{pascal_project_name}/OpaqueSorts.lean", self.project)),
119 )?;
120 for sort in sorts {
121 let pascal_sort_file_name =
122 Self::snake_case_to_pascal_case(&format!("{}", LeanSortVar(&sort.name)));
123 writeln!(opaque_sorts_file, "import {pascal_project_name}.{pascal_sort_file_name}")?;
124 }
125 Ok(())
126 }
127
128 fn generate_struct_defs_file(
129 &self,
130 data_decls: &[fixpoint::DataDecl],
131 has_opaque_sorts: bool,
132 ) -> Result<(), io::Error> {
133 if data_decls.is_empty() {
134 return Ok(());
135 }
136 let pascal_project_name = Self::snake_case_to_pascal_case(&self.project);
137 let mut structs_file = fs::File::create(
138 self.path
139 .join(format!("{}/{}/Structs.lean", self.project, pascal_project_name,)),
140 )?;
141 writeln!(structs_file, "import {}.Lib", pascal_project_name)?;
142 if has_opaque_sorts {
143 writeln!(structs_file, "import {}.OpaqueSorts", pascal_project_name)?;
144 }
145 writeln!(structs_file, "-- STRUCT DECLS --")?;
146 writeln!(structs_file, "mutual")?;
147 let cx = LeanCtxt { genv: self.genv, pretty_var_map: &self.pretty_var_map };
148 for data_decl in data_decls {
149 writeln!(structs_file, "{}", WithLeanCtxt { item: data_decl, cx: &cx })?;
150 }
151 writeln!(structs_file, "end")?;
152 Ok(())
153 }
154
155 fn generate_opaque_func_files(
156 &self,
157 funs: &[fixpoint::ConstDecl],
158 has_opaque_sorts: bool,
159 has_data_decls: bool,
160 ) -> Result<(), io::Error> {
161 if funs.is_empty() {
162 return Ok(());
163 }
164 for fun in funs {
165 self.generate_func_def_file_if_not_present(fun, has_opaque_sorts, has_data_decls)?;
166 }
167 let pascal_project_name = Self::snake_case_to_pascal_case(&self.project);
168 let mut opaque_funcs_file = fs::File::create(
169 self.path
170 .join(format!("{}/{pascal_project_name}/OpaqueFuncs.lean", self.project)),
171 )?;
172 let cx = LeanCtxt { genv: self.genv, pretty_var_map: &self.pretty_var_map };
173 for fun in funs {
174 let fun_name = format!("{}", WithLeanCtxt { item: &fun.name, cx: &cx });
175 let pascal_fun_file_name = Self::snake_case_to_pascal_case(&fun_name);
176 writeln!(opaque_funcs_file, "import {pascal_project_name}.{pascal_fun_file_name}")?;
177 }
178 Ok(())
179 }
180
181 fn generate_func_defs_file(
182 &self,
183 func_defs: &[fixpoint::FunDef],
184 has_opaque_sorts: bool,
185 has_data_decls: bool,
186 has_opaque_funcs: bool,
187 ) -> Result<(), io::Error> {
188 if func_defs.is_empty() {
189 return Ok(());
190 }
191 let pascal_project_name = Self::snake_case_to_pascal_case(&self.project);
192 let defs_path = self
193 .path
194 .join(format!("{}/{pascal_project_name}/{}.lean", self.project, self.defs_file_name));
195 let mut file = fs::File::create(defs_path)?;
196
197 writeln!(file, "import {pascal_project_name}.Lib")?;
198 if has_opaque_sorts {
199 writeln!(file, "import {pascal_project_name}.OpaqueSorts")?;
200 }
201 if has_data_decls {
202 writeln!(file, "import {pascal_project_name}.Structs")?;
203 }
204 if has_opaque_funcs {
205 writeln!(file, "import {pascal_project_name}.OpaqueFuncs")?;
206 }
207 writeln!(file, "-- FUNC DECLS --")?;
208 writeln!(file, "mutual")?;
209 let cx = LeanCtxt { genv: self.genv, pretty_var_map: &self.pretty_var_map };
210 for fun_def in func_defs {
211 writeln!(file, "{}", WithLeanCtxt { item: fun_def, cx: &cx })?;
212 }
213 writeln!(file, "end")?;
214 Ok(())
215 }
216
217 fn generate_lib_file(&self) -> Result<(), io::Error> {
218 let pascal_project_name = Self::snake_case_to_pascal_case(&self.project);
219 let mut lib_file = fs::File::create(
220 self.path
221 .join(format!("{}/{pascal_project_name}/Lib.lean", self.project)),
222 )?;
223 writeln!(
224 lib_file,
225 "def BitVec_shiftLeft {{ n : Nat }} (x s : BitVec n) : BitVec n := BitVec.shiftLeft x (s.toNat)"
226 )?;
227 writeln!(
228 lib_file,
229 "def BitVec_ushiftRight {{ n : Nat }} (x s : BitVec n) : BitVec n := BitVec.ushiftRight x (s.toNat)"
230 )?;
231 writeln!(
232 lib_file,
233 "def BitVec_sshiftRight {{ n : Nat }} (x s : BitVec n) : BitVec n := BitVec.sshiftRight x (s.toNat)"
234 )?;
235 writeln!(
236 lib_file,
237 "def BitVec_uge {{ n : Nat }} (x y : BitVec n) := (BitVec.ult x y).not"
238 )?;
239 writeln!(
240 lib_file,
241 "def BitVec_sge {{ n : Nat }} (x y : BitVec n) := (BitVec.slt x y).not"
242 )?;
243 writeln!(
244 lib_file,
245 "def BitVec_ugt {{ n : Nat }} (x y : BitVec n) := (BitVec.ule x y).not"
246 )?;
247 writeln!(
248 lib_file,
249 "def BitVec_sgt {{ n : Nat }} (x y : BitVec n) := (BitVec.sle x y).not"
250 )?;
251 writeln!(
252 lib_file,
253 "def SmtMap (t0 t1 : Type) [Inhabited t0] [BEq t0] [Inhabited t1] : Type := t0 -> t1"
254 )?;
255 writeln!(
256 lib_file,
257 "def SmtMap_default {{ t0 t1: Type }} (v : t1) [Inhabited t0] [BEq t0] [Inhabited t1] : SmtMap t0 t1 := fun _ => v"
258 )?;
259 writeln!(
260 lib_file,
261 "def SmtMap_store {{ t0 t1 : Type }} [Inhabited t0] [BEq t0] [Inhabited t1] (m : SmtMap t0 t1) (k : t0) (v : t1) : SmtMap t0 t1 :=\n fun x => if x == k then v else m x"
262 )?;
263 writeln!(
264 lib_file,
265 "def SmtMap_select {{ t0 t1 : Type }} [Inhabited t0] [BEq t0] [Inhabited t1] (m : SmtMap t0 t1) (k : t0) := m k"
266 )?;
267 Ok(())
268 }
269
270 pub fn encode_defs(
271 &self,
272 opaque_sorts: &[fixpoint::SortDecl],
273 opaque_funs: &[fixpoint::ConstDecl],
274 data_decls: &[fixpoint::DataDecl],
275 func_defs: &[fixpoint::FunDef],
276 ) -> Result<(), io::Error> {
277 self.generate_lake_project_if_not_present()?;
278 self.generate_lib_file()?;
279
280 let has_opaque_sorts = !opaque_sorts.is_empty();
281 let has_data_decls = !data_decls.is_empty();
282 let has_opaque_funcs = !opaque_funs.is_empty();
283
284 self.generate_opaque_sort_files(opaque_sorts)?;
285 self.generate_struct_defs_file(data_decls, has_opaque_sorts)?;
286 self.generate_opaque_func_files(opaque_funs, has_opaque_sorts, has_data_decls)?;
287 self.generate_func_defs_file(
288 func_defs,
289 has_opaque_sorts,
290 has_data_decls,
291 has_opaque_funcs,
292 )?;
293 Ok(())
294 }
295
296 fn theorem_path(&self, theorem_name: &str) -> PathBuf {
297 let pascal_project_name = Self::snake_case_to_pascal_case(&self.project);
298
299 self.path.join(format!(
300 "{}/{}/{}.lean",
301 self.project,
302 pascal_project_name,
303 Self::snake_case_to_pascal_case(theorem_name)
304 ))
305 }
306
307 fn generate_theorem_file(
308 &self,
309 theorem_name: &str,
310 kvars: &[fixpoint::KVarDecl],
311 cstr: &fixpoint::Constraint,
312 ) -> Result<(), io::Error> {
313 let pascal_project_name = Self::snake_case_to_pascal_case(&self.project);
314 let theorem_path = self.theorem_path(theorem_name);
315 let mut theorem_file = fs::File::create(theorem_path)?;
316 writeln!(theorem_file, "import {}.Lib", pascal_project_name)?;
317 if self
318 .path
319 .join(format!("{}/{pascal_project_name}/{}.lean", self.project, self.defs_file_name))
320 .exists()
321 {
322 writeln!(theorem_file, "import {pascal_project_name}.{}", self.defs_file_name)?;
323 }
324 if self
325 .path
326 .join(format!("{}/{pascal_project_name}/OpaqueSorts.lean", self.project))
327 .exists()
328 {
329 writeln!(theorem_file, "import {pascal_project_name}.OpaqueSorts")?;
330 }
331 if self
332 .path
333 .join(format!("{}/{pascal_project_name}/OpaqueFuncs.lean", self.project))
334 .exists()
335 {
336 writeln!(theorem_file, "import {pascal_project_name}.OpaqueFuncs")?;
337 }
338 let cx = LeanCtxt { genv: self.genv, pretty_var_map: &self.pretty_var_map };
339 writeln!(
340 theorem_file,
341 "def {} := {}",
342 theorem_name.replace(".", "_"),
343 WithLeanCtxt { item: lean_format::LeanKConstraint { kvars, constr: cstr }, cx: &cx }
344 )
345 }
346
347 fn generate_proof_file_if_not_present(
348 &self,
349 def_id: MaybeExternId,
350 theorem_name: &str,
351 ) -> Result<(), io::Error> {
352 let module_name = Self::snake_case_to_pascal_case(&self.project);
353 let proof_name = format!("{theorem_name}_proof");
354 let proof_path = self.path.join(format!(
355 "{}/{}/{}.lean",
356 self.project,
357 module_name,
358 Self::snake_case_to_pascal_case(&proof_name)
359 ));
360
361 if let Some(span) = self.genv.proven_externally(def_id.local_id()) {
362 let dst_span = SpanTrace::from_path(&proof_path, 3, 5, proof_name.len());
363 dbg::hyperlink_json!(self.genv.tcx(), span, dst_span);
364 }
365
366 if proof_path.exists() {
367 return Ok(());
368 }
369 let mut proof_file = fs::File::create(proof_path)?;
370 writeln!(proof_file, "import {module_name}.Lib")?;
371 writeln!(
372 proof_file,
373 "import {}.{}",
374 module_name,
375 Self::snake_case_to_pascal_case(theorem_name)
376 )?;
377 writeln!(proof_file, "def {proof_name} : {theorem_name} := by")?;
378 writeln!(proof_file, " unfold {theorem_name}")?;
379 writeln!(proof_file, " sorry")
380 }
381
382 pub fn encode_constraint(
383 &self,
384 def_id: MaybeExternId,
385 kvars: &[fixpoint::KVarDecl],
386 cstr: &fixpoint::Constraint,
387 ) -> Result<(), io::Error> {
388 self.generate_lake_project_if_not_present()?;
389 self.generate_lib_file()?;
390 let theorem_name = self
391 .genv
392 .tcx()
393 .def_path(def_id.resolved_id())
394 .to_filename_friendly_no_crate()
395 .replace("-", "_");
396
397 self.generate_theorem_file(&theorem_name, kvars, cstr)?;
398
399 self.generate_proof_file_if_not_present(def_id, &theorem_name)
400 }
401
402 fn check_proof_help(&self, theorem_name: &str) -> io::Result<()> {
403 let proof_name = format!("{theorem_name}_proof");
404 let project_path = self.path.join(&self.project);
405 let proof_path = format!(
406 "{}/{}.lean",
407 Self::snake_case_to_pascal_case(&self.project),
408 Self::snake_case_to_pascal_case(&proof_name)
409 );
410 let child = Command::new("lake")
411 .arg("--quiet")
412 .arg("lean")
413 .arg(proof_path)
414 .stdout(Stdio::piped())
415 .stderr(Stdio::piped())
416 .current_dir(project_path.as_path())
417 .spawn()?;
418 let out = child.wait_with_output()?;
419 if out.stderr.is_empty() && out.stdout.is_empty() {
420 Ok(())
421 } else {
422 let stderr = std::str::from_utf8(&out.stderr)
423 .unwrap_or("Lean exited with a non-zero return code");
424 Err(io::Error::other(stderr))
425 }
426 }
427
428 pub fn check_proof(&self, def_id: MaybeExternId) -> QueryResult<()> {
429 let theorem_name = self
430 .genv
431 .tcx()
432 .def_path(def_id.resolved_id())
433 .to_filename_friendly_no_crate()
434 .replace("-", "_");
435 self.check_proof_help(&theorem_name).map_err(|_| {
436 let msg = format!("checking proof for {theorem_name} failed");
437 let span = self.genv.tcx().def_span(def_id.resolved_id());
438 QueryErr::Emitted(
439 self.genv
440 .sess()
441 .dcx()
442 .handle()
443 .struct_span_err(span, msg)
444 .emit(),
445 )
446 })
447 }
448
449 fn snake_case_to_pascal_case(snake: &str) -> String {
450 snake
451 .split('_')
452 .filter(|s| !s.is_empty()) .map(|word| {
454 let mut chars = word.chars();
455 match chars.next() {
456 Some(first) => first.to_ascii_uppercase().to_string() + chars.as_str(),
457 None => String::new(),
458 }
459 })
460 .collect::<String>()
461 }
462}