flux_infer/
refine_tree.rs

1use std::{
2    cell::RefCell,
3    ops::ControlFlow,
4    rc::{Rc, Weak},
5};
6
7use flux_common::{index::IndexVec, iter::IterExt, tracked_span_bug};
8use flux_config::OverflowMode;
9use flux_macros::DebugAsJson;
10use flux_middle::{
11    global_env::GlobalEnv,
12    pretty::{PrettyCx, PrettyNested, format_cx},
13    queries::QueryResult,
14    rty::{
15        BaseTy, EVid, Expr, ExprKind, KVid, Name, NameProvenance, PrettyVar, Sort, Ty, TyKind, Var,
16        fold::{TypeFoldable, TypeSuperVisitable, TypeVisitable, TypeVisitor},
17    },
18};
19use itertools::Itertools;
20use rustc_data_structures::{
21    snapshot_map::SnapshotMap,
22    unord::{UnordMap, UnordSet},
23};
24use rustc_hash::FxHashSet;
25use rustc_index::newtype_index;
26use rustc_middle::ty::TyCtxt;
27use serde::Serialize;
28
29use crate::{
30    evars::EVarStore,
31    fixpoint_encoding::{FixpointCtxt, fixpoint},
32    infer::{Tag, TypeTrace},
33};
34
35/// A *refine*ment *tree* tracks the "tree-like structure" of refinement variables and predicates
36/// generated during refinement type-checking. This tree can be encoded as a fixpoint constraint
37/// whose satisfiability implies the safety of a function.
38///
39/// We try to hide the representation of the tree as much as possible and only a couple of operations
40/// can be used to manipulate the structure of the tree explicitly. Instead, the tree is mostly
41/// constructed implicitly via a restricted api provided by [`Cursor`]. Some methods operate on *nodes*
42/// of the tree which we try to keep abstract, but it is important to remember that there's an
43/// underlying tree.
44///
45/// The current implementation uses [`Rc`] and [`RefCell`] to represent the tree, but we ensure
46/// statically that the [`RefineTree`] is the single owner of the data and require a mutable reference
47/// to it for all mutations, i.e., we could in theory replace the [`RefCell`] with an [`UnsafeCell`]
48/// (or a [`GhostCell`]).
49///
50/// [`UnsafeCell`]: std::cell::UnsafeCell
51/// [`GhostCell`]: https://docs.rs/ghost-cell/0.2.3/ghost_cell/ghost_cell/struct.GhostCell.html
52pub struct RefineTree {
53    root: NodePtr,
54}
55
56impl RefineTree {
57    pub(crate) fn new(params: Vec<(Var, Sort)>) -> RefineTree {
58        let root =
59            Node { kind: NodeKind::Root(params), nbindings: 0, parent: None, children: vec![] };
60        let root = NodePtr(Rc::new(RefCell::new(root)));
61        RefineTree { root }
62    }
63
64    pub(crate) fn simplify(&mut self, genv: GlobalEnv) {
65        self.root
66            .borrow_mut()
67            .simplify(SimplifyPhase::Full(genv), &mut SnapshotMap::default());
68        self.root.borrow_mut().simplify_bot();
69        self.root.borrow_mut().simplify_top();
70    }
71
72    pub(crate) fn to_fixpoint(
73        &self,
74        cx: &mut FixpointCtxt<Tag>,
75    ) -> QueryResult<fixpoint::Constraint> {
76        Ok(self
77            .root
78            .borrow()
79            .to_fixpoint(cx)?
80            .unwrap_or(fixpoint::Constraint::TRUE))
81    }
82
83    pub(crate) fn cursor_at_root(&mut self) -> Cursor<'_> {
84        Cursor { ptr: NodePtr(Rc::clone(&self.root)), tree: self }
85    }
86
87    pub(crate) fn replace_evars(&mut self, evars: &EVarStore) -> Result<(), EVid> {
88        self.root.borrow_mut().replace_evars(evars)
89    }
90}
91
92/// A cursor into the [refinement tree]. More specifically, a [`Cursor`] represents a path from the
93/// root to some internal node in a [refinement tree].
94///
95/// [refinement tree]: RefineTree
96pub struct Cursor<'a> {
97    tree: &'a mut RefineTree,
98    ptr: NodePtr,
99}
100
101impl Cursor<'_> {
102    /// Moves the cursor to the specified [marker]. If `clear_children` is `true`, all children of
103    /// the node are removed after moving the cursor, invalidating any markers pointing to a node
104    /// within those children.
105    ///
106    /// [marker]: Marker
107    pub(crate) fn move_to(&mut self, marker: &Marker, clear_children: bool) -> Option<Cursor<'_>> {
108        let ptr = marker.ptr.upgrade()?;
109        if clear_children {
110            ptr.borrow_mut().children.clear();
111        }
112        Some(Cursor { ptr, tree: self.tree })
113    }
114
115    /// Returns a marker to the current node
116    #[must_use]
117    pub(crate) fn marker(&self) -> Marker {
118        Marker { ptr: NodePtr::downgrade(&self.ptr) }
119    }
120
121    #[must_use]
122    pub(crate) fn branch(&mut self) -> Cursor<'_> {
123        Cursor { tree: self.tree, ptr: NodePtr::clone(&self.ptr) }
124    }
125
126    pub(crate) fn vars(&self) -> impl Iterator<Item = (Var, Sort)> {
127        // TODO(nilehmann) we could incrementally cache the scope
128        self.ptr.scope().into_iter()
129    }
130
131    #[expect(dead_code, reason = "used for debugging")]
132    pub(crate) fn push_trace(&mut self, trace: TypeTrace) {
133        self.ptr = self.ptr.push_node(NodeKind::Trace(trace));
134    }
135
136    /// Defines a fresh refinement variable with the given `sort` and advance the cursor to the new
137    /// node. It returns the freshly generated name for the variable.
138    pub(crate) fn define_var(&mut self, sort: &Sort, provenance: NameProvenance) -> Name {
139        let fresh = Name::from_usize(self.ptr.next_name_idx());
140        self.ptr = self
141            .ptr
142            .push_node(NodeKind::ForAll(fresh, sort.clone(), provenance));
143        fresh
144    }
145
146    /// Pushes an [assumption] and moves the cursor into the new node.
147    ///
148    /// [assumption]: NodeKind::Assumption
149    pub(crate) fn assume_pred(&mut self, pred: impl Into<Expr>) {
150        let pred = pred.into();
151        if !pred.is_trivially_true() {
152            self.ptr = self.ptr.push_node(NodeKind::Assumption(pred));
153        }
154    }
155
156    /// Pushes a predicate that must be true assuming variables and predicates in the current branch
157    /// of the tree (i.e., it pushes a [`NodeKind::Head`]). This methods does not advance the cursor.
158    pub(crate) fn check_pred(&mut self, pred: impl Into<Expr>, tag: Tag) {
159        let pred = pred.into();
160        if !pred.is_trivially_true() {
161            self.ptr.push_node(NodeKind::Head(pred, tag));
162        }
163    }
164
165    /// Convenience method to push an assumption followed by a predicate that needs to be checked.
166    /// This method does not advance the cursor.
167    pub(crate) fn check_impl(&mut self, pred1: impl Into<Expr>, pred2: impl Into<Expr>, tag: Tag) {
168        self.ptr
169            .push_node(NodeKind::Assumption(pred1.into()))
170            .push_node(NodeKind::Head(pred2.into(), tag));
171    }
172
173    pub(crate) fn assume_invariants(
174        &mut self,
175        tcx: TyCtxt,
176        ty: &Ty,
177        overflow_checking: OverflowMode,
178    ) {
179        struct Visitor<'a, 'b, 'tcx> {
180            tcx: TyCtxt<'tcx>,
181            cursor: &'a mut Cursor<'b>,
182            overflow_mode: OverflowMode,
183        }
184        impl TypeVisitor for Visitor<'_, '_, '_> {
185            fn visit_bty(&mut self, bty: &BaseTy) -> ControlFlow<!> {
186                match bty {
187                    BaseTy::Adt(adt_def, substs) if adt_def.is_box() => substs.visit_with(self),
188                    BaseTy::Ref(_, ty, _) => ty.visit_with(self),
189                    BaseTy::Tuple(tys) => tys.visit_with(self),
190                    _ => ControlFlow::Continue(()),
191                }
192            }
193
194            fn visit_ty(&mut self, ty: &Ty) -> ControlFlow<!> {
195                if let TyKind::Indexed(bty, idx) = ty.kind()
196                    && !idx.has_escaping_bvars()
197                {
198                    for invariant in bty.invariants(self.tcx, self.overflow_mode) {
199                        let invariant = invariant.apply(idx);
200                        self.cursor.assume_pred(&invariant);
201                    }
202                }
203                ty.super_visit_with(self)
204            }
205        }
206        let _ = ty.visit_with(&mut Visitor { tcx, cursor: self, overflow_mode: overflow_checking });
207    }
208}
209
210/// A marker is a pointer to a node in the [refinement tree] that can be used to query information
211/// about that node or to move the cursor. A marker may become invalid if the underlying node is
212/// [cleared].
213///
214/// [cleared]: Cursor::move_to
215/// [refinement tree]: RefineTree
216pub struct Marker {
217    ptr: WeakNodePtr,
218}
219
220impl Marker {
221    /// Returns the [`scope`] at the marker if it is still valid or [`None`] otherwise.
222    ///
223    /// [`scope`]: Scope
224    pub fn scope(&self) -> Option<Scope> {
225        Some(self.ptr.upgrade()?.scope())
226    }
227
228    pub fn has_free_vars<T: TypeVisitable>(&self, t: &T) -> bool {
229        let ptr = self
230            .ptr
231            .upgrade()
232            .unwrap_or_else(|| tracked_span_bug!("`has_free_vars` called on invalid `Marker`"));
233
234        let nbindings = ptr.next_name_idx();
235
236        !t.fvars().into_iter().all(|name| name.index() < nbindings)
237    }
238}
239
240/// A list of refinement variables and their sorts.
241#[derive(PartialEq, Eq)]
242pub struct Scope {
243    params: Vec<(Var, Sort)>,
244    bindings: IndexVec<Name, Sort>,
245}
246
247impl Scope {
248    pub(crate) fn iter(&self) -> impl Iterator<Item = (Var, Sort)> + '_ {
249        itertools::chain(
250            self.params.iter().cloned(),
251            self.bindings
252                .iter_enumerated()
253                .map(|(name, sort)| (Var::Free(name), sort.clone())),
254        )
255    }
256
257    fn into_iter(self) -> impl Iterator<Item = (Var, Sort)> {
258        itertools::chain(
259            self.params,
260            self.bindings
261                .into_iter_enumerated()
262                .map(|(name, sort)| (Var::Free(name), sort.clone())),
263        )
264    }
265
266    /// Whether `t` has any free variables not in this scope
267    pub fn has_free_vars<T: TypeFoldable>(&self, t: &T) -> bool {
268        !self.contains_all(t.fvars())
269    }
270
271    fn contains_all(&self, iter: impl IntoIterator<Item = Name>) -> bool {
272        iter.into_iter().all(|name| self.contains(name))
273    }
274
275    fn contains(&self, name: Name) -> bool {
276        name.index() < self.bindings.len()
277    }
278}
279
280struct Node {
281    kind: NodeKind,
282    /// Number of bindings between the root and this node's parent, i.e., we have
283    /// as an invariant that `nbindings` equals the number of [`NodeKind::ForAll`]
284    /// nodes from the parent of this node to the root.
285    nbindings: usize,
286    parent: Option<WeakNodePtr>,
287    children: Vec<NodePtr>,
288}
289
290#[derive(Clone)]
291struct NodePtr(Rc<RefCell<Node>>);
292
293impl NodePtr {
294    fn downgrade(this: &Self) -> WeakNodePtr {
295        WeakNodePtr(Rc::downgrade(&this.0))
296    }
297
298    fn push_node(&mut self, kind: NodeKind) -> NodePtr {
299        debug_assert!(!matches!(self.borrow().kind, NodeKind::Head(..)));
300        let node = Node {
301            kind,
302            nbindings: self.next_name_idx(),
303            parent: Some(NodePtr::downgrade(self)),
304            children: vec![],
305        };
306        let node = NodePtr(Rc::new(RefCell::new(node)));
307        self.borrow_mut().children.push(NodePtr::clone(&node));
308        node
309    }
310
311    fn next_name_idx(&self) -> usize {
312        self.borrow().nbindings + usize::from(self.borrow().is_forall())
313    }
314
315    fn scope(&self) -> Scope {
316        let mut params = None;
317        let parents = ParentsIter::new(self.clone());
318        let bindings = parents
319            .filter_map(|node| {
320                let node = node.borrow();
321                match &node.kind {
322                    NodeKind::Root(p) => {
323                        params = Some(p.clone());
324                        None
325                    }
326                    NodeKind::ForAll(_, sort, _) => Some(sort.clone()),
327                    _ => None,
328                }
329            })
330            .collect_vec()
331            .into_iter()
332            .rev()
333            .collect();
334        Scope { bindings, params: params.unwrap_or_default() }
335    }
336}
337
338struct WeakNodePtr(Weak<RefCell<Node>>);
339
340impl WeakNodePtr {
341    fn upgrade(&self) -> Option<NodePtr> {
342        Some(NodePtr(self.0.upgrade()?))
343    }
344}
345
346enum NodeKind {
347    /// List of const and refinement generics
348    Root(Vec<(Var, Sort)>),
349    ForAll(Name, Sort, NameProvenance),
350    Assumption(Expr),
351    Head(Expr, Tag),
352    True,
353    /// Used for debugging. See [`TypeTrace`]
354    Trace(TypeTrace),
355}
356
357impl std::ops::Index<Name> for Scope {
358    type Output = Sort;
359
360    fn index(&self, name: Name) -> &Self::Output {
361        &self.bindings[name]
362    }
363}
364
365impl std::ops::Deref for NodePtr {
366    type Target = Rc<RefCell<Node>>;
367
368    fn deref(&self) -> &Self::Target {
369        &self.0
370    }
371}
372
373#[derive(Clone, Copy)]
374enum SimplifyPhase<'genv, 'tcx> {
375    /// Normalize and simplify inner `Expr`
376    Full(GlobalEnv<'genv, 'tcx>),
377    /// Only propagate `true` (TOP) and `false` (BOT)
378    Partial,
379}
380
381impl Node {
382    fn simplify(&mut self, phase: SimplifyPhase, assumed_preds: &mut SnapshotMap<Expr, ()>) {
383        // First, simplify the node itself
384        match &mut self.kind {
385            NodeKind::Head(pred, tag) => {
386                let pred = match phase {
387                    SimplifyPhase::Full(genv) => pred.normalize(genv).simplify(assumed_preds),
388                    SimplifyPhase::Partial => pred.clone(),
389                };
390                if pred.is_trivially_true() {
391                    self.kind = NodeKind::True;
392                } else {
393                    self.kind = NodeKind::Head(pred, *tag);
394                }
395            }
396            NodeKind::Assumption(pred) => {
397                if let SimplifyPhase::Full(genv) = phase {
398                    *pred = pred.normalize(genv).simplify(assumed_preds);
399                }
400                pred.visit_conj(|conjunct| {
401                    assumed_preds.insert(conjunct.erase_spans(), ());
402                });
403            }
404            _ => {}
405        }
406        let is_false_asm =
407            matches!(&self.kind, NodeKind::Assumption(pred) if pred.is_trivially_false());
408
409        // Then simplify the children
410        // (the order matters here because we need to collect assumed preds first)
411        for child in &self.children {
412            let current_version = assumed_preds.snapshot();
413            child.borrow_mut().simplify(phase, assumed_preds);
414            assumed_preds.rollback_to(current_version);
415        }
416
417        // Then remove any unnecessary children
418        match &mut self.kind {
419            NodeKind::Head(..) | NodeKind::True => {}
420            NodeKind::Assumption(_)
421            | NodeKind::Trace(_)
422            | NodeKind::Root(_)
423            | NodeKind::ForAll(..) => {
424                self.children
425                    .extract_if(.., |child| {
426                        is_false_asm || matches!(&child.borrow().kind, NodeKind::True)
427                    })
428                    .for_each(drop);
429            }
430        }
431        if !self.is_leaf() && self.children.is_empty() {
432            self.kind = NodeKind::True;
433        }
434    }
435
436    fn is_leaf(&self) -> bool {
437        matches!(self.kind, NodeKind::Head(..) | NodeKind::True)
438    }
439
440    fn replace_evars(&mut self, evars: &EVarStore) -> Result<(), EVid> {
441        for child in &self.children {
442            child.borrow_mut().replace_evars(evars)?;
443        }
444        match &mut self.kind {
445            NodeKind::Assumption(pred) => *pred = evars.replace_evars(pred)?,
446            NodeKind::Head(pred, _) => {
447                *pred = evars.replace_evars(pred)?;
448            }
449            NodeKind::Trace(trace) => {
450                evars.replace_evars(trace)?;
451            }
452            NodeKind::Root(_) | NodeKind::ForAll(..) | NodeKind::True => {}
453        }
454        Ok(())
455    }
456
457    fn to_fixpoint(&self, cx: &mut FixpointCtxt<Tag>) -> QueryResult<Option<fixpoint::Constraint>> {
458        let cstr = match &self.kind {
459            NodeKind::Trace(_) | NodeKind::ForAll(_, Sort::Loc, _) => {
460                children_to_fixpoint(cx, &self.children)?
461            }
462
463            NodeKind::Root(params) => {
464                // declare pretty-vars for params
465                for (var, sort) in params {
466                    if let Var::EarlyParam(param) = var
467                        && !sort.is_loc()
468                    {
469                        cx.with_early_param(param);
470                    }
471                }
472
473                let Some(children) = children_to_fixpoint(cx, &self.children)? else {
474                    return Ok(None);
475                };
476                let mut constr = children;
477                for (var, sort) in params.iter().rev() {
478                    if sort.is_loc() {
479                        continue;
480                    }
481                    constr = fixpoint::Constraint::ForAll(
482                        fixpoint::Bind {
483                            name: cx.var_to_fixpoint(var),
484                            sort: cx.sort_to_fixpoint(sort),
485                            pred: fixpoint::Pred::TRUE,
486                        },
487                        Box::new(constr),
488                    );
489                }
490                Some(constr)
491            }
492            NodeKind::ForAll(name, sort, provenance) => {
493                cx.with_name_map(*name, *provenance, |cx, fresh| -> QueryResult<_> {
494                    let Some(children) = children_to_fixpoint(cx, &self.children)? else {
495                        return Ok(None);
496                    };
497                    Ok(Some(fixpoint::Constraint::ForAll(
498                        fixpoint::Bind {
499                            name: fixpoint::Var::Local(fresh),
500                            sort: cx.sort_to_fixpoint(sort),
501                            pred: fixpoint::Pred::TRUE,
502                        },
503                        Box::new(children),
504                    )))
505                })?
506            }
507            NodeKind::Assumption(pred) => {
508                let (mut bindings, pred) = cx.assumption_to_fixpoint(pred)?;
509                let Some(cstr) = children_to_fixpoint(cx, &self.children)? else {
510                    return Ok(None);
511                };
512                bindings.push(fixpoint::Bind {
513                    name: fixpoint::Var::Underscore,
514                    sort: fixpoint::Sort::Int,
515                    pred,
516                });
517                Some(fixpoint::Constraint::foralls(bindings, cstr))
518            }
519            NodeKind::Head(pred, tag) => {
520                Some(cx.head_to_fixpoint(pred, |span| tag.with_dst(span))?)
521            }
522            NodeKind::True => None,
523        };
524        Ok(cstr)
525    }
526
527    /// Returns `true` if the node kind is [`ForAll`].
528    ///
529    /// [`ForAll`]: NodeKind::ForAll
530    fn is_forall(&self) -> bool {
531        matches!(self.kind, NodeKind::ForAll(..))
532    }
533
534    /// Returns `true` if the node kind is [`Head`].
535    ///
536    /// [`Head`]: NodeKind::Head
537    fn is_head(&self) -> bool {
538        matches!(self.kind, NodeKind::Head(..))
539    }
540}
541
542fn children_to_fixpoint(
543    cx: &mut FixpointCtxt<Tag>,
544    children: &[NodePtr],
545) -> QueryResult<Option<fixpoint::Constraint>> {
546    let mut children = children
547        .iter()
548        .filter_map(|node| node.borrow().to_fixpoint(cx).transpose())
549        .try_collect_vec()?;
550    let cstr = match children.len() {
551        0 => None,
552        1 => children.pop(),
553        _ => Some(fixpoint::Constraint::Conj(children)),
554    };
555    Ok(cstr)
556}
557
558struct ParentsIter {
559    ptr: Option<NodePtr>,
560}
561
562impl ParentsIter {
563    fn new(ptr: NodePtr) -> Self {
564        Self { ptr: Some(ptr) }
565    }
566}
567
568impl Iterator for ParentsIter {
569    type Item = NodePtr;
570
571    fn next(&mut self) -> Option<Self::Item> {
572        if let Some(ptr) = self.ptr.take() {
573            self.ptr = ptr.borrow().parent.as_ref().and_then(WeakNodePtr::upgrade);
574            Some(ptr)
575        } else {
576            None
577        }
578    }
579}
580
581mod pretty {
582    use std::fmt::{self, Write};
583
584    use flux_middle::pretty::*;
585    use pad_adapter::PadAdapter;
586
587    use super::*;
588
589    fn bindings_chain(ptr: &NodePtr) -> (Vec<(Name, Sort)>, Vec<NodePtr>) {
590        fn go(ptr: &NodePtr, mut bindings: Vec<(Name, Sort)>) -> (Vec<(Name, Sort)>, Vec<NodePtr>) {
591            let node = ptr.borrow();
592            if let NodeKind::ForAll(name, sort, _) = &node.kind {
593                bindings.push((*name, sort.clone()));
594                if let [child] = &node.children[..] {
595                    go(child, bindings)
596                } else {
597                    (bindings, node.children.clone())
598                }
599            } else {
600                (bindings, vec![NodePtr::clone(ptr)])
601            }
602        }
603        go(ptr, vec![])
604    }
605
606    fn preds_chain(ptr: &NodePtr) -> (Vec<Expr>, Vec<NodePtr>) {
607        fn go(ptr: &NodePtr, mut preds: Vec<Expr>) -> (Vec<Expr>, Vec<NodePtr>) {
608            let node = ptr.borrow();
609            if let NodeKind::Assumption(pred) = &node.kind {
610                preds.push(pred.clone());
611                if let [child] = &node.children[..] {
612                    go(child, preds)
613                } else {
614                    (preds, node.children.clone())
615                }
616            } else {
617                (preds, vec![NodePtr::clone(ptr)])
618            }
619        }
620        go(ptr, vec![])
621    }
622
623    impl Pretty for RefineTree {
624        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
625            w!(cx, f, "{:?}", &self.root)
626        }
627    }
628
629    impl Pretty for NodePtr {
630        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
631            let node = self.borrow();
632            match &node.kind {
633                NodeKind::Trace(trace) => {
634                    w!(cx, f, "@ {:?}", ^trace)?;
635                    w!(cx, with_padding(f), "\n{:?}", join!("\n", &node.children))
636                }
637                NodeKind::Root(bindings) => {
638                    w!(cx, f,
639                        "∀ {}.",
640                        ^bindings
641                            .iter()
642                            .format_with(", ", |(name, sort), f| {
643                                f(&format_args_cx!(cx, "{:?}: {:?}", ^name, sort))
644                            })
645                    )?;
646                    fmt_children(&node.children, cx, f)
647                }
648                NodeKind::ForAll(name, sort, _) => {
649                    let (bindings, children) = if cx.bindings_chain {
650                        bindings_chain(self)
651                    } else {
652                        (vec![(*name, sort.clone())], node.children.clone())
653                    };
654
655                    w!(cx, f,
656                        "∀ {}.",
657                        ^bindings
658                            .into_iter()
659                            .format_with(", ", |(name, sort), f| {
660                                f(&format_args_cx!(cx, "{:?}: {:?}", ^name, sort))
661                            })
662                    )?;
663                    fmt_children(&children, cx, f)
664                }
665                NodeKind::Assumption(pred) => {
666                    let (preds, children) = if cx.preds_chain {
667                        preds_chain(self)
668                    } else {
669                        (vec![pred.clone()], node.children.clone())
670                    };
671                    let guard = Expr::and_from_iter(preds).simplify(&SnapshotMap::default());
672                    w!(cx, f, "{:?} =>", parens!(guard, !guard.is_atom()))?;
673                    fmt_children(&children, cx, f)
674                }
675                NodeKind::Head(pred, tag) => {
676                    let pred = if cx.simplify_exprs {
677                        pred.simplify(&SnapshotMap::default())
678                    } else {
679                        pred.clone()
680                    };
681                    w!(cx, f, "{:?}", parens!(pred, !pred.is_atom()))?;
682                    if cx.tags {
683                        w!(cx, f, " ~ {:?}", tag)?;
684                    }
685                    Ok(())
686                }
687                NodeKind::True => {
688                    w!(cx, f, "true")
689                }
690            }
691        }
692    }
693
694    fn fmt_children(
695        children: &[NodePtr],
696        cx: &PrettyCx,
697        f: &mut fmt::Formatter<'_>,
698    ) -> fmt::Result {
699        match children {
700            [] => w!(cx, f, " true"),
701            [n] => {
702                if n.borrow().is_head() {
703                    w!(cx, f, " {:?}", n)
704                } else {
705                    w!(cx, with_padding(f), "\n{:?}", n)
706                }
707            }
708            _ => w!(cx, with_padding(f), "\n{:?}", join!("\n", children)),
709        }
710    }
711
712    impl Pretty for Cursor<'_> {
713        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
714            let mut elements = vec![];
715            for node in ParentsIter::new(NodePtr::clone(&self.ptr)) {
716                let n = node.borrow();
717                match &n.kind {
718                    NodeKind::Root(bindings) => {
719                        // We reverse here because is reversed again at the end
720                        for (name, sort) in bindings.iter().rev() {
721                            elements.push(format_cx!(cx, "{:?}: {:?}", ^name, sort));
722                        }
723                    }
724                    NodeKind::ForAll(name, sort, _) => {
725                        elements.push(format_cx!(cx, "{:?}: {:?}", ^name, sort));
726                    }
727                    NodeKind::Assumption(pred) => {
728                        elements.push(format_cx!(cx, "{:?}", pred));
729                    }
730                    _ => {}
731                }
732            }
733            write!(f, "{{{}}}", elements.into_iter().rev().format(", "))
734        }
735    }
736
737    impl Pretty for Scope {
738        fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
739            write!(
740                f,
741                "[bindings = {}, reftgenerics = {}]",
742                self.bindings
743                    .iter_enumerated()
744                    .format_with(", ", |(name, sort), f| {
745                        f(&format_args_cx!(cx, "{:?}: {:?}", ^name, sort))
746                    }),
747                self.params
748                    .iter()
749                    .format_with(", ", |(param_const, sort), f| {
750                        f(&format_args_cx!(cx, "{:?}: {:?}", ^param_const, sort))
751                    }),
752            )
753        }
754    }
755
756    fn with_padding<'a, 'b>(f: &'a mut fmt::Formatter<'b>) -> PadAdapter<'a, 'b, 'static> {
757        PadAdapter::with_padding(f, "  ")
758    }
759
760    impl_debug_with_default_cx!(
761        RefineTree => "refine_tree",
762        Cursor<'_> => "cursor",
763        Scope,
764    );
765}
766
767/// An explicit representation of a path in the [`RefineTree`] used for debugging/tracing/serialization ONLY.
768#[derive(Serialize, DebugAsJson)]
769pub struct RefineCtxtTrace {
770    bindings: Vec<RcxBind>,
771    exprs: Vec<String>,
772}
773
774#[derive(Serialize)]
775struct RcxBind {
776    name: String,
777    sort: String,
778}
779
780impl RefineCtxtTrace {
781    pub fn new(cx: &mut PrettyCx, cursor: &Cursor) -> Self {
782        let parents = ParentsIter::new(NodePtr::clone(&cursor.ptr)).collect_vec();
783        let mut bindings = vec![];
784        let mut exprs = vec![];
785
786        parents.into_iter().rev().for_each(|ptr| {
787            let node = ptr.borrow();
788            match &node.kind {
789                NodeKind::ForAll(name, sort, provenance) => {
790                    let name = cx
791                        .pretty_var_env
792                        .set(PrettyVar::Local(*name), provenance.opt_symbol());
793                    let sort = format_cx!(cx, "{:?}", sort);
794                    let bind = RcxBind { name, sort };
795                    bindings.push(bind);
796                }
797                NodeKind::Assumption(e)
798                    if !e.simplify(&SnapshotMap::default()).is_trivially_true() =>
799                {
800                    e.visit_conj(|e| {
801                        exprs.push(e.nested_string(cx));
802                    });
803                }
804                NodeKind::Root(binds) => {
805                    for (name, sort) in binds {
806                        let name = if let Var::EarlyParam(param) = name {
807                            cx.pretty_var_env
808                                .set(PrettyVar::Param(*param), Some(param.name))
809                        } else {
810                            format_cx!(cx, "{:?}", name)
811                        };
812                        let sort = format_cx!(cx, "{:?}", sort);
813                        let bind = RcxBind { name, sort };
814                        bindings.push(bind);
815                    }
816                }
817                _ => (),
818            }
819        });
820        Self { bindings, exprs }
821    }
822}
823
824impl Node {
825    /// replace bot-kvars with false
826    fn simplify_bot(&mut self) {
827        let graph = ConstraintDeps::new(self);
828        let bots = graph.bot_kvars();
829        self.simplify_with_assignment(&bots);
830        self.simplify(SimplifyPhase::Partial, &mut SnapshotMap::default());
831    }
832
833    /// replace top-kvars with true
834    fn simplify_top(&mut self) {
835        let graph = ConstraintDeps::new(self);
836        let tops = graph.top_kvars();
837        self.simplify_with_assignment(&tops);
838        self.simplify(SimplifyPhase::Partial, &mut SnapshotMap::default());
839    }
840
841    /// simplifies assumptions and heads using the TOP/BOT kvar assignment; follow
842    /// with a call to `simplify` to delete constraints with FALSE assm.
843    fn simplify_with_assignment(&mut self, assignment: &Assignment) {
844        match &mut self.kind {
845            NodeKind::Head(pred, tag) => {
846                let pred = assignment.simplify(pred);
847                self.kind = NodeKind::Head(pred, *tag);
848            }
849            NodeKind::Assumption(pred) => {
850                let pred = assignment.simplify(pred);
851                self.kind = NodeKind::Assumption(pred);
852            }
853            _ => {}
854        }
855        for child in &self.children {
856            child.borrow_mut().simplify_with_assignment(assignment);
857        }
858    }
859}
860
861#[derive(Debug)]
862struct ConstraintDeps {
863    /// assumptions for each clause
864    assumptions: IndexVec<ClauseId, FxHashSet<KVid>>,
865    /// head of each clause
866    heads: IndexVec<ClauseId, Head>,
867}
868
869impl ConstraintDeps {
870    fn new(node: &Node) -> Self {
871        let mut graph = Self { assumptions: IndexVec::default(), heads: IndexVec::default() };
872        graph.build(node, &mut vec![]);
873        graph
874    }
875
876    fn insert_clause(&mut self, assumptions: &[KVid], head: Head) {
877        self.assumptions.push(assumptions.iter().copied().collect());
878        self.heads.push(head);
879    }
880
881    fn build(&mut self, node: &Node, assumptions: &mut Vec<KVid>) {
882        let n = assumptions.len();
883        match &node.kind {
884            NodeKind::Head(expr, _) => {
885                expr.visit_conj(|e| {
886                    if let ExprKind::KVar(kvar) = e.kind() {
887                        self.insert_clause(assumptions, Head::KVar(kvar.kvid));
888                    } else {
889                        self.insert_clause(assumptions, Head::Conc);
890                    }
891                });
892            }
893            NodeKind::Assumption(expr) => {
894                expr.visit_conj(|e| {
895                    if let ExprKind::KVar(kvar) = e.kind() {
896                        assumptions.push(kvar.kvid);
897                    }
898                });
899            }
900            _ => {}
901        };
902
903        for child in &node.children {
904            self.build(&child.borrow(), assumptions);
905        }
906
907        assumptions.truncate(n); // restore ctx
908    }
909
910    /// set of edges where kvid appears as ASSM
911    fn kv_lhs(&self) -> UnordMap<KVid, Vec<ClauseId>> {
912        let mut res: UnordMap<KVid, Vec<ClauseId>> = UnordMap::default();
913        for (clause_id, kvids) in self.assumptions.iter_enumerated() {
914            for kvid in kvids {
915                res.entry(*kvid).or_default().push(clause_id);
916            }
917        }
918        res
919    }
920
921    /// set of edges where kvid appears as HEAD
922    fn kv_rhs(&self) -> UnordMap<KVid, Vec<ClauseId>> {
923        let mut res: UnordMap<KVid, Vec<ClauseId>> = UnordMap::default();
924        for (clause_id, head) in self.heads.iter_enumerated() {
925            if let Head::KVar(kvid) = head {
926                res.entry(*kvid).or_default().push(clause_id);
927            }
928        }
929        res
930    }
931
932    /// Computes the set of all kvars that can be assigned to Bot (False),
933    /// because they are not (transitively) reachable from any concrete ASSUMPTION.
934    fn bot_kvars(self) -> Assignment {
935        // set of BOT kvars (initially, all)
936        let mut assignment = Assignment::new(Label::Bot);
937
938        let kv_lhs = self.kv_lhs();
939
940        // set of BOT kvars in LHS of each constraint with KVar HEAD
941        let mut bot_assms: IndexVec<ClauseId, FxHashSet<KVid>> = self.assumptions;
942
943        // set of constraints `cid` whose bot-assms is empty
944        let mut candidates: Vec<ClauseId> = bot_assms
945            .iter_enumerated()
946            .filter_map(|(cid, lhs)| if lhs.is_empty() { Some(cid) } else { None })
947            .collect();
948
949        // while there is a candidate constraint, that has NO BOT kvars in lhs
950        while let Some(cid) = candidates.pop() {
951            if let Head::KVar(kvid) = self.heads[cid] {
952                // un-BOT the head kvar
953                assignment.remove(kvid);
954                // remove the head kvar from all (bot) assumptions where it currently occurs
955                for cid in kv_lhs.get(&kvid).unwrap_or(&vec![]) {
956                    // if cid HEAD is a kvar
957                    if let Head::KVar(rhs_kvid) = self.heads[*cid] {
958                        let assms = &mut bot_assms[*cid];
959                        assms.remove(&kvid);
960                        if assignment.has_label(rhs_kvid) && assms.is_empty() {
961                            candidates.push(*cid);
962                        }
963                    };
964                }
965            }
966        }
967
968        assignment
969    }
970
971    /// Computes the set of all kvars that can be assigned to Top (True),
972    /// because they do not (transitively) reach any concrete HEAD.
973    fn top_kvars(self) -> Assignment {
974        // initialize
975        let mut assignment = Assignment::new(Label::Top);
976
977        let kv_rhs = self.kv_rhs();
978
979        // set of kvar {k | cid in graph.edges, c.rhs is concrete, k in c.lhs }
980        let mut candidates = vec![];
981        for (cid, head) in self.heads.iter_enumerated() {
982            if matches!(head, Head::Conc) {
983                for kvid in &self.assumptions[cid] {
984                    candidates.push(*kvid);
985                }
986            }
987        }
988
989        // set each kvar that transitively reaches a concrete HEAD to NON-BOT
990        while let Some(kvid) = candidates.pop() {
991            // set that kvar to non-top
992            assignment.remove(kvid);
993
994            // for each constraint where kvid appears as head
995            for cid in kv_rhs.get(&kvid).unwrap_or(&vec![]) {
996                // add kvars in lhs to candidates (if they have not already been solved to non-BOT)
997                for lhs_kvid in &self.assumptions[*cid] {
998                    if assignment.has_label(*lhs_kvid) {
999                        candidates.push(*lhs_kvid);
1000                    }
1001                }
1002            }
1003        }
1004
1005        assignment
1006    }
1007}
1008
1009newtype_index! {
1010    struct ClauseId {}
1011}
1012
1013#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1014enum Head {
1015    /// KVar
1016    KVar(KVid),
1017    /// A *conc*rete predicate, i.e., an [`Expr`] that's not a kvar. We don't need to know
1018    /// the exact expression, only that it's concrete.
1019    Conc,
1020}
1021
1022#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1023enum Label {
1024    /// Kvar can be solved to false
1025    Bot,
1026    /// Kvar can be solved to true
1027    Top,
1028}
1029
1030struct Assignment {
1031    /// These vars are NOT assigned `label`,
1032    /// all other `KVid` implicitly have assignment `label`
1033    vars: UnordSet<KVid>,
1034    label: Label,
1035}
1036
1037impl Assignment {
1038    fn new(label: Label) -> Self {
1039        let vars = UnordSet::new();
1040        Self { vars, label }
1041    }
1042
1043    fn has_label(&self, kvid: KVid) -> bool {
1044        !self.vars.contains(&kvid)
1045    }
1046
1047    fn remove(&mut self, kvid: KVid) {
1048        self.vars.insert(kvid);
1049    }
1050
1051    /// simplifies the given predicate expression by replacing
1052    /// kvid assigned to TOP with True, BOT with false.
1053    fn simplify(&self, pred: &Expr) -> Expr {
1054        let mut preds = vec![];
1055        for p in pred.flatten_conjs() {
1056            if let ExprKind::KVar(kvar) = p.kind()
1057                && self.has_label(kvar.kvid)
1058            {
1059                if self.label == Label::Bot {
1060                    return Expr::ff();
1061                } // else, skip pushing `p` into `preds`
1062            } else {
1063                preds.push(p.clone());
1064            }
1065        }
1066        Expr::and_from_iter(preds)
1067    }
1068}