flux_infer/fixpoint_encoding/
decoding.rs1use flux_middle::{
2 big_int::BigInt,
3 rty::{self, Binder, EarlyReftParam, InternalFuncKind, List, SpecFuncKind},
4};
5use flux_rustc_bridge::lowering::Lower;
6use itertools::Itertools;
7use rustc_hir::def_id::DefId;
8use rustc_type_ir::BoundVar;
9
10use super::{ConstKey, FixpointCtxt, fixpoint};
11
12impl<'genv, 'tcx, Tag> FixpointCtxt<'genv, 'tcx, Tag>
13where
14 Tag: std::hash::Hash + Eq + Copy,
15{
16 fn fixpoint_to_sort_ctor(
17 &self,
18 ctor: &fixpoint::SortCtor,
19 ) -> Result<rty::SortCtor, FixpointParseError> {
20 match ctor {
21 fixpoint::SortCtor::Set => Ok(rty::SortCtor::Set),
22 fixpoint::SortCtor::Map => Ok(rty::SortCtor::Map),
23 fixpoint::SortCtor::Data(fixpoint::DataSort::Tuple(_)) => {
24 panic!("oh no! tuple!") }
26 fixpoint::SortCtor::Data(fixpoint::DataSort::User(def_id)) => {
27 Ok(rty::SortCtor::User(*def_id))
28 }
29 fixpoint::SortCtor::Data(fixpoint::DataSort::Adt(adt_id)) => {
30 let def_id = self.scx.adt_sorts[adt_id.as_usize()];
31 let Ok(adt_sort_def) = self.genv.adt_sort_def_of(def_id) else {
32 return Err(FixpointParseError::UnknownAdt(def_id));
33 };
34 Ok(rty::SortCtor::Adt(adt_sort_def))
35 }
36 }
37 }
38
39 pub(crate) fn fixpoint_to_sort(
40 &self,
41 fsort: &fixpoint::Sort,
42 ) -> Result<rty::Sort, FixpointParseError> {
43 match fsort {
44 fixpoint::Sort::Int => Ok(rty::Sort::Int),
45 fixpoint::Sort::Real => Ok(rty::Sort::Real),
46 fixpoint::Sort::Bool => Ok(rty::Sort::Bool),
47 fixpoint::Sort::Str => Ok(rty::Sort::Str),
48 fixpoint::Sort::Func(sorts) => {
49 let sort1 = self.fixpoint_to_sort(&sorts[0])?;
50 let sort2 = self.fixpoint_to_sort(&sorts[1])?;
51 let fsort = rty::FuncSort::new(vec![sort1], sort2);
52 let poly_sort = rty::PolyFuncSort::new(List::empty(), fsort);
53 Ok(rty::Sort::Func(poly_sort))
54 }
55 fixpoint::Sort::App(ctor, args) => {
56 let ctor = self.fixpoint_to_sort_ctor(ctor)?;
57 let args = args
58 .iter()
59 .map(|fsort| self.fixpoint_to_sort(fsort))
60 .try_collect()?;
61 Ok(rty::Sort::App(ctor, args))
62 }
63 fixpoint::Sort::BitVec(fsort) if let fixpoint::Sort::BvSize(size) = **fsort => {
64 Ok(rty::Sort::BitVec(rty::BvSize::Fixed(size)))
65 }
66 _ => unimplemented!("fixpoint_to_sort: {fsort:?}"),
67 }
68 }
69
70 #[allow(dead_code)]
71 pub(crate) fn fixpoint_to_expr(
72 &self,
73 fexpr: &fixpoint::Expr,
74 ) -> Result<rty::Expr, FixpointParseError> {
75 match fexpr {
76 fixpoint::Expr::Constant(constant) => {
77 let c = match constant {
78 fixpoint::Constant::Numeral(num) => rty::Constant::Int(BigInt::from(*num)),
79 fixpoint::Constant::Real(dec) => rty::Constant::Real(rty::Real(*dec)),
80 fixpoint::Constant::Boolean(b) => rty::Constant::Bool(*b),
81 fixpoint::Constant::String(s) => rty::Constant::Str(s.0),
82 fixpoint::Constant::BitVec(bv, size) => rty::Constant::BitVec(*bv, *size),
83 };
84 Ok(rty::Expr::constant(c))
85 }
86 fixpoint::Expr::Var(fvar) => {
87 match fvar {
88 fixpoint::Var::Underscore => {
89 unreachable!("Underscore should not appear in exprs")
90 }
91 fixpoint::Var::Global(global_var, _) => {
92 if let Some(const_key) = self.ecx.const_env.const_map_rev.get(global_var) {
93 match const_key {
94 ConstKey::Uif(def_id) => {
95 Ok(rty::Expr::global_func(SpecFuncKind::Uif(*def_id)))
96 }
97 ConstKey::RustConst(def_id) => Ok(rty::Expr::const_def_id(*def_id)),
98 ConstKey::Alias(_flux_id, _args) => {
99 unreachable!("Should be special-cased as the head of an app")
100 }
101 ConstKey::Lambda(lambda) => Ok(rty::Expr::abs(lambda.clone())),
102 ConstKey::PrimOp(bin_op) => {
103 Ok(rty::Expr::internal_func(InternalFuncKind::Rel(
104 bin_op.clone(),
105 )))
106 }
107 ConstKey::Cast(_sort, _sort1) => {
108 unreachable!(
109 "Should be specially handled as the head of a function app."
110 )
111 }
112 }
113 } else {
114 Err(FixpointParseError::NoGlobalVar(*global_var))
115 }
116 }
117 fixpoint::Var::Local(fname) => {
118 if let Some(expr) = self.ecx.local_var_env.reverse_map.get(fname) {
119 Ok(expr.clone())
120 } else {
121 Err(FixpointParseError::NoLocalVar(*fname))
122 }
123 }
124 fixpoint::Var::DataCtor(adt_id, variant_idx) => {
125 let def_id = self.scx.adt_sorts[adt_id.as_usize()];
126 Ok(rty::Expr::ctor_enum(def_id, *variant_idx))
127 }
128 fixpoint::Var::TupleCtor { .. }
129 | fixpoint::Var::TupleProj { .. }
130 | fixpoint::Var::DataProj { .. }
131 | fixpoint::Var::UIFRel(_) => {
132 unreachable!(
133 "Trying to convert an atomic var, but reached a var that should only occur as the head of an app (and be special-cased in conversion as a result)"
134 )
135 }
136 fixpoint::Var::Param(EarlyReftParam { index, name }) => {
137 Ok(rty::Expr::early_param(*index, *name))
138 }
139 fixpoint::Var::ConstGeneric(const_generic) => {
140 Ok(rty::Expr::const_generic(*const_generic))
141 }
142 }
143 }
144 fixpoint::Expr::App(fhead, fargs) => {
145 match &**fhead {
146 fixpoint::Expr::Var(fixpoint::Var::TupleProj { arity, field }) => {
147 if fargs.len() == 1 {
148 let earg = self.fixpoint_to_expr(&fargs[0])?;
149 Ok(rty::Expr::field_proj(
150 earg,
151 rty::FieldProj::Tuple { arity: *arity, field: *field },
152 ))
153 } else {
154 Err(FixpointParseError::ProjArityMismatch(fargs.len()))
155 }
156 }
157 fixpoint::Expr::Var(fixpoint::Var::DataProj { adt_id, field }) => {
158 if fargs.len() == 1 {
159 let earg = self.fixpoint_to_expr(&fargs[0])?;
160 Ok(rty::Expr::field_proj(
161 earg,
162 rty::FieldProj::Adt {
163 def_id: self.scx.adt_sorts[adt_id.as_usize()],
164 field: *field,
165 },
166 ))
167 } else {
168 Err(FixpointParseError::ProjArityMismatch(fargs.len()))
169 }
170 }
171 fixpoint::Expr::Var(fixpoint::Var::TupleCtor { arity }) => {
172 if fargs.len() == *arity {
173 let eargs = fargs
174 .iter()
175 .map(|farg| self.fixpoint_to_expr(farg))
176 .try_collect()?;
177 Ok(rty::Expr::tuple(eargs))
178 } else {
179 Err(FixpointParseError::TupleCtorArityMismatch(*arity, fargs.len()))
180 }
181 }
182 fixpoint::Expr::Var(fixpoint::Var::UIFRel(fbinrel)) => {
183 if fargs.len() == 2 {
184 let e1 = self.fixpoint_to_expr(&fargs[0])?;
185 let e2 = self.fixpoint_to_expr(&fargs[1])?;
186 let binrel = match fbinrel {
187 fixpoint::BinRel::Eq => rty::BinOp::Eq,
188 fixpoint::BinRel::Ne => rty::BinOp::Ne,
189 fixpoint::BinRel::Gt => rty::BinOp::Gt(rty::Sort::Str),
196 fixpoint::BinRel::Ge => rty::BinOp::Ge(rty::Sort::Str),
197 fixpoint::BinRel::Lt => rty::BinOp::Lt(rty::Sort::Str),
198 fixpoint::BinRel::Le => rty::BinOp::Le(rty::Sort::Str),
199 };
200 Ok(rty::Expr::binary_op(binrel, e1, e2))
201 } else {
202 Err(FixpointParseError::UIFRelArityMismatch(fargs.len()))
203 }
204 }
205 fixpoint::Expr::Var(fixpoint::Var::Global(global_var, _)) => {
206 if let Some(const_key) = self.ecx.const_env.const_map_rev.get(global_var) {
207 match const_key {
208 ConstKey::PrimOp(bin_op) => {
212 if fargs.len() != 2 {
213 Err(FixpointParseError::PrimOpArityMismatch(fargs.len()))
214 } else {
215 Ok(rty::Expr::prim_val(
216 bin_op.clone(),
217 self.fixpoint_to_expr(&fargs[0])?,
218 self.fixpoint_to_expr(&fargs[1])?,
219 ))
220 }
221 }
222 ConstKey::Cast(sort1, sort2) => {
223 if fargs.len() != 1 {
224 Err(FixpointParseError::CastArityMismatch(fargs.len()))
225 } else {
226 Ok(rty::Expr::cast(
227 sort1.clone(),
228 sort2.clone(),
229 self.fixpoint_to_expr(&fargs[0])?,
230 ))
231 }
232 }
233 ConstKey::Alias(assoc_id, generic_args) => {
234 let lowered_args: flux_rustc_bridge::ty::GenericArgs =
235 generic_args.lower(self.genv.tcx()).unwrap();
236 let generic_args = rty::refining::Refiner::default_for_item(
237 self.genv,
238 assoc_id.parent(),
239 )
240 .unwrap()
241 .refine_generic_args(assoc_id.parent(), &lowered_args)
242 .unwrap();
243 let alias_reft =
244 rty::AliasReft { assoc_id: *assoc_id, args: generic_args };
245 let args = fargs
246 .iter()
247 .map(|farg| self.fixpoint_to_expr(farg))
248 .try_collect()?;
249 Ok(rty::Expr::alias(alias_reft, args))
250 }
251 ConstKey::Uif(..)
252 | ConstKey::RustConst(..)
253 | ConstKey::Lambda(..) => {
254 self.fixpoint_app_to_expr(fhead, fargs)
256 }
257 }
258 } else {
259 Err(FixpointParseError::NoGlobalVar(*global_var))
260 }
261 }
262 fhead => self.fixpoint_app_to_expr(fhead, fargs),
263 }
264 }
265 fixpoint::Expr::Neg(fexpr) => {
266 let e = self.fixpoint_to_expr(fexpr)?;
267 Ok(rty::Expr::neg(&e))
268 }
269 fixpoint::Expr::BinaryOp(fbinop, boxed_args) => {
270 let binop = match fbinop {
271 fixpoint::BinOp::Add => rty::BinOp::Add(rty::Sort::Int),
275 fixpoint::BinOp::Sub => rty::BinOp::Sub(rty::Sort::Int),
276 fixpoint::BinOp::Mul => rty::BinOp::Mul(rty::Sort::Int),
277 fixpoint::BinOp::Div => rty::BinOp::Div(rty::Sort::Int),
278 fixpoint::BinOp::Mod => rty::BinOp::Mod(rty::Sort::Int),
279 };
280 let [fe1, fe2] = &**boxed_args;
281 let e1 = self.fixpoint_to_expr(fe1)?;
282 let e2 = self.fixpoint_to_expr(fe2)?;
283 Ok(rty::Expr::binary_op(binop, e1, e2))
284 }
285 fixpoint::Expr::IfThenElse(boxed_args) => {
286 let [fe1, fe2, fe3] = &**boxed_args;
287 let e1 = self.fixpoint_to_expr(fe1)?;
288 let e2 = self.fixpoint_to_expr(fe2)?;
289 let e3 = self.fixpoint_to_expr(fe3)?;
290 Ok(rty::Expr::ite(e1, e2, e3))
291 }
292 fixpoint::Expr::And(fexprs) => {
293 let exprs: Vec<rty::Expr> = fexprs
294 .iter()
295 .map(|fexpr| self.fixpoint_to_expr(fexpr))
296 .try_collect()?;
297 Ok(rty::Expr::and_from_iter(exprs))
298 }
299 fixpoint::Expr::Or(fexprs) => {
300 let exprs: Vec<rty::Expr> = fexprs
301 .iter()
302 .map(|fexpr| self.fixpoint_to_expr(fexpr))
303 .try_collect()?;
304 Ok(rty::Expr::or_from_iter(exprs))
305 }
306 fixpoint::Expr::Not(fexpr) => {
307 let e = self.fixpoint_to_expr(fexpr)?;
308 Ok(rty::Expr::not(&e))
309 }
310 fixpoint::Expr::Imp(boxed_args) => {
311 let [fe1, fe2] = &**boxed_args;
312 let e1 = self.fixpoint_to_expr(fe1)?;
313 let e2 = self.fixpoint_to_expr(fe2)?;
314 Ok(rty::Expr::binary_op(rty::BinOp::Imp, e1, e2))
315 }
316 fixpoint::Expr::Iff(boxed_args) => {
317 let [fe1, fe2] = &**boxed_args;
318 let e1 = self.fixpoint_to_expr(fe1)?;
319 let e2 = self.fixpoint_to_expr(fe2)?;
320 Ok(rty::Expr::binary_op(rty::BinOp::Iff, e1, e2))
321 }
322 fixpoint::Expr::Atom(fbinrel, boxed_args) => {
323 let binrel = match fbinrel {
324 fixpoint::BinRel::Eq => rty::BinOp::Eq,
325 fixpoint::BinRel::Ne => rty::BinOp::Ne,
326 fixpoint::BinRel::Gt => rty::BinOp::Gt(rty::Sort::Int),
334 fixpoint::BinRel::Ge => rty::BinOp::Ge(rty::Sort::Int),
335 fixpoint::BinRel::Lt => rty::BinOp::Lt(rty::Sort::Int),
336 fixpoint::BinRel::Le => rty::BinOp::Le(rty::Sort::Int),
337 };
338 let [fe1, fe2] = &**boxed_args;
339 let e1 = self.fixpoint_to_expr(fe1)?;
340 let e2 = self.fixpoint_to_expr(fe2)?;
341 Ok(rty::Expr::binary_op(binrel, e1, e2))
342 }
343 fixpoint::Expr::Let(_var, _boxed_args) => {
344 todo!("Convert `var` in e2 to locally nameless var, then fill in sort");
351 }
353 fixpoint::Expr::ThyFunc(itf) => Ok(rty::Expr::global_func(SpecFuncKind::Thy(*itf))),
354 fixpoint::Expr::IsCtor(var, fe) => {
355 let (def_id, variant_idx) = match var {
356 fixpoint::Var::DataCtor(adt_id, variant_idx) => {
357 let def_id = self.scx.adt_sorts[adt_id.as_usize()];
358 Ok((def_id, *variant_idx))
359 }
360 _ => Err(FixpointParseError::WrongVarInIsCtor(*var)),
361 }?;
362 let e = self.fixpoint_to_expr(fe)?;
363 Ok(rty::Expr::is_ctor(def_id, variant_idx, e))
364 }
365 fixpoint::Expr::Exists(sorts, body) => {
366 let sorts: Vec<_> = sorts
367 .iter()
368 .map(|fsort| self.fixpoint_to_sort(fsort))
369 .try_collect()?;
370 let body = self.fixpoint_to_expr(body)?;
371 Ok(rty::Expr::exists(Binder::bind_with_sorts(body, &sorts)))
372 }
373 fixpoint::Expr::BoundVar(fixpoint::BoundVar { level, idx }) => {
374 Ok(rty::Expr::bvar(
375 rty::DebruijnIndex::from_usize(*level),
376 BoundVar::from_usize(*idx),
377 rty::BoundReftKind::Anon,
378 ))
379 }
380 }
381 }
382
383 fn fixpoint_app_to_expr(
384 &self,
385 fhead: &fixpoint::Expr,
386 fargs: &[fixpoint::Expr],
387 ) -> Result<rty::Expr, FixpointParseError> {
388 let head = self.fixpoint_to_expr(fhead)?;
389 let args = fargs
390 .iter()
391 .map(|farg| self.fixpoint_to_expr(farg))
392 .try_collect()?;
393 Ok(rty::Expr::app(head, List::empty(), args))
394 }
395}
396
397#[derive(Debug)]
398pub enum FixpointParseError {
399 UIFRelArityMismatch(usize),
402 TupleCtorArityMismatch(usize, usize),
404 ProjArityMismatch(usize),
406 NoGlobalVar(fixpoint::GlobalVar),
407 CastArityMismatch(usize),
409 PrimOpArityMismatch(usize),
410 NoLocalVar(fixpoint::LocalVar),
411 WrongVarInIsCtor(fixpoint::Var),
413 UnknownAdt(DefId),
414}