flux_infer/fixpoint_encoding/
decoding.rs

1use flux_common::tracked_span_bug;
2use flux_middle::{
3    big_int::BigInt,
4    rty::{self, Binder, EarlyReftParam, InternalFuncKind, List, SpecFuncKind},
5};
6use flux_rustc_bridge::lowering::Lower;
7use itertools::Itertools;
8use rustc_hir::def_id::DefId;
9use rustc_type_ir::BoundVar;
10
11use super::{ConstKey, FixpointCtxt, fixpoint};
12use crate::fixpoint_encoding::FixpointSolution;
13
14impl<'genv, 'tcx, Tag> FixpointCtxt<'genv, 'tcx, Tag>
15where
16    Tag: std::hash::Hash + Eq + Copy,
17{
18    pub(crate) fn fixpoint_to_solution(
19        &mut self,
20        sol: &FixpointSolution,
21    ) -> rty::Binder<rty::Expr> {
22        let mut vars = vec![];
23        let mut sorts = vec![];
24        for (var, sort) in &sol.0 {
25            let fixpoint::Var::Local(local_var) = var else {
26                tracked_span_bug!("encountered non-local variable in binder: {var:?}");
27            };
28            vars.push(*local_var);
29            sorts.push(
30                self.fixpoint_to_sort(sort)
31                    .unwrap_or_else(|_| tracked_span_bug!("failed to parse sort: {sort:?}")),
32            );
33        }
34        self.ecx.local_var_env.push_layer(vars);
35        let expr = self
36            .fixpoint_to_expr(&sol.1)
37            .unwrap_or_else(|err| tracked_span_bug!("failed to convert expr: {err:?}"));
38        self.ecx.local_var_env.pop_layer();
39        rty::Binder::bind_with_sorts(expr, &sorts)
40    }
41
42    fn fixpoint_to_sort_ctor(
43        &self,
44        ctor: &fixpoint::SortCtor,
45    ) -> Result<rty::SortCtor, FixpointParseError> {
46        match ctor {
47            fixpoint::SortCtor::Set => Ok(rty::SortCtor::Set),
48            fixpoint::SortCtor::Map => Ok(rty::SortCtor::Map),
49            fixpoint::SortCtor::Data(fixpoint::DataSort::Tuple(_)) => {
50                panic!("oh no! tuple!") // Ok(rty::SortCtor::Tuple(*size))
51            }
52            fixpoint::SortCtor::Data(fixpoint::DataSort::User(opaque_id)) => {
53                let def_id = self.scx.opaque_sorts[opaque_id.as_usize()];
54                Ok(rty::SortCtor::User(def_id))
55            }
56            fixpoint::SortCtor::Data(fixpoint::DataSort::Adt(adt_id)) => {
57                let def_id = self.scx.adt_sorts[adt_id.as_usize()];
58                let Ok(adt_sort_def) = self.genv.adt_sort_def_of(def_id) else {
59                    return Err(FixpointParseError::UnknownAdt(def_id));
60                };
61                Ok(rty::SortCtor::Adt(adt_sort_def))
62            }
63        }
64    }
65
66    pub(crate) fn fixpoint_to_sort(
67        &self,
68        fsort: &fixpoint::Sort,
69    ) -> Result<rty::Sort, FixpointParseError> {
70        match fsort {
71            fixpoint::Sort::Int => Ok(rty::Sort::Int),
72            fixpoint::Sort::Real => Ok(rty::Sort::Real),
73            fixpoint::Sort::Bool => Ok(rty::Sort::Bool),
74            fixpoint::Sort::Str => Ok(rty::Sort::Str),
75            fixpoint::Sort::Func(sorts) => {
76                let sort1 = self.fixpoint_to_sort(&sorts[0])?;
77                let sort2 = self.fixpoint_to_sort(&sorts[1])?;
78                let fsort = rty::FuncSort::new(vec![sort1], sort2);
79                let poly_sort = rty::PolyFuncSort::new(List::empty(), fsort);
80                Ok(rty::Sort::Func(poly_sort))
81            }
82            fixpoint::Sort::App(ctor, args) => {
83                let ctor = self.fixpoint_to_sort_ctor(ctor)?;
84                let args = args
85                    .iter()
86                    .map(|fsort| self.fixpoint_to_sort(fsort))
87                    .try_collect()?;
88                Ok(rty::Sort::App(ctor, args))
89            }
90            fixpoint::Sort::BitVec(fsort) if let fixpoint::Sort::BvSize(size) = **fsort => {
91                Ok(rty::Sort::BitVec(rty::BvSize::Fixed(size)))
92            }
93            _ => unimplemented!("fixpoint_to_sort:  {fsort:?}"),
94        }
95    }
96
97    fn is_curried_primop_app(
98        &mut self,
99        fhead: &fixpoint::Expr,
100        fargs: &[fixpoint::Expr],
101        op_args: &mut Vec<fixpoint::Expr>,
102    ) -> Option<rty::BinOp> {
103        match fhead {
104            fixpoint::Expr::Var(fixpoint::Var::Global(global_var, _))
105            | fixpoint::Expr::Var(fixpoint::Var::Const(global_var, _)) => {
106                if let Some(ConstKey::PrimOp(bin_op)) =
107                    self.ecx.const_env.const_map_rev.get(global_var)
108                {
109                    op_args.reverse();
110                    Some(bin_op.clone())
111                } else {
112                    None
113                }
114            }
115            fixpoint::Expr::App(fhead_inner, _, fargs_inner, _) => {
116                if fargs.len() == 1 {
117                    op_args.push(fargs[0].clone());
118                }
119                self.is_curried_primop_app(fhead_inner, fargs_inner, op_args)
120            }
121            _ => None,
122        }
123    }
124
125    #[allow(dead_code)]
126    pub(crate) fn fixpoint_to_expr(
127        &mut self,
128        fexpr: &fixpoint::Expr,
129    ) -> Result<rty::Expr, FixpointParseError> {
130        match fexpr {
131            fixpoint::Expr::Constant(constant) => {
132                let c = match constant {
133                    fixpoint::Constant::Numeral(num) => rty::Constant::Int(BigInt::from(*num)),
134                    fixpoint::Constant::Real(dec) => rty::Constant::Real(rty::Real(dec.0)),
135                    fixpoint::Constant::Boolean(b) => rty::Constant::Bool(*b),
136                    fixpoint::Constant::String(s) => rty::Constant::Str(s.0),
137                    fixpoint::Constant::BitVec(bv, size) => rty::Constant::BitVec(*bv, *size),
138                };
139                Ok(rty::Expr::constant(c))
140            }
141            fixpoint::Expr::Var(fvar) => {
142                match fvar {
143                    fixpoint::Var::Underscore => {
144                        unreachable!("Underscore should not appear in exprs")
145                    }
146                    fixpoint::Var::Global(global_var, _) | fixpoint::Var::Const(global_var, _) => {
147                        if let Some(const_key) = self.ecx.const_env.const_map_rev.get(global_var) {
148                            match const_key {
149                                ConstKey::RustConst(def_id) => Ok(rty::Expr::const_def_id(*def_id)),
150                                ConstKey::Alias(_flux_id, _args) => {
151                                    unreachable!("Should be special-cased as the head of an app")
152                                }
153                                ConstKey::Lambda(lambda) => Ok(rty::Expr::abs(lambda.clone())),
154                                ConstKey::PrimOp(bin_op) => {
155                                    Ok(rty::Expr::internal_func(InternalFuncKind::Rel(
156                                        bin_op.clone(),
157                                    )))
158                                }
159                                ConstKey::Cast(_sort, _sort1) => {
160                                    unreachable!(
161                                        "Should be specially handled as the head of a function app."
162                                    )
163                                }
164                            }
165                        } else {
166                            Err(FixpointParseError::NoGlobalVar(*global_var))
167                        }
168                    }
169                    fixpoint::Var::Local(fname) => {
170                        if let Some(expr) = self.ecx.local_var_env.reverse_map.get(fname) {
171                            return Ok(expr.clone());
172                        }
173
174                        for (depth, layer) in self.ecx.local_var_env.layers.iter().rev().enumerate()
175                        {
176                            for (idx, var) in layer.iter().enumerate() {
177                                if fname == var {
178                                    return Ok(rty::Expr::bvar(
179                                        rty::DebruijnIndex::from_usize(depth),
180                                        BoundVar::from_usize(idx),
181                                        rty::BoundReftKind::Anon,
182                                    ));
183                                }
184                            }
185                        }
186
187                        Err(FixpointParseError::NoLocalVar(*fname))
188                    }
189                    fixpoint::Var::DataCtor(adt_id, variant_idx) => {
190                        let def_id = self.scx.adt_sorts[adt_id.as_usize()];
191                        Ok(rty::Expr::ctor_enum(def_id, *variant_idx))
192                    }
193                    fixpoint::Var::TupleCtor { .. }
194                    | fixpoint::Var::TupleProj { .. }
195                    | fixpoint::Var::DataProj { .. }
196                    | fixpoint::Var::UIFRel(_) => {
197                        unreachable!(
198                            "Trying to convert an atomic var, but reached a var that should only occur as the head of an app (and be special-cased in conversion as a result)"
199                        )
200                    }
201                    fixpoint::Var::Param(EarlyReftParam { index, name }) => {
202                        Ok(rty::Expr::early_param(*index, *name))
203                    }
204                    fixpoint::Var::ConstGeneric(const_generic) => {
205                        Ok(rty::Expr::const_generic(*const_generic))
206                    }
207                }
208            }
209            fixpoint::Expr::App(fhead, _sort_args, fargs, _out_sort) => {
210                let mut op_args = vec![];
211                if let Some(bin_op) = self.is_curried_primop_app(fhead, fargs, &mut op_args) {
212                    if op_args.len() != 2 {
213                        return Err(FixpointParseError::PrimOpArityMismatch(fargs.len()));
214                    } else {
215                        let e1 = self.fixpoint_to_expr(&op_args[0])?;
216                        let e2 = self.fixpoint_to_expr(&op_args[1])?;
217                        return Ok(rty::Expr::prim_val(bin_op, e1, e2));
218                    }
219                }
220                match &**fhead {
221                    fixpoint::Expr::Var(fixpoint::Var::TupleProj { arity, field }) => {
222                        if fargs.len() == 1 {
223                            let earg = self.fixpoint_to_expr(&fargs[0])?;
224                            Ok(rty::Expr::field_proj(
225                                earg,
226                                rty::FieldProj::Tuple { arity: *arity, field: *field },
227                            ))
228                        } else {
229                            Err(FixpointParseError::ProjArityMismatch(fargs.len()))
230                        }
231                    }
232                    fixpoint::Expr::Var(fixpoint::Var::DataProj { adt_id, field }) => {
233                        if fargs.len() == 1 {
234                            let earg = self.fixpoint_to_expr(&fargs[0])?;
235                            Ok(rty::Expr::field_proj(
236                                earg,
237                                rty::FieldProj::Adt {
238                                    def_id: self.scx.adt_sorts[adt_id.as_usize()],
239                                    field: *field,
240                                },
241                            ))
242                        } else {
243                            Err(FixpointParseError::ProjArityMismatch(fargs.len()))
244                        }
245                    }
246                    fixpoint::Expr::Var(fixpoint::Var::TupleCtor { arity }) => {
247                        if fargs.len() == *arity {
248                            let eargs = fargs
249                                .iter()
250                                .map(|farg| self.fixpoint_to_expr(farg))
251                                .try_collect()?;
252                            Ok(rty::Expr::tuple(eargs))
253                        } else {
254                            Err(FixpointParseError::TupleCtorArityMismatch(*arity, fargs.len()))
255                        }
256                    }
257                    fixpoint::Expr::Var(fixpoint::Var::UIFRel(fbinrel)) => {
258                        if fargs.len() == 2 {
259                            let e1 = self.fixpoint_to_expr(&fargs[0])?;
260                            let e2 = self.fixpoint_to_expr(&fargs[1])?;
261                            let binrel = match fbinrel {
262                                fixpoint::BinRel::Eq => rty::BinOp::Eq,
263                                fixpoint::BinRel::Ne => rty::BinOp::Ne,
264                                // FIXME: (ck) faked sort information
265                                //
266                                // This needs to be a sort that goes to the UIFRel
267                                // case in fixpoint conversion. Again, if we actually
268                                // need to inspect the sorts this will die unless the
269                                // arguments are actually Strs.
270                                fixpoint::BinRel::Gt => rty::BinOp::Gt(rty::Sort::Str),
271                                fixpoint::BinRel::Ge => rty::BinOp::Ge(rty::Sort::Str),
272                                fixpoint::BinRel::Lt => rty::BinOp::Lt(rty::Sort::Str),
273                                fixpoint::BinRel::Le => rty::BinOp::Le(rty::Sort::Str),
274                            };
275                            Ok(rty::Expr::binary_op(binrel, e1, e2))
276                        } else {
277                            Err(FixpointParseError::UIFRelArityMismatch(fargs.len()))
278                        }
279                    }
280                    fixpoint::Expr::Var(fixpoint::Var::Global(global_var, _))
281                    | fixpoint::Expr::Var(fixpoint::Var::Const(global_var, _)) => {
282                        if let Some(const_key) = self.ecx.const_env.const_map_rev.get(global_var) {
283                            match const_key {
284                                // NOTE: Only a few of these are meaningfully needed,
285                                // e.g. ConstKey::Alias because the rty Expr has its
286                                // args as a part of it.
287                                ConstKey::PrimOp(_) => {
288                                    unreachable!(
289                                        "Should have been handled by is_curried_primop_app"
290                                    )
291                                }
292                                ConstKey::Cast(sort1, sort2) => {
293                                    if fargs.len() != 1 {
294                                        Err(FixpointParseError::CastArityMismatch(fargs.len()))
295                                    } else {
296                                        Ok(rty::Expr::cast(
297                                            sort1.clone(),
298                                            sort2.clone(),
299                                            self.fixpoint_to_expr(&fargs[0])?,
300                                        ))
301                                    }
302                                }
303                                ConstKey::Alias(assoc_id, generic_args) => {
304                                    let lowered_args: flux_rustc_bridge::ty::GenericArgs =
305                                        generic_args.lower(self.genv.tcx()).unwrap();
306                                    let generic_args = rty::refining::Refiner::default_for_item(
307                                        self.genv,
308                                        assoc_id.parent(),
309                                    )
310                                    .unwrap()
311                                    .refine_generic_args(assoc_id.parent(), &lowered_args)
312                                    .unwrap();
313                                    let alias_reft =
314                                        rty::AliasReft { assoc_id: *assoc_id, args: generic_args };
315                                    let args = fargs
316                                        .iter()
317                                        .map(|farg| self.fixpoint_to_expr(farg))
318                                        .try_collect()?;
319                                    Ok(rty::Expr::alias(alias_reft, args))
320                                }
321                                ConstKey::RustConst(..) | ConstKey::Lambda(..) => {
322                                    // These should be treated as a normal app.
323                                    self.fixpoint_app_to_expr(fhead, fargs)
324                                }
325                            }
326                        } else {
327                            Err(FixpointParseError::NoGlobalVar(*global_var))
328                        }
329                    }
330                    fhead => self.fixpoint_app_to_expr(fhead, fargs),
331                }
332            }
333            fixpoint::Expr::Neg(fexpr) => {
334                let e = self.fixpoint_to_expr(fexpr)?;
335                Ok(rty::Expr::neg(&e))
336            }
337            fixpoint::Expr::BinaryOp(fbinop, boxed_args) => {
338                let binop = match fbinop {
339                    // FIXME: (ck) faked sort information
340                    //
341                    // See what we do for binrel for an explanation.
342                    fixpoint::BinOp::Add => rty::BinOp::Add(rty::Sort::Int),
343                    fixpoint::BinOp::Sub => rty::BinOp::Sub(rty::Sort::Int),
344                    fixpoint::BinOp::Mul => rty::BinOp::Mul(rty::Sort::Int),
345                    fixpoint::BinOp::Div => rty::BinOp::Div(rty::Sort::Int),
346                    fixpoint::BinOp::Mod => rty::BinOp::Mod(rty::Sort::Int),
347                };
348                let [fe1, fe2] = &**boxed_args;
349                let e1 = self.fixpoint_to_expr(fe1)?;
350                let e2 = self.fixpoint_to_expr(fe2)?;
351                Ok(rty::Expr::binary_op(binop, e1, e2))
352            }
353            fixpoint::Expr::IfThenElse(boxed_args) => {
354                let [fe1, fe2, fe3] = &**boxed_args;
355                let e1 = self.fixpoint_to_expr(fe1)?;
356                let e2 = self.fixpoint_to_expr(fe2)?;
357                let e3 = self.fixpoint_to_expr(fe3)?;
358                Ok(rty::Expr::ite(e1, e2, e3))
359            }
360            fixpoint::Expr::And(fexprs) => {
361                let exprs: Vec<rty::Expr> = fexprs
362                    .iter()
363                    .map(|fexpr| self.fixpoint_to_expr(fexpr))
364                    .try_collect()?;
365                Ok(rty::Expr::and_from_iter(exprs))
366            }
367            fixpoint::Expr::Or(fexprs) => {
368                let exprs: Vec<rty::Expr> = fexprs
369                    .iter()
370                    .map(|fexpr| self.fixpoint_to_expr(fexpr))
371                    .try_collect()?;
372                Ok(rty::Expr::or_from_iter(exprs))
373            }
374            fixpoint::Expr::Not(fexpr) => {
375                let e = self.fixpoint_to_expr(fexpr)?;
376                Ok(rty::Expr::not(&e))
377            }
378            fixpoint::Expr::Imp(boxed_args) => {
379                let [fe1, fe2] = &**boxed_args;
380                let e1 = self.fixpoint_to_expr(fe1)?;
381                let e2 = self.fixpoint_to_expr(fe2)?;
382                Ok(rty::Expr::binary_op(rty::BinOp::Imp, e1, e2))
383            }
384            fixpoint::Expr::Iff(boxed_args) => {
385                let [fe1, fe2] = &**boxed_args;
386                let e1 = self.fixpoint_to_expr(fe1)?;
387                let e2 = self.fixpoint_to_expr(fe2)?;
388                Ok(rty::Expr::binary_op(rty::BinOp::Iff, e1, e2))
389            }
390            fixpoint::Expr::Atom(fbinrel, boxed_args) => {
391                let binrel = match fbinrel {
392                    fixpoint::BinRel::Eq => rty::BinOp::Eq,
393                    fixpoint::BinRel::Ne => rty::BinOp::Ne,
394                    // FIXME: (ck) faked sort information
395                    //
396                    // I'm pretty sure that it is OK to give `rty::Sort::Int`
397                    // because we only emit `fixpoint::BinRel::Gt`, etc. when we
398                    // have an Int/Real/Char sort (and further this sort info
399                    // isn't further used). But if we inspect this in other
400                    // places then things could break.
401                    fixpoint::BinRel::Gt => rty::BinOp::Gt(rty::Sort::Int),
402                    fixpoint::BinRel::Ge => rty::BinOp::Ge(rty::Sort::Int),
403                    fixpoint::BinRel::Lt => rty::BinOp::Lt(rty::Sort::Int),
404                    fixpoint::BinRel::Le => rty::BinOp::Le(rty::Sort::Int),
405                };
406                let [fe1, fe2] = &**boxed_args;
407                let e1 = self.fixpoint_to_expr(fe1)?;
408                let e2 = self.fixpoint_to_expr(fe2)?;
409                Ok(rty::Expr::binary_op(binrel, e1, e2))
410            }
411            fixpoint::Expr::Let(_var, _boxed_args) => {
412                // TODO: (ck) uncomment this and fix the missing code in the todo!()
413                //
414                // let [fe1, fe2] = &**boxed_args;
415                // let e1 = self.fixpoint_to_expr(fe1)?;
416                // let e2 = self.fixpoint_to_expr(fe2)?;
417                // let e2_binder =
418                todo!("Convert `var` in e2 to locally nameless var, then fill in sort");
419                // Ok(rty::Expr::let_(e1, e2_binder))
420            }
421            fixpoint::Expr::ThyFunc(itf) => Ok(rty::Expr::global_func(SpecFuncKind::Thy(*itf))),
422            fixpoint::Expr::IsCtor(var, fe) => {
423                let (def_id, variant_idx) = match var {
424                    fixpoint::Var::DataCtor(adt_id, variant_idx) => {
425                        let def_id = self.scx.adt_sorts[adt_id.as_usize()];
426                        Ok((def_id, *variant_idx))
427                    }
428                    _ => Err(FixpointParseError::WrongVarInIsCtor(*var)),
429                }?;
430                let e = self.fixpoint_to_expr(fe)?;
431                Ok(rty::Expr::is_ctor(def_id, variant_idx, e))
432            }
433            fixpoint::Expr::Quantifier(q, binder, body) => {
434                let expr = self.fixpoint_to_bind_expr(binder, body)?;
435                match q {
436                    fixpoint::Quantifier::Exists => Ok(rty::Expr::exists(expr)),
437                    fixpoint::Quantifier::Forall => Ok(rty::Expr::forall(expr)),
438                }
439            }
440        }
441    }
442
443    fn fixpoint_to_bind_expr(
444        &mut self,
445        binder: &[(fixpoint::Var, fixpoint::Sort)],
446        body: &fixpoint::Expr,
447    ) -> Result<rty::Binder<rty::Expr>, FixpointParseError> {
448        let mut vars = vec![];
449        let mut sorts = vec![];
450        for (var, sort) in binder {
451            let fixpoint::Var::Local(local_var) = var else {
452                return Err(FixpointParseError::WrongVarInBinder(*var));
453            };
454            vars.push(*local_var);
455            sorts.push(self.fixpoint_to_sort(sort)?);
456        }
457        self.ecx.local_var_env.push_layer(vars);
458        let body = self.fixpoint_to_expr(body)?;
459        self.ecx.local_var_env.pop_layer();
460        Ok(Binder::bind_with_sorts(body, &sorts))
461    }
462
463    fn fixpoint_app_to_expr(
464        &mut self,
465        fhead: &fixpoint::Expr,
466        fargs: &[fixpoint::Expr],
467    ) -> Result<rty::Expr, FixpointParseError> {
468        let head = self.fixpoint_to_expr(fhead)?;
469        let args = fargs
470            .iter()
471            .map(|farg| self.fixpoint_to_expr(farg))
472            .try_collect()?;
473        Ok(rty::Expr::app(head, List::empty(), args))
474    }
475}
476
477#[derive(Debug)]
478pub enum FixpointParseError {
479    /// UIFRels are encoded as Apps, but they are as of right now only binary
480    /// relations so they must have 2 arguments only.
481    UIFRelArityMismatch(usize),
482    /// Expected arity (based off of the tuple ctor), actual arity (the numer of args)
483    TupleCtorArityMismatch(usize, usize),
484    /// The number of arguments should only ever be 1 for a tuple proj
485    ProjArityMismatch(usize),
486    NoGlobalVar(fixpoint::GlobalVar),
487    /// Casts should only have 1 arg
488    CastArityMismatch(usize),
489    PrimOpArityMismatch(usize),
490    NoLocalVar(fixpoint::LocalVar),
491    /// Expecting fixpoint::Var::DataCtor
492    WrongVarInIsCtor(fixpoint::Var),
493    /// Expecting fixpoint::Var::LocalVar
494    WrongVarInBinder(fixpoint::Var),
495    UnknownAdt(DefId),
496}