flux_infer/
refine_tree.rs

1use std::{
2    cell::RefCell,
3    ops::ControlFlow,
4    rc::{Rc, Weak},
5};
6
7use flux_common::{index::IndexVec, iter::IterExt, tracked_span_bug};
8use flux_config::OverflowMode;
9use flux_macros::DebugAsJson;
10use flux_middle::{
11    global_env::GlobalEnv,
12    pretty::{PrettyCx, PrettyNested, format_cx},
13    queries::QueryResult,
14    rty::{
15        BaseTy, EVid, Expr, ExprKind, KVid, Name, NameProvenance, PrettyVar, Sort, Ty, TyKind, Var,
16        fold::{TypeFoldable, TypeSuperVisitable, TypeVisitable, TypeVisitor},
17    },
18};
19use itertools::Itertools;
20use rustc_data_structures::snapshot_map::SnapshotMap;
21use rustc_hash::{FxHashMap, FxHashSet};
22use rustc_index::newtype_index;
23use rustc_middle::ty::TyCtxt;
24use serde::Serialize;
25
26use crate::{
27    evars::EVarStore,
28    fixpoint_encoding::{FixpointCtxt, fixpoint},
29    infer::{Tag, TypeTrace},
30};
31
32/// A *refine*ment *tree* tracks the "tree-like structure" of refinement variables and predicates
33/// generated during refinement type-checking. This tree can be encoded as a fixpoint constraint
34/// whose satisfiability implies the safety of a function.
35///
36/// We try to hide the representation of the tree as much as possible and only a couple of operations
37/// can be used to manipulate the structure of the tree explicitly. Instead, the tree is mostly
38/// constructed implicitly via a restricted api provided by [`Cursor`]. Some methods operate on *nodes*
39/// of the tree which we try to keep abstract, but it is important to remember that there's an
40/// underlying tree.
41///
42/// The current implementation uses [`Rc`] and [`RefCell`] to represent the tree, but we ensure
43/// statically that the [`RefineTree`] is the single owner of the data and require a mutable reference
44/// to it for all mutations, i.e., we could in theory replace the [`RefCell`] with an [`UnsafeCell`]
45/// (or a [`GhostCell`]).
46///
47/// [`UnsafeCell`]: std::cell::UnsafeCell
48/// [`GhostCell`]: https://docs.rs/ghost-cell/0.2.3/ghost_cell/ghost_cell/struct.GhostCell.html
49pub struct RefineTree {
50    root: NodePtr,
51}
52
53impl RefineTree {
54    pub(crate) fn new(params: Vec<(Var, Sort)>) -> RefineTree {
55        let root =
56            Node { kind: NodeKind::Root(params), nbindings: 0, parent: None, children: vec![] };
57        let root = NodePtr(Rc::new(RefCell::new(root)));
58        RefineTree { root }
59    }
60
61    pub(crate) fn simplify(&mut self, genv: GlobalEnv) {
62        self.root
63            .borrow_mut()
64            .simplify(SimplifyPhase::Full(genv), &mut SnapshotMap::default());
65        self.root.borrow_mut().simplify_bot();
66        self.root.borrow_mut().simplify_top();
67    }
68
69    pub(crate) fn into_fixpoint(
70        self,
71        cx: &mut FixpointCtxt<Tag>,
72    ) -> QueryResult<fixpoint::Constraint> {
73        Ok(self
74            .root
75            .borrow()
76            .to_fixpoint(cx)?
77            .unwrap_or(fixpoint::Constraint::TRUE))
78    }
79
80    pub(crate) fn cursor_at_root(&mut self) -> Cursor<'_> {
81        Cursor { ptr: NodePtr(Rc::clone(&self.root)), tree: self }
82    }
83
84    pub(crate) fn replace_evars(&mut self, evars: &EVarStore) -> Result<(), EVid> {
85        self.root.borrow_mut().replace_evars(evars)
86    }
87}
88
89/// A cursor into the [refinement tree]. More specifically, a [`Cursor`] represents a path from the
90/// root to some internal node in a [refinement tree].
91///
92/// [refinement tree]: RefineTree
93pub struct Cursor<'a> {
94    tree: &'a mut RefineTree,
95    ptr: NodePtr,
96}
97
98impl Cursor<'_> {
99    /// Moves the cursor to the specified [marker]. If `clear_children` is `true`, all children of
100    /// the node are removed after moving the cursor, invalidating any markers pointing to a node
101    /// within those children.
102    ///
103    /// [marker]: Marker
104    pub(crate) fn move_to(&mut self, marker: &Marker, clear_children: bool) -> Option<Cursor<'_>> {
105        let ptr = marker.ptr.upgrade()?;
106        if clear_children {
107            ptr.borrow_mut().children.clear();
108        }
109        Some(Cursor { ptr, tree: self.tree })
110    }
111
112    /// Returns a marker to the current node
113    #[must_use]
114    pub(crate) fn marker(&self) -> Marker {
115        Marker { ptr: NodePtr::downgrade(&self.ptr) }
116    }
117
118    #[must_use]
119    pub(crate) fn branch(&mut self) -> Cursor<'_> {
120        Cursor { tree: self.tree, ptr: NodePtr::clone(&self.ptr) }
121    }
122
123    pub(crate) fn vars(&self) -> impl Iterator<Item = (Var, Sort)> {
124        // TODO(nilehmann) we could incrementally cache the scope
125        self.ptr.scope().into_iter()
126    }
127
128    #[expect(dead_code, reason = "used for debugging")]
129    pub(crate) fn push_trace(&mut self, trace: TypeTrace) {
130        self.ptr = self.ptr.push_node(NodeKind::Trace(trace));
131    }
132
133    /// Defines a fresh refinement variable with the given `sort` and advance the cursor to the new
134    /// node. It returns the freshly generated name for the variable.
135    pub(crate) fn define_var(&mut self, sort: &Sort, provenance: NameProvenance) -> Name {
136        let fresh = Name::from_usize(self.ptr.next_name_idx());
137        self.ptr = self
138            .ptr
139            .push_node(NodeKind::ForAll(fresh, sort.clone(), provenance));
140        fresh
141    }
142
143    /// Pushes an [assumption] and moves the cursor into the new node.
144    ///
145    /// [assumption]: NodeKind::Assumption
146    pub(crate) fn assume_pred(&mut self, pred: impl Into<Expr>) {
147        let pred = pred.into();
148        if !pred.is_trivially_true() {
149            self.ptr = self.ptr.push_node(NodeKind::Assumption(pred));
150        }
151    }
152
153    /// Pushes a predicate that must be true assuming variables and predicates in the current branch
154    /// of the tree (i.e., it pushes a [`NodeKind::Head`]). This methods does not advance the cursor.
155    pub(crate) fn check_pred(&mut self, pred: impl Into<Expr>, tag: Tag) {
156        let pred = pred.into();
157        if !pred.is_trivially_true() {
158            self.ptr.push_node(NodeKind::Head(pred, tag));
159        }
160    }
161
162    /// Convenience method to push an assumption followed by a predicate that needs to be checked.
163    /// This method does not advance the cursor.
164    pub(crate) fn check_impl(&mut self, pred1: impl Into<Expr>, pred2: impl Into<Expr>, tag: Tag) {
165        self.ptr
166            .push_node(NodeKind::Assumption(pred1.into()))
167            .push_node(NodeKind::Head(pred2.into(), tag));
168    }
169
170    pub(crate) fn assume_invariants(
171        &mut self,
172        tcx: TyCtxt,
173        ty: &Ty,
174        overflow_checking: OverflowMode,
175    ) {
176        struct Visitor<'a, 'b, 'tcx> {
177            tcx: TyCtxt<'tcx>,
178            cursor: &'a mut Cursor<'b>,
179            overflow_mode: OverflowMode,
180        }
181        impl TypeVisitor for Visitor<'_, '_, '_> {
182            fn visit_bty(&mut self, bty: &BaseTy) -> ControlFlow<!> {
183                match bty {
184                    BaseTy::Adt(adt_def, substs) if adt_def.is_box() => substs.visit_with(self),
185                    BaseTy::Ref(_, ty, _) => ty.visit_with(self),
186                    BaseTy::Tuple(tys) => tys.visit_with(self),
187                    _ => ControlFlow::Continue(()),
188                }
189            }
190
191            fn visit_ty(&mut self, ty: &Ty) -> ControlFlow<!> {
192                if let TyKind::Indexed(bty, idx) = ty.kind()
193                    && !idx.has_escaping_bvars()
194                {
195                    for invariant in bty.invariants(self.tcx, self.overflow_mode) {
196                        let invariant = invariant.apply(idx);
197                        self.cursor.assume_pred(&invariant);
198                    }
199                }
200                ty.super_visit_with(self)
201            }
202        }
203        let _ = ty.visit_with(&mut Visitor { tcx, cursor: self, overflow_mode: overflow_checking });
204    }
205}
206
207/// A marker is a pointer to a node in the [refinement tree] that can be used to query information
208/// about that node or to move the cursor. A marker may become invalid if the underlying node is
209/// [cleared].
210///
211/// [cleared]: Cursor::move_to
212/// [refinement tree]: RefineTree
213pub struct Marker {
214    ptr: WeakNodePtr,
215}
216
217impl Marker {
218    /// Returns the [`scope`] at the marker if it is still valid or [`None`] otherwise.
219    ///
220    /// [`scope`]: Scope
221    pub fn scope(&self) -> Option<Scope> {
222        Some(self.ptr.upgrade()?.scope())
223    }
224
225    pub fn has_free_vars<T: TypeVisitable>(&self, t: &T) -> bool {
226        let ptr = self
227            .ptr
228            .upgrade()
229            .unwrap_or_else(|| tracked_span_bug!("`has_free_vars` called on invalid `Marker`"));
230
231        let nbindings = ptr.next_name_idx();
232
233        !t.fvars().into_iter().all(|name| name.index() < nbindings)
234    }
235}
236
237/// A list of refinement variables and their sorts.
238#[derive(PartialEq, Eq)]
239pub struct Scope {
240    params: Vec<(Var, Sort)>,
241    bindings: IndexVec<Name, Sort>,
242}
243
244impl Scope {
245    pub(crate) fn iter(&self) -> impl Iterator<Item = (Var, Sort)> + '_ {
246        itertools::chain(
247            self.params.iter().cloned(),
248            self.bindings
249                .iter_enumerated()
250                .map(|(name, sort)| (Var::Free(name), sort.clone())),
251        )
252    }
253
254    fn into_iter(self) -> impl Iterator<Item = (Var, Sort)> {
255        itertools::chain(
256            self.params,
257            self.bindings
258                .into_iter_enumerated()
259                .map(|(name, sort)| (Var::Free(name), sort.clone())),
260        )
261    }
262
263    /// Whether `t` has any free variables not in this scope
264    pub fn has_free_vars<T: TypeFoldable>(&self, t: &T) -> bool {
265        !self.contains_all(t.fvars())
266    }
267
268    fn contains_all(&self, iter: impl IntoIterator<Item = Name>) -> bool {
269        iter.into_iter().all(|name| self.contains(name))
270    }
271
272    fn contains(&self, name: Name) -> bool {
273        name.index() < self.bindings.len()
274    }
275}
276
277struct Node {
278    kind: NodeKind,
279    /// Number of bindings between the root and this node's parent, i.e., we have
280    /// as an invariant that `nbindings` equals the number of [`NodeKind::ForAll`]
281    /// nodes from the parent of this node to the root.
282    nbindings: usize,
283    parent: Option<WeakNodePtr>,
284    children: Vec<NodePtr>,
285}
286
287#[derive(Clone)]
288struct NodePtr(Rc<RefCell<Node>>);
289
290impl NodePtr {
291    fn downgrade(this: &Self) -> WeakNodePtr {
292        WeakNodePtr(Rc::downgrade(&this.0))
293    }
294
295    fn push_node(&mut self, kind: NodeKind) -> NodePtr {
296        debug_assert!(!matches!(self.borrow().kind, NodeKind::Head(..)));
297        let node = Node {
298            kind,
299            nbindings: self.next_name_idx(),
300            parent: Some(NodePtr::downgrade(self)),
301            children: vec![],
302        };
303        let node = NodePtr(Rc::new(RefCell::new(node)));
304        self.borrow_mut().children.push(NodePtr::clone(&node));
305        node
306    }
307
308    fn next_name_idx(&self) -> usize {
309        self.borrow().nbindings + usize::from(self.borrow().is_forall())
310    }
311
312    fn scope(&self) -> Scope {
313        let mut params = None;
314        let parents = ParentsIter::new(self.clone());
315        let bindings = parents
316            .filter_map(|node| {
317                let node = node.borrow();
318                match &node.kind {
319                    NodeKind::Root(p) => {
320                        params = Some(p.clone());
321                        None
322                    }
323                    NodeKind::ForAll(_, sort, _) => Some(sort.clone()),
324                    _ => None,
325                }
326            })
327            .collect_vec()
328            .into_iter()
329            .rev()
330            .collect();
331        Scope { bindings, params: params.unwrap_or_default() }
332    }
333}
334
335struct WeakNodePtr(Weak<RefCell<Node>>);
336
337impl WeakNodePtr {
338    fn upgrade(&self) -> Option<NodePtr> {
339        Some(NodePtr(self.0.upgrade()?))
340    }
341}
342
343enum NodeKind {
344    /// List of const and refinement generics
345    Root(Vec<(Var, Sort)>),
346    ForAll(Name, Sort, NameProvenance),
347    Assumption(Expr),
348    Head(Expr, Tag),
349    True,
350    /// Used for debugging. See [`TypeTrace`]
351    Trace(TypeTrace),
352}
353
354impl std::ops::Index<Name> for Scope {
355    type Output = Sort;
356
357    fn index(&self, name: Name) -> &Self::Output {
358        &self.bindings[name]
359    }
360}
361
362impl std::ops::Deref for NodePtr {
363    type Target = Rc<RefCell<Node>>;
364
365    fn deref(&self) -> &Self::Target {
366        &self.0
367    }
368}
369
370#[derive(Clone, Copy)]
371enum SimplifyPhase<'genv, 'tcx> {
372    /// Normalize and simplify inner `Expr`
373    Full(GlobalEnv<'genv, 'tcx>),
374    /// Only propagate `true` (TOP) and `false` (BOT)
375    Partial,
376}
377
378impl Node {
379    fn simplify(&mut self, phase: SimplifyPhase, assumed_preds: &mut SnapshotMap<Expr, ()>) {
380        // First, simplify the node itself
381        match &mut self.kind {
382            NodeKind::Head(pred, tag) => {
383                let pred = match phase {
384                    SimplifyPhase::Full(genv) => pred.normalize(genv).simplify(assumed_preds),
385                    SimplifyPhase::Partial => pred.clone(),
386                };
387                if pred.is_trivially_true() {
388                    self.kind = NodeKind::True;
389                } else {
390                    self.kind = NodeKind::Head(pred, *tag);
391                }
392            }
393            NodeKind::Assumption(pred) => {
394                if let SimplifyPhase::Full(genv) = phase {
395                    *pred = pred.normalize(genv).simplify(assumed_preds);
396                }
397                pred.visit_conj(|conjunct| {
398                    assumed_preds.insert(conjunct.erase_spans(), ());
399                });
400            }
401            _ => {}
402        }
403        let is_false_asm =
404            matches!(&self.kind, NodeKind::Assumption(pred) if pred.is_trivially_false());
405
406        // Then simplify the children
407        // (the order matters here because we need to collect assumed preds first)
408        for child in &self.children {
409            let current_version = assumed_preds.snapshot();
410            child.borrow_mut().simplify(phase, assumed_preds);
411            assumed_preds.rollback_to(current_version);
412        }
413
414        // Then remove any unnecessary children
415        match &mut self.kind {
416            NodeKind::Head(..) | NodeKind::True => {}
417            NodeKind::Assumption(_)
418            | NodeKind::Trace(_)
419            | NodeKind::Root(_)
420            | NodeKind::ForAll(..) => {
421                self.children
422                    .extract_if(.., |child| {
423                        is_false_asm || matches!(&child.borrow().kind, NodeKind::True)
424                    })
425                    .for_each(drop);
426            }
427        }
428        if !self.is_leaf() && self.children.is_empty() {
429            self.kind = NodeKind::True;
430        }
431    }
432
433    fn is_leaf(&self) -> bool {
434        matches!(self.kind, NodeKind::Head(..) | NodeKind::True)
435    }
436
437    fn replace_evars(&mut self, evars: &EVarStore) -> Result<(), EVid> {
438        for child in &self.children {
439            child.borrow_mut().replace_evars(evars)?;
440        }
441        match &mut self.kind {
442            NodeKind::Assumption(pred) => *pred = evars.replace_evars(pred)?,
443            NodeKind::Head(pred, _) => {
444                *pred = evars.replace_evars(pred)?;
445            }
446            NodeKind::Trace(trace) => {
447                evars.replace_evars(trace)?;
448            }
449            NodeKind::Root(_) | NodeKind::ForAll(..) | NodeKind::True => {}
450        }
451        Ok(())
452    }
453
454    fn to_fixpoint(&self, cx: &mut FixpointCtxt<Tag>) -> QueryResult<Option<fixpoint::Constraint>> {
455        let cstr = match &self.kind {
456            NodeKind::Trace(_) | NodeKind::ForAll(_, Sort::Loc, _) => {
457                children_to_fixpoint(cx, &self.children)?
458            }
459
460            NodeKind::Root(params) => {
461                // declare pretty-vars for params
462                for (var, sort) in params {
463                    if let Var::EarlyParam(param) = var
464                        && !sort.is_loc()
465                    {
466                        cx.with_early_param(param);
467                    }
468                }
469
470                let Some(children) = children_to_fixpoint(cx, &self.children)? else {
471                    return Ok(None);
472                };
473                let mut constr = children;
474                for (var, sort) in params.iter().rev() {
475                    if sort.is_loc() {
476                        continue;
477                    }
478                    constr = fixpoint::Constraint::ForAll(
479                        fixpoint::Bind {
480                            name: cx.var_to_fixpoint(var),
481                            sort: cx.sort_to_fixpoint(sort),
482                            pred: fixpoint::Pred::TRUE,
483                        },
484                        Box::new(constr),
485                    );
486                }
487                Some(constr)
488            }
489            NodeKind::ForAll(name, sort, provenance) => {
490                cx.with_name_map(*name, *provenance, |cx, fresh| -> QueryResult<_> {
491                    let Some(children) = children_to_fixpoint(cx, &self.children)? else {
492                        return Ok(None);
493                    };
494                    Ok(Some(fixpoint::Constraint::ForAll(
495                        fixpoint::Bind {
496                            name: fixpoint::Var::Local(fresh),
497                            sort: cx.sort_to_fixpoint(sort),
498                            pred: fixpoint::Pred::TRUE,
499                        },
500                        Box::new(children),
501                    )))
502                })?
503            }
504            NodeKind::Assumption(pred) => {
505                let (mut bindings, pred) = cx.assumption_to_fixpoint(pred)?;
506                let Some(cstr) = children_to_fixpoint(cx, &self.children)? else {
507                    return Ok(None);
508                };
509                bindings.push(fixpoint::Bind {
510                    name: fixpoint::Var::Underscore,
511                    sort: fixpoint::Sort::Int,
512                    pred,
513                });
514                Some(fixpoint::Constraint::foralls(bindings, cstr))
515            }
516            NodeKind::Head(pred, tag) => {
517                Some(cx.head_to_fixpoint(pred, |span| tag.with_dst(span))?)
518            }
519            NodeKind::True => None,
520        };
521        Ok(cstr)
522    }
523
524    /// Returns `true` if the node kind is [`ForAll`].
525    ///
526    /// [`ForAll`]: NodeKind::ForAll
527    fn is_forall(&self) -> bool {
528        matches!(self.kind, NodeKind::ForAll(..))
529    }
530
531    /// Returns `true` if the node kind is [`Head`].
532    ///
533    /// [`Head`]: NodeKind::Head
534    fn is_head(&self) -> bool {
535        matches!(self.kind, NodeKind::Head(..))
536    }
537}
538
539fn children_to_fixpoint(
540    cx: &mut FixpointCtxt<Tag>,
541    children: &[NodePtr],
542) -> QueryResult<Option<fixpoint::Constraint>> {
543    let mut children = children
544        .iter()
545        .filter_map(|node| node.borrow().to_fixpoint(cx).transpose())
546        .try_collect_vec()?;
547    let cstr = match children.len() {
548        0 => None,
549        1 => children.pop(),
550        _ => Some(fixpoint::Constraint::Conj(children)),
551    };
552    Ok(cstr)
553}
554
555struct ParentsIter {
556    ptr: Option<NodePtr>,
557}
558
559impl ParentsIter {
560    fn new(ptr: NodePtr) -> Self {
561        Self { ptr: Some(ptr) }
562    }
563}
564
565impl Iterator for ParentsIter {
566    type Item = NodePtr;
567
568    fn next(&mut self) -> Option<Self::Item> {
569        if let Some(ptr) = self.ptr.take() {
570            self.ptr = ptr.borrow().parent.as_ref().and_then(WeakNodePtr::upgrade);
571            Some(ptr)
572        } else {
573            None
574        }
575    }
576}
577
578mod pretty {
579    use std::fmt::{self, Write};
580
581    use flux_middle::pretty::*;
582    use pad_adapter::PadAdapter;
583
584    use super::*;
585
586    fn bindings_chain(ptr: &NodePtr) -> (Vec<(Name, Sort)>, Vec<NodePtr>) {
587        fn go(ptr: &NodePtr, mut bindings: Vec<(Name, Sort)>) -> (Vec<(Name, Sort)>, Vec<NodePtr>) {
588            let node = ptr.borrow();
589            if let NodeKind::ForAll(name, sort, _) = &node.kind {
590                bindings.push((*name, sort.clone()));
591                if let [child] = &node.children[..] {
592                    go(child, bindings)
593                } else {
594                    (bindings, node.children.clone())
595                }
596            } else {
597                (bindings, vec![NodePtr::clone(ptr)])
598            }
599        }
600        go(ptr, vec![])
601    }
602
603    fn preds_chain(ptr: &NodePtr) -> (Vec<Expr>, Vec<NodePtr>) {
604        fn go(ptr: &NodePtr, mut preds: Vec<Expr>) -> (Vec<Expr>, Vec<NodePtr>) {
605            let node = ptr.borrow();
606            if let NodeKind::Assumption(pred) = &node.kind {
607                preds.push(pred.clone());
608                if let [child] = &node.children[..] {
609                    go(child, preds)
610                } else {
611                    (preds, node.children.clone())
612                }
613            } else {
614                (preds, vec![NodePtr::clone(ptr)])
615            }
616        }
617        go(ptr, vec![])
618    }
619
620    impl Pretty for RefineTree {
621        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
622            w!(cx, f, "{:?}", &self.root)
623        }
624    }
625
626    impl Pretty for NodePtr {
627        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
628            let node = self.borrow();
629            match &node.kind {
630                NodeKind::Trace(trace) => {
631                    w!(cx, f, "@ {:?}", ^trace)?;
632                    w!(cx, with_padding(f), "\n{:?}", join!("\n", &node.children))
633                }
634                NodeKind::Root(bindings) => {
635                    w!(cx, f,
636                        "∀ {}.",
637                        ^bindings
638                            .iter()
639                            .format_with(", ", |(name, sort), f| {
640                                f(&format_args_cx!(cx, "{:?}: {:?}", ^name, sort))
641                            })
642                    )?;
643                    fmt_children(&node.children, cx, f)
644                }
645                NodeKind::ForAll(name, sort, _) => {
646                    let (bindings, children) = if cx.bindings_chain {
647                        bindings_chain(self)
648                    } else {
649                        (vec![(*name, sort.clone())], node.children.clone())
650                    };
651
652                    w!(cx, f,
653                        "∀ {}.",
654                        ^bindings
655                            .into_iter()
656                            .format_with(", ", |(name, sort), f| {
657                                f(&format_args_cx!(cx, "{:?}: {:?}", ^name, sort))
658                            })
659                    )?;
660                    fmt_children(&children, cx, f)
661                }
662                NodeKind::Assumption(pred) => {
663                    let (preds, children) = if cx.preds_chain {
664                        preds_chain(self)
665                    } else {
666                        (vec![pred.clone()], node.children.clone())
667                    };
668                    let guard = Expr::and_from_iter(preds).simplify(&SnapshotMap::default());
669                    w!(cx, f, "{:?} =>", parens!(guard, !guard.is_atom()))?;
670                    fmt_children(&children, cx, f)
671                }
672                NodeKind::Head(pred, tag) => {
673                    let pred = if cx.simplify_exprs {
674                        pred.simplify(&SnapshotMap::default())
675                    } else {
676                        pred.clone()
677                    };
678                    w!(cx, f, "{:?}", parens!(pred, !pred.is_atom()))?;
679                    if cx.tags {
680                        w!(cx, f, " ~ {:?}", tag)?;
681                    }
682                    Ok(())
683                }
684                NodeKind::True => {
685                    w!(cx, f, "true")
686                }
687            }
688        }
689    }
690
691    fn fmt_children(
692        children: &[NodePtr],
693        cx: &PrettyCx,
694        f: &mut fmt::Formatter<'_>,
695    ) -> fmt::Result {
696        match children {
697            [] => w!(cx, f, " true"),
698            [n] => {
699                if n.borrow().is_head() {
700                    w!(cx, f, " {:?}", n)
701                } else {
702                    w!(cx, with_padding(f), "\n{:?}", n)
703                }
704            }
705            _ => w!(cx, with_padding(f), "\n{:?}", join!("\n", children)),
706        }
707    }
708
709    impl Pretty for Cursor<'_> {
710        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
711            let mut elements = vec![];
712            for node in ParentsIter::new(NodePtr::clone(&self.ptr)) {
713                let n = node.borrow();
714                match &n.kind {
715                    NodeKind::Root(bindings) => {
716                        // We reverse here because is reversed again at the end
717                        for (name, sort) in bindings.iter().rev() {
718                            elements.push(format_cx!(cx, "{:?}: {:?}", ^name, sort));
719                        }
720                    }
721                    NodeKind::ForAll(name, sort, _) => {
722                        elements.push(format_cx!(cx, "{:?}: {:?}", ^name, sort));
723                    }
724                    NodeKind::Assumption(pred) => {
725                        elements.push(format_cx!(cx, "{:?}", pred));
726                    }
727                    _ => {}
728                }
729            }
730            write!(f, "{{{}}}", elements.into_iter().rev().format(", "))
731        }
732    }
733
734    impl Pretty for Scope {
735        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
736            write!(
737                f,
738                "[bindings = {}, reftgenerics = {}]",
739                self.bindings
740                    .iter_enumerated()
741                    .format_with(", ", |(name, sort), f| {
742                        f(&format_args_cx!(cx, "{:?}: {:?}", ^name, sort))
743                    }),
744                self.params
745                    .iter()
746                    .format_with(", ", |(param_const, sort), f| {
747                        f(&format_args_cx!(cx, "{:?}: {:?}", ^param_const, sort))
748                    }),
749            )
750        }
751    }
752
753    fn with_padding<'a, 'b>(f: &'a mut fmt::Formatter<'b>) -> PadAdapter<'a, 'b, 'static> {
754        PadAdapter::with_padding(f, "  ")
755    }
756
757    impl_debug_with_default_cx!(
758        RefineTree => "refine_tree",
759        Cursor<'_> => "cursor",
760        Scope,
761    );
762}
763
764/// An explicit representation of a path in the [`RefineTree`] used for debugging/tracing/serialization ONLY.
765#[derive(Serialize, DebugAsJson)]
766pub struct RefineCtxtTrace {
767    bindings: Vec<RcxBind>,
768    exprs: Vec<String>,
769}
770
771#[derive(Serialize)]
772struct RcxBind {
773    name: String,
774    sort: String,
775}
776
777impl RefineCtxtTrace {
778    pub fn new(cx: &mut PrettyCx, cursor: &Cursor) -> Self {
779        let parents = ParentsIter::new(NodePtr::clone(&cursor.ptr)).collect_vec();
780        let mut bindings = vec![];
781        let mut exprs = vec![];
782
783        parents.into_iter().rev().for_each(|ptr| {
784            let node = ptr.borrow();
785            match &node.kind {
786                NodeKind::ForAll(name, sort, provenance) => {
787                    let name = cx
788                        .pretty_var_env
789                        .set(PrettyVar::Local(*name), provenance.opt_symbol());
790                    let sort = format_cx!(cx, "{:?}", sort);
791                    let bind = RcxBind { name, sort };
792                    bindings.push(bind);
793                }
794                NodeKind::Assumption(e)
795                    if !e.simplify(&SnapshotMap::default()).is_trivially_true() =>
796                {
797                    e.visit_conj(|e| {
798                        exprs.push(e.nested_string(cx));
799                    });
800                }
801                NodeKind::Root(binds) => {
802                    for (name, sort) in binds {
803                        let name = if let Var::EarlyParam(param) = name {
804                            cx.pretty_var_env
805                                .set(PrettyVar::Param(*param), Some(param.name))
806                        } else {
807                            format_cx!(cx, "{:?}", name)
808                        };
809                        let sort = format_cx!(cx, "{:?}", sort);
810                        let bind = RcxBind { name, sort };
811                        bindings.push(bind);
812                    }
813                }
814                _ => (),
815            }
816        });
817        Self { bindings, exprs }
818    }
819}
820
821impl Node {
822    /// replace bot-kvars with false
823    fn simplify_bot(&mut self) {
824        let graph = ConstraintDeps::new(self);
825        let bots = graph.bot_kvars();
826        self.simplify_with_assignment(&bots);
827        self.simplify(SimplifyPhase::Partial, &mut SnapshotMap::default());
828    }
829
830    /// replace top-kvars with true
831    fn simplify_top(&mut self) {
832        let graph = ConstraintDeps::new(self);
833        let tops = graph.top_kvars();
834        self.simplify_with_assignment(&tops);
835        self.simplify(SimplifyPhase::Partial, &mut SnapshotMap::default());
836    }
837
838    /// simplifies assumptions and heads using the TOP/BOT kvar assignment; follow
839    /// with a call to `simplify` to delete constraints with FALSE assm.
840    fn simplify_with_assignment(&mut self, assignment: &Assignment) {
841        match &mut self.kind {
842            NodeKind::Head(pred, tag) => {
843                let pred = assignment.simplify(pred);
844                self.kind = NodeKind::Head(pred, *tag);
845            }
846            NodeKind::Assumption(pred) => {
847                let pred = assignment.simplify(pred);
848                self.kind = NodeKind::Assumption(pred);
849            }
850            _ => {}
851        }
852        for child in &self.children {
853            child.borrow_mut().simplify_with_assignment(assignment);
854        }
855    }
856}
857
858#[derive(Debug)]
859struct ConstraintDeps {
860    /// assumptions for each clause
861    assumptions: IndexVec<ClauseId, FxHashSet<KVid>>,
862    /// head of each clause
863    heads: IndexVec<ClauseId, Head>,
864}
865
866impl ConstraintDeps {
867    fn new(node: &Node) -> Self {
868        let mut graph = Self { assumptions: IndexVec::default(), heads: IndexVec::default() };
869        graph.build(node, &mut vec![]);
870        graph
871    }
872
873    fn insert_clause(&mut self, assumptions: &[KVid], head: Head) {
874        self.assumptions.push(assumptions.iter().copied().collect());
875        self.heads.push(head);
876    }
877
878    fn build(&mut self, node: &Node, assumptions: &mut Vec<KVid>) {
879        let n = assumptions.len();
880        match &node.kind {
881            NodeKind::Head(expr, _) => {
882                expr.visit_conj(|e| {
883                    if let ExprKind::KVar(kvar) = e.kind() {
884                        self.insert_clause(assumptions, Head::KVar(kvar.kvid));
885                    } else {
886                        self.insert_clause(assumptions, Head::Conc);
887                    }
888                });
889            }
890            NodeKind::Assumption(expr) => {
891                expr.visit_conj(|e| {
892                    if let ExprKind::KVar(kvar) = e.kind() {
893                        assumptions.push(kvar.kvid);
894                    }
895                });
896            }
897            _ => {}
898        };
899
900        for child in &node.children {
901            self.build(&child.borrow(), assumptions);
902        }
903
904        assumptions.truncate(n); // restore ctx
905    }
906
907    /// set of edges where kvid appears as ASSM
908    fn kv_lhs(&self) -> FxHashMap<KVid, Vec<ClauseId>> {
909        let mut res: FxHashMap<KVid, Vec<ClauseId>> = FxHashMap::default();
910        for (clause_id, kvids) in self.assumptions.iter_enumerated() {
911            for kvid in kvids {
912                res.entry(*kvid).or_default().push(clause_id);
913            }
914        }
915        res
916    }
917
918    /// set of edges where kvid appears as HEAD
919    fn kv_rhs(&self) -> FxHashMap<KVid, Vec<ClauseId>> {
920        let mut res: FxHashMap<KVid, Vec<ClauseId>> = FxHashMap::default();
921        for (clause_id, head) in self.heads.iter_enumerated() {
922            if let Head::KVar(kvid) = head {
923                res.entry(*kvid).or_default().push(clause_id);
924            }
925        }
926        res
927    }
928
929    /// Computes the set of all kvars that can be assigned to Bot (False),
930    /// because they are not (transitively) reachable from any concrete ASSUMPTION.
931    fn bot_kvars(self) -> Assignment {
932        // set of BOT kvars (initially, all)
933        let mut assignment = Assignment::new(Label::Bot);
934
935        let kv_lhs = self.kv_lhs();
936
937        // set of BOT kvars in LHS of each constraint with KVar HEAD
938        let mut bot_assms: IndexVec<ClauseId, FxHashSet<KVid>> = self.assumptions;
939
940        // set of constraints `cid` whose bot-assms is empty
941        let mut candidates: Vec<ClauseId> = bot_assms
942            .iter_enumerated()
943            .filter_map(|(cid, lhs)| if lhs.is_empty() { Some(cid) } else { None })
944            .collect();
945
946        // while there is a candidate constraint, that has NO BOT kvars in lhs
947        while let Some(cid) = candidates.pop() {
948            if let Head::KVar(kvid) = self.heads[cid] {
949                // un-BOT the head kvar
950                assignment.remove(kvid);
951                // remove the head kvar from all (bot) assumptions where it currently occurs
952                for cid in kv_lhs.get(&kvid).unwrap_or(&vec![]) {
953                    // if cid HEAD is a kvar
954                    if let Head::KVar(rhs_kvid) = self.heads[*cid] {
955                        let assms = &mut bot_assms[*cid];
956                        assms.remove(&kvid);
957                        if assignment.has_label(rhs_kvid) && assms.is_empty() {
958                            candidates.push(*cid);
959                        }
960                    };
961                }
962            }
963        }
964
965        assignment
966    }
967
968    /// Computes the set of all kvars that can be assigned to Top (True),
969    /// because they do not (transitively) reach any concrete HEAD.
970    fn top_kvars(self) -> Assignment {
971        // initialize
972        let mut assignment = Assignment::new(Label::Top);
973
974        let kv_rhs = self.kv_rhs();
975
976        // set of kvar {k | cid in graph.edges, c.rhs is concrete, k in c.lhs }
977        let mut candidates = vec![];
978        for (cid, head) in self.heads.iter_enumerated() {
979            if matches!(head, Head::Conc) {
980                for kvid in &self.assumptions[cid] {
981                    candidates.push(*kvid);
982                }
983            }
984        }
985
986        // set each kvar that transitively reaches a concrete HEAD to NON-BOT
987        while let Some(kvid) = candidates.pop() {
988            // set that kvar to non-top
989            assignment.remove(kvid);
990
991            // for each constraint where kvid appears as head
992            for cid in kv_rhs.get(&kvid).unwrap_or(&vec![]) {
993                // add kvars in lhs to candidates (if they have not already been solved to non-BOT)
994                for lhs_kvid in &self.assumptions[*cid] {
995                    if assignment.has_label(*lhs_kvid) {
996                        candidates.push(*lhs_kvid);
997                    }
998                }
999            }
1000        }
1001
1002        assignment
1003    }
1004}
1005
1006newtype_index! {
1007    struct ClauseId {}
1008}
1009
1010#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1011enum Head {
1012    /// KVar
1013    KVar(KVid),
1014    /// A *conc*rete predicate, i.e., an [`Expr`] that's not a kvar. We don't need to know
1015    /// the exact expression, only that it's concrete.
1016    Conc,
1017}
1018
1019#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1020enum Label {
1021    /// Kvar can be solved to false
1022    Bot,
1023    /// Kvar can be solved to true
1024    Top,
1025}
1026
1027struct Assignment {
1028    /// These vars are NOT assigned `label`,
1029    /// all other `KVid` implicitly have assignment `label`
1030    vars: FxHashSet<KVid>,
1031    label: Label,
1032}
1033
1034impl Assignment {
1035    fn new(label: Label) -> Self {
1036        let vars = FxHashSet::default();
1037        Self { vars, label }
1038    }
1039
1040    fn has_label(&self, kvid: KVid) -> bool {
1041        !self.vars.contains(&kvid)
1042    }
1043
1044    fn remove(&mut self, kvid: KVid) {
1045        self.vars.insert(kvid);
1046    }
1047
1048    /// simplifies the given predicate expression by replacing
1049    /// kvid assigned to TOP with True, BOT with false.
1050    fn simplify(&self, pred: &Expr) -> Expr {
1051        let mut preds = vec![];
1052        for p in pred.flatten_conjs() {
1053            if let ExprKind::KVar(kvar) = p.kind()
1054                && self.has_label(kvar.kvid)
1055            {
1056                if self.label == Label::Bot {
1057                    return Expr::ff();
1058                } // else, skip pushing `p` into `preds`
1059            } else {
1060                preds.push(p.clone());
1061            }
1062        }
1063        Expr::and_from_iter(preds)
1064    }
1065}