1use std::{collections::hash_map::Entry, fmt, iter};
2
3use flux_common::{tracked_span_assert_eq, tracked_span_bug, tracked_span_dbg_assert_eq};
4use flux_middle::{
5 PlaceExt as _, def_id_to_string, global_env::GlobalEnv, queries::QueryResult, query_bug, rty,
6};
7use flux_rustc_bridge::{
8 mir::{
9 BasicBlock, Body, BorrowKind, FIRST_VARIANT, FieldIdx, Local, Location,
10 NonDivergingIntrinsic, Operand, Place, PlaceElem, PlaceRef, Rvalue, Statement,
11 StatementKind, Terminator, TerminatorKind, UnOp, VariantIdx,
12 },
13 ty::{AdtDef, GenericArgs, GenericArgsExt as _, List, Mutability, Ty, TyKind},
14};
15use itertools::{Itertools, repeat_n};
16use rustc_data_structures::unord::UnordMap;
17use rustc_hash::FxHashMap;
18use rustc_hir::def_id::DefId;
19use rustc_index::{Idx, IndexVec, bit_set::DenseBitSet};
20use rustc_middle::mir::{FakeReadCause, START_BLOCK};
21
22use super::{GhostStatements, StatementsAt};
23use crate::{
24 ghost_statements::{GhostStatement, Point},
25 queue::WorkQueue,
26};
27
28pub(crate) fn add_ghost_statements<'tcx>(
29 stmts: &mut GhostStatements,
30 genv: GlobalEnv<'_, 'tcx>,
31 body: &Body<'tcx>,
32 fn_sig: Option<&rty::EarlyBinder<rty::PolyFnSig>>,
33) -> QueryResult {
34 let mut bb_envs = FxHashMap::default();
35 FoldUnfoldAnalysis::new(genv, body, &mut bb_envs, Infer).run(fn_sig)?;
36
37 FoldUnfoldAnalysis::new(genv, body, &mut bb_envs, Elaboration { stmts }).run(fn_sig)
38}
39
40#[derive(Clone)]
41struct Env {
42 map: IndexVec<Local, PlaceNode>,
43}
44
45impl Env {
46 fn new(body: &Body) -> Self {
47 Self {
48 map: body
49 .local_decls
50 .iter()
51 .map(|decl| PlaceNode::Ty(decl.ty.clone()))
52 .collect(),
53 }
54 }
55
56 fn projection<'a>(&mut self, genv: GlobalEnv, place: &'a Place) -> QueryResult<ProjResult<'a>> {
57 let (node, place, modified) = self.ensure_unfolded(genv, place)?;
58 if modified {
59 Ok(ProjResult::Unfold(place))
60 } else if node.ensure_folded() {
61 Ok(ProjResult::Fold(place))
62 } else {
63 Ok(ProjResult::None)
64 }
65 }
66
67 fn downcast(&mut self, genv: GlobalEnv, place: &Place, variant_idx: VariantIdx) -> QueryResult {
68 let (node, ..) = self.ensure_unfolded(genv, place)?;
69 node.downcast(genv, variant_idx)?;
70 Ok(())
71 }
72
73 fn ensure_unfolded<'a>(
74 &mut self,
75 genv: GlobalEnv,
76 place: &'a Place,
77 ) -> QueryResult<(&mut PlaceNode, PlaceRef<'a>, Modified)> {
78 let mut node = &mut self.map[place.local];
79 let mut modified = false;
80 let mut i = 0;
81 while i < place.projection.len() {
82 let elem = place.projection[i];
83 let (n, m) = match elem {
84 PlaceElem::Deref => node.deref(),
85 PlaceElem::Field(f) => node.field(genv, f)?,
86 PlaceElem::Downcast(_, idx) => node.downcast(genv, idx)?,
87 PlaceElem::Index(_) | PlaceElem::ConstantIndex { .. } => break,
88 };
89 node = n;
90 modified |= m;
91 i += 1;
92 }
93 Ok((node, place.as_ref().truncate(i), modified))
94 }
95
96 fn join(&mut self, genv: GlobalEnv, mut other: Env) -> QueryResult<Modified> {
97 let mut modified = false;
98 for (local, node) in self.map.iter_enumerated_mut() {
99 let (m, _) = node.join(genv, &mut other.map[local], false)?;
100 modified |= m;
101 }
102 Ok(modified)
103 }
104
105 fn collect_fold_unfolds_at_goto(&self, target: &Env, stmts: &mut StatementsAt) {
106 for (local, node) in self.map.iter_enumerated() {
107 node.collect_fold_unfolds(&target.map[local], &mut Place::new(local, vec![]), stmts);
108 }
109 }
110
111 fn collect_folds_at_ret(&self, body: &Body, stmts: &mut StatementsAt) {
112 for local in body.args_iter() {
113 self.map[local].collect_folds_at_ret(&mut Place::new(local, vec![]), stmts);
114 }
115 }
116}
117
118type Modified = bool;
119
120struct FoldUnfoldAnalysis<'a, 'genv, 'tcx, M> {
121 genv: GlobalEnv<'genv, 'tcx>,
122 body: &'a Body<'tcx>,
123 bb_envs: &'a mut FxHashMap<BasicBlock, Env>,
124 visited: DenseBitSet<BasicBlock>,
125 queue: WorkQueue<'a>,
126 discriminants: UnordMap<Place, Place>,
127 point: Point,
128 mode: M,
129}
130
131trait Mode: Sized {
132 const NAME: &'static str;
133
134 fn projection(
135 analysis: &mut FoldUnfoldAnalysis<Self>,
136 env: &mut Env,
137 place: &Place,
138 ) -> QueryResult;
139
140 fn goto_join_point(
141 analysis: &mut FoldUnfoldAnalysis<Self>,
142 target: BasicBlock,
143 env: Env,
144 ) -> QueryResult<bool>;
145
146 fn ret(analysis: &mut FoldUnfoldAnalysis<Self>, env: &Env);
147}
148
149struct Infer;
150
151struct Elaboration<'a> {
152 stmts: &'a mut GhostStatements,
153}
154
155impl Elaboration<'_> {
156 fn insert_at(&mut self, point: Point, stmt: GhostStatement) {
157 self.stmts.insert_at(point, stmt);
158 }
159}
160
161#[derive(Debug)]
162enum ProjResult<'a> {
163 None,
164 Fold(PlaceRef<'a>),
165 Unfold(PlaceRef<'a>),
166}
167
168impl Mode for Infer {
169 const NAME: &'static str = "infer";
170
171 fn projection(
172 analysis: &mut FoldUnfoldAnalysis<Self>,
173 env: &mut Env,
174 place: &Place,
175 ) -> QueryResult {
176 env.projection(analysis.genv, place)?;
177 Ok(())
178 }
179
180 fn goto_join_point(
181 analysis: &mut FoldUnfoldAnalysis<Self>,
182 target: BasicBlock,
183 env: Env,
184 ) -> QueryResult<bool> {
185 let modified = match analysis.bb_envs.entry(target) {
186 Entry::Occupied(mut entry) => entry.get_mut().join(analysis.genv, env)?,
187 Entry::Vacant(entry) => {
188 entry.insert(env);
189 true
190 }
191 };
192 Ok(modified)
193 }
194
195 fn ret(_: &mut FoldUnfoldAnalysis<Self>, _: &Env) {}
196}
197
198impl Mode for Elaboration<'_> {
199 const NAME: &'static str = "elaboration";
200
201 fn projection(
202 analysis: &mut FoldUnfoldAnalysis<Self>,
203 env: &mut Env,
204 place: &Place,
205 ) -> QueryResult {
206 match env.projection(analysis.genv, place)? {
207 ProjResult::None => {}
208 ProjResult::Fold(place_ref) => {
209 tracked_span_assert_eq!(place_ref, place.as_ref());
210 let place = place.clone();
211 analysis
212 .mode
213 .insert_at(analysis.point, GhostStatement::Fold(place));
214 }
215 ProjResult::Unfold(place_ref) => {
216 tracked_span_assert_eq!(place_ref, place.as_ref());
217 match place_ref.last_projection() {
218 Some((base, PlaceElem::Deref | PlaceElem::Field(..))) => {
219 analysis
220 .mode
221 .insert_at(analysis.point, GhostStatement::Unfold(base.to_place()));
222 }
223 _ => Err(query_bug!("invalid projection for unfolding {place_ref:?}"))?,
224 }
225 }
226 }
227 Ok(())
228 }
229
230 fn goto_join_point(
231 analysis: &mut FoldUnfoldAnalysis<Self>,
232 target: BasicBlock,
233 env: Env,
234 ) -> QueryResult<bool> {
235 env.collect_fold_unfolds_at_goto(
236 &analysis.bb_envs[&target],
237 &mut analysis.mode.stmts.at(analysis.point),
238 );
239 Ok(!analysis.visited.contains(target))
240 }
241
242 fn ret(analysis: &mut FoldUnfoldAnalysis<Self>, env: &Env) {
243 env.collect_folds_at_ret(analysis.body, &mut analysis.mode.stmts.at(analysis.point));
244 }
245}
246
247#[derive(Clone)]
248enum PlaceNode {
249 Deref(Ty, Box<PlaceNode>),
250 Downcast(AdtDef, GenericArgs, VariantIdx, Vec<PlaceNode>),
251 Closure(DefId, GenericArgs, Vec<PlaceNode>),
252 Generator(DefId, GenericArgs, Vec<PlaceNode>),
253 Tuple(List<Ty>, Vec<PlaceNode>),
254 Ty(Ty),
255}
256
257impl<M: Mode> FoldUnfoldAnalysis<'_, '_, '_, M> {
258 fn run(mut self, fn_sig: Option<&rty::EarlyBinder<rty::PolyFnSig>>) -> QueryResult {
259 let mut env = Env::new(self.body);
260
261 if let Some(fn_sig) = fn_sig {
262 let fn_sig = fn_sig.as_ref().skip_binder().as_ref().skip_binder();
263 for (local, ty) in iter::zip(self.body.args_iter(), fn_sig.inputs()) {
264 if let rty::TyKind::StrgRef(..) | rty::Ref!(.., Mutability::Mut) = ty.kind() {
265 M::projection(&mut self, &mut env, &Place::new(local, vec![PlaceElem::Deref]))?;
266 }
267 }
268 }
269 self.goto(START_BLOCK, env)?;
270 while let Some(bb) = self.queue.pop() {
271 self.basic_block(bb, self.bb_envs[&bb].clone())?;
272 }
273 Ok(())
274 }
275
276 fn basic_block(&mut self, bb: BasicBlock, mut env: Env) -> QueryResult {
277 self.visited.insert(bb);
278 let data = &self.body.basic_blocks[bb];
279 for (statement_index, stmt) in data.statements.iter().enumerate() {
280 self.point = Point::BeforeLocation(Location { block: bb, statement_index });
281 self.statement(stmt, &mut env)?;
282 }
283 if let Some(terminator) = &data.terminator {
284 self.point = Point::BeforeLocation(self.body.terminator_loc(bb));
285 let successors = self.terminator(terminator, env)?;
286 for (env, target) in successors {
287 self.point = Point::Edge(bb, target);
288 self.goto(target, env)?;
289 }
290 }
291 Ok(())
292 }
293
294 fn statement(&mut self, stmt: &Statement, env: &mut Env) -> QueryResult {
295 match &stmt.kind {
296 StatementKind::FakeRead(box (FakeReadCause::ForIndex, place)) => {
297 M::projection(self, env, place)?;
298 }
299 StatementKind::Assign(place, rvalue) => {
300 match rvalue {
301 Rvalue::Len(place) => {
302 M::projection(self, env, place)?;
303 }
304 Rvalue::UnaryOp(UnOp::PtrMetadata, Operand::Copy(place))
305 | Rvalue::UnaryOp(UnOp::PtrMetadata, Operand::Move(place)) => {
306 let deref_place = place.deref();
307 M::projection(self, env, &deref_place)?;
308 }
309 Rvalue::Use(op)
310 | Rvalue::Cast(_, op, _)
311 | Rvalue::UnaryOp(_, op)
312 | Rvalue::ShallowInitBox(op, _) => {
313 self.operand(op, env)?;
314 }
315 Rvalue::Ref(.., bk, place) => {
316 if !matches!(bk, BorrowKind::Fake(_)) {
318 M::projection(self, env, place)?;
319 }
320 }
321 Rvalue::RawPtr(_, place) => {
322 M::projection(self, env, place)?;
323 }
324 Rvalue::BinaryOp(_, op1, op2) => {
325 self.operand(op1, env)?;
326 self.operand(op2, env)?;
327 }
328 Rvalue::Aggregate(_, args) => {
329 for arg in args {
330 self.operand(arg, env)?;
331 }
332 }
333
334 Rvalue::Discriminant(discr) => {
335 M::projection(self, env, discr)?;
336 self.discriminants.insert(place.clone(), discr.clone());
337 }
338 Rvalue::Repeat(op, _) => {
339 self.operand(op, env)?;
340 }
341 Rvalue::NullaryOp(_, _) => {}
342 }
343 M::projection(self, env, place)?;
344 }
345 StatementKind::Intrinsic(NonDivergingIntrinsic::Assume(op)) => {
346 self.operand(op, env)?;
347 }
348 StatementKind::SetDiscriminant(_, _)
349 | StatementKind::FakeRead(_)
350 | StatementKind::AscribeUserType(_, _)
351 | StatementKind::PlaceMention(_)
352 | StatementKind::Nop => {}
353 }
354 Ok(())
355 }
356
357 fn operand(&mut self, op: &Operand, env: &mut Env) -> QueryResult {
358 match op {
359 Operand::Copy(place) | Operand::Move(place) => {
360 M::projection(self, env, place)?;
361 }
362 Operand::Constant(_) => {}
363 }
364 Ok(())
365 }
366
367 fn terminator(
368 &mut self,
369 terminator: &Terminator,
370 mut env: Env,
371 ) -> QueryResult<Vec<(Env, BasicBlock)>> {
372 let mut successors = vec![];
373 match &terminator.kind {
374 TerminatorKind::Return => {
375 M::ret(self, &env);
376 }
377 TerminatorKind::Call { args, destination, target, .. } => {
378 for arg in args {
379 self.operand(arg, &mut env)?;
380 }
381 M::projection(self, &mut env, destination)?;
382 if let Some(target) = target {
383 successors.push((env, *target));
384 }
385 }
386 TerminatorKind::SwitchInt { discr, targets } => {
387 let is_match = match discr {
388 Operand::Copy(place) | Operand::Move(place) => {
389 M::projection(self, &mut env, place)?;
390 self.discriminants.remove(place)
391 }
392 Operand::Constant(_) => None,
393 };
394 if let Some(place) = is_match {
395 let discr_ty = place.ty(self.genv, &self.body.local_decls)?.ty;
396 let (adt, _) = discr_ty.expect_adt();
397
398 let mut remaining: FxHashMap<u128, VariantIdx> = adt
399 .discriminants()
400 .map(|(idx, discr)| (discr, idx))
401 .collect();
402 for (bits, target) in targets.iter() {
403 let variant_idx = remaining
404 .remove(&bits)
405 .expect("value doesn't correspond to any variant");
406
407 let mut env = env.clone();
410 env.downcast(self.genv, &place, variant_idx)?;
411 successors.push((env, target));
412 }
413 if remaining.len() == 1 {
414 let (_, variant_idx) = remaining
415 .into_iter()
416 .next()
417 .unwrap_or_else(|| tracked_span_bug!());
418 env.downcast(self.genv, &place, variant_idx)?;
419 }
420 successors.push((env, targets.otherwise()));
421 } else {
422 let n = targets.all_targets().len();
423 for (env, target) in iter::zip(repeat_n(env, n), targets.all_targets()) {
424 successors.push((env, *target));
425 }
426 }
427 }
428 TerminatorKind::Goto { target } => {
429 successors.push((env, *target));
430 }
431 TerminatorKind::Yield { resume, resume_arg, .. } => {
432 M::projection(self, &mut env, resume_arg)?;
433 successors.push((env, *resume));
434 }
435 TerminatorKind::Drop { place, target, .. } => {
436 M::projection(self, &mut env, place)?;
437 successors.push((env, *target));
438 }
439 TerminatorKind::Assert { cond, target, .. } => {
440 self.operand(cond, &mut env)?;
441 successors.push((env, *target));
442 }
443 TerminatorKind::FalseEdge { real_target, .. } => {
444 successors.push((env, *real_target));
445 }
446 TerminatorKind::FalseUnwind { real_target, .. } => {
447 successors.push((env, *real_target));
448 }
449 TerminatorKind::Unreachable
450 | TerminatorKind::UnwindResume
451 | TerminatorKind::CoroutineDrop => {}
452 }
453 Ok(successors)
454 }
455
456 fn goto(&mut self, target: BasicBlock, env: Env) -> QueryResult {
457 if self.body.is_join_point(target) {
458 if M::goto_join_point(self, target, env)? {
459 self.queue.insert(target);
460 }
461 Ok(())
462 } else {
463 self.basic_block(target, env)
464 }
465 }
466}
467
468impl<'a, 'genv, 'tcx, M> FoldUnfoldAnalysis<'a, 'genv, 'tcx, M> {
469 pub(crate) fn new(
470 genv: GlobalEnv<'genv, 'tcx>,
471 body: &'a Body<'tcx>,
472 bb_envs: &'a mut FxHashMap<BasicBlock, Env>,
473 mode: M,
474 ) -> Self {
475 Self {
476 genv,
477 body,
478 bb_envs,
479 discriminants: Default::default(),
480 point: Point::FunEntry,
481 visited: DenseBitSet::new_empty(body.basic_blocks.len()),
482 queue: WorkQueue::empty(body.basic_blocks.len(), &body.dominator_order_rank),
483 mode,
484 }
485 }
486}
487
488impl PlaceNode {
489 fn deref(&mut self) -> (&mut PlaceNode, Modified) {
490 match self {
491 PlaceNode::Deref(_, node) => (node, false),
492 PlaceNode::Ty(ty) => {
493 *self = PlaceNode::Deref(ty.clone(), Box::new(PlaceNode::Ty(ty.deref())));
494 let PlaceNode::Deref(_, node) = self else { unreachable!() };
495 (node, true)
496 }
497 _ => tracked_span_bug!("deref of non-deref place: `{:?}`", self),
498 }
499 }
500
501 fn downcast(
502 &mut self,
503 genv: GlobalEnv,
504 idx: VariantIdx,
505 ) -> QueryResult<(&mut PlaceNode, Modified)> {
506 match self {
507 PlaceNode::Downcast(.., idx2, _) => {
508 debug_assert_eq!(idx, *idx2);
509 Ok((self, false))
510 }
511 PlaceNode::Ty(ty) => {
512 if let TyKind::Adt(adt_def, args) = ty.kind() {
513 let fields = downcast(genv, adt_def, args, idx)?;
514 *self = PlaceNode::Downcast(adt_def.clone(), args.clone(), idx, fields);
515 Ok((self, true))
516 } else {
517 tracked_span_bug!("invalid downcast `{self:?}`");
518 }
519 }
520 _ => tracked_span_bug!("invalid downcast `{self:?}`"),
521 }
522 }
523
524 fn field(&mut self, genv: GlobalEnv, f: FieldIdx) -> QueryResult<(&mut PlaceNode, Modified)> {
525 let (fields, unfolded) = self.fields(genv)?;
526 Ok((&mut fields[f.as_usize()], unfolded))
527 }
528
529 fn fields(&mut self, genv: GlobalEnv) -> QueryResult<(&mut Vec<PlaceNode>, bool)> {
530 match self {
531 PlaceNode::Ty(ty) => {
532 let fields = match ty.kind() {
533 TyKind::Adt(adt_def, args) => {
534 let fields = downcast_struct(genv, adt_def, args)?;
535 *self = PlaceNode::Downcast(
536 adt_def.clone(),
537 args.clone(),
538 FIRST_VARIANT,
539 fields,
540 );
541 let PlaceNode::Downcast(.., fields) = self else { unreachable!() };
542 fields
543 }
544 TyKind::Closure(def_id, args) => {
545 let fields = args
546 .as_closure()
547 .upvar_tys()
548 .iter()
549 .cloned()
550 .map(PlaceNode::Ty)
551 .collect_vec();
552 *self = PlaceNode::Closure(*def_id, args.clone(), fields);
553 let PlaceNode::Closure(.., fields) = self else { unreachable!() };
554 fields
555 }
556 TyKind::Tuple(fields) => {
557 let node_fields = fields.iter().cloned().map(PlaceNode::Ty).collect();
558 *self = PlaceNode::Tuple(fields.clone(), node_fields);
559 let PlaceNode::Tuple(.., fields) = self else { unreachable!() };
560 fields
561 }
562 TyKind::Coroutine(def_id, args) => {
563 let fields = args
564 .as_coroutine()
565 .upvar_tys()
566 .cloned()
567 .map(PlaceNode::Ty)
568 .collect_vec();
569 *self = PlaceNode::Generator(*def_id, args.clone(), fields);
570 let PlaceNode::Generator(.., fields) = self else { unreachable!() };
571 fields
572 }
573 _ => tracked_span_bug!("implicit downcast of non-struct: `{ty:?}`"),
574 };
575 Ok((fields, true))
576 }
577 PlaceNode::Downcast(.., fields)
578 | PlaceNode::Tuple(.., fields)
579 | PlaceNode::Closure(.., fields)
580 | PlaceNode::Generator(.., fields) => Ok((fields, false)),
581 PlaceNode::Deref(..) => {
582 tracked_span_bug!("projection field of non-adt non-tuple place: `{self:?}`")
583 }
584 }
585 }
586
587 fn ensure_folded(&mut self) -> Modified {
588 match self {
589 PlaceNode::Deref(ty, _) => {
590 *self = PlaceNode::Ty(ty.clone());
591 true
592 }
593 PlaceNode::Downcast(adt, args, ..) => {
594 *self = PlaceNode::Ty(Ty::mk_adt(adt.clone(), args.clone()));
595 true
596 }
597 PlaceNode::Closure(did, args, _) => {
598 *self = PlaceNode::Ty(Ty::mk_closure(*did, args.clone()));
599 true
600 }
601 PlaceNode::Generator(did, args, _) => {
602 *self = PlaceNode::Ty(Ty::mk_coroutine(*did, args.clone()));
603 true
604 }
605 PlaceNode::Tuple(fields, ..) => {
606 *self = PlaceNode::Ty(Ty::mk_tuple(fields.clone()));
607 true
608 }
609 PlaceNode::Ty(_) => false,
610 }
611 }
612
613 fn join(
614 &mut self,
615 genv: GlobalEnv,
616 other: &mut PlaceNode,
617 in_mut_ref: bool,
618 ) -> QueryResult<(bool, bool)> {
619 let mut modified1 = false;
620 let mut modified2 = false;
621
622 let (fields1, fields2) = match (&mut *self, &mut *other) {
623 (PlaceNode::Deref(ty1, node1), PlaceNode::Deref(ty2, node2)) => {
624 debug_assert_eq!(ty1, ty2);
625 return node1.join(genv, node2, in_mut_ref || ty1.is_mut_ref());
626 }
627 (PlaceNode::Tuple(_, fields1), PlaceNode::Tuple(_, fields2)) => (fields1, fields2),
628 (PlaceNode::Closure(.., fields1), PlaceNode::Closure(.., fields2)) => {
629 (fields1, fields2)
630 }
631 (PlaceNode::Generator(.., fields1), PlaceNode::Generator(.., fields2)) => {
632 (fields1, fields2)
633 }
634 (
635 PlaceNode::Downcast(adt1, args1, variant1, fields1),
636 PlaceNode::Downcast(adt2, args2, variant2, fields2),
637 ) => {
638 debug_assert_eq!(adt1, adt2);
639 if variant1 == variant2 {
640 (fields1, fields2)
641 } else {
642 *self = PlaceNode::Ty(Ty::mk_adt(adt1.clone(), args1.clone()));
643 *other = PlaceNode::Ty(Ty::mk_adt(adt2.clone(), args2.clone()));
644 return Ok((true, true));
645 }
646 }
647 (PlaceNode::Ty(_), PlaceNode::Ty(_)) => return Ok((false, false)),
648 (PlaceNode::Ty(_), _) => {
649 let (m1, m2) = other.join(genv, self, in_mut_ref)?;
650 return Ok((m2, m1));
651 }
652 (PlaceNode::Deref(ty, _), _) => {
653 *self = PlaceNode::Ty(ty.clone());
654 return Ok((true, false));
655 }
656 (PlaceNode::Tuple(_, fields1), _) => {
657 let (fields2, m) = other.fields(genv)?;
658 modified2 |= m;
659 (fields1, fields2)
660 }
661 (PlaceNode::Closure(.., fields1), _) | (PlaceNode::Generator(.., fields1), _) => {
662 let (fields2, m) = other.fields(genv)?;
663 modified2 |= m;
664 (fields1, fields2)
665 }
666
667 (PlaceNode::Downcast(adt, args, .., fields1), _) => {
668 if adt.is_struct() && !in_mut_ref {
669 let (fields2, m) = other.fields(genv)?;
670 modified2 |= m;
671 (fields1, fields2)
672 } else {
673 *self = PlaceNode::Ty(Ty::mk_adt(adt.clone(), args.clone()));
674 return Ok((true, false));
675 }
676 }
677 };
678 for (node1, node2) in iter::zip(fields1, fields2) {
679 let (m1, m2) = node1.join(genv, node2, in_mut_ref)?;
680 modified1 |= m1;
681 modified2 |= m2;
682 }
683 Ok((modified1, modified2))
684 }
685
686 fn collect_fold_unfolds(
688 &self,
689 target: &PlaceNode,
690 place: &mut Place,
691 stmts: &mut StatementsAt,
692 ) {
693 let (fields1, fields2) = match (self, target) {
694 (PlaceNode::Deref(_, node1), PlaceNode::Deref(_, node2)) => {
695 place.projection.push(PlaceElem::Deref);
696 node1.collect_fold_unfolds(node2, place, stmts);
697 place.projection.pop();
698 return;
699 }
700 (PlaceNode::Tuple(_, fields1), PlaceNode::Tuple(_, fields2)) => (fields1, fields2),
701 (PlaceNode::Closure(.., fields1), PlaceNode::Closure(.., fields2))
702 | (PlaceNode::Generator(.., fields1), PlaceNode::Generator(.., fields2)) => {
703 (fields1, fields2)
704 }
705 (
706 PlaceNode::Downcast(adt1, .., idx1, fields1),
707 PlaceNode::Downcast(adt2, .., idx2, fields2),
708 ) => {
709 tracked_span_dbg_assert_eq!(adt1.did(), adt2.did());
710 tracked_span_dbg_assert_eq!(idx1, idx2);
711 (fields1, fields2)
712 }
713 (PlaceNode::Ty(_), PlaceNode::Ty(_)) => return,
714 (PlaceNode::Ty(_), _) => {
715 target.collect_unfolds(place, stmts);
716 return;
717 }
718 (_, PlaceNode::Ty(_)) => {
719 stmts.insert(GhostStatement::Fold(place.clone()));
720 return;
721 }
722 _ => tracked_span_bug!("{self:?} {target:?}"),
723 };
724 for (i, (node1, node2)) in iter::zip(fields1, fields2).enumerate() {
725 place.projection.push(PlaceElem::Field(FieldIdx::new(i)));
726 node1.collect_fold_unfolds(node2, place, stmts);
727 place.projection.pop();
728 }
729 }
730
731 fn collect_unfolds(&self, place: &mut Place, stmts: &mut StatementsAt) {
732 match self {
733 PlaceNode::Ty(_) => {}
734 PlaceNode::Deref(_, node) => {
735 if node.is_ty() {
736 stmts.insert(GhostStatement::Unfold(place.clone()));
737 } else {
738 place.projection.push(PlaceElem::Deref);
739 node.collect_unfolds(place, stmts);
740 place.projection.pop();
741 }
742 }
743 PlaceNode::Downcast(.., fields)
744 | PlaceNode::Closure(.., fields)
745 | PlaceNode::Generator(.., fields)
746 | PlaceNode::Tuple(.., fields) => {
747 let all_leaves = fields.iter().all(PlaceNode::is_ty);
748 if all_leaves {
749 stmts.insert(GhostStatement::Unfold(place.clone()));
750 } else {
751 if let Some(idx) = self.enum_variant() {
752 place.projection.push(PlaceElem::Downcast(None, idx));
753 }
754 for (i, node) in fields.iter().enumerate() {
755 place.projection.push(PlaceElem::Field(FieldIdx::new(i)));
756 node.collect_unfolds(place, stmts);
757 place.projection.pop();
758 }
759 if self.enum_variant().is_some() {
760 place.projection.pop();
761 }
762 }
763 }
764 }
765 }
766
767 fn collect_folds_at_ret(&self, place: &mut Place, stmts: &mut StatementsAt) {
768 let fields = match self {
769 PlaceNode::Deref(ty, deref_ty) => {
770 place.projection.push(PlaceElem::Deref);
771 if ty.is_mut_ref() {
772 stmts.insert(GhostStatement::Fold(place.clone()));
773 } else if ty.is_box() {
774 deref_ty.collect_folds_at_ret(place, stmts);
775 }
776 place.projection.pop();
777 return;
778 }
779 PlaceNode::Downcast(adt, _, idx, fields) => {
780 if adt.is_enum() {
781 place.projection.push(PlaceElem::Downcast(None, *idx));
782 }
783 fields
784 }
785 PlaceNode::Closure(_, _, fields)
786 | PlaceNode::Generator(_, _, fields)
787 | PlaceNode::Tuple(_, fields) => fields,
788 PlaceNode::Ty(_) => return,
789 };
790 for (i, node) in fields.iter().enumerate() {
791 place.projection.push(PlaceElem::Field(FieldIdx::new(i)));
792 node.collect_folds_at_ret(place, stmts);
793 place.projection.pop();
794 }
795 if let PlaceNode::Downcast(adt, ..) = self
796 && adt.is_enum()
797 {
798 place.projection.pop();
799 }
800 }
801
802 fn enum_variant(&self) -> Option<VariantIdx> {
803 if let PlaceNode::Downcast(adt, _, idx, _) = self
804 && adt.is_enum()
805 {
806 Some(*idx)
807 } else {
808 None
809 }
810 }
811
812 #[must_use]
816 fn is_ty(&self) -> bool {
817 matches!(self, Self::Ty(..))
818 }
819}
820
821fn downcast(
822 genv: GlobalEnv,
823 adt_def: &AdtDef,
824 args: &GenericArgs,
825 variant: VariantIdx,
826) -> QueryResult<Vec<PlaceNode>> {
827 adt_def
828 .variant(variant)
829 .fields
830 .iter()
831 .map(|field| {
832 let ty = genv.lower_type_of(field.did)?.subst(args);
833 QueryResult::Ok(PlaceNode::Ty(ty))
834 })
835 .try_collect()
836}
837
838fn downcast_struct(
839 genv: GlobalEnv,
840 adt_def: &AdtDef,
841 args: &GenericArgs,
842) -> QueryResult<Vec<PlaceNode>> {
843 adt_def
844 .non_enum_variant()
845 .fields
846 .iter()
847 .map(|field| {
848 let ty = genv.lower_type_of(field.did)?.subst(args);
849 QueryResult::Ok(PlaceNode::Ty(ty))
850 })
851 .try_collect()
852}
853
854impl fmt::Debug for Env {
855 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
856 write!(
857 f,
858 "{}",
859 self.map
860 .iter_enumerated()
861 .format_with(", ", |(local, node), f| f(&format_args!("{local:?}: {node:?}")))
862 )
863 }
864}
865
866impl fmt::Debug for PlaceNode {
867 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
868 match self {
869 PlaceNode::Deref(_, node) => write!(f, "*({node:?})"),
870 PlaceNode::Downcast(adt, args, variant, fields) => {
871 write!(f, "{}", def_id_to_string(adt.did()))?;
872 if !args.is_empty() {
873 write!(f, "<{:?}>", args.iter().format(", "),)?;
874 }
875 write!(f, "::{}", adt.variant(*variant).name)?;
876 if !fields.is_empty() {
877 write!(f, "({:?})", fields.iter().format(", "),)?;
878 }
879 Ok(())
880 }
881 PlaceNode::Closure(did, args, fields) => {
882 write!(f, "Closure {}", def_id_to_string(*did))?;
883 if !args.is_empty() {
884 write!(f, "<{:?}>", args.iter().format(", "),)?;
885 }
886 if !fields.is_empty() {
887 write!(f, "({:?})", fields.iter().format(", "),)?;
888 }
889 Ok(())
890 }
891 PlaceNode::Generator(did, args, fields) => {
892 write!(f, "Generator {}", def_id_to_string(*did))?;
893 if !args.is_empty() {
894 write!(f, "<{:?}>", args.iter().format(", "),)?;
895 }
896 if !fields.is_empty() {
897 write!(f, "({:?})", fields.iter().format(", "),)?;
898 }
899 Ok(())
900 }
901 PlaceNode::Tuple(_, fields) => write!(f, "({:?})", fields.iter().format(", ")),
902 PlaceNode::Ty(ty) => write!(f, "•{ty:?}"),
903 }
904 }
905}