1use std::{
2 cell::RefCell,
3 ops::ControlFlow,
4 rc::{Rc, Weak},
5};
6
7use flux_common::{index::IndexVec, iter::IterExt, tracked_span_bug};
8use flux_config::OverflowMode;
9use flux_macros::DebugAsJson;
10use flux_middle::{
11 global_env::GlobalEnv,
12 pretty::{PrettyCx, PrettyNested, format_cx},
13 queries::QueryResult,
14 rty::{
15 BaseTy, EVid, Expr, ExprKind, KVid, Name, NameProvenance, PrettyVar, Sort, Ty, TyKind, Var,
16 fold::{TypeFoldable, TypeSuperVisitable, TypeVisitable, TypeVisitor},
17 },
18};
19use itertools::Itertools;
20use rustc_data_structures::{
21 snapshot_map::SnapshotMap,
22 unord::{UnordMap, UnordSet},
23};
24use rustc_hash::FxHashSet;
25use rustc_index::newtype_index;
26use rustc_middle::ty::TyCtxt;
27use serde::Serialize;
28
29use crate::{
30 evars::EVarStore,
31 fixpoint_encoding::{FixpointCtxt, fixpoint},
32 infer::{Tag, TypeTrace},
33};
34
35pub struct RefineTree {
53 root: NodePtr,
54}
55
56impl RefineTree {
57 pub(crate) fn new(params: Vec<(Var, Sort)>) -> RefineTree {
58 let root =
59 Node { kind: NodeKind::Root(params), nbindings: 0, parent: None, children: vec![] };
60 let root = NodePtr(Rc::new(RefCell::new(root)));
61 RefineTree { root }
62 }
63
64 pub(crate) fn simplify(&mut self, genv: GlobalEnv) {
65 self.root
66 .borrow_mut()
67 .simplify(SimplifyPhase::Full(genv), &mut SnapshotMap::default());
68 self.root.borrow_mut().simplify_bot();
69 self.root.borrow_mut().simplify_top();
70 }
71
72 pub(crate) fn to_fixpoint(
73 &self,
74 cx: &mut FixpointCtxt<Tag>,
75 ) -> QueryResult<fixpoint::Constraint> {
76 Ok(self
77 .root
78 .borrow()
79 .to_fixpoint(cx)?
80 .unwrap_or(fixpoint::Constraint::TRUE))
81 }
82
83 pub(crate) fn cursor_at_root(&mut self) -> Cursor<'_> {
84 Cursor { ptr: NodePtr(Rc::clone(&self.root)), tree: self }
85 }
86
87 pub(crate) fn replace_evars(&mut self, evars: &EVarStore) -> Result<(), EVid> {
88 self.root.borrow_mut().replace_evars(evars)
89 }
90}
91
92pub struct Cursor<'a> {
97 tree: &'a mut RefineTree,
98 ptr: NodePtr,
99}
100
101impl Cursor<'_> {
102 pub(crate) fn move_to(&mut self, marker: &Marker, clear_children: bool) -> Option<Cursor<'_>> {
108 let ptr = marker.ptr.upgrade()?;
109 if clear_children {
110 ptr.borrow_mut().children.clear();
111 }
112 Some(Cursor { ptr, tree: self.tree })
113 }
114
115 #[must_use]
117 pub(crate) fn marker(&self) -> Marker {
118 Marker { ptr: NodePtr::downgrade(&self.ptr) }
119 }
120
121 #[must_use]
122 pub(crate) fn branch(&mut self) -> Cursor<'_> {
123 Cursor { tree: self.tree, ptr: NodePtr::clone(&self.ptr) }
124 }
125
126 pub(crate) fn vars(&self) -> impl Iterator<Item = (Var, Sort)> {
127 self.ptr.scope().into_iter()
129 }
130
131 #[expect(dead_code, reason = "used for debugging")]
132 pub(crate) fn push_trace(&mut self, trace: TypeTrace) {
133 self.ptr = self.ptr.push_node(NodeKind::Trace(trace));
134 }
135
136 pub(crate) fn define_var(&mut self, sort: &Sort, provenance: NameProvenance) -> Name {
139 let fresh = Name::from_usize(self.ptr.next_name_idx());
140 self.ptr = self
141 .ptr
142 .push_node(NodeKind::ForAll(fresh, sort.clone(), provenance));
143 fresh
144 }
145
146 pub(crate) fn assume_pred(&mut self, pred: impl Into<Expr>) {
150 let pred = pred.into();
151 if !pred.is_trivially_true() {
152 self.ptr = self.ptr.push_node(NodeKind::Assumption(pred));
153 }
154 }
155
156 pub(crate) fn check_pred(&mut self, pred: impl Into<Expr>, tag: Tag) {
159 let pred = pred.into();
160 if !pred.is_trivially_true() {
161 self.ptr.push_node(NodeKind::Head(pred, tag));
162 }
163 }
164
165 pub(crate) fn check_impl(&mut self, pred1: impl Into<Expr>, pred2: impl Into<Expr>, tag: Tag) {
168 self.ptr
169 .push_node(NodeKind::Assumption(pred1.into()))
170 .push_node(NodeKind::Head(pred2.into(), tag));
171 }
172
173 pub(crate) fn assume_invariants(
174 &mut self,
175 tcx: TyCtxt,
176 ty: &Ty,
177 overflow_checking: OverflowMode,
178 ) {
179 struct Visitor<'a, 'b, 'tcx> {
180 tcx: TyCtxt<'tcx>,
181 cursor: &'a mut Cursor<'b>,
182 overflow_mode: OverflowMode,
183 }
184 impl TypeVisitor for Visitor<'_, '_, '_> {
185 fn visit_bty(&mut self, bty: &BaseTy) -> ControlFlow<!> {
186 match bty {
187 BaseTy::Adt(adt_def, substs) if adt_def.is_box() => substs.visit_with(self),
188 BaseTy::Ref(_, ty, _) => ty.visit_with(self),
189 BaseTy::Tuple(tys) => tys.visit_with(self),
190 _ => ControlFlow::Continue(()),
191 }
192 }
193
194 fn visit_ty(&mut self, ty: &Ty) -> ControlFlow<!> {
195 if let TyKind::Indexed(bty, idx) = ty.kind()
196 && !idx.has_escaping_bvars()
197 {
198 for invariant in bty.invariants(self.tcx, self.overflow_mode) {
199 let invariant = invariant.apply(idx);
200 self.cursor.assume_pred(&invariant);
201 }
202 }
203 ty.super_visit_with(self)
204 }
205 }
206 let _ = ty.visit_with(&mut Visitor { tcx, cursor: self, overflow_mode: overflow_checking });
207 }
208}
209
210pub struct Marker {
217 ptr: WeakNodePtr,
218}
219
220impl Marker {
221 pub fn scope(&self) -> Option<Scope> {
225 Some(self.ptr.upgrade()?.scope())
226 }
227
228 pub fn has_free_vars<T: TypeVisitable>(&self, t: &T) -> bool {
229 let ptr = self
230 .ptr
231 .upgrade()
232 .unwrap_or_else(|| tracked_span_bug!("`has_free_vars` called on invalid `Marker`"));
233
234 let nbindings = ptr.next_name_idx();
235
236 !t.fvars().into_iter().all(|name| name.index() < nbindings)
237 }
238}
239
240#[derive(PartialEq, Eq)]
242pub struct Scope {
243 params: Vec<(Var, Sort)>,
244 bindings: IndexVec<Name, Sort>,
245}
246
247impl Scope {
248 pub(crate) fn iter(&self) -> impl Iterator<Item = (Var, Sort)> + '_ {
249 itertools::chain(
250 self.params.iter().cloned(),
251 self.bindings
252 .iter_enumerated()
253 .map(|(name, sort)| (Var::Free(name), sort.clone())),
254 )
255 }
256
257 fn into_iter(self) -> impl Iterator<Item = (Var, Sort)> {
258 itertools::chain(
259 self.params,
260 self.bindings
261 .into_iter_enumerated()
262 .map(|(name, sort)| (Var::Free(name), sort.clone())),
263 )
264 }
265
266 pub fn has_free_vars<T: TypeFoldable>(&self, t: &T) -> bool {
268 !self.contains_all(t.fvars())
269 }
270
271 fn contains_all(&self, iter: impl IntoIterator<Item = Name>) -> bool {
272 iter.into_iter().all(|name| self.contains(name))
273 }
274
275 fn contains(&self, name: Name) -> bool {
276 name.index() < self.bindings.len()
277 }
278}
279
280struct Node {
281 kind: NodeKind,
282 nbindings: usize,
286 parent: Option<WeakNodePtr>,
287 children: Vec<NodePtr>,
288}
289
290#[derive(Clone)]
291struct NodePtr(Rc<RefCell<Node>>);
292
293impl NodePtr {
294 fn downgrade(this: &Self) -> WeakNodePtr {
295 WeakNodePtr(Rc::downgrade(&this.0))
296 }
297
298 fn push_node(&mut self, kind: NodeKind) -> NodePtr {
299 debug_assert!(!matches!(self.borrow().kind, NodeKind::Head(..)));
300 let node = Node {
301 kind,
302 nbindings: self.next_name_idx(),
303 parent: Some(NodePtr::downgrade(self)),
304 children: vec![],
305 };
306 let node = NodePtr(Rc::new(RefCell::new(node)));
307 self.borrow_mut().children.push(NodePtr::clone(&node));
308 node
309 }
310
311 fn next_name_idx(&self) -> usize {
312 self.borrow().nbindings + usize::from(self.borrow().is_forall())
313 }
314
315 fn scope(&self) -> Scope {
316 let mut params = None;
317 let parents = ParentsIter::new(self.clone());
318 let bindings = parents
319 .filter_map(|node| {
320 let node = node.borrow();
321 match &node.kind {
322 NodeKind::Root(p) => {
323 params = Some(p.clone());
324 None
325 }
326 NodeKind::ForAll(_, sort, _) => Some(sort.clone()),
327 _ => None,
328 }
329 })
330 .collect_vec()
331 .into_iter()
332 .rev()
333 .collect();
334 Scope { bindings, params: params.unwrap_or_default() }
335 }
336}
337
338struct WeakNodePtr(Weak<RefCell<Node>>);
339
340impl WeakNodePtr {
341 fn upgrade(&self) -> Option<NodePtr> {
342 Some(NodePtr(self.0.upgrade()?))
343 }
344}
345
346enum NodeKind {
347 Root(Vec<(Var, Sort)>),
349 ForAll(Name, Sort, NameProvenance),
350 Assumption(Expr),
351 Head(Expr, Tag),
352 True,
353 Trace(TypeTrace),
355}
356
357impl std::ops::Index<Name> for Scope {
358 type Output = Sort;
359
360 fn index(&self, name: Name) -> &Self::Output {
361 &self.bindings[name]
362 }
363}
364
365impl std::ops::Deref for NodePtr {
366 type Target = Rc<RefCell<Node>>;
367
368 fn deref(&self) -> &Self::Target {
369 &self.0
370 }
371}
372
373#[derive(Clone, Copy)]
374enum SimplifyPhase<'genv, 'tcx> {
375 Full(GlobalEnv<'genv, 'tcx>),
377 Partial,
379}
380
381impl Node {
382 fn simplify(&mut self, phase: SimplifyPhase, assumed_preds: &mut SnapshotMap<Expr, ()>) {
383 match &mut self.kind {
385 NodeKind::Head(pred, tag) => {
386 let pred = match phase {
387 SimplifyPhase::Full(genv) => pred.normalize(genv).simplify(assumed_preds),
388 SimplifyPhase::Partial => pred.clone(),
389 };
390 if pred.is_trivially_true() {
391 self.kind = NodeKind::True;
392 } else {
393 self.kind = NodeKind::Head(pred, *tag);
394 }
395 }
396 NodeKind::Assumption(pred) => {
397 if let SimplifyPhase::Full(genv) = phase {
398 *pred = pred.normalize(genv).simplify(assumed_preds);
399 }
400 pred.visit_conj(|conjunct| {
401 assumed_preds.insert(conjunct.erase_spans(), ());
402 });
403 }
404 _ => {}
405 }
406 let is_false_asm =
407 matches!(&self.kind, NodeKind::Assumption(pred) if pred.is_trivially_false());
408
409 for child in &self.children {
412 let current_version = assumed_preds.snapshot();
413 child.borrow_mut().simplify(phase, assumed_preds);
414 assumed_preds.rollback_to(current_version);
415 }
416
417 match &mut self.kind {
419 NodeKind::Head(..) | NodeKind::True => {}
420 NodeKind::Assumption(_)
421 | NodeKind::Trace(_)
422 | NodeKind::Root(_)
423 | NodeKind::ForAll(..) => {
424 self.children
425 .extract_if(.., |child| {
426 is_false_asm || matches!(&child.borrow().kind, NodeKind::True)
427 })
428 .for_each(drop);
429 }
430 }
431 if !self.is_leaf() && self.children.is_empty() {
432 self.kind = NodeKind::True;
433 }
434 }
435
436 fn is_leaf(&self) -> bool {
437 matches!(self.kind, NodeKind::Head(..) | NodeKind::True)
438 }
439
440 fn replace_evars(&mut self, evars: &EVarStore) -> Result<(), EVid> {
441 for child in &self.children {
442 child.borrow_mut().replace_evars(evars)?;
443 }
444 match &mut self.kind {
445 NodeKind::Assumption(pred) => *pred = evars.replace_evars(pred)?,
446 NodeKind::Head(pred, _) => {
447 *pred = evars.replace_evars(pred)?;
448 }
449 NodeKind::Trace(trace) => {
450 evars.replace_evars(trace)?;
451 }
452 NodeKind::Root(_) | NodeKind::ForAll(..) | NodeKind::True => {}
453 }
454 Ok(())
455 }
456
457 fn to_fixpoint(&self, cx: &mut FixpointCtxt<Tag>) -> QueryResult<Option<fixpoint::Constraint>> {
458 let cstr = match &self.kind {
459 NodeKind::Trace(_) | NodeKind::ForAll(_, Sort::Loc, _) => {
460 children_to_fixpoint(cx, &self.children)?
461 }
462
463 NodeKind::Root(params) => {
464 for (var, sort) in params {
466 if let Var::EarlyParam(param) = var
467 && !sort.is_loc()
468 {
469 cx.with_early_param(param);
470 }
471 }
472
473 let Some(children) = children_to_fixpoint(cx, &self.children)? else {
474 return Ok(None);
475 };
476 let mut constr = children;
477 for (var, sort) in params.iter().rev() {
478 if sort.is_loc() {
479 continue;
480 }
481 constr = fixpoint::Constraint::ForAll(
482 fixpoint::Bind {
483 name: cx.var_to_fixpoint(var),
484 sort: cx.sort_to_fixpoint(sort),
485 pred: fixpoint::Pred::TRUE,
486 },
487 Box::new(constr),
488 );
489 }
490 Some(constr)
491 }
492 NodeKind::ForAll(name, sort, provenance) => {
493 cx.with_name_map(*name, *provenance, |cx, fresh| -> QueryResult<_> {
494 let Some(children) = children_to_fixpoint(cx, &self.children)? else {
495 return Ok(None);
496 };
497 Ok(Some(fixpoint::Constraint::ForAll(
498 fixpoint::Bind {
499 name: fixpoint::Var::Local(fresh),
500 sort: cx.sort_to_fixpoint(sort),
501 pred: fixpoint::Pred::TRUE,
502 },
503 Box::new(children),
504 )))
505 })?
506 }
507 NodeKind::Assumption(pred) => {
508 let (mut bindings, pred) = cx.assumption_to_fixpoint(pred)?;
509 let Some(cstr) = children_to_fixpoint(cx, &self.children)? else {
510 return Ok(None);
511 };
512 bindings.push(fixpoint::Bind {
513 name: fixpoint::Var::Underscore,
514 sort: fixpoint::Sort::Int,
515 pred,
516 });
517 Some(fixpoint::Constraint::foralls(bindings, cstr))
518 }
519 NodeKind::Head(pred, tag) => {
520 Some(cx.head_to_fixpoint(pred, |span| tag.with_dst(span))?)
521 }
522 NodeKind::True => None,
523 };
524 Ok(cstr)
525 }
526
527 fn is_forall(&self) -> bool {
531 matches!(self.kind, NodeKind::ForAll(..))
532 }
533
534 fn is_head(&self) -> bool {
538 matches!(self.kind, NodeKind::Head(..))
539 }
540}
541
542fn children_to_fixpoint(
543 cx: &mut FixpointCtxt<Tag>,
544 children: &[NodePtr],
545) -> QueryResult<Option<fixpoint::Constraint>> {
546 let mut children = children
547 .iter()
548 .filter_map(|node| node.borrow().to_fixpoint(cx).transpose())
549 .try_collect_vec()?;
550 let cstr = match children.len() {
551 0 => None,
552 1 => children.pop(),
553 _ => Some(fixpoint::Constraint::Conj(children)),
554 };
555 Ok(cstr)
556}
557
558struct ParentsIter {
559 ptr: Option<NodePtr>,
560}
561
562impl ParentsIter {
563 fn new(ptr: NodePtr) -> Self {
564 Self { ptr: Some(ptr) }
565 }
566}
567
568impl Iterator for ParentsIter {
569 type Item = NodePtr;
570
571 fn next(&mut self) -> Option<Self::Item> {
572 if let Some(ptr) = self.ptr.take() {
573 self.ptr = ptr.borrow().parent.as_ref().and_then(WeakNodePtr::upgrade);
574 Some(ptr)
575 } else {
576 None
577 }
578 }
579}
580
581mod pretty {
582 use std::fmt::{self, Write};
583
584 use flux_middle::pretty::*;
585 use pad_adapter::PadAdapter;
586
587 use super::*;
588
589 fn bindings_chain(ptr: &NodePtr) -> (Vec<(Name, Sort)>, Vec<NodePtr>) {
590 fn go(ptr: &NodePtr, mut bindings: Vec<(Name, Sort)>) -> (Vec<(Name, Sort)>, Vec<NodePtr>) {
591 let node = ptr.borrow();
592 if let NodeKind::ForAll(name, sort, _) = &node.kind {
593 bindings.push((*name, sort.clone()));
594 if let [child] = &node.children[..] {
595 go(child, bindings)
596 } else {
597 (bindings, node.children.clone())
598 }
599 } else {
600 (bindings, vec![NodePtr::clone(ptr)])
601 }
602 }
603 go(ptr, vec![])
604 }
605
606 fn preds_chain(ptr: &NodePtr) -> (Vec<Expr>, Vec<NodePtr>) {
607 fn go(ptr: &NodePtr, mut preds: Vec<Expr>) -> (Vec<Expr>, Vec<NodePtr>) {
608 let node = ptr.borrow();
609 if let NodeKind::Assumption(pred) = &node.kind {
610 preds.push(pred.clone());
611 if let [child] = &node.children[..] {
612 go(child, preds)
613 } else {
614 (preds, node.children.clone())
615 }
616 } else {
617 (preds, vec![NodePtr::clone(ptr)])
618 }
619 }
620 go(ptr, vec![])
621 }
622
623 impl Pretty for RefineTree {
624 fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
625 w!(cx, f, "{:?}", &self.root)
626 }
627 }
628
629 impl Pretty for NodePtr {
630 fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
631 let node = self.borrow();
632 match &node.kind {
633 NodeKind::Trace(trace) => {
634 w!(cx, f, "@ {:?}", ^trace)?;
635 w!(cx, with_padding(f), "\n{:?}", join!("\n", &node.children))
636 }
637 NodeKind::Root(bindings) => {
638 w!(cx, f,
639 "∀ {}.",
640 ^bindings
641 .iter()
642 .format_with(", ", |(name, sort), f| {
643 f(&format_args_cx!(cx, "{:?}: {:?}", ^name, sort))
644 })
645 )?;
646 fmt_children(&node.children, cx, f)
647 }
648 NodeKind::ForAll(name, sort, _) => {
649 let (bindings, children) = if cx.bindings_chain {
650 bindings_chain(self)
651 } else {
652 (vec![(*name, sort.clone())], node.children.clone())
653 };
654
655 w!(cx, f,
656 "∀ {}.",
657 ^bindings
658 .into_iter()
659 .format_with(", ", |(name, sort), f| {
660 f(&format_args_cx!(cx, "{:?}: {:?}", ^name, sort))
661 })
662 )?;
663 fmt_children(&children, cx, f)
664 }
665 NodeKind::Assumption(pred) => {
666 let (preds, children) = if cx.preds_chain {
667 preds_chain(self)
668 } else {
669 (vec![pred.clone()], node.children.clone())
670 };
671 let guard = Expr::and_from_iter(preds).simplify(&SnapshotMap::default());
672 w!(cx, f, "{:?} =>", parens!(guard, !guard.is_atom()))?;
673 fmt_children(&children, cx, f)
674 }
675 NodeKind::Head(pred, tag) => {
676 let pred = if cx.simplify_exprs {
677 pred.simplify(&SnapshotMap::default())
678 } else {
679 pred.clone()
680 };
681 w!(cx, f, "{:?}", parens!(pred, !pred.is_atom()))?;
682 if cx.tags {
683 w!(cx, f, " ~ {:?}", tag)?;
684 }
685 Ok(())
686 }
687 NodeKind::True => {
688 w!(cx, f, "true")
689 }
690 }
691 }
692 }
693
694 fn fmt_children(
695 children: &[NodePtr],
696 cx: &PrettyCx,
697 f: &mut fmt::Formatter<'_>,
698 ) -> fmt::Result {
699 match children {
700 [] => w!(cx, f, " true"),
701 [n] => {
702 if n.borrow().is_head() {
703 w!(cx, f, " {:?}", n)
704 } else {
705 w!(cx, with_padding(f), "\n{:?}", n)
706 }
707 }
708 _ => w!(cx, with_padding(f), "\n{:?}", join!("\n", children)),
709 }
710 }
711
712 impl Pretty for Cursor<'_> {
713 fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
714 let mut elements = vec![];
715 for node in ParentsIter::new(NodePtr::clone(&self.ptr)) {
716 let n = node.borrow();
717 match &n.kind {
718 NodeKind::Root(bindings) => {
719 for (name, sort) in bindings.iter().rev() {
721 elements.push(format_cx!(cx, "{:?}: {:?}", ^name, sort));
722 }
723 }
724 NodeKind::ForAll(name, sort, _) => {
725 elements.push(format_cx!(cx, "{:?}: {:?}", ^name, sort));
726 }
727 NodeKind::Assumption(pred) => {
728 elements.push(format_cx!(cx, "{:?}", pred));
729 }
730 _ => {}
731 }
732 }
733 write!(f, "{{{}}}", elements.into_iter().rev().format(", "))
734 }
735 }
736
737 impl Pretty for Scope {
738 fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
739 write!(
740 f,
741 "[bindings = {}, reftgenerics = {}]",
742 self.bindings
743 .iter_enumerated()
744 .format_with(", ", |(name, sort), f| {
745 f(&format_args_cx!(cx, "{:?}: {:?}", ^name, sort))
746 }),
747 self.params
748 .iter()
749 .format_with(", ", |(param_const, sort), f| {
750 f(&format_args_cx!(cx, "{:?}: {:?}", ^param_const, sort))
751 }),
752 )
753 }
754 }
755
756 fn with_padding<'a, 'b>(f: &'a mut fmt::Formatter<'b>) -> PadAdapter<'a, 'b, 'static> {
757 PadAdapter::with_padding(f, " ")
758 }
759
760 impl_debug_with_default_cx!(
761 RefineTree => "refine_tree",
762 Cursor<'_> => "cursor",
763 Scope,
764 );
765}
766
767#[derive(Serialize, DebugAsJson)]
769pub struct RefineCtxtTrace {
770 bindings: Vec<RcxBind>,
771 exprs: Vec<String>,
772}
773
774#[derive(Serialize)]
775struct RcxBind {
776 name: String,
777 sort: String,
778}
779
780impl RefineCtxtTrace {
781 pub fn new(cx: &mut PrettyCx, cursor: &Cursor) -> Self {
782 let parents = ParentsIter::new(NodePtr::clone(&cursor.ptr)).collect_vec();
783 let mut bindings = vec![];
784 let mut exprs = vec![];
785
786 parents.into_iter().rev().for_each(|ptr| {
787 let node = ptr.borrow();
788 match &node.kind {
789 NodeKind::ForAll(name, sort, provenance) => {
790 let name = cx
791 .pretty_var_env
792 .set(PrettyVar::Local(*name), provenance.opt_symbol());
793 let sort = format_cx!(cx, "{:?}", sort);
794 let bind = RcxBind { name, sort };
795 bindings.push(bind);
796 }
797 NodeKind::Assumption(e)
798 if !e.simplify(&SnapshotMap::default()).is_trivially_true() =>
799 {
800 e.visit_conj(|e| {
801 exprs.push(e.nested_string(cx));
802 });
803 }
804 NodeKind::Root(binds) => {
805 for (name, sort) in binds {
806 let name = if let Var::EarlyParam(param) = name {
807 cx.pretty_var_env
808 .set(PrettyVar::Param(*param), Some(param.name))
809 } else {
810 format_cx!(cx, "{:?}", name)
811 };
812 let sort = format_cx!(cx, "{:?}", sort);
813 let bind = RcxBind { name, sort };
814 bindings.push(bind);
815 }
816 }
817 _ => (),
818 }
819 });
820 Self { bindings, exprs }
821 }
822}
823
824impl Node {
825 fn simplify_bot(&mut self) {
827 let graph = ConstraintDeps::new(self);
828 let bots = graph.bot_kvars();
829 self.simplify_with_assignment(&bots);
830 self.simplify(SimplifyPhase::Partial, &mut SnapshotMap::default());
831 }
832
833 fn simplify_top(&mut self) {
835 let graph = ConstraintDeps::new(self);
836 let tops = graph.top_kvars();
837 self.simplify_with_assignment(&tops);
838 self.simplify(SimplifyPhase::Partial, &mut SnapshotMap::default());
839 }
840
841 fn simplify_with_assignment(&mut self, assignment: &Assignment) {
844 match &mut self.kind {
845 NodeKind::Head(pred, tag) => {
846 let pred = assignment.simplify(pred);
847 self.kind = NodeKind::Head(pred, *tag);
848 }
849 NodeKind::Assumption(pred) => {
850 let pred = assignment.simplify(pred);
851 self.kind = NodeKind::Assumption(pred);
852 }
853 _ => {}
854 }
855 for child in &self.children {
856 child.borrow_mut().simplify_with_assignment(assignment);
857 }
858 }
859}
860
861#[derive(Debug)]
862struct ConstraintDeps {
863 assumptions: IndexVec<ClauseId, FxHashSet<KVid>>,
865 heads: IndexVec<ClauseId, Head>,
867}
868
869impl ConstraintDeps {
870 fn new(node: &Node) -> Self {
871 let mut graph = Self { assumptions: IndexVec::default(), heads: IndexVec::default() };
872 graph.build(node, &mut vec![]);
873 graph
874 }
875
876 fn insert_clause(&mut self, assumptions: &[KVid], head: Head) {
877 self.assumptions.push(assumptions.iter().copied().collect());
878 self.heads.push(head);
879 }
880
881 fn build(&mut self, node: &Node, assumptions: &mut Vec<KVid>) {
882 let n = assumptions.len();
883 match &node.kind {
884 NodeKind::Head(expr, _) => {
885 expr.visit_conj(|e| {
886 if let ExprKind::KVar(kvar) = e.kind() {
887 self.insert_clause(assumptions, Head::KVar(kvar.kvid));
888 } else {
889 self.insert_clause(assumptions, Head::Conc);
890 }
891 });
892 }
893 NodeKind::Assumption(expr) => {
894 expr.visit_conj(|e| {
895 if let ExprKind::KVar(kvar) = e.kind() {
896 assumptions.push(kvar.kvid);
897 }
898 });
899 }
900 _ => {}
901 };
902
903 for child in &node.children {
904 self.build(&child.borrow(), assumptions);
905 }
906
907 assumptions.truncate(n); }
909
910 fn kv_lhs(&self) -> UnordMap<KVid, Vec<ClauseId>> {
912 let mut res: UnordMap<KVid, Vec<ClauseId>> = UnordMap::default();
913 for (clause_id, kvids) in self.assumptions.iter_enumerated() {
914 for kvid in kvids {
915 res.entry(*kvid).or_default().push(clause_id);
916 }
917 }
918 res
919 }
920
921 fn kv_rhs(&self) -> UnordMap<KVid, Vec<ClauseId>> {
923 let mut res: UnordMap<KVid, Vec<ClauseId>> = UnordMap::default();
924 for (clause_id, head) in self.heads.iter_enumerated() {
925 if let Head::KVar(kvid) = head {
926 res.entry(*kvid).or_default().push(clause_id);
927 }
928 }
929 res
930 }
931
932 fn bot_kvars(self) -> Assignment {
935 let mut assignment = Assignment::new(Label::Bot);
937
938 let kv_lhs = self.kv_lhs();
939
940 let mut bot_assms: IndexVec<ClauseId, FxHashSet<KVid>> = self.assumptions;
942
943 let mut candidates: Vec<ClauseId> = bot_assms
945 .iter_enumerated()
946 .filter_map(|(cid, lhs)| if lhs.is_empty() { Some(cid) } else { None })
947 .collect();
948
949 while let Some(cid) = candidates.pop() {
951 if let Head::KVar(kvid) = self.heads[cid] {
952 assignment.remove(kvid);
954 for cid in kv_lhs.get(&kvid).unwrap_or(&vec![]) {
956 if let Head::KVar(rhs_kvid) = self.heads[*cid] {
958 let assms = &mut bot_assms[*cid];
959 assms.remove(&kvid);
960 if assignment.has_label(rhs_kvid) && assms.is_empty() {
961 candidates.push(*cid);
962 }
963 };
964 }
965 }
966 }
967
968 assignment
969 }
970
971 fn top_kvars(self) -> Assignment {
974 let mut assignment = Assignment::new(Label::Top);
976
977 let kv_rhs = self.kv_rhs();
978
979 let mut candidates = vec![];
981 for (cid, head) in self.heads.iter_enumerated() {
982 if matches!(head, Head::Conc) {
983 for kvid in &self.assumptions[cid] {
984 candidates.push(*kvid);
985 }
986 }
987 }
988
989 while let Some(kvid) = candidates.pop() {
991 assignment.remove(kvid);
993
994 for cid in kv_rhs.get(&kvid).unwrap_or(&vec![]) {
996 for lhs_kvid in &self.assumptions[*cid] {
998 if assignment.has_label(*lhs_kvid) {
999 candidates.push(*lhs_kvid);
1000 }
1001 }
1002 }
1003 }
1004
1005 assignment
1006 }
1007}
1008
1009newtype_index! {
1010 struct ClauseId {}
1011}
1012
1013#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1014enum Head {
1015 KVar(KVid),
1017 Conc,
1020}
1021
1022#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1023enum Label {
1024 Bot,
1026 Top,
1028}
1029
1030struct Assignment {
1031 vars: UnordSet<KVid>,
1034 label: Label,
1035}
1036
1037impl Assignment {
1038 fn new(label: Label) -> Self {
1039 let vars = UnordSet::new();
1040 Self { vars, label }
1041 }
1042
1043 fn has_label(&self, kvid: KVid) -> bool {
1044 !self.vars.contains(&kvid)
1045 }
1046
1047 fn remove(&mut self, kvid: KVid) {
1048 self.vars.insert(kvid);
1049 }
1050
1051 fn simplify(&self, pred: &Expr) -> Expr {
1054 let mut preds = vec![];
1055 for p in pred.flatten_conjs() {
1056 if let ExprKind::KVar(kvar) = p.kind()
1057 && self.has_label(kvar.kvid)
1058 {
1059 if self.label == Label::Bot {
1060 return Expr::ff();
1061 } } else {
1063 preds.push(p.clone());
1064 }
1065 }
1066 Expr::and_from_iter(preds)
1067 }
1068}