liquid_fixpoint/
constraint.rs

1use std::{collections::HashSet, hash::Hash};
2
3use derive_where::derive_where;
4use indexmap::IndexSet;
5
6use crate::{ThyFunc, Types};
7
8#[derive_where(Hash, Clone, Debug)]
9pub struct Bind<T: Types> {
10    pub name: T::Var,
11    pub sort: Sort<T>,
12    pub pred: Pred<T>,
13}
14
15#[derive_where(Hash, Clone, Debug)]
16pub enum Constraint<T: Types> {
17    Pred(Pred<T>, #[derive_where(skip)] Option<T::Tag>),
18    Conj(Vec<Self>),
19    ForAll(Bind<T>, Box<Self>),
20}
21
22impl<T: Types> Constraint<T> {
23    pub const TRUE: Self = Self::Pred(Pred::TRUE, None);
24
25    pub fn foralls(bindings: Vec<Bind<T>>, c: Self) -> Self {
26        bindings
27            .into_iter()
28            .rev()
29            .fold(c, |c, bind| Constraint::ForAll(bind, Box::new(c)))
30    }
31
32    pub fn conj(mut cstrs: Vec<Self>) -> Self {
33        if cstrs.len() == 1 { cstrs.remove(0) } else { Self::Conj(cstrs) }
34    }
35
36    /// Returns true if the constraint has at least one concrete RHS ("head") predicates.
37    /// If `!c.is_concrete` then `c` is trivially satisfiable and we can avoid calling fixpoint.
38    /// Returns the number of concrete, non-trivial head predicates in the constraint.
39    pub fn concrete_head_count(&self) -> usize {
40        fn go<T: Types>(c: &Constraint<T>, count: &mut usize) {
41            match c {
42                Constraint::Conj(cs) => cs.iter().for_each(|c| go(c, count)),
43                Constraint::ForAll(_, c) => go(c, count),
44                Constraint::Pred(p, _) => {
45                    if p.is_concrete() && !p.is_trivially_true() {
46                        *count += 1;
47                    }
48                }
49            }
50        }
51        let mut count = 0;
52        go(self, &mut count);
53        count
54    }
55}
56
57#[derive_where(Hash, Clone, Debug)]
58pub struct DataDecl<T: Types> {
59    pub name: T::Sort,
60    pub vars: usize,
61    pub ctors: Vec<DataCtor<T>>,
62}
63
64impl<T: Types> DataDecl<T> {
65    pub fn deps(&self, acc: &mut Vec<T::Sort>) {
66        for ctor in &self.ctors {
67            for field in &ctor.fields {
68                field.sort.deps(acc);
69            }
70        }
71    }
72}
73
74#[derive_where(Hash, Clone, Debug)]
75pub struct SortDecl<T: Types> {
76    pub name: T::Sort,
77    pub vars: usize,
78}
79
80#[derive_where(Hash, Clone, Debug)]
81pub struct DataCtor<T: Types> {
82    pub name: T::Var,
83    pub fields: Vec<DataField<T>>,
84}
85
86#[derive_where(Hash, Clone, Debug)]
87pub struct DataField<T: Types> {
88    pub name: T::Var,
89    pub sort: Sort<T>,
90}
91
92#[derive_where(Hash, Clone, Debug)]
93pub enum Sort<T: Types> {
94    Int,
95    Bool,
96    Real,
97    Str,
98    BitVec(Box<Sort<T>>),
99    BvSize(u32),
100    Var(usize),
101    Func(Box<[Self; 2]>),
102    Abs(usize, Box<Self>),
103    App(SortCtor<T>, Vec<Self>),
104}
105
106impl<T: Types> Sort<T> {
107    pub fn deps(&self, acc: &mut Vec<T::Sort>) {
108        match self {
109            Sort::App(SortCtor::Data(dt_name), args) => {
110                acc.push(dt_name.clone());
111                for arg in args {
112                    arg.deps(acc);
113                }
114            }
115            Sort::Func(input_and_output) => {
116                let [input, output] = &**input_and_output;
117                input.deps(acc);
118                output.deps(acc);
119            }
120            Sort::Abs(_, sort) => {
121                sort.deps(acc);
122            }
123            _ => {}
124        }
125    }
126
127    pub fn mk_func<I>(params: usize, inputs: I, output: Sort<T>) -> Sort<T>
128    where
129        I: IntoIterator<Item = Sort<T>>,
130        I::IntoIter: DoubleEndedIterator,
131    {
132        let sort = inputs
133            .into_iter()
134            .rev()
135            .fold(output, |output, input| Sort::Func(Box::new([input, output])));
136
137        (0..params)
138            .rev()
139            .fold(sort, |sort, i| Sort::Abs(i, Box::new(sort)))
140    }
141
142    pub(crate) fn peel_out_abs(&self) -> (usize, &Sort<T>) {
143        let mut n = 0;
144        let mut curr = self;
145        while let Sort::Abs(i, sort) = curr {
146            assert_eq!(*i, n);
147            n += 1;
148            curr = sort;
149        }
150        (n, curr)
151    }
152
153    fn free_var_sorts_to_int_help(&mut self, bound: &mut HashSet<usize>) {
154        match self {
155            Sort::Int
156            | Sort::Real
157            | Sort::Bool
158            | Sort::Str
159            | Sort::BvSize(..)
160            | Sort::BitVec(..) => {}
161            Sort::Abs(var, inner) => {
162                bound.insert(*var);
163                inner.free_var_sorts_to_int_help(bound);
164                bound.remove(var);
165            }
166            Sort::App(_, args) => {
167                for arg in args {
168                    arg.free_var_sorts_to_int_help(bound);
169                }
170            }
171            Sort::Func(inner) => {
172                let [arg, out] = &mut **inner;
173                arg.free_var_sorts_to_int_help(bound);
174                out.free_var_sorts_to_int_help(bound);
175            }
176            Sort::Var(v) => {
177                if !bound.contains(v) {
178                    *self = Sort::Int;
179                }
180            }
181        }
182    }
183
184    pub(crate) fn free_var_sorts_to_int(&mut self) {
185        let mut bound = HashSet::new();
186        self.free_var_sorts_to_int_help(&mut bound);
187    }
188}
189
190#[derive_where(Hash, Debug)]
191pub struct FunSort<T: Types> {
192    pub params: usize,
193    pub inputs: Vec<Sort<T>>,
194    pub output: Sort<T>,
195}
196
197impl<T: Types> FunSort<T> {
198    pub fn deps(&self, acc: &mut Vec<T::Sort>) {
199        for sort in &self.inputs {
200            sort.deps(acc);
201        }
202        self.output.deps(acc);
203    }
204
205    pub fn into_sort(self) -> Sort<T> {
206        Sort::mk_func(self.params, self.inputs, self.output)
207    }
208}
209
210#[derive_where(Hash, Clone, Debug)]
211pub enum SortCtor<T: Types> {
212    Set,
213    Map,
214    Data(T::Sort),
215}
216
217#[derive_where(Hash, Clone, Debug)]
218pub enum Pred<T: Types> {
219    And(Vec<Self>),
220    KVar(T::KVar, Vec<Expr<T>>),
221    Expr(Expr<T>),
222}
223
224impl<T: Types> Pred<T> {
225    pub const TRUE: Self = Pred::Expr(Expr::Constant(Constant::Boolean(true)));
226
227    pub fn and(mut preds: Vec<Self>) -> Self {
228        if preds.is_empty() {
229            Pred::TRUE
230        } else if preds.len() == 1 {
231            preds.remove(0)
232        } else {
233            Self::And(preds)
234        }
235    }
236
237    pub fn is_trivially_true(&self) -> bool {
238        match self {
239            Pred::Expr(Expr::Constant(Constant::Boolean(true))) => true,
240            Pred::And(ps) => ps.is_empty(),
241            _ => false,
242        }
243    }
244
245    pub fn is_concrete(&self) -> bool {
246        match self {
247            Pred::And(ps) => ps.iter().any(Pred::is_concrete),
248            Pred::KVar(_, _) => false,
249            Pred::Expr(_) => true,
250        }
251    }
252
253    #[cfg(feature = "rust-fixpoint")]
254    pub(crate) fn simplify(&mut self) {
255        if let Pred::And(conjuncts) = self {
256            if conjuncts.is_empty() {
257                *self = Pred::TRUE;
258            } else if conjuncts.len() == 1 {
259                *self = conjuncts[0].clone();
260            } else {
261                conjuncts.iter_mut().for_each(|pred| pred.simplify());
262            }
263        }
264    }
265}
266
267#[derive(Hash, Debug, Copy, Clone, PartialEq, Eq)]
268pub enum BinRel {
269    Eq,
270    Ne,
271    Gt,
272    Ge,
273    Lt,
274    Le,
275}
276
277impl BinRel {
278    pub const INEQUALITIES: [BinRel; 4] = [BinRel::Gt, BinRel::Ge, BinRel::Lt, BinRel::Le];
279}
280
281#[derive(Hash, Debug, Copy, Clone, PartialEq, Eq)]
282pub struct BoundVar {
283    pub level: usize,
284    pub idx: usize,
285}
286
287impl BoundVar {
288    pub fn new(level: usize, idx: usize) -> Self {
289        Self { level, idx }
290    }
291}
292
293#[derive_where(Hash, Clone, Debug)]
294pub enum Expr<T: Types> {
295    Constant(Constant<T>),
296    Var(T::Var),
297    App(Box<Self>, Option<Vec<Sort<T>>>, Vec<Self>, Option<Sort<T>>),
298    Neg(Box<Self>),
299    BinaryOp(BinOp, Box<[Self; 2]>),
300    IfThenElse(Box<[Self; 3]>),
301    And(Vec<Self>),
302    Or(Vec<Self>),
303    Not(Box<Self>),
304    Imp(Box<[Self; 2]>),
305    Iff(Box<[Self; 2]>),
306    Atom(BinRel, Box<[Self; 2]>),
307    Let(T::Var, Box<[Self; 2]>),
308    ThyFunc(ThyFunc),
309    IsCtor(T::Var, Box<Self>),
310    Quantifier(Quantifier, Vec<(T::Var, Sort<T>)>, Box<Self>),
311}
312
313impl<T: Types> From<Constant<T>> for Expr<T> {
314    fn from(v: Constant<T>) -> Self {
315        Self::Constant(v)
316    }
317}
318
319impl<T: Types> Expr<T> {
320    pub const fn int(val: u128) -> Expr<T> {
321        Expr::Constant(Constant::Numeral(val))
322    }
323
324    pub fn eq(self, other: Self) -> Self {
325        Expr::Atom(BinRel::Eq, Box::new([self, other]))
326    }
327
328    pub fn and(mut exprs: Vec<Self>) -> Self {
329        if exprs.len() == 1 { exprs.remove(0) } else { Self::And(exprs) }
330    }
331
332    pub fn var_sorts_to_int(&mut self) {
333        match self {
334            Expr::Constant(_) | Expr::ThyFunc(_) | Expr::Var(_) => {}
335            Expr::App(func, sort_args, args, out_sort) => {
336                func.var_sorts_to_int();
337                for arg in args {
338                    arg.var_sorts_to_int();
339                }
340                if let Some(sort_args) = sort_args {
341                    for sort_arg in sort_args {
342                        sort_arg.free_var_sorts_to_int();
343                    }
344                }
345                if let Some(out_sort) = out_sort {
346                    out_sort.free_var_sorts_to_int();
347                }
348            }
349            Expr::Neg(e) | Expr::Not(e) => {
350                e.var_sorts_to_int();
351            }
352            Expr::BinaryOp(_, exprs)
353            | Expr::Imp(exprs)
354            | Expr::Iff(exprs)
355            | Expr::Atom(_, exprs) => {
356                let [e1, e2] = &mut **exprs;
357                e1.var_sorts_to_int();
358                e2.var_sorts_to_int();
359            }
360            Expr::IfThenElse(exprs) => {
361                let [p, e1, e2] = &mut **exprs;
362                p.var_sorts_to_int();
363                e1.var_sorts_to_int();
364                e2.var_sorts_to_int();
365            }
366            Expr::And(exprs) | Expr::Or(exprs) => {
367                for e in exprs {
368                    e.var_sorts_to_int();
369                }
370            }
371            Expr::Let(_, exprs) => {
372                let [var_e, body_e] = &mut **exprs;
373                var_e.var_sorts_to_int();
374                body_e.var_sorts_to_int();
375            }
376            Expr::IsCtor(_v, expr) => {
377                expr.var_sorts_to_int();
378            }
379            Expr::Quantifier(_, binder, expr) => {
380                for (_, sort) in binder {
381                    sort.free_var_sorts_to_int();
382                }
383                expr.var_sorts_to_int();
384            }
385        }
386    }
387
388    pub fn free_vars(&self) -> IndexSet<T::Var> {
389        let mut vars = IndexSet::new();
390        match self {
391            Expr::Constant(_) | Expr::ThyFunc(_) => {}
392            Expr::Var(x) => {
393                vars.insert(x.clone());
394            }
395            Expr::App(func, _sort_args, args, _out_sort) => {
396                vars.extend(func.free_vars());
397                for arg in args {
398                    vars.extend(arg.free_vars());
399                }
400            }
401            Expr::Neg(e) | Expr::Not(e) => {
402                vars = e.free_vars();
403            }
404            Expr::BinaryOp(_, exprs)
405            | Expr::Imp(exprs)
406            | Expr::Iff(exprs)
407            | Expr::Atom(_, exprs) => {
408                let [e1, e2] = &**exprs;
409                vars.extend(e1.free_vars());
410                vars.extend(e2.free_vars());
411            }
412            Expr::IfThenElse(exprs) => {
413                let [p, e1, e2] = &**exprs;
414                vars.extend(p.free_vars());
415                vars.extend(e1.free_vars());
416                vars.extend(e2.free_vars());
417            }
418            Expr::And(exprs) | Expr::Or(exprs) => {
419                for e in exprs {
420                    vars.extend(e.free_vars());
421                }
422            }
423            Expr::Let(name, exprs) => {
424                // Fixpoint only support one binder per let expressions, but it parses a singleton
425                // list of binders to be forward-compatible
426                let [var_e, body_e] = &**exprs;
427                vars.extend(var_e.free_vars());
428                let mut body_vars = body_e.free_vars();
429                body_vars.swap_remove(name);
430                vars.extend(body_vars);
431            }
432            Expr::IsCtor(_v, expr) => {
433                // NOTE: (ck) I'm pretty sure this isn't a binder so I'm not going to
434                // bother with `v`.
435                vars.extend(expr.free_vars());
436            }
437            Expr::Quantifier(_, binder, expr) => {
438                let mut inner = expr.free_vars();
439                for (var, _sort) in binder {
440                    inner.swap_remove(var);
441                }
442                vars.extend(inner);
443            }
444        };
445        vars
446    }
447}
448
449#[derive_where(Hash, Clone, Debug)]
450pub enum Constant<T: Types> {
451    Numeral(u128),
452    Real(T::Real),
453    Boolean(bool),
454    String(T::String),
455    BitVec(u128, u32),
456}
457
458#[derive_where(Debug, Clone, Hash)]
459pub struct Qualifier<T: Types> {
460    pub name: String,
461    pub args: Vec<(T::Var, Sort<T>)>,
462    pub body: Expr<T>,
463}
464
465#[derive(Clone, Copy, PartialEq, Eq, Hash)]
466pub enum BinOp {
467    Add,
468    Sub,
469    Mul,
470    Div,
471    Mod,
472}
473
474#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)]
475pub enum Quantifier {
476    Exists,
477    Forall,
478}