1use std::{
2 cell::RefCell,
3 ops::ControlFlow,
4 rc::{Rc, Weak},
5};
6
7use flux_common::{index::IndexVec, iter::IterExt, tracked_span_bug};
8use flux_macros::DebugAsJson;
9use flux_middle::{
10 global_env::GlobalEnv,
11 pretty::{PrettyCx, PrettyNested, format_cx},
12 queries::QueryResult,
13 rty::{
14 BaseTy, EVid, Expr, Name, Sort, Ty, TyKind, Var,
15 fold::{TypeFoldable, TypeSuperVisitable, TypeVisitable, TypeVisitor},
16 },
17};
18use itertools::Itertools;
19use rustc_data_structures::snapshot_map::SnapshotMap;
20use rustc_middle::ty::TyCtxt;
21use serde::Serialize;
22
23use crate::{
24 evars::EVarStore,
25 fixpoint_encoding::{FixpointCtxt, fixpoint},
26 infer::{Tag, TypeTrace},
27};
28
29pub struct RefineTree {
47 root: NodePtr,
48}
49
50impl RefineTree {
51 pub(crate) fn new(params: Vec<(Var, Sort)>) -> RefineTree {
52 let root =
53 Node { kind: NodeKind::Root(params), nbindings: 0, parent: None, children: vec![] };
54 let root = NodePtr(Rc::new(RefCell::new(root)));
55 RefineTree { root }
56 }
57
58 pub(crate) fn simplify(&mut self, genv: GlobalEnv) {
59 self.root
60 .borrow_mut()
61 .simplify(genv, &mut SnapshotMap::default());
62 }
63
64 pub(crate) fn into_fixpoint(
65 self,
66 cx: &mut FixpointCtxt<Tag>,
67 ) -> QueryResult<fixpoint::Constraint> {
68 Ok(self
69 .root
70 .borrow()
71 .to_fixpoint(cx)?
72 .unwrap_or(fixpoint::Constraint::TRUE))
73 }
74
75 pub(crate) fn cursor_at_root(&mut self) -> Cursor<'_> {
76 Cursor { ptr: NodePtr(Rc::clone(&self.root)), tree: self }
77 }
78
79 pub(crate) fn replace_evars(&mut self, evars: &EVarStore) -> Result<(), EVid> {
80 self.root.borrow_mut().replace_evars(evars)
81 }
82}
83
84pub struct Cursor<'a> {
89 tree: &'a mut RefineTree,
90 ptr: NodePtr,
91}
92
93impl Cursor<'_> {
94 pub(crate) fn move_to(&mut self, marker: &Marker, clear_children: bool) -> Option<Cursor<'_>> {
100 let ptr = marker.ptr.upgrade()?;
101 if clear_children {
102 ptr.borrow_mut().children.clear();
103 }
104 Some(Cursor { ptr, tree: self.tree })
105 }
106
107 #[must_use]
109 pub(crate) fn marker(&self) -> Marker {
110 Marker { ptr: NodePtr::downgrade(&self.ptr) }
111 }
112
113 #[must_use]
114 pub(crate) fn branch(&mut self) -> Cursor<'_> {
115 Cursor { tree: self.tree, ptr: NodePtr::clone(&self.ptr) }
116 }
117
118 pub(crate) fn vars(&self) -> impl Iterator<Item = (Var, Sort)> {
119 self.ptr.scope().into_iter()
121 }
122
123 #[expect(dead_code, reason = "used for debugging")]
124 pub(crate) fn push_trace(&mut self, trace: TypeTrace) {
125 self.ptr = self.ptr.push_node(NodeKind::Trace(trace));
126 }
127
128 pub(crate) fn define_var(&mut self, sort: &Sort) -> Name {
131 let fresh = Name::from_usize(self.ptr.next_name_idx());
132 self.ptr = self.ptr.push_node(NodeKind::ForAll(fresh, sort.clone()));
133 fresh
134 }
135
136 pub(crate) fn assume_pred(&mut self, pred: impl Into<Expr>) {
140 let pred = pred.into();
141 if !pred.is_trivially_true() {
142 self.ptr = self.ptr.push_node(NodeKind::Assumption(pred));
143 }
144 }
145
146 pub(crate) fn check_pred(&mut self, pred: impl Into<Expr>, tag: Tag) {
149 let pred = pred.into();
150 if !pred.is_trivially_true() {
151 self.ptr.push_node(NodeKind::Head(pred, tag));
152 }
153 }
154
155 pub(crate) fn check_impl(&mut self, pred1: impl Into<Expr>, pred2: impl Into<Expr>, tag: Tag) {
158 self.ptr
159 .push_node(NodeKind::Assumption(pred1.into()))
160 .push_node(NodeKind::Head(pred2.into(), tag));
161 }
162
163 pub(crate) fn assume_invariants(&mut self, tcx: TyCtxt, ty: &Ty, overflow_checking: bool) {
164 struct Visitor<'a, 'b, 'tcx> {
165 tcx: TyCtxt<'tcx>,
166 cursor: &'a mut Cursor<'b>,
167 overflow_checking: bool,
168 }
169 impl TypeVisitor for Visitor<'_, '_, '_> {
170 fn visit_bty(&mut self, bty: &BaseTy) -> ControlFlow<!> {
171 match bty {
172 BaseTy::Adt(adt_def, substs) if adt_def.is_box() => substs.visit_with(self),
173 BaseTy::Ref(_, ty, _) => ty.visit_with(self),
174 BaseTy::Tuple(tys) => tys.visit_with(self),
175 _ => ControlFlow::Continue(()),
176 }
177 }
178
179 fn visit_ty(&mut self, ty: &Ty) -> ControlFlow<!> {
180 if let TyKind::Indexed(bty, idx) = ty.kind()
181 && !idx.has_escaping_bvars()
182 {
183 for invariant in bty.invariants(self.tcx, self.overflow_checking) {
184 let invariant = invariant.apply(idx);
185 self.cursor.assume_pred(&invariant);
186 }
187 }
188 ty.super_visit_with(self)
189 }
190 }
191 let _ = ty.visit_with(&mut Visitor { tcx, cursor: self, overflow_checking });
192 }
193}
194
195pub struct Marker {
202 ptr: WeakNodePtr,
203}
204
205impl Marker {
206 pub fn scope(&self) -> Option<Scope> {
210 Some(self.ptr.upgrade()?.scope())
211 }
212
213 pub fn has_free_vars<T: TypeVisitable>(&self, t: &T) -> bool {
214 let ptr = self
215 .ptr
216 .upgrade()
217 .unwrap_or_else(|| tracked_span_bug!("`has_free_vars` called on invalid `Marker`"));
218
219 let nbindings = ptr.next_name_idx();
220
221 !t.fvars().into_iter().all(|name| name.index() < nbindings)
222 }
223}
224
225#[derive(PartialEq, Eq)]
227pub struct Scope {
228 params: Vec<(Var, Sort)>,
229 bindings: IndexVec<Name, Sort>,
230}
231
232impl Scope {
233 pub(crate) fn iter(&self) -> impl Iterator<Item = (Var, Sort)> + '_ {
234 itertools::chain(
235 self.params.iter().cloned(),
236 self.bindings
237 .iter_enumerated()
238 .map(|(name, sort)| (Var::Free(name), sort.clone())),
239 )
240 }
241
242 fn into_iter(self) -> impl Iterator<Item = (Var, Sort)> {
243 itertools::chain(
244 self.params,
245 self.bindings
246 .into_iter_enumerated()
247 .map(|(name, sort)| (Var::Free(name), sort.clone())),
248 )
249 }
250
251 pub fn has_free_vars<T: TypeFoldable>(&self, t: &T) -> bool {
253 !self.contains_all(t.fvars())
254 }
255
256 fn contains_all(&self, iter: impl IntoIterator<Item = Name>) -> bool {
257 iter.into_iter().all(|name| self.contains(name))
258 }
259
260 fn contains(&self, name: Name) -> bool {
261 name.index() < self.bindings.len()
262 }
263}
264
265struct Node {
266 kind: NodeKind,
267 nbindings: usize,
271 parent: Option<WeakNodePtr>,
272 children: Vec<NodePtr>,
273}
274
275#[derive(Clone)]
276struct NodePtr(Rc<RefCell<Node>>);
277
278impl NodePtr {
279 fn downgrade(this: &Self) -> WeakNodePtr {
280 WeakNodePtr(Rc::downgrade(&this.0))
281 }
282
283 fn push_node(&mut self, kind: NodeKind) -> NodePtr {
284 debug_assert!(!matches!(self.borrow().kind, NodeKind::Head(..)));
285 let node = Node {
286 kind,
287 nbindings: self.next_name_idx(),
288 parent: Some(NodePtr::downgrade(self)),
289 children: vec![],
290 };
291 let node = NodePtr(Rc::new(RefCell::new(node)));
292 self.borrow_mut().children.push(NodePtr::clone(&node));
293 node
294 }
295
296 fn next_name_idx(&self) -> usize {
297 self.borrow().nbindings + usize::from(self.borrow().is_forall())
298 }
299
300 fn scope(&self) -> Scope {
301 let mut params = None;
302 let parents = ParentsIter::new(self.clone());
303 let bindings = parents
304 .filter_map(|node| {
305 let node = node.borrow();
306 match &node.kind {
307 NodeKind::Root(p) => {
308 params = Some(p.clone());
309 None
310 }
311 NodeKind::ForAll(_, sort) => Some(sort.clone()),
312 _ => None,
313 }
314 })
315 .collect_vec()
316 .into_iter()
317 .rev()
318 .collect();
319 Scope { bindings, params: params.unwrap_or_default() }
320 }
321}
322
323struct WeakNodePtr(Weak<RefCell<Node>>);
324
325impl WeakNodePtr {
326 fn upgrade(&self) -> Option<NodePtr> {
327 Some(NodePtr(self.0.upgrade()?))
328 }
329}
330
331enum NodeKind {
332 Root(Vec<(Var, Sort)>),
334 ForAll(Name, Sort),
335 Assumption(Expr),
336 Head(Expr, Tag),
337 True,
338 Trace(TypeTrace),
340}
341
342impl std::ops::Index<Name> for Scope {
343 type Output = Sort;
344
345 fn index(&self, name: Name) -> &Self::Output {
346 &self.bindings[name]
347 }
348}
349
350impl std::ops::Deref for NodePtr {
351 type Target = Rc<RefCell<Node>>;
352
353 fn deref(&self) -> &Self::Target {
354 &self.0
355 }
356}
357
358impl Node {
359 fn simplify(&mut self, genv: GlobalEnv, assumed_preds: &mut SnapshotMap<Expr, ()>) {
360 match &mut self.kind {
362 NodeKind::Head(pred, tag) => {
363 let pred = pred.normalize(genv).simplify(assumed_preds);
364 if pred.is_trivially_true() {
365 self.kind = NodeKind::True;
366 } else {
367 self.kind = NodeKind::Head(pred, *tag);
368 }
369 }
370 NodeKind::Assumption(pred) => {
371 *pred = pred.normalize(genv).simplify(assumed_preds);
372 pred.flatten_conjs().into_iter().for_each(|conjunct| {
373 assumed_preds.insert(conjunct.erase_spans(), ());
374 });
375 }
376 _ => {}
377 }
378
379 for child in &self.children {
382 let current_version = assumed_preds.snapshot();
383 child.borrow_mut().simplify(genv, assumed_preds);
384 assumed_preds.rollback_to(current_version);
385 }
386
387 match &mut self.kind {
389 NodeKind::Head(..) | NodeKind::True => {}
390 NodeKind::Assumption(_)
391 | NodeKind::Trace(_)
392 | NodeKind::Root(_)
393 | NodeKind::ForAll(..) => {
394 self.children
395 .extract_if(.., |child| matches!(&child.borrow().kind, NodeKind::True))
396 .for_each(drop);
397 }
398 }
399 if !self.is_leaf() && self.children.is_empty() {
400 self.kind = NodeKind::True;
401 }
402 }
403
404 fn is_leaf(&self) -> bool {
405 matches!(self.kind, NodeKind::Head(..) | NodeKind::True)
406 }
407
408 fn replace_evars(&mut self, evars: &EVarStore) -> Result<(), EVid> {
409 for child in &self.children {
410 child.borrow_mut().replace_evars(evars)?;
411 }
412 match &mut self.kind {
413 NodeKind::Assumption(pred) => *pred = evars.replace_evars(pred)?,
414 NodeKind::Head(pred, _) => {
415 *pred = evars.replace_evars(pred)?;
416 }
417 NodeKind::Trace(trace) => {
418 evars.replace_evars(trace)?;
419 }
420 NodeKind::Root(_) | NodeKind::ForAll(..) | NodeKind::True => {}
421 }
422 Ok(())
423 }
424
425 fn to_fixpoint(&self, cx: &mut FixpointCtxt<Tag>) -> QueryResult<Option<fixpoint::Constraint>> {
426 let cstr = match &self.kind {
427 NodeKind::Trace(_) | NodeKind::ForAll(_, Sort::Loc) => {
428 children_to_fixpoint(cx, &self.children)?
429 }
430
431 NodeKind::Root(params) => {
432 let Some(children) = children_to_fixpoint(cx, &self.children)? else {
433 return Ok(None);
434 };
435 let mut constr = children;
436 for (var, sort) in params.iter().rev() {
437 if sort.is_loc() {
438 continue;
439 }
440 constr = fixpoint::Constraint::ForAll(
441 fixpoint::Bind {
442 name: cx.var_to_fixpoint(var),
443 sort: cx.sort_to_fixpoint(sort),
444 pred: fixpoint::Pred::TRUE,
445 },
446 Box::new(constr),
447 );
448 }
449 Some(constr)
450 }
451 NodeKind::ForAll(name, sort) => {
452 cx.with_name_map(*name, |cx, fresh| -> QueryResult<_> {
453 let Some(children) = children_to_fixpoint(cx, &self.children)? else {
454 return Ok(None);
455 };
456 Ok(Some(fixpoint::Constraint::ForAll(
457 fixpoint::Bind {
458 name: fixpoint::Var::Local(fresh),
459 sort: cx.sort_to_fixpoint(sort),
460 pred: fixpoint::Pred::TRUE,
461 },
462 Box::new(children),
463 )))
464 })?
465 }
466 NodeKind::Assumption(pred) => {
467 let (mut bindings, pred) = cx.assumption_to_fixpoint(pred)?;
468 let Some(cstr) = children_to_fixpoint(cx, &self.children)? else {
469 return Ok(None);
470 };
471 bindings.push(fixpoint::Bind {
472 name: fixpoint::Var::Underscore,
473 sort: fixpoint::Sort::Int,
474 pred,
475 });
476 Some(fixpoint::Constraint::foralls(bindings, cstr))
477 }
478 NodeKind::Head(pred, tag) => {
479 Some(cx.head_to_fixpoint(pred, |span| tag.with_dst(span))?)
480 }
481 NodeKind::True => None,
482 };
483 Ok(cstr)
484 }
485
486 fn is_forall(&self) -> bool {
490 matches!(self.kind, NodeKind::ForAll(..))
491 }
492
493 fn is_head(&self) -> bool {
497 matches!(self.kind, NodeKind::Head(..))
498 }
499}
500
501fn children_to_fixpoint(
502 cx: &mut FixpointCtxt<Tag>,
503 children: &[NodePtr],
504) -> QueryResult<Option<fixpoint::Constraint>> {
505 let mut children = children
506 .iter()
507 .filter_map(|node| node.borrow().to_fixpoint(cx).transpose())
508 .try_collect_vec()?;
509 let cstr = match children.len() {
510 0 => None,
511 1 => children.pop(),
512 _ => Some(fixpoint::Constraint::Conj(children)),
513 };
514 Ok(cstr)
515}
516
517struct ParentsIter {
518 ptr: Option<NodePtr>,
519}
520
521impl ParentsIter {
522 fn new(ptr: NodePtr) -> Self {
523 Self { ptr: Some(ptr) }
524 }
525}
526
527impl Iterator for ParentsIter {
528 type Item = NodePtr;
529
530 fn next(&mut self) -> Option<Self::Item> {
531 if let Some(ptr) = self.ptr.take() {
532 self.ptr = ptr.borrow().parent.as_ref().and_then(WeakNodePtr::upgrade);
533 Some(ptr)
534 } else {
535 None
536 }
537 }
538}
539
540mod pretty {
541 use std::fmt::{self, Write};
542
543 use flux_middle::pretty::*;
544 use pad_adapter::PadAdapter;
545
546 use super::*;
547
548 fn bindings_chain(ptr: &NodePtr) -> (Vec<(Name, Sort)>, Vec<NodePtr>) {
549 fn go(ptr: &NodePtr, mut bindings: Vec<(Name, Sort)>) -> (Vec<(Name, Sort)>, Vec<NodePtr>) {
550 let node = ptr.borrow();
551 if let NodeKind::ForAll(name, sort) = &node.kind {
552 bindings.push((*name, sort.clone()));
553 if let [child] = &node.children[..] {
554 go(child, bindings)
555 } else {
556 (bindings, node.children.clone())
557 }
558 } else {
559 (bindings, vec![NodePtr::clone(ptr)])
560 }
561 }
562 go(ptr, vec![])
563 }
564
565 fn preds_chain(ptr: &NodePtr) -> (Vec<Expr>, Vec<NodePtr>) {
566 fn go(ptr: &NodePtr, mut preds: Vec<Expr>) -> (Vec<Expr>, Vec<NodePtr>) {
567 let node = ptr.borrow();
568 if let NodeKind::Assumption(pred) = &node.kind {
569 preds.push(pred.clone());
570 if let [child] = &node.children[..] {
571 go(child, preds)
572 } else {
573 (preds, node.children.clone())
574 }
575 } else {
576 (preds, vec![NodePtr::clone(ptr)])
577 }
578 }
579 go(ptr, vec![])
580 }
581
582 impl Pretty for RefineTree {
583 fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
584 w!(cx, f, "{:?}", &self.root)
585 }
586 }
587
588 impl Pretty for NodePtr {
589 fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
590 let node = self.borrow();
591 match &node.kind {
592 NodeKind::Trace(trace) => {
593 w!(cx, f, "@ {:?}", ^trace)?;
594 w!(cx, with_padding(f), "\n{:?}", join!("\n", &node.children))
595 }
596 NodeKind::Root(bindings) => {
597 w!(cx, f,
598 "∀ {}.",
599 ^bindings
600 .iter()
601 .format_with(", ", |(name, sort), f| {
602 f(&format_args_cx!(cx, "{:?}: {:?}", ^name, sort))
603 })
604 )?;
605 fmt_children(&node.children, cx, f)
606 }
607 NodeKind::ForAll(name, sort) => {
608 let (bindings, children) = if cx.bindings_chain {
609 bindings_chain(self)
610 } else {
611 (vec![(*name, sort.clone())], node.children.clone())
612 };
613
614 w!(cx, f,
615 "∀ {}.",
616 ^bindings
617 .into_iter()
618 .format_with(", ", |(name, sort), f| {
619 f(&format_args_cx!(cx, "{:?}: {:?}", ^name, sort))
620 })
621 )?;
622 fmt_children(&children, cx, f)
623 }
624 NodeKind::Assumption(pred) => {
625 let (preds, children) = if cx.preds_chain {
626 preds_chain(self)
627 } else {
628 (vec![pred.clone()], node.children.clone())
629 };
630 let guard = Expr::and_from_iter(preds).simplify(&SnapshotMap::default());
631 w!(cx, f, "{:?} =>", parens!(guard, !guard.is_atom()))?;
632 fmt_children(&children, cx, f)
633 }
634 NodeKind::Head(pred, tag) => {
635 let pred = if cx.simplify_exprs {
636 pred.simplify(&SnapshotMap::default())
637 } else {
638 pred.clone()
639 };
640 w!(cx, f, "{:?}", parens!(pred, !pred.is_atom()))?;
641 if cx.tags {
642 w!(cx, f, " ~ {:?}", tag)?;
643 }
644 Ok(())
645 }
646 NodeKind::True => {
647 w!(cx, f, "true")
648 }
649 }
650 }
651 }
652
653 fn fmt_children(
654 children: &[NodePtr],
655 cx: &PrettyCx,
656 f: &mut fmt::Formatter<'_>,
657 ) -> fmt::Result {
658 match children {
659 [] => w!(cx, f, " true"),
660 [n] => {
661 if n.borrow().is_head() {
662 w!(cx, f, " {:?}", n)
663 } else {
664 w!(cx, with_padding(f), "\n{:?}", n)
665 }
666 }
667 _ => w!(cx, with_padding(f), "\n{:?}", join!("\n", children)),
668 }
669 }
670
671 impl Pretty for Cursor<'_> {
672 fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
673 let mut elements = vec![];
674 for node in ParentsIter::new(NodePtr::clone(&self.ptr)) {
675 let n = node.borrow();
676 match &n.kind {
677 NodeKind::Root(bindings) => {
678 for (name, sort) in bindings.iter().rev() {
680 elements.push(format_cx!(cx, "{:?}: {:?}", ^name, sort));
681 }
682 }
683 NodeKind::ForAll(name, sort) => {
684 elements.push(format_cx!(cx, "{:?}: {:?}", ^name, sort));
685 }
686 NodeKind::Assumption(pred) => {
687 elements.push(format_cx!(cx, "{:?}", pred));
688 }
689 _ => {}
690 }
691 }
692 write!(f, "{{{}}}", elements.into_iter().rev().format(", "))
693 }
694 }
695
696 impl Pretty for Scope {
697 fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
698 write!(
699 f,
700 "[bindings = {}, reftgenerics = {}]",
701 self.bindings
702 .iter_enumerated()
703 .format_with(", ", |(name, sort), f| {
704 f(&format_args_cx!(cx, "{:?}: {:?}", ^name, sort))
705 }),
706 self.params
707 .iter()
708 .format_with(", ", |(param_const, sort), f| {
709 f(&format_args_cx!(cx, "{:?}: {:?}", ^param_const, sort))
710 }),
711 )
712 }
713 }
714
715 fn with_padding<'a, 'b>(f: &'a mut fmt::Formatter<'b>) -> PadAdapter<'a, 'b, 'static> {
716 PadAdapter::with_padding(f, " ")
717 }
718
719 impl_debug_with_default_cx!(
720 RefineTree => "refine_tree",
721 Cursor<'_> => "cursor",
722 Scope,
723 );
724}
725
726#[derive(Serialize, DebugAsJson)]
728pub struct RefineCtxtTrace {
729 bindings: Vec<RcxBind>,
730 exprs: Vec<String>,
731}
732
733#[derive(Serialize)]
734struct RcxBind {
735 name: String,
736 sort: String,
737}
738
739impl RefineCtxtTrace {
740 pub fn new(genv: GlobalEnv, cursor: &Cursor) -> Self {
741 let parents = ParentsIter::new(NodePtr::clone(&cursor.ptr)).collect_vec();
742 let mut bindings = vec![];
743 let mut exprs = vec![];
744 let cx = &PrettyCx::default(genv);
745
746 parents.into_iter().rev().for_each(|ptr| {
747 let node = ptr.borrow();
748 match &node.kind {
749 NodeKind::ForAll(name, sort) => {
750 let bind = RcxBind {
751 name: format_cx!(cx, "{:?}", ^name),
752 sort: format_cx!(cx, "{:?}", sort),
753 };
754 bindings.push(bind);
755 }
756 NodeKind::Assumption(e)
757 if !e.simplify(&SnapshotMap::default()).is_trivially_true() =>
758 {
759 let e = e.nested_string(cx);
760 exprs.push(e);
761 }
762 NodeKind::Root(binds) => {
763 for (name, sort) in binds {
764 let bind = RcxBind {
765 name: format_cx!(cx, "{:?}", name),
766 sort: format_cx!(cx, "{:?}", sort),
767 };
768 bindings.push(bind);
769 }
770 }
771 _ => (),
772 }
773 });
774 Self { bindings, exprs }
775 }
776}