flux_infer/
fixpoint_encoding.rs

1//! Encoding of the refinement tree into a fixpoint constraint.
2
3use std::{hash::Hash, iter};
4
5use fixpoint::AdtId;
6use flux_common::{
7    bug,
8    cache::QueryCache,
9    dbg,
10    index::{IndexGen, IndexVec},
11    iter::IterExt,
12    span_bug, tracked_span_bug,
13};
14use flux_config::{self as config};
15use flux_errors::Errors;
16use flux_middle::{
17    FixpointQueryKind,
18    def_id::{FluxDefId, MaybeExternId},
19    def_id_to_string,
20    global_env::GlobalEnv,
21    queries::QueryResult,
22    rty::{self, ESpan, GenericArgsExt, InternalFuncKind, Lambda, List, SpecFuncKind, VariantIdx},
23    timings::{self, TimingKind},
24};
25use itertools::Itertools;
26use liquid_fixpoint::{FixpointResult, SmtSolver};
27use rustc_data_structures::{
28    fx::{FxIndexMap, FxIndexSet},
29    unord::{UnordMap, UnordSet},
30};
31use rustc_hir::def_id::{DefId, LocalDefId};
32use rustc_index::newtype_index;
33use rustc_infer::infer::TyCtxtInferExt as _;
34use rustc_middle::ty::TypingMode;
35use rustc_span::Span;
36use rustc_type_ir::{BoundVar, DebruijnIndex};
37use serde::{Deserialize, Deserializer, Serialize};
38
39use crate::projections::structurally_normalize_expr;
40
41pub mod fixpoint {
42    use std::fmt;
43
44    use flux_middle::rty::{EarlyReftParam, Real};
45    use liquid_fixpoint::{FixpointFmt, Identifier};
46    use rustc_abi::VariantIdx;
47    use rustc_index::newtype_index;
48    use rustc_middle::ty::ParamConst;
49    use rustc_span::Symbol;
50
51    newtype_index! {
52        pub struct KVid {}
53    }
54
55    impl Identifier for KVid {
56        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57            write!(f, "k{}", self.as_u32())
58        }
59    }
60
61    newtype_index! {
62        pub struct LocalVar {}
63    }
64
65    newtype_index! {
66        pub struct GlobalVar {}
67    }
68
69    newtype_index! {
70        /// Unique id assigned to each [`flux_middle::rty::AdtSortDef`] that needs to be encoded
71        /// into fixpoint
72        pub struct AdtId {}
73    }
74
75    #[derive(Hash, Copy, Clone, Debug, PartialEq)]
76    pub enum Var {
77        Underscore,
78        Global(GlobalVar, Option<Symbol>),
79        Local(LocalVar),
80        DataCtor(AdtId, VariantIdx),
81        TupleCtor {
82            arity: usize,
83        },
84        TupleProj {
85            arity: usize,
86            field: u32,
87        },
88        UIFRel(BinRel),
89        /// Interpreted theory function. This can be an arbitrary string, thus we are assuming the
90        /// name is different than the display implementation for the other variants.
91        Itf(liquid_fixpoint::ThyFunc),
92        Param(EarlyReftParam),
93        ConstGeneric(ParamConst),
94    }
95
96    impl From<GlobalVar> for Var {
97        fn from(v: GlobalVar) -> Self {
98            Self::Global(v, None)
99        }
100    }
101
102    impl From<LocalVar> for Var {
103        fn from(v: LocalVar) -> Self {
104            Self::Local(v)
105        }
106    }
107
108    impl Identifier for Var {
109        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110            match self {
111                Var::Global(v, None) => write!(f, "c{}", v.as_u32()),
112                Var::Global(v, Some(sym)) => write!(f, "f${}${}", sym, v.as_u32()),
113                Var::Local(v) => write!(f, "a{}", v.as_u32()),
114                Var::DataCtor(adt_id, variant_idx) => {
115                    write!(f, "mkadt{}${}", adt_id.as_u32(), variant_idx.as_u32())
116                }
117                Var::TupleCtor { arity } => write!(f, "mktuple{arity}"),
118                Var::TupleProj { arity, field } => write!(f, "tuple{arity}${field}"),
119                Var::Itf(name) => write!(f, "{name}"),
120                Var::UIFRel(BinRel::Gt) => write!(f, "gt"),
121                Var::UIFRel(BinRel::Ge) => write!(f, "ge"),
122                Var::UIFRel(BinRel::Lt) => write!(f, "lt"),
123                Var::UIFRel(BinRel::Le) => write!(f, "le"),
124                // these are actually not necessary because equality is interpreted for all sorts
125                Var::UIFRel(BinRel::Eq) => write!(f, "eq"),
126                Var::UIFRel(BinRel::Ne) => write!(f, "ne"),
127                Var::Underscore => write!(f, "_$"), // To avoid clashing with `_` used for `app (_ bv_op n)` for parametric SMT ops
128                Var::ConstGeneric(param) => {
129                    write!(f, "constgen${}${}", param.name, param.index)
130                }
131                Var::Param(param) => {
132                    write!(f, "reftgen${}${}", param.name, param.index)
133                }
134            }
135        }
136    }
137
138    #[derive(Clone, Hash, Debug)]
139    pub enum DataSort {
140        Tuple(usize),
141        Adt(AdtId),
142    }
143
144    impl Identifier for DataSort {
145        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146            match self {
147                DataSort::Tuple(arity) => {
148                    write!(f, "Tuple{arity}")
149                }
150                DataSort::Adt(adt_id) => {
151                    write!(f, "Adt{}", adt_id.as_u32())
152                }
153            }
154        }
155    }
156
157    #[derive(Hash, Clone, Debug)]
158    pub struct SymStr(pub Symbol);
159
160    impl FixpointFmt for SymStr {
161        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
162            write!(f, "\"{}\"", self.0)
163        }
164    }
165
166    liquid_fixpoint::declare_types! {
167        type Sort = DataSort;
168        type KVar = KVid;
169        type Var = Var;
170        type Decimal = Real;
171        type String = SymStr;
172        type Tag = super::TagIdx;
173    }
174    pub use fixpoint_generated::*;
175}
176
177newtype_index! {
178    #[debug_format = "TagIdx({})"]
179    pub struct TagIdx {}
180}
181
182impl Serialize for TagIdx {
183    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
184        self.as_u32().serialize(serializer)
185    }
186}
187
188impl<'de> Deserialize<'de> for TagIdx {
189    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
190        let idx = usize::deserialize(deserializer)?;
191        Ok(TagIdx::from_u32(idx as u32))
192    }
193}
194
195/// Keep track of all the data sorts that we need to define in fixpoint to encode the constraint.
196#[derive(Default)]
197struct SortEncodingCtxt {
198    /// Set of all the tuple arities that need to be defined
199    tuples: UnordSet<usize>,
200    /// Set of all the [`AdtDefSortDef`](flux_middle::rty::AdtSortDef) that need to be declared as
201    /// Fixpoint data-decls
202    adt_sorts: FxIndexSet<DefId>,
203}
204
205impl SortEncodingCtxt {
206    fn sort_to_fixpoint(&mut self, sort: &rty::Sort) -> fixpoint::Sort {
207        match sort {
208            rty::Sort::Int => fixpoint::Sort::Int,
209            rty::Sort::Real => fixpoint::Sort::Real,
210            rty::Sort::Bool => fixpoint::Sort::Bool,
211            rty::Sort::Str => fixpoint::Sort::Str,
212            rty::Sort::Char => fixpoint::Sort::Int,
213            rty::Sort::BitVec(size) => fixpoint::Sort::BitVec(Box::new(bv_size_to_fixpoint(*size))),
214            // There's no way to declare opaque sorts in the fixpoint horn syntax so we encode user
215            // declared opaque sorts, type parameter sorts, and (unormalizable) type alias sorts as
216            // integers. Well-formedness should ensure values of these sorts are used "opaquely",
217            // i.e., the only values of these sorts are variables.
218            rty::Sort::App(rty::SortCtor::User { .. }, _)
219            | rty::Sort::Param(_)
220            | rty::Sort::Alias(rty::AliasKind::Opaque | rty::AliasKind::Projection, ..) => {
221                fixpoint::Sort::Int
222            }
223            rty::Sort::App(rty::SortCtor::Set, args) => {
224                let args = args.iter().map(|s| self.sort_to_fixpoint(s)).collect_vec();
225                fixpoint::Sort::App(fixpoint::SortCtor::Set, args)
226            }
227            rty::Sort::App(rty::SortCtor::Map, args) => {
228                let args = args.iter().map(|s| self.sort_to_fixpoint(s)).collect_vec();
229                fixpoint::Sort::App(fixpoint::SortCtor::Map, args)
230            }
231            rty::Sort::App(rty::SortCtor::Adt(sort_def), args) => {
232                if let Some(variant) = sort_def.opt_struct_variant() {
233                    let sorts = variant.field_sorts(args);
234                    // do not generate 1-tuples
235                    if let [sort] = &sorts[..] {
236                        self.sort_to_fixpoint(sort)
237                    } else {
238                        self.declare_tuple(sorts.len());
239                        let ctor = fixpoint::SortCtor::Data(fixpoint::DataSort::Tuple(sorts.len()));
240                        let args = sorts.iter().map(|s| self.sort_to_fixpoint(s)).collect_vec();
241                        fixpoint::Sort::App(ctor, args)
242                    }
243                } else {
244                    debug_assert!(args.is_empty());
245                    let adt_id = self.declare_adt(sort_def.did());
246                    fixpoint::Sort::App(
247                        fixpoint::SortCtor::Data(fixpoint::DataSort::Adt(adt_id)),
248                        vec![],
249                    )
250                }
251            }
252            rty::Sort::Tuple(sorts) => {
253                // do not generate 1-tuples
254                if let [sort] = &sorts[..] {
255                    self.sort_to_fixpoint(sort)
256                } else {
257                    self.declare_tuple(sorts.len());
258                    let ctor = fixpoint::SortCtor::Data(fixpoint::DataSort::Tuple(sorts.len()));
259                    let args = sorts.iter().map(|s| self.sort_to_fixpoint(s)).collect();
260                    fixpoint::Sort::App(ctor, args)
261                }
262            }
263            rty::Sort::Func(sort) => self.func_sort_to_fixpoint(sort),
264            rty::Sort::Var(k) => fixpoint::Sort::Var(k.index()),
265            rty::Sort::Err
266            | rty::Sort::Infer(_)
267            | rty::Sort::Loc
268            | rty::Sort::Alias(rty::AliasKind::Free, _) => {
269                tracked_span_bug!("unexpected sort `{sort:?}`")
270            }
271        }
272    }
273
274    fn func_sort_to_fixpoint(&mut self, fsort: &rty::PolyFuncSort) -> fixpoint::Sort {
275        let params = fsort.params().len();
276        let fsort = fsort.skip_binders();
277        let output = self.sort_to_fixpoint(fsort.output());
278        fixpoint::Sort::mk_func(
279            params,
280            fsort.inputs().iter().map(|s| self.sort_to_fixpoint(s)),
281            output,
282        )
283    }
284
285    fn declare_tuple(&mut self, arity: usize) {
286        self.tuples.insert(arity);
287    }
288
289    pub fn declare_adt(&mut self, did: DefId) -> AdtId {
290        if let Some(idx) = self.adt_sorts.get_index_of(&did) {
291            AdtId::from_usize(idx)
292        } else {
293            let adt_id = AdtId::from_usize(self.adt_sorts.len());
294            self.adt_sorts.insert(did);
295            adt_id
296        }
297    }
298
299    fn append_adt_decls(
300        genv: GlobalEnv,
301        adt_sorts: FxIndexSet<DefId>,
302        decls: &mut Vec<fixpoint::DataDecl>,
303    ) -> QueryResult {
304        for (idx, adt_def_id) in adt_sorts.iter().enumerate() {
305            let adt_id = AdtId::from_usize(idx);
306            let adt_sort_def = genv.adt_sort_def_of(adt_def_id)?;
307            decls.push(fixpoint::DataDecl {
308                name: fixpoint::DataSort::Adt(adt_id),
309                vars: adt_sort_def.param_count(),
310                ctors: adt_sort_def
311                    .variants()
312                    .iter_enumerated()
313                    .map(|(idx, variant)| {
314                        debug_assert_eq!(variant.fields(), 0);
315                        fixpoint::DataCtor {
316                            name: fixpoint::Var::DataCtor(adt_id, idx),
317                            fields: vec![],
318                        }
319                    })
320                    .collect(),
321            });
322        }
323        Ok(())
324    }
325
326    fn append_tuple_decls(tuples: UnordSet<usize>, decls: &mut Vec<fixpoint::DataDecl>) {
327        decls.extend(
328            tuples
329                .into_items()
330                .into_sorted_stable_ord()
331                .into_iter()
332                .map(|arity| {
333                    fixpoint::DataDecl {
334                        name: fixpoint::DataSort::Tuple(arity),
335                        vars: arity,
336                        ctors: vec![fixpoint::DataCtor {
337                            name: fixpoint::Var::TupleCtor { arity },
338                            fields: (0..(arity as u32))
339                                .map(|field| {
340                                    fixpoint::DataField {
341                                        name: fixpoint::Var::TupleProj { arity, field },
342                                        sort: fixpoint::Sort::Var(field as usize),
343                                    }
344                                })
345                                .collect(),
346                        }],
347                    }
348                }),
349        );
350    }
351
352    fn into_data_decls(self, genv: GlobalEnv) -> QueryResult<Vec<fixpoint::DataDecl>> {
353        let mut decls = vec![];
354        Self::append_tuple_decls(self.tuples, &mut decls);
355        Self::append_adt_decls(genv, self.adt_sorts, &mut decls)?;
356        Ok(decls)
357    }
358}
359
360fn bv_size_to_fixpoint(size: rty::BvSize) -> fixpoint::Sort {
361    match size {
362        rty::BvSize::Fixed(size) => fixpoint::Sort::BvSize(size),
363        rty::BvSize::Param(_var) => {
364            // I think we could encode the size as a sort variable, but this would require some care
365            // because smtlib doesn't really support parametric sizes. Fixpoint is probably already
366            // too liberal about this and it'd be easy to make it crash.
367            // fixpoint::Sort::Var(var.index)
368            bug!("unexpected parametric bit-vector size")
369        }
370        rty::BvSize::Infer(_) => bug!("unexpected infer variable for bit-vector size"),
371    }
372}
373
374type FunDefMap = FxIndexMap<FluxDefId, fixpoint::Var>;
375type ConstMap<'tcx> = FxIndexMap<ConstKey<'tcx>, fixpoint::ConstDecl>;
376
377#[derive(Eq, Hash, PartialEq)]
378enum ConstKey<'tcx> {
379    Uif(FluxDefId),
380    RustConst(DefId),
381    Alias(FluxDefId, rustc_middle::ty::GenericArgsRef<'tcx>),
382    Lambda(Lambda),
383    PrimOp(rty::BinOp),
384    Cast(rty::Sort, rty::Sort),
385}
386
387pub struct FixpointCtxt<'genv, 'tcx, T: Eq + Hash> {
388    comments: Vec<String>,
389    genv: GlobalEnv<'genv, 'tcx>,
390    kvars: KVarGen,
391    scx: SortEncodingCtxt,
392    kcx: KVarEncodingCtxt,
393    ecx: ExprEncodingCtxt<'genv, 'tcx>,
394    tags: IndexVec<TagIdx, T>,
395    tags_inv: UnordMap<T, TagIdx>,
396}
397
398pub type FixQueryCache = QueryCache<FixpointResult<TagIdx>>;
399
400impl<'genv, 'tcx, Tag> FixpointCtxt<'genv, 'tcx, Tag>
401where
402    Tag: std::hash::Hash + Eq + Copy,
403{
404    pub fn new(genv: GlobalEnv<'genv, 'tcx>, def_id: MaybeExternId, kvars: KVarGen) -> Self {
405        Self {
406            comments: vec![],
407            kvars,
408            scx: SortEncodingCtxt::default(),
409            genv,
410            ecx: ExprEncodingCtxt::new(genv, def_id),
411            kcx: Default::default(),
412            tags: IndexVec::new(),
413            tags_inv: Default::default(),
414        }
415    }
416
417    pub fn check(
418        mut self,
419        cache: &mut FixQueryCache,
420        constraint: fixpoint::Constraint,
421        kind: FixpointQueryKind,
422        scrape_quals: bool,
423        solver: SmtSolver,
424    ) -> QueryResult<Vec<Tag>> {
425        // skip checking trivial constraints
426        if !constraint.is_concrete() {
427            self.ecx.errors.into_result()?;
428            return Ok(vec![]);
429        }
430        let def_span = self.ecx.def_span();
431        let def_id = self.ecx.def_id;
432
433        let kvars = self.kcx.into_fixpoint();
434
435        let (define_funs, define_constants) = self.ecx.define_funs(def_id, &mut self.scx)?;
436        let qualifiers = self.ecx.qualifiers_for(def_id.local_id(), &mut self.scx)?;
437
438        // Assuming values should happen after all encoding is done so we are sure we've collected
439        // all constants.
440        let constraint = self.ecx.assume_const_values(constraint, &mut self.scx)?;
441
442        let mut constants = self.ecx.const_map.into_values().collect_vec();
443        constants.extend(define_constants);
444
445        for rel in fixpoint::BinRel::INEQUALITIES {
446            // ∀a. a -> a -> bool
447            let sort = fixpoint::Sort::mk_func(
448                1,
449                [fixpoint::Sort::Var(0), fixpoint::Sort::Var(0)],
450                fixpoint::Sort::Bool,
451            );
452            constants.push(fixpoint::ConstDecl {
453                name: fixpoint::Var::UIFRel(rel),
454                sort,
455                comment: None,
456            });
457        }
458
459        // We are done encoding expressions. Check if there are any errors.
460        self.ecx.errors.into_result()?;
461
462        let task = fixpoint::Task {
463            comments: self.comments,
464            constants,
465            kvars,
466            define_funs,
467            constraint,
468            qualifiers,
469            scrape_quals,
470            solver,
471            data_decls: self.scx.into_data_decls(self.genv)?,
472        };
473        if config::dump_constraint() {
474            dbg::dump_item_info(self.genv.tcx(), def_id.resolved_id(), "smt2", &task).unwrap();
475        }
476
477        match Self::run_task_with_cache(self.genv, task, def_id.resolved_id(), kind, cache) {
478            FixpointResult::Safe(_) => Ok(vec![]),
479            FixpointResult::Unsafe(_, errors) => {
480                Ok(errors
481                    .into_iter()
482                    .map(|err| self.tags[err.tag])
483                    .unique()
484                    .collect_vec())
485            }
486            FixpointResult::Crash(err) => span_bug!(def_span, "fixpoint crash: {err:?}"),
487        }
488    }
489
490    fn run_task_with_cache(
491        genv: GlobalEnv,
492        task: fixpoint::Task,
493        def_id: DefId,
494        kind: FixpointQueryKind,
495        cache: &mut FixQueryCache,
496    ) -> FixpointResult<TagIdx> {
497        let key = kind.task_key(genv.tcx(), def_id);
498
499        let hash = task.hash_with_default();
500
501        if config::is_cache_enabled()
502            && let Some(result) = cache.lookup(&key, hash)
503        {
504            return result.clone();
505        }
506        let result = timings::time_it(TimingKind::FixpointQuery(def_id, kind), || {
507            task.run()
508                .unwrap_or_else(|err| tracked_span_bug!("failed to run fixpoint: {err}"))
509        });
510
511        if config::is_cache_enabled() {
512            cache.insert(key, hash, result.clone());
513        }
514        result
515    }
516
517    fn tag_idx(&mut self, tag: Tag) -> TagIdx
518    where
519        Tag: std::fmt::Debug,
520    {
521        *self.tags_inv.entry(tag).or_insert_with(|| {
522            let idx = self.tags.push(tag);
523            self.comments.push(format!("Tag {idx}: {tag:?}"));
524            idx
525        })
526    }
527
528    pub(crate) fn with_name_map<R>(
529        &mut self,
530        name: rty::Name,
531        f: impl FnOnce(&mut Self, fixpoint::LocalVar) -> R,
532    ) -> R {
533        let fresh = self.ecx.local_var_env.insert_fvar_map(name);
534        let r = f(self, fresh);
535        self.ecx.local_var_env.remove_fvar_map(name);
536        r
537    }
538
539    pub(crate) fn sort_to_fixpoint(&mut self, sort: &rty::Sort) -> fixpoint::Sort {
540        self.scx.sort_to_fixpoint(sort)
541    }
542
543    pub(crate) fn var_to_fixpoint(&self, var: &rty::Var) -> fixpoint::Var {
544        self.ecx.var_to_fixpoint(var)
545    }
546
547    /// Encodes an expression in head position as a [`fixpoint::Constraint`] "peeling out"
548    /// implications and foralls.
549    ///
550    /// [`fixpoint::Constraint`]: liquid_fixpoint::Constraint
551    pub(crate) fn head_to_fixpoint(
552        &mut self,
553        expr: &rty::Expr,
554        mk_tag: impl Fn(Option<ESpan>) -> Tag + Copy,
555    ) -> QueryResult<fixpoint::Constraint>
556    where
557        Tag: std::fmt::Debug,
558    {
559        match expr.kind() {
560            rty::ExprKind::BinaryOp(rty::BinOp::And, ..) => {
561                // avoid creating nested conjunctions
562                let cstrs = expr
563                    .flatten_conjs()
564                    .into_iter()
565                    .map(|e| self.head_to_fixpoint(e, mk_tag))
566                    .try_collect()?;
567                Ok(fixpoint::Constraint::conj(cstrs))
568            }
569            rty::ExprKind::BinaryOp(rty::BinOp::Imp, e1, e2) => {
570                let (bindings, assumption) = self.assumption_to_fixpoint(e1)?;
571                let cstr = self.head_to_fixpoint(e2, mk_tag)?;
572                Ok(fixpoint::Constraint::foralls(bindings, mk_implies(assumption, cstr)))
573            }
574            rty::ExprKind::KVar(kvar) => {
575                let mut bindings = vec![];
576                let pred = self.kvar_to_fixpoint(kvar, &mut bindings)?;
577                Ok(fixpoint::Constraint::foralls(bindings, fixpoint::Constraint::Pred(pred, None)))
578            }
579            rty::ExprKind::ForAll(pred) => {
580                self.ecx
581                    .local_var_env
582                    .push_layer_with_fresh_names(pred.vars().len());
583                let cstr = self.head_to_fixpoint(pred.as_ref().skip_binder(), mk_tag)?;
584                let vars = self.ecx.local_var_env.pop_layer();
585
586                let bindings = iter::zip(vars, pred.vars())
587                    .map(|(var, kind)| {
588                        fixpoint::Bind {
589                            name: var.into(),
590                            sort: self.scx.sort_to_fixpoint(kind.expect_sort()),
591                            pred: fixpoint::Pred::TRUE,
592                        }
593                    })
594                    .collect_vec();
595
596                Ok(fixpoint::Constraint::foralls(bindings, cstr))
597            }
598            _ => {
599                let tag_idx = self.tag_idx(mk_tag(expr.span()));
600                let pred = fixpoint::Pred::Expr(self.ecx.expr_to_fixpoint(expr, &mut self.scx)?);
601                Ok(fixpoint::Constraint::Pred(pred, Some(tag_idx)))
602            }
603        }
604    }
605
606    /// Encodes an expression in assumptive position as a [`fixpoint::Pred`]. Returns the encoded
607    /// predicate and a list of bindings produced by ANF-ing kvars.
608    ///
609    /// [`fixpoint::Pred`]: liquid_fixpoint::Pred
610    pub(crate) fn assumption_to_fixpoint(
611        &mut self,
612        pred: &rty::Expr,
613    ) -> QueryResult<(Vec<fixpoint::Bind>, fixpoint::Pred)> {
614        let mut bindings = vec![];
615        let mut preds = vec![];
616        self.assumption_to_fixpoint_aux(pred, &mut bindings, &mut preds)?;
617        Ok((bindings, fixpoint::Pred::and(preds)))
618    }
619
620    /// Auxiliary function to merge nested conjunctions in a single predicate
621    fn assumption_to_fixpoint_aux(
622        &mut self,
623        expr: &rty::Expr,
624        bindings: &mut Vec<fixpoint::Bind>,
625        preds: &mut Vec<fixpoint::Pred>,
626    ) -> QueryResult {
627        match expr.kind() {
628            rty::ExprKind::BinaryOp(rty::BinOp::And, e1, e2) => {
629                self.assumption_to_fixpoint_aux(e1, bindings, preds)?;
630                self.assumption_to_fixpoint_aux(e2, bindings, preds)?;
631            }
632            rty::ExprKind::KVar(kvar) => {
633                preds.push(self.kvar_to_fixpoint(kvar, bindings)?);
634            }
635            rty::ExprKind::ForAll(_) => {
636                // If a forall appears in assumptive position replace it with true. This is sound
637                // because we are weakening the context, i.e., anything true without the assumption
638                // should remain true after adding it. Note that this relies on the predicate
639                // appearing negatively. This is guaranteed by the surface syntax because foralls
640                // can only appear at the top-level in a requires clause.
641                preds.push(fixpoint::Pred::TRUE);
642            }
643            _ => {
644                preds.push(fixpoint::Pred::Expr(self.ecx.expr_to_fixpoint(expr, &mut self.scx)?));
645            }
646        }
647        Ok(())
648    }
649
650    fn kvar_to_fixpoint(
651        &mut self,
652        kvar: &rty::KVar,
653        bindings: &mut Vec<fixpoint::Bind>,
654    ) -> QueryResult<fixpoint::Pred> {
655        let decl = self.kvars.get(kvar.kvid);
656        let kvids = self.kcx.encode(kvar.kvid, decl, &mut self.scx);
657
658        let all_args = iter::zip(&kvar.args, &decl.sorts)
659            .map(|(arg, sort)| self.ecx.imm(arg, sort, &mut self.scx, bindings))
660            .try_collect_vec()?;
661
662        // Fixpoint doesn't support kvars without arguments, which we do generate sometimes. To get
663        // around it, we encode `$k()` as ($k 0), or more precisely `(forall ((x int) (= x 0)) ... ($k x)`
664        // after ANF-ing.
665        if all_args.is_empty() {
666            let fresh = self.ecx.local_var_env.fresh_name();
667            let var = fixpoint::Var::Local(fresh);
668            bindings.push(fixpoint::Bind {
669                name: fresh.into(),
670                sort: fixpoint::Sort::Int,
671                pred: fixpoint::Pred::Expr(fixpoint::Expr::eq(
672                    fixpoint::Expr::Var(var),
673                    fixpoint::Expr::int(0),
674                )),
675            });
676            return Ok(fixpoint::Pred::KVar(kvids[0], vec![var]));
677        }
678
679        let kvars = kvids
680            .iter()
681            .enumerate()
682            .map(|(i, kvid)| {
683                let args = all_args[i..].to_vec();
684                fixpoint::Pred::KVar(*kvid, args)
685            })
686            .collect_vec();
687
688        Ok(fixpoint::Pred::And(kvars))
689    }
690}
691
692fn const_to_fixpoint(cst: rty::Constant) -> fixpoint::Expr {
693    match cst {
694        rty::Constant::Int(i) => {
695            if i.is_negative() {
696                fixpoint::Expr::Neg(Box::new(fixpoint::Constant::Numeral(i.abs()).into()))
697            } else {
698                fixpoint::Constant::Numeral(i.abs()).into()
699            }
700        }
701        rty::Constant::Real(r) => fixpoint::Constant::Decimal(r).into(),
702        rty::Constant::Bool(b) => fixpoint::Constant::Boolean(b).into(),
703        rty::Constant::Char(c) => fixpoint::Constant::Numeral(u128::from(c)).into(),
704        rty::Constant::Str(s) => fixpoint::Constant::String(fixpoint::SymStr(s)).into(),
705        rty::Constant::BitVec(i, size) => fixpoint::Constant::BitVec(i, size).into(),
706    }
707}
708
709struct FixpointKVar {
710    sorts: Vec<fixpoint::Sort>,
711    orig: rty::KVid,
712}
713
714/// During encoding into fixpoint we generate multiple fixpoint kvars per kvar in flux. A
715/// [`KVarEncodingCtxt`] is used to keep track of the state needed for this.
716#[derive(Default)]
717struct KVarEncodingCtxt {
718    /// List of all kvars that need to be defined in fixpoint
719    kvars: IndexVec<fixpoint::KVid, FixpointKVar>,
720    /// A mapping from [`rty::KVid`] to the list of [`fixpoint::KVid`]s encoding the kvar.
721    map: UnordMap<rty::KVid, Vec<fixpoint::KVid>>,
722}
723
724impl KVarEncodingCtxt {
725    fn encode(
726        &mut self,
727        kvid: rty::KVid,
728        decl: &KVarDecl,
729        scx: &mut SortEncodingCtxt,
730    ) -> &[fixpoint::KVid] {
731        self.map.entry(kvid).or_insert_with(|| {
732            let all_args = decl
733                .sorts
734                .iter()
735                .map(|s| scx.sort_to_fixpoint(s))
736                .collect_vec();
737
738            // See comment in `kvar_to_fixpoint`
739            if all_args.is_empty() {
740                let sorts = vec![fixpoint::Sort::Int];
741                let kvid = self.kvars.push(FixpointKVar::new(sorts, kvid));
742                return vec![kvid];
743            }
744
745            match decl.encoding {
746                KVarEncoding::Single => {
747                    let kvid = self.kvars.push(FixpointKVar::new(all_args, kvid));
748                    vec![kvid]
749                }
750                KVarEncoding::Conj => {
751                    let n = usize::max(decl.self_args, 1);
752                    (0..n)
753                        .map(|i| {
754                            let sorts = all_args[i..].to_vec();
755                            self.kvars.push(FixpointKVar::new(sorts, kvid))
756                        })
757                        .collect_vec()
758                }
759            }
760        })
761    }
762
763    fn into_fixpoint(self) -> Vec<fixpoint::KVarDecl> {
764        self.kvars
765            .into_iter_enumerated()
766            .map(|(kvid, kvar)| {
767                fixpoint::KVarDecl::new(kvid, kvar.sorts, format!("orig: {:?}", kvar.orig))
768            })
769            .collect()
770    }
771}
772
773/// Environment used to map from [`rty::Var`] to a [`fixpoint::LocalVar`].
774struct LocalVarEnv {
775    local_var_gen: IndexGen<fixpoint::LocalVar>,
776    fvars: UnordMap<rty::Name, fixpoint::LocalVar>,
777    /// Layers of late bound variables
778    layers: Vec<Vec<fixpoint::LocalVar>>,
779}
780
781impl LocalVarEnv {
782    fn new() -> Self {
783        Self { local_var_gen: IndexGen::new(), fvars: Default::default(), layers: Vec::new() }
784    }
785
786    // This doesn't require to be mutable because `IndexGen` uses atomics, but we make it mutable
787    // to better declare the intent.
788    fn fresh_name(&mut self) -> fixpoint::LocalVar {
789        self.local_var_gen.fresh()
790    }
791
792    fn insert_fvar_map(&mut self, name: rty::Name) -> fixpoint::LocalVar {
793        let fresh = self.fresh_name();
794        self.fvars.insert(name, fresh);
795        fresh
796    }
797
798    fn remove_fvar_map(&mut self, name: rty::Name) {
799        self.fvars.remove(&name);
800    }
801
802    /// Push a layer of bound variables assigning a fresh [`fixpoint::LocalVar`] to each one
803    fn push_layer_with_fresh_names(&mut self, count: usize) {
804        let layer = (0..count).map(|_| self.fresh_name()).collect();
805        self.layers.push(layer);
806    }
807
808    fn pop_layer(&mut self) -> Vec<fixpoint::LocalVar> {
809        self.layers.pop().unwrap()
810    }
811
812    fn get_fvar(&self, name: rty::Name) -> Option<fixpoint::LocalVar> {
813        self.fvars.get(&name).copied()
814    }
815
816    fn get_late_bvar(&self, debruijn: DebruijnIndex, var: BoundVar) -> Option<fixpoint::LocalVar> {
817        let depth = self.layers.len().checked_sub(debruijn.as_usize() + 1)?;
818        self.layers[depth].get(var.as_usize()).copied()
819    }
820}
821
822impl FixpointKVar {
823    fn new(sorts: Vec<fixpoint::Sort>, orig: rty::KVid) -> Self {
824        Self { sorts, orig }
825    }
826}
827
828pub struct KVarGen {
829    kvars: IndexVec<rty::KVid, KVarDecl>,
830    /// If true, generate dummy [holes] instead of kvars. Used during shape mode to avoid generating
831    /// unnecessary kvars.
832    ///
833    /// [holes]: rty::ExprKind::Hole
834    dummy: bool,
835}
836
837impl KVarGen {
838    pub(crate) fn new(dummy: bool) -> Self {
839        Self { kvars: IndexVec::new(), dummy }
840    }
841
842    fn get(&self, kvid: rty::KVid) -> &KVarDecl {
843        &self.kvars[kvid]
844    }
845
846    /// Generate a fresh [kvar] under several layers of [binders]. Each layer may contain any kind
847    /// of bound variable, but variables that are not of kind [`BoundVariableKind::Refine`] will
848    /// be filtered out.
849    ///
850    /// The variables bound in the last layer (last element of the `binders` slice) is expected to
851    /// have only [`BoundVariableKind::Refine`] and all its elements are used as the [self arguments].
852    /// The rest of the binders are appended to the `scope`.
853    ///
854    /// Note that the returned expression will have escaping variables and it is up to the caller to
855    /// put it under an appropriate number of binders.
856    ///
857    /// Prefer using [`InferCtxt::fresh_kvar`] when possible.
858    ///
859    /// [binders]: rty::Binder
860    /// [kvar]: rty::KVar
861    /// [`InferCtxt::fresh_kvar`]: crate::infer::InferCtxt::fresh_kvar
862    /// [self arguments]: rty::KVar::self_args
863    /// [`BoundVariableKind::Refine`]: rty::BoundVariableKind::Refine
864    pub fn fresh(
865        &mut self,
866        binders: &[rty::BoundVariableKinds],
867        scope: impl IntoIterator<Item = (rty::Var, rty::Sort)>,
868        encoding: KVarEncoding,
869    ) -> rty::Expr {
870        if self.dummy {
871            return rty::Expr::hole(rty::HoleKind::Pred);
872        }
873
874        let args = itertools::chain(
875            binders.iter().rev().enumerate().flat_map(|(level, vars)| {
876                let debruijn = DebruijnIndex::from_usize(level);
877                vars.iter()
878                    .cloned()
879                    .enumerate()
880                    .flat_map(move |(idx, var)| {
881                        if let rty::BoundVariableKind::Refine(sort, _, kind) = var {
882                            let br = rty::BoundReft { var: BoundVar::from_usize(idx), kind };
883                            Some((rty::Var::Bound(debruijn, br), sort))
884                        } else {
885                            None
886                        }
887                    })
888            }),
889            scope,
890        );
891        let [.., last] = binders else {
892            return self.fresh_inner(0, [], encoding);
893        };
894        let num_self_args = last
895            .iter()
896            .filter(|var| matches!(var, rty::BoundVariableKind::Refine(..)))
897            .count();
898        self.fresh_inner(num_self_args, args, encoding)
899    }
900
901    fn fresh_inner<A>(&mut self, self_args: usize, args: A, encoding: KVarEncoding) -> rty::Expr
902    where
903        A: IntoIterator<Item = (rty::Var, rty::Sort)>,
904    {
905        // asset last one has things
906        let mut sorts = vec![];
907        let mut exprs = vec![];
908
909        let mut flattened_self_args = 0;
910        for (i, (var, sort)) in args.into_iter().enumerate() {
911            let is_self_arg = i < self_args;
912            let var = var.to_expr();
913            sort.walk(|sort, proj| {
914                if !matches!(sort, rty::Sort::Loc) {
915                    flattened_self_args += is_self_arg as usize;
916                    sorts.push(sort.clone());
917                    exprs.push(rty::Expr::field_projs(&var, proj));
918                }
919            });
920        }
921
922        let kvid = self
923            .kvars
924            .push(KVarDecl { self_args: flattened_self_args, sorts, encoding });
925
926        let kvar = rty::KVar::new(kvid, flattened_self_args, exprs);
927        rty::Expr::kvar(kvar)
928    }
929}
930
931#[derive(Clone)]
932struct KVarDecl {
933    self_args: usize,
934    sorts: Vec<rty::Sort>,
935    encoding: KVarEncoding,
936}
937
938/// How an [`rty::KVar`] is encoded in the fixpoint constraint
939#[derive(Clone, Copy)]
940pub enum KVarEncoding {
941    /// Generate a single kvar appending the self arguments and the scope, i.e.,
942    /// a kvar `$k(a0, ...)[b0, ...]` becomes `$k(a0, ..., b0, ...)` in the fixpoint constraint.
943    Single,
944    /// Generate a conjunction of kvars, one per argument in [`rty::KVar::args`].
945    /// Concretely, a kvar `$k(a0, a1, ..., an)[b0, ...]` becomes
946    /// `$k0(a0, a1, ..., an, b0, ...) ∧ $k1(a1, ..., an, b0, ...) ∧ ... ∧ $kn(an, b0, ...)`
947    Conj,
948}
949
950impl std::fmt::Display for TagIdx {
951    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
952        write!(f, "{}", self.as_u32())
953    }
954}
955
956impl std::str::FromStr for TagIdx {
957    type Err = std::num::ParseIntError;
958
959    fn from_str(s: &str) -> Result<Self, Self::Err> {
960        Ok(Self::from_u32(s.parse()?))
961    }
962}
963
964struct ExprEncodingCtxt<'genv, 'tcx> {
965    genv: GlobalEnv<'genv, 'tcx>,
966    local_var_env: LocalVarEnv,
967    global_var_gen: IndexGen<fixpoint::GlobalVar>,
968    const_map: ConstMap<'tcx>,
969    fun_def_map: FunDefMap,
970    errors: Errors<'genv>,
971    /// Id of the item being checked. This is a [`MaybeExternId`] because we may be encoding
972    /// invariants for an extern spec on an enum.
973    def_id: MaybeExternId,
974    infcx: rustc_infer::infer::InferCtxt<'tcx>,
975}
976
977impl<'genv, 'tcx> ExprEncodingCtxt<'genv, 'tcx> {
978    fn new(genv: GlobalEnv<'genv, 'tcx>, def_id: MaybeExternId) -> Self {
979        Self {
980            genv,
981            local_var_env: LocalVarEnv::new(),
982            global_var_gen: IndexGen::new(),
983            const_map: Default::default(),
984            fun_def_map: Default::default(),
985            errors: Errors::new(genv.sess()),
986            def_id,
987            infcx: genv
988                .tcx()
989                .infer_ctxt()
990                .with_next_trait_solver(true)
991                .build(TypingMode::non_body_analysis()),
992        }
993    }
994
995    fn def_span(&self) -> Span {
996        self.genv.tcx().def_span(self.def_id)
997    }
998
999    fn var_to_fixpoint(&self, var: &rty::Var) -> fixpoint::Var {
1000        match var {
1001            rty::Var::Free(name) => {
1002                self.local_var_env
1003                    .get_fvar(*name)
1004                    .unwrap_or_else(|| {
1005                        span_bug!(self.def_span(), "no entry found for name: `{name:?}`")
1006                    })
1007                    .into()
1008            }
1009            rty::Var::Bound(debruijn, breft) => {
1010                self.local_var_env
1011                    .get_late_bvar(*debruijn, breft.var)
1012                    .unwrap_or_else(|| {
1013                        span_bug!(self.def_span(), "no entry found for late bound var: `{breft:?}`")
1014                    })
1015                    .into()
1016            }
1017            rty::Var::ConstGeneric(param) => fixpoint::Var::ConstGeneric(*param),
1018            rty::Var::EarlyParam(param) => fixpoint::Var::Param(*param),
1019            rty::Var::EVar(_) => {
1020                span_bug!(self.def_span(), "unexpected evar: `{var:?}`")
1021            }
1022        }
1023    }
1024
1025    fn variant_to_fixpoint(
1026        &self,
1027        scx: &mut SortEncodingCtxt,
1028        enum_def_id: &DefId,
1029        idx: VariantIdx,
1030    ) -> fixpoint::Expr {
1031        let adt_id = scx.declare_adt(*enum_def_id);
1032        let var = fixpoint::Var::DataCtor(adt_id, idx);
1033        fixpoint::Expr::Var(var)
1034    }
1035
1036    fn fields_to_fixpoint(
1037        &mut self,
1038        flds: &[rty::Expr],
1039        scx: &mut SortEncodingCtxt,
1040    ) -> QueryResult<fixpoint::Expr> {
1041        // do not generate 1-tuples
1042        if let [fld] = flds {
1043            self.expr_to_fixpoint(fld, scx)
1044        } else {
1045            scx.declare_tuple(flds.len());
1046            let ctor = fixpoint::Expr::Var(fixpoint::Var::TupleCtor { arity: flds.len() });
1047            let args = flds
1048                .iter()
1049                .map(|fld| self.expr_to_fixpoint(fld, scx))
1050                .try_collect()?;
1051            Ok(fixpoint::Expr::App(Box::new(ctor), args))
1052        }
1053    }
1054
1055    fn internal_func_to_fixpoint(
1056        &mut self,
1057        internal_func: &InternalFuncKind,
1058        sort_args: &[rty::SortArg],
1059        args: &[rty::Expr],
1060        scx: &mut SortEncodingCtxt,
1061    ) -> QueryResult<fixpoint::Expr> {
1062        match internal_func {
1063            InternalFuncKind::Val(op) => {
1064                let func = fixpoint::Expr::Var(self.define_const_for_prim_op(op, scx));
1065                let args = self.exprs_to_fixpoint(args, scx)?;
1066                Ok(fixpoint::Expr::App(Box::new(func), args))
1067            }
1068            InternalFuncKind::Rel(op) => {
1069                let expr = if let Some(prim_rel) = self.genv.prim_rel_for(op)? {
1070                    prim_rel.body.replace_bound_refts(args)
1071                } else {
1072                    rty::Expr::tt()
1073                };
1074                self.expr_to_fixpoint(&expr, scx)
1075            }
1076            InternalFuncKind::Cast => {
1077                let [rty::SortArg::Sort(from), rty::SortArg::Sort(to)] = &sort_args else {
1078                    span_bug!(self.def_span(), "unexpected cast")
1079                };
1080                match from.cast_kind(to) {
1081                    rty::CastKind::Identity => self.expr_to_fixpoint(&args[0], scx),
1082                    rty::CastKind::BoolToInt => {
1083                        Ok(fixpoint::Expr::IfThenElse(Box::new([
1084                            self.expr_to_fixpoint(&args[0], scx)?,
1085                            fixpoint::Expr::int(1),
1086                            fixpoint::Expr::int(0),
1087                        ])))
1088                    }
1089                    rty::CastKind::IntoUnit => self.expr_to_fixpoint(&rty::Expr::unit(), scx),
1090                    rty::CastKind::Uninterpreted => {
1091                        let func = fixpoint::Expr::Var(self.define_const_for_cast(from, to, scx));
1092                        let args = self.exprs_to_fixpoint(args, scx)?;
1093                        Ok(fixpoint::Expr::App(Box::new(func), args))
1094                    }
1095                }
1096            }
1097        }
1098    }
1099
1100    fn structurally_normalize_expr(&self, expr: &rty::Expr) -> QueryResult<rty::Expr> {
1101        structurally_normalize_expr(self.genv, self.def_id.resolved_id(), &self.infcx, expr)
1102    }
1103
1104    fn expr_to_fixpoint(
1105        &mut self,
1106        expr: &rty::Expr,
1107        scx: &mut SortEncodingCtxt,
1108    ) -> QueryResult<fixpoint::Expr> {
1109        let expr = self.structurally_normalize_expr(expr)?;
1110        let e = match expr.kind() {
1111            rty::ExprKind::Var(var) => fixpoint::Expr::Var(self.var_to_fixpoint(var)),
1112            rty::ExprKind::Constant(c) => const_to_fixpoint(*c),
1113            rty::ExprKind::BinaryOp(op, e1, e2) => self.bin_op_to_fixpoint(op, e1, e2, scx)?,
1114            rty::ExprKind::UnaryOp(op, e) => self.un_op_to_fixpoint(*op, e, scx)?,
1115            rty::ExprKind::FieldProj(e, proj) => self.proj_to_fixpoint(e, *proj, scx)?,
1116            rty::ExprKind::Tuple(flds) => self.fields_to_fixpoint(flds, scx)?,
1117            rty::ExprKind::Ctor(rty::Ctor::Struct(_), flds) => {
1118                self.fields_to_fixpoint(flds, scx)?
1119            }
1120            rty::ExprKind::Ctor(rty::Ctor::Enum(did, idx), _) => {
1121                self.variant_to_fixpoint(scx, did, *idx)
1122            }
1123            rty::ExprKind::ConstDefId(did) => {
1124                let var = self.define_const_for_rust_const(*did, scx);
1125                fixpoint::Expr::Var(var)
1126            }
1127            rty::ExprKind::App(func, sort_args, args) => {
1128                if let rty::ExprKind::InternalFunc(func) = func.kind() {
1129                    self.internal_func_to_fixpoint(func, sort_args, args, scx)?
1130                } else {
1131                    let func = self.expr_to_fixpoint(func, scx)?;
1132                    let args = self.exprs_to_fixpoint(args, scx)?;
1133                    fixpoint::Expr::App(Box::new(func), args)
1134                }
1135            }
1136            rty::ExprKind::IfThenElse(p, e1, e2) => {
1137                fixpoint::Expr::IfThenElse(Box::new([
1138                    self.expr_to_fixpoint(p, scx)?,
1139                    self.expr_to_fixpoint(e1, scx)?,
1140                    self.expr_to_fixpoint(e2, scx)?,
1141                ]))
1142            }
1143            rty::ExprKind::Alias(alias_reft, args) => {
1144                let sort = self.genv.sort_of_assoc_reft(alias_reft.assoc_id)?;
1145                let sort = sort.instantiate_identity();
1146                let func =
1147                    fixpoint::Expr::Var(self.define_const_for_alias_reft(alias_reft, sort, scx));
1148                let args = args
1149                    .iter()
1150                    .map(|expr| self.expr_to_fixpoint(expr, scx))
1151                    .try_collect()?;
1152                fixpoint::Expr::App(Box::new(func), args)
1153            }
1154            rty::ExprKind::Abs(lam) => {
1155                let var = self.define_const_for_lambda(lam, scx);
1156                fixpoint::Expr::Var(var)
1157            }
1158            rty::ExprKind::Let(init, body) => {
1159                debug_assert_eq!(body.vars().len(), 1);
1160                let init = self.expr_to_fixpoint(init, scx)?;
1161
1162                self.local_var_env.push_layer_with_fresh_names(1);
1163                let body = self.expr_to_fixpoint(body.skip_binder_ref(), scx)?;
1164                let vars = self.local_var_env.pop_layer();
1165
1166                fixpoint::Expr::Let(vars[0].into(), Box::new([init, body]))
1167            }
1168            rty::ExprKind::GlobalFunc(SpecFuncKind::Thy(itf)) => {
1169                fixpoint::Expr::Var(fixpoint::Var::Itf(*itf))
1170            }
1171            rty::ExprKind::GlobalFunc(SpecFuncKind::Uif(def_id)) => {
1172                fixpoint::Expr::Var(self.define_const_for_uif(*def_id, scx))
1173            }
1174            rty::ExprKind::GlobalFunc(SpecFuncKind::Def(def_id)) => {
1175                fixpoint::Expr::Var(self.declare_fun(*def_id))
1176            }
1177            rty::ExprKind::Hole(..)
1178            | rty::ExprKind::KVar(_)
1179            | rty::ExprKind::Local(_)
1180            | rty::ExprKind::PathProj(..)
1181            | rty::ExprKind::ForAll(_)
1182            | rty::ExprKind::InternalFunc(_) => {
1183                span_bug!(self.def_span(), "unexpected expr: `{expr:?}`")
1184            }
1185            rty::ExprKind::BoundedQuant(kind, rng, body) => {
1186                let exprs = (rng.start..rng.end).map(|i| {
1187                    let arg = rty::Expr::constant(rty::Constant::from(i));
1188                    body.replace_bound_reft(&arg)
1189                });
1190                let expr = match kind {
1191                    flux_middle::fhir::QuantKind::Forall => rty::Expr::and_from_iter(exprs),
1192                    flux_middle::fhir::QuantKind::Exists => rty::Expr::or_from_iter(exprs),
1193                };
1194                self.expr_to_fixpoint(&expr, scx)?
1195            }
1196        };
1197        Ok(e)
1198    }
1199
1200    fn exprs_to_fixpoint<'b>(
1201        &mut self,
1202        exprs: impl IntoIterator<Item = &'b rty::Expr>,
1203        scx: &mut SortEncodingCtxt,
1204    ) -> QueryResult<Vec<fixpoint::Expr>> {
1205        exprs
1206            .into_iter()
1207            .map(|e| self.expr_to_fixpoint(e, scx))
1208            .try_collect()
1209    }
1210
1211    fn proj_to_fixpoint(
1212        &mut self,
1213        e: &rty::Expr,
1214        proj: rty::FieldProj,
1215        scx: &mut SortEncodingCtxt,
1216    ) -> QueryResult<fixpoint::Expr> {
1217        let arity = proj.arity(self.genv)?;
1218        // we encode 1-tuples as the single element inside so no projection necessary here
1219        if arity == 1 {
1220            self.expr_to_fixpoint(e, scx)
1221        } else {
1222            let field = proj.field_idx();
1223            scx.declare_tuple(arity);
1224            let proj = fixpoint::Var::TupleProj { arity, field };
1225            let proj = fixpoint::Expr::Var(proj);
1226            Ok(fixpoint::Expr::App(Box::new(proj), vec![self.expr_to_fixpoint(e, scx)?]))
1227        }
1228    }
1229
1230    fn un_op_to_fixpoint(
1231        &mut self,
1232        op: rty::UnOp,
1233        e: &rty::Expr,
1234        scx: &mut SortEncodingCtxt,
1235    ) -> QueryResult<fixpoint::Expr> {
1236        match op {
1237            rty::UnOp::Not => Ok(fixpoint::Expr::Not(Box::new(self.expr_to_fixpoint(e, scx)?))),
1238            rty::UnOp::Neg => Ok(fixpoint::Expr::Neg(Box::new(self.expr_to_fixpoint(e, scx)?))),
1239        }
1240    }
1241
1242    fn bv_rel_to_fixpoint(&self, rel: &fixpoint::BinRel) -> fixpoint::Expr {
1243        let itf = match rel {
1244            fixpoint::BinRel::Gt => fixpoint::ThyFunc::BvUgt,
1245            fixpoint::BinRel::Ge => fixpoint::ThyFunc::BvUge,
1246            fixpoint::BinRel::Lt => fixpoint::ThyFunc::BvUlt,
1247            fixpoint::BinRel::Le => fixpoint::ThyFunc::BvUle,
1248            _ => span_bug!(self.def_span(), "not a bitvector relation!"),
1249        };
1250        fixpoint::Expr::Var(fixpoint::Var::Itf(itf))
1251    }
1252
1253    fn bv_op_to_fixpoint(&self, op: &rty::BinOp) -> fixpoint::Expr {
1254        let itf = match op {
1255            rty::BinOp::Add(_) => fixpoint::ThyFunc::BvAdd,
1256            rty::BinOp::Sub(_) => fixpoint::ThyFunc::BvSub,
1257            rty::BinOp::Mul(_) => fixpoint::ThyFunc::BvMul,
1258            rty::BinOp::Div(_) => fixpoint::ThyFunc::BvUdiv,
1259            rty::BinOp::Mod(_) => fixpoint::ThyFunc::BvUrem,
1260            rty::BinOp::BitAnd => fixpoint::ThyFunc::BvAnd,
1261            rty::BinOp::BitOr => fixpoint::ThyFunc::BvOr,
1262            rty::BinOp::BitXor => fixpoint::ThyFunc::BvXor,
1263            rty::BinOp::BitShl => fixpoint::ThyFunc::BvShl,
1264            rty::BinOp::BitShr => fixpoint::ThyFunc::BvLshr,
1265            _ => span_bug!(self.def_span(), "not a bitvector operation!"),
1266        };
1267        fixpoint::Expr::Var(fixpoint::Var::Itf(itf))
1268    }
1269
1270    fn bin_op_to_fixpoint(
1271        &mut self,
1272        op: &rty::BinOp,
1273        e1: &rty::Expr,
1274        e2: &rty::Expr,
1275        scx: &mut SortEncodingCtxt,
1276    ) -> QueryResult<fixpoint::Expr> {
1277        let op = match op {
1278            rty::BinOp::Eq => {
1279                return Ok(fixpoint::Expr::Atom(
1280                    fixpoint::BinRel::Eq,
1281                    Box::new([self.expr_to_fixpoint(e1, scx)?, self.expr_to_fixpoint(e2, scx)?]),
1282                ));
1283            }
1284            rty::BinOp::Ne => {
1285                return Ok(fixpoint::Expr::Atom(
1286                    fixpoint::BinRel::Ne,
1287                    Box::new([self.expr_to_fixpoint(e1, scx)?, self.expr_to_fixpoint(e2, scx)?]),
1288                ));
1289            }
1290            rty::BinOp::Gt(sort) => {
1291                return self.bin_rel_to_fixpoint(sort, fixpoint::BinRel::Gt, e1, e2, scx);
1292            }
1293            rty::BinOp::Ge(sort) => {
1294                return self.bin_rel_to_fixpoint(sort, fixpoint::BinRel::Ge, e1, e2, scx);
1295            }
1296            rty::BinOp::Lt(sort) => {
1297                return self.bin_rel_to_fixpoint(sort, fixpoint::BinRel::Lt, e1, e2, scx);
1298            }
1299            rty::BinOp::Le(sort) => {
1300                return self.bin_rel_to_fixpoint(sort, fixpoint::BinRel::Le, e1, e2, scx);
1301            }
1302            rty::BinOp::And => {
1303                return Ok(fixpoint::Expr::And(vec![
1304                    self.expr_to_fixpoint(e1, scx)?,
1305                    self.expr_to_fixpoint(e2, scx)?,
1306                ]));
1307            }
1308            rty::BinOp::Or => {
1309                return Ok(fixpoint::Expr::Or(vec![
1310                    self.expr_to_fixpoint(e1, scx)?,
1311                    self.expr_to_fixpoint(e2, scx)?,
1312                ]));
1313            }
1314            rty::BinOp::Imp => {
1315                return Ok(fixpoint::Expr::Imp(Box::new([
1316                    self.expr_to_fixpoint(e1, scx)?,
1317                    self.expr_to_fixpoint(e2, scx)?,
1318                ])));
1319            }
1320            rty::BinOp::Iff => {
1321                return Ok(fixpoint::Expr::Iff(Box::new([
1322                    self.expr_to_fixpoint(e1, scx)?,
1323                    self.expr_to_fixpoint(e2, scx)?,
1324                ])));
1325            }
1326            rty::BinOp::Add(rty::Sort::BitVec(_))
1327            | rty::BinOp::Sub(rty::Sort::BitVec(_))
1328            | rty::BinOp::Mul(rty::Sort::BitVec(_))
1329            | rty::BinOp::Div(rty::Sort::BitVec(_))
1330            | rty::BinOp::Mod(rty::Sort::BitVec(_))
1331            | rty::BinOp::BitAnd
1332            | rty::BinOp::BitOr
1333            | rty::BinOp::BitXor
1334            | rty::BinOp::BitShl
1335            | rty::BinOp::BitShr => {
1336                let bv_func = self.bv_op_to_fixpoint(op);
1337                return Ok(fixpoint::Expr::App(
1338                    Box::new(bv_func),
1339                    vec![self.expr_to_fixpoint(e1, scx)?, self.expr_to_fixpoint(e2, scx)?],
1340                ));
1341            }
1342            rty::BinOp::Add(_) => fixpoint::BinOp::Add,
1343            rty::BinOp::Sub(_) => fixpoint::BinOp::Sub,
1344            rty::BinOp::Mul(_) => fixpoint::BinOp::Mul,
1345            rty::BinOp::Div(_) => fixpoint::BinOp::Div,
1346            rty::BinOp::Mod(_) => fixpoint::BinOp::Mod,
1347        };
1348        Ok(fixpoint::Expr::BinaryOp(
1349            op,
1350            Box::new([self.expr_to_fixpoint(e1, scx)?, self.expr_to_fixpoint(e2, scx)?]),
1351        ))
1352    }
1353
1354    /// A binary relation is encoded as a structurally recursive relation between aggregate sorts.
1355    /// For "leaf" expressions, we encode them as an interpreted relation if the sort supports it,
1356    /// otherwise we use an uninterpreted function. For example, consider the following relation
1357    /// between two tuples of sort `(int, int -> int)`
1358    /// ```text
1359    /// (0, λv. v + 1) <= (1, λv. v + 1)
1360    /// ```
1361    /// The encoding in fixpoint will be
1362    ///
1363    /// ```text
1364    /// 0 <= 1 && (le (λv. v + 1) (λv. v + 1))
1365    /// ```
1366    /// Where `<=` is the (interpreted) less than or equal relation between integers and `le` is
1367    /// an uninterpreted relation between ([the encoding] of) lambdas.
1368    ///
1369    /// [the encoding]: Self::define_const_for_lambda
1370    fn bin_rel_to_fixpoint(
1371        &mut self,
1372        sort: &rty::Sort,
1373        rel: fixpoint::BinRel,
1374        e1: &rty::Expr,
1375        e2: &rty::Expr,
1376        scx: &mut SortEncodingCtxt,
1377    ) -> QueryResult<fixpoint::Expr> {
1378        let e = match sort {
1379            rty::Sort::Int | rty::Sort::Real | rty::Sort::Char => {
1380                fixpoint::Expr::Atom(
1381                    rel,
1382                    Box::new([self.expr_to_fixpoint(e1, scx)?, self.expr_to_fixpoint(e2, scx)?]),
1383                )
1384            }
1385            rty::Sort::BitVec(_) => {
1386                let e1 = self.expr_to_fixpoint(e1, scx)?;
1387                let e2 = self.expr_to_fixpoint(e2, scx)?;
1388                let rel = self.bv_rel_to_fixpoint(&rel);
1389                fixpoint::Expr::App(Box::new(rel), vec![e1, e2])
1390            }
1391            rty::Sort::Tuple(sorts) => {
1392                let arity = sorts.len();
1393                self.apply_bin_rel_rec(sorts, rel, e1, e2, scx, |field| {
1394                    rty::FieldProj::Tuple { arity, field }
1395                })?
1396            }
1397            rty::Sort::App(rty::SortCtor::Adt(sort_def), args)
1398                if let Some(variant) = sort_def.opt_struct_variant() =>
1399            {
1400                let def_id = sort_def.did();
1401                let sorts = variant.field_sorts(args);
1402                self.apply_bin_rel_rec(&sorts, rel, e1, e2, scx, |field| {
1403                    rty::FieldProj::Adt { def_id, field }
1404                })?
1405            }
1406            _ => {
1407                let rel = fixpoint::Expr::Var(fixpoint::Var::UIFRel(rel));
1408                fixpoint::Expr::App(
1409                    Box::new(rel),
1410                    vec![self.expr_to_fixpoint(e1, scx)?, self.expr_to_fixpoint(e2, scx)?],
1411                )
1412            }
1413        };
1414        Ok(e)
1415    }
1416
1417    /// Apply binary relation recursively over aggregate expressions
1418    fn apply_bin_rel_rec(
1419        &mut self,
1420        sorts: &[rty::Sort],
1421        rel: fixpoint::BinRel,
1422        e1: &rty::Expr,
1423        e2: &rty::Expr,
1424        scx: &mut SortEncodingCtxt,
1425        mk_proj: impl Fn(u32) -> rty::FieldProj,
1426    ) -> QueryResult<fixpoint::Expr> {
1427        Ok(fixpoint::Expr::and(
1428            sorts
1429                .iter()
1430                .enumerate()
1431                .map(|(idx, s)| {
1432                    let proj = mk_proj(idx as u32);
1433                    let e1 = e1.proj_and_reduce(proj);
1434                    let e2 = e2.proj_and_reduce(proj);
1435                    self.bin_rel_to_fixpoint(s, rel, &e1, &e2, scx)
1436                })
1437                .try_collect()?,
1438        ))
1439    }
1440
1441    fn imm(
1442        &mut self,
1443        arg: &rty::Expr,
1444        sort: &rty::Sort,
1445        scx: &mut SortEncodingCtxt,
1446        bindings: &mut Vec<fixpoint::Bind>,
1447    ) -> QueryResult<fixpoint::Var> {
1448        let arg = self.expr_to_fixpoint(arg, scx)?;
1449        // Check if it's a variable after encoding, in case the encoding produced a variable from a
1450        // non-variable.
1451        if let fixpoint::Expr::Var(var) = arg {
1452            Ok(var)
1453        } else {
1454            let fresh = self.local_var_env.fresh_name();
1455            let pred = fixpoint::Expr::eq(fixpoint::Expr::Var(fresh.into()), arg);
1456            bindings.push(fixpoint::Bind {
1457                name: fresh.into(),
1458                sort: scx.sort_to_fixpoint(sort),
1459                pred: fixpoint::Pred::Expr(pred),
1460            });
1461            Ok(fresh.into())
1462        }
1463    }
1464
1465    /// Declare that the `def_id` of a Flux function definition needs to be encoded and assigns
1466    /// a name to it if it hasn't yet been declared. The encoding of the function body happens
1467    /// in [`Self::define_funs`].
1468    fn declare_fun(&mut self, def_id: FluxDefId) -> fixpoint::Var {
1469        *self.fun_def_map.entry(def_id).or_insert_with(|| {
1470            let id = self.global_var_gen.fresh();
1471            fixpoint::Var::Global(id, Some(def_id.name()))
1472        })
1473    }
1474
1475    /// The logic below is a bit "duplicated" with the `[`prim_op_sort`] in sortck.rs;
1476    /// They are not exactly the same because this is on rty and the other one on fhir.
1477    /// We should make sure these two remain in sync.
1478    ///
1479    /// (NOTE:PrimOpSort) We are somewhat "overloading" the BinOps: as we are using them
1480    /// for (a) interpreted operations on bit vectors AND (b) uninterpreted functions on integers.
1481    /// So when Binop::BitShr (a) appears in a ExprKind::BinOp, it means bit vectors, but
1482    /// (b) inside ExprKind::InternalFunc it means int.
1483    fn prim_op_sort(op: &rty::BinOp, span: Span) -> rty::PolyFuncSort {
1484        match op {
1485            rty::BinOp::BitAnd
1486            | rty::BinOp::BitOr
1487            | rty::BinOp::BitXor
1488            | rty::BinOp::BitShl
1489            | rty::BinOp::BitShr => {
1490                let fsort =
1491                    rty::FuncSort::new(vec![rty::Sort::Int, rty::Sort::Int], rty::Sort::Int);
1492                rty::PolyFuncSort::new(List::empty(), fsort)
1493            }
1494            _ => span_bug!(span, "unexpected prim op: {op:?} in `prim_op_sort`"),
1495        }
1496    }
1497
1498    fn define_const_for_cast(
1499        &mut self,
1500        from: &rty::Sort,
1501        to: &rty::Sort,
1502        scx: &mut SortEncodingCtxt,
1503    ) -> fixpoint::Var {
1504        let key = ConstKey::Cast(from.clone(), to.clone());
1505        self.const_map
1506            .entry(key)
1507            .or_insert_with(|| {
1508                let fsort = rty::FuncSort::new(vec![from.clone()], to.clone());
1509                let fsort = rty::PolyFuncSort::new(List::empty(), fsort);
1510                let sort = scx.func_sort_to_fixpoint(&fsort);
1511                fixpoint::ConstDecl {
1512                    name: fixpoint::Var::Global(self.global_var_gen.fresh(), None),
1513                    sort,
1514                    comment: Some(format!("cast uif: ({from:?}) -> {to:?}")),
1515                }
1516            })
1517            .name
1518    }
1519
1520    fn define_const_for_prim_op(
1521        &mut self,
1522        op: &rty::BinOp,
1523        scx: &mut SortEncodingCtxt,
1524    ) -> fixpoint::Var {
1525        let key = ConstKey::PrimOp(op.clone());
1526        let span = self.def_span();
1527        self.const_map
1528            .entry(key)
1529            .or_insert_with(|| {
1530                let sort = scx.func_sort_to_fixpoint(&Self::prim_op_sort(op, span));
1531                fixpoint::ConstDecl {
1532                    name: fixpoint::Var::Global(self.global_var_gen.fresh(), None),
1533                    sort,
1534                    comment: Some(format!("prim op uif: {op:?}")),
1535                }
1536            })
1537            .name
1538    }
1539
1540    fn define_const_for_uif(
1541        &mut self,
1542        def_id: FluxDefId,
1543        scx: &mut SortEncodingCtxt,
1544    ) -> fixpoint::Var {
1545        let key = ConstKey::Uif(def_id);
1546        self.const_map
1547            .entry(key)
1548            .or_insert_with(|| {
1549                let sort = scx.func_sort_to_fixpoint(&self.genv.func_sort(def_id));
1550                fixpoint::ConstDecl {
1551                    name: fixpoint::Var::Global(self.global_var_gen.fresh(), Some(def_id.name())),
1552                    sort,
1553                    comment: Some(format!("uif: {def_id:?}")),
1554                }
1555            })
1556            .name
1557    }
1558
1559    fn define_const_for_rust_const(
1560        &mut self,
1561        def_id: DefId,
1562        scx: &mut SortEncodingCtxt,
1563    ) -> fixpoint::Var {
1564        let key = ConstKey::RustConst(def_id);
1565        self.const_map
1566            .entry(key)
1567            .or_insert_with(|| {
1568                let sort = self.genv.sort_of_def_id(def_id).unwrap().unwrap();
1569                fixpoint::ConstDecl {
1570                    name: fixpoint::Var::Global(self.global_var_gen.fresh(), None),
1571                    sort: scx.sort_to_fixpoint(&sort),
1572                    comment: Some(format!("rust const: {}", def_id_to_string(def_id))),
1573                }
1574            })
1575            .name
1576    }
1577
1578    /// returns the 'constant' UIF for Var used to represent the alias_pred, creating and adding it
1579    /// to the const_map if necessary
1580    fn define_const_for_alias_reft(
1581        &mut self,
1582        alias_reft: &rty::AliasReft,
1583        fsort: rty::FuncSort,
1584        scx: &mut SortEncodingCtxt,
1585    ) -> fixpoint::Var {
1586        let tcx = self.genv.tcx();
1587        let args = alias_reft
1588            .args
1589            .to_rustc(tcx)
1590            .truncate_to(tcx, tcx.generics_of(alias_reft.assoc_id.parent()));
1591        let key = ConstKey::Alias(alias_reft.assoc_id, args);
1592        self.const_map
1593            .entry(key)
1594            .or_insert_with(|| {
1595                let comment = Some(format!("alias reft: {alias_reft:?}"));
1596                let name = fixpoint::Var::Global(self.global_var_gen.fresh(), None);
1597                let fsort = rty::PolyFuncSort::new(List::empty(), fsort);
1598                let sort = scx.func_sort_to_fixpoint(&fsort);
1599                fixpoint::ConstDecl { name, comment, sort }
1600            })
1601            .name
1602    }
1603
1604    /// We encode lambdas with uninterpreted constant. Two syntactically equal lambdas will be encoded
1605    /// with the same constant.
1606    fn define_const_for_lambda(
1607        &mut self,
1608        lam: &rty::Lambda,
1609        scx: &mut SortEncodingCtxt,
1610    ) -> fixpoint::Var {
1611        let key = ConstKey::Lambda(lam.clone());
1612        self.const_map
1613            .entry(key)
1614            .or_insert_with(|| {
1615                let comment = Some(format!("lambda: {lam:?}"));
1616                let name = fixpoint::Var::Global(self.global_var_gen.fresh(), None);
1617                let sort = scx.func_sort_to_fixpoint(&lam.fsort().to_poly());
1618                fixpoint::ConstDecl { name, comment, sort }
1619            })
1620            .name
1621    }
1622
1623    fn assume_const_values(
1624        &mut self,
1625        mut constraint: fixpoint::Constraint,
1626        scx: &mut SortEncodingCtxt,
1627    ) -> QueryResult<fixpoint::Constraint> {
1628        // Encoding the value for a constant could in theory define more constants for which
1629        // we need to assume values, so we iterate until there are no more constants.
1630        let mut idx = 0;
1631        while let Some((key, const_)) = self.const_map.get_index(idx) {
1632            idx += 1;
1633
1634            let ConstKey::RustConst(def_id) = key else { continue };
1635            let info = self.genv.constant_info(def_id)?;
1636            match info {
1637                rty::ConstantInfo::Uninterpreted => {}
1638                rty::ConstantInfo::Interpreted(val, _) => {
1639                    let e1 = fixpoint::Expr::Var(const_.name);
1640                    let e2 = self.expr_to_fixpoint(&val, scx)?;
1641                    let pred = fixpoint::Pred::Expr(e1.eq(e2));
1642                    constraint = fixpoint::Constraint::ForAll(
1643                        fixpoint::Bind {
1644                            name: fixpoint::Var::Underscore,
1645                            sort: fixpoint::Sort::Int,
1646                            pred,
1647                        },
1648                        Box::new(constraint),
1649                    );
1650                }
1651            }
1652        }
1653        Ok(constraint)
1654    }
1655
1656    fn qualifiers_for(
1657        &mut self,
1658        def_id: LocalDefId,
1659        scx: &mut SortEncodingCtxt,
1660    ) -> QueryResult<Vec<fixpoint::Qualifier>> {
1661        self.genv
1662            .qualifiers_for(def_id)?
1663            .map(|qual| self.qualifier_to_fixpoint(qual, scx))
1664            .try_collect()
1665    }
1666
1667    fn define_funs(
1668        &mut self,
1669        def_id: MaybeExternId,
1670        scx: &mut SortEncodingCtxt,
1671    ) -> QueryResult<(Vec<fixpoint::FunDef>, Vec<fixpoint::ConstDecl>)> {
1672        let reveals: UnordSet<FluxDefId> = self.genv.reveals_for(def_id.local_id())?.collect();
1673        let mut consts = vec![];
1674        let mut defs = vec![];
1675
1676        // We iterate until encoding the body of functions doesn't require any more functions
1677        // to be encoded.
1678        let mut idx = 0;
1679        while let Some((&did, &name)) = self.fun_def_map.get_index(idx) {
1680            idx += 1;
1681
1682            let comment = format!("flux def: {did:?}");
1683            let info = self.genv.normalized_info(did);
1684            let revealed = reveals.contains(&did);
1685            if info.hide && !revealed {
1686                let sort = scx.func_sort_to_fixpoint(&self.genv.func_sort(did));
1687                consts.push(fixpoint::ConstDecl { name, sort, comment: Some(comment) });
1688            } else {
1689                let out = scx.sort_to_fixpoint(self.genv.func_sort(did).expect_mono().output());
1690                let (args, body) = self.body_to_fixpoint(&info.body, scx)?;
1691                let fun_def = fixpoint::FunDef { name, args, body, out, comment: Some(comment) };
1692                defs.push((info.rank, fun_def));
1693            };
1694        }
1695
1696        // we sort by rank so the definitions go out without any forward dependencies.
1697        let defs = defs
1698            .into_iter()
1699            .sorted_by_key(|(rank, _)| *rank)
1700            .map(|(_, def)| def)
1701            .collect();
1702
1703        Ok((defs, consts))
1704    }
1705
1706    fn body_to_fixpoint(
1707        &mut self,
1708        body: &rty::Binder<rty::Expr>,
1709        scx: &mut SortEncodingCtxt,
1710    ) -> QueryResult<(Vec<(fixpoint::Var, fixpoint::Sort)>, fixpoint::Expr)> {
1711        self.local_var_env
1712            .push_layer_with_fresh_names(body.vars().len());
1713
1714        let expr = self.expr_to_fixpoint(body.as_ref().skip_binder(), scx)?;
1715
1716        let args: Vec<(fixpoint::Var, fixpoint::Sort)> =
1717            iter::zip(self.local_var_env.pop_layer(), body.vars())
1718                .map(|(name, var)| (name.into(), scx.sort_to_fixpoint(var.expect_sort())))
1719                .collect();
1720
1721        Ok((args, expr))
1722    }
1723
1724    fn qualifier_to_fixpoint(
1725        &mut self,
1726        qualifier: &rty::Qualifier,
1727        scx: &mut SortEncodingCtxt,
1728    ) -> QueryResult<fixpoint::Qualifier> {
1729        let (args, body) = self.body_to_fixpoint(&qualifier.body, scx)?;
1730        let name = qualifier.def_id.name().to_string();
1731        Ok(fixpoint::Qualifier { name, args, body })
1732    }
1733}
1734
1735fn mk_implies(assumption: fixpoint::Pred, cstr: fixpoint::Constraint) -> fixpoint::Constraint {
1736    fixpoint::Constraint::ForAll(
1737        fixpoint::Bind {
1738            name: fixpoint::Var::Underscore,
1739            sort: fixpoint::Sort::Int,
1740            pred: assumption,
1741        },
1742        Box::new(cstr),
1743    )
1744}