1use std::{collections::hash_map::Entry, fmt, iter};
2
3use flux_common::{tracked_span_assert_eq, tracked_span_bug, tracked_span_dbg_assert_eq};
4use flux_middle::{
5 PlaceExt as _, def_id_to_string, global_env::GlobalEnv, queries::QueryResult, query_bug, rty,
6};
7use flux_rustc_bridge::{
8 mir::{
9 BasicBlock, Body, BorrowKind, FIRST_VARIANT, FieldIdx, Local, Location,
10 NonDivergingIntrinsic, Operand, Place, PlaceElem, PlaceRef, Rvalue, Statement,
11 StatementKind, Terminator, TerminatorKind, UnOp, VariantIdx,
12 },
13 ty::{AdtDef, GenericArgs, GenericArgsExt as _, List, Mutability, Ty, TyKind},
14};
15use itertools::{Itertools, repeat_n};
16use rustc_data_structures::unord::UnordMap;
17use rustc_hash::FxHashMap;
18use rustc_hir::def_id::DefId;
19use rustc_index::{Idx, IndexVec, bit_set::DenseBitSet};
20use rustc_middle::mir::{FakeReadCause, START_BLOCK};
21
22use super::{GhostStatements, StatementsAt};
23use crate::{
24 ghost_statements::{GhostStatement, Point},
25 queue::WorkQueue,
26};
27
28pub(crate) fn add_ghost_statements<'tcx>(
29 stmts: &mut GhostStatements,
30 genv: GlobalEnv<'_, 'tcx>,
31 body: &Body<'tcx>,
32 fn_sig: Option<&rty::EarlyBinder<rty::PolyFnSig>>,
33) -> QueryResult {
34 let mut bb_envs = FxHashMap::default();
35 FoldUnfoldAnalysis::new(genv, body, &mut bb_envs, Infer).run(fn_sig)?;
36
37 FoldUnfoldAnalysis::new(genv, body, &mut bb_envs, Elaboration { stmts }).run(fn_sig)
38}
39
40#[derive(Clone)]
41struct Env {
42 map: IndexVec<Local, PlaceNode>,
43}
44
45impl Env {
46 fn new(body: &Body) -> Self {
47 Self {
48 map: body
49 .local_decls
50 .iter()
51 .map(|decl| PlaceNode::Ty(decl.ty.clone()))
52 .collect(),
53 }
54 }
55
56 fn projection<'a>(&mut self, genv: GlobalEnv, place: &'a Place) -> QueryResult<ProjResult<'a>> {
57 let (node, place, modified) = self.ensure_unfolded(genv, place)?;
58 if modified {
59 Ok(ProjResult::Unfold(place))
60 } else if node.ensure_folded() {
61 Ok(ProjResult::Fold(place))
62 } else {
63 Ok(ProjResult::None)
64 }
65 }
66
67 fn downcast(&mut self, genv: GlobalEnv, place: &Place, variant_idx: VariantIdx) -> QueryResult {
68 let (node, ..) = self.ensure_unfolded(genv, place)?;
69 node.downcast(genv, variant_idx)?;
70 Ok(())
71 }
72
73 fn ensure_unfolded<'a>(
74 &mut self,
75 genv: GlobalEnv,
76 place: &'a Place,
77 ) -> QueryResult<(&mut PlaceNode, PlaceRef<'a>, Modified)> {
78 let mut node = &mut self.map[place.local];
79 let mut modified = false;
80 let mut i = 0;
81 while i < place.projection.len() {
82 let elem = place.projection[i];
83 let (n, m) = match elem {
84 PlaceElem::Deref => node.deref(),
85 PlaceElem::Field(f) => node.field(genv, f)?,
86 PlaceElem::Downcast(_, idx) => node.downcast(genv, idx)?,
87 PlaceElem::Index(_) | PlaceElem::ConstantIndex { .. } => break,
88 };
89 node = n;
90 modified |= m;
91 i += 1;
92 }
93 Ok((node, place.as_ref().truncate(i), modified))
94 }
95
96 fn join(&mut self, genv: GlobalEnv, mut other: Env) -> QueryResult<Modified> {
97 let mut modified = false;
98 for (local, node) in self.map.iter_enumerated_mut() {
99 let (m, _) = node.join(genv, &mut other.map[local], false)?;
100 modified |= m;
101 }
102 Ok(modified)
103 }
104
105 fn collect_fold_unfolds_at_goto(&self, target: &Env, stmts: &mut StatementsAt) {
106 for (local, node) in self.map.iter_enumerated() {
107 node.collect_fold_unfolds(&target.map[local], &mut Place::new(local, vec![]), stmts);
108 }
109 }
110
111 fn collect_folds_at_ret(&self, body: &Body, stmts: &mut StatementsAt) {
112 for local in body.args_iter() {
113 self.map[local].collect_folds_at_ret(&mut Place::new(local, vec![]), stmts);
114 }
115 }
116}
117
118type Modified = bool;
119
120struct FoldUnfoldAnalysis<'a, 'genv, 'tcx, M> {
121 genv: GlobalEnv<'genv, 'tcx>,
122 body: &'a Body<'tcx>,
123 bb_envs: &'a mut FxHashMap<BasicBlock, Env>,
124 visited: DenseBitSet<BasicBlock>,
125 queue: WorkQueue<'a>,
126 discriminants: UnordMap<Place, Place>,
127 point: Point,
128 mode: M,
129}
130
131trait Mode: Sized {
132 const _NAME: &'static str;
133
134 fn projection(
135 analysis: &mut FoldUnfoldAnalysis<Self>,
136 env: &mut Env,
137 place: &Place,
138 ) -> QueryResult;
139
140 fn goto_join_point(
141 analysis: &mut FoldUnfoldAnalysis<Self>,
142 target: BasicBlock,
143 env: Env,
144 ) -> QueryResult<bool>;
145
146 fn ret(analysis: &mut FoldUnfoldAnalysis<Self>, env: &Env);
147}
148
149struct Infer;
150
151struct Elaboration<'a> {
152 stmts: &'a mut GhostStatements,
153}
154
155impl Elaboration<'_> {
156 fn insert_at(&mut self, point: Point, stmt: GhostStatement) {
157 self.stmts.insert_at(point, stmt);
158 }
159}
160
161#[derive(Debug)]
162enum ProjResult<'a> {
163 None,
164 Fold(PlaceRef<'a>),
165 Unfold(PlaceRef<'a>),
166}
167
168impl Mode for Infer {
169 const _NAME: &'static str = "infer";
170
171 fn projection(
172 analysis: &mut FoldUnfoldAnalysis<Self>,
173 env: &mut Env,
174 place: &Place,
175 ) -> QueryResult {
176 env.projection(analysis.genv, place)?;
177 Ok(())
178 }
179
180 fn goto_join_point(
181 analysis: &mut FoldUnfoldAnalysis<Self>,
182 target: BasicBlock,
183 env: Env,
184 ) -> QueryResult<bool> {
185 let modified = match analysis.bb_envs.entry(target) {
186 Entry::Occupied(mut entry) => entry.get_mut().join(analysis.genv, env)?,
187 Entry::Vacant(entry) => {
188 entry.insert(env);
189 true
190 }
191 };
192 Ok(modified)
193 }
194
195 fn ret(_: &mut FoldUnfoldAnalysis<Self>, _: &Env) {}
196}
197
198impl Mode for Elaboration<'_> {
199 const _NAME: &'static str = "elaboration";
200
201 fn projection(
202 analysis: &mut FoldUnfoldAnalysis<Self>,
203 env: &mut Env,
204 place: &Place,
205 ) -> QueryResult {
206 match env.projection(analysis.genv, place)? {
207 ProjResult::None => {}
208 ProjResult::Fold(place_ref) => {
209 tracked_span_assert_eq!(place_ref, place.as_ref());
210 let place = place.clone();
211 analysis
212 .mode
213 .insert_at(analysis.point, GhostStatement::Fold(place));
214 }
215 ProjResult::Unfold(place_ref) => {
216 tracked_span_assert_eq!(place_ref, place.as_ref());
217 match place_ref.last_projection() {
218 Some((base, PlaceElem::Deref | PlaceElem::Field(..))) => {
219 analysis
220 .mode
221 .insert_at(analysis.point, GhostStatement::Unfold(base.to_place()));
222 }
223 _ => Err(query_bug!("invalid projection for unfolding {place_ref:?}"))?,
224 }
225 }
226 }
227 Ok(())
228 }
229
230 fn goto_join_point(
231 analysis: &mut FoldUnfoldAnalysis<Self>,
232 target: BasicBlock,
233 env: Env,
234 ) -> QueryResult<bool> {
235 env.collect_fold_unfolds_at_goto(
236 &analysis.bb_envs[&target],
237 &mut analysis.mode.stmts.at(analysis.point),
238 );
239 Ok(!analysis.visited.contains(target))
240 }
241
242 fn ret(analysis: &mut FoldUnfoldAnalysis<Self>, env: &Env) {
243 env.collect_folds_at_ret(analysis.body, &mut analysis.mode.stmts.at(analysis.point));
244 }
245}
246
247#[derive(Clone)]
248enum PlaceNode {
249 Deref(Ty, Box<PlaceNode>),
250 Downcast(AdtDef, GenericArgs, VariantIdx, Vec<PlaceNode>),
251 Closure(DefId, GenericArgs, Vec<PlaceNode>),
252 Generator(DefId, GenericArgs, Vec<PlaceNode>),
253 Tuple(List<Ty>, Vec<PlaceNode>),
254 Ty(Ty),
255}
256
257impl<M: Mode> FoldUnfoldAnalysis<'_, '_, '_, M> {
258 fn run(mut self, fn_sig: Option<&rty::EarlyBinder<rty::PolyFnSig>>) -> QueryResult {
259 let mut env = Env::new(self.body);
260
261 if let Some(fn_sig) = fn_sig {
262 let fn_sig = fn_sig.as_ref().skip_binder().as_ref().skip_binder();
263 for (local, ty) in iter::zip(self.body.args_iter(), fn_sig.inputs()) {
264 if let rty::TyKind::StrgRef(..) | rty::Ref!(.., Mutability::Mut) = ty.kind() {
265 M::projection(&mut self, &mut env, &Place::new(local, vec![PlaceElem::Deref]))?;
266 }
267 }
268 }
269 self.goto(START_BLOCK, env)?;
270 while let Some(bb) = self.queue.pop() {
271 self.basic_block(bb, self.bb_envs[&bb].clone())?;
272 }
273 Ok(())
274 }
275
276 fn basic_block(&mut self, bb: BasicBlock, mut env: Env) -> QueryResult {
277 self.visited.insert(bb);
278 let data = &self.body.basic_blocks[bb];
279 for (statement_index, stmt) in data.statements.iter().enumerate() {
280 self.point = Point::BeforeLocation(Location { block: bb, statement_index });
281 self.statement(stmt, &mut env)?;
282 }
283 if let Some(terminator) = &data.terminator {
284 self.point = Point::BeforeLocation(self.body.terminator_loc(bb));
285 let successors = self.terminator(terminator, env)?;
286 for (env, target) in successors {
287 self.point = Point::Edge(bb, target);
288 self.goto(target, env)?;
289 }
290 }
291 Ok(())
292 }
293
294 fn statement(&mut self, stmt: &Statement, env: &mut Env) -> QueryResult {
295 match &stmt.kind {
296 StatementKind::FakeRead(box (FakeReadCause::ForIndex, place)) => {
297 M::projection(self, env, place)?;
298 }
299 StatementKind::Assign(place, rvalue) => {
300 match rvalue {
301 Rvalue::UnaryOp(UnOp::PtrMetadata, Operand::Copy(place))
302 | Rvalue::UnaryOp(UnOp::PtrMetadata, Operand::Move(place)) => {
303 let deref_place = place.deref();
304 M::projection(self, env, &deref_place)?;
305 }
306 Rvalue::Use(op)
307 | Rvalue::Cast(_, op, _)
308 | Rvalue::UnaryOp(_, op)
309 | Rvalue::ShallowInitBox(op, _) => {
310 self.operand(op, env)?;
311 }
312 Rvalue::Ref(.., bk, place) => {
313 if !matches!(bk, BorrowKind::Fake(_)) {
315 M::projection(self, env, place)?;
316 }
317 }
318 Rvalue::RawPtr(_, place) => {
319 M::projection(self, env, place)?;
320 }
321 Rvalue::BinaryOp(_, op1, op2) => {
322 self.operand(op1, env)?;
323 self.operand(op2, env)?;
324 }
325 Rvalue::Aggregate(_, args) => {
326 for arg in args {
327 self.operand(arg, env)?;
328 }
329 }
330
331 Rvalue::Discriminant(discr) => {
332 M::projection(self, env, discr)?;
333 self.discriminants.insert(place.clone(), discr.clone());
334 }
335 Rvalue::Repeat(op, _) => {
336 self.operand(op, env)?;
337 }
338 }
339 M::projection(self, env, place)?;
340 }
341 StatementKind::Intrinsic(NonDivergingIntrinsic::Assume(op)) => {
342 self.operand(op, env)?;
343 }
344 StatementKind::SetDiscriminant(_, _)
345 | StatementKind::FakeRead(_)
346 | StatementKind::AscribeUserType(_, _)
347 | StatementKind::PlaceMention(_)
348 | StatementKind::Nop => {}
349 }
350 Ok(())
351 }
352
353 fn operand(&mut self, op: &Operand, env: &mut Env) -> QueryResult {
354 match op {
355 Operand::Copy(place) | Operand::Move(place) => {
356 M::projection(self, env, place)?;
357 }
358 Operand::Constant(_) => {}
359 }
360 Ok(())
361 }
362
363 fn terminator(
364 &mut self,
365 terminator: &Terminator,
366 mut env: Env,
367 ) -> QueryResult<Vec<(Env, BasicBlock)>> {
368 let mut successors = vec![];
369 match &terminator.kind {
370 TerminatorKind::Return => {
371 M::ret(self, &env);
372 }
373 TerminatorKind::Call { args, destination, target, .. } => {
374 for arg in args {
375 self.operand(arg, &mut env)?;
376 }
377 M::projection(self, &mut env, destination)?;
378 if let Some(target) = target {
379 successors.push((env, *target));
380 }
381 }
382 TerminatorKind::SwitchInt { discr, targets } => {
383 let is_match = match discr {
384 Operand::Copy(place) | Operand::Move(place) => {
385 M::projection(self, &mut env, place)?;
386 self.discriminants.remove(place)
387 }
388 Operand::Constant(_) => None,
389 };
390 if let Some(place) = is_match {
391 let discr_ty = place.ty(self.genv, &self.body.local_decls)?.ty;
392 let (adt, _) = discr_ty.expect_adt();
393
394 let mut remaining: FxHashMap<u128, VariantIdx> = adt
395 .discriminants()
396 .map(|(idx, discr)| (discr, idx))
397 .collect();
398 for (bits, target) in targets.iter() {
399 let variant_idx = remaining
400 .remove(&bits)
401 .expect("value doesn't correspond to any variant");
402
403 let mut env = env.clone();
406 env.downcast(self.genv, &place, variant_idx)?;
407 successors.push((env, target));
408 }
409 if remaining.len() == 1 {
410 let (_, variant_idx) = remaining
411 .into_iter()
412 .next()
413 .unwrap_or_else(|| tracked_span_bug!());
414 env.downcast(self.genv, &place, variant_idx)?;
415 }
416 successors.push((env, targets.otherwise()));
417 } else {
418 let n = targets.all_targets().len();
419 for (env, target) in iter::zip(repeat_n(env, n), targets.all_targets()) {
420 successors.push((env, *target));
421 }
422 }
423 }
424 TerminatorKind::Goto { target } => {
425 successors.push((env, *target));
426 }
427 TerminatorKind::Yield { resume, resume_arg, .. } => {
428 M::projection(self, &mut env, resume_arg)?;
429 successors.push((env, *resume));
430 }
431 TerminatorKind::Drop { place, target, .. } => {
432 M::projection(self, &mut env, place)?;
433 successors.push((env, *target));
434 }
435 TerminatorKind::Assert { cond, target, .. } => {
436 self.operand(cond, &mut env)?;
437 successors.push((env, *target));
438 }
439 TerminatorKind::FalseEdge { real_target, .. } => {
440 successors.push((env, *real_target));
441 }
442 TerminatorKind::FalseUnwind { real_target, .. } => {
443 successors.push((env, *real_target));
444 }
445 TerminatorKind::Unreachable
446 | TerminatorKind::UnwindResume
447 | TerminatorKind::CoroutineDrop => {}
448 }
449 Ok(successors)
450 }
451
452 fn goto(&mut self, target: BasicBlock, env: Env) -> QueryResult {
453 if self.body.is_join_point(target) {
454 if M::goto_join_point(self, target, env)? {
455 self.queue.insert(target);
456 }
457 Ok(())
458 } else {
459 self.basic_block(target, env)
460 }
461 }
462}
463
464impl<'a, 'genv, 'tcx, M> FoldUnfoldAnalysis<'a, 'genv, 'tcx, M> {
465 pub(crate) fn new(
466 genv: GlobalEnv<'genv, 'tcx>,
467 body: &'a Body<'tcx>,
468 bb_envs: &'a mut FxHashMap<BasicBlock, Env>,
469 mode: M,
470 ) -> Self {
471 Self {
472 genv,
473 body,
474 bb_envs,
475 discriminants: Default::default(),
476 point: Point::FunEntry,
477 visited: DenseBitSet::new_empty(body.basic_blocks.len()),
478 queue: WorkQueue::empty(body.basic_blocks.len(), &body.dominator_order_rank),
479 mode,
480 }
481 }
482}
483
484impl PlaceNode {
485 fn deref(&mut self) -> (&mut PlaceNode, Modified) {
486 match self {
487 PlaceNode::Deref(_, node) => (node, false),
488 PlaceNode::Ty(ty) => {
489 *self = PlaceNode::Deref(ty.clone(), Box::new(PlaceNode::Ty(ty.deref())));
490 let PlaceNode::Deref(_, node) = self else { unreachable!() };
491 (node, true)
492 }
493 _ => tracked_span_bug!("deref of non-deref place: `{:?}`", self),
494 }
495 }
496
497 fn downcast(
498 &mut self,
499 genv: GlobalEnv,
500 idx: VariantIdx,
501 ) -> QueryResult<(&mut PlaceNode, Modified)> {
502 match self {
503 PlaceNode::Downcast(.., idx2, _) => {
504 debug_assert_eq!(idx, *idx2);
505 Ok((self, false))
506 }
507 PlaceNode::Ty(ty) => {
508 if let TyKind::Adt(adt_def, args) = ty.kind() {
509 let fields = downcast(genv, adt_def, args, idx)?;
510 *self = PlaceNode::Downcast(adt_def.clone(), args.clone(), idx, fields);
511 Ok((self, true))
512 } else {
513 tracked_span_bug!("invalid downcast `{self:?}`");
514 }
515 }
516 _ => tracked_span_bug!("invalid downcast `{self:?}`"),
517 }
518 }
519
520 fn field(&mut self, genv: GlobalEnv, f: FieldIdx) -> QueryResult<(&mut PlaceNode, Modified)> {
521 let (fields, unfolded) = self.fields(genv)?;
522 Ok((&mut fields[f.as_usize()], unfolded))
523 }
524
525 fn fields(&mut self, genv: GlobalEnv) -> QueryResult<(&mut Vec<PlaceNode>, bool)> {
526 match self {
527 PlaceNode::Ty(ty) => {
528 let fields = match ty.kind() {
529 TyKind::Adt(adt_def, args) => {
530 let fields = downcast_struct(genv, adt_def, args)?;
531 *self = PlaceNode::Downcast(
532 adt_def.clone(),
533 args.clone(),
534 FIRST_VARIANT,
535 fields,
536 );
537 let PlaceNode::Downcast(.., fields) = self else { unreachable!() };
538 fields
539 }
540 TyKind::Closure(def_id, args) => {
541 let fields = args
542 .as_closure()
543 .upvar_tys()
544 .iter()
545 .cloned()
546 .map(PlaceNode::Ty)
547 .collect_vec();
548 *self = PlaceNode::Closure(*def_id, args.clone(), fields);
549 let PlaceNode::Closure(.., fields) = self else { unreachable!() };
550 fields
551 }
552 TyKind::Tuple(fields) => {
553 let node_fields = fields.iter().cloned().map(PlaceNode::Ty).collect();
554 *self = PlaceNode::Tuple(fields.clone(), node_fields);
555 let PlaceNode::Tuple(.., fields) = self else { unreachable!() };
556 fields
557 }
558 TyKind::Coroutine(def_id, args) => {
559 let fields = args
560 .as_coroutine()
561 .upvar_tys()
562 .cloned()
563 .map(PlaceNode::Ty)
564 .collect_vec();
565 *self = PlaceNode::Generator(*def_id, args.clone(), fields);
566 let PlaceNode::Generator(.., fields) = self else { unreachable!() };
567 fields
568 }
569 _ => tracked_span_bug!("implicit downcast of non-struct: `{ty:?}`"),
570 };
571 Ok((fields, true))
572 }
573 PlaceNode::Downcast(.., fields)
574 | PlaceNode::Tuple(.., fields)
575 | PlaceNode::Closure(.., fields)
576 | PlaceNode::Generator(.., fields) => Ok((fields, false)),
577 PlaceNode::Deref(..) => {
578 tracked_span_bug!("projection field of non-adt non-tuple place: `{self:?}`")
579 }
580 }
581 }
582
583 fn ensure_folded(&mut self) -> Modified {
584 match self {
585 PlaceNode::Deref(ty, _) => {
586 *self = PlaceNode::Ty(ty.clone());
587 true
588 }
589 PlaceNode::Downcast(adt, args, ..) => {
590 *self = PlaceNode::Ty(Ty::mk_adt(adt.clone(), args.clone()));
591 true
592 }
593 PlaceNode::Closure(did, args, _) => {
594 *self = PlaceNode::Ty(Ty::mk_closure(*did, args.clone()));
595 true
596 }
597 PlaceNode::Generator(did, args, _) => {
598 *self = PlaceNode::Ty(Ty::mk_coroutine(*did, args.clone()));
599 true
600 }
601 PlaceNode::Tuple(fields, ..) => {
602 *self = PlaceNode::Ty(Ty::mk_tuple(fields.clone()));
603 true
604 }
605 PlaceNode::Ty(_) => false,
606 }
607 }
608
609 fn join(
610 &mut self,
611 genv: GlobalEnv,
612 other: &mut PlaceNode,
613 in_mut_ref: bool,
614 ) -> QueryResult<(bool, bool)> {
615 let mut modified1 = false;
616 let mut modified2 = false;
617
618 let (fields1, fields2) = match (&mut *self, &mut *other) {
619 (PlaceNode::Deref(ty1, node1), PlaceNode::Deref(ty2, node2)) => {
620 debug_assert_eq!(ty1, ty2);
621 return node1.join(genv, node2, in_mut_ref || ty1.is_mut_ref());
622 }
623 (PlaceNode::Tuple(_, fields1), PlaceNode::Tuple(_, fields2)) => (fields1, fields2),
624 (PlaceNode::Closure(.., fields1), PlaceNode::Closure(.., fields2)) => {
625 (fields1, fields2)
626 }
627 (PlaceNode::Generator(.., fields1), PlaceNode::Generator(.., fields2)) => {
628 (fields1, fields2)
629 }
630 (
631 PlaceNode::Downcast(adt1, args1, variant1, fields1),
632 PlaceNode::Downcast(adt2, args2, variant2, fields2),
633 ) => {
634 debug_assert_eq!(adt1, adt2);
635 if variant1 == variant2 {
636 (fields1, fields2)
637 } else {
638 *self = PlaceNode::Ty(Ty::mk_adt(adt1.clone(), args1.clone()));
639 *other = PlaceNode::Ty(Ty::mk_adt(adt2.clone(), args2.clone()));
640 return Ok((true, true));
641 }
642 }
643 (PlaceNode::Ty(_), PlaceNode::Ty(_)) => return Ok((false, false)),
644 (PlaceNode::Ty(_), _) => {
645 let (m1, m2) = other.join(genv, self, in_mut_ref)?;
646 return Ok((m2, m1));
647 }
648 (PlaceNode::Deref(ty, _), _) => {
649 *self = PlaceNode::Ty(ty.clone());
650 return Ok((true, false));
651 }
652 (PlaceNode::Tuple(_, fields1), _) => {
653 let (fields2, m) = other.fields(genv)?;
654 modified2 |= m;
655 (fields1, fields2)
656 }
657 (PlaceNode::Closure(.., fields1), _) | (PlaceNode::Generator(.., fields1), _) => {
658 let (fields2, m) = other.fields(genv)?;
659 modified2 |= m;
660 (fields1, fields2)
661 }
662
663 (PlaceNode::Downcast(adt, args, .., fields1), _) => {
664 if adt.is_struct() && !in_mut_ref {
665 let (fields2, m) = other.fields(genv)?;
666 modified2 |= m;
667 (fields1, fields2)
668 } else {
669 *self = PlaceNode::Ty(Ty::mk_adt(adt.clone(), args.clone()));
670 return Ok((true, false));
671 }
672 }
673 };
674 for (node1, node2) in iter::zip(fields1, fields2) {
675 let (m1, m2) = node1.join(genv, node2, in_mut_ref)?;
676 modified1 |= m1;
677 modified2 |= m2;
678 }
679 Ok((modified1, modified2))
680 }
681
682 fn collect_fold_unfolds(
684 &self,
685 target: &PlaceNode,
686 place: &mut Place,
687 stmts: &mut StatementsAt,
688 ) {
689 let (fields1, fields2) = match (self, target) {
690 (PlaceNode::Deref(_, node1), PlaceNode::Deref(_, node2)) => {
691 place.projection.push(PlaceElem::Deref);
692 node1.collect_fold_unfolds(node2, place, stmts);
693 place.projection.pop();
694 return;
695 }
696 (PlaceNode::Tuple(_, fields1), PlaceNode::Tuple(_, fields2)) => (fields1, fields2),
697 (PlaceNode::Closure(.., fields1), PlaceNode::Closure(.., fields2))
698 | (PlaceNode::Generator(.., fields1), PlaceNode::Generator(.., fields2)) => {
699 (fields1, fields2)
700 }
701 (
702 PlaceNode::Downcast(adt1, .., idx1, fields1),
703 PlaceNode::Downcast(adt2, .., idx2, fields2),
704 ) => {
705 tracked_span_dbg_assert_eq!(adt1.did(), adt2.did());
706 tracked_span_dbg_assert_eq!(idx1, idx2);
707 (fields1, fields2)
708 }
709 (PlaceNode::Ty(_), PlaceNode::Ty(_)) => return,
710 (PlaceNode::Ty(_), _) => {
711 target.collect_unfolds(place, stmts);
712 return;
713 }
714 (_, PlaceNode::Ty(_)) => {
715 stmts.insert(GhostStatement::Fold(place.clone()));
716 return;
717 }
718 _ => tracked_span_bug!("{self:?} {target:?}"),
719 };
720 for (i, (node1, node2)) in iter::zip(fields1, fields2).enumerate() {
721 place.projection.push(PlaceElem::Field(FieldIdx::new(i)));
722 node1.collect_fold_unfolds(node2, place, stmts);
723 place.projection.pop();
724 }
725 }
726
727 fn collect_unfolds(&self, place: &mut Place, stmts: &mut StatementsAt) {
728 match self {
729 PlaceNode::Ty(_) => {}
730 PlaceNode::Deref(_, node) => {
731 if node.is_ty() {
732 stmts.insert(GhostStatement::Unfold(place.clone()));
733 } else {
734 place.projection.push(PlaceElem::Deref);
735 node.collect_unfolds(place, stmts);
736 place.projection.pop();
737 }
738 }
739 PlaceNode::Downcast(.., fields)
740 | PlaceNode::Closure(.., fields)
741 | PlaceNode::Generator(.., fields)
742 | PlaceNode::Tuple(.., fields) => {
743 let all_leaves = fields.iter().all(PlaceNode::is_ty);
744 if all_leaves {
745 stmts.insert(GhostStatement::Unfold(place.clone()));
746 } else {
747 if let Some(idx) = self.enum_variant() {
748 place.projection.push(PlaceElem::Downcast(None, idx));
749 }
750 for (i, node) in fields.iter().enumerate() {
751 place.projection.push(PlaceElem::Field(FieldIdx::new(i)));
752 node.collect_unfolds(place, stmts);
753 place.projection.pop();
754 }
755 if self.enum_variant().is_some() {
756 place.projection.pop();
757 }
758 }
759 }
760 }
761 }
762
763 fn collect_folds_at_ret(&self, place: &mut Place, stmts: &mut StatementsAt) {
764 let fields = match self {
765 PlaceNode::Deref(ty, deref_ty) => {
766 place.projection.push(PlaceElem::Deref);
767 if ty.is_mut_ref() {
768 stmts.insert(GhostStatement::Fold(place.clone()));
769 } else if ty.is_box() {
770 deref_ty.collect_folds_at_ret(place, stmts);
771 }
772 place.projection.pop();
773 return;
774 }
775 PlaceNode::Downcast(adt, _, idx, fields) => {
776 if adt.is_enum() {
777 place.projection.push(PlaceElem::Downcast(None, *idx));
778 }
779 fields
780 }
781 PlaceNode::Closure(_, _, fields)
782 | PlaceNode::Generator(_, _, fields)
783 | PlaceNode::Tuple(_, fields) => fields,
784 PlaceNode::Ty(_) => return,
785 };
786 for (i, node) in fields.iter().enumerate() {
787 place.projection.push(PlaceElem::Field(FieldIdx::new(i)));
788 node.collect_folds_at_ret(place, stmts);
789 place.projection.pop();
790 }
791 if let PlaceNode::Downcast(adt, ..) = self
792 && adt.is_enum()
793 {
794 place.projection.pop();
795 }
796 }
797
798 fn enum_variant(&self) -> Option<VariantIdx> {
799 if let PlaceNode::Downcast(adt, _, idx, _) = self
800 && adt.is_enum()
801 {
802 Some(*idx)
803 } else {
804 None
805 }
806 }
807
808 #[must_use]
812 fn is_ty(&self) -> bool {
813 matches!(self, Self::Ty(..))
814 }
815}
816
817fn downcast(
818 genv: GlobalEnv,
819 adt_def: &AdtDef,
820 args: &GenericArgs,
821 variant: VariantIdx,
822) -> QueryResult<Vec<PlaceNode>> {
823 adt_def
824 .variant(variant)
825 .fields
826 .iter()
827 .map(|field| {
828 let ty = genv.lower_type_of(field.did)?.subst(args);
829 QueryResult::Ok(PlaceNode::Ty(ty))
830 })
831 .try_collect()
832}
833
834fn downcast_struct(
835 genv: GlobalEnv,
836 adt_def: &AdtDef,
837 args: &GenericArgs,
838) -> QueryResult<Vec<PlaceNode>> {
839 adt_def
840 .non_enum_variant()
841 .fields
842 .iter()
843 .map(|field| {
844 let ty = genv.lower_type_of(field.did)?.subst(args);
845 QueryResult::Ok(PlaceNode::Ty(ty))
846 })
847 .try_collect()
848}
849
850impl fmt::Debug for Env {
851 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
852 write!(
853 f,
854 "{}",
855 self.map
856 .iter_enumerated()
857 .format_with(", ", |(local, node), f| f(&format_args!("{local:?}: {node:?}")))
858 )
859 }
860}
861
862impl fmt::Debug for PlaceNode {
863 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
864 match self {
865 PlaceNode::Deref(_, node) => write!(f, "*({node:?})"),
866 PlaceNode::Downcast(adt, args, variant, fields) => {
867 write!(f, "{}", def_id_to_string(adt.did()))?;
868 if !args.is_empty() {
869 write!(f, "<{:?}>", args.iter().format(", "),)?;
870 }
871 write!(f, "::{}", adt.variant(*variant).name)?;
872 if !fields.is_empty() {
873 write!(f, "({:?})", fields.iter().format(", "),)?;
874 }
875 Ok(())
876 }
877 PlaceNode::Closure(did, args, fields) => {
878 write!(f, "Closure {}", def_id_to_string(*did))?;
879 if !args.is_empty() {
880 write!(f, "<{:?}>", args.iter().format(", "),)?;
881 }
882 if !fields.is_empty() {
883 write!(f, "({:?})", fields.iter().format(", "),)?;
884 }
885 Ok(())
886 }
887 PlaceNode::Generator(did, args, fields) => {
888 write!(f, "Generator {}", def_id_to_string(*did))?;
889 if !args.is_empty() {
890 write!(f, "<{:?}>", args.iter().format(", "),)?;
891 }
892 if !fields.is_empty() {
893 write!(f, "({:?})", fields.iter().format(", "),)?;
894 }
895 Ok(())
896 }
897 PlaceNode::Tuple(_, fields) => write!(f, "({:?})", fields.iter().format(", ")),
898 PlaceNode::Ty(ty) => write!(f, "•{ty:?}"),
899 }
900 }
901}