flux_infer/
fixpoint_encoding.rs

1//! Encoding of the refinement tree into a fixpoint constraint.
2
3use std::{hash::Hash, iter};
4
5use fixpoint::AdtId;
6use flux_common::{
7    bug,
8    cache::QueryCache,
9    dbg,
10    index::{IndexGen, IndexVec},
11    iter::IterExt,
12    span_bug, tracked_span_bug,
13};
14use flux_config::{self as config};
15use flux_errors::Errors;
16use flux_middle::{
17    FixpointQueryKind,
18    def_id::{FluxDefId, MaybeExternId},
19    def_id_to_string,
20    fhir::SpecFuncKind,
21    global_env::GlobalEnv,
22    queries::QueryResult,
23    rty::{self, ESpan, GenericArgsExt, Lambda, List, VariantIdx},
24    timings::{self, TimingKind},
25};
26use itertools::Itertools;
27use liquid_fixpoint::{FixpointResult, SmtSolver};
28use rustc_data_structures::{
29    fx::{FxIndexMap, FxIndexSet},
30    unord::{UnordMap, UnordSet},
31};
32use rustc_hir::def_id::{DefId, LocalDefId};
33use rustc_index::newtype_index;
34use rustc_span::Span;
35use rustc_type_ir::{BoundVar, DebruijnIndex};
36use serde::{Deserialize, Deserializer, Serialize};
37
38pub mod fixpoint {
39    use std::fmt;
40
41    use flux_middle::rty::{EarlyReftParam, Real};
42    use liquid_fixpoint::{FixpointFmt, Identifier};
43    use rustc_abi::VariantIdx;
44    use rustc_index::newtype_index;
45    use rustc_middle::ty::ParamConst;
46    use rustc_span::Symbol;
47
48    newtype_index! {
49        pub struct KVid {}
50    }
51
52    impl Identifier for KVid {
53        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54            write!(f, "k{}", self.as_u32())
55        }
56    }
57
58    newtype_index! {
59        pub struct LocalVar {}
60    }
61
62    newtype_index! {
63        pub struct GlobalVar {}
64    }
65
66    newtype_index! {
67        /// Unique id assigned to each [`flux_middle::rty::AdtSortDef`] that needs to be encoded
68        /// into fixpoint
69        pub struct AdtId {}
70    }
71
72    #[derive(Hash, Copy, Clone)]
73    pub enum Var {
74        Underscore,
75        Global(GlobalVar, Option<Symbol>),
76        Local(LocalVar),
77        DataCtor(AdtId, VariantIdx),
78        TupleCtor {
79            arity: usize,
80        },
81        TupleProj {
82            arity: usize,
83            field: u32,
84        },
85        UIFRel(BinRel),
86        /// Interpreted theory function. This can be an arbitrary string, thus we are assuming the
87        /// name is different than the display implementation for the other variants.
88        Itf(liquid_fixpoint::ThyFunc),
89        Param(EarlyReftParam),
90        ConstGeneric(ParamConst),
91    }
92
93    impl From<GlobalVar> for Var {
94        fn from(v: GlobalVar) -> Self {
95            Self::Global(v, None)
96        }
97    }
98
99    impl From<LocalVar> for Var {
100        fn from(v: LocalVar) -> Self {
101            Self::Local(v)
102        }
103    }
104
105    impl Identifier for Var {
106        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107            match self {
108                Var::Global(v, None) => write!(f, "c{}", v.as_u32()),
109                Var::Global(v, Some(sym)) => write!(f, "f${}${}", sym, v.as_u32()),
110                Var::Local(v) => write!(f, "a{}", v.as_u32()),
111                Var::DataCtor(adt_id, variant_idx) => {
112                    write!(f, "mkadt{}${}", adt_id.as_u32(), variant_idx.as_u32())
113                }
114                Var::TupleCtor { arity } => write!(f, "mktuple{arity}"),
115                Var::TupleProj { arity, field } => write!(f, "tuple{arity}${field}"),
116                Var::Itf(name) => write!(f, "{name}"),
117                Var::UIFRel(BinRel::Gt) => write!(f, "gt"),
118                Var::UIFRel(BinRel::Ge) => write!(f, "ge"),
119                Var::UIFRel(BinRel::Lt) => write!(f, "lt"),
120                Var::UIFRel(BinRel::Le) => write!(f, "le"),
121                // these are actually not necessary because equality is interpreted for all sorts
122                Var::UIFRel(BinRel::Eq) => write!(f, "eq"),
123                Var::UIFRel(BinRel::Ne) => write!(f, "ne"),
124                Var::Underscore => write!(f, "_$"), // To avoid clashing with `_` used for `app (_ bv_op n)` for parametric SMT ops
125                Var::ConstGeneric(param) => {
126                    write!(f, "constgen${}${}", param.name, param.index)
127                }
128                Var::Param(param) => {
129                    write!(f, "reftgen${}${}", param.name, param.index)
130                }
131            }
132        }
133    }
134
135    #[derive(Clone, Hash)]
136    pub enum DataSort {
137        Tuple(usize),
138        Adt(AdtId),
139    }
140
141    impl Identifier for DataSort {
142        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143            match self {
144                DataSort::Tuple(arity) => {
145                    write!(f, "Tuple{arity}")
146                }
147                DataSort::Adt(adt_id) => {
148                    write!(f, "Adt{}", adt_id.as_u32())
149                }
150            }
151        }
152    }
153
154    #[derive(Hash)]
155    pub struct SymStr(pub Symbol);
156
157    impl FixpointFmt for SymStr {
158        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
159            write!(f, "\"{}\"", self.0)
160        }
161    }
162
163    liquid_fixpoint::declare_types! {
164        type Sort = DataSort;
165        type KVar = KVid;
166        type Var = Var;
167        type Decimal = Real;
168        type String = SymStr;
169        type Tag = super::TagIdx;
170    }
171    pub use fixpoint_generated::*;
172}
173
174newtype_index! {
175    #[debug_format = "TagIdx({})"]
176    pub struct TagIdx {}
177}
178
179impl Serialize for TagIdx {
180    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
181        self.as_u32().serialize(serializer)
182    }
183}
184
185impl<'de> Deserialize<'de> for TagIdx {
186    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
187        let idx = usize::deserialize(deserializer)?;
188        Ok(TagIdx::from_u32(idx as u32))
189    }
190}
191
192/// Keep track of all the data sorts that we need to define in fixpoint to encode the constraint.
193#[derive(Default)]
194struct SortEncodingCtxt {
195    /// Set of all the tuple arities that need to be defined
196    tuples: UnordSet<usize>,
197    /// Set of all the [`AdtDefSortDef`](flux_middle::rty::AdtSortDef) that need to be declared as
198    /// Fixpoint data-decls
199    adt_sorts: FxIndexSet<DefId>,
200}
201
202impl SortEncodingCtxt {
203    fn sort_to_fixpoint(&mut self, sort: &rty::Sort) -> fixpoint::Sort {
204        match sort {
205            rty::Sort::Int => fixpoint::Sort::Int,
206            rty::Sort::Real => fixpoint::Sort::Real,
207            rty::Sort::Bool => fixpoint::Sort::Bool,
208            rty::Sort::Str => fixpoint::Sort::Str,
209            rty::Sort::Char => fixpoint::Sort::Int,
210            rty::Sort::BitVec(size) => fixpoint::Sort::BitVec(Box::new(bv_size_to_fixpoint(*size))),
211            // There's no way to declare opaque sorts in the fixpoint horn syntax so we encode user
212            // declared opaque sorts, type parameter sorts, and (unormalizable) type alias sorts as
213            // integers. Well-formedness should ensure values of these sorts are used "opaquely",
214            // i.e., the only values of these sorts are variables.
215            rty::Sort::App(rty::SortCtor::User { .. }, _)
216            | rty::Sort::Param(_)
217            | rty::Sort::Alias(rty::AliasKind::Opaque | rty::AliasKind::Projection, ..) => {
218                fixpoint::Sort::Int
219            }
220            rty::Sort::App(rty::SortCtor::Set, args) => {
221                let args = args.iter().map(|s| self.sort_to_fixpoint(s)).collect_vec();
222                fixpoint::Sort::App(fixpoint::SortCtor::Set, args)
223            }
224            rty::Sort::App(rty::SortCtor::Map, args) => {
225                let args = args.iter().map(|s| self.sort_to_fixpoint(s)).collect_vec();
226                fixpoint::Sort::App(fixpoint::SortCtor::Map, args)
227            }
228            rty::Sort::App(rty::SortCtor::Adt(sort_def), args) => {
229                if let Some(variant) = sort_def.opt_struct_variant() {
230                    let sorts = variant.field_sorts(args);
231                    // do not generate 1-tuples
232                    if let [sort] = &sorts[..] {
233                        self.sort_to_fixpoint(sort)
234                    } else {
235                        self.declare_tuple(sorts.len());
236                        let ctor = fixpoint::SortCtor::Data(fixpoint::DataSort::Tuple(sorts.len()));
237                        let args = sorts.iter().map(|s| self.sort_to_fixpoint(s)).collect_vec();
238                        fixpoint::Sort::App(ctor, args)
239                    }
240                } else {
241                    debug_assert!(args.is_empty());
242                    let adt_id = self.declare_adt(sort_def.did());
243                    fixpoint::Sort::App(
244                        fixpoint::SortCtor::Data(fixpoint::DataSort::Adt(adt_id)),
245                        vec![],
246                    )
247                }
248            }
249            rty::Sort::Tuple(sorts) => {
250                // do not generate 1-tuples
251                if let [sort] = &sorts[..] {
252                    self.sort_to_fixpoint(sort)
253                } else {
254                    self.declare_tuple(sorts.len());
255                    let ctor = fixpoint::SortCtor::Data(fixpoint::DataSort::Tuple(sorts.len()));
256                    let args = sorts.iter().map(|s| self.sort_to_fixpoint(s)).collect();
257                    fixpoint::Sort::App(ctor, args)
258                }
259            }
260            rty::Sort::Func(sort) => self.func_sort_to_fixpoint(sort),
261            rty::Sort::Var(k) => fixpoint::Sort::Var(k.index()),
262            rty::Sort::Err
263            | rty::Sort::Infer(_)
264            | rty::Sort::Loc
265            | rty::Sort::Alias(rty::AliasKind::Weak, _) => {
266                tracked_span_bug!("unexpected sort `{sort:?}`")
267            }
268        }
269    }
270
271    fn func_sort_to_fixpoint(&mut self, fsort: &rty::PolyFuncSort) -> fixpoint::Sort {
272        let params = fsort.params().len();
273        let fsort = fsort.skip_binders();
274        let output = self.sort_to_fixpoint(fsort.output());
275        fixpoint::Sort::mk_func(
276            params,
277            fsort.inputs().iter().map(|s| self.sort_to_fixpoint(s)),
278            output,
279        )
280    }
281
282    fn declare_tuple(&mut self, arity: usize) {
283        self.tuples.insert(arity);
284    }
285
286    pub fn declare_adt(&mut self, did: DefId) -> AdtId {
287        if let Some(idx) = self.adt_sorts.get_index_of(&did) {
288            AdtId::from_usize(idx)
289        } else {
290            let adt_id = AdtId::from_usize(self.adt_sorts.len());
291            self.adt_sorts.insert(did);
292            adt_id
293        }
294    }
295
296    fn append_adt_decls(
297        genv: GlobalEnv,
298        adt_sorts: FxIndexSet<DefId>,
299        decls: &mut Vec<fixpoint::DataDecl>,
300    ) -> QueryResult {
301        for (idx, adt_def_id) in adt_sorts.iter().enumerate() {
302            let adt_id = AdtId::from_usize(idx);
303            let adt_sort_def = genv.adt_sort_def_of(adt_def_id)?;
304            decls.push(fixpoint::DataDecl {
305                name: fixpoint::DataSort::Adt(adt_id),
306                vars: adt_sort_def.param_count(),
307                ctors: adt_sort_def
308                    .variants()
309                    .iter_enumerated()
310                    .map(|(idx, variant)| {
311                        debug_assert_eq!(variant.fields(), 0);
312                        fixpoint::DataCtor {
313                            name: fixpoint::Var::DataCtor(adt_id, idx),
314                            fields: vec![],
315                        }
316                    })
317                    .collect(),
318            });
319        }
320        Ok(())
321    }
322
323    fn append_tuple_decls(tuples: UnordSet<usize>, decls: &mut Vec<fixpoint::DataDecl>) {
324        decls.extend(
325            tuples
326                .into_items()
327                .into_sorted_stable_ord()
328                .into_iter()
329                .map(|arity| {
330                    fixpoint::DataDecl {
331                        name: fixpoint::DataSort::Tuple(arity),
332                        vars: arity,
333                        ctors: vec![fixpoint::DataCtor {
334                            name: fixpoint::Var::TupleCtor { arity },
335                            fields: (0..(arity as u32))
336                                .map(|field| {
337                                    fixpoint::DataField {
338                                        name: fixpoint::Var::TupleProj { arity, field },
339                                        sort: fixpoint::Sort::Var(field as usize),
340                                    }
341                                })
342                                .collect(),
343                        }],
344                    }
345                }),
346        );
347    }
348
349    fn into_data_decls(self, genv: GlobalEnv) -> QueryResult<Vec<fixpoint::DataDecl>> {
350        let mut decls = vec![];
351        Self::append_tuple_decls(self.tuples, &mut decls);
352        Self::append_adt_decls(genv, self.adt_sorts, &mut decls)?;
353        Ok(decls)
354    }
355}
356
357fn bv_size_to_fixpoint(size: rty::BvSize) -> fixpoint::Sort {
358    match size {
359        rty::BvSize::Fixed(size) => fixpoint::Sort::BvSize(size),
360        rty::BvSize::Param(_var) => {
361            // I think we could encode the size as a sort variable, but this would require some care
362            // because smtlib doesn't really support parametric sizes. Fixpoint is probably already
363            // too liberal about this and it'd be easy to make it crash.
364            // fixpoint::Sort::Var(var.index)
365            bug!("unexpected parametric bit-vector size")
366        }
367        rty::BvSize::Infer(_) => bug!("unexpected infer variable for bit-vector size"),
368    }
369}
370
371type FunDefMap = FxIndexMap<FluxDefId, fixpoint::Var>;
372type ConstMap<'tcx> = FxIndexMap<ConstKey<'tcx>, fixpoint::ConstDecl>;
373
374#[derive(Eq, Hash, PartialEq)]
375enum ConstKey<'tcx> {
376    Uif(FluxDefId),
377    RustConst(DefId),
378    Alias(FluxDefId, rustc_middle::ty::GenericArgsRef<'tcx>),
379    Lambda(Lambda),
380}
381
382pub struct FixpointCtxt<'genv, 'tcx, T: Eq + Hash> {
383    comments: Vec<String>,
384    genv: GlobalEnv<'genv, 'tcx>,
385    kvars: KVarGen,
386    scx: SortEncodingCtxt,
387    kcx: KVarEncodingCtxt,
388    ecx: ExprEncodingCtxt<'genv, 'tcx>,
389    tags: IndexVec<TagIdx, T>,
390    tags_inv: UnordMap<T, TagIdx>,
391    /// Id of the item being checked. This is a [`MaybeExternId`] because we can be checking invariants for
392    /// an extern spec on an enum.
393    def_id: MaybeExternId,
394}
395
396pub type FixQueryCache = QueryCache<FixpointResult<TagIdx>>;
397
398impl<'genv, 'tcx, Tag> FixpointCtxt<'genv, 'tcx, Tag>
399where
400    Tag: std::hash::Hash + Eq + Copy,
401{
402    pub fn new(genv: GlobalEnv<'genv, 'tcx>, def_id: MaybeExternId, kvars: KVarGen) -> Self {
403        let def_span = genv.tcx().def_span(def_id);
404        Self {
405            comments: vec![],
406            kvars,
407            scx: SortEncodingCtxt::default(),
408            genv,
409            ecx: ExprEncodingCtxt::new(genv, def_span),
410            kcx: Default::default(),
411            tags: IndexVec::new(),
412            tags_inv: Default::default(),
413            def_id,
414        }
415    }
416
417    pub fn check(
418        mut self,
419        cache: &mut FixQueryCache,
420        constraint: fixpoint::Constraint,
421        kind: FixpointQueryKind,
422        scrape_quals: bool,
423        solver: SmtSolver,
424    ) -> QueryResult<Vec<Tag>> {
425        // skip checking trivial constraints
426        if !constraint.is_concrete() {
427            self.ecx.errors.into_result()?;
428            return Ok(vec![]);
429        }
430        let def_span = self.def_span();
431        let def_id = self.def_id;
432
433        let kvars = self.kcx.into_fixpoint();
434
435        let (define_funs, define_constants) = self.ecx.define_funs(def_id, &mut self.scx)?;
436        let qualifiers = self.ecx.qualifiers_for(def_id.local_id(), &mut self.scx)?;
437
438        // Assuming values should happen after all encoding is done so we are sure we've collected
439        // all constants.
440        let constraint = self.ecx.assume_const_values(constraint, &mut self.scx)?;
441
442        let mut constants = self.ecx.const_map.into_values().collect_vec();
443        constants.extend(define_constants);
444
445        for rel in fixpoint::BinRel::INEQUALITIES {
446            // ∀a. a -> a -> bool
447            let sort = fixpoint::Sort::mk_func(
448                1,
449                [fixpoint::Sort::Var(0), fixpoint::Sort::Var(0)],
450                fixpoint::Sort::Bool,
451            );
452            constants.push(fixpoint::ConstDecl {
453                name: fixpoint::Var::UIFRel(rel),
454                sort,
455                comment: None,
456            });
457        }
458
459        // We are done encoding expressions. Check if there are any errors.
460        self.ecx.errors.into_result()?;
461
462        let task = fixpoint::Task {
463            comments: self.comments,
464            constants,
465            kvars,
466            define_funs,
467            constraint,
468            qualifiers,
469            scrape_quals,
470            solver,
471            data_decls: self.scx.into_data_decls(self.genv)?,
472        };
473        if config::dump_constraint() {
474            dbg::dump_item_info(self.genv.tcx(), self.def_id.resolved_id(), "smt2", &task).unwrap();
475        }
476
477        match Self::run_task_with_cache(self.genv, task, self.def_id.resolved_id(), kind, cache) {
478            FixpointResult::Safe(_) => Ok(vec![]),
479            FixpointResult::Unsafe(_, errors) => {
480                Ok(errors
481                    .into_iter()
482                    .map(|err| self.tags[err.tag])
483                    .unique()
484                    .collect_vec())
485            }
486            FixpointResult::Crash(err) => span_bug!(def_span, "fixpoint crash: {err:?}"),
487        }
488    }
489
490    fn run_task_with_cache(
491        genv: GlobalEnv,
492        task: fixpoint::Task,
493        def_id: DefId,
494        kind: FixpointQueryKind,
495        cache: &mut FixQueryCache,
496    ) -> FixpointResult<TagIdx> {
497        let key = kind.task_key(genv.tcx(), def_id);
498
499        let hash = task.hash_with_default();
500
501        if config::is_cache_enabled()
502            && let Some(result) = cache.lookup(&key, hash)
503        {
504            return result.clone();
505        }
506        let result = timings::time_it(TimingKind::FixpointQuery(def_id, kind), || {
507            task.run()
508                .unwrap_or_else(|err| tracked_span_bug!("failed to run fixpoint: {err}"))
509        });
510
511        if config::is_cache_enabled() {
512            cache.insert(key, hash, result.clone());
513        }
514        result
515    }
516
517    fn tag_idx(&mut self, tag: Tag) -> TagIdx
518    where
519        Tag: std::fmt::Debug,
520    {
521        *self.tags_inv.entry(tag).or_insert_with(|| {
522            let idx = self.tags.push(tag);
523            self.comments.push(format!("Tag {idx}: {tag:?}"));
524            idx
525        })
526    }
527
528    pub(crate) fn with_name_map<R>(
529        &mut self,
530        name: rty::Name,
531        f: impl FnOnce(&mut Self, fixpoint::LocalVar) -> R,
532    ) -> R {
533        let fresh = self.ecx.local_var_env.insert_fvar_map(name);
534        let r = f(self, fresh);
535        self.ecx.local_var_env.remove_fvar_map(name);
536        r
537    }
538
539    pub(crate) fn sort_to_fixpoint(&mut self, sort: &rty::Sort) -> fixpoint::Sort {
540        self.scx.sort_to_fixpoint(sort)
541    }
542
543    pub(crate) fn var_to_fixpoint(&self, var: &rty::Var) -> fixpoint::Var {
544        self.ecx.var_to_fixpoint(var)
545    }
546
547    /// Encodes an expression in head position as a [`fixpoint::Constraint`] "peeling out"
548    /// implications and foralls.
549    ///
550    /// [`fixpoint::Constraint`]: liquid_fixpoint::Constraint
551    pub(crate) fn head_to_fixpoint(
552        &mut self,
553        expr: &rty::Expr,
554        mk_tag: impl Fn(Option<ESpan>) -> Tag + Copy,
555    ) -> QueryResult<fixpoint::Constraint>
556    where
557        Tag: std::fmt::Debug,
558    {
559        match expr.kind() {
560            rty::ExprKind::BinaryOp(rty::BinOp::And, ..) => {
561                // avoid creating nested conjunctions
562                let cstrs = expr
563                    .flatten_conjs()
564                    .into_iter()
565                    .map(|e| self.head_to_fixpoint(e, mk_tag))
566                    .try_collect()?;
567                Ok(fixpoint::Constraint::conj(cstrs))
568            }
569            rty::ExprKind::BinaryOp(rty::BinOp::Imp, e1, e2) => {
570                let (bindings, assumption) = self.assumption_to_fixpoint(e1)?;
571                let cstr = self.head_to_fixpoint(e2, mk_tag)?;
572                Ok(fixpoint::Constraint::foralls(bindings, mk_implies(assumption, cstr)))
573            }
574            rty::ExprKind::KVar(kvar) => {
575                let mut bindings = vec![];
576                let pred = self.kvar_to_fixpoint(kvar, &mut bindings)?;
577                Ok(fixpoint::Constraint::foralls(bindings, fixpoint::Constraint::Pred(pred, None)))
578            }
579            rty::ExprKind::ForAll(pred) => {
580                self.ecx
581                    .local_var_env
582                    .push_layer_with_fresh_names(pred.vars().len());
583                let cstr = self.head_to_fixpoint(pred.as_ref().skip_binder(), mk_tag)?;
584                let vars = self.ecx.local_var_env.pop_layer();
585
586                let bindings = iter::zip(vars, pred.vars())
587                    .map(|(var, kind)| {
588                        fixpoint::Bind {
589                            name: var.into(),
590                            sort: self.scx.sort_to_fixpoint(kind.expect_sort()),
591                            pred: fixpoint::Pred::TRUE,
592                        }
593                    })
594                    .collect_vec();
595
596                Ok(fixpoint::Constraint::foralls(bindings, cstr))
597            }
598            _ => {
599                let tag_idx = self.tag_idx(mk_tag(expr.span()));
600                let pred = fixpoint::Pred::Expr(self.ecx.expr_to_fixpoint(expr, &mut self.scx)?);
601                Ok(fixpoint::Constraint::Pred(pred, Some(tag_idx)))
602            }
603        }
604    }
605
606    /// Encodes an expression in assumptive position as a [`fixpoint::Pred`]. Returns the encoded
607    /// predicate and a list of bindings produced by ANF-ing kvars.
608    ///
609    /// [`fixpoint::Pred`]: liquid_fixpoint::Pred
610    pub(crate) fn assumption_to_fixpoint(
611        &mut self,
612        pred: &rty::Expr,
613    ) -> QueryResult<(Vec<fixpoint::Bind>, fixpoint::Pred)> {
614        let mut bindings = vec![];
615        let mut preds = vec![];
616        self.assumption_to_fixpoint_aux(pred, &mut bindings, &mut preds)?;
617        Ok((bindings, fixpoint::Pred::and(preds)))
618    }
619
620    /// Auxiliary function to merge nested conjunctions in a single predicate
621    fn assumption_to_fixpoint_aux(
622        &mut self,
623        expr: &rty::Expr,
624        bindings: &mut Vec<fixpoint::Bind>,
625        preds: &mut Vec<fixpoint::Pred>,
626    ) -> QueryResult {
627        match expr.kind() {
628            rty::ExprKind::BinaryOp(rty::BinOp::And, e1, e2) => {
629                self.assumption_to_fixpoint_aux(e1, bindings, preds)?;
630                self.assumption_to_fixpoint_aux(e2, bindings, preds)?;
631            }
632            rty::ExprKind::KVar(kvar) => {
633                preds.push(self.kvar_to_fixpoint(kvar, bindings)?);
634            }
635            rty::ExprKind::ForAll(_) => {
636                // If a forall appears in assumptive position replace it with true. This is sound
637                // because we are weakening the context, i.e., anything true without the assumption
638                // should remain true after adding it. Note that this relies on the predicate
639                // appearing negatively. This is guaranteed by the surface syntax because foralls
640                // can only appear at the top-level in a requires clause.
641                preds.push(fixpoint::Pred::TRUE);
642            }
643            _ => {
644                preds.push(fixpoint::Pred::Expr(self.ecx.expr_to_fixpoint(expr, &mut self.scx)?));
645            }
646        }
647        Ok(())
648    }
649
650    fn kvar_to_fixpoint(
651        &mut self,
652        kvar: &rty::KVar,
653        bindings: &mut Vec<fixpoint::Bind>,
654    ) -> QueryResult<fixpoint::Pred> {
655        let decl = self.kvars.get(kvar.kvid);
656        let kvids = self.kcx.encode(kvar.kvid, decl, &mut self.scx);
657
658        let all_args = iter::zip(&kvar.args, &decl.sorts)
659            .map(|(arg, sort)| self.ecx.imm(arg, sort, &mut self.scx, bindings))
660            .try_collect_vec()?;
661
662        // Fixpoint doesn't support kvars without arguments, which we do generate sometimes. To get
663        // around it, we encode `$k()` as ($k 0), or more precisely `(forall ((x int) (= x 0)) ... ($k x)`
664        // after ANF-ing.
665        if all_args.is_empty() {
666            let fresh = self.ecx.local_var_env.fresh_name();
667            let var = fixpoint::Var::Local(fresh);
668            bindings.push(fixpoint::Bind {
669                name: fresh.into(),
670                sort: fixpoint::Sort::Int,
671                pred: fixpoint::Pred::Expr(fixpoint::Expr::eq(
672                    fixpoint::Expr::Var(var),
673                    fixpoint::Expr::int(0),
674                )),
675            });
676            return Ok(fixpoint::Pred::KVar(kvids[0], vec![var]));
677        }
678
679        let kvars = kvids
680            .iter()
681            .enumerate()
682            .map(|(i, kvid)| {
683                let args = all_args[i..].to_vec();
684                fixpoint::Pred::KVar(*kvid, args)
685            })
686            .collect_vec();
687
688        Ok(fixpoint::Pred::And(kvars))
689    }
690
691    fn def_span(&self) -> Span {
692        self.genv.tcx().def_span(self.def_id)
693    }
694}
695
696fn const_to_fixpoint(cst: rty::Constant) -> fixpoint::Expr {
697    match cst {
698        rty::Constant::Int(i) => {
699            if i.is_negative() {
700                fixpoint::Expr::Neg(Box::new(fixpoint::Constant::Numeral(i.abs()).into()))
701            } else {
702                fixpoint::Constant::Numeral(i.abs()).into()
703            }
704        }
705        rty::Constant::Real(r) => fixpoint::Constant::Decimal(r).into(),
706        rty::Constant::Bool(b) => fixpoint::Constant::Boolean(b).into(),
707        rty::Constant::Char(c) => fixpoint::Constant::Numeral(u128::from(c)).into(),
708        rty::Constant::Str(s) => fixpoint::Constant::String(fixpoint::SymStr(s)).into(),
709        rty::Constant::BitVec(i, size) => fixpoint::Constant::BitVec(i, size).into(),
710    }
711}
712
713struct FixpointKVar {
714    sorts: Vec<fixpoint::Sort>,
715    orig: rty::KVid,
716}
717
718/// During encoding into fixpoint we generate multiple fixpoint kvars per kvar in flux. A
719/// [`KVarEncodingCtxt`] is used to keep track of the state needed for this.
720#[derive(Default)]
721struct KVarEncodingCtxt {
722    /// List of all kvars that need to be defined in fixpoint
723    kvars: IndexVec<fixpoint::KVid, FixpointKVar>,
724    /// A mapping from [`rty::KVid`] to the list of [`fixpoint::KVid`]s encoding the kvar.
725    map: UnordMap<rty::KVid, Vec<fixpoint::KVid>>,
726}
727
728impl KVarEncodingCtxt {
729    fn encode(
730        &mut self,
731        kvid: rty::KVid,
732        decl: &KVarDecl,
733        scx: &mut SortEncodingCtxt,
734    ) -> &[fixpoint::KVid] {
735        self.map.entry(kvid).or_insert_with(|| {
736            let all_args = decl
737                .sorts
738                .iter()
739                .map(|s| scx.sort_to_fixpoint(s))
740                .collect_vec();
741
742            // See comment in `kvar_to_fixpoint`
743            if all_args.is_empty() {
744                let sorts = vec![fixpoint::Sort::Int];
745                let kvid = self.kvars.push(FixpointKVar::new(sorts, kvid));
746                return vec![kvid];
747            }
748
749            match decl.encoding {
750                KVarEncoding::Single => {
751                    let kvid = self.kvars.push(FixpointKVar::new(all_args, kvid));
752                    vec![kvid]
753                }
754                KVarEncoding::Conj => {
755                    let n = usize::max(decl.self_args, 1);
756                    (0..n)
757                        .map(|i| {
758                            let sorts = all_args[i..].to_vec();
759                            self.kvars.push(FixpointKVar::new(sorts, kvid))
760                        })
761                        .collect_vec()
762                }
763            }
764        })
765    }
766
767    fn into_fixpoint(self) -> Vec<fixpoint::KVarDecl> {
768        self.kvars
769            .into_iter_enumerated()
770            .map(|(kvid, kvar)| {
771                fixpoint::KVarDecl::new(kvid, kvar.sorts, format!("orig: {:?}", kvar.orig))
772            })
773            .collect()
774    }
775}
776
777/// Environment used to map from [`rty::Var`] to a [`fixpoint::LocalVar`].
778struct LocalVarEnv {
779    local_var_gen: IndexGen<fixpoint::LocalVar>,
780    fvars: UnordMap<rty::Name, fixpoint::LocalVar>,
781    /// Layers of late bound variables
782    layers: Vec<Vec<fixpoint::LocalVar>>,
783}
784
785impl LocalVarEnv {
786    fn new() -> Self {
787        Self { local_var_gen: IndexGen::new(), fvars: Default::default(), layers: Vec::new() }
788    }
789
790    // This doesn't require to be mutable because `IndexGen` uses atomics, but we make it mutable
791    // to better declare the intent.
792    fn fresh_name(&mut self) -> fixpoint::LocalVar {
793        self.local_var_gen.fresh()
794    }
795
796    fn insert_fvar_map(&mut self, name: rty::Name) -> fixpoint::LocalVar {
797        let fresh = self.fresh_name();
798        self.fvars.insert(name, fresh);
799        fresh
800    }
801
802    fn remove_fvar_map(&mut self, name: rty::Name) {
803        self.fvars.remove(&name);
804    }
805
806    /// Push a layer of bound variables assigning a fresh [`fixpoint::LocalVar`] to each one
807    fn push_layer_with_fresh_names(&mut self, count: usize) {
808        let layer = (0..count).map(|_| self.fresh_name()).collect();
809        self.layers.push(layer);
810    }
811
812    fn pop_layer(&mut self) -> Vec<fixpoint::LocalVar> {
813        self.layers.pop().unwrap()
814    }
815
816    fn get_fvar(&self, name: rty::Name) -> Option<fixpoint::LocalVar> {
817        self.fvars.get(&name).copied()
818    }
819
820    fn get_late_bvar(&self, debruijn: DebruijnIndex, var: BoundVar) -> Option<fixpoint::LocalVar> {
821        let depth = self.layers.len().checked_sub(debruijn.as_usize() + 1)?;
822        self.layers[depth].get(var.as_usize()).copied()
823    }
824}
825
826impl FixpointKVar {
827    fn new(sorts: Vec<fixpoint::Sort>, orig: rty::KVid) -> Self {
828        Self { sorts, orig }
829    }
830}
831
832pub struct KVarGen {
833    kvars: IndexVec<rty::KVid, KVarDecl>,
834    /// If true, generate dummy [holes] instead of kvars. Used during shape mode to avoid generating
835    /// unnecessary kvars.
836    ///
837    /// [holes]: rty::ExprKind::Hole
838    dummy: bool,
839}
840
841impl KVarGen {
842    pub(crate) fn new(dummy: bool) -> Self {
843        Self { kvars: IndexVec::new(), dummy }
844    }
845
846    fn get(&self, kvid: rty::KVid) -> &KVarDecl {
847        &self.kvars[kvid]
848    }
849
850    /// Generate a fresh [kvar] under several layers of [binders]. Each layer may contain any kind
851    /// of bound variable, but variables that are not of kind [`BoundVariableKind::Refine`] will
852    /// be filtered out.
853    ///
854    /// The variables bound in the last layer (last element of the `binders` slice) is expected to
855    /// have only [`BoundVariableKind::Refine`] and all its elements are used as the [self arguments].
856    /// The rest of the binders are appended to the `scope`.
857    ///
858    /// Note that the returned expression will have escaping variables and it is up to the caller to
859    /// put it under an appropriate number of binders.
860    ///
861    /// Prefer using [`InferCtxt::fresh_kvar`] when possible.
862    ///
863    /// [binders]: rty::Binder
864    /// [kvar]: rty::KVar
865    /// [`InferCtxt::fresh_kvar`]: crate::infer::InferCtxt::fresh_kvar
866    /// [self arguments]: rty::KVar::self_args
867    pub fn fresh(
868        &mut self,
869        binders: &[rty::BoundVariableKinds],
870        scope: impl IntoIterator<Item = (rty::Var, rty::Sort)>,
871        encoding: KVarEncoding,
872    ) -> rty::Expr {
873        if self.dummy {
874            return rty::Expr::hole(rty::HoleKind::Pred);
875        }
876
877        let args = itertools::chain(
878            binders.iter().rev().enumerate().flat_map(|(level, vars)| {
879                let debruijn = DebruijnIndex::from_usize(level);
880                vars.iter()
881                    .cloned()
882                    .enumerate()
883                    .flat_map(move |(idx, var)| {
884                        if let rty::BoundVariableKind::Refine(sort, _, kind) = var {
885                            let br = rty::BoundReft { var: BoundVar::from_usize(idx), kind };
886                            Some((rty::Var::Bound(debruijn, br), sort))
887                        } else {
888                            None
889                        }
890                    })
891            }),
892            scope,
893        );
894        let [.., last] = binders else {
895            return self.fresh_inner(0, [], encoding);
896        };
897        let num_self_args = last
898            .iter()
899            .filter(|var| matches!(var, rty::BoundVariableKind::Refine(..)))
900            .count();
901        self.fresh_inner(num_self_args, args, encoding)
902    }
903
904    fn fresh_inner<A>(&mut self, self_args: usize, args: A, encoding: KVarEncoding) -> rty::Expr
905    where
906        A: IntoIterator<Item = (rty::Var, rty::Sort)>,
907    {
908        // asset last one has things
909        let mut sorts = vec![];
910        let mut exprs = vec![];
911
912        let mut flattened_self_args = 0;
913        for (i, (var, sort)) in args.into_iter().enumerate() {
914            let is_self_arg = i < self_args;
915            let var = var.to_expr();
916            sort.walk(|sort, proj| {
917                if !matches!(sort, rty::Sort::Loc) {
918                    flattened_self_args += is_self_arg as usize;
919                    sorts.push(sort.clone());
920                    exprs.push(rty::Expr::field_projs(&var, proj));
921                }
922            });
923        }
924
925        let kvid = self
926            .kvars
927            .push(KVarDecl { self_args: flattened_self_args, sorts, encoding });
928
929        let kvar = rty::KVar::new(kvid, flattened_self_args, exprs);
930        rty::Expr::kvar(kvar)
931    }
932}
933
934#[derive(Clone)]
935struct KVarDecl {
936    self_args: usize,
937    sorts: Vec<rty::Sort>,
938    encoding: KVarEncoding,
939}
940
941/// How an [`rty::KVar`] is encoded in the fixpoint constraint
942#[derive(Clone, Copy)]
943pub enum KVarEncoding {
944    /// Generate a single kvar appending the self arguments and the scope, i.e.,
945    /// a kvar `$k(a0, ...)[b0, ...]` becomes `$k(a0, ..., b0, ...)` in the fixpoint constraint.
946    Single,
947    /// Generate a conjunction of kvars, one per argument in [`rty::KVar::args`].
948    /// Concretely, a kvar `$k(a0, a1, ..., an)[b0, ...]` becomes
949    /// `$k0(a0, a1, ..., an, b0, ...) ∧ $k1(a1, ..., an, b0, ...) ∧ ... ∧ $kn(an, b0, ...)`
950    Conj,
951}
952
953impl std::fmt::Display for TagIdx {
954    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
955        write!(f, "{}", self.as_u32())
956    }
957}
958
959impl std::str::FromStr for TagIdx {
960    type Err = std::num::ParseIntError;
961
962    fn from_str(s: &str) -> Result<Self, Self::Err> {
963        Ok(Self::from_u32(s.parse()?))
964    }
965}
966
967struct ExprEncodingCtxt<'genv, 'tcx> {
968    genv: GlobalEnv<'genv, 'tcx>,
969    local_var_env: LocalVarEnv,
970    global_var_gen: IndexGen<fixpoint::GlobalVar>,
971    const_map: ConstMap<'tcx>,
972    fun_def_map: FunDefMap,
973    errors: Errors<'genv>,
974    def_span: Span,
975}
976
977impl<'genv, 'tcx> ExprEncodingCtxt<'genv, 'tcx> {
978    fn new(genv: GlobalEnv<'genv, 'tcx>, def_span: Span) -> Self {
979        Self {
980            genv,
981            local_var_env: LocalVarEnv::new(),
982            global_var_gen: IndexGen::new(),
983            const_map: Default::default(),
984            fun_def_map: Default::default(),
985            errors: Errors::new(genv.sess()),
986            def_span,
987        }
988    }
989
990    fn var_to_fixpoint(&self, var: &rty::Var) -> fixpoint::Var {
991        match var {
992            rty::Var::Free(name) => {
993                self.local_var_env
994                    .get_fvar(*name)
995                    .unwrap_or_else(|| {
996                        span_bug!(self.def_span, "no entry found for name: `{name:?}`")
997                    })
998                    .into()
999            }
1000            rty::Var::Bound(debruijn, breft) => {
1001                self.local_var_env
1002                    .get_late_bvar(*debruijn, breft.var)
1003                    .unwrap_or_else(|| {
1004                        span_bug!(self.def_span, "no entry found for late bound var: `{breft:?}`")
1005                    })
1006                    .into()
1007            }
1008            rty::Var::ConstGeneric(param) => fixpoint::Var::ConstGeneric(*param),
1009            rty::Var::EarlyParam(param) => fixpoint::Var::Param(*param),
1010            rty::Var::EVar(_) => {
1011                span_bug!(self.def_span, "unexpected evar: `{var:?}`")
1012            }
1013        }
1014    }
1015
1016    fn variant_to_fixpoint(
1017        &self,
1018        scx: &mut SortEncodingCtxt,
1019        enum_def_id: &DefId,
1020        idx: VariantIdx,
1021    ) -> fixpoint::Expr {
1022        let adt_id = scx.declare_adt(*enum_def_id);
1023        let var = fixpoint::Var::DataCtor(adt_id, idx);
1024        fixpoint::Expr::Var(var)
1025    }
1026
1027    fn fields_to_fixpoint(
1028        &mut self,
1029        flds: &[rty::Expr],
1030        scx: &mut SortEncodingCtxt,
1031    ) -> QueryResult<fixpoint::Expr> {
1032        // do not generate 1-tuples
1033        if let [fld] = flds {
1034            self.expr_to_fixpoint(fld, scx)
1035        } else {
1036            scx.declare_tuple(flds.len());
1037            let ctor = fixpoint::Expr::Var(fixpoint::Var::TupleCtor { arity: flds.len() });
1038            let args = flds
1039                .iter()
1040                .map(|fld| self.expr_to_fixpoint(fld, scx))
1041                .try_collect()?;
1042            Ok(fixpoint::Expr::App(Box::new(ctor), args))
1043        }
1044    }
1045
1046    fn expr_to_fixpoint(
1047        &mut self,
1048        expr: &rty::Expr,
1049        scx: &mut SortEncodingCtxt,
1050    ) -> QueryResult<fixpoint::Expr> {
1051        let e = match expr.kind() {
1052            rty::ExprKind::Var(var) => fixpoint::Expr::Var(self.var_to_fixpoint(var)),
1053            rty::ExprKind::Constant(c) => const_to_fixpoint(*c),
1054            rty::ExprKind::BinaryOp(op, e1, e2) => self.bin_op_to_fixpoint(op, e1, e2, scx)?,
1055            rty::ExprKind::UnaryOp(op, e) => self.un_op_to_fixpoint(*op, e, scx)?,
1056            rty::ExprKind::FieldProj(e, proj) => self.proj_to_fixpoint(e, *proj, scx)?,
1057            rty::ExprKind::Tuple(flds) => self.fields_to_fixpoint(flds, scx)?,
1058            rty::ExprKind::Ctor(rty::Ctor::Struct(_), flds) => {
1059                self.fields_to_fixpoint(flds, scx)?
1060            }
1061            rty::ExprKind::Ctor(rty::Ctor::Enum(did, idx), _) => {
1062                self.variant_to_fixpoint(scx, did, *idx)
1063            }
1064            rty::ExprKind::ConstDefId(did) => {
1065                let var = self.define_const_for_rust_const(*did, scx);
1066                fixpoint::Expr::Var(var)
1067            }
1068            rty::ExprKind::App(func, args) => {
1069                let func = self.expr_to_fixpoint(func, scx)?;
1070                let args = self.exprs_to_fixpoint(args, scx)?;
1071                fixpoint::Expr::App(Box::new(func), args)
1072            }
1073            rty::ExprKind::IfThenElse(p, e1, e2) => {
1074                fixpoint::Expr::IfThenElse(Box::new([
1075                    self.expr_to_fixpoint(p, scx)?,
1076                    self.expr_to_fixpoint(e1, scx)?,
1077                    self.expr_to_fixpoint(e2, scx)?,
1078                ]))
1079            }
1080            rty::ExprKind::Alias(alias_reft, args) => {
1081                let sort = self.genv.sort_of_assoc_reft(alias_reft.assoc_id)?;
1082                let sort = sort.instantiate_identity();
1083                let func =
1084                    fixpoint::Expr::Var(self.define_const_for_alias_reft(alias_reft, sort, scx));
1085                let args = args
1086                    .iter()
1087                    .map(|expr| self.expr_to_fixpoint(expr, scx))
1088                    .try_collect()?;
1089                fixpoint::Expr::App(Box::new(func), args)
1090            }
1091            rty::ExprKind::Abs(lam) => {
1092                let var = self.define_const_for_lambda(lam, scx);
1093                fixpoint::Expr::Var(var)
1094            }
1095            rty::ExprKind::Let(init, body) => {
1096                debug_assert_eq!(body.vars().len(), 1);
1097                let init = self.expr_to_fixpoint(init, scx)?;
1098
1099                self.local_var_env.push_layer_with_fresh_names(1);
1100                let body = self.expr_to_fixpoint(body.skip_binder_ref(), scx)?;
1101                let vars = self.local_var_env.pop_layer();
1102
1103                fixpoint::Expr::Let(vars[0].into(), Box::new([init, body]))
1104            }
1105            rty::ExprKind::GlobalFunc(SpecFuncKind::Thy(itf)) => {
1106                fixpoint::Expr::Var(fixpoint::Var::Itf(*itf))
1107            }
1108            rty::ExprKind::GlobalFunc(SpecFuncKind::Uif(def_id)) => {
1109                fixpoint::Expr::Var(self.define_const_for_uif(*def_id, scx))
1110            }
1111            rty::ExprKind::GlobalFunc(SpecFuncKind::Def(def_id)) => {
1112                fixpoint::Expr::Var(self.declare_fun(*def_id))
1113            }
1114            rty::ExprKind::Hole(..)
1115            | rty::ExprKind::KVar(_)
1116            | rty::ExprKind::Local(_)
1117            | rty::ExprKind::PathProj(..)
1118            | rty::ExprKind::ForAll(_) => {
1119                span_bug!(self.def_span, "unexpected expr: `{expr:?}`")
1120            }
1121            rty::ExprKind::BoundedQuant(kind, rng, body) => {
1122                let exprs = (rng.start..rng.end).map(|i| {
1123                    let arg = rty::Expr::constant(rty::Constant::from(i));
1124                    body.replace_bound_reft(&arg)
1125                });
1126                let expr = match kind {
1127                    flux_middle::fhir::QuantKind::Forall => rty::Expr::and_from_iter(exprs),
1128                    flux_middle::fhir::QuantKind::Exists => rty::Expr::or_from_iter(exprs),
1129                };
1130                self.expr_to_fixpoint(&expr, scx)?
1131            }
1132        };
1133        Ok(e)
1134    }
1135
1136    fn exprs_to_fixpoint<'b>(
1137        &mut self,
1138        exprs: impl IntoIterator<Item = &'b rty::Expr>,
1139        scx: &mut SortEncodingCtxt,
1140    ) -> QueryResult<Vec<fixpoint::Expr>> {
1141        exprs
1142            .into_iter()
1143            .map(|e| self.expr_to_fixpoint(e, scx))
1144            .try_collect()
1145    }
1146
1147    fn proj_to_fixpoint(
1148        &mut self,
1149        e: &rty::Expr,
1150        proj: rty::FieldProj,
1151        scx: &mut SortEncodingCtxt,
1152    ) -> QueryResult<fixpoint::Expr> {
1153        let arity = proj.arity(self.genv)?;
1154        // we encode 1-tuples as the single element inside so no projection necessary here
1155        if arity == 1 {
1156            self.expr_to_fixpoint(e, scx)
1157        } else {
1158            let field = proj.field_idx();
1159            scx.declare_tuple(arity);
1160            let proj = fixpoint::Var::TupleProj { arity, field };
1161            let proj = fixpoint::Expr::Var(proj);
1162            Ok(fixpoint::Expr::App(Box::new(proj), vec![self.expr_to_fixpoint(e, scx)?]))
1163        }
1164    }
1165
1166    fn un_op_to_fixpoint(
1167        &mut self,
1168        op: rty::UnOp,
1169        e: &rty::Expr,
1170        scx: &mut SortEncodingCtxt,
1171    ) -> QueryResult<fixpoint::Expr> {
1172        match op {
1173            rty::UnOp::Not => Ok(fixpoint::Expr::Not(Box::new(self.expr_to_fixpoint(e, scx)?))),
1174            rty::UnOp::Neg => Ok(fixpoint::Expr::Neg(Box::new(self.expr_to_fixpoint(e, scx)?))),
1175        }
1176    }
1177
1178    fn bv_rel_to_fixpoint(&self, rel: &fixpoint::BinRel) -> fixpoint::Expr {
1179        let itf = match rel {
1180            fixpoint::BinRel::Gt => fixpoint::ThyFunc::BvUgt,
1181            fixpoint::BinRel::Ge => fixpoint::ThyFunc::BvUge,
1182            fixpoint::BinRel::Lt => fixpoint::ThyFunc::BvUlt,
1183            fixpoint::BinRel::Le => fixpoint::ThyFunc::BvUle,
1184            _ => span_bug!(self.def_span, "not a bitvector relation!"),
1185        };
1186        fixpoint::Expr::Var(fixpoint::Var::Itf(itf))
1187    }
1188
1189    fn bv_op_to_fixpoint(&self, op: &rty::BinOp) -> fixpoint::Expr {
1190        let itf = match op {
1191            rty::BinOp::Add(_) => fixpoint::ThyFunc::BvAdd,
1192            rty::BinOp::Sub(_) => fixpoint::ThyFunc::BvSub,
1193            rty::BinOp::Mul(_) => fixpoint::ThyFunc::BvMul,
1194            rty::BinOp::Div(_) => fixpoint::ThyFunc::BvUdiv,
1195            rty::BinOp::Mod(_) => fixpoint::ThyFunc::BvUrem,
1196            rty::BinOp::BitAnd => fixpoint::ThyFunc::BvAnd,
1197            rty::BinOp::BitOr => fixpoint::ThyFunc::BvOr,
1198            rty::BinOp::BitShl => fixpoint::ThyFunc::BvShl,
1199            rty::BinOp::BitShr => fixpoint::ThyFunc::BvLshr,
1200            _ => span_bug!(self.def_span, "not a bitvector operation!"),
1201        };
1202        fixpoint::Expr::Var(fixpoint::Var::Itf(itf))
1203    }
1204
1205    fn bin_op_to_fixpoint(
1206        &mut self,
1207        op: &rty::BinOp,
1208        e1: &rty::Expr,
1209        e2: &rty::Expr,
1210        scx: &mut SortEncodingCtxt,
1211    ) -> QueryResult<fixpoint::Expr> {
1212        let op = match op {
1213            rty::BinOp::Eq => {
1214                return Ok(fixpoint::Expr::Atom(
1215                    fixpoint::BinRel::Eq,
1216                    Box::new([self.expr_to_fixpoint(e1, scx)?, self.expr_to_fixpoint(e2, scx)?]),
1217                ));
1218            }
1219            rty::BinOp::Ne => {
1220                return Ok(fixpoint::Expr::Atom(
1221                    fixpoint::BinRel::Ne,
1222                    Box::new([self.expr_to_fixpoint(e1, scx)?, self.expr_to_fixpoint(e2, scx)?]),
1223                ));
1224            }
1225            rty::BinOp::Gt(sort) => {
1226                return self.bin_rel_to_fixpoint(sort, fixpoint::BinRel::Gt, e1, e2, scx);
1227            }
1228            rty::BinOp::Ge(sort) => {
1229                return self.bin_rel_to_fixpoint(sort, fixpoint::BinRel::Ge, e1, e2, scx);
1230            }
1231            rty::BinOp::Lt(sort) => {
1232                return self.bin_rel_to_fixpoint(sort, fixpoint::BinRel::Lt, e1, e2, scx);
1233            }
1234            rty::BinOp::Le(sort) => {
1235                return self.bin_rel_to_fixpoint(sort, fixpoint::BinRel::Le, e1, e2, scx);
1236            }
1237            rty::BinOp::And => {
1238                return Ok(fixpoint::Expr::And(vec![
1239                    self.expr_to_fixpoint(e1, scx)?,
1240                    self.expr_to_fixpoint(e2, scx)?,
1241                ]));
1242            }
1243            rty::BinOp::Or => {
1244                return Ok(fixpoint::Expr::Or(vec![
1245                    self.expr_to_fixpoint(e1, scx)?,
1246                    self.expr_to_fixpoint(e2, scx)?,
1247                ]));
1248            }
1249            rty::BinOp::Imp => {
1250                return Ok(fixpoint::Expr::Imp(Box::new([
1251                    self.expr_to_fixpoint(e1, scx)?,
1252                    self.expr_to_fixpoint(e2, scx)?,
1253                ])));
1254            }
1255            rty::BinOp::Iff => {
1256                return Ok(fixpoint::Expr::Iff(Box::new([
1257                    self.expr_to_fixpoint(e1, scx)?,
1258                    self.expr_to_fixpoint(e2, scx)?,
1259                ])));
1260            }
1261            rty::BinOp::Add(rty::Sort::BitVec(_))
1262            | rty::BinOp::Sub(rty::Sort::BitVec(_))
1263            | rty::BinOp::Mul(rty::Sort::BitVec(_))
1264            | rty::BinOp::Div(rty::Sort::BitVec(_))
1265            | rty::BinOp::Mod(rty::Sort::BitVec(_))
1266            | rty::BinOp::BitAnd
1267            | rty::BinOp::BitOr
1268            | rty::BinOp::BitShl
1269            | rty::BinOp::BitShr => {
1270                let bv_func = self.bv_op_to_fixpoint(op);
1271                return Ok(fixpoint::Expr::App(
1272                    Box::new(bv_func),
1273                    vec![self.expr_to_fixpoint(e1, scx)?, self.expr_to_fixpoint(e2, scx)?],
1274                ));
1275            }
1276            rty::BinOp::Add(_) => fixpoint::BinOp::Add,
1277            rty::BinOp::Sub(_) => fixpoint::BinOp::Sub,
1278            rty::BinOp::Mul(_) => fixpoint::BinOp::Mul,
1279            rty::BinOp::Div(_) => fixpoint::BinOp::Div,
1280            rty::BinOp::Mod(_) => fixpoint::BinOp::Mod,
1281        };
1282        Ok(fixpoint::Expr::BinaryOp(
1283            op,
1284            Box::new([self.expr_to_fixpoint(e1, scx)?, self.expr_to_fixpoint(e2, scx)?]),
1285        ))
1286    }
1287
1288    /// A binary relation is encoded as a structurally recursive relation between aggregate sorts.
1289    /// For "leaf" expressions, we encode them as an interpreted relation if the sort supports it,
1290    /// otherwise we use an uninterpreted function. For example, consider the following relation
1291    /// between two tuples of sort `(int, int -> int)`
1292    /// ```text
1293    /// (0, λv. v + 1) <= (1, λv. v + 1)
1294    /// ```
1295    /// The encoding in fixpoint will be
1296    ///
1297    /// ```text
1298    /// 0 <= 1 && (le (λv. v + 1) (λv. v + 1))
1299    /// ```
1300    /// Where `<=` is the (interpreted) less than or equal relation between integers and `le` is
1301    /// an uninterpreted relation between ([the encoding] of) lambdas.
1302    ///
1303    /// [the encoding]: Self::define_const_for_lambda
1304    fn bin_rel_to_fixpoint(
1305        &mut self,
1306        sort: &rty::Sort,
1307        rel: fixpoint::BinRel,
1308        e1: &rty::Expr,
1309        e2: &rty::Expr,
1310        scx: &mut SortEncodingCtxt,
1311    ) -> QueryResult<fixpoint::Expr> {
1312        let e = match sort {
1313            rty::Sort::Int | rty::Sort::Real | rty::Sort::Char => {
1314                fixpoint::Expr::Atom(
1315                    rel,
1316                    Box::new([self.expr_to_fixpoint(e1, scx)?, self.expr_to_fixpoint(e2, scx)?]),
1317                )
1318            }
1319            rty::Sort::BitVec(_) => {
1320                let e1 = self.expr_to_fixpoint(e1, scx)?;
1321                let e2 = self.expr_to_fixpoint(e2, scx)?;
1322                let rel = self.bv_rel_to_fixpoint(&rel);
1323                fixpoint::Expr::App(Box::new(rel), vec![e1, e2])
1324            }
1325            rty::Sort::Tuple(sorts) => {
1326                let arity = sorts.len();
1327                self.apply_bin_rel_rec(sorts, rel, e1, e2, scx, |field| {
1328                    rty::FieldProj::Tuple { arity, field }
1329                })?
1330            }
1331            rty::Sort::App(rty::SortCtor::Adt(sort_def), args)
1332                if let Some(variant) = sort_def.opt_struct_variant() =>
1333            {
1334                let def_id = sort_def.did();
1335                let sorts = variant.field_sorts(args);
1336                self.apply_bin_rel_rec(&sorts, rel, e1, e2, scx, |field| {
1337                    rty::FieldProj::Adt { def_id, field }
1338                })?
1339            }
1340            _ => {
1341                let rel = fixpoint::Expr::Var(fixpoint::Var::UIFRel(rel));
1342                fixpoint::Expr::App(
1343                    Box::new(rel),
1344                    vec![self.expr_to_fixpoint(e1, scx)?, self.expr_to_fixpoint(e2, scx)?],
1345                )
1346            }
1347        };
1348        Ok(e)
1349    }
1350
1351    /// Apply binary relation recursively over aggregate expressions
1352    fn apply_bin_rel_rec(
1353        &mut self,
1354        sorts: &[rty::Sort],
1355        rel: fixpoint::BinRel,
1356        e1: &rty::Expr,
1357        e2: &rty::Expr,
1358        scx: &mut SortEncodingCtxt,
1359        mk_proj: impl Fn(u32) -> rty::FieldProj,
1360    ) -> QueryResult<fixpoint::Expr> {
1361        Ok(fixpoint::Expr::and(
1362            sorts
1363                .iter()
1364                .enumerate()
1365                .map(|(idx, s)| {
1366                    let proj = mk_proj(idx as u32);
1367                    let e1 = e1.proj_and_reduce(proj);
1368                    let e2 = e2.proj_and_reduce(proj);
1369                    self.bin_rel_to_fixpoint(s, rel, &e1, &e2, scx)
1370                })
1371                .try_collect()?,
1372        ))
1373    }
1374
1375    fn imm(
1376        &mut self,
1377        arg: &rty::Expr,
1378        sort: &rty::Sort,
1379        scx: &mut SortEncodingCtxt,
1380        bindings: &mut Vec<fixpoint::Bind>,
1381    ) -> QueryResult<fixpoint::Var> {
1382        let arg = self.expr_to_fixpoint(arg, scx)?;
1383        // Check if it's a variable after encoding, in case the encoding produced a variable from a
1384        // non-variable.
1385        if let fixpoint::Expr::Var(var) = arg {
1386            Ok(var)
1387        } else {
1388            let fresh = self.local_var_env.fresh_name();
1389            let pred = fixpoint::Expr::eq(fixpoint::Expr::Var(fresh.into()), arg);
1390            bindings.push(fixpoint::Bind {
1391                name: fresh.into(),
1392                sort: scx.sort_to_fixpoint(sort),
1393                pred: fixpoint::Pred::Expr(pred),
1394            });
1395            Ok(fresh.into())
1396        }
1397    }
1398
1399    /// Declare that the `def_id` of a Flux function definition needs to be encoded and assigns
1400    /// a name to it if it hasn't yet been declared. The encoding of the function body happens
1401    /// in [`Self::define_funs`].
1402    fn declare_fun(&mut self, def_id: FluxDefId) -> fixpoint::Var {
1403        *self.fun_def_map.entry(def_id).or_insert_with(|| {
1404            let id = self.global_var_gen.fresh();
1405            fixpoint::Var::Global(id, Some(def_id.name()))
1406        })
1407    }
1408
1409    fn define_const_for_uif(
1410        &mut self,
1411        def_id: FluxDefId,
1412        scx: &mut SortEncodingCtxt,
1413    ) -> fixpoint::Var {
1414        let key = ConstKey::Uif(def_id);
1415        self.const_map
1416            .entry(key)
1417            .or_insert_with(|| {
1418                let sort = scx.func_sort_to_fixpoint(&self.genv.func_sort(def_id));
1419                fixpoint::ConstDecl {
1420                    name: fixpoint::Var::Global(self.global_var_gen.fresh(), Some(def_id.name())),
1421                    sort,
1422                    comment: Some(format!("uif: {def_id:?}")),
1423                }
1424            })
1425            .name
1426    }
1427
1428    fn define_const_for_rust_const(
1429        &mut self,
1430        def_id: DefId,
1431        scx: &mut SortEncodingCtxt,
1432    ) -> fixpoint::Var {
1433        let key = ConstKey::RustConst(def_id);
1434        self.const_map
1435            .entry(key)
1436            .or_insert_with(|| {
1437                let sort = self.genv.sort_of_def_id(def_id).unwrap().unwrap();
1438                fixpoint::ConstDecl {
1439                    name: fixpoint::Var::Global(self.global_var_gen.fresh(), None),
1440                    sort: scx.sort_to_fixpoint(&sort),
1441                    comment: Some(format!("rust const: {}", def_id_to_string(def_id))),
1442                }
1443            })
1444            .name
1445    }
1446
1447    /// returns the 'constant' UIF for Var used to represent the alias_pred, creating and adding it
1448    /// to the const_map if necessary
1449    fn define_const_for_alias_reft(
1450        &mut self,
1451        alias_reft: &rty::AliasReft,
1452        fsort: rty::FuncSort,
1453        scx: &mut SortEncodingCtxt,
1454    ) -> fixpoint::Var {
1455        let tcx = self.genv.tcx();
1456        let args = alias_reft
1457            .args
1458            .to_rustc(tcx)
1459            .truncate_to(tcx, tcx.generics_of(alias_reft.assoc_id.parent()));
1460        let key = ConstKey::Alias(alias_reft.assoc_id, args);
1461        self.const_map
1462            .entry(key)
1463            .or_insert_with(|| {
1464                let comment = Some(format!("alias reft: {alias_reft:?}"));
1465                let name = fixpoint::Var::Global(self.global_var_gen.fresh(), None);
1466                let fsort = rty::PolyFuncSort::new(List::empty(), fsort);
1467                let sort = scx.func_sort_to_fixpoint(&fsort);
1468                fixpoint::ConstDecl { name, comment, sort }
1469            })
1470            .name
1471    }
1472
1473    /// We encode lambdas with uninterpreted constant. Two syntactically equal lambdas will be encoded
1474    /// with the same constant.
1475    fn define_const_for_lambda(
1476        &mut self,
1477        lam: &rty::Lambda,
1478        scx: &mut SortEncodingCtxt,
1479    ) -> fixpoint::Var {
1480        let key = ConstKey::Lambda(lam.clone());
1481        self.const_map
1482            .entry(key)
1483            .or_insert_with(|| {
1484                let comment = Some(format!("lambda: {lam:?}"));
1485                let name = fixpoint::Var::Global(self.global_var_gen.fresh(), None);
1486                let sort = scx.func_sort_to_fixpoint(&lam.fsort().to_poly());
1487                fixpoint::ConstDecl { name, comment, sort }
1488            })
1489            .name
1490    }
1491
1492    fn assume_const_values(
1493        &mut self,
1494        mut constraint: fixpoint::Constraint,
1495        scx: &mut SortEncodingCtxt,
1496    ) -> QueryResult<fixpoint::Constraint> {
1497        // Encoding the value for a constant could in theory define more constants for which
1498        // we need to assume values, so we iterate until there are no more constants.
1499        let mut idx = 0;
1500        while let Some((key, const_)) = self.const_map.get_index(idx) {
1501            idx += 1;
1502
1503            let ConstKey::RustConst(def_id) = key else { continue };
1504            let info = self.genv.constant_info(def_id)?;
1505            match info {
1506                rty::ConstantInfo::Uninterpreted => {}
1507                rty::ConstantInfo::Interpreted(val, _) => {
1508                    let e1 = fixpoint::Expr::Var(const_.name);
1509                    let e2 = self.expr_to_fixpoint(&val, scx)?;
1510                    let pred = fixpoint::Pred::Expr(e1.eq(e2));
1511                    constraint = fixpoint::Constraint::ForAll(
1512                        fixpoint::Bind {
1513                            name: fixpoint::Var::Underscore,
1514                            sort: fixpoint::Sort::Int,
1515                            pred,
1516                        },
1517                        Box::new(constraint),
1518                    );
1519                }
1520            }
1521        }
1522        Ok(constraint)
1523    }
1524
1525    fn qualifiers_for(
1526        &mut self,
1527        def_id: LocalDefId,
1528        scx: &mut SortEncodingCtxt,
1529    ) -> QueryResult<Vec<fixpoint::Qualifier>> {
1530        self.genv
1531            .qualifiers_for(def_id)?
1532            .map(|qual| self.qualifier_to_fixpoint(qual, scx))
1533            .try_collect()
1534    }
1535
1536    fn define_funs(
1537        &mut self,
1538        def_id: MaybeExternId,
1539        scx: &mut SortEncodingCtxt,
1540    ) -> QueryResult<(Vec<fixpoint::FunDef>, Vec<fixpoint::ConstDecl>)> {
1541        let reveals: UnordSet<FluxDefId> = self.genv.reveals_for(def_id.local_id())?.collect();
1542        let mut consts = vec![];
1543        let mut defs = vec![];
1544
1545        // We iterate until encoding the body of functions doesn't require any more functions
1546        // to be encoded.
1547        let mut idx = 0;
1548        while let Some((&did, &name)) = self.fun_def_map.get_index(idx) {
1549            idx += 1;
1550
1551            let comment = format!("flux def: {did:?}");
1552            let info = self.genv.normalized_info(did);
1553            let revealed = reveals.contains(&did);
1554            if info.hide && !revealed {
1555                let sort = scx.func_sort_to_fixpoint(&self.genv.func_sort(did));
1556                consts.push(fixpoint::ConstDecl { name, sort, comment: Some(comment) });
1557            } else {
1558                let out = scx.sort_to_fixpoint(self.genv.func_sort(did).expect_mono().output());
1559                let (args, body) = self.body_to_fixpoint(&info.body, scx)?;
1560                let fun_def = fixpoint::FunDef { name, args, body, out, comment: Some(comment) };
1561                defs.push((info.rank, fun_def));
1562            };
1563        }
1564
1565        // we sort by rank so the definitions go out without any forward dependencies.
1566        let defs = defs
1567            .into_iter()
1568            .sorted_by_key(|(rank, _)| *rank)
1569            .map(|(_, def)| def)
1570            .collect();
1571
1572        Ok((defs, consts))
1573    }
1574
1575    fn body_to_fixpoint(
1576        &mut self,
1577        body: &rty::Binder<rty::Expr>,
1578        scx: &mut SortEncodingCtxt,
1579    ) -> QueryResult<(Vec<(fixpoint::Var, fixpoint::Sort)>, fixpoint::Expr)> {
1580        self.local_var_env
1581            .push_layer_with_fresh_names(body.vars().len());
1582
1583        let expr = self.expr_to_fixpoint(body.as_ref().skip_binder(), scx)?;
1584
1585        let args: Vec<(fixpoint::Var, fixpoint::Sort)> =
1586            iter::zip(self.local_var_env.pop_layer(), body.vars())
1587                .map(|(name, var)| (name.into(), scx.sort_to_fixpoint(var.expect_sort())))
1588                .collect();
1589
1590        Ok((args, expr))
1591    }
1592
1593    fn qualifier_to_fixpoint(
1594        &mut self,
1595        qualifier: &rty::Qualifier,
1596        scx: &mut SortEncodingCtxt,
1597    ) -> QueryResult<fixpoint::Qualifier> {
1598        let (args, body) = self.body_to_fixpoint(&qualifier.body, scx)?;
1599        let name = qualifier.def_id.name().to_string();
1600        Ok(fixpoint::Qualifier { name, args, body })
1601    }
1602}
1603
1604fn mk_implies(assumption: fixpoint::Pred, cstr: fixpoint::Constraint) -> fixpoint::Constraint {
1605    fixpoint::Constraint::ForAll(
1606        fixpoint::Bind {
1607            name: fixpoint::Var::Underscore,
1608            sort: fixpoint::Sort::Int,
1609            pred: assumption,
1610        },
1611        Box::new(cstr),
1612    )
1613}