1use std::{
2 fs,
3 io::{self, Write},
4 path,
5 process::{Command, Stdio},
6};
7
8use flux_middle::{
9 def_id::MaybeExternId,
10 global_env::GlobalEnv,
11 queries::{QueryErr, QueryResult},
12};
13use itertools::Itertools;
14use liquid_fixpoint::Identifier;
15
16use crate::fixpoint_encoding::fixpoint::{BinRel, ConstDecl, Constraint, Expr, FunDef, Pred};
17
18pub(crate) struct ConstDef(pub ConstDecl, pub Option<Expr>);
19
20pub(crate) struct LeanEncoder<'genv, 'tcx> {
21 def_id: MaybeExternId,
22 genv: GlobalEnv<'genv, 'tcx>,
23 fun_defs: Vec<FunDef>,
24 constants: Vec<ConstDef>,
25 constraint: Constraint,
26}
27
28impl<'genv, 'tcx> LeanEncoder<'genv, 'tcx> {
29 pub fn new(
30 def_id: MaybeExternId,
31 genv: GlobalEnv<'genv, 'tcx>,
32 fun_defs: Vec<FunDef>,
33 constants: Vec<ConstDecl>,
34 constraint: Constraint,
35 ) -> Self {
36 let constants = Self::extract_const_defs(constants, &constraint);
37 Self { def_id, genv, fun_defs, constants, constraint }
38 }
39
40 pub fn fun_defs(&self) -> &[FunDef] {
41 &self.fun_defs
42 }
43
44 pub fn constraint(&self) -> &Constraint {
45 &self.constraint
46 }
47
48 pub fn constants(&self) -> &[ConstDef] {
49 &self.constants
50 }
51
52 pub(crate) fn theorem_name(&self) -> String {
53 self.genv
54 .tcx()
55 .def_path(self.def_id.resolved_id())
56 .to_filename_friendly_no_crate()
57 .replace("-", "_")
58 }
59
60 fn proof_name(&self) -> String {
61 format!("{}_proof", self.theorem_name()).to_string()
62 }
63
64 fn generate_lake_project_if_not_present(
65 &self,
66 lean_path: &path::Path,
67 project_name: &str,
68 ) -> Result<(), io::Error> {
69 if !lean_path.join(project_name).exists() {
70 Command::new("lake")
71 .arg("new")
72 .arg(project_name)
73 .arg("lib")
74 .spawn()
75 .and_then(|mut child| child.wait())
76 .map(|_| ())
77 } else {
78 Ok(())
79 }
80 }
81
82 fn generate_def_file(
83 &self,
84 lean_path: &path::Path,
85 project_name: &str,
86 ) -> Result<(), io::Error> {
87 self.generate_lake_project_if_not_present(lean_path, project_name)?;
88 let theorem_path = lean_path.join(
89 format!(
90 "{project_name}/{}/{}.lean",
91 Self::snake_case_to_pascal_case(project_name),
92 Self::snake_case_to_pascal_case(self.theorem_name().as_str())
93 )
94 .as_str(),
95 );
96 let mut file = fs::File::create(theorem_path)?;
97 writeln!(file, "{self}")
98 }
99
100 fn generate_proof_file_if_not_present(
101 &self,
102 lean_path: &path::Path,
103 project_name: &str,
104 ) -> Result<(), io::Error> {
105 self.generate_def_file(lean_path, project_name)?;
106 let module_name = Self::snake_case_to_pascal_case(project_name);
107 let proof_name = self.proof_name();
108 let proof_path = lean_path.join(
109 format!(
110 "{project_name}/{}/{}.lean",
111 module_name.as_str(),
112 Self::snake_case_to_pascal_case(proof_name.as_str())
113 )
114 .as_str(),
115 );
116 let theorem_name = self.theorem_name();
117 if !proof_path.exists() {
118 let mut file = std::fs::File::create(proof_path)?;
119 writeln!(
120 file,
121 "import {}.{}",
122 Self::snake_case_to_pascal_case(module_name.as_str()),
123 Self::snake_case_to_pascal_case(theorem_name.as_str())
124 )?;
125 writeln!(file, "def {proof_name} : {theorem_name} := by")?;
126 writeln!(file, " unfold {theorem_name}")?;
127 writeln!(file, " sorry")
128 } else {
129 Ok(())
130 }
131 }
132
133 fn check_proof_help(&self, lean_path: &path::Path, project_name: &str) -> std::io::Result<()> {
134 self.generate_proof_file_if_not_present(lean_path, project_name)?;
135 let project_path = lean_path.join(project_name);
136 let proof_path = project_path.join(format!(
137 "{}/{}.lean",
138 Self::snake_case_to_pascal_case(project_name),
139 Self::snake_case_to_pascal_case(self.proof_name().as_str())
140 ));
141 let child = Command::new("lake")
142 .arg("--dir")
143 .arg(project_path.to_str().unwrap())
144 .arg("lean")
145 .arg(proof_path.to_str().unwrap())
146 .stdout(Stdio::piped())
147 .stderr(Stdio::piped())
148 .spawn()?;
149 let out = child.wait_with_output()?;
150 if out.stderr.is_empty() && out.stdout.is_empty() {
151 Ok(())
152 } else {
153 let stderr = std::str::from_utf8(&out.stderr)
154 .unwrap_or("Lean exited with a non-zero return code");
155 Err(io::Error::other(stderr))
156 }
157 }
158
159 pub fn check_proof(&self, lean_path: &path::Path, project_name: &str) -> QueryResult<()> {
160 self.check_proof_help(lean_path, project_name).map_err(|_| {
161 let msg = format!("checking proof for {} failed", self.theorem_name());
162 let span = self.genv.tcx().def_span(self.def_id.resolved_id());
163 QueryErr::Emitted(
164 self.genv
165 .sess()
166 .dcx()
167 .handle()
168 .struct_span_err(span, msg)
169 .emit(),
170 )
171 })
172 }
173
174 fn snake_case_to_pascal_case(snake: &str) -> String {
175 snake
176 .split('_')
177 .filter(|s| !s.is_empty()) .map(|word| {
179 let mut chars = word.chars();
180 match chars.next() {
181 Some(first) => first.to_ascii_uppercase().to_string() + chars.as_str(),
182 None => String::new(),
183 }
184 })
185 .collect::<String>()
186 }
187
188 fn extract_const_defs(const_decls: Vec<ConstDecl>, constraint: &Constraint) -> Vec<ConstDef> {
189 const_decls
190 .into_iter()
191 .map(|const_decl| {
192 let mut defs = vec![];
193 Self::extract_const_def(&const_decl, constraint, &mut defs);
194 if defs.len() <= 1 {
195 ConstDef(const_decl, defs.pop())
196 } else {
197 panic!("Constant {} has {} definitions", const_decl.name.display(), defs.len())
198 }
199 })
200 .collect_vec()
201 }
202
203 fn extract_const_def(const_decl: &ConstDecl, constraint: &Constraint, acc: &mut Vec<Expr>) {
204 match constraint {
205 Constraint::ForAll(bind, inner) => {
206 if let Pred::Expr(Expr::Atom(BinRel::Eq, equals)) = &bind.pred {
207 if let Expr::Var(vl) = &equals[0]
208 && vl.eq(&const_decl.name)
209 {
210 acc.push(equals[1].clone());
211 }
212 if let Expr::Var(vr) = &equals[1]
213 && vr.eq(&const_decl.name)
214 {
215 acc.push(equals[0].clone());
216 }
217 Self::extract_const_def(const_decl, inner.as_ref(), acc);
218 }
219 }
220 Constraint::Conj(conjuncts) => {
221 conjuncts
222 .iter()
223 .for_each(|constraint| Self::extract_const_def(const_decl, constraint, acc));
224 }
225 Constraint::Pred(..) => {}
226 }
227 }
228}