1use std::{collections::hash_map::Entry, fmt, iter};
2
3use flux_common::{tracked_span_assert_eq, tracked_span_bug, tracked_span_dbg_assert_eq};
4use flux_middle::{
5 PlaceExt as _, def_id_to_string, global_env::GlobalEnv, queries::QueryResult, query_bug, rty,
6};
7use flux_rustc_bridge::{
8 mir::{
9 BasicBlock, Body, BorrowKind, FIRST_VARIANT, FieldIdx, Local, Location,
10 NonDivergingIntrinsic, Operand, Place, PlaceElem, PlaceRef, Rvalue, Statement,
11 StatementKind, Terminator, TerminatorKind, UnOp, VariantIdx,
12 },
13 ty::{AdtDef, GenericArgs, GenericArgsExt as _, List, Mutability, Ty, TyKind},
14};
15use itertools::{Itertools, repeat_n};
16use rustc_data_structures::unord::UnordMap;
17use rustc_hash::FxHashMap;
18use rustc_hir::def_id::DefId;
19use rustc_index::{Idx, IndexVec, bit_set::DenseBitSet};
20use rustc_middle::mir::{FakeReadCause, START_BLOCK};
21
22use super::{GhostStatements, StatementsAt};
23use crate::{
24 ghost_statements::{GhostStatement, Point},
25 queue::WorkQueue,
26};
27
28pub(crate) fn add_ghost_statements<'tcx>(
29 stmts: &mut GhostStatements,
30 genv: GlobalEnv<'_, 'tcx>,
31 body: &Body<'tcx>,
32 fn_sig: Option<&rty::EarlyBinder<rty::PolyFnSig>>,
33) -> QueryResult {
34 let mut bb_envs = FxHashMap::default();
35 FoldUnfoldAnalysis::new(genv, body, &mut bb_envs, Infer).run(fn_sig)?;
36
37 FoldUnfoldAnalysis::new(genv, body, &mut bb_envs, Elaboration { stmts }).run(fn_sig)
38}
39
40#[derive(Clone)]
41struct Env {
42 map: IndexVec<Local, PlaceNode>,
43}
44
45impl Env {
46 fn new(body: &Body) -> Self {
47 Self {
48 map: body
49 .local_decls
50 .iter()
51 .map(|decl| PlaceNode::Ty(decl.ty.clone()))
52 .collect(),
53 }
54 }
55
56 fn projection<'a>(&mut self, genv: GlobalEnv, place: &'a Place) -> QueryResult<ProjResult<'a>> {
57 let (node, place, modified) = self.ensure_unfolded(genv, place)?;
58 if modified {
59 Ok(ProjResult::Unfold(place))
60 } else if node.ensure_folded() {
61 Ok(ProjResult::Fold(place))
62 } else {
63 Ok(ProjResult::None)
64 }
65 }
66
67 fn downcast(&mut self, genv: GlobalEnv, place: &Place, variant_idx: VariantIdx) -> QueryResult {
68 let (node, ..) = self.ensure_unfolded(genv, place)?;
69 node.downcast(genv, variant_idx)?;
70 Ok(())
71 }
72
73 fn ensure_unfolded<'a>(
74 &mut self,
75 genv: GlobalEnv,
76 place: &'a Place,
77 ) -> QueryResult<(&mut PlaceNode, PlaceRef<'a>, Modified)> {
78 let mut node = &mut self.map[place.local];
79 let mut modified = false;
80 let mut i = 0;
81 while i < place.projection.len() {
82 let elem = place.projection[i];
83 let (n, m) = match elem {
84 PlaceElem::Deref => node.deref(),
85 PlaceElem::Field(f) => node.field(genv, f)?,
86 PlaceElem::Downcast(_, idx) => node.downcast(genv, idx)?,
87 PlaceElem::Index(_) | PlaceElem::ConstantIndex { .. } => break,
88 };
89 node = n;
90 modified |= m;
91 i += 1;
92 }
93 Ok((node, place.as_ref().truncate(i), modified))
94 }
95
96 fn join(&mut self, genv: GlobalEnv, mut other: Env) -> QueryResult<Modified> {
97 let mut modified = false;
98 for (local, node) in self.map.iter_enumerated_mut() {
99 let (m, _) = node.join(genv, &mut other.map[local], false)?;
100 modified |= m;
101 }
102 Ok(modified)
103 }
104
105 fn collect_fold_unfolds_at_goto(&self, target: &Env, stmts: &mut StatementsAt) {
106 for (local, node) in self.map.iter_enumerated() {
107 node.collect_fold_unfolds(&target.map[local], &mut Place::new(local, vec![]), stmts);
108 }
109 }
110
111 fn collect_folds_at_ret(&self, body: &Body, stmts: &mut StatementsAt) {
112 for local in body.args_iter() {
113 self.map[local].collect_folds_at_ret(&mut Place::new(local, vec![]), stmts);
114 }
115 }
116}
117
118type Modified = bool;
119
120struct FoldUnfoldAnalysis<'a, 'genv, 'tcx, M> {
121 genv: GlobalEnv<'genv, 'tcx>,
122 body: &'a Body<'tcx>,
123 bb_envs: &'a mut FxHashMap<BasicBlock, Env>,
124 visited: DenseBitSet<BasicBlock>,
125 queue: WorkQueue<'a>,
126 discriminants: UnordMap<Place, Place>,
127 point: Point,
128 mode: M,
129}
130
131trait Mode: Sized {
132 const _NAME: &'static str;
133
134 fn projection(
135 analysis: &mut FoldUnfoldAnalysis<Self>,
136 env: &mut Env,
137 place: &Place,
138 ) -> QueryResult;
139
140 fn goto_join_point(
141 analysis: &mut FoldUnfoldAnalysis<Self>,
142 target: BasicBlock,
143 env: Env,
144 ) -> QueryResult<bool>;
145
146 fn ret(analysis: &mut FoldUnfoldAnalysis<Self>, env: &Env);
147}
148
149struct Infer;
150
151struct Elaboration<'a> {
152 stmts: &'a mut GhostStatements,
153}
154
155impl Elaboration<'_> {
156 fn insert_at(&mut self, point: Point, stmt: GhostStatement) {
157 self.stmts.insert_at(point, stmt);
158 }
159}
160
161#[derive(Debug)]
162enum ProjResult<'a> {
163 None,
164 Fold(PlaceRef<'a>),
165 Unfold(PlaceRef<'a>),
166}
167
168impl Mode for Infer {
169 const _NAME: &'static str = "infer";
170
171 fn projection(
172 analysis: &mut FoldUnfoldAnalysis<Self>,
173 env: &mut Env,
174 place: &Place,
175 ) -> QueryResult {
176 env.projection(analysis.genv, place)?;
177 Ok(())
178 }
179
180 fn goto_join_point(
181 analysis: &mut FoldUnfoldAnalysis<Self>,
182 target: BasicBlock,
183 env: Env,
184 ) -> QueryResult<bool> {
185 let modified = match analysis.bb_envs.entry(target) {
186 Entry::Occupied(mut entry) => entry.get_mut().join(analysis.genv, env)?,
187 Entry::Vacant(entry) => {
188 entry.insert(env);
189 true
190 }
191 };
192 Ok(modified)
193 }
194
195 fn ret(_: &mut FoldUnfoldAnalysis<Self>, _: &Env) {}
196}
197
198impl Mode for Elaboration<'_> {
199 const _NAME: &'static str = "elaboration";
200
201 fn projection(
202 analysis: &mut FoldUnfoldAnalysis<Self>,
203 env: &mut Env,
204 place: &Place,
205 ) -> QueryResult {
206 match env.projection(analysis.genv, place)? {
207 ProjResult::None => {}
208 ProjResult::Fold(place_ref) => {
209 tracked_span_assert_eq!(place_ref, place.as_ref());
210 let place = place.clone();
211 analysis
212 .mode
213 .insert_at(analysis.point, GhostStatement::Fold(place));
214 }
215 ProjResult::Unfold(place_ref) => {
216 tracked_span_assert_eq!(place_ref, place.as_ref());
217 match place_ref.last_projection() {
218 Some((base, PlaceElem::Deref | PlaceElem::Field(..))) => {
219 analysis
220 .mode
221 .insert_at(analysis.point, GhostStatement::Unfold(base.to_place()));
222 }
223 _ => Err(query_bug!("invalid projection for unfolding {place_ref:?}"))?,
224 }
225 }
226 }
227 Ok(())
228 }
229
230 fn goto_join_point(
231 analysis: &mut FoldUnfoldAnalysis<Self>,
232 target: BasicBlock,
233 env: Env,
234 ) -> QueryResult<bool> {
235 env.collect_fold_unfolds_at_goto(
236 &analysis.bb_envs[&target],
237 &mut analysis.mode.stmts.at(analysis.point),
238 );
239 Ok(!analysis.visited.contains(target))
240 }
241
242 fn ret(analysis: &mut FoldUnfoldAnalysis<Self>, env: &Env) {
243 env.collect_folds_at_ret(analysis.body, &mut analysis.mode.stmts.at(analysis.point));
244 }
245}
246
247#[derive(Clone)]
248enum PlaceNode {
249 Deref(Ty, Box<PlaceNode>),
250 Downcast(AdtDef, GenericArgs, VariantIdx, Vec<PlaceNode>),
251 Closure(DefId, GenericArgs, Vec<PlaceNode>),
252 Generator(DefId, GenericArgs, Vec<PlaceNode>),
253 Tuple(List<Ty>, Vec<PlaceNode>),
254 Ty(Ty),
255}
256
257impl<M: Mode> FoldUnfoldAnalysis<'_, '_, '_, M> {
258 fn run(mut self, fn_sig: Option<&rty::EarlyBinder<rty::PolyFnSig>>) -> QueryResult {
259 let mut env = Env::new(self.body);
260
261 if let Some(fn_sig) = fn_sig {
262 let fn_sig = fn_sig.as_ref().skip_binder().as_ref().skip_binder();
263 for (local, ty) in iter::zip(self.body.args_iter(), fn_sig.inputs()) {
264 if let rty::TyKind::StrgRef(..) | rty::Ref!(.., Mutability::Mut) = ty.kind() {
265 M::projection(&mut self, &mut env, &Place::new(local, vec![PlaceElem::Deref]))?;
266 }
267 }
268 }
269 self.goto(START_BLOCK, env)?;
270 while let Some(bb) = self.queue.pop() {
271 self.basic_block(bb, self.bb_envs[&bb].clone())?;
272 }
273 Ok(())
274 }
275
276 fn basic_block(&mut self, bb: BasicBlock, mut env: Env) -> QueryResult {
277 self.visited.insert(bb);
278 let data = &self.body.basic_blocks[bb];
279 for (statement_index, stmt) in data.statements.iter().enumerate() {
280 self.point = Point::BeforeLocation(Location { block: bb, statement_index });
281 self.statement(stmt, &mut env)?;
282 }
283 if let Some(terminator) = &data.terminator {
284 self.point = Point::BeforeLocation(self.body.terminator_loc(bb));
285 let successors = self.terminator(terminator, env)?;
286 for (env, target) in successors {
287 self.point = Point::Edge(bb, target);
288 self.goto(target, env)?;
289 }
290 }
291 Ok(())
292 }
293
294 fn statement(&mut self, stmt: &Statement, env: &mut Env) -> QueryResult {
295 match &stmt.kind {
296 StatementKind::FakeRead(box (FakeReadCause::ForIndex, place)) => {
297 M::projection(self, env, place)?;
298 }
299 StatementKind::Assign(place, rvalue) => {
300 match rvalue {
301 Rvalue::UnaryOp(UnOp::PtrMetadata, Operand::Copy(place))
302 | Rvalue::UnaryOp(UnOp::PtrMetadata, Operand::Move(place)) => {
303 let deref_place = place.deref();
304 M::projection(self, env, &deref_place)?;
305 }
306 Rvalue::Use(op)
307 | Rvalue::Cast(_, op, _)
308 | Rvalue::UnaryOp(_, op)
309 | Rvalue::ShallowInitBox(op, _) => {
310 self.operand(op, env)?;
311 }
312 Rvalue::Ref(.., bk, place) => {
313 if !matches!(bk, BorrowKind::Fake(_)) {
315 M::projection(self, env, place)?;
316 }
317 }
318 Rvalue::RawPtr(_, place) => {
319 M::projection(self, env, place)?;
320 }
321 Rvalue::BinaryOp(_, op1, op2) => {
322 self.operand(op1, env)?;
323 self.operand(op2, env)?;
324 }
325 Rvalue::Aggregate(_, args) => {
326 for arg in args {
327 self.operand(arg, env)?;
328 }
329 }
330
331 Rvalue::Discriminant(discr) => {
332 M::projection(self, env, discr)?;
333 self.discriminants.insert(place.clone(), discr.clone());
334 }
335 Rvalue::Repeat(op, _) => {
336 self.operand(op, env)?;
337 }
338 Rvalue::NullaryOp(_, _) => {}
339 }
340 M::projection(self, env, place)?;
341 }
342 StatementKind::Intrinsic(NonDivergingIntrinsic::Assume(op)) => {
343 self.operand(op, env)?;
344 }
345 StatementKind::SetDiscriminant(_, _)
346 | StatementKind::FakeRead(_)
347 | StatementKind::AscribeUserType(_, _)
348 | StatementKind::PlaceMention(_)
349 | StatementKind::Nop => {}
350 }
351 Ok(())
352 }
353
354 fn operand(&mut self, op: &Operand, env: &mut Env) -> QueryResult {
355 match op {
356 Operand::Copy(place) | Operand::Move(place) => {
357 M::projection(self, env, place)?;
358 }
359 Operand::Constant(_) => {}
360 }
361 Ok(())
362 }
363
364 fn terminator(
365 &mut self,
366 terminator: &Terminator,
367 mut env: Env,
368 ) -> QueryResult<Vec<(Env, BasicBlock)>> {
369 let mut successors = vec![];
370 match &terminator.kind {
371 TerminatorKind::Return => {
372 M::ret(self, &env);
373 }
374 TerminatorKind::Call { args, destination, target, .. } => {
375 for arg in args {
376 self.operand(arg, &mut env)?;
377 }
378 M::projection(self, &mut env, destination)?;
379 if let Some(target) = target {
380 successors.push((env, *target));
381 }
382 }
383 TerminatorKind::SwitchInt { discr, targets } => {
384 let is_match = match discr {
385 Operand::Copy(place) | Operand::Move(place) => {
386 M::projection(self, &mut env, place)?;
387 self.discriminants.remove(place)
388 }
389 Operand::Constant(_) => None,
390 };
391 if let Some(place) = is_match {
392 let discr_ty = place.ty(self.genv, &self.body.local_decls)?.ty;
393 let (adt, _) = discr_ty.expect_adt();
394
395 let mut remaining: FxHashMap<u128, VariantIdx> = adt
396 .discriminants()
397 .map(|(idx, discr)| (discr, idx))
398 .collect();
399 for (bits, target) in targets.iter() {
400 let variant_idx = remaining
401 .remove(&bits)
402 .expect("value doesn't correspond to any variant");
403
404 let mut env = env.clone();
407 env.downcast(self.genv, &place, variant_idx)?;
408 successors.push((env, target));
409 }
410 if remaining.len() == 1 {
411 let (_, variant_idx) = remaining
412 .into_iter()
413 .next()
414 .unwrap_or_else(|| tracked_span_bug!());
415 env.downcast(self.genv, &place, variant_idx)?;
416 }
417 successors.push((env, targets.otherwise()));
418 } else {
419 let n = targets.all_targets().len();
420 for (env, target) in iter::zip(repeat_n(env, n), targets.all_targets()) {
421 successors.push((env, *target));
422 }
423 }
424 }
425 TerminatorKind::Goto { target } => {
426 successors.push((env, *target));
427 }
428 TerminatorKind::Yield { resume, resume_arg, .. } => {
429 M::projection(self, &mut env, resume_arg)?;
430 successors.push((env, *resume));
431 }
432 TerminatorKind::Drop { place, target, .. } => {
433 M::projection(self, &mut env, place)?;
434 successors.push((env, *target));
435 }
436 TerminatorKind::Assert { cond, target, .. } => {
437 self.operand(cond, &mut env)?;
438 successors.push((env, *target));
439 }
440 TerminatorKind::FalseEdge { real_target, .. } => {
441 successors.push((env, *real_target));
442 }
443 TerminatorKind::FalseUnwind { real_target, .. } => {
444 successors.push((env, *real_target));
445 }
446 TerminatorKind::Unreachable
447 | TerminatorKind::UnwindResume
448 | TerminatorKind::CoroutineDrop => {}
449 }
450 Ok(successors)
451 }
452
453 fn goto(&mut self, target: BasicBlock, env: Env) -> QueryResult {
454 if self.body.is_join_point(target) {
455 if M::goto_join_point(self, target, env)? {
456 self.queue.insert(target);
457 }
458 Ok(())
459 } else {
460 self.basic_block(target, env)
461 }
462 }
463}
464
465impl<'a, 'genv, 'tcx, M> FoldUnfoldAnalysis<'a, 'genv, 'tcx, M> {
466 pub(crate) fn new(
467 genv: GlobalEnv<'genv, 'tcx>,
468 body: &'a Body<'tcx>,
469 bb_envs: &'a mut FxHashMap<BasicBlock, Env>,
470 mode: M,
471 ) -> Self {
472 Self {
473 genv,
474 body,
475 bb_envs,
476 discriminants: Default::default(),
477 point: Point::FunEntry,
478 visited: DenseBitSet::new_empty(body.basic_blocks.len()),
479 queue: WorkQueue::empty(body.basic_blocks.len(), &body.dominator_order_rank),
480 mode,
481 }
482 }
483}
484
485impl PlaceNode {
486 fn deref(&mut self) -> (&mut PlaceNode, Modified) {
487 match self {
488 PlaceNode::Deref(_, node) => (node, false),
489 PlaceNode::Ty(ty) => {
490 *self = PlaceNode::Deref(ty.clone(), Box::new(PlaceNode::Ty(ty.deref())));
491 let PlaceNode::Deref(_, node) = self else { unreachable!() };
492 (node, true)
493 }
494 _ => tracked_span_bug!("deref of non-deref place: `{:?}`", self),
495 }
496 }
497
498 fn downcast(
499 &mut self,
500 genv: GlobalEnv,
501 idx: VariantIdx,
502 ) -> QueryResult<(&mut PlaceNode, Modified)> {
503 match self {
504 PlaceNode::Downcast(.., idx2, _) => {
505 debug_assert_eq!(idx, *idx2);
506 Ok((self, false))
507 }
508 PlaceNode::Ty(ty) => {
509 if let TyKind::Adt(adt_def, args) = ty.kind() {
510 let fields = downcast(genv, adt_def, args, idx)?;
511 *self = PlaceNode::Downcast(adt_def.clone(), args.clone(), idx, fields);
512 Ok((self, true))
513 } else {
514 tracked_span_bug!("invalid downcast `{self:?}`");
515 }
516 }
517 _ => tracked_span_bug!("invalid downcast `{self:?}`"),
518 }
519 }
520
521 fn field(&mut self, genv: GlobalEnv, f: FieldIdx) -> QueryResult<(&mut PlaceNode, Modified)> {
522 let (fields, unfolded) = self.fields(genv)?;
523 Ok((&mut fields[f.as_usize()], unfolded))
524 }
525
526 fn fields(&mut self, genv: GlobalEnv) -> QueryResult<(&mut Vec<PlaceNode>, bool)> {
527 match self {
528 PlaceNode::Ty(ty) => {
529 let fields = match ty.kind() {
530 TyKind::Adt(adt_def, args) => {
531 let fields = downcast_struct(genv, adt_def, args)?;
532 *self = PlaceNode::Downcast(
533 adt_def.clone(),
534 args.clone(),
535 FIRST_VARIANT,
536 fields,
537 );
538 let PlaceNode::Downcast(.., fields) = self else { unreachable!() };
539 fields
540 }
541 TyKind::Closure(def_id, args) => {
542 let fields = args
543 .as_closure()
544 .upvar_tys()
545 .iter()
546 .cloned()
547 .map(PlaceNode::Ty)
548 .collect_vec();
549 *self = PlaceNode::Closure(*def_id, args.clone(), fields);
550 let PlaceNode::Closure(.., fields) = self else { unreachable!() };
551 fields
552 }
553 TyKind::Tuple(fields) => {
554 let node_fields = fields.iter().cloned().map(PlaceNode::Ty).collect();
555 *self = PlaceNode::Tuple(fields.clone(), node_fields);
556 let PlaceNode::Tuple(.., fields) = self else { unreachable!() };
557 fields
558 }
559 TyKind::Coroutine(def_id, args) => {
560 let fields = args
561 .as_coroutine()
562 .upvar_tys()
563 .cloned()
564 .map(PlaceNode::Ty)
565 .collect_vec();
566 *self = PlaceNode::Generator(*def_id, args.clone(), fields);
567 let PlaceNode::Generator(.., fields) = self else { unreachable!() };
568 fields
569 }
570 _ => tracked_span_bug!("implicit downcast of non-struct: `{ty:?}`"),
571 };
572 Ok((fields, true))
573 }
574 PlaceNode::Downcast(.., fields)
575 | PlaceNode::Tuple(.., fields)
576 | PlaceNode::Closure(.., fields)
577 | PlaceNode::Generator(.., fields) => Ok((fields, false)),
578 PlaceNode::Deref(..) => {
579 tracked_span_bug!("projection field of non-adt non-tuple place: `{self:?}`")
580 }
581 }
582 }
583
584 fn ensure_folded(&mut self) -> Modified {
585 match self {
586 PlaceNode::Deref(ty, _) => {
587 *self = PlaceNode::Ty(ty.clone());
588 true
589 }
590 PlaceNode::Downcast(adt, args, ..) => {
591 *self = PlaceNode::Ty(Ty::mk_adt(adt.clone(), args.clone()));
592 true
593 }
594 PlaceNode::Closure(did, args, _) => {
595 *self = PlaceNode::Ty(Ty::mk_closure(*did, args.clone()));
596 true
597 }
598 PlaceNode::Generator(did, args, _) => {
599 *self = PlaceNode::Ty(Ty::mk_coroutine(*did, args.clone()));
600 true
601 }
602 PlaceNode::Tuple(fields, ..) => {
603 *self = PlaceNode::Ty(Ty::mk_tuple(fields.clone()));
604 true
605 }
606 PlaceNode::Ty(_) => false,
607 }
608 }
609
610 fn join(
611 &mut self,
612 genv: GlobalEnv,
613 other: &mut PlaceNode,
614 in_mut_ref: bool,
615 ) -> QueryResult<(bool, bool)> {
616 let mut modified1 = false;
617 let mut modified2 = false;
618
619 let (fields1, fields2) = match (&mut *self, &mut *other) {
620 (PlaceNode::Deref(ty1, node1), PlaceNode::Deref(ty2, node2)) => {
621 debug_assert_eq!(ty1, ty2);
622 return node1.join(genv, node2, in_mut_ref || ty1.is_mut_ref());
623 }
624 (PlaceNode::Tuple(_, fields1), PlaceNode::Tuple(_, fields2)) => (fields1, fields2),
625 (PlaceNode::Closure(.., fields1), PlaceNode::Closure(.., fields2)) => {
626 (fields1, fields2)
627 }
628 (PlaceNode::Generator(.., fields1), PlaceNode::Generator(.., fields2)) => {
629 (fields1, fields2)
630 }
631 (
632 PlaceNode::Downcast(adt1, args1, variant1, fields1),
633 PlaceNode::Downcast(adt2, args2, variant2, fields2),
634 ) => {
635 debug_assert_eq!(adt1, adt2);
636 if variant1 == variant2 {
637 (fields1, fields2)
638 } else {
639 *self = PlaceNode::Ty(Ty::mk_adt(adt1.clone(), args1.clone()));
640 *other = PlaceNode::Ty(Ty::mk_adt(adt2.clone(), args2.clone()));
641 return Ok((true, true));
642 }
643 }
644 (PlaceNode::Ty(_), PlaceNode::Ty(_)) => return Ok((false, false)),
645 (PlaceNode::Ty(_), _) => {
646 let (m1, m2) = other.join(genv, self, in_mut_ref)?;
647 return Ok((m2, m1));
648 }
649 (PlaceNode::Deref(ty, _), _) => {
650 *self = PlaceNode::Ty(ty.clone());
651 return Ok((true, false));
652 }
653 (PlaceNode::Tuple(_, fields1), _) => {
654 let (fields2, m) = other.fields(genv)?;
655 modified2 |= m;
656 (fields1, fields2)
657 }
658 (PlaceNode::Closure(.., fields1), _) | (PlaceNode::Generator(.., fields1), _) => {
659 let (fields2, m) = other.fields(genv)?;
660 modified2 |= m;
661 (fields1, fields2)
662 }
663
664 (PlaceNode::Downcast(adt, args, .., fields1), _) => {
665 if adt.is_struct() && !in_mut_ref {
666 let (fields2, m) = other.fields(genv)?;
667 modified2 |= m;
668 (fields1, fields2)
669 } else {
670 *self = PlaceNode::Ty(Ty::mk_adt(adt.clone(), args.clone()));
671 return Ok((true, false));
672 }
673 }
674 };
675 for (node1, node2) in iter::zip(fields1, fields2) {
676 let (m1, m2) = node1.join(genv, node2, in_mut_ref)?;
677 modified1 |= m1;
678 modified2 |= m2;
679 }
680 Ok((modified1, modified2))
681 }
682
683 fn collect_fold_unfolds(
685 &self,
686 target: &PlaceNode,
687 place: &mut Place,
688 stmts: &mut StatementsAt,
689 ) {
690 let (fields1, fields2) = match (self, target) {
691 (PlaceNode::Deref(_, node1), PlaceNode::Deref(_, node2)) => {
692 place.projection.push(PlaceElem::Deref);
693 node1.collect_fold_unfolds(node2, place, stmts);
694 place.projection.pop();
695 return;
696 }
697 (PlaceNode::Tuple(_, fields1), PlaceNode::Tuple(_, fields2)) => (fields1, fields2),
698 (PlaceNode::Closure(.., fields1), PlaceNode::Closure(.., fields2))
699 | (PlaceNode::Generator(.., fields1), PlaceNode::Generator(.., fields2)) => {
700 (fields1, fields2)
701 }
702 (
703 PlaceNode::Downcast(adt1, .., idx1, fields1),
704 PlaceNode::Downcast(adt2, .., idx2, fields2),
705 ) => {
706 tracked_span_dbg_assert_eq!(adt1.did(), adt2.did());
707 tracked_span_dbg_assert_eq!(idx1, idx2);
708 (fields1, fields2)
709 }
710 (PlaceNode::Ty(_), PlaceNode::Ty(_)) => return,
711 (PlaceNode::Ty(_), _) => {
712 target.collect_unfolds(place, stmts);
713 return;
714 }
715 (_, PlaceNode::Ty(_)) => {
716 stmts.insert(GhostStatement::Fold(place.clone()));
717 return;
718 }
719 _ => tracked_span_bug!("{self:?} {target:?}"),
720 };
721 for (i, (node1, node2)) in iter::zip(fields1, fields2).enumerate() {
722 place.projection.push(PlaceElem::Field(FieldIdx::new(i)));
723 node1.collect_fold_unfolds(node2, place, stmts);
724 place.projection.pop();
725 }
726 }
727
728 fn collect_unfolds(&self, place: &mut Place, stmts: &mut StatementsAt) {
729 match self {
730 PlaceNode::Ty(_) => {}
731 PlaceNode::Deref(_, node) => {
732 if node.is_ty() {
733 stmts.insert(GhostStatement::Unfold(place.clone()));
734 } else {
735 place.projection.push(PlaceElem::Deref);
736 node.collect_unfolds(place, stmts);
737 place.projection.pop();
738 }
739 }
740 PlaceNode::Downcast(.., fields)
741 | PlaceNode::Closure(.., fields)
742 | PlaceNode::Generator(.., fields)
743 | PlaceNode::Tuple(.., fields) => {
744 let all_leaves = fields.iter().all(PlaceNode::is_ty);
745 if all_leaves {
746 stmts.insert(GhostStatement::Unfold(place.clone()));
747 } else {
748 if let Some(idx) = self.enum_variant() {
749 place.projection.push(PlaceElem::Downcast(None, idx));
750 }
751 for (i, node) in fields.iter().enumerate() {
752 place.projection.push(PlaceElem::Field(FieldIdx::new(i)));
753 node.collect_unfolds(place, stmts);
754 place.projection.pop();
755 }
756 if self.enum_variant().is_some() {
757 place.projection.pop();
758 }
759 }
760 }
761 }
762 }
763
764 fn collect_folds_at_ret(&self, place: &mut Place, stmts: &mut StatementsAt) {
765 let fields = match self {
766 PlaceNode::Deref(ty, deref_ty) => {
767 place.projection.push(PlaceElem::Deref);
768 if ty.is_mut_ref() {
769 stmts.insert(GhostStatement::Fold(place.clone()));
770 } else if ty.is_box() {
771 deref_ty.collect_folds_at_ret(place, stmts);
772 }
773 place.projection.pop();
774 return;
775 }
776 PlaceNode::Downcast(adt, _, idx, fields) => {
777 if adt.is_enum() {
778 place.projection.push(PlaceElem::Downcast(None, *idx));
779 }
780 fields
781 }
782 PlaceNode::Closure(_, _, fields)
783 | PlaceNode::Generator(_, _, fields)
784 | PlaceNode::Tuple(_, fields) => fields,
785 PlaceNode::Ty(_) => return,
786 };
787 for (i, node) in fields.iter().enumerate() {
788 place.projection.push(PlaceElem::Field(FieldIdx::new(i)));
789 node.collect_folds_at_ret(place, stmts);
790 place.projection.pop();
791 }
792 if let PlaceNode::Downcast(adt, ..) = self
793 && adt.is_enum()
794 {
795 place.projection.pop();
796 }
797 }
798
799 fn enum_variant(&self) -> Option<VariantIdx> {
800 if let PlaceNode::Downcast(adt, _, idx, _) = self
801 && adt.is_enum()
802 {
803 Some(*idx)
804 } else {
805 None
806 }
807 }
808
809 #[must_use]
813 fn is_ty(&self) -> bool {
814 matches!(self, Self::Ty(..))
815 }
816}
817
818fn downcast(
819 genv: GlobalEnv,
820 adt_def: &AdtDef,
821 args: &GenericArgs,
822 variant: VariantIdx,
823) -> QueryResult<Vec<PlaceNode>> {
824 adt_def
825 .variant(variant)
826 .fields
827 .iter()
828 .map(|field| {
829 let ty = genv.lower_type_of(field.did)?.subst(args);
830 QueryResult::Ok(PlaceNode::Ty(ty))
831 })
832 .try_collect()
833}
834
835fn downcast_struct(
836 genv: GlobalEnv,
837 adt_def: &AdtDef,
838 args: &GenericArgs,
839) -> QueryResult<Vec<PlaceNode>> {
840 adt_def
841 .non_enum_variant()
842 .fields
843 .iter()
844 .map(|field| {
845 let ty = genv.lower_type_of(field.did)?.subst(args);
846 QueryResult::Ok(PlaceNode::Ty(ty))
847 })
848 .try_collect()
849}
850
851impl fmt::Debug for Env {
852 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
853 write!(
854 f,
855 "{}",
856 self.map
857 .iter_enumerated()
858 .format_with(", ", |(local, node), f| f(&format_args!("{local:?}: {node:?}")))
859 )
860 }
861}
862
863impl fmt::Debug for PlaceNode {
864 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
865 match self {
866 PlaceNode::Deref(_, node) => write!(f, "*({node:?})"),
867 PlaceNode::Downcast(adt, args, variant, fields) => {
868 write!(f, "{}", def_id_to_string(adt.did()))?;
869 if !args.is_empty() {
870 write!(f, "<{:?}>", args.iter().format(", "),)?;
871 }
872 write!(f, "::{}", adt.variant(*variant).name)?;
873 if !fields.is_empty() {
874 write!(f, "({:?})", fields.iter().format(", "),)?;
875 }
876 Ok(())
877 }
878 PlaceNode::Closure(did, args, fields) => {
879 write!(f, "Closure {}", def_id_to_string(*did))?;
880 if !args.is_empty() {
881 write!(f, "<{:?}>", args.iter().format(", "),)?;
882 }
883 if !fields.is_empty() {
884 write!(f, "({:?})", fields.iter().format(", "),)?;
885 }
886 Ok(())
887 }
888 PlaceNode::Generator(did, args, fields) => {
889 write!(f, "Generator {}", def_id_to_string(*did))?;
890 if !args.is_empty() {
891 write!(f, "<{:?}>", args.iter().format(", "),)?;
892 }
893 if !fields.is_empty() {
894 write!(f, "({:?})", fields.iter().format(", "),)?;
895 }
896 Ok(())
897 }
898 PlaceNode::Tuple(_, fields) => write!(f, "({:?})", fields.iter().format(", ")),
899 PlaceNode::Ty(ty) => write!(f, "•{ty:?}"),
900 }
901 }
902}