1use std::{
2 cell::RefCell,
3 ops::ControlFlow,
4 rc::{Rc, Weak},
5};
6
7use flux_common::{index::IndexVec, iter::IterExt, tracked_span_bug};
8use flux_config::OverflowMode;
9use flux_macros::DebugAsJson;
10use flux_middle::{
11 global_env::GlobalEnv,
12 pretty::{PrettyCx, PrettyNested, format_cx},
13 queries::QueryResult,
14 rty::{
15 BaseTy, EVid, Expr, ExprKind, KVid, Name, NameProvenance, PrettyVar, Sort, Ty, TyKind, Var,
16 fold::{TypeFoldable, TypeSuperVisitable, TypeVisitable, TypeVisitor},
17 },
18};
19use itertools::Itertools;
20use rustc_data_structures::snapshot_map::SnapshotMap;
21use rustc_hash::{FxHashMap, FxHashSet};
22use rustc_index::newtype_index;
23use rustc_middle::ty::TyCtxt;
24use serde::Serialize;
25
26use crate::{
27 evars::EVarStore,
28 fixpoint_encoding::{FixpointCtxt, fixpoint},
29 infer::{Tag, TypeTrace},
30};
31
32pub struct RefineTree {
50 root: NodePtr,
51}
52
53impl RefineTree {
54 pub(crate) fn new(params: Vec<(Var, Sort)>) -> RefineTree {
55 let root =
56 Node { kind: NodeKind::Root(params), nbindings: 0, parent: None, children: vec![] };
57 let root = NodePtr(Rc::new(RefCell::new(root)));
58 RefineTree { root }
59 }
60
61 pub(crate) fn simplify(&mut self, genv: GlobalEnv) {
62 self.root
63 .borrow_mut()
64 .simplify(SimplifyPhase::Full(genv), &mut SnapshotMap::default());
65 self.root.borrow_mut().simplify_bot();
66 self.root.borrow_mut().simplify_top();
67 }
68
69 pub(crate) fn into_fixpoint(
70 self,
71 cx: &mut FixpointCtxt<Tag>,
72 ) -> QueryResult<fixpoint::Constraint> {
73 Ok(self
74 .root
75 .borrow()
76 .to_fixpoint(cx)?
77 .unwrap_or(fixpoint::Constraint::TRUE))
78 }
79
80 pub(crate) fn cursor_at_root(&mut self) -> Cursor<'_> {
81 Cursor { ptr: NodePtr(Rc::clone(&self.root)), tree: self }
82 }
83
84 pub(crate) fn replace_evars(&mut self, evars: &EVarStore) -> Result<(), EVid> {
85 self.root.borrow_mut().replace_evars(evars)
86 }
87}
88
89pub struct Cursor<'a> {
94 tree: &'a mut RefineTree,
95 ptr: NodePtr,
96}
97
98impl Cursor<'_> {
99 pub(crate) fn move_to(&mut self, marker: &Marker, clear_children: bool) -> Option<Cursor<'_>> {
105 let ptr = marker.ptr.upgrade()?;
106 if clear_children {
107 ptr.borrow_mut().children.clear();
108 }
109 Some(Cursor { ptr, tree: self.tree })
110 }
111
112 #[must_use]
114 pub(crate) fn marker(&self) -> Marker {
115 Marker { ptr: NodePtr::downgrade(&self.ptr) }
116 }
117
118 #[must_use]
119 pub(crate) fn branch(&mut self) -> Cursor<'_> {
120 Cursor { tree: self.tree, ptr: NodePtr::clone(&self.ptr) }
121 }
122
123 pub(crate) fn vars(&self) -> impl Iterator<Item = (Var, Sort)> {
124 self.ptr.scope().into_iter()
126 }
127
128 #[expect(dead_code, reason = "used for debugging")]
129 pub(crate) fn push_trace(&mut self, trace: TypeTrace) {
130 self.ptr = self.ptr.push_node(NodeKind::Trace(trace));
131 }
132
133 pub(crate) fn define_var(&mut self, sort: &Sort, provenance: NameProvenance) -> Name {
136 let fresh = Name::from_usize(self.ptr.next_name_idx());
137 self.ptr = self
138 .ptr
139 .push_node(NodeKind::ForAll(fresh, sort.clone(), provenance));
140 fresh
141 }
142
143 pub(crate) fn assume_pred(&mut self, pred: impl Into<Expr>) {
147 let pred = pred.into();
148 if !pred.is_trivially_true() {
149 self.ptr = self.ptr.push_node(NodeKind::Assumption(pred));
150 }
151 }
152
153 pub(crate) fn check_pred(&mut self, pred: impl Into<Expr>, tag: Tag) {
156 let pred = pred.into();
157 if !pred.is_trivially_true() {
158 self.ptr.push_node(NodeKind::Head(pred, tag));
159 }
160 }
161
162 pub(crate) fn check_impl(&mut self, pred1: impl Into<Expr>, pred2: impl Into<Expr>, tag: Tag) {
165 self.ptr
166 .push_node(NodeKind::Assumption(pred1.into()))
167 .push_node(NodeKind::Head(pred2.into(), tag));
168 }
169
170 pub(crate) fn assume_invariants(
171 &mut self,
172 tcx: TyCtxt,
173 ty: &Ty,
174 overflow_checking: OverflowMode,
175 ) {
176 struct Visitor<'a, 'b, 'tcx> {
177 tcx: TyCtxt<'tcx>,
178 cursor: &'a mut Cursor<'b>,
179 overflow_mode: OverflowMode,
180 }
181 impl TypeVisitor for Visitor<'_, '_, '_> {
182 fn visit_bty(&mut self, bty: &BaseTy) -> ControlFlow<!> {
183 match bty {
184 BaseTy::Adt(adt_def, substs) if adt_def.is_box() => substs.visit_with(self),
185 BaseTy::Ref(_, ty, _) => ty.visit_with(self),
186 BaseTy::Tuple(tys) => tys.visit_with(self),
187 _ => ControlFlow::Continue(()),
188 }
189 }
190
191 fn visit_ty(&mut self, ty: &Ty) -> ControlFlow<!> {
192 if let TyKind::Indexed(bty, idx) = ty.kind()
193 && !idx.has_escaping_bvars()
194 {
195 for invariant in bty.invariants(self.tcx, self.overflow_mode) {
196 let invariant = invariant.apply(idx);
197 self.cursor.assume_pred(&invariant);
198 }
199 }
200 ty.super_visit_with(self)
201 }
202 }
203 let _ = ty.visit_with(&mut Visitor { tcx, cursor: self, overflow_mode: overflow_checking });
204 }
205}
206
207pub struct Marker {
214 ptr: WeakNodePtr,
215}
216
217impl Marker {
218 pub fn scope(&self) -> Option<Scope> {
222 Some(self.ptr.upgrade()?.scope())
223 }
224
225 pub fn has_free_vars<T: TypeVisitable>(&self, t: &T) -> bool {
226 let ptr = self
227 .ptr
228 .upgrade()
229 .unwrap_or_else(|| tracked_span_bug!("`has_free_vars` called on invalid `Marker`"));
230
231 let nbindings = ptr.next_name_idx();
232
233 !t.fvars().into_iter().all(|name| name.index() < nbindings)
234 }
235}
236
237#[derive(PartialEq, Eq)]
239pub struct Scope {
240 params: Vec<(Var, Sort)>,
241 bindings: IndexVec<Name, Sort>,
242}
243
244impl Scope {
245 pub(crate) fn iter(&self) -> impl Iterator<Item = (Var, Sort)> + '_ {
246 itertools::chain(
247 self.params.iter().cloned(),
248 self.bindings
249 .iter_enumerated()
250 .map(|(name, sort)| (Var::Free(name), sort.clone())),
251 )
252 }
253
254 fn into_iter(self) -> impl Iterator<Item = (Var, Sort)> {
255 itertools::chain(
256 self.params,
257 self.bindings
258 .into_iter_enumerated()
259 .map(|(name, sort)| (Var::Free(name), sort.clone())),
260 )
261 }
262
263 pub fn has_free_vars<T: TypeFoldable>(&self, t: &T) -> bool {
265 !self.contains_all(t.fvars())
266 }
267
268 fn contains_all(&self, iter: impl IntoIterator<Item = Name>) -> bool {
269 iter.into_iter().all(|name| self.contains(name))
270 }
271
272 fn contains(&self, name: Name) -> bool {
273 name.index() < self.bindings.len()
274 }
275}
276
277struct Node {
278 kind: NodeKind,
279 nbindings: usize,
283 parent: Option<WeakNodePtr>,
284 children: Vec<NodePtr>,
285}
286
287#[derive(Clone)]
288struct NodePtr(Rc<RefCell<Node>>);
289
290impl NodePtr {
291 fn downgrade(this: &Self) -> WeakNodePtr {
292 WeakNodePtr(Rc::downgrade(&this.0))
293 }
294
295 fn push_node(&mut self, kind: NodeKind) -> NodePtr {
296 debug_assert!(!matches!(self.borrow().kind, NodeKind::Head(..)));
297 let node = Node {
298 kind,
299 nbindings: self.next_name_idx(),
300 parent: Some(NodePtr::downgrade(self)),
301 children: vec![],
302 };
303 let node = NodePtr(Rc::new(RefCell::new(node)));
304 self.borrow_mut().children.push(NodePtr::clone(&node));
305 node
306 }
307
308 fn next_name_idx(&self) -> usize {
309 self.borrow().nbindings + usize::from(self.borrow().is_forall())
310 }
311
312 fn scope(&self) -> Scope {
313 let mut params = None;
314 let parents = ParentsIter::new(self.clone());
315 let bindings = parents
316 .filter_map(|node| {
317 let node = node.borrow();
318 match &node.kind {
319 NodeKind::Root(p) => {
320 params = Some(p.clone());
321 None
322 }
323 NodeKind::ForAll(_, sort, _) => Some(sort.clone()),
324 _ => None,
325 }
326 })
327 .collect_vec()
328 .into_iter()
329 .rev()
330 .collect();
331 Scope { bindings, params: params.unwrap_or_default() }
332 }
333}
334
335struct WeakNodePtr(Weak<RefCell<Node>>);
336
337impl WeakNodePtr {
338 fn upgrade(&self) -> Option<NodePtr> {
339 Some(NodePtr(self.0.upgrade()?))
340 }
341}
342
343enum NodeKind {
344 Root(Vec<(Var, Sort)>),
346 ForAll(Name, Sort, NameProvenance),
347 Assumption(Expr),
348 Head(Expr, Tag),
349 True,
350 Trace(TypeTrace),
352}
353
354impl std::ops::Index<Name> for Scope {
355 type Output = Sort;
356
357 fn index(&self, name: Name) -> &Self::Output {
358 &self.bindings[name]
359 }
360}
361
362impl std::ops::Deref for NodePtr {
363 type Target = Rc<RefCell<Node>>;
364
365 fn deref(&self) -> &Self::Target {
366 &self.0
367 }
368}
369
370#[derive(Clone, Copy)]
371enum SimplifyPhase<'genv, 'tcx> {
372 Full(GlobalEnv<'genv, 'tcx>),
374 Partial,
376}
377
378impl Node {
379 fn simplify(&mut self, phase: SimplifyPhase, assumed_preds: &mut SnapshotMap<Expr, ()>) {
380 match &mut self.kind {
382 NodeKind::Head(pred, tag) => {
383 let pred = match phase {
384 SimplifyPhase::Full(genv) => pred.normalize(genv).simplify(assumed_preds),
385 SimplifyPhase::Partial => pred.clone(),
386 };
387 if pred.is_trivially_true() {
388 self.kind = NodeKind::True;
389 } else {
390 self.kind = NodeKind::Head(pred, *tag);
391 }
392 }
393 NodeKind::Assumption(pred) => {
394 if let SimplifyPhase::Full(genv) = phase {
395 *pred = pred.normalize(genv).simplify(assumed_preds);
396 }
397 pred.visit_conj(|conjunct| {
398 assumed_preds.insert(conjunct.erase_spans(), ());
399 });
400 }
401 _ => {}
402 }
403 let is_false_asm =
404 matches!(&self.kind, NodeKind::Assumption(pred) if pred.is_trivially_false());
405
406 for child in &self.children {
409 let current_version = assumed_preds.snapshot();
410 child.borrow_mut().simplify(phase, assumed_preds);
411 assumed_preds.rollback_to(current_version);
412 }
413
414 match &mut self.kind {
416 NodeKind::Head(..) | NodeKind::True => {}
417 NodeKind::Assumption(_)
418 | NodeKind::Trace(_)
419 | NodeKind::Root(_)
420 | NodeKind::ForAll(..) => {
421 self.children
422 .extract_if(.., |child| {
423 is_false_asm || matches!(&child.borrow().kind, NodeKind::True)
424 })
425 .for_each(drop);
426 }
427 }
428 if !self.is_leaf() && self.children.is_empty() {
429 self.kind = NodeKind::True;
430 }
431 }
432
433 fn is_leaf(&self) -> bool {
434 matches!(self.kind, NodeKind::Head(..) | NodeKind::True)
435 }
436
437 fn replace_evars(&mut self, evars: &EVarStore) -> Result<(), EVid> {
438 for child in &self.children {
439 child.borrow_mut().replace_evars(evars)?;
440 }
441 match &mut self.kind {
442 NodeKind::Assumption(pred) => *pred = evars.replace_evars(pred)?,
443 NodeKind::Head(pred, _) => {
444 *pred = evars.replace_evars(pred)?;
445 }
446 NodeKind::Trace(trace) => {
447 evars.replace_evars(trace)?;
448 }
449 NodeKind::Root(_) | NodeKind::ForAll(..) | NodeKind::True => {}
450 }
451 Ok(())
452 }
453
454 fn to_fixpoint(&self, cx: &mut FixpointCtxt<Tag>) -> QueryResult<Option<fixpoint::Constraint>> {
455 let cstr = match &self.kind {
456 NodeKind::Trace(_) | NodeKind::ForAll(_, Sort::Loc, _) => {
457 children_to_fixpoint(cx, &self.children)?
458 }
459
460 NodeKind::Root(params) => {
461 for (var, sort) in params {
463 if let Var::EarlyParam(param) = var
464 && !sort.is_loc()
465 {
466 cx.with_early_param(param);
467 }
468 }
469
470 let Some(children) = children_to_fixpoint(cx, &self.children)? else {
471 return Ok(None);
472 };
473 let mut constr = children;
474 for (var, sort) in params.iter().rev() {
475 if sort.is_loc() {
476 continue;
477 }
478 constr = fixpoint::Constraint::ForAll(
479 fixpoint::Bind {
480 name: cx.var_to_fixpoint(var),
481 sort: cx.sort_to_fixpoint(sort),
482 pred: fixpoint::Pred::TRUE,
483 },
484 Box::new(constr),
485 );
486 }
487 Some(constr)
488 }
489 NodeKind::ForAll(name, sort, provenance) => {
490 cx.with_name_map(*name, *provenance, |cx, fresh| -> QueryResult<_> {
491 let Some(children) = children_to_fixpoint(cx, &self.children)? else {
492 return Ok(None);
493 };
494 Ok(Some(fixpoint::Constraint::ForAll(
495 fixpoint::Bind {
496 name: fixpoint::Var::Local(fresh),
497 sort: cx.sort_to_fixpoint(sort),
498 pred: fixpoint::Pred::TRUE,
499 },
500 Box::new(children),
501 )))
502 })?
503 }
504 NodeKind::Assumption(pred) => {
505 let (mut bindings, pred) = cx.assumption_to_fixpoint(pred)?;
506 let Some(cstr) = children_to_fixpoint(cx, &self.children)? else {
507 return Ok(None);
508 };
509 bindings.push(fixpoint::Bind {
510 name: fixpoint::Var::Underscore,
511 sort: fixpoint::Sort::Int,
512 pred,
513 });
514 Some(fixpoint::Constraint::foralls(bindings, cstr))
515 }
516 NodeKind::Head(pred, tag) => {
517 Some(cx.head_to_fixpoint(pred, |span| tag.with_dst(span))?)
518 }
519 NodeKind::True => None,
520 };
521 Ok(cstr)
522 }
523
524 fn is_forall(&self) -> bool {
528 matches!(self.kind, NodeKind::ForAll(..))
529 }
530
531 fn is_head(&self) -> bool {
535 matches!(self.kind, NodeKind::Head(..))
536 }
537}
538
539fn children_to_fixpoint(
540 cx: &mut FixpointCtxt<Tag>,
541 children: &[NodePtr],
542) -> QueryResult<Option<fixpoint::Constraint>> {
543 let mut children = children
544 .iter()
545 .filter_map(|node| node.borrow().to_fixpoint(cx).transpose())
546 .try_collect_vec()?;
547 let cstr = match children.len() {
548 0 => None,
549 1 => children.pop(),
550 _ => Some(fixpoint::Constraint::Conj(children)),
551 };
552 Ok(cstr)
553}
554
555struct ParentsIter {
556 ptr: Option<NodePtr>,
557}
558
559impl ParentsIter {
560 fn new(ptr: NodePtr) -> Self {
561 Self { ptr: Some(ptr) }
562 }
563}
564
565impl Iterator for ParentsIter {
566 type Item = NodePtr;
567
568 fn next(&mut self) -> Option<Self::Item> {
569 if let Some(ptr) = self.ptr.take() {
570 self.ptr = ptr.borrow().parent.as_ref().and_then(WeakNodePtr::upgrade);
571 Some(ptr)
572 } else {
573 None
574 }
575 }
576}
577
578mod pretty {
579 use std::fmt::{self, Write};
580
581 use flux_middle::pretty::*;
582 use pad_adapter::PadAdapter;
583
584 use super::*;
585
586 fn bindings_chain(ptr: &NodePtr) -> (Vec<(Name, Sort)>, Vec<NodePtr>) {
587 fn go(ptr: &NodePtr, mut bindings: Vec<(Name, Sort)>) -> (Vec<(Name, Sort)>, Vec<NodePtr>) {
588 let node = ptr.borrow();
589 if let NodeKind::ForAll(name, sort, _) = &node.kind {
590 bindings.push((*name, sort.clone()));
591 if let [child] = &node.children[..] {
592 go(child, bindings)
593 } else {
594 (bindings, node.children.clone())
595 }
596 } else {
597 (bindings, vec![NodePtr::clone(ptr)])
598 }
599 }
600 go(ptr, vec![])
601 }
602
603 fn preds_chain(ptr: &NodePtr) -> (Vec<Expr>, Vec<NodePtr>) {
604 fn go(ptr: &NodePtr, mut preds: Vec<Expr>) -> (Vec<Expr>, Vec<NodePtr>) {
605 let node = ptr.borrow();
606 if let NodeKind::Assumption(pred) = &node.kind {
607 preds.push(pred.clone());
608 if let [child] = &node.children[..] {
609 go(child, preds)
610 } else {
611 (preds, node.children.clone())
612 }
613 } else {
614 (preds, vec![NodePtr::clone(ptr)])
615 }
616 }
617 go(ptr, vec![])
618 }
619
620 impl Pretty for RefineTree {
621 fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
622 w!(cx, f, "{:?}", &self.root)
623 }
624 }
625
626 impl Pretty for NodePtr {
627 fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
628 let node = self.borrow();
629 match &node.kind {
630 NodeKind::Trace(trace) => {
631 w!(cx, f, "@ {:?}", ^trace)?;
632 w!(cx, with_padding(f), "\n{:?}", join!("\n", &node.children))
633 }
634 NodeKind::Root(bindings) => {
635 w!(cx, f,
636 "∀ {}.",
637 ^bindings
638 .iter()
639 .format_with(", ", |(name, sort), f| {
640 f(&format_args_cx!(cx, "{:?}: {:?}", ^name, sort))
641 })
642 )?;
643 fmt_children(&node.children, cx, f)
644 }
645 NodeKind::ForAll(name, sort, _) => {
646 let (bindings, children) = if cx.bindings_chain {
647 bindings_chain(self)
648 } else {
649 (vec![(*name, sort.clone())], node.children.clone())
650 };
651
652 w!(cx, f,
653 "∀ {}.",
654 ^bindings
655 .into_iter()
656 .format_with(", ", |(name, sort), f| {
657 f(&format_args_cx!(cx, "{:?}: {:?}", ^name, sort))
658 })
659 )?;
660 fmt_children(&children, cx, f)
661 }
662 NodeKind::Assumption(pred) => {
663 let (preds, children) = if cx.preds_chain {
664 preds_chain(self)
665 } else {
666 (vec![pred.clone()], node.children.clone())
667 };
668 let guard = Expr::and_from_iter(preds).simplify(&SnapshotMap::default());
669 w!(cx, f, "{:?} =>", parens!(guard, !guard.is_atom()))?;
670 fmt_children(&children, cx, f)
671 }
672 NodeKind::Head(pred, tag) => {
673 let pred = if cx.simplify_exprs {
674 pred.simplify(&SnapshotMap::default())
675 } else {
676 pred.clone()
677 };
678 w!(cx, f, "{:?}", parens!(pred, !pred.is_atom()))?;
679 if cx.tags {
680 w!(cx, f, " ~ {:?}", tag)?;
681 }
682 Ok(())
683 }
684 NodeKind::True => {
685 w!(cx, f, "true")
686 }
687 }
688 }
689 }
690
691 fn fmt_children(
692 children: &[NodePtr],
693 cx: &PrettyCx,
694 f: &mut fmt::Formatter<'_>,
695 ) -> fmt::Result {
696 match children {
697 [] => w!(cx, f, " true"),
698 [n] => {
699 if n.borrow().is_head() {
700 w!(cx, f, " {:?}", n)
701 } else {
702 w!(cx, with_padding(f), "\n{:?}", n)
703 }
704 }
705 _ => w!(cx, with_padding(f), "\n{:?}", join!("\n", children)),
706 }
707 }
708
709 impl Pretty for Cursor<'_> {
710 fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
711 let mut elements = vec![];
712 for node in ParentsIter::new(NodePtr::clone(&self.ptr)) {
713 let n = node.borrow();
714 match &n.kind {
715 NodeKind::Root(bindings) => {
716 for (name, sort) in bindings.iter().rev() {
718 elements.push(format_cx!(cx, "{:?}: {:?}", ^name, sort));
719 }
720 }
721 NodeKind::ForAll(name, sort, _) => {
722 elements.push(format_cx!(cx, "{:?}: {:?}", ^name, sort));
723 }
724 NodeKind::Assumption(pred) => {
725 elements.push(format_cx!(cx, "{:?}", pred));
726 }
727 _ => {}
728 }
729 }
730 write!(f, "{{{}}}", elements.into_iter().rev().format(", "))
731 }
732 }
733
734 impl Pretty for Scope {
735 fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
736 write!(
737 f,
738 "[bindings = {}, reftgenerics = {}]",
739 self.bindings
740 .iter_enumerated()
741 .format_with(", ", |(name, sort), f| {
742 f(&format_args_cx!(cx, "{:?}: {:?}", ^name, sort))
743 }),
744 self.params
745 .iter()
746 .format_with(", ", |(param_const, sort), f| {
747 f(&format_args_cx!(cx, "{:?}: {:?}", ^param_const, sort))
748 }),
749 )
750 }
751 }
752
753 fn with_padding<'a, 'b>(f: &'a mut fmt::Formatter<'b>) -> PadAdapter<'a, 'b, 'static> {
754 PadAdapter::with_padding(f, " ")
755 }
756
757 impl_debug_with_default_cx!(
758 RefineTree => "refine_tree",
759 Cursor<'_> => "cursor",
760 Scope,
761 );
762}
763
764#[derive(Serialize, DebugAsJson)]
766pub struct RefineCtxtTrace {
767 bindings: Vec<RcxBind>,
768 exprs: Vec<String>,
769}
770
771#[derive(Serialize)]
772struct RcxBind {
773 name: String,
774 sort: String,
775}
776
777impl RefineCtxtTrace {
778 pub fn new(cx: &mut PrettyCx, cursor: &Cursor) -> Self {
779 let parents = ParentsIter::new(NodePtr::clone(&cursor.ptr)).collect_vec();
780 let mut bindings = vec![];
781 let mut exprs = vec![];
782
783 parents.into_iter().rev().for_each(|ptr| {
784 let node = ptr.borrow();
785 match &node.kind {
786 NodeKind::ForAll(name, sort, provenance) => {
787 let name = cx
788 .pretty_var_env
789 .set(PrettyVar::Local(*name), provenance.opt_symbol());
790 let sort = format_cx!(cx, "{:?}", sort);
791 let bind = RcxBind { name, sort };
792 bindings.push(bind);
793 }
794 NodeKind::Assumption(e)
795 if !e.simplify(&SnapshotMap::default()).is_trivially_true() =>
796 {
797 e.visit_conj(|e| {
798 exprs.push(e.nested_string(cx));
799 });
800 }
801 NodeKind::Root(binds) => {
802 for (name, sort) in binds {
803 let name = if let Var::EarlyParam(param) = name {
804 cx.pretty_var_env
805 .set(PrettyVar::Param(*param), Some(param.name))
806 } else {
807 format_cx!(cx, "{:?}", name)
808 };
809 let sort = format_cx!(cx, "{:?}", sort);
810 let bind = RcxBind { name, sort };
811 bindings.push(bind);
812 }
813 }
814 _ => (),
815 }
816 });
817 Self { bindings, exprs }
818 }
819}
820
821impl Node {
822 fn simplify_bot(&mut self) {
824 let graph = ConstraintDeps::new(self);
825 let bots = graph.bot_kvars();
826 self.simplify_with_assignment(&bots);
827 self.simplify(SimplifyPhase::Partial, &mut SnapshotMap::default());
828 }
829
830 fn simplify_top(&mut self) {
832 let graph = ConstraintDeps::new(self);
833 let tops = graph.top_kvars();
834 self.simplify_with_assignment(&tops);
835 self.simplify(SimplifyPhase::Partial, &mut SnapshotMap::default());
836 }
837
838 fn simplify_with_assignment(&mut self, assignment: &Assignment) {
841 match &mut self.kind {
842 NodeKind::Head(pred, tag) => {
843 let pred = assignment.simplify(pred);
844 self.kind = NodeKind::Head(pred, *tag);
845 }
846 NodeKind::Assumption(pred) => {
847 let pred = assignment.simplify(pred);
848 self.kind = NodeKind::Assumption(pred);
849 }
850 _ => {}
851 }
852 for child in &self.children {
853 child.borrow_mut().simplify_with_assignment(assignment);
854 }
855 }
856}
857
858#[derive(Debug)]
859struct ConstraintDeps {
860 assumptions: IndexVec<ClauseId, FxHashSet<KVid>>,
862 heads: IndexVec<ClauseId, Head>,
864}
865
866impl ConstraintDeps {
867 fn new(node: &Node) -> Self {
868 let mut graph = Self { assumptions: IndexVec::default(), heads: IndexVec::default() };
869 graph.build(node, &mut vec![]);
870 graph
871 }
872
873 fn insert_clause(&mut self, assumptions: &[KVid], head: Head) {
874 self.assumptions.push(assumptions.iter().copied().collect());
875 self.heads.push(head);
876 }
877
878 fn build(&mut self, node: &Node, assumptions: &mut Vec<KVid>) {
879 let n = assumptions.len();
880 match &node.kind {
881 NodeKind::Head(expr, _) => {
882 expr.visit_conj(|e| {
883 if let ExprKind::KVar(kvar) = e.kind() {
884 self.insert_clause(assumptions, Head::KVar(kvar.kvid));
885 } else {
886 self.insert_clause(assumptions, Head::Conc);
887 }
888 });
889 }
890 NodeKind::Assumption(expr) => {
891 expr.visit_conj(|e| {
892 if let ExprKind::KVar(kvar) = e.kind() {
893 assumptions.push(kvar.kvid);
894 }
895 });
896 }
897 _ => {}
898 };
899
900 for child in &node.children {
901 self.build(&child.borrow(), assumptions);
902 }
903
904 assumptions.truncate(n); }
906
907 fn kv_lhs(&self) -> FxHashMap<KVid, Vec<ClauseId>> {
909 let mut res: FxHashMap<KVid, Vec<ClauseId>> = FxHashMap::default();
910 for (clause_id, kvids) in self.assumptions.iter_enumerated() {
911 for kvid in kvids {
912 res.entry(*kvid).or_default().push(clause_id);
913 }
914 }
915 res
916 }
917
918 fn kv_rhs(&self) -> FxHashMap<KVid, Vec<ClauseId>> {
920 let mut res: FxHashMap<KVid, Vec<ClauseId>> = FxHashMap::default();
921 for (clause_id, head) in self.heads.iter_enumerated() {
922 if let Head::KVar(kvid) = head {
923 res.entry(*kvid).or_default().push(clause_id);
924 }
925 }
926 res
927 }
928
929 fn bot_kvars(self) -> Assignment {
932 let mut assignment = Assignment::new(Label::Bot);
934
935 let kv_lhs = self.kv_lhs();
936
937 let mut bot_assms: IndexVec<ClauseId, FxHashSet<KVid>> = self.assumptions;
939
940 let mut candidates: Vec<ClauseId> = bot_assms
942 .iter_enumerated()
943 .filter_map(|(cid, lhs)| if lhs.is_empty() { Some(cid) } else { None })
944 .collect();
945
946 while let Some(cid) = candidates.pop() {
948 if let Head::KVar(kvid) = self.heads[cid] {
949 assignment.remove(kvid);
951 for cid in kv_lhs.get(&kvid).unwrap_or(&vec![]) {
953 if let Head::KVar(rhs_kvid) = self.heads[*cid] {
955 let assms = &mut bot_assms[*cid];
956 assms.remove(&kvid);
957 if assignment.has_label(rhs_kvid) && assms.is_empty() {
958 candidates.push(*cid);
959 }
960 };
961 }
962 }
963 }
964
965 assignment
966 }
967
968 fn top_kvars(self) -> Assignment {
971 let mut assignment = Assignment::new(Label::Top);
973
974 let kv_rhs = self.kv_rhs();
975
976 let mut candidates = vec![];
978 for (cid, head) in self.heads.iter_enumerated() {
979 if matches!(head, Head::Conc) {
980 for kvid in &self.assumptions[cid] {
981 candidates.push(*kvid);
982 }
983 }
984 }
985
986 while let Some(kvid) = candidates.pop() {
988 assignment.remove(kvid);
990
991 for cid in kv_rhs.get(&kvid).unwrap_or(&vec![]) {
993 for lhs_kvid in &self.assumptions[*cid] {
995 if assignment.has_label(*lhs_kvid) {
996 candidates.push(*lhs_kvid);
997 }
998 }
999 }
1000 }
1001
1002 assignment
1003 }
1004}
1005
1006newtype_index! {
1007 struct ClauseId {}
1008}
1009
1010#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1011enum Head {
1012 KVar(KVid),
1014 Conc,
1017}
1018
1019#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1020enum Label {
1021 Bot,
1023 Top,
1025}
1026
1027struct Assignment {
1028 vars: FxHashSet<KVid>,
1031 label: Label,
1032}
1033
1034impl Assignment {
1035 fn new(label: Label) -> Self {
1036 let vars = FxHashSet::default();
1037 Self { vars, label }
1038 }
1039
1040 fn has_label(&self, kvid: KVid) -> bool {
1041 !self.vars.contains(&kvid)
1042 }
1043
1044 fn remove(&mut self, kvid: KVid) {
1045 self.vars.insert(kvid);
1046 }
1047
1048 fn simplify(&self, pred: &Expr) -> Expr {
1051 let mut preds = vec![];
1052 for p in pred.flatten_conjs() {
1053 if let ExprKind::KVar(kvar) = p.kind()
1054 && self.has_label(kvar.kvid)
1055 {
1056 if self.label == Label::Bot {
1057 return Expr::ff();
1058 } } else {
1060 preds.push(p.clone());
1061 }
1062 }
1063 Expr::and_from_iter(preds)
1064 }
1065}