flux_infer/fixpoint_encoding/
decoding.rs

1use flux_common::tracked_span_bug;
2use flux_middle::{
3    big_int::BigInt,
4    rty::{self, Binder, EarlyReftParam, InternalFuncKind, List, SpecFuncKind},
5};
6use flux_rustc_bridge::lowering::Lower;
7use itertools::Itertools;
8use rustc_hir::def_id::DefId;
9use rustc_type_ir::BoundVar;
10
11use super::{ConstKey, FixpointCtxt, fixpoint};
12use crate::fixpoint_encoding::FixpointSolution;
13
14impl<'genv, 'tcx, Tag> FixpointCtxt<'genv, 'tcx, Tag>
15where
16    Tag: std::hash::Hash + Eq + Copy,
17{
18    pub(crate) fn fixpoint_to_solution(
19        &mut self,
20        sol: &FixpointSolution,
21    ) -> rty::Binder<rty::Expr> {
22        let mut vars = vec![];
23        let mut sorts = vec![];
24        for (var, sort) in &sol.0 {
25            let fixpoint::Var::Local(local_var) = var else {
26                tracked_span_bug!("encountered non-local variable in binder: {var:?}");
27            };
28            vars.push(*local_var);
29            sorts.push(
30                self.fixpoint_to_sort(sort)
31                    .unwrap_or_else(|_| tracked_span_bug!("failed to parse sort: {sort:?}")),
32            );
33        }
34        self.ecx.local_var_env.push_layer(vars);
35        let expr = self
36            .fixpoint_to_expr(&sol.1)
37            .unwrap_or_else(|err| tracked_span_bug!("failed to convert expr: {err:?}"));
38        self.ecx.local_var_env.pop_layer();
39        rty::Binder::bind_with_sorts(expr, &sorts)
40    }
41
42    fn fixpoint_to_sort_ctor(
43        &self,
44        ctor: &fixpoint::SortCtor,
45    ) -> Result<rty::SortCtor, FixpointParseError> {
46        match ctor {
47            fixpoint::SortCtor::Set => Ok(rty::SortCtor::Set),
48            fixpoint::SortCtor::Map => Ok(rty::SortCtor::Map),
49            fixpoint::SortCtor::Data(fixpoint::DataSort::Tuple(_)) => {
50                panic!("oh no! tuple!") // Ok(rty::SortCtor::Tuple(*size))
51            }
52            fixpoint::SortCtor::Data(fixpoint::DataSort::User(def_id)) => {
53                Ok(rty::SortCtor::User(*def_id))
54            }
55            fixpoint::SortCtor::Data(fixpoint::DataSort::Adt(adt_id)) => {
56                let def_id = self.scx.adt_sorts[adt_id.as_usize()];
57                let Ok(adt_sort_def) = self.genv.adt_sort_def_of(def_id) else {
58                    return Err(FixpointParseError::UnknownAdt(def_id));
59                };
60                Ok(rty::SortCtor::Adt(adt_sort_def))
61            }
62        }
63    }
64
65    pub(crate) fn fixpoint_to_sort(
66        &self,
67        fsort: &fixpoint::Sort,
68    ) -> Result<rty::Sort, FixpointParseError> {
69        match fsort {
70            fixpoint::Sort::Int => Ok(rty::Sort::Int),
71            fixpoint::Sort::Real => Ok(rty::Sort::Real),
72            fixpoint::Sort::Bool => Ok(rty::Sort::Bool),
73            fixpoint::Sort::Str => Ok(rty::Sort::Str),
74            fixpoint::Sort::Func(sorts) => {
75                let sort1 = self.fixpoint_to_sort(&sorts[0])?;
76                let sort2 = self.fixpoint_to_sort(&sorts[1])?;
77                let fsort = rty::FuncSort::new(vec![sort1], sort2);
78                let poly_sort = rty::PolyFuncSort::new(List::empty(), fsort);
79                Ok(rty::Sort::Func(poly_sort))
80            }
81            fixpoint::Sort::App(ctor, args) => {
82                let ctor = self.fixpoint_to_sort_ctor(ctor)?;
83                let args = args
84                    .iter()
85                    .map(|fsort| self.fixpoint_to_sort(fsort))
86                    .try_collect()?;
87                Ok(rty::Sort::App(ctor, args))
88            }
89            fixpoint::Sort::BitVec(fsort) if let fixpoint::Sort::BvSize(size) = **fsort => {
90                Ok(rty::Sort::BitVec(rty::BvSize::Fixed(size)))
91            }
92            _ => unimplemented!("fixpoint_to_sort:  {fsort:?}"),
93        }
94    }
95
96    #[allow(dead_code)]
97    pub(crate) fn fixpoint_to_expr(
98        &mut self,
99        fexpr: &fixpoint::Expr,
100    ) -> Result<rty::Expr, FixpointParseError> {
101        match fexpr {
102            fixpoint::Expr::Constant(constant) => {
103                let c = match constant {
104                    fixpoint::Constant::Numeral(num) => rty::Constant::Int(BigInt::from(*num)),
105                    fixpoint::Constant::Real(dec) => rty::Constant::Real(rty::Real(*dec)),
106                    fixpoint::Constant::Boolean(b) => rty::Constant::Bool(*b),
107                    fixpoint::Constant::String(s) => rty::Constant::Str(s.0),
108                    fixpoint::Constant::BitVec(bv, size) => rty::Constant::BitVec(*bv, *size),
109                };
110                Ok(rty::Expr::constant(c))
111            }
112            fixpoint::Expr::Var(fvar) => {
113                match fvar {
114                    fixpoint::Var::Underscore => {
115                        unreachable!("Underscore should not appear in exprs")
116                    }
117                    fixpoint::Var::Global(global_var, _) | fixpoint::Var::Const(global_var, _) => {
118                        if let Some(const_key) = self.ecx.const_env.const_map_rev.get(global_var) {
119                            match const_key {
120                                ConstKey::RustConst(def_id) => Ok(rty::Expr::const_def_id(*def_id)),
121                                ConstKey::Alias(_flux_id, _args) => {
122                                    unreachable!("Should be special-cased as the head of an app")
123                                }
124                                ConstKey::Lambda(lambda) => Ok(rty::Expr::abs(lambda.clone())),
125                                ConstKey::PrimOp(bin_op) => {
126                                    Ok(rty::Expr::internal_func(InternalFuncKind::Rel(
127                                        bin_op.clone(),
128                                    )))
129                                }
130                                ConstKey::Cast(_sort, _sort1) => {
131                                    unreachable!(
132                                        "Should be specially handled as the head of a function app."
133                                    )
134                                }
135                            }
136                        } else {
137                            Err(FixpointParseError::NoGlobalVar(*global_var))
138                        }
139                    }
140                    fixpoint::Var::Local(fname) => {
141                        if let Some(expr) = self.ecx.local_var_env.reverse_map.get(fname) {
142                            return Ok(expr.clone());
143                        }
144
145                        for (depth, layer) in self.ecx.local_var_env.layers.iter().rev().enumerate()
146                        {
147                            for (idx, var) in layer.iter().enumerate() {
148                                if fname == var {
149                                    return Ok(rty::Expr::bvar(
150                                        rty::DebruijnIndex::from_usize(depth),
151                                        BoundVar::from_usize(idx),
152                                        rty::BoundReftKind::Anon,
153                                    ));
154                                }
155                            }
156                        }
157
158                        Err(FixpointParseError::NoLocalVar(*fname))
159                    }
160                    fixpoint::Var::DataCtor(adt_id, variant_idx) => {
161                        let def_id = self.scx.adt_sorts[adt_id.as_usize()];
162                        Ok(rty::Expr::ctor_enum(def_id, *variant_idx))
163                    }
164                    fixpoint::Var::TupleCtor { .. }
165                    | fixpoint::Var::TupleProj { .. }
166                    | fixpoint::Var::DataProj { .. }
167                    | fixpoint::Var::UIFRel(_) => {
168                        unreachable!(
169                            "Trying to convert an atomic var, but reached a var that should only occur as the head of an app (and be special-cased in conversion as a result)"
170                        )
171                    }
172                    fixpoint::Var::Param(EarlyReftParam { index, name }) => {
173                        Ok(rty::Expr::early_param(*index, *name))
174                    }
175                    fixpoint::Var::ConstGeneric(const_generic) => {
176                        Ok(rty::Expr::const_generic(*const_generic))
177                    }
178                }
179            }
180            fixpoint::Expr::App(fhead, _sort_args, fargs, _out_sort) => {
181                match &**fhead {
182                    fixpoint::Expr::Var(fixpoint::Var::TupleProj { arity, field }) => {
183                        if fargs.len() == 1 {
184                            let earg = self.fixpoint_to_expr(&fargs[0])?;
185                            Ok(rty::Expr::field_proj(
186                                earg,
187                                rty::FieldProj::Tuple { arity: *arity, field: *field },
188                            ))
189                        } else {
190                            Err(FixpointParseError::ProjArityMismatch(fargs.len()))
191                        }
192                    }
193                    fixpoint::Expr::Var(fixpoint::Var::DataProj { adt_id, field }) => {
194                        if fargs.len() == 1 {
195                            let earg = self.fixpoint_to_expr(&fargs[0])?;
196                            Ok(rty::Expr::field_proj(
197                                earg,
198                                rty::FieldProj::Adt {
199                                    def_id: self.scx.adt_sorts[adt_id.as_usize()],
200                                    field: *field,
201                                },
202                            ))
203                        } else {
204                            Err(FixpointParseError::ProjArityMismatch(fargs.len()))
205                        }
206                    }
207                    fixpoint::Expr::Var(fixpoint::Var::TupleCtor { arity }) => {
208                        if fargs.len() == *arity {
209                            let eargs = fargs
210                                .iter()
211                                .map(|farg| self.fixpoint_to_expr(farg))
212                                .try_collect()?;
213                            Ok(rty::Expr::tuple(eargs))
214                        } else {
215                            Err(FixpointParseError::TupleCtorArityMismatch(*arity, fargs.len()))
216                        }
217                    }
218                    fixpoint::Expr::Var(fixpoint::Var::UIFRel(fbinrel)) => {
219                        if fargs.len() == 2 {
220                            let e1 = self.fixpoint_to_expr(&fargs[0])?;
221                            let e2 = self.fixpoint_to_expr(&fargs[1])?;
222                            let binrel = match fbinrel {
223                                fixpoint::BinRel::Eq => rty::BinOp::Eq,
224                                fixpoint::BinRel::Ne => rty::BinOp::Ne,
225                                // FIXME: (ck) faked sort information
226                                //
227                                // This needs to be a sort that goes to the UIFRel
228                                // case in fixpoint conversion. Again, if we actually
229                                // need to inspect the sorts this will die unless the
230                                // arguments are actually Strs.
231                                fixpoint::BinRel::Gt => rty::BinOp::Gt(rty::Sort::Str),
232                                fixpoint::BinRel::Ge => rty::BinOp::Ge(rty::Sort::Str),
233                                fixpoint::BinRel::Lt => rty::BinOp::Lt(rty::Sort::Str),
234                                fixpoint::BinRel::Le => rty::BinOp::Le(rty::Sort::Str),
235                            };
236                            Ok(rty::Expr::binary_op(binrel, e1, e2))
237                        } else {
238                            Err(FixpointParseError::UIFRelArityMismatch(fargs.len()))
239                        }
240                    }
241                    fixpoint::Expr::Var(fixpoint::Var::Global(global_var, _))
242                    | fixpoint::Expr::Var(fixpoint::Var::Const(global_var, _)) => {
243                        if let Some(const_key) = self.ecx.const_env.const_map_rev.get(global_var) {
244                            match const_key {
245                                // NOTE: Only a few of these are meaningfully needed,
246                                // e.g. ConstKey::Alias because the rty Expr has its
247                                // args as a part of it.
248                                ConstKey::PrimOp(bin_op) => {
249                                    if fargs.len() != 2 {
250                                        Err(FixpointParseError::PrimOpArityMismatch(fargs.len()))
251                                    } else {
252                                        Ok(rty::Expr::prim_val(
253                                            bin_op.clone(),
254                                            self.fixpoint_to_expr(&fargs[0])?,
255                                            self.fixpoint_to_expr(&fargs[1])?,
256                                        ))
257                                    }
258                                }
259                                ConstKey::Cast(sort1, sort2) => {
260                                    if fargs.len() != 1 {
261                                        Err(FixpointParseError::CastArityMismatch(fargs.len()))
262                                    } else {
263                                        Ok(rty::Expr::cast(
264                                            sort1.clone(),
265                                            sort2.clone(),
266                                            self.fixpoint_to_expr(&fargs[0])?,
267                                        ))
268                                    }
269                                }
270                                ConstKey::Alias(assoc_id, generic_args) => {
271                                    let lowered_args: flux_rustc_bridge::ty::GenericArgs =
272                                        generic_args.lower(self.genv.tcx()).unwrap();
273                                    let generic_args = rty::refining::Refiner::default_for_item(
274                                        self.genv,
275                                        assoc_id.parent(),
276                                    )
277                                    .unwrap()
278                                    .refine_generic_args(assoc_id.parent(), &lowered_args)
279                                    .unwrap();
280                                    let alias_reft =
281                                        rty::AliasReft { assoc_id: *assoc_id, args: generic_args };
282                                    let args = fargs
283                                        .iter()
284                                        .map(|farg| self.fixpoint_to_expr(farg))
285                                        .try_collect()?;
286                                    Ok(rty::Expr::alias(alias_reft, args))
287                                }
288                                ConstKey::RustConst(..) | ConstKey::Lambda(..) => {
289                                    // These should be treated as a normal app.
290                                    self.fixpoint_app_to_expr(fhead, fargs)
291                                }
292                            }
293                        } else {
294                            Err(FixpointParseError::NoGlobalVar(*global_var))
295                        }
296                    }
297                    fhead => self.fixpoint_app_to_expr(fhead, fargs),
298                }
299            }
300            fixpoint::Expr::Neg(fexpr) => {
301                let e = self.fixpoint_to_expr(fexpr)?;
302                Ok(rty::Expr::neg(&e))
303            }
304            fixpoint::Expr::BinaryOp(fbinop, boxed_args) => {
305                let binop = match fbinop {
306                    // FIXME: (ck) faked sort information
307                    //
308                    // See what we do for binrel for an explanation.
309                    fixpoint::BinOp::Add => rty::BinOp::Add(rty::Sort::Int),
310                    fixpoint::BinOp::Sub => rty::BinOp::Sub(rty::Sort::Int),
311                    fixpoint::BinOp::Mul => rty::BinOp::Mul(rty::Sort::Int),
312                    fixpoint::BinOp::Div => rty::BinOp::Div(rty::Sort::Int),
313                    fixpoint::BinOp::Mod => rty::BinOp::Mod(rty::Sort::Int),
314                };
315                let [fe1, fe2] = &**boxed_args;
316                let e1 = self.fixpoint_to_expr(fe1)?;
317                let e2 = self.fixpoint_to_expr(fe2)?;
318                Ok(rty::Expr::binary_op(binop, e1, e2))
319            }
320            fixpoint::Expr::IfThenElse(boxed_args) => {
321                let [fe1, fe2, fe3] = &**boxed_args;
322                let e1 = self.fixpoint_to_expr(fe1)?;
323                let e2 = self.fixpoint_to_expr(fe2)?;
324                let e3 = self.fixpoint_to_expr(fe3)?;
325                Ok(rty::Expr::ite(e1, e2, e3))
326            }
327            fixpoint::Expr::And(fexprs) => {
328                let exprs: Vec<rty::Expr> = fexprs
329                    .iter()
330                    .map(|fexpr| self.fixpoint_to_expr(fexpr))
331                    .try_collect()?;
332                Ok(rty::Expr::and_from_iter(exprs))
333            }
334            fixpoint::Expr::Or(fexprs) => {
335                let exprs: Vec<rty::Expr> = fexprs
336                    .iter()
337                    .map(|fexpr| self.fixpoint_to_expr(fexpr))
338                    .try_collect()?;
339                Ok(rty::Expr::or_from_iter(exprs))
340            }
341            fixpoint::Expr::Not(fexpr) => {
342                let e = self.fixpoint_to_expr(fexpr)?;
343                Ok(rty::Expr::not(&e))
344            }
345            fixpoint::Expr::Imp(boxed_args) => {
346                let [fe1, fe2] = &**boxed_args;
347                let e1 = self.fixpoint_to_expr(fe1)?;
348                let e2 = self.fixpoint_to_expr(fe2)?;
349                Ok(rty::Expr::binary_op(rty::BinOp::Imp, e1, e2))
350            }
351            fixpoint::Expr::Iff(boxed_args) => {
352                let [fe1, fe2] = &**boxed_args;
353                let e1 = self.fixpoint_to_expr(fe1)?;
354                let e2 = self.fixpoint_to_expr(fe2)?;
355                Ok(rty::Expr::binary_op(rty::BinOp::Iff, e1, e2))
356            }
357            fixpoint::Expr::Atom(fbinrel, boxed_args) => {
358                let binrel = match fbinrel {
359                    fixpoint::BinRel::Eq => rty::BinOp::Eq,
360                    fixpoint::BinRel::Ne => rty::BinOp::Ne,
361                    // FIXME: (ck) faked sort information
362                    //
363                    // I'm pretty sure that it is OK to give `rty::Sort::Int`
364                    // because we only emit `fixpoint::BinRel::Gt`, etc. when we
365                    // have an Int/Real/Char sort (and further this sort info
366                    // isn't further used). But if we inspect this in other
367                    // places then things could break.
368                    fixpoint::BinRel::Gt => rty::BinOp::Gt(rty::Sort::Int),
369                    fixpoint::BinRel::Ge => rty::BinOp::Ge(rty::Sort::Int),
370                    fixpoint::BinRel::Lt => rty::BinOp::Lt(rty::Sort::Int),
371                    fixpoint::BinRel::Le => rty::BinOp::Le(rty::Sort::Int),
372                };
373                let [fe1, fe2] = &**boxed_args;
374                let e1 = self.fixpoint_to_expr(fe1)?;
375                let e2 = self.fixpoint_to_expr(fe2)?;
376                Ok(rty::Expr::binary_op(binrel, e1, e2))
377            }
378            fixpoint::Expr::Let(_var, _boxed_args) => {
379                // TODO: (ck) uncomment this and fix the missing code in the todo!()
380                //
381                // let [fe1, fe2] = &**boxed_args;
382                // let e1 = self.fixpoint_to_expr(fe1)?;
383                // let e2 = self.fixpoint_to_expr(fe2)?;
384                // let e2_binder =
385                todo!("Convert `var` in e2 to locally nameless var, then fill in sort");
386                // Ok(rty::Expr::let_(e1, e2_binder))
387            }
388            fixpoint::Expr::ThyFunc(itf) => Ok(rty::Expr::global_func(SpecFuncKind::Thy(*itf))),
389            fixpoint::Expr::IsCtor(var, fe) => {
390                let (def_id, variant_idx) = match var {
391                    fixpoint::Var::DataCtor(adt_id, variant_idx) => {
392                        let def_id = self.scx.adt_sorts[adt_id.as_usize()];
393                        Ok((def_id, *variant_idx))
394                    }
395                    _ => Err(FixpointParseError::WrongVarInIsCtor(*var)),
396                }?;
397                let e = self.fixpoint_to_expr(fe)?;
398                Ok(rty::Expr::is_ctor(def_id, variant_idx, e))
399            }
400            fixpoint::Expr::Exists(binder, body) => {
401                let mut vars = vec![];
402                let mut sorts = vec![];
403                for (var, sort) in binder {
404                    let fixpoint::Var::Local(local_var) = var else {
405                        return Err(FixpointParseError::WrongVarInBinder(*var));
406                    };
407                    vars.push(*local_var);
408                    sorts.push(self.fixpoint_to_sort(sort)?);
409                }
410                self.ecx.local_var_env.push_layer(vars);
411                let body = self.fixpoint_to_expr(body)?;
412                self.ecx.local_var_env.pop_layer();
413                Ok(rty::Expr::exists(Binder::bind_with_sorts(body, &sorts)))
414            }
415        }
416    }
417
418    fn fixpoint_app_to_expr(
419        &mut self,
420        fhead: &fixpoint::Expr,
421        fargs: &[fixpoint::Expr],
422    ) -> Result<rty::Expr, FixpointParseError> {
423        let head = self.fixpoint_to_expr(fhead)?;
424        let args = fargs
425            .iter()
426            .map(|farg| self.fixpoint_to_expr(farg))
427            .try_collect()?;
428        Ok(rty::Expr::app(head, List::empty(), args))
429    }
430}
431
432#[derive(Debug)]
433pub enum FixpointParseError {
434    /// UIFRels are encoded as Apps, but they are as of right now only binary
435    /// relations so they must have 2 arguments only.
436    UIFRelArityMismatch(usize),
437    /// Expected arity (based off of the tuple ctor), actual arity (the numer of args)
438    TupleCtorArityMismatch(usize, usize),
439    /// The number of arguments should only ever be 1 for a tuple proj
440    ProjArityMismatch(usize),
441    NoGlobalVar(fixpoint::GlobalVar),
442    /// Casts should only have 1 arg
443    CastArityMismatch(usize),
444    PrimOpArityMismatch(usize),
445    NoLocalVar(fixpoint::LocalVar),
446    /// Expecting fixpoint::Var::DataCtor
447    WrongVarInIsCtor(fixpoint::Var),
448    /// Expecting fixpoint::Var::LocalVar
449    WrongVarInBinder(fixpoint::Var),
450    UnknownAdt(DefId),
451}