1use std::{
2 cell::RefCell,
3 ops::ControlFlow,
4 rc::{Rc, Weak},
5};
6
7use flux_common::{index::IndexVec, iter::IterExt, tracked_span_bug};
8use flux_config::OverflowMode;
9use flux_macros::DebugAsJson;
10use flux_middle::{
11 global_env::GlobalEnv,
12 pretty::{PrettyCx, PrettyNested, format_cx},
13 queries::QueryResult,
14 rty::{
15 BaseTy, EVid, Expr, ExprKind, KVid, Name, Sort, Ty, TyKind, Var,
16 fold::{TypeFoldable, TypeSuperVisitable, TypeVisitable, TypeVisitor},
17 },
18};
19use itertools::Itertools;
20use rustc_data_structures::snapshot_map::SnapshotMap;
21use rustc_hash::{FxHashMap, FxHashSet};
22use rustc_index::newtype_index;
23use rustc_middle::ty::TyCtxt;
24use serde::Serialize;
25
26use crate::{
27 evars::EVarStore,
28 fixpoint_encoding::{FixpointCtxt, fixpoint},
29 infer::{Tag, TypeTrace},
30};
31
32pub struct RefineTree {
50 root: NodePtr,
51}
52
53impl RefineTree {
54 pub(crate) fn new(params: Vec<(Var, Sort)>) -> RefineTree {
55 let root =
56 Node { kind: NodeKind::Root(params), nbindings: 0, parent: None, children: vec![] };
57 let root = NodePtr(Rc::new(RefCell::new(root)));
58 RefineTree { root }
59 }
60
61 pub(crate) fn simplify(&mut self, genv: GlobalEnv) {
62 self.root
63 .borrow_mut()
64 .simplify(SimplifyPhase::Full(genv), &mut SnapshotMap::default());
65 self.root.borrow_mut().simplify_bot();
66 self.root.borrow_mut().simplify_top();
67 }
68
69 pub(crate) fn into_fixpoint(
70 self,
71 cx: &mut FixpointCtxt<Tag>,
72 ) -> QueryResult<fixpoint::Constraint> {
73 Ok(self
74 .root
75 .borrow()
76 .to_fixpoint(cx)?
77 .unwrap_or(fixpoint::Constraint::TRUE))
78 }
79
80 pub(crate) fn cursor_at_root(&mut self) -> Cursor<'_> {
81 Cursor { ptr: NodePtr(Rc::clone(&self.root)), tree: self }
82 }
83
84 pub(crate) fn replace_evars(&mut self, evars: &EVarStore) -> Result<(), EVid> {
85 self.root.borrow_mut().replace_evars(evars)
86 }
87}
88
89pub struct Cursor<'a> {
94 tree: &'a mut RefineTree,
95 ptr: NodePtr,
96}
97
98impl Cursor<'_> {
99 pub(crate) fn move_to(&mut self, marker: &Marker, clear_children: bool) -> Option<Cursor<'_>> {
105 let ptr = marker.ptr.upgrade()?;
106 if clear_children {
107 ptr.borrow_mut().children.clear();
108 }
109 Some(Cursor { ptr, tree: self.tree })
110 }
111
112 #[must_use]
114 pub(crate) fn marker(&self) -> Marker {
115 Marker { ptr: NodePtr::downgrade(&self.ptr) }
116 }
117
118 #[must_use]
119 pub(crate) fn branch(&mut self) -> Cursor<'_> {
120 Cursor { tree: self.tree, ptr: NodePtr::clone(&self.ptr) }
121 }
122
123 pub(crate) fn vars(&self) -> impl Iterator<Item = (Var, Sort)> {
124 self.ptr.scope().into_iter()
126 }
127
128 #[expect(dead_code, reason = "used for debugging")]
129 pub(crate) fn push_trace(&mut self, trace: TypeTrace) {
130 self.ptr = self.ptr.push_node(NodeKind::Trace(trace));
131 }
132
133 pub(crate) fn define_var(&mut self, sort: &Sort) -> Name {
136 let fresh = Name::from_usize(self.ptr.next_name_idx());
137 self.ptr = self.ptr.push_node(NodeKind::ForAll(fresh, sort.clone()));
138 fresh
139 }
140
141 pub(crate) fn assume_pred(&mut self, pred: impl Into<Expr>) {
145 let pred = pred.into();
146 if !pred.is_trivially_true() {
147 self.ptr = self.ptr.push_node(NodeKind::Assumption(pred));
148 }
149 }
150
151 pub(crate) fn check_pred(&mut self, pred: impl Into<Expr>, tag: Tag) {
154 let pred = pred.into();
155 if !pred.is_trivially_true() {
156 self.ptr.push_node(NodeKind::Head(pred, tag));
157 }
158 }
159
160 pub(crate) fn check_impl(&mut self, pred1: impl Into<Expr>, pred2: impl Into<Expr>, tag: Tag) {
163 self.ptr
164 .push_node(NodeKind::Assumption(pred1.into()))
165 .push_node(NodeKind::Head(pred2.into(), tag));
166 }
167
168 pub(crate) fn assume_invariants(
169 &mut self,
170 tcx: TyCtxt,
171 ty: &Ty,
172 overflow_checking: OverflowMode,
173 ) {
174 struct Visitor<'a, 'b, 'tcx> {
175 tcx: TyCtxt<'tcx>,
176 cursor: &'a mut Cursor<'b>,
177 overflow_mode: OverflowMode,
178 }
179 impl TypeVisitor for Visitor<'_, '_, '_> {
180 fn visit_bty(&mut self, bty: &BaseTy) -> ControlFlow<!> {
181 match bty {
182 BaseTy::Adt(adt_def, substs) if adt_def.is_box() => substs.visit_with(self),
183 BaseTy::Ref(_, ty, _) => ty.visit_with(self),
184 BaseTy::Tuple(tys) => tys.visit_with(self),
185 _ => ControlFlow::Continue(()),
186 }
187 }
188
189 fn visit_ty(&mut self, ty: &Ty) -> ControlFlow<!> {
190 if let TyKind::Indexed(bty, idx) = ty.kind()
191 && !idx.has_escaping_bvars()
192 {
193 for invariant in bty.invariants(self.tcx, self.overflow_mode) {
194 let invariant = invariant.apply(idx);
195 self.cursor.assume_pred(&invariant);
196 }
197 }
198 ty.super_visit_with(self)
199 }
200 }
201 let _ = ty.visit_with(&mut Visitor { tcx, cursor: self, overflow_mode: overflow_checking });
202 }
203}
204
205pub struct Marker {
212 ptr: WeakNodePtr,
213}
214
215impl Marker {
216 pub fn scope(&self) -> Option<Scope> {
220 Some(self.ptr.upgrade()?.scope())
221 }
222
223 pub fn has_free_vars<T: TypeVisitable>(&self, t: &T) -> bool {
224 let ptr = self
225 .ptr
226 .upgrade()
227 .unwrap_or_else(|| tracked_span_bug!("`has_free_vars` called on invalid `Marker`"));
228
229 let nbindings = ptr.next_name_idx();
230
231 !t.fvars().into_iter().all(|name| name.index() < nbindings)
232 }
233}
234
235#[derive(PartialEq, Eq)]
237pub struct Scope {
238 params: Vec<(Var, Sort)>,
239 bindings: IndexVec<Name, Sort>,
240}
241
242impl Scope {
243 pub(crate) fn iter(&self) -> impl Iterator<Item = (Var, Sort)> + '_ {
244 itertools::chain(
245 self.params.iter().cloned(),
246 self.bindings
247 .iter_enumerated()
248 .map(|(name, sort)| (Var::Free(name), sort.clone())),
249 )
250 }
251
252 fn into_iter(self) -> impl Iterator<Item = (Var, Sort)> {
253 itertools::chain(
254 self.params,
255 self.bindings
256 .into_iter_enumerated()
257 .map(|(name, sort)| (Var::Free(name), sort.clone())),
258 )
259 }
260
261 pub fn has_free_vars<T: TypeFoldable>(&self, t: &T) -> bool {
263 !self.contains_all(t.fvars())
264 }
265
266 fn contains_all(&self, iter: impl IntoIterator<Item = Name>) -> bool {
267 iter.into_iter().all(|name| self.contains(name))
268 }
269
270 fn contains(&self, name: Name) -> bool {
271 name.index() < self.bindings.len()
272 }
273}
274
275struct Node {
276 kind: NodeKind,
277 nbindings: usize,
281 parent: Option<WeakNodePtr>,
282 children: Vec<NodePtr>,
283}
284
285#[derive(Clone)]
286struct NodePtr(Rc<RefCell<Node>>);
287
288impl NodePtr {
289 fn downgrade(this: &Self) -> WeakNodePtr {
290 WeakNodePtr(Rc::downgrade(&this.0))
291 }
292
293 fn push_node(&mut self, kind: NodeKind) -> NodePtr {
294 debug_assert!(!matches!(self.borrow().kind, NodeKind::Head(..)));
295 let node = Node {
296 kind,
297 nbindings: self.next_name_idx(),
298 parent: Some(NodePtr::downgrade(self)),
299 children: vec![],
300 };
301 let node = NodePtr(Rc::new(RefCell::new(node)));
302 self.borrow_mut().children.push(NodePtr::clone(&node));
303 node
304 }
305
306 fn next_name_idx(&self) -> usize {
307 self.borrow().nbindings + usize::from(self.borrow().is_forall())
308 }
309
310 fn scope(&self) -> Scope {
311 let mut params = None;
312 let parents = ParentsIter::new(self.clone());
313 let bindings = parents
314 .filter_map(|node| {
315 let node = node.borrow();
316 match &node.kind {
317 NodeKind::Root(p) => {
318 params = Some(p.clone());
319 None
320 }
321 NodeKind::ForAll(_, sort) => Some(sort.clone()),
322 _ => None,
323 }
324 })
325 .collect_vec()
326 .into_iter()
327 .rev()
328 .collect();
329 Scope { bindings, params: params.unwrap_or_default() }
330 }
331}
332
333struct WeakNodePtr(Weak<RefCell<Node>>);
334
335impl WeakNodePtr {
336 fn upgrade(&self) -> Option<NodePtr> {
337 Some(NodePtr(self.0.upgrade()?))
338 }
339}
340
341enum NodeKind {
342 Root(Vec<(Var, Sort)>),
344 ForAll(Name, Sort),
345 Assumption(Expr),
346 Head(Expr, Tag),
347 True,
348 Trace(TypeTrace),
350}
351
352impl std::ops::Index<Name> for Scope {
353 type Output = Sort;
354
355 fn index(&self, name: Name) -> &Self::Output {
356 &self.bindings[name]
357 }
358}
359
360impl std::ops::Deref for NodePtr {
361 type Target = Rc<RefCell<Node>>;
362
363 fn deref(&self) -> &Self::Target {
364 &self.0
365 }
366}
367
368#[derive(Clone, Copy)]
369enum SimplifyPhase<'genv, 'tcx> {
370 Full(GlobalEnv<'genv, 'tcx>),
372 Partial,
374}
375
376impl Node {
377 fn simplify(&mut self, phase: SimplifyPhase, assumed_preds: &mut SnapshotMap<Expr, ()>) {
378 match &mut self.kind {
380 NodeKind::Head(pred, tag) => {
381 let pred = match phase {
382 SimplifyPhase::Full(genv) => pred.normalize(genv).simplify(assumed_preds),
383 SimplifyPhase::Partial => pred.clone(),
384 };
385 if pred.is_trivially_true() {
386 self.kind = NodeKind::True;
387 } else {
388 self.kind = NodeKind::Head(pred, *tag);
389 }
390 }
391 NodeKind::Assumption(pred) => {
392 if let SimplifyPhase::Full(genv) = phase {
393 *pred = pred.normalize(genv).simplify(assumed_preds);
394 }
395 pred.visit_conj(|conjunct| {
396 assumed_preds.insert(conjunct.erase_spans(), ());
397 });
398 }
399 _ => {}
400 }
401 let is_false_asm =
402 matches!(&self.kind, NodeKind::Assumption(pred) if pred.is_trivially_false());
403
404 for child in &self.children {
407 let current_version = assumed_preds.snapshot();
408 child.borrow_mut().simplify(phase, assumed_preds);
409 assumed_preds.rollback_to(current_version);
410 }
411
412 match &mut self.kind {
414 NodeKind::Head(..) | NodeKind::True => {}
415 NodeKind::Assumption(_)
416 | NodeKind::Trace(_)
417 | NodeKind::Root(_)
418 | NodeKind::ForAll(..) => {
419 self.children
420 .extract_if(.., |child| {
421 is_false_asm || matches!(&child.borrow().kind, NodeKind::True)
422 })
423 .for_each(drop);
424 }
425 }
426 if !self.is_leaf() && self.children.is_empty() {
427 self.kind = NodeKind::True;
428 }
429 }
430
431 fn is_leaf(&self) -> bool {
432 matches!(self.kind, NodeKind::Head(..) | NodeKind::True)
433 }
434
435 fn replace_evars(&mut self, evars: &EVarStore) -> Result<(), EVid> {
436 for child in &self.children {
437 child.borrow_mut().replace_evars(evars)?;
438 }
439 match &mut self.kind {
440 NodeKind::Assumption(pred) => *pred = evars.replace_evars(pred)?,
441 NodeKind::Head(pred, _) => {
442 *pred = evars.replace_evars(pred)?;
443 }
444 NodeKind::Trace(trace) => {
445 evars.replace_evars(trace)?;
446 }
447 NodeKind::Root(_) | NodeKind::ForAll(..) | NodeKind::True => {}
448 }
449 Ok(())
450 }
451
452 fn to_fixpoint(&self, cx: &mut FixpointCtxt<Tag>) -> QueryResult<Option<fixpoint::Constraint>> {
453 let cstr = match &self.kind {
454 NodeKind::Trace(_) | NodeKind::ForAll(_, Sort::Loc) => {
455 children_to_fixpoint(cx, &self.children)?
456 }
457
458 NodeKind::Root(params) => {
459 let Some(children) = children_to_fixpoint(cx, &self.children)? else {
460 return Ok(None);
461 };
462 let mut constr = children;
463 for (var, sort) in params.iter().rev() {
464 if sort.is_loc() {
465 continue;
466 }
467 constr = fixpoint::Constraint::ForAll(
468 fixpoint::Bind {
469 name: cx.var_to_fixpoint(var),
470 sort: cx.sort_to_fixpoint(sort),
471 pred: fixpoint::Pred::TRUE,
472 },
473 Box::new(constr),
474 );
475 }
476 Some(constr)
477 }
478 NodeKind::ForAll(name, sort) => {
479 cx.with_name_map(*name, |cx, fresh| -> QueryResult<_> {
480 let Some(children) = children_to_fixpoint(cx, &self.children)? else {
481 return Ok(None);
482 };
483 Ok(Some(fixpoint::Constraint::ForAll(
484 fixpoint::Bind {
485 name: fixpoint::Var::Local(fresh),
486 sort: cx.sort_to_fixpoint(sort),
487 pred: fixpoint::Pred::TRUE,
488 },
489 Box::new(children),
490 )))
491 })?
492 }
493 NodeKind::Assumption(pred) => {
494 let (mut bindings, pred) = cx.assumption_to_fixpoint(pred)?;
495 let Some(cstr) = children_to_fixpoint(cx, &self.children)? else {
496 return Ok(None);
497 };
498 bindings.push(fixpoint::Bind {
499 name: fixpoint::Var::Underscore,
500 sort: fixpoint::Sort::Int,
501 pred,
502 });
503 Some(fixpoint::Constraint::foralls(bindings, cstr))
504 }
505 NodeKind::Head(pred, tag) => {
506 Some(cx.head_to_fixpoint(pred, |span| tag.with_dst(span))?)
507 }
508 NodeKind::True => None,
509 };
510 Ok(cstr)
511 }
512
513 fn is_forall(&self) -> bool {
517 matches!(self.kind, NodeKind::ForAll(..))
518 }
519
520 fn is_head(&self) -> bool {
524 matches!(self.kind, NodeKind::Head(..))
525 }
526}
527
528fn children_to_fixpoint(
529 cx: &mut FixpointCtxt<Tag>,
530 children: &[NodePtr],
531) -> QueryResult<Option<fixpoint::Constraint>> {
532 let mut children = children
533 .iter()
534 .filter_map(|node| node.borrow().to_fixpoint(cx).transpose())
535 .try_collect_vec()?;
536 let cstr = match children.len() {
537 0 => None,
538 1 => children.pop(),
539 _ => Some(fixpoint::Constraint::Conj(children)),
540 };
541 Ok(cstr)
542}
543
544struct ParentsIter {
545 ptr: Option<NodePtr>,
546}
547
548impl ParentsIter {
549 fn new(ptr: NodePtr) -> Self {
550 Self { ptr: Some(ptr) }
551 }
552}
553
554impl Iterator for ParentsIter {
555 type Item = NodePtr;
556
557 fn next(&mut self) -> Option<Self::Item> {
558 if let Some(ptr) = self.ptr.take() {
559 self.ptr = ptr.borrow().parent.as_ref().and_then(WeakNodePtr::upgrade);
560 Some(ptr)
561 } else {
562 None
563 }
564 }
565}
566
567mod pretty {
568 use std::fmt::{self, Write};
569
570 use flux_middle::pretty::*;
571 use pad_adapter::PadAdapter;
572
573 use super::*;
574
575 fn bindings_chain(ptr: &NodePtr) -> (Vec<(Name, Sort)>, Vec<NodePtr>) {
576 fn go(ptr: &NodePtr, mut bindings: Vec<(Name, Sort)>) -> (Vec<(Name, Sort)>, Vec<NodePtr>) {
577 let node = ptr.borrow();
578 if let NodeKind::ForAll(name, sort) = &node.kind {
579 bindings.push((*name, sort.clone()));
580 if let [child] = &node.children[..] {
581 go(child, bindings)
582 } else {
583 (bindings, node.children.clone())
584 }
585 } else {
586 (bindings, vec![NodePtr::clone(ptr)])
587 }
588 }
589 go(ptr, vec![])
590 }
591
592 fn preds_chain(ptr: &NodePtr) -> (Vec<Expr>, Vec<NodePtr>) {
593 fn go(ptr: &NodePtr, mut preds: Vec<Expr>) -> (Vec<Expr>, Vec<NodePtr>) {
594 let node = ptr.borrow();
595 if let NodeKind::Assumption(pred) = &node.kind {
596 preds.push(pred.clone());
597 if let [child] = &node.children[..] {
598 go(child, preds)
599 } else {
600 (preds, node.children.clone())
601 }
602 } else {
603 (preds, vec![NodePtr::clone(ptr)])
604 }
605 }
606 go(ptr, vec![])
607 }
608
609 impl Pretty for RefineTree {
610 fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
611 w!(cx, f, "{:?}", &self.root)
612 }
613 }
614
615 impl Pretty for NodePtr {
616 fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
617 let node = self.borrow();
618 match &node.kind {
619 NodeKind::Trace(trace) => {
620 w!(cx, f, "@ {:?}", ^trace)?;
621 w!(cx, with_padding(f), "\n{:?}", join!("\n", &node.children))
622 }
623 NodeKind::Root(bindings) => {
624 w!(cx, f,
625 "∀ {}.",
626 ^bindings
627 .iter()
628 .format_with(", ", |(name, sort), f| {
629 f(&format_args_cx!(cx, "{:?}: {:?}", ^name, sort))
630 })
631 )?;
632 fmt_children(&node.children, cx, f)
633 }
634 NodeKind::ForAll(name, sort) => {
635 let (bindings, children) = if cx.bindings_chain {
636 bindings_chain(self)
637 } else {
638 (vec![(*name, sort.clone())], node.children.clone())
639 };
640
641 w!(cx, f,
642 "∀ {}.",
643 ^bindings
644 .into_iter()
645 .format_with(", ", |(name, sort), f| {
646 f(&format_args_cx!(cx, "{:?}: {:?}", ^name, sort))
647 })
648 )?;
649 fmt_children(&children, cx, f)
650 }
651 NodeKind::Assumption(pred) => {
652 let (preds, children) = if cx.preds_chain {
653 preds_chain(self)
654 } else {
655 (vec![pred.clone()], node.children.clone())
656 };
657 let guard = Expr::and_from_iter(preds).simplify(&SnapshotMap::default());
658 w!(cx, f, "{:?} =>", parens!(guard, !guard.is_atom()))?;
659 fmt_children(&children, cx, f)
660 }
661 NodeKind::Head(pred, tag) => {
662 let pred = if cx.simplify_exprs {
663 pred.simplify(&SnapshotMap::default())
664 } else {
665 pred.clone()
666 };
667 w!(cx, f, "{:?}", parens!(pred, !pred.is_atom()))?;
668 if cx.tags {
669 w!(cx, f, " ~ {:?}", tag)?;
670 }
671 Ok(())
672 }
673 NodeKind::True => {
674 w!(cx, f, "true")
675 }
676 }
677 }
678 }
679
680 fn fmt_children(
681 children: &[NodePtr],
682 cx: &PrettyCx,
683 f: &mut fmt::Formatter<'_>,
684 ) -> fmt::Result {
685 match children {
686 [] => w!(cx, f, " true"),
687 [n] => {
688 if n.borrow().is_head() {
689 w!(cx, f, " {:?}", n)
690 } else {
691 w!(cx, with_padding(f), "\n{:?}", n)
692 }
693 }
694 _ => w!(cx, with_padding(f), "\n{:?}", join!("\n", children)),
695 }
696 }
697
698 impl Pretty for Cursor<'_> {
699 fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
700 let mut elements = vec![];
701 for node in ParentsIter::new(NodePtr::clone(&self.ptr)) {
702 let n = node.borrow();
703 match &n.kind {
704 NodeKind::Root(bindings) => {
705 for (name, sort) in bindings.iter().rev() {
707 elements.push(format_cx!(cx, "{:?}: {:?}", ^name, sort));
708 }
709 }
710 NodeKind::ForAll(name, sort) => {
711 elements.push(format_cx!(cx, "{:?}: {:?}", ^name, sort));
712 }
713 NodeKind::Assumption(pred) => {
714 elements.push(format_cx!(cx, "{:?}", pred));
715 }
716 _ => {}
717 }
718 }
719 write!(f, "{{{}}}", elements.into_iter().rev().format(", "))
720 }
721 }
722
723 impl Pretty for Scope {
724 fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
725 write!(
726 f,
727 "[bindings = {}, reftgenerics = {}]",
728 self.bindings
729 .iter_enumerated()
730 .format_with(", ", |(name, sort), f| {
731 f(&format_args_cx!(cx, "{:?}: {:?}", ^name, sort))
732 }),
733 self.params
734 .iter()
735 .format_with(", ", |(param_const, sort), f| {
736 f(&format_args_cx!(cx, "{:?}: {:?}", ^param_const, sort))
737 }),
738 )
739 }
740 }
741
742 fn with_padding<'a, 'b>(f: &'a mut fmt::Formatter<'b>) -> PadAdapter<'a, 'b, 'static> {
743 PadAdapter::with_padding(f, " ")
744 }
745
746 impl_debug_with_default_cx!(
747 RefineTree => "refine_tree",
748 Cursor<'_> => "cursor",
749 Scope,
750 );
751}
752
753#[derive(Serialize, DebugAsJson)]
755pub struct RefineCtxtTrace {
756 bindings: Vec<RcxBind>,
757 exprs: Vec<String>,
758}
759
760#[derive(Serialize)]
761struct RcxBind {
762 name: String,
763 sort: String,
764}
765
766impl RefineCtxtTrace {
767 pub fn new(genv: GlobalEnv, cursor: &Cursor) -> Self {
768 let parents = ParentsIter::new(NodePtr::clone(&cursor.ptr)).collect_vec();
769 let mut bindings = vec![];
770 let mut exprs = vec![];
771 let cx = &PrettyCx::default(genv);
772
773 parents.into_iter().rev().for_each(|ptr| {
774 let node = ptr.borrow();
775 match &node.kind {
776 NodeKind::ForAll(name, sort) => {
777 let bind = RcxBind {
778 name: format_cx!(cx, "{:?}", ^name),
779 sort: format_cx!(cx, "{:?}", sort),
780 };
781 bindings.push(bind);
782 }
783 NodeKind::Assumption(e)
784 if !e.simplify(&SnapshotMap::default()).is_trivially_true() =>
785 {
786 e.visit_conj(|e| {
787 exprs.push(e.nested_string(cx));
788 });
789 }
790 NodeKind::Root(binds) => {
791 for (name, sort) in binds {
792 let bind = RcxBind {
793 name: format_cx!(cx, "{:?}", name),
794 sort: format_cx!(cx, "{:?}", sort),
795 };
796 bindings.push(bind);
797 }
798 }
799 _ => (),
800 }
801 });
802 Self { bindings, exprs }
803 }
804}
805
806impl Node {
807 fn simplify_bot(&mut self) {
809 let graph = ConstraintDeps::new(self);
810 let bots = graph.bot_kvars();
811 self.simplify_with_assignment(&bots);
812 self.simplify(SimplifyPhase::Partial, &mut SnapshotMap::default());
813 }
814
815 fn simplify_top(&mut self) {
817 let graph = ConstraintDeps::new(self);
818 let tops = graph.top_kvars();
819 self.simplify_with_assignment(&tops);
820 self.simplify(SimplifyPhase::Partial, &mut SnapshotMap::default());
821 }
822
823 fn simplify_with_assignment(&mut self, assignment: &Assignment) {
826 match &mut self.kind {
827 NodeKind::Head(pred, tag) => {
828 let pred = assignment.simplify(pred);
829 self.kind = NodeKind::Head(pred, *tag);
830 }
831 NodeKind::Assumption(pred) => {
832 let pred = assignment.simplify(pred);
833 self.kind = NodeKind::Assumption(pred);
834 }
835 _ => {}
836 }
837 for child in &self.children {
838 child.borrow_mut().simplify_with_assignment(assignment);
839 }
840 }
841}
842
843#[derive(Debug)]
844struct ConstraintDeps {
845 assumptions: IndexVec<ClauseId, FxHashSet<KVid>>,
847 heads: IndexVec<ClauseId, Head>,
849}
850
851impl ConstraintDeps {
852 fn new(node: &Node) -> Self {
853 let mut graph = Self { assumptions: IndexVec::default(), heads: IndexVec::default() };
854 graph.build(node, &mut vec![]);
855 graph
856 }
857
858 fn insert_clause(&mut self, assumptions: &[KVid], head: Head) {
859 self.assumptions.push(assumptions.iter().copied().collect());
860 self.heads.push(head);
861 }
862
863 fn build(&mut self, node: &Node, assumptions: &mut Vec<KVid>) {
864 let n = assumptions.len();
865 match &node.kind {
866 NodeKind::Head(expr, _) => {
867 expr.visit_conj(|e| {
868 if let ExprKind::KVar(kvar) = e.kind() {
869 self.insert_clause(assumptions, Head::KVar(kvar.kvid));
870 } else {
871 self.insert_clause(assumptions, Head::Conc);
872 }
873 });
874 }
875 NodeKind::Assumption(expr) => {
876 expr.visit_conj(|e| {
877 if let ExprKind::KVar(kvar) = e.kind() {
878 assumptions.push(kvar.kvid);
879 }
880 });
881 }
882 _ => {}
883 };
884
885 for child in &node.children {
886 self.build(&child.borrow(), assumptions);
887 }
888
889 assumptions.truncate(n); }
891
892 fn kv_lhs(&self) -> FxHashMap<KVid, Vec<ClauseId>> {
894 let mut res: FxHashMap<KVid, Vec<ClauseId>> = FxHashMap::default();
895 for (clause_id, kvids) in self.assumptions.iter_enumerated() {
896 for kvid in kvids {
897 res.entry(*kvid).or_default().push(clause_id);
898 }
899 }
900 res
901 }
902
903 fn kv_rhs(&self) -> FxHashMap<KVid, Vec<ClauseId>> {
905 let mut res: FxHashMap<KVid, Vec<ClauseId>> = FxHashMap::default();
906 for (clause_id, head) in self.heads.iter_enumerated() {
907 if let Head::KVar(kvid) = head {
908 res.entry(*kvid).or_default().push(clause_id);
909 }
910 }
911 res
912 }
913
914 fn bot_kvars(self) -> Assignment {
917 let mut assignment = Assignment::new(Label::Bot);
919
920 let kv_lhs = self.kv_lhs();
921
922 let mut bot_assms: IndexVec<ClauseId, FxHashSet<KVid>> = self.assumptions;
924
925 let mut candidates: Vec<ClauseId> = bot_assms
927 .iter_enumerated()
928 .filter_map(|(cid, lhs)| if lhs.is_empty() { Some(cid) } else { None })
929 .collect();
930
931 while let Some(cid) = candidates.pop() {
933 if let Head::KVar(kvid) = self.heads[cid] {
934 assignment.remove(kvid);
936 for cid in kv_lhs.get(&kvid).unwrap_or(&vec![]) {
938 if let Head::KVar(rhs_kvid) = self.heads[*cid] {
940 let assms = &mut bot_assms[*cid];
941 assms.remove(&kvid);
942 if assignment.has_label(rhs_kvid) && assms.is_empty() {
943 candidates.push(*cid);
944 }
945 };
946 }
947 }
948 }
949
950 assignment
951 }
952
953 fn top_kvars(self) -> Assignment {
956 let mut assignment = Assignment::new(Label::Top);
958
959 let kv_rhs = self.kv_rhs();
960
961 let mut candidates = vec![];
963 for (cid, head) in self.heads.iter_enumerated() {
964 if matches!(head, Head::Conc) {
965 for kvid in &self.assumptions[cid] {
966 candidates.push(*kvid);
967 }
968 }
969 }
970
971 while let Some(kvid) = candidates.pop() {
973 assignment.remove(kvid);
975
976 for cid in kv_rhs.get(&kvid).unwrap_or(&vec![]) {
978 for lhs_kvid in &self.assumptions[*cid] {
980 if assignment.has_label(*lhs_kvid) {
981 candidates.push(*lhs_kvid);
982 }
983 }
984 }
985 }
986
987 assignment
988 }
989}
990
991newtype_index! {
992 struct ClauseId {}
993}
994
995#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
996enum Head {
997 KVar(KVid),
999 Conc,
1002}
1003
1004#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1005enum Label {
1006 Bot,
1008 Top,
1010}
1011
1012struct Assignment {
1013 vars: FxHashSet<KVid>,
1016 label: Label,
1017}
1018
1019impl Assignment {
1020 fn new(label: Label) -> Self {
1021 let vars = FxHashSet::default();
1022 Self { vars, label }
1023 }
1024
1025 fn has_label(&self, kvid: KVid) -> bool {
1026 !self.vars.contains(&kvid)
1027 }
1028
1029 fn remove(&mut self, kvid: KVid) {
1030 self.vars.insert(kvid);
1031 }
1032
1033 fn simplify(&self, pred: &Expr) -> Expr {
1036 let mut preds = vec![];
1037 for p in pred.flatten_conjs() {
1038 if let ExprKind::KVar(kvar) = p.kind()
1039 && self.has_label(kvar.kvid)
1040 {
1041 if self.label == Label::Bot {
1042 return Expr::ff();
1043 } } else {
1045 preds.push(p.clone());
1046 }
1047 }
1048 Expr::and_from_iter(preds)
1049 }
1050}