flux_infer/
refine_tree.rs

1use std::{
2    cell::RefCell,
3    ops::ControlFlow,
4    rc::{Rc, Weak},
5};
6
7use flux_common::{index::IndexVec, iter::IterExt, tracked_span_bug};
8use flux_config::OverflowMode;
9use flux_macros::DebugAsJson;
10use flux_middle::{
11    global_env::GlobalEnv,
12    pretty::{PrettyCx, PrettyNested, format_cx},
13    queries::QueryResult,
14    rty::{
15        BaseTy, EVid, Expr, ExprKind, KVid, Name, Sort, Ty, TyKind, Var,
16        fold::{TypeFoldable, TypeSuperVisitable, TypeVisitable, TypeVisitor},
17    },
18};
19use itertools::Itertools;
20use rustc_data_structures::snapshot_map::SnapshotMap;
21use rustc_hash::{FxHashMap, FxHashSet};
22use rustc_index::newtype_index;
23use rustc_middle::ty::TyCtxt;
24use serde::Serialize;
25
26use crate::{
27    evars::EVarStore,
28    fixpoint_encoding::{FixpointCtxt, fixpoint},
29    infer::{Tag, TypeTrace},
30};
31
32/// A *refine*ment *tree* tracks the "tree-like structure" of refinement variables and predicates
33/// generated during refinement type-checking. This tree can be encoded as a fixpoint constraint
34/// whose satisfiability implies the safety of a function.
35///
36/// We try to hide the representation of the tree as much as possible and only a couple of operations
37/// can be used to manipulate the structure of the tree explicitly. Instead, the tree is mostly
38/// constructed implicitly via a restricted api provided by [`Cursor`]. Some methods operate on *nodes*
39/// of the tree which we try to keep abstract, but it is important to remember that there's an
40/// underlying tree.
41///
42/// The current implementation uses [`Rc`] and [`RefCell`] to represent the tree, but we ensure
43/// statically that the [`RefineTree`] is the single owner of the data and require a mutable reference
44/// to it for all mutations, i.e., we could in theory replace the [`RefCell`] with an [`UnsafeCell`]
45/// (or a [`GhostCell`]).
46///
47/// [`UnsafeCell`]: std::cell::UnsafeCell
48/// [`GhostCell`]: https://docs.rs/ghost-cell/0.2.3/ghost_cell/ghost_cell/struct.GhostCell.html
49pub struct RefineTree {
50    root: NodePtr,
51}
52
53impl RefineTree {
54    pub(crate) fn new(params: Vec<(Var, Sort)>) -> RefineTree {
55        let root =
56            Node { kind: NodeKind::Root(params), nbindings: 0, parent: None, children: vec![] };
57        let root = NodePtr(Rc::new(RefCell::new(root)));
58        RefineTree { root }
59    }
60
61    pub(crate) fn simplify(&mut self, genv: GlobalEnv) {
62        self.root
63            .borrow_mut()
64            .simplify(SimplifyPhase::Full(genv), &mut SnapshotMap::default());
65        self.root.borrow_mut().simplify_bot();
66        self.root.borrow_mut().simplify_top();
67    }
68
69    pub(crate) fn into_fixpoint(
70        self,
71        cx: &mut FixpointCtxt<Tag>,
72    ) -> QueryResult<fixpoint::Constraint> {
73        Ok(self
74            .root
75            .borrow()
76            .to_fixpoint(cx)?
77            .unwrap_or(fixpoint::Constraint::TRUE))
78    }
79
80    pub(crate) fn cursor_at_root(&mut self) -> Cursor<'_> {
81        Cursor { ptr: NodePtr(Rc::clone(&self.root)), tree: self }
82    }
83
84    pub(crate) fn replace_evars(&mut self, evars: &EVarStore) -> Result<(), EVid> {
85        self.root.borrow_mut().replace_evars(evars)
86    }
87}
88
89/// A cursor into the [refinement tree]. More specifically, a [`Cursor`] represents a path from the
90/// root to some internal node in a [refinement tree].
91///
92/// [refinement tree]: RefineTree
93pub struct Cursor<'a> {
94    tree: &'a mut RefineTree,
95    ptr: NodePtr,
96}
97
98impl Cursor<'_> {
99    /// Moves the cursor to the specified [marker]. If `clear_children` is `true`, all children of
100    /// the node are removed after moving the cursor, invalidating any markers pointing to a node
101    /// within those children.
102    ///
103    /// [marker]: Marker
104    pub(crate) fn move_to(&mut self, marker: &Marker, clear_children: bool) -> Option<Cursor<'_>> {
105        let ptr = marker.ptr.upgrade()?;
106        if clear_children {
107            ptr.borrow_mut().children.clear();
108        }
109        Some(Cursor { ptr, tree: self.tree })
110    }
111
112    /// Returns a marker to the current node
113    #[must_use]
114    pub(crate) fn marker(&self) -> Marker {
115        Marker { ptr: NodePtr::downgrade(&self.ptr) }
116    }
117
118    #[must_use]
119    pub(crate) fn branch(&mut self) -> Cursor<'_> {
120        Cursor { tree: self.tree, ptr: NodePtr::clone(&self.ptr) }
121    }
122
123    pub(crate) fn vars(&self) -> impl Iterator<Item = (Var, Sort)> {
124        // TODO(nilehmann) we could incrementally cache the scope
125        self.ptr.scope().into_iter()
126    }
127
128    #[expect(dead_code, reason = "used for debugging")]
129    pub(crate) fn push_trace(&mut self, trace: TypeTrace) {
130        self.ptr = self.ptr.push_node(NodeKind::Trace(trace));
131    }
132
133    /// Defines a fresh refinement variable with the given `sort` and advance the cursor to the new
134    /// node. It returns the freshly generated name for the variable.
135    pub(crate) fn define_var(&mut self, sort: &Sort) -> Name {
136        let fresh = Name::from_usize(self.ptr.next_name_idx());
137        self.ptr = self.ptr.push_node(NodeKind::ForAll(fresh, sort.clone()));
138        fresh
139    }
140
141    /// Pushes an [assumption] and moves the cursor into the new node.
142    ///
143    /// [assumption]: NodeKind::Assumption
144    pub(crate) fn assume_pred(&mut self, pred: impl Into<Expr>) {
145        let pred = pred.into();
146        if !pred.is_trivially_true() {
147            self.ptr = self.ptr.push_node(NodeKind::Assumption(pred));
148        }
149    }
150
151    /// Pushes a predicate that must be true assuming variables and predicates in the current branch
152    /// of the tree (i.e., it pushes a [`NodeKind::Head`]). This methods does not advance the cursor.
153    pub(crate) fn check_pred(&mut self, pred: impl Into<Expr>, tag: Tag) {
154        let pred = pred.into();
155        if !pred.is_trivially_true() {
156            self.ptr.push_node(NodeKind::Head(pred, tag));
157        }
158    }
159
160    /// Convenience method to push an assumption followed by a predicate that needs to be checked.
161    /// This method does not advance the cursor.
162    pub(crate) fn check_impl(&mut self, pred1: impl Into<Expr>, pred2: impl Into<Expr>, tag: Tag) {
163        self.ptr
164            .push_node(NodeKind::Assumption(pred1.into()))
165            .push_node(NodeKind::Head(pred2.into(), tag));
166    }
167
168    pub(crate) fn assume_invariants(
169        &mut self,
170        tcx: TyCtxt,
171        ty: &Ty,
172        overflow_checking: OverflowMode,
173    ) {
174        struct Visitor<'a, 'b, 'tcx> {
175            tcx: TyCtxt<'tcx>,
176            cursor: &'a mut Cursor<'b>,
177            overflow_mode: OverflowMode,
178        }
179        impl TypeVisitor for Visitor<'_, '_, '_> {
180            fn visit_bty(&mut self, bty: &BaseTy) -> ControlFlow<!> {
181                match bty {
182                    BaseTy::Adt(adt_def, substs) if adt_def.is_box() => substs.visit_with(self),
183                    BaseTy::Ref(_, ty, _) => ty.visit_with(self),
184                    BaseTy::Tuple(tys) => tys.visit_with(self),
185                    _ => ControlFlow::Continue(()),
186                }
187            }
188
189            fn visit_ty(&mut self, ty: &Ty) -> ControlFlow<!> {
190                if let TyKind::Indexed(bty, idx) = ty.kind()
191                    && !idx.has_escaping_bvars()
192                {
193                    for invariant in bty.invariants(self.tcx, self.overflow_mode) {
194                        let invariant = invariant.apply(idx);
195                        self.cursor.assume_pred(&invariant);
196                    }
197                }
198                ty.super_visit_with(self)
199            }
200        }
201        let _ = ty.visit_with(&mut Visitor { tcx, cursor: self, overflow_mode: overflow_checking });
202    }
203}
204
205/// A marker is a pointer to a node in the [refinement tree] that can be used to query information
206/// about that node or to move the cursor. A marker may become invalid if the underlying node is
207/// [cleared].
208///
209/// [cleared]: Cursor::move_to
210/// [refinement tree]: RefineTree
211pub struct Marker {
212    ptr: WeakNodePtr,
213}
214
215impl Marker {
216    /// Returns the [`scope`] at the marker if it is still valid or [`None`] otherwise.
217    ///
218    /// [`scope`]: Scope
219    pub fn scope(&self) -> Option<Scope> {
220        Some(self.ptr.upgrade()?.scope())
221    }
222
223    pub fn has_free_vars<T: TypeVisitable>(&self, t: &T) -> bool {
224        let ptr = self
225            .ptr
226            .upgrade()
227            .unwrap_or_else(|| tracked_span_bug!("`has_free_vars` called on invalid `Marker`"));
228
229        let nbindings = ptr.next_name_idx();
230
231        !t.fvars().into_iter().all(|name| name.index() < nbindings)
232    }
233}
234
235/// A list of refinement variables and their sorts.
236#[derive(PartialEq, Eq)]
237pub struct Scope {
238    params: Vec<(Var, Sort)>,
239    bindings: IndexVec<Name, Sort>,
240}
241
242impl Scope {
243    pub(crate) fn iter(&self) -> impl Iterator<Item = (Var, Sort)> + '_ {
244        itertools::chain(
245            self.params.iter().cloned(),
246            self.bindings
247                .iter_enumerated()
248                .map(|(name, sort)| (Var::Free(name), sort.clone())),
249        )
250    }
251
252    fn into_iter(self) -> impl Iterator<Item = (Var, Sort)> {
253        itertools::chain(
254            self.params,
255            self.bindings
256                .into_iter_enumerated()
257                .map(|(name, sort)| (Var::Free(name), sort.clone())),
258        )
259    }
260
261    /// Whether `t` has any free variables not in this scope
262    pub fn has_free_vars<T: TypeFoldable>(&self, t: &T) -> bool {
263        !self.contains_all(t.fvars())
264    }
265
266    fn contains_all(&self, iter: impl IntoIterator<Item = Name>) -> bool {
267        iter.into_iter().all(|name| self.contains(name))
268    }
269
270    fn contains(&self, name: Name) -> bool {
271        name.index() < self.bindings.len()
272    }
273}
274
275struct Node {
276    kind: NodeKind,
277    /// Number of bindings between the root and this node's parent, i.e., we have
278    /// as an invariant that `nbindings` equals the number of [`NodeKind::ForAll`]
279    /// nodes from the parent of this node to the root.
280    nbindings: usize,
281    parent: Option<WeakNodePtr>,
282    children: Vec<NodePtr>,
283}
284
285#[derive(Clone)]
286struct NodePtr(Rc<RefCell<Node>>);
287
288impl NodePtr {
289    fn downgrade(this: &Self) -> WeakNodePtr {
290        WeakNodePtr(Rc::downgrade(&this.0))
291    }
292
293    fn push_node(&mut self, kind: NodeKind) -> NodePtr {
294        debug_assert!(!matches!(self.borrow().kind, NodeKind::Head(..)));
295        let node = Node {
296            kind,
297            nbindings: self.next_name_idx(),
298            parent: Some(NodePtr::downgrade(self)),
299            children: vec![],
300        };
301        let node = NodePtr(Rc::new(RefCell::new(node)));
302        self.borrow_mut().children.push(NodePtr::clone(&node));
303        node
304    }
305
306    fn next_name_idx(&self) -> usize {
307        self.borrow().nbindings + usize::from(self.borrow().is_forall())
308    }
309
310    fn scope(&self) -> Scope {
311        let mut params = None;
312        let parents = ParentsIter::new(self.clone());
313        let bindings = parents
314            .filter_map(|node| {
315                let node = node.borrow();
316                match &node.kind {
317                    NodeKind::Root(p) => {
318                        params = Some(p.clone());
319                        None
320                    }
321                    NodeKind::ForAll(_, sort) => Some(sort.clone()),
322                    _ => None,
323                }
324            })
325            .collect_vec()
326            .into_iter()
327            .rev()
328            .collect();
329        Scope { bindings, params: params.unwrap_or_default() }
330    }
331}
332
333struct WeakNodePtr(Weak<RefCell<Node>>);
334
335impl WeakNodePtr {
336    fn upgrade(&self) -> Option<NodePtr> {
337        Some(NodePtr(self.0.upgrade()?))
338    }
339}
340
341enum NodeKind {
342    /// List of const and refinement generics
343    Root(Vec<(Var, Sort)>),
344    ForAll(Name, Sort),
345    Assumption(Expr),
346    Head(Expr, Tag),
347    True,
348    /// Used for debugging. See [`TypeTrace`]
349    Trace(TypeTrace),
350}
351
352impl std::ops::Index<Name> for Scope {
353    type Output = Sort;
354
355    fn index(&self, name: Name) -> &Self::Output {
356        &self.bindings[name]
357    }
358}
359
360impl std::ops::Deref for NodePtr {
361    type Target = Rc<RefCell<Node>>;
362
363    fn deref(&self) -> &Self::Target {
364        &self.0
365    }
366}
367
368#[derive(Clone, Copy)]
369enum SimplifyPhase<'genv, 'tcx> {
370    /// Normalize and simplify inner `Expr`
371    Full(GlobalEnv<'genv, 'tcx>),
372    /// Only propagate `true` (TOP) and `false` (BOT)
373    Partial,
374}
375
376impl Node {
377    fn simplify(&mut self, phase: SimplifyPhase, assumed_preds: &mut SnapshotMap<Expr, ()>) {
378        // First, simplify the node itself
379        match &mut self.kind {
380            NodeKind::Head(pred, tag) => {
381                let pred = match phase {
382                    SimplifyPhase::Full(genv) => pred.normalize(genv).simplify(assumed_preds),
383                    SimplifyPhase::Partial => pred.clone(),
384                };
385                if pred.is_trivially_true() {
386                    self.kind = NodeKind::True;
387                } else {
388                    self.kind = NodeKind::Head(pred, *tag);
389                }
390            }
391            NodeKind::Assumption(pred) => {
392                if let SimplifyPhase::Full(genv) = phase {
393                    *pred = pred.normalize(genv).simplify(assumed_preds);
394                }
395                pred.visit_conj(|conjunct| {
396                    assumed_preds.insert(conjunct.erase_spans(), ());
397                });
398            }
399            _ => {}
400        }
401        let is_false_asm =
402            matches!(&self.kind, NodeKind::Assumption(pred) if pred.is_trivially_false());
403
404        // Then simplify the children
405        // (the order matters here because we need to collect assumed preds first)
406        for child in &self.children {
407            let current_version = assumed_preds.snapshot();
408            child.borrow_mut().simplify(phase, assumed_preds);
409            assumed_preds.rollback_to(current_version);
410        }
411
412        // Then remove any unnecessary children
413        match &mut self.kind {
414            NodeKind::Head(..) | NodeKind::True => {}
415            NodeKind::Assumption(_)
416            | NodeKind::Trace(_)
417            | NodeKind::Root(_)
418            | NodeKind::ForAll(..) => {
419                self.children
420                    .extract_if(.., |child| {
421                        is_false_asm || matches!(&child.borrow().kind, NodeKind::True)
422                    })
423                    .for_each(drop);
424            }
425        }
426        if !self.is_leaf() && self.children.is_empty() {
427            self.kind = NodeKind::True;
428        }
429    }
430
431    fn is_leaf(&self) -> bool {
432        matches!(self.kind, NodeKind::Head(..) | NodeKind::True)
433    }
434
435    fn replace_evars(&mut self, evars: &EVarStore) -> Result<(), EVid> {
436        for child in &self.children {
437            child.borrow_mut().replace_evars(evars)?;
438        }
439        match &mut self.kind {
440            NodeKind::Assumption(pred) => *pred = evars.replace_evars(pred)?,
441            NodeKind::Head(pred, _) => {
442                *pred = evars.replace_evars(pred)?;
443            }
444            NodeKind::Trace(trace) => {
445                evars.replace_evars(trace)?;
446            }
447            NodeKind::Root(_) | NodeKind::ForAll(..) | NodeKind::True => {}
448        }
449        Ok(())
450    }
451
452    fn to_fixpoint(&self, cx: &mut FixpointCtxt<Tag>) -> QueryResult<Option<fixpoint::Constraint>> {
453        let cstr = match &self.kind {
454            NodeKind::Trace(_) | NodeKind::ForAll(_, Sort::Loc) => {
455                children_to_fixpoint(cx, &self.children)?
456            }
457
458            NodeKind::Root(params) => {
459                let Some(children) = children_to_fixpoint(cx, &self.children)? else {
460                    return Ok(None);
461                };
462                let mut constr = children;
463                for (var, sort) in params.iter().rev() {
464                    if sort.is_loc() {
465                        continue;
466                    }
467                    constr = fixpoint::Constraint::ForAll(
468                        fixpoint::Bind {
469                            name: cx.var_to_fixpoint(var),
470                            sort: cx.sort_to_fixpoint(sort),
471                            pred: fixpoint::Pred::TRUE,
472                        },
473                        Box::new(constr),
474                    );
475                }
476                Some(constr)
477            }
478            NodeKind::ForAll(name, sort) => {
479                cx.with_name_map(*name, |cx, fresh| -> QueryResult<_> {
480                    let Some(children) = children_to_fixpoint(cx, &self.children)? else {
481                        return Ok(None);
482                    };
483                    Ok(Some(fixpoint::Constraint::ForAll(
484                        fixpoint::Bind {
485                            name: fixpoint::Var::Local(fresh),
486                            sort: cx.sort_to_fixpoint(sort),
487                            pred: fixpoint::Pred::TRUE,
488                        },
489                        Box::new(children),
490                    )))
491                })?
492            }
493            NodeKind::Assumption(pred) => {
494                let (mut bindings, pred) = cx.assumption_to_fixpoint(pred)?;
495                let Some(cstr) = children_to_fixpoint(cx, &self.children)? else {
496                    return Ok(None);
497                };
498                bindings.push(fixpoint::Bind {
499                    name: fixpoint::Var::Underscore,
500                    sort: fixpoint::Sort::Int,
501                    pred,
502                });
503                Some(fixpoint::Constraint::foralls(bindings, cstr))
504            }
505            NodeKind::Head(pred, tag) => {
506                Some(cx.head_to_fixpoint(pred, |span| tag.with_dst(span))?)
507            }
508            NodeKind::True => None,
509        };
510        Ok(cstr)
511    }
512
513    /// Returns `true` if the node kind is [`ForAll`].
514    ///
515    /// [`ForAll`]: NodeKind::ForAll
516    fn is_forall(&self) -> bool {
517        matches!(self.kind, NodeKind::ForAll(..))
518    }
519
520    /// Returns `true` if the node kind is [`Head`].
521    ///
522    /// [`Head`]: NodeKind::Head
523    fn is_head(&self) -> bool {
524        matches!(self.kind, NodeKind::Head(..))
525    }
526}
527
528fn children_to_fixpoint(
529    cx: &mut FixpointCtxt<Tag>,
530    children: &[NodePtr],
531) -> QueryResult<Option<fixpoint::Constraint>> {
532    let mut children = children
533        .iter()
534        .filter_map(|node| node.borrow().to_fixpoint(cx).transpose())
535        .try_collect_vec()?;
536    let cstr = match children.len() {
537        0 => None,
538        1 => children.pop(),
539        _ => Some(fixpoint::Constraint::Conj(children)),
540    };
541    Ok(cstr)
542}
543
544struct ParentsIter {
545    ptr: Option<NodePtr>,
546}
547
548impl ParentsIter {
549    fn new(ptr: NodePtr) -> Self {
550        Self { ptr: Some(ptr) }
551    }
552}
553
554impl Iterator for ParentsIter {
555    type Item = NodePtr;
556
557    fn next(&mut self) -> Option<Self::Item> {
558        if let Some(ptr) = self.ptr.take() {
559            self.ptr = ptr.borrow().parent.as_ref().and_then(WeakNodePtr::upgrade);
560            Some(ptr)
561        } else {
562            None
563        }
564    }
565}
566
567mod pretty {
568    use std::fmt::{self, Write};
569
570    use flux_middle::pretty::*;
571    use pad_adapter::PadAdapter;
572
573    use super::*;
574
575    fn bindings_chain(ptr: &NodePtr) -> (Vec<(Name, Sort)>, Vec<NodePtr>) {
576        fn go(ptr: &NodePtr, mut bindings: Vec<(Name, Sort)>) -> (Vec<(Name, Sort)>, Vec<NodePtr>) {
577            let node = ptr.borrow();
578            if let NodeKind::ForAll(name, sort) = &node.kind {
579                bindings.push((*name, sort.clone()));
580                if let [child] = &node.children[..] {
581                    go(child, bindings)
582                } else {
583                    (bindings, node.children.clone())
584                }
585            } else {
586                (bindings, vec![NodePtr::clone(ptr)])
587            }
588        }
589        go(ptr, vec![])
590    }
591
592    fn preds_chain(ptr: &NodePtr) -> (Vec<Expr>, Vec<NodePtr>) {
593        fn go(ptr: &NodePtr, mut preds: Vec<Expr>) -> (Vec<Expr>, Vec<NodePtr>) {
594            let node = ptr.borrow();
595            if let NodeKind::Assumption(pred) = &node.kind {
596                preds.push(pred.clone());
597                if let [child] = &node.children[..] {
598                    go(child, preds)
599                } else {
600                    (preds, node.children.clone())
601                }
602            } else {
603                (preds, vec![NodePtr::clone(ptr)])
604            }
605        }
606        go(ptr, vec![])
607    }
608
609    impl Pretty for RefineTree {
610        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
611            w!(cx, f, "{:?}", &self.root)
612        }
613    }
614
615    impl Pretty for NodePtr {
616        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
617            let node = self.borrow();
618            match &node.kind {
619                NodeKind::Trace(trace) => {
620                    w!(cx, f, "@ {:?}", ^trace)?;
621                    w!(cx, with_padding(f), "\n{:?}", join!("\n", &node.children))
622                }
623                NodeKind::Root(bindings) => {
624                    w!(cx, f,
625                        "∀ {}.",
626                        ^bindings
627                            .iter()
628                            .format_with(", ", |(name, sort), f| {
629                                f(&format_args_cx!(cx, "{:?}: {:?}", ^name, sort))
630                            })
631                    )?;
632                    fmt_children(&node.children, cx, f)
633                }
634                NodeKind::ForAll(name, sort) => {
635                    let (bindings, children) = if cx.bindings_chain {
636                        bindings_chain(self)
637                    } else {
638                        (vec![(*name, sort.clone())], node.children.clone())
639                    };
640
641                    w!(cx, f,
642                        "∀ {}.",
643                        ^bindings
644                            .into_iter()
645                            .format_with(", ", |(name, sort), f| {
646                                f(&format_args_cx!(cx, "{:?}: {:?}", ^name, sort))
647                            })
648                    )?;
649                    fmt_children(&children, cx, f)
650                }
651                NodeKind::Assumption(pred) => {
652                    let (preds, children) = if cx.preds_chain {
653                        preds_chain(self)
654                    } else {
655                        (vec![pred.clone()], node.children.clone())
656                    };
657                    let guard = Expr::and_from_iter(preds).simplify(&SnapshotMap::default());
658                    w!(cx, f, "{:?} =>", parens!(guard, !guard.is_atom()))?;
659                    fmt_children(&children, cx, f)
660                }
661                NodeKind::Head(pred, tag) => {
662                    let pred = if cx.simplify_exprs {
663                        pred.simplify(&SnapshotMap::default())
664                    } else {
665                        pred.clone()
666                    };
667                    w!(cx, f, "{:?}", parens!(pred, !pred.is_atom()))?;
668                    if cx.tags {
669                        w!(cx, f, " ~ {:?}", tag)?;
670                    }
671                    Ok(())
672                }
673                NodeKind::True => {
674                    w!(cx, f, "true")
675                }
676            }
677        }
678    }
679
680    fn fmt_children(
681        children: &[NodePtr],
682        cx: &PrettyCx,
683        f: &mut fmt::Formatter<'_>,
684    ) -> fmt::Result {
685        match children {
686            [] => w!(cx, f, " true"),
687            [n] => {
688                if n.borrow().is_head() {
689                    w!(cx, f, " {:?}", n)
690                } else {
691                    w!(cx, with_padding(f), "\n{:?}", n)
692                }
693            }
694            _ => w!(cx, with_padding(f), "\n{:?}", join!("\n", children)),
695        }
696    }
697
698    impl Pretty for Cursor<'_> {
699        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
700            let mut elements = vec![];
701            for node in ParentsIter::new(NodePtr::clone(&self.ptr)) {
702                let n = node.borrow();
703                match &n.kind {
704                    NodeKind::Root(bindings) => {
705                        // We reverse here because is reversed again at the end
706                        for (name, sort) in bindings.iter().rev() {
707                            elements.push(format_cx!(cx, "{:?}: {:?}", ^name, sort));
708                        }
709                    }
710                    NodeKind::ForAll(name, sort) => {
711                        elements.push(format_cx!(cx, "{:?}: {:?}", ^name, sort));
712                    }
713                    NodeKind::Assumption(pred) => {
714                        elements.push(format_cx!(cx, "{:?}", pred));
715                    }
716                    _ => {}
717                }
718            }
719            write!(f, "{{{}}}", elements.into_iter().rev().format(", "))
720        }
721    }
722
723    impl Pretty for Scope {
724        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
725            write!(
726                f,
727                "[bindings = {}, reftgenerics = {}]",
728                self.bindings
729                    .iter_enumerated()
730                    .format_with(", ", |(name, sort), f| {
731                        f(&format_args_cx!(cx, "{:?}: {:?}", ^name, sort))
732                    }),
733                self.params
734                    .iter()
735                    .format_with(", ", |(param_const, sort), f| {
736                        f(&format_args_cx!(cx, "{:?}: {:?}", ^param_const, sort))
737                    }),
738            )
739        }
740    }
741
742    fn with_padding<'a, 'b>(f: &'a mut fmt::Formatter<'b>) -> PadAdapter<'a, 'b, 'static> {
743        PadAdapter::with_padding(f, "  ")
744    }
745
746    impl_debug_with_default_cx!(
747        RefineTree => "refine_tree",
748        Cursor<'_> => "cursor",
749        Scope,
750    );
751}
752
753/// An explicit representation of a path in the [`RefineTree`] used for debugging/tracing/serialization ONLY.
754#[derive(Serialize, DebugAsJson)]
755pub struct RefineCtxtTrace {
756    bindings: Vec<RcxBind>,
757    exprs: Vec<String>,
758}
759
760#[derive(Serialize)]
761struct RcxBind {
762    name: String,
763    sort: String,
764}
765
766impl RefineCtxtTrace {
767    pub fn new(genv: GlobalEnv, cursor: &Cursor) -> Self {
768        let parents = ParentsIter::new(NodePtr::clone(&cursor.ptr)).collect_vec();
769        let mut bindings = vec![];
770        let mut exprs = vec![];
771        let cx = &PrettyCx::default(genv);
772
773        parents.into_iter().rev().for_each(|ptr| {
774            let node = ptr.borrow();
775            match &node.kind {
776                NodeKind::ForAll(name, sort) => {
777                    let bind = RcxBind {
778                        name: format_cx!(cx, "{:?}", ^name),
779                        sort: format_cx!(cx, "{:?}", sort),
780                    };
781                    bindings.push(bind);
782                }
783                NodeKind::Assumption(e)
784                    if !e.simplify(&SnapshotMap::default()).is_trivially_true() =>
785                {
786                    e.visit_conj(|e| {
787                        exprs.push(e.nested_string(cx));
788                    });
789                }
790                NodeKind::Root(binds) => {
791                    for (name, sort) in binds {
792                        let bind = RcxBind {
793                            name: format_cx!(cx, "{:?}", name),
794                            sort: format_cx!(cx, "{:?}", sort),
795                        };
796                        bindings.push(bind);
797                    }
798                }
799                _ => (),
800            }
801        });
802        Self { bindings, exprs }
803    }
804}
805
806impl Node {
807    /// replace bot-kvars with false
808    fn simplify_bot(&mut self) {
809        let graph = ConstraintDeps::new(self);
810        let bots = graph.bot_kvars();
811        self.simplify_with_assignment(&bots);
812        self.simplify(SimplifyPhase::Partial, &mut SnapshotMap::default());
813    }
814
815    /// replace top-kvars with true
816    fn simplify_top(&mut self) {
817        let graph = ConstraintDeps::new(self);
818        let tops = graph.top_kvars();
819        self.simplify_with_assignment(&tops);
820        self.simplify(SimplifyPhase::Partial, &mut SnapshotMap::default());
821    }
822
823    /// simplifies assumptions and heads using the TOP/BOT kvar assignment; follow
824    /// with a call to `simplify` to delete constraints with FALSE assm.
825    fn simplify_with_assignment(&mut self, assignment: &Assignment) {
826        match &mut self.kind {
827            NodeKind::Head(pred, tag) => {
828                let pred = assignment.simplify(pred);
829                self.kind = NodeKind::Head(pred, *tag);
830            }
831            NodeKind::Assumption(pred) => {
832                let pred = assignment.simplify(pred);
833                self.kind = NodeKind::Assumption(pred);
834            }
835            _ => {}
836        }
837        for child in &self.children {
838            child.borrow_mut().simplify_with_assignment(assignment);
839        }
840    }
841}
842
843#[derive(Debug)]
844struct ConstraintDeps {
845    /// assumptions for each clause
846    assumptions: IndexVec<ClauseId, FxHashSet<KVid>>,
847    /// head of each clause
848    heads: IndexVec<ClauseId, Head>,
849}
850
851impl ConstraintDeps {
852    fn new(node: &Node) -> Self {
853        let mut graph = Self { assumptions: IndexVec::default(), heads: IndexVec::default() };
854        graph.build(node, &mut vec![]);
855        graph
856    }
857
858    fn insert_clause(&mut self, assumptions: &[KVid], head: Head) {
859        self.assumptions.push(assumptions.iter().copied().collect());
860        self.heads.push(head);
861    }
862
863    fn build(&mut self, node: &Node, assumptions: &mut Vec<KVid>) {
864        let n = assumptions.len();
865        match &node.kind {
866            NodeKind::Head(expr, _) => {
867                expr.visit_conj(|e| {
868                    if let ExprKind::KVar(kvar) = e.kind() {
869                        self.insert_clause(assumptions, Head::KVar(kvar.kvid));
870                    } else {
871                        self.insert_clause(assumptions, Head::Conc);
872                    }
873                });
874            }
875            NodeKind::Assumption(expr) => {
876                expr.visit_conj(|e| {
877                    if let ExprKind::KVar(kvar) = e.kind() {
878                        assumptions.push(kvar.kvid);
879                    }
880                });
881            }
882            _ => {}
883        };
884
885        for child in &node.children {
886            self.build(&child.borrow(), assumptions);
887        }
888
889        assumptions.truncate(n); // restore ctx
890    }
891
892    /// set of edges where kvid appears as ASSM
893    fn kv_lhs(&self) -> FxHashMap<KVid, Vec<ClauseId>> {
894        let mut res: FxHashMap<KVid, Vec<ClauseId>> = FxHashMap::default();
895        for (clause_id, kvids) in self.assumptions.iter_enumerated() {
896            for kvid in kvids {
897                res.entry(*kvid).or_default().push(clause_id);
898            }
899        }
900        res
901    }
902
903    /// set of edges where kvid appears as HEAD
904    fn kv_rhs(&self) -> FxHashMap<KVid, Vec<ClauseId>> {
905        let mut res: FxHashMap<KVid, Vec<ClauseId>> = FxHashMap::default();
906        for (clause_id, head) in self.heads.iter_enumerated() {
907            if let Head::KVar(kvid) = head {
908                res.entry(*kvid).or_default().push(clause_id);
909            }
910        }
911        res
912    }
913
914    /// Computes the set of all kvars that can be assigned to Bot (False),
915    /// because they are not (transitively) reachable from any concrete ASSUMPTION.
916    fn bot_kvars(self) -> Assignment {
917        // set of BOT kvars (initially, all)
918        let mut assignment = Assignment::new(Label::Bot);
919
920        let kv_lhs = self.kv_lhs();
921
922        // set of BOT kvars in LHS of each constraint with KVar HEAD
923        let mut bot_assms: IndexVec<ClauseId, FxHashSet<KVid>> = self.assumptions;
924
925        // set of constraints `cid` whose bot-assms is empty
926        let mut candidates: Vec<ClauseId> = bot_assms
927            .iter_enumerated()
928            .filter_map(|(cid, lhs)| if lhs.is_empty() { Some(cid) } else { None })
929            .collect();
930
931        // while there is a candidate constraint, that has NO BOT kvars in lhs
932        while let Some(cid) = candidates.pop() {
933            if let Head::KVar(kvid) = self.heads[cid] {
934                // un-BOT the head kvar
935                assignment.remove(kvid);
936                // remove the head kvar from all (bot) assumptions where it currently occurs
937                for cid in kv_lhs.get(&kvid).unwrap_or(&vec![]) {
938                    // if cid HEAD is a kvar
939                    if let Head::KVar(rhs_kvid) = self.heads[*cid] {
940                        let assms = &mut bot_assms[*cid];
941                        assms.remove(&kvid);
942                        if assignment.has_label(rhs_kvid) && assms.is_empty() {
943                            candidates.push(*cid);
944                        }
945                    };
946                }
947            }
948        }
949
950        assignment
951    }
952
953    /// Computes the set of all kvars that can be assigned to Top (True),
954    /// because they do not (transitively) reach any concrete HEAD.
955    fn top_kvars(self) -> Assignment {
956        // initialize
957        let mut assignment = Assignment::new(Label::Top);
958
959        let kv_rhs = self.kv_rhs();
960
961        // set of kvar {k | cid in graph.edges, c.rhs is concrete, k in c.lhs }
962        let mut candidates = vec![];
963        for (cid, head) in self.heads.iter_enumerated() {
964            if matches!(head, Head::Conc) {
965                for kvid in &self.assumptions[cid] {
966                    candidates.push(*kvid);
967                }
968            }
969        }
970
971        // set each kvar that transitively reaches a concrete HEAD to NON-BOT
972        while let Some(kvid) = candidates.pop() {
973            // set that kvar to non-top
974            assignment.remove(kvid);
975
976            // for each constraint where kvid appears as head
977            for cid in kv_rhs.get(&kvid).unwrap_or(&vec![]) {
978                // add kvars in lhs to candidates (if they have not already been solved to non-BOT)
979                for lhs_kvid in &self.assumptions[*cid] {
980                    if assignment.has_label(*lhs_kvid) {
981                        candidates.push(*lhs_kvid);
982                    }
983                }
984            }
985        }
986
987        assignment
988    }
989}
990
991newtype_index! {
992    struct ClauseId {}
993}
994
995#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
996enum Head {
997    /// KVar
998    KVar(KVid),
999    /// A *conc*rete predicate, i.e., an [`Expr`] that's not a kvar. We don't need to know
1000    /// the exact expression, only that it's concrete.
1001    Conc,
1002}
1003
1004#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1005enum Label {
1006    /// Kvar can be solved to false
1007    Bot,
1008    /// Kvar can be solved to true
1009    Top,
1010}
1011
1012struct Assignment {
1013    /// These vars are NOT assigned `label`,
1014    /// all other `KVid` implicitly have assignment `label`
1015    vars: FxHashSet<KVid>,
1016    label: Label,
1017}
1018
1019impl Assignment {
1020    fn new(label: Label) -> Self {
1021        let vars = FxHashSet::default();
1022        Self { vars, label }
1023    }
1024
1025    fn has_label(&self, kvid: KVid) -> bool {
1026        !self.vars.contains(&kvid)
1027    }
1028
1029    fn remove(&mut self, kvid: KVid) {
1030        self.vars.insert(kvid);
1031    }
1032
1033    /// simplifies the given predicate expression by replacing
1034    /// kvid assigned to TOP with True, BOT with false.
1035    fn simplify(&self, pred: &Expr) -> Expr {
1036        let mut preds = vec![];
1037        for p in pred.flatten_conjs() {
1038            if let ExprKind::KVar(kvar) = p.kind()
1039                && self.has_label(kvar.kvid)
1040            {
1041                if self.label == Label::Bot {
1042                    return Expr::ff();
1043                } // else, skip pushing `p` into `preds`
1044            } else {
1045                preds.push(p.clone());
1046            }
1047        }
1048        Expr::and_from_iter(preds)
1049    }
1050}