1use flux_common::tracked_span_bug;
2use flux_middle::{
3 big_int::BigInt,
4 rty::{self, Binder, EarlyReftParam, InternalFuncKind, List, SpecFuncKind},
5};
6use flux_rustc_bridge::lowering::Lower;
7use itertools::Itertools;
8use rustc_hir::def_id::DefId;
9use rustc_type_ir::BoundVar;
10
11use super::{ConstKey, FixpointCtxt, fixpoint};
12use crate::fixpoint_encoding::FixpointSolution;
13
14impl<'genv, 'tcx, Tag> FixpointCtxt<'genv, 'tcx, Tag>
15where
16 Tag: std::hash::Hash + Eq + Copy,
17{
18 pub(crate) fn fixpoint_to_solution(
19 &mut self,
20 sol: &FixpointSolution,
21 ) -> rty::Binder<rty::Expr> {
22 let mut vars = vec![];
23 let mut sorts = vec![];
24 for (var, sort) in &sol.0 {
25 let fixpoint::Var::Local(local_var) = var else {
26 tracked_span_bug!("encountered non-local variable in binder: {var:?}");
27 };
28 vars.push(*local_var);
29 sorts.push(
30 self.fixpoint_to_sort(sort)
31 .unwrap_or_else(|_| tracked_span_bug!("failed to parse sort: {sort:?}")),
32 );
33 }
34 self.ecx.local_var_env.push_layer(vars);
35 let expr = self
36 .fixpoint_to_expr(&sol.1)
37 .unwrap_or_else(|err| tracked_span_bug!("failed to convert expr: {err:?}"));
38 self.ecx.local_var_env.pop_layer();
39 rty::Binder::bind_with_sorts(expr, &sorts)
40 }
41
42 fn fixpoint_to_sort_ctor(
43 &self,
44 ctor: &fixpoint::SortCtor,
45 ) -> Result<rty::SortCtor, FixpointParseError> {
46 match ctor {
47 fixpoint::SortCtor::Set => Ok(rty::SortCtor::Set),
48 fixpoint::SortCtor::Map => Ok(rty::SortCtor::Map),
49 fixpoint::SortCtor::Data(fixpoint::DataSort::Tuple(_)) => {
50 panic!("oh no! tuple!") }
52 fixpoint::SortCtor::Data(fixpoint::DataSort::User(opaque_id)) => {
53 let def_id = self.scx.opaque_sorts[opaque_id.as_usize()];
54 Ok(rty::SortCtor::User(def_id))
55 }
56 fixpoint::SortCtor::Data(fixpoint::DataSort::Adt(adt_id)) => {
57 let def_id = self.scx.adt_sorts[adt_id.as_usize()];
58 let Ok(adt_sort_def) = self.genv.adt_sort_def_of(def_id) else {
59 return Err(FixpointParseError::UnknownAdt(def_id));
60 };
61 Ok(rty::SortCtor::Adt(adt_sort_def))
62 }
63 }
64 }
65
66 pub(crate) fn fixpoint_to_sort(
67 &self,
68 fsort: &fixpoint::Sort,
69 ) -> Result<rty::Sort, FixpointParseError> {
70 match fsort {
71 fixpoint::Sort::Int => Ok(rty::Sort::Int),
72 fixpoint::Sort::Real => Ok(rty::Sort::Real),
73 fixpoint::Sort::Bool => Ok(rty::Sort::Bool),
74 fixpoint::Sort::Str => Ok(rty::Sort::Str),
75 fixpoint::Sort::Func(sorts) => {
76 let sort1 = self.fixpoint_to_sort(&sorts[0])?;
77 let sort2 = self.fixpoint_to_sort(&sorts[1])?;
78 let fsort = rty::FuncSort::new(vec![sort1], sort2);
79 let poly_sort = rty::PolyFuncSort::new(List::empty(), fsort);
80 Ok(rty::Sort::Func(poly_sort))
81 }
82 fixpoint::Sort::App(ctor, args) => {
83 let ctor = self.fixpoint_to_sort_ctor(ctor)?;
84 let args = args
85 .iter()
86 .map(|fsort| self.fixpoint_to_sort(fsort))
87 .try_collect()?;
88 Ok(rty::Sort::App(ctor, args))
89 }
90 fixpoint::Sort::BitVec(fsort) if let fixpoint::Sort::BvSize(size) = **fsort => {
91 Ok(rty::Sort::BitVec(rty::BvSize::Fixed(size)))
92 }
93 _ => unimplemented!("fixpoint_to_sort: {fsort:?}"),
94 }
95 }
96
97 #[allow(dead_code)]
98 pub(crate) fn fixpoint_to_expr(
99 &mut self,
100 fexpr: &fixpoint::Expr,
101 ) -> Result<rty::Expr, FixpointParseError> {
102 match fexpr {
103 fixpoint::Expr::Constant(constant) => {
104 let c = match constant {
105 fixpoint::Constant::Numeral(num) => rty::Constant::Int(BigInt::from(*num)),
106 fixpoint::Constant::Real(dec) => rty::Constant::Real(rty::Real(*dec)),
107 fixpoint::Constant::Boolean(b) => rty::Constant::Bool(*b),
108 fixpoint::Constant::String(s) => rty::Constant::Str(s.0),
109 fixpoint::Constant::BitVec(bv, size) => rty::Constant::BitVec(*bv, *size),
110 };
111 Ok(rty::Expr::constant(c))
112 }
113 fixpoint::Expr::Var(fvar) => {
114 match fvar {
115 fixpoint::Var::Underscore => {
116 unreachable!("Underscore should not appear in exprs")
117 }
118 fixpoint::Var::Global(global_var, _) | fixpoint::Var::Const(global_var, _) => {
119 if let Some(const_key) = self.ecx.const_env.const_map_rev.get(global_var) {
120 match const_key {
121 ConstKey::RustConst(def_id) => Ok(rty::Expr::const_def_id(*def_id)),
122 ConstKey::Alias(_flux_id, _args) => {
123 unreachable!("Should be special-cased as the head of an app")
124 }
125 ConstKey::Lambda(lambda) => Ok(rty::Expr::abs(lambda.clone())),
126 ConstKey::PrimOp(bin_op) => {
127 Ok(rty::Expr::internal_func(InternalFuncKind::Rel(
128 bin_op.clone(),
129 )))
130 }
131 ConstKey::Cast(_sort, _sort1) => {
132 unreachable!(
133 "Should be specially handled as the head of a function app."
134 )
135 }
136 ConstKey::PtrSize => {
137 Ok(rty::Expr::internal_func(InternalFuncKind::PtrSize))
138 }
139 }
140 } else {
141 Err(FixpointParseError::NoGlobalVar(*global_var))
142 }
143 }
144 fixpoint::Var::Local(fname) => {
145 if let Some(expr) = self.ecx.local_var_env.reverse_map.get(fname) {
146 return Ok(expr.clone());
147 }
148
149 for (depth, layer) in self.ecx.local_var_env.layers.iter().rev().enumerate()
150 {
151 for (idx, var) in layer.iter().enumerate() {
152 if fname == var {
153 return Ok(rty::Expr::bvar(
154 rty::DebruijnIndex::from_usize(depth),
155 BoundVar::from_usize(idx),
156 rty::BoundReftKind::Anon,
157 ));
158 }
159 }
160 }
161
162 Err(FixpointParseError::NoLocalVar(*fname))
163 }
164 fixpoint::Var::DataCtor(adt_id, variant_idx) => {
165 let def_id = self.scx.adt_sorts[adt_id.as_usize()];
166 Ok(rty::Expr::ctor_enum(def_id, *variant_idx))
167 }
168 fixpoint::Var::TupleCtor { .. }
169 | fixpoint::Var::TupleProj { .. }
170 | fixpoint::Var::DataProj { .. }
171 | fixpoint::Var::UIFRel(_) => {
172 unreachable!(
173 "Trying to convert an atomic var, but reached a var that should only occur as the head of an app (and be special-cased in conversion as a result)"
174 )
175 }
176 fixpoint::Var::Param(EarlyReftParam { index, name }) => {
177 Ok(rty::Expr::early_param(*index, *name))
178 }
179 fixpoint::Var::ConstGeneric(const_generic) => {
180 Ok(rty::Expr::const_generic(*const_generic))
181 }
182 }
183 }
184 fixpoint::Expr::App(fhead, _sort_args, fargs, _out_sort) => {
185 match &**fhead {
186 fixpoint::Expr::Var(fixpoint::Var::TupleProj { arity, field }) => {
187 if fargs.len() == 1 {
188 let earg = self.fixpoint_to_expr(&fargs[0])?;
189 Ok(rty::Expr::field_proj(
190 earg,
191 rty::FieldProj::Tuple { arity: *arity, field: *field },
192 ))
193 } else {
194 Err(FixpointParseError::ProjArityMismatch(fargs.len()))
195 }
196 }
197 fixpoint::Expr::Var(fixpoint::Var::DataProj { adt_id, field }) => {
198 if fargs.len() == 1 {
199 let earg = self.fixpoint_to_expr(&fargs[0])?;
200 Ok(rty::Expr::field_proj(
201 earg,
202 rty::FieldProj::Adt {
203 def_id: self.scx.adt_sorts[adt_id.as_usize()],
204 field: *field,
205 },
206 ))
207 } else {
208 Err(FixpointParseError::ProjArityMismatch(fargs.len()))
209 }
210 }
211 fixpoint::Expr::Var(fixpoint::Var::TupleCtor { arity }) => {
212 if fargs.len() == *arity {
213 let eargs = fargs
214 .iter()
215 .map(|farg| self.fixpoint_to_expr(farg))
216 .try_collect()?;
217 Ok(rty::Expr::tuple(eargs))
218 } else {
219 Err(FixpointParseError::TupleCtorArityMismatch(*arity, fargs.len()))
220 }
221 }
222 fixpoint::Expr::Var(fixpoint::Var::UIFRel(fbinrel)) => {
223 if fargs.len() == 2 {
224 let e1 = self.fixpoint_to_expr(&fargs[0])?;
225 let e2 = self.fixpoint_to_expr(&fargs[1])?;
226 let binrel = match fbinrel {
227 fixpoint::BinRel::Eq => rty::BinOp::Eq,
228 fixpoint::BinRel::Ne => rty::BinOp::Ne,
229 fixpoint::BinRel::Gt => rty::BinOp::Gt(rty::Sort::Str),
236 fixpoint::BinRel::Ge => rty::BinOp::Ge(rty::Sort::Str),
237 fixpoint::BinRel::Lt => rty::BinOp::Lt(rty::Sort::Str),
238 fixpoint::BinRel::Le => rty::BinOp::Le(rty::Sort::Str),
239 };
240 Ok(rty::Expr::binary_op(binrel, e1, e2))
241 } else {
242 Err(FixpointParseError::UIFRelArityMismatch(fargs.len()))
243 }
244 }
245 fixpoint::Expr::Var(fixpoint::Var::Global(global_var, _))
246 | fixpoint::Expr::Var(fixpoint::Var::Const(global_var, _)) => {
247 if let Some(const_key) = self.ecx.const_env.const_map_rev.get(global_var) {
248 match const_key {
249 ConstKey::PrimOp(bin_op) => {
253 if fargs.len() != 2 {
254 Err(FixpointParseError::PrimOpArityMismatch(fargs.len()))
255 } else {
256 Ok(rty::Expr::prim_val(
257 bin_op.clone(),
258 self.fixpoint_to_expr(&fargs[0])?,
259 self.fixpoint_to_expr(&fargs[1])?,
260 ))
261 }
262 }
263 ConstKey::Cast(sort1, sort2) => {
264 if fargs.len() != 1 {
265 Err(FixpointParseError::CastArityMismatch(fargs.len()))
266 } else {
267 Ok(rty::Expr::cast(
268 sort1.clone(),
269 sort2.clone(),
270 self.fixpoint_to_expr(&fargs[0])?,
271 ))
272 }
273 }
274 ConstKey::Alias(assoc_id, generic_args) => {
275 let lowered_args: flux_rustc_bridge::ty::GenericArgs =
276 generic_args.lower(self.genv.tcx()).unwrap();
277 let generic_args = rty::refining::Refiner::default_for_item(
278 self.genv,
279 assoc_id.parent(),
280 )
281 .unwrap()
282 .refine_generic_args(assoc_id.parent(), &lowered_args)
283 .unwrap();
284 let alias_reft =
285 rty::AliasReft { assoc_id: *assoc_id, args: generic_args };
286 let args = fargs
287 .iter()
288 .map(|farg| self.fixpoint_to_expr(farg))
289 .try_collect()?;
290 Ok(rty::Expr::alias(alias_reft, args))
291 }
292 ConstKey::RustConst(..)
293 | ConstKey::Lambda(..)
294 | ConstKey::PtrSize => {
295 self.fixpoint_app_to_expr(fhead, fargs)
297 }
298 }
299 } else {
300 Err(FixpointParseError::NoGlobalVar(*global_var))
301 }
302 }
303 fhead => self.fixpoint_app_to_expr(fhead, fargs),
304 }
305 }
306 fixpoint::Expr::Neg(fexpr) => {
307 let e = self.fixpoint_to_expr(fexpr)?;
308 Ok(rty::Expr::neg(&e))
309 }
310 fixpoint::Expr::BinaryOp(fbinop, boxed_args) => {
311 let binop = match fbinop {
312 fixpoint::BinOp::Add => rty::BinOp::Add(rty::Sort::Int),
316 fixpoint::BinOp::Sub => rty::BinOp::Sub(rty::Sort::Int),
317 fixpoint::BinOp::Mul => rty::BinOp::Mul(rty::Sort::Int),
318 fixpoint::BinOp::Div => rty::BinOp::Div(rty::Sort::Int),
319 fixpoint::BinOp::Mod => rty::BinOp::Mod(rty::Sort::Int),
320 };
321 let [fe1, fe2] = &**boxed_args;
322 let e1 = self.fixpoint_to_expr(fe1)?;
323 let e2 = self.fixpoint_to_expr(fe2)?;
324 Ok(rty::Expr::binary_op(binop, e1, e2))
325 }
326 fixpoint::Expr::IfThenElse(boxed_args) => {
327 let [fe1, fe2, fe3] = &**boxed_args;
328 let e1 = self.fixpoint_to_expr(fe1)?;
329 let e2 = self.fixpoint_to_expr(fe2)?;
330 let e3 = self.fixpoint_to_expr(fe3)?;
331 Ok(rty::Expr::ite(e1, e2, e3))
332 }
333 fixpoint::Expr::And(fexprs) => {
334 let exprs: Vec<rty::Expr> = fexprs
335 .iter()
336 .map(|fexpr| self.fixpoint_to_expr(fexpr))
337 .try_collect()?;
338 Ok(rty::Expr::and_from_iter(exprs))
339 }
340 fixpoint::Expr::Or(fexprs) => {
341 let exprs: Vec<rty::Expr> = fexprs
342 .iter()
343 .map(|fexpr| self.fixpoint_to_expr(fexpr))
344 .try_collect()?;
345 Ok(rty::Expr::or_from_iter(exprs))
346 }
347 fixpoint::Expr::Not(fexpr) => {
348 let e = self.fixpoint_to_expr(fexpr)?;
349 Ok(rty::Expr::not(&e))
350 }
351 fixpoint::Expr::Imp(boxed_args) => {
352 let [fe1, fe2] = &**boxed_args;
353 let e1 = self.fixpoint_to_expr(fe1)?;
354 let e2 = self.fixpoint_to_expr(fe2)?;
355 Ok(rty::Expr::binary_op(rty::BinOp::Imp, e1, e2))
356 }
357 fixpoint::Expr::Iff(boxed_args) => {
358 let [fe1, fe2] = &**boxed_args;
359 let e1 = self.fixpoint_to_expr(fe1)?;
360 let e2 = self.fixpoint_to_expr(fe2)?;
361 Ok(rty::Expr::binary_op(rty::BinOp::Iff, e1, e2))
362 }
363 fixpoint::Expr::Atom(fbinrel, boxed_args) => {
364 let binrel = match fbinrel {
365 fixpoint::BinRel::Eq => rty::BinOp::Eq,
366 fixpoint::BinRel::Ne => rty::BinOp::Ne,
367 fixpoint::BinRel::Gt => rty::BinOp::Gt(rty::Sort::Int),
375 fixpoint::BinRel::Ge => rty::BinOp::Ge(rty::Sort::Int),
376 fixpoint::BinRel::Lt => rty::BinOp::Lt(rty::Sort::Int),
377 fixpoint::BinRel::Le => rty::BinOp::Le(rty::Sort::Int),
378 };
379 let [fe1, fe2] = &**boxed_args;
380 let e1 = self.fixpoint_to_expr(fe1)?;
381 let e2 = self.fixpoint_to_expr(fe2)?;
382 Ok(rty::Expr::binary_op(binrel, e1, e2))
383 }
384 fixpoint::Expr::Let(_var, _boxed_args) => {
385 todo!("Convert `var` in e2 to locally nameless var, then fill in sort");
392 }
394 fixpoint::Expr::ThyFunc(itf) => Ok(rty::Expr::global_func(SpecFuncKind::Thy(*itf))),
395 fixpoint::Expr::IsCtor(var, fe) => {
396 let (def_id, variant_idx) = match var {
397 fixpoint::Var::DataCtor(adt_id, variant_idx) => {
398 let def_id = self.scx.adt_sorts[adt_id.as_usize()];
399 Ok((def_id, *variant_idx))
400 }
401 _ => Err(FixpointParseError::WrongVarInIsCtor(*var)),
402 }?;
403 let e = self.fixpoint_to_expr(fe)?;
404 Ok(rty::Expr::is_ctor(def_id, variant_idx, e))
405 }
406 fixpoint::Expr::Exists(binder, body) => {
407 let mut vars = vec![];
408 let mut sorts = vec![];
409 for (var, sort) in binder {
410 let fixpoint::Var::Local(local_var) = var else {
411 return Err(FixpointParseError::WrongVarInBinder(*var));
412 };
413 vars.push(*local_var);
414 sorts.push(self.fixpoint_to_sort(sort)?);
415 }
416 self.ecx.local_var_env.push_layer(vars);
417 let body = self.fixpoint_to_expr(body)?;
418 self.ecx.local_var_env.pop_layer();
419 Ok(rty::Expr::exists(Binder::bind_with_sorts(body, &sorts)))
420 }
421 }
422 }
423
424 fn fixpoint_app_to_expr(
425 &mut self,
426 fhead: &fixpoint::Expr,
427 fargs: &[fixpoint::Expr],
428 ) -> Result<rty::Expr, FixpointParseError> {
429 let head = self.fixpoint_to_expr(fhead)?;
430 let args = fargs
431 .iter()
432 .map(|farg| self.fixpoint_to_expr(farg))
433 .try_collect()?;
434 Ok(rty::Expr::app(head, List::empty(), args))
435 }
436}
437
438#[derive(Debug)]
439pub enum FixpointParseError {
440 UIFRelArityMismatch(usize),
443 TupleCtorArityMismatch(usize, usize),
445 ProjArityMismatch(usize),
447 NoGlobalVar(fixpoint::GlobalVar),
448 CastArityMismatch(usize),
450 PrimOpArityMismatch(usize),
451 NoLocalVar(fixpoint::LocalVar),
452 WrongVarInIsCtor(fixpoint::Var),
454 WrongVarInBinder(fixpoint::Var),
456 UnknownAdt(DefId),
457}