1use std::{collections::hash_map::Entry, fmt, iter};
2
3use flux_common::{tracked_span_assert_eq, tracked_span_bug, tracked_span_dbg_assert_eq};
4use flux_middle::{
5 PlaceExt as _, def_id_to_string, global_env::GlobalEnv, queries::QueryResult, query_bug, rty,
6};
7use flux_rustc_bridge::{
8 mir::{
9 BasicBlock, Body, BorrowKind, FIRST_VARIANT, FieldIdx, Local, Location,
10 NonDivergingIntrinsic, Operand, Place, PlaceElem, PlaceRef, Rvalue, Statement,
11 StatementKind, Terminator, TerminatorKind, UnOp, VariantIdx,
12 },
13 ty::{AdtDef, GenericArgs, GenericArgsExt as _, List, Mutability, Ty, TyKind},
14};
15use itertools::{Itertools, repeat_n};
16use rustc_data_structures::{fx::FxHashMap, unord::UnordMap};
17use rustc_hir::def_id::DefId;
18use rustc_index::{Idx, IndexVec, bit_set::DenseBitSet};
19use rustc_middle::mir::{FakeReadCause, START_BLOCK};
20
21use super::{GhostStatements, StatementsAt};
22use crate::{
23 ghost_statements::{GhostStatement, Point},
24 queue::WorkQueue,
25};
26
27pub(crate) fn add_ghost_statements<'tcx>(
28 stmts: &mut GhostStatements,
29 genv: GlobalEnv<'_, 'tcx>,
30 body: &Body<'tcx>,
31 fn_sig: Option<&rty::EarlyBinder<rty::PolyFnSig>>,
32) -> QueryResult {
33 let mut bb_envs = UnordMap::default();
34 FoldUnfoldAnalysis::new(genv, body, &mut bb_envs, Infer).run(fn_sig)?;
35
36 FoldUnfoldAnalysis::new(genv, body, &mut bb_envs, Elaboration { stmts }).run(fn_sig)
37}
38
39#[derive(Clone)]
40struct Env {
41 map: IndexVec<Local, PlaceNode>,
42}
43
44impl Env {
45 fn new(body: &Body) -> Self {
46 Self {
47 map: body
48 .local_decls
49 .iter()
50 .map(|decl| PlaceNode::Ty(decl.ty.clone()))
51 .collect(),
52 }
53 }
54
55 fn projection<'a>(&mut self, genv: GlobalEnv, place: &'a Place) -> QueryResult<ProjResult<'a>> {
56 let (node, place, modified) = self.ensure_unfolded(genv, place)?;
57 if modified {
58 Ok(ProjResult::Unfold(place))
59 } else if node.ensure_folded() {
60 Ok(ProjResult::Fold(place))
61 } else {
62 Ok(ProjResult::None)
63 }
64 }
65
66 fn downcast(&mut self, genv: GlobalEnv, place: &Place, variant_idx: VariantIdx) -> QueryResult {
67 let (node, ..) = self.ensure_unfolded(genv, place)?;
68 node.downcast(genv, variant_idx)?;
69 Ok(())
70 }
71
72 fn ensure_unfolded<'a>(
73 &mut self,
74 genv: GlobalEnv,
75 place: &'a Place,
76 ) -> QueryResult<(&mut PlaceNode, PlaceRef<'a>, Modified)> {
77 let mut node = &mut self.map[place.local];
78 let mut modified = false;
79 let mut i = 0;
80 while i < place.projection.len() {
81 let elem = place.projection[i];
82 let (n, m) = match elem {
83 PlaceElem::Deref => node.deref(),
84 PlaceElem::Field(f) => node.field(genv, f)?,
85 PlaceElem::Downcast(_, idx) => node.downcast(genv, idx)?,
86 PlaceElem::Index(_) | PlaceElem::ConstantIndex { .. } => break,
87 };
88 node = n;
89 modified |= m;
90 i += 1;
91 }
92 Ok((node, place.as_ref().truncate(i), modified))
93 }
94
95 fn join(&mut self, genv: GlobalEnv, mut other: Env) -> QueryResult<Modified> {
96 let mut modified = false;
97 for (local, node) in self.map.iter_enumerated_mut() {
98 let (m, _) = node.join(genv, &mut other.map[local], false)?;
99 modified |= m;
100 }
101 Ok(modified)
102 }
103
104 fn collect_fold_unfolds_at_goto(&self, target: &Env, stmts: &mut StatementsAt) {
105 for (local, node) in self.map.iter_enumerated() {
106 node.collect_fold_unfolds(&target.map[local], &mut Place::new(local, vec![]), stmts);
107 }
108 }
109
110 fn collect_folds_at_ret(&self, body: &Body, stmts: &mut StatementsAt) {
111 for local in body.args_iter() {
112 self.map[local].collect_folds_at_ret(&mut Place::new(local, vec![]), stmts);
113 }
114 }
115}
116
117type Modified = bool;
118
119struct FoldUnfoldAnalysis<'a, 'genv, 'tcx, M> {
120 genv: GlobalEnv<'genv, 'tcx>,
121 body: &'a Body<'tcx>,
122 bb_envs: &'a mut UnordMap<BasicBlock, Env>,
123 visited: DenseBitSet<BasicBlock>,
124 queue: WorkQueue<'a>,
125 discriminants: UnordMap<Place, Place>,
126 point: Point,
127 mode: M,
128}
129
130trait Mode: Sized {
131 const _NAME: &'static str;
132
133 fn projection(
134 analysis: &mut FoldUnfoldAnalysis<Self>,
135 env: &mut Env,
136 place: &Place,
137 ) -> QueryResult;
138
139 fn goto_join_point(
140 analysis: &mut FoldUnfoldAnalysis<Self>,
141 target: BasicBlock,
142 env: Env,
143 ) -> QueryResult<bool>;
144
145 fn ret(analysis: &mut FoldUnfoldAnalysis<Self>, env: &Env);
146}
147
148struct Infer;
149
150struct Elaboration<'a> {
151 stmts: &'a mut GhostStatements,
152}
153
154impl Elaboration<'_> {
155 fn insert_at(&mut self, point: Point, stmt: GhostStatement) {
156 self.stmts.insert_at(point, stmt);
157 }
158}
159
160#[derive(Debug)]
161enum ProjResult<'a> {
162 None,
163 Fold(PlaceRef<'a>),
164 Unfold(PlaceRef<'a>),
165}
166
167impl Mode for Infer {
168 const _NAME: &'static str = "infer";
169
170 fn projection(
171 analysis: &mut FoldUnfoldAnalysis<Self>,
172 env: &mut Env,
173 place: &Place,
174 ) -> QueryResult {
175 env.projection(analysis.genv, place)?;
176 Ok(())
177 }
178
179 fn goto_join_point(
180 analysis: &mut FoldUnfoldAnalysis<Self>,
181 target: BasicBlock,
182 env: Env,
183 ) -> QueryResult<bool> {
184 let modified = match analysis.bb_envs.entry(target) {
185 Entry::Occupied(mut entry) => entry.get_mut().join(analysis.genv, env)?,
186 Entry::Vacant(entry) => {
187 entry.insert(env);
188 true
189 }
190 };
191 Ok(modified)
192 }
193
194 fn ret(_: &mut FoldUnfoldAnalysis<Self>, _: &Env) {}
195}
196
197impl Mode for Elaboration<'_> {
198 const _NAME: &'static str = "elaboration";
199
200 fn projection(
201 analysis: &mut FoldUnfoldAnalysis<Self>,
202 env: &mut Env,
203 place: &Place,
204 ) -> QueryResult {
205 match env.projection(analysis.genv, place)? {
206 ProjResult::None => {}
207 ProjResult::Fold(place_ref) => {
208 tracked_span_assert_eq!(place_ref, place.as_ref());
209 let place = place.clone();
210 analysis
211 .mode
212 .insert_at(analysis.point, GhostStatement::Fold(place));
213 }
214 ProjResult::Unfold(place_ref) => {
215 tracked_span_assert_eq!(place_ref, place.as_ref());
216 match place_ref.last_projection() {
217 Some((base, PlaceElem::Deref | PlaceElem::Field(..))) => {
218 analysis
219 .mode
220 .insert_at(analysis.point, GhostStatement::Unfold(base.to_place()));
221 }
222 _ => Err(query_bug!("invalid projection for unfolding {place_ref:?}"))?,
223 }
224 }
225 }
226 Ok(())
227 }
228
229 fn goto_join_point(
230 analysis: &mut FoldUnfoldAnalysis<Self>,
231 target: BasicBlock,
232 env: Env,
233 ) -> QueryResult<bool> {
234 env.collect_fold_unfolds_at_goto(
235 &analysis.bb_envs[&target],
236 &mut analysis.mode.stmts.at(analysis.point),
237 );
238 Ok(!analysis.visited.contains(target))
239 }
240
241 fn ret(analysis: &mut FoldUnfoldAnalysis<Self>, env: &Env) {
242 env.collect_folds_at_ret(analysis.body, &mut analysis.mode.stmts.at(analysis.point));
243 }
244}
245
246#[derive(Clone)]
247enum PlaceNode {
248 Deref(Ty, Box<PlaceNode>),
249 Downcast(AdtDef, GenericArgs, VariantIdx, Vec<PlaceNode>),
250 Closure(DefId, GenericArgs, Vec<PlaceNode>),
251 Generator(DefId, GenericArgs, Vec<PlaceNode>),
252 Tuple(List<Ty>, Vec<PlaceNode>),
253 Ty(Ty),
254}
255
256impl<M: Mode> FoldUnfoldAnalysis<'_, '_, '_, M> {
257 fn run(mut self, fn_sig: Option<&rty::EarlyBinder<rty::PolyFnSig>>) -> QueryResult {
258 let mut env = Env::new(self.body);
259
260 if let Some(fn_sig) = fn_sig {
261 let fn_sig = fn_sig.as_ref().skip_binder().as_ref().skip_binder();
262 for (local, ty) in iter::zip(self.body.args_iter(), fn_sig.inputs()) {
263 if let rty::TyKind::StrgRef(..) | rty::Ref!(.., Mutability::Mut) = ty.kind() {
264 M::projection(&mut self, &mut env, &Place::new(local, vec![PlaceElem::Deref]))?;
265 }
266 }
267 }
268 self.goto(START_BLOCK, env)?;
269 while let Some(bb) = self.queue.pop() {
270 self.basic_block(bb, self.bb_envs[&bb].clone())?;
271 }
272 Ok(())
273 }
274
275 fn basic_block(&mut self, bb: BasicBlock, mut env: Env) -> QueryResult {
276 self.visited.insert(bb);
277 let data = &self.body.basic_blocks[bb];
278 for (statement_index, stmt) in data.statements.iter().enumerate() {
279 self.point = Point::BeforeLocation(Location { block: bb, statement_index });
280 self.statement(stmt, &mut env)?;
281 }
282 if let Some(terminator) = &data.terminator {
283 self.point = Point::BeforeLocation(self.body.terminator_loc(bb));
284 let successors = self.terminator(terminator, env)?;
285 for (env, target) in successors {
286 self.point = Point::Edge(bb, target);
287 self.goto(target, env)?;
288 }
289 }
290 Ok(())
291 }
292
293 fn statement(&mut self, stmt: &Statement, env: &mut Env) -> QueryResult {
294 match &stmt.kind {
295 StatementKind::FakeRead(box (FakeReadCause::ForIndex, place)) => {
296 M::projection(self, env, place)?;
297 }
298 StatementKind::Assign(place, rvalue) => {
299 match rvalue {
300 Rvalue::UnaryOp(UnOp::PtrMetadata, Operand::Copy(place))
301 | Rvalue::UnaryOp(UnOp::PtrMetadata, Operand::Move(place)) => {
302 let deref_place = place.deref();
303 M::projection(self, env, &deref_place)?;
304 }
305 Rvalue::Use(op)
306 | Rvalue::Cast(_, op, _)
307 | Rvalue::UnaryOp(_, op)
308 | Rvalue::ShallowInitBox(op, _) => {
309 self.operand(op, env)?;
310 }
311 Rvalue::Ref(.., bk, place) => {
312 if !matches!(bk, BorrowKind::Fake(_)) {
314 M::projection(self, env, place)?;
315 }
316 }
317 Rvalue::RawPtr(_, place) => {
318 M::projection(self, env, place)?;
319 }
320 Rvalue::BinaryOp(_, op1, op2) => {
321 self.operand(op1, env)?;
322 self.operand(op2, env)?;
323 }
324 Rvalue::Aggregate(_, args) => {
325 for arg in args {
326 self.operand(arg, env)?;
327 }
328 }
329
330 Rvalue::Discriminant(discr) => {
331 M::projection(self, env, discr)?;
332 self.discriminants.insert(place.clone(), discr.clone());
333 }
334 Rvalue::Repeat(op, _) => {
335 self.operand(op, env)?;
336 }
337 }
338 M::projection(self, env, place)?;
339 }
340 StatementKind::Intrinsic(NonDivergingIntrinsic::Assume(op)) => {
341 self.operand(op, env)?;
342 }
343 StatementKind::SetDiscriminant(_, _)
344 | StatementKind::FakeRead(_)
345 | StatementKind::AscribeUserType(_, _)
346 | StatementKind::PlaceMention(_)
347 | StatementKind::Nop => {}
348 }
349 Ok(())
350 }
351
352 fn operand(&mut self, op: &Operand, env: &mut Env) -> QueryResult {
353 match op {
354 Operand::Copy(place) | Operand::Move(place) => {
355 M::projection(self, env, place)?;
356 }
357 Operand::Constant(_) => {}
358 }
359 Ok(())
360 }
361
362 fn terminator(
363 &mut self,
364 terminator: &Terminator,
365 mut env: Env,
366 ) -> QueryResult<Vec<(Env, BasicBlock)>> {
367 let mut successors = vec![];
368 match &terminator.kind {
369 TerminatorKind::Return => {
370 M::ret(self, &env);
371 }
372 TerminatorKind::Call { args, destination, target, .. } => {
373 for arg in args {
374 self.operand(arg, &mut env)?;
375 }
376 M::projection(self, &mut env, destination)?;
377 if let Some(target) = target {
378 successors.push((env, *target));
379 }
380 }
381 TerminatorKind::SwitchInt { discr, targets } => {
382 let is_match = match discr {
383 Operand::Copy(place) | Operand::Move(place) => {
384 M::projection(self, &mut env, place)?;
385 self.discriminants.remove(place)
386 }
387 Operand::Constant(_) => None,
388 };
389 if let Some(place) = is_match {
390 let discr_ty = place.ty(self.genv, &self.body.local_decls)?.ty;
391 let (adt, _) = discr_ty.expect_adt();
392
393 let mut remaining: FxHashMap<u128, VariantIdx> = adt
394 .discriminants()
395 .map(|(idx, discr)| (discr, idx))
396 .collect();
397 for (bits, target) in targets.iter() {
398 let variant_idx = remaining
399 .remove(&bits)
400 .expect("value doesn't correspond to any variant");
401
402 let mut env = env.clone();
405 env.downcast(self.genv, &place, variant_idx)?;
406 successors.push((env, target));
407 }
408 if remaining.len() == 1 {
409 let (_, variant_idx) = remaining
410 .into_iter()
411 .next()
412 .unwrap_or_else(|| tracked_span_bug!());
413 env.downcast(self.genv, &place, variant_idx)?;
414 }
415 successors.push((env, targets.otherwise()));
416 } else {
417 let n = targets.all_targets().len();
418 for (env, target) in iter::zip(repeat_n(env, n), targets.all_targets()) {
419 successors.push((env, *target));
420 }
421 }
422 }
423 TerminatorKind::Goto { target } => {
424 successors.push((env, *target));
425 }
426 TerminatorKind::Yield { resume, resume_arg, .. } => {
427 M::projection(self, &mut env, resume_arg)?;
428 successors.push((env, *resume));
429 }
430 TerminatorKind::Drop { place, target, .. } => {
431 M::projection(self, &mut env, place)?;
432 successors.push((env, *target));
433 }
434 TerminatorKind::Assert { cond, target, .. } => {
435 self.operand(cond, &mut env)?;
436 successors.push((env, *target));
437 }
438 TerminatorKind::FalseEdge { real_target, .. } => {
439 successors.push((env, *real_target));
440 }
441 TerminatorKind::FalseUnwind { real_target, .. } => {
442 successors.push((env, *real_target));
443 }
444 TerminatorKind::Unreachable
445 | TerminatorKind::UnwindResume
446 | TerminatorKind::CoroutineDrop => {}
447 }
448 Ok(successors)
449 }
450
451 fn goto(&mut self, target: BasicBlock, env: Env) -> QueryResult {
452 if self.body.is_join_point(target) {
453 if M::goto_join_point(self, target, env)? {
454 self.queue.insert(target);
455 }
456 Ok(())
457 } else {
458 self.basic_block(target, env)
459 }
460 }
461}
462
463impl<'a, 'genv, 'tcx, M> FoldUnfoldAnalysis<'a, 'genv, 'tcx, M> {
464 pub(crate) fn new(
465 genv: GlobalEnv<'genv, 'tcx>,
466 body: &'a Body<'tcx>,
467 bb_envs: &'a mut UnordMap<BasicBlock, Env>,
468 mode: M,
469 ) -> Self {
470 Self {
471 genv,
472 body,
473 bb_envs,
474 discriminants: Default::default(),
475 point: Point::FunEntry,
476 visited: DenseBitSet::new_empty(body.basic_blocks.len()),
477 queue: WorkQueue::empty(body.basic_blocks.len(), &body.dominator_order_rank),
478 mode,
479 }
480 }
481}
482
483impl PlaceNode {
484 fn deref(&mut self) -> (&mut PlaceNode, Modified) {
485 match self {
486 PlaceNode::Deref(_, node) => (node, false),
487 PlaceNode::Ty(ty) => {
488 *self = PlaceNode::Deref(ty.clone(), Box::new(PlaceNode::Ty(ty.deref())));
489 let PlaceNode::Deref(_, node) = self else { unreachable!() };
490 (node, true)
491 }
492 _ => tracked_span_bug!("deref of non-deref place: `{:?}`", self),
493 }
494 }
495
496 fn downcast(
497 &mut self,
498 genv: GlobalEnv,
499 idx: VariantIdx,
500 ) -> QueryResult<(&mut PlaceNode, Modified)> {
501 match self {
502 PlaceNode::Downcast(.., idx2, _) => {
503 debug_assert_eq!(idx, *idx2);
504 Ok((self, false))
505 }
506 PlaceNode::Ty(ty) => {
507 if let TyKind::Adt(adt_def, args) = ty.kind() {
508 let fields = downcast(genv, adt_def, args, idx)?;
509 *self = PlaceNode::Downcast(adt_def.clone(), args.clone(), idx, fields);
510 Ok((self, true))
511 } else {
512 tracked_span_bug!("invalid downcast `{self:?}`");
513 }
514 }
515 _ => tracked_span_bug!("invalid downcast `{self:?}`"),
516 }
517 }
518
519 fn field(&mut self, genv: GlobalEnv, f: FieldIdx) -> QueryResult<(&mut PlaceNode, Modified)> {
520 let (fields, unfolded) = self.fields(genv)?;
521 Ok((&mut fields[f.as_usize()], unfolded))
522 }
523
524 fn fields(&mut self, genv: GlobalEnv) -> QueryResult<(&mut Vec<PlaceNode>, bool)> {
525 match self {
526 PlaceNode::Ty(ty) => {
527 let fields = match ty.kind() {
528 TyKind::Adt(adt_def, args) => {
529 let fields = downcast_struct(genv, adt_def, args)?;
530 *self = PlaceNode::Downcast(
531 adt_def.clone(),
532 args.clone(),
533 FIRST_VARIANT,
534 fields,
535 );
536 let PlaceNode::Downcast(.., fields) = self else { unreachable!() };
537 fields
538 }
539 TyKind::Closure(def_id, args) => {
540 let fields = args
541 .as_closure()
542 .upvar_tys()
543 .iter()
544 .cloned()
545 .map(PlaceNode::Ty)
546 .collect_vec();
547 *self = PlaceNode::Closure(*def_id, args.clone(), fields);
548 let PlaceNode::Closure(.., fields) = self else { unreachable!() };
549 fields
550 }
551 TyKind::Tuple(fields) => {
552 let node_fields = fields.iter().cloned().map(PlaceNode::Ty).collect();
553 *self = PlaceNode::Tuple(fields.clone(), node_fields);
554 let PlaceNode::Tuple(.., fields) = self else { unreachable!() };
555 fields
556 }
557 TyKind::Coroutine(def_id, args) => {
558 let fields = args
559 .as_coroutine()
560 .upvar_tys()
561 .cloned()
562 .map(PlaceNode::Ty)
563 .collect_vec();
564 *self = PlaceNode::Generator(*def_id, args.clone(), fields);
565 let PlaceNode::Generator(.., fields) = self else { unreachable!() };
566 fields
567 }
568 _ => tracked_span_bug!("implicit downcast of non-struct: `{ty:?}`"),
569 };
570 Ok((fields, true))
571 }
572 PlaceNode::Downcast(.., fields)
573 | PlaceNode::Tuple(.., fields)
574 | PlaceNode::Closure(.., fields)
575 | PlaceNode::Generator(.., fields) => Ok((fields, false)),
576 PlaceNode::Deref(..) => {
577 tracked_span_bug!("projection field of non-adt non-tuple place: `{self:?}`")
578 }
579 }
580 }
581
582 fn ensure_folded(&mut self) -> Modified {
583 match self {
584 PlaceNode::Deref(ty, _) => {
585 *self = PlaceNode::Ty(ty.clone());
586 true
587 }
588 PlaceNode::Downcast(adt, args, ..) => {
589 *self = PlaceNode::Ty(Ty::mk_adt(adt.clone(), args.clone()));
590 true
591 }
592 PlaceNode::Closure(did, args, _) => {
593 *self = PlaceNode::Ty(Ty::mk_closure(*did, args.clone()));
594 true
595 }
596 PlaceNode::Generator(did, args, _) => {
597 *self = PlaceNode::Ty(Ty::mk_coroutine(*did, args.clone()));
598 true
599 }
600 PlaceNode::Tuple(fields, ..) => {
601 *self = PlaceNode::Ty(Ty::mk_tuple(fields.clone()));
602 true
603 }
604 PlaceNode::Ty(_) => false,
605 }
606 }
607
608 fn join(
609 &mut self,
610 genv: GlobalEnv,
611 other: &mut PlaceNode,
612 in_mut_ref: bool,
613 ) -> QueryResult<(bool, bool)> {
614 let mut modified1 = false;
615 let mut modified2 = false;
616
617 let (fields1, fields2) = match (&mut *self, &mut *other) {
618 (PlaceNode::Deref(ty1, node1), PlaceNode::Deref(ty2, node2)) => {
619 debug_assert_eq!(ty1, ty2);
620 return node1.join(genv, node2, in_mut_ref || ty1.is_mut_ref());
621 }
622 (PlaceNode::Tuple(_, fields1), PlaceNode::Tuple(_, fields2)) => (fields1, fields2),
623 (PlaceNode::Closure(.., fields1), PlaceNode::Closure(.., fields2)) => {
624 (fields1, fields2)
625 }
626 (PlaceNode::Generator(.., fields1), PlaceNode::Generator(.., fields2)) => {
627 (fields1, fields2)
628 }
629 (
630 PlaceNode::Downcast(adt1, args1, variant1, fields1),
631 PlaceNode::Downcast(adt2, args2, variant2, fields2),
632 ) => {
633 debug_assert_eq!(adt1, adt2);
634 if variant1 == variant2 {
635 (fields1, fields2)
636 } else {
637 *self = PlaceNode::Ty(Ty::mk_adt(adt1.clone(), args1.clone()));
638 *other = PlaceNode::Ty(Ty::mk_adt(adt2.clone(), args2.clone()));
639 return Ok((true, true));
640 }
641 }
642 (PlaceNode::Ty(_), PlaceNode::Ty(_)) => return Ok((false, false)),
643 (PlaceNode::Ty(_), _) => {
644 let (m1, m2) = other.join(genv, self, in_mut_ref)?;
645 return Ok((m2, m1));
646 }
647 (PlaceNode::Deref(ty, _), _) => {
648 *self = PlaceNode::Ty(ty.clone());
649 return Ok((true, false));
650 }
651 (PlaceNode::Tuple(_, fields1), _) => {
652 let (fields2, m) = other.fields(genv)?;
653 modified2 |= m;
654 (fields1, fields2)
655 }
656 (PlaceNode::Closure(.., fields1), _) | (PlaceNode::Generator(.., fields1), _) => {
657 let (fields2, m) = other.fields(genv)?;
658 modified2 |= m;
659 (fields1, fields2)
660 }
661
662 (PlaceNode::Downcast(adt, args, .., fields1), _) => {
663 if adt.is_struct() && !in_mut_ref {
664 let (fields2, m) = other.fields(genv)?;
665 modified2 |= m;
666 (fields1, fields2)
667 } else {
668 *self = PlaceNode::Ty(Ty::mk_adt(adt.clone(), args.clone()));
669 return Ok((true, false));
670 }
671 }
672 };
673 for (node1, node2) in iter::zip(fields1, fields2) {
674 let (m1, m2) = node1.join(genv, node2, in_mut_ref)?;
675 modified1 |= m1;
676 modified2 |= m2;
677 }
678 Ok((modified1, modified2))
679 }
680
681 fn collect_fold_unfolds(
683 &self,
684 target: &PlaceNode,
685 place: &mut Place,
686 stmts: &mut StatementsAt,
687 ) {
688 let (fields1, fields2) = match (self, target) {
689 (PlaceNode::Deref(_, node1), PlaceNode::Deref(_, node2)) => {
690 place.projection.push(PlaceElem::Deref);
691 node1.collect_fold_unfolds(node2, place, stmts);
692 place.projection.pop();
693 return;
694 }
695 (PlaceNode::Tuple(_, fields1), PlaceNode::Tuple(_, fields2)) => (fields1, fields2),
696 (PlaceNode::Closure(.., fields1), PlaceNode::Closure(.., fields2))
697 | (PlaceNode::Generator(.., fields1), PlaceNode::Generator(.., fields2)) => {
698 (fields1, fields2)
699 }
700 (
701 PlaceNode::Downcast(adt1, .., idx1, fields1),
702 PlaceNode::Downcast(adt2, .., idx2, fields2),
703 ) => {
704 tracked_span_dbg_assert_eq!(adt1.did(), adt2.did());
705 tracked_span_dbg_assert_eq!(idx1, idx2);
706 (fields1, fields2)
707 }
708 (PlaceNode::Ty(_), PlaceNode::Ty(_)) => return,
709 (PlaceNode::Ty(_), _) => {
710 target.collect_unfolds(place, stmts);
711 return;
712 }
713 (_, PlaceNode::Ty(_)) => {
714 stmts.insert(GhostStatement::Fold(place.clone()));
715 return;
716 }
717 _ => tracked_span_bug!("{self:?} {target:?}"),
718 };
719 for (i, (node1, node2)) in iter::zip(fields1, fields2).enumerate() {
720 place.projection.push(PlaceElem::Field(FieldIdx::new(i)));
721 node1.collect_fold_unfolds(node2, place, stmts);
722 place.projection.pop();
723 }
724 }
725
726 fn collect_unfolds(&self, place: &mut Place, stmts: &mut StatementsAt) {
727 match self {
728 PlaceNode::Ty(_) => {}
729 PlaceNode::Deref(_, node) => {
730 if node.is_ty() {
731 stmts.insert(GhostStatement::Unfold(place.clone()));
732 } else {
733 place.projection.push(PlaceElem::Deref);
734 node.collect_unfolds(place, stmts);
735 place.projection.pop();
736 }
737 }
738 PlaceNode::Downcast(.., fields)
739 | PlaceNode::Closure(.., fields)
740 | PlaceNode::Generator(.., fields)
741 | PlaceNode::Tuple(.., fields) => {
742 let all_leaves = fields.iter().all(PlaceNode::is_ty);
743 if all_leaves {
744 stmts.insert(GhostStatement::Unfold(place.clone()));
745 } else {
746 if let Some(idx) = self.enum_variant() {
747 place.projection.push(PlaceElem::Downcast(None, idx));
748 }
749 for (i, node) in fields.iter().enumerate() {
750 place.projection.push(PlaceElem::Field(FieldIdx::new(i)));
751 node.collect_unfolds(place, stmts);
752 place.projection.pop();
753 }
754 if self.enum_variant().is_some() {
755 place.projection.pop();
756 }
757 }
758 }
759 }
760 }
761
762 fn collect_folds_at_ret(&self, place: &mut Place, stmts: &mut StatementsAt) {
763 let fields = match self {
764 PlaceNode::Deref(ty, deref_ty) => {
765 place.projection.push(PlaceElem::Deref);
766 if ty.is_mut_ref() {
767 stmts.insert(GhostStatement::Fold(place.clone()));
768 } else if ty.is_box() {
769 deref_ty.collect_folds_at_ret(place, stmts);
770 }
771 place.projection.pop();
772 return;
773 }
774 PlaceNode::Downcast(adt, _, idx, fields) => {
775 if adt.is_enum() {
776 place.projection.push(PlaceElem::Downcast(None, *idx));
777 }
778 fields
779 }
780 PlaceNode::Closure(_, _, fields)
781 | PlaceNode::Generator(_, _, fields)
782 | PlaceNode::Tuple(_, fields) => fields,
783 PlaceNode::Ty(_) => return,
784 };
785 for (i, node) in fields.iter().enumerate() {
786 place.projection.push(PlaceElem::Field(FieldIdx::new(i)));
787 node.collect_folds_at_ret(place, stmts);
788 place.projection.pop();
789 }
790 if let PlaceNode::Downcast(adt, ..) = self
791 && adt.is_enum()
792 {
793 place.projection.pop();
794 }
795 }
796
797 fn enum_variant(&self) -> Option<VariantIdx> {
798 if let PlaceNode::Downcast(adt, _, idx, _) = self
799 && adt.is_enum()
800 {
801 Some(*idx)
802 } else {
803 None
804 }
805 }
806
807 #[must_use]
811 fn is_ty(&self) -> bool {
812 matches!(self, Self::Ty(..))
813 }
814}
815
816fn downcast(
817 genv: GlobalEnv,
818 adt_def: &AdtDef,
819 args: &GenericArgs,
820 variant: VariantIdx,
821) -> QueryResult<Vec<PlaceNode>> {
822 adt_def
823 .variant(variant)
824 .fields
825 .iter()
826 .map(|field| {
827 let ty = genv.lower_type_of(field.did)?.subst(args);
828 QueryResult::Ok(PlaceNode::Ty(ty))
829 })
830 .try_collect()
831}
832
833fn downcast_struct(
834 genv: GlobalEnv,
835 adt_def: &AdtDef,
836 args: &GenericArgs,
837) -> QueryResult<Vec<PlaceNode>> {
838 adt_def
839 .non_enum_variant()
840 .fields
841 .iter()
842 .map(|field| {
843 let ty = genv.lower_type_of(field.did)?.subst(args);
844 QueryResult::Ok(PlaceNode::Ty(ty))
845 })
846 .try_collect()
847}
848
849impl fmt::Debug for Env {
850 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
851 write!(
852 f,
853 "{}",
854 self.map
855 .iter_enumerated()
856 .format_with(", ", |(local, node), f| f(&format_args!("{local:?}: {node:?}")))
857 )
858 }
859}
860
861impl fmt::Debug for PlaceNode {
862 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
863 match self {
864 PlaceNode::Deref(_, node) => write!(f, "*({node:?})"),
865 PlaceNode::Downcast(adt, args, variant, fields) => {
866 write!(f, "{}", def_id_to_string(adt.did()))?;
867 if !args.is_empty() {
868 write!(f, "<{:?}>", args.iter().format(", "),)?;
869 }
870 write!(f, "::{}", adt.variant(*variant).name)?;
871 if !fields.is_empty() {
872 write!(f, "({:?})", fields.iter().format(", "),)?;
873 }
874 Ok(())
875 }
876 PlaceNode::Closure(did, args, fields) => {
877 write!(f, "Closure {}", def_id_to_string(*did))?;
878 if !args.is_empty() {
879 write!(f, "<{:?}>", args.iter().format(", "),)?;
880 }
881 if !fields.is_empty() {
882 write!(f, "({:?})", fields.iter().format(", "),)?;
883 }
884 Ok(())
885 }
886 PlaceNode::Generator(did, args, fields) => {
887 write!(f, "Generator {}", def_id_to_string(*did))?;
888 if !args.is_empty() {
889 write!(f, "<{:?}>", args.iter().format(", "),)?;
890 }
891 if !fields.is_empty() {
892 write!(f, "({:?})", fields.iter().format(", "),)?;
893 }
894 Ok(())
895 }
896 PlaceNode::Tuple(_, fields) => write!(f, "({:?})", fields.iter().format(", ")),
897 PlaceNode::Ty(ty) => write!(f, "•{ty:?}"),
898 }
899 }
900}