flux_refineck/ghost_statements/
fold_unfold.rs

1use std::{collections::hash_map::Entry, fmt, iter};
2
3use flux_common::{tracked_span_assert_eq, tracked_span_bug, tracked_span_dbg_assert_eq};
4use flux_middle::{
5    PlaceExt as _, def_id_to_string, global_env::GlobalEnv, queries::QueryResult, query_bug, rty,
6};
7use flux_rustc_bridge::{
8    mir::{
9        BasicBlock, Body, BorrowKind, FIRST_VARIANT, FieldIdx, Local, Location,
10        NonDivergingIntrinsic, Operand, Place, PlaceElem, PlaceRef, Rvalue, Statement,
11        StatementKind, Terminator, TerminatorKind, UnOp, VariantIdx,
12    },
13    ty::{AdtDef, GenericArgs, GenericArgsExt as _, List, Mutability, Ty, TyKind},
14};
15use itertools::{Itertools, repeat_n};
16use rustc_data_structures::unord::UnordMap;
17use rustc_hash::FxHashMap;
18use rustc_hir::def_id::DefId;
19use rustc_index::{Idx, IndexVec, bit_set::DenseBitSet};
20use rustc_middle::mir::{FakeReadCause, START_BLOCK};
21
22use super::{GhostStatements, StatementsAt};
23use crate::{
24    ghost_statements::{GhostStatement, Point},
25    queue::WorkQueue,
26};
27
28pub(crate) fn add_ghost_statements<'tcx>(
29    stmts: &mut GhostStatements,
30    genv: GlobalEnv<'_, 'tcx>,
31    body: &Body<'tcx>,
32    fn_sig: Option<&rty::EarlyBinder<rty::PolyFnSig>>,
33) -> QueryResult {
34    let mut bb_envs = FxHashMap::default();
35    FoldUnfoldAnalysis::new(genv, body, &mut bb_envs, Infer).run(fn_sig)?;
36
37    FoldUnfoldAnalysis::new(genv, body, &mut bb_envs, Elaboration { stmts }).run(fn_sig)
38}
39
40#[derive(Clone)]
41struct Env {
42    map: IndexVec<Local, PlaceNode>,
43}
44
45impl Env {
46    fn new(body: &Body) -> Self {
47        Self {
48            map: body
49                .local_decls
50                .iter()
51                .map(|decl| PlaceNode::Ty(decl.ty.clone()))
52                .collect(),
53        }
54    }
55
56    fn projection<'a>(&mut self, genv: GlobalEnv, place: &'a Place) -> QueryResult<ProjResult<'a>> {
57        let (node, place, modified) = self.ensure_unfolded(genv, place)?;
58        if modified {
59            Ok(ProjResult::Unfold(place))
60        } else if node.ensure_folded() {
61            Ok(ProjResult::Fold(place))
62        } else {
63            Ok(ProjResult::None)
64        }
65    }
66
67    fn downcast(&mut self, genv: GlobalEnv, place: &Place, variant_idx: VariantIdx) -> QueryResult {
68        let (node, ..) = self.ensure_unfolded(genv, place)?;
69        node.downcast(genv, variant_idx)?;
70        Ok(())
71    }
72
73    fn ensure_unfolded<'a>(
74        &mut self,
75        genv: GlobalEnv,
76        place: &'a Place,
77    ) -> QueryResult<(&mut PlaceNode, PlaceRef<'a>, Modified)> {
78        let mut node = &mut self.map[place.local];
79        let mut modified = false;
80        let mut i = 0;
81        while i < place.projection.len() {
82            let elem = place.projection[i];
83            let (n, m) = match elem {
84                PlaceElem::Deref => node.deref(),
85                PlaceElem::Field(f) => node.field(genv, f)?,
86                PlaceElem::Downcast(_, idx) => node.downcast(genv, idx)?,
87                PlaceElem::Index(_) | PlaceElem::ConstantIndex { .. } => break,
88            };
89            node = n;
90            modified |= m;
91            i += 1;
92        }
93        Ok((node, place.as_ref().truncate(i), modified))
94    }
95
96    fn join(&mut self, genv: GlobalEnv, mut other: Env) -> QueryResult<Modified> {
97        let mut modified = false;
98        for (local, node) in self.map.iter_enumerated_mut() {
99            let (m, _) = node.join(genv, &mut other.map[local], false)?;
100            modified |= m;
101        }
102        Ok(modified)
103    }
104
105    fn collect_fold_unfolds_at_goto(&self, target: &Env, stmts: &mut StatementsAt) {
106        for (local, node) in self.map.iter_enumerated() {
107            node.collect_fold_unfolds(&target.map[local], &mut Place::new(local, vec![]), stmts);
108        }
109    }
110
111    fn collect_folds_at_ret(&self, body: &Body, stmts: &mut StatementsAt) {
112        for local in body.args_iter() {
113            self.map[local].collect_folds_at_ret(&mut Place::new(local, vec![]), stmts);
114        }
115    }
116}
117
118type Modified = bool;
119
120struct FoldUnfoldAnalysis<'a, 'genv, 'tcx, M> {
121    genv: GlobalEnv<'genv, 'tcx>,
122    body: &'a Body<'tcx>,
123    bb_envs: &'a mut FxHashMap<BasicBlock, Env>,
124    visited: DenseBitSet<BasicBlock>,
125    queue: WorkQueue<'a>,
126    discriminants: UnordMap<Place, Place>,
127    point: Point,
128    mode: M,
129}
130
131trait Mode: Sized {
132    const _NAME: &'static str;
133
134    fn projection(
135        analysis: &mut FoldUnfoldAnalysis<Self>,
136        env: &mut Env,
137        place: &Place,
138    ) -> QueryResult;
139
140    fn goto_join_point(
141        analysis: &mut FoldUnfoldAnalysis<Self>,
142        target: BasicBlock,
143        env: Env,
144    ) -> QueryResult<bool>;
145
146    fn ret(analysis: &mut FoldUnfoldAnalysis<Self>, env: &Env);
147}
148
149struct Infer;
150
151struct Elaboration<'a> {
152    stmts: &'a mut GhostStatements,
153}
154
155impl Elaboration<'_> {
156    fn insert_at(&mut self, point: Point, stmt: GhostStatement) {
157        self.stmts.insert_at(point, stmt);
158    }
159}
160
161#[derive(Debug)]
162enum ProjResult<'a> {
163    None,
164    Fold(PlaceRef<'a>),
165    Unfold(PlaceRef<'a>),
166}
167
168impl Mode for Infer {
169    const _NAME: &'static str = "infer";
170
171    fn projection(
172        analysis: &mut FoldUnfoldAnalysis<Self>,
173        env: &mut Env,
174        place: &Place,
175    ) -> QueryResult {
176        env.projection(analysis.genv, place)?;
177        Ok(())
178    }
179
180    fn goto_join_point(
181        analysis: &mut FoldUnfoldAnalysis<Self>,
182        target: BasicBlock,
183        env: Env,
184    ) -> QueryResult<bool> {
185        let modified = match analysis.bb_envs.entry(target) {
186            Entry::Occupied(mut entry) => entry.get_mut().join(analysis.genv, env)?,
187            Entry::Vacant(entry) => {
188                entry.insert(env);
189                true
190            }
191        };
192        Ok(modified)
193    }
194
195    fn ret(_: &mut FoldUnfoldAnalysis<Self>, _: &Env) {}
196}
197
198impl Mode for Elaboration<'_> {
199    const _NAME: &'static str = "elaboration";
200
201    fn projection(
202        analysis: &mut FoldUnfoldAnalysis<Self>,
203        env: &mut Env,
204        place: &Place,
205    ) -> QueryResult {
206        match env.projection(analysis.genv, place)? {
207            ProjResult::None => {}
208            ProjResult::Fold(place_ref) => {
209                tracked_span_assert_eq!(place_ref, place.as_ref());
210                let place = place.clone();
211                analysis
212                    .mode
213                    .insert_at(analysis.point, GhostStatement::Fold(place));
214            }
215            ProjResult::Unfold(place_ref) => {
216                tracked_span_assert_eq!(place_ref, place.as_ref());
217                match place_ref.last_projection() {
218                    Some((base, PlaceElem::Deref | PlaceElem::Field(..))) => {
219                        analysis
220                            .mode
221                            .insert_at(analysis.point, GhostStatement::Unfold(base.to_place()));
222                    }
223                    _ => Err(query_bug!("invalid projection for unfolding {place_ref:?}"))?,
224                }
225            }
226        }
227        Ok(())
228    }
229
230    fn goto_join_point(
231        analysis: &mut FoldUnfoldAnalysis<Self>,
232        target: BasicBlock,
233        env: Env,
234    ) -> QueryResult<bool> {
235        env.collect_fold_unfolds_at_goto(
236            &analysis.bb_envs[&target],
237            &mut analysis.mode.stmts.at(analysis.point),
238        );
239        Ok(!analysis.visited.contains(target))
240    }
241
242    fn ret(analysis: &mut FoldUnfoldAnalysis<Self>, env: &Env) {
243        env.collect_folds_at_ret(analysis.body, &mut analysis.mode.stmts.at(analysis.point));
244    }
245}
246
247#[derive(Clone)]
248enum PlaceNode {
249    Deref(Ty, Box<PlaceNode>),
250    Downcast(AdtDef, GenericArgs, VariantIdx, Vec<PlaceNode>),
251    Closure(DefId, GenericArgs, Vec<PlaceNode>),
252    Generator(DefId, GenericArgs, Vec<PlaceNode>),
253    Tuple(List<Ty>, Vec<PlaceNode>),
254    Ty(Ty),
255}
256
257impl<M: Mode> FoldUnfoldAnalysis<'_, '_, '_, M> {
258    fn run(mut self, fn_sig: Option<&rty::EarlyBinder<rty::PolyFnSig>>) -> QueryResult {
259        let mut env = Env::new(self.body);
260
261        if let Some(fn_sig) = fn_sig {
262            let fn_sig = fn_sig.as_ref().skip_binder().as_ref().skip_binder();
263            for (local, ty) in iter::zip(self.body.args_iter(), fn_sig.inputs()) {
264                if let rty::TyKind::StrgRef(..) | rty::Ref!(.., Mutability::Mut) = ty.kind() {
265                    M::projection(&mut self, &mut env, &Place::new(local, vec![PlaceElem::Deref]))?;
266                }
267            }
268        }
269        self.goto(START_BLOCK, env)?;
270        while let Some(bb) = self.queue.pop() {
271            self.basic_block(bb, self.bb_envs[&bb].clone())?;
272        }
273        Ok(())
274    }
275
276    fn basic_block(&mut self, bb: BasicBlock, mut env: Env) -> QueryResult {
277        self.visited.insert(bb);
278        let data = &self.body.basic_blocks[bb];
279        for (statement_index, stmt) in data.statements.iter().enumerate() {
280            self.point = Point::BeforeLocation(Location { block: bb, statement_index });
281            self.statement(stmt, &mut env)?;
282        }
283        if let Some(terminator) = &data.terminator {
284            self.point = Point::BeforeLocation(self.body.terminator_loc(bb));
285            let successors = self.terminator(terminator, env)?;
286            for (env, target) in successors {
287                self.point = Point::Edge(bb, target);
288                self.goto(target, env)?;
289            }
290        }
291        Ok(())
292    }
293
294    fn statement(&mut self, stmt: &Statement, env: &mut Env) -> QueryResult {
295        match &stmt.kind {
296            StatementKind::FakeRead(box (FakeReadCause::ForIndex, place)) => {
297                M::projection(self, env, place)?;
298            }
299            StatementKind::Assign(place, rvalue) => {
300                match rvalue {
301                    Rvalue::UnaryOp(UnOp::PtrMetadata, Operand::Copy(place))
302                    | Rvalue::UnaryOp(UnOp::PtrMetadata, Operand::Move(place)) => {
303                        let deref_place = place.deref();
304                        M::projection(self, env, &deref_place)?;
305                    }
306                    Rvalue::Use(op)
307                    | Rvalue::Cast(_, op, _)
308                    | Rvalue::UnaryOp(_, op)
309                    | Rvalue::ShallowInitBox(op, _) => {
310                        self.operand(op, env)?;
311                    }
312                    Rvalue::Ref(.., bk, place) => {
313                        // Fake borrows should not cause the place to fold
314                        if !matches!(bk, BorrowKind::Fake(_)) {
315                            M::projection(self, env, place)?;
316                        }
317                    }
318                    Rvalue::RawPtr(_, place) => {
319                        M::projection(self, env, place)?;
320                    }
321                    Rvalue::BinaryOp(_, op1, op2) => {
322                        self.operand(op1, env)?;
323                        self.operand(op2, env)?;
324                    }
325                    Rvalue::Aggregate(_, args) => {
326                        for arg in args {
327                            self.operand(arg, env)?;
328                        }
329                    }
330
331                    Rvalue::Discriminant(discr) => {
332                        M::projection(self, env, discr)?;
333                        self.discriminants.insert(place.clone(), discr.clone());
334                    }
335                    Rvalue::Repeat(op, _) => {
336                        self.operand(op, env)?;
337                    }
338                }
339                M::projection(self, env, place)?;
340            }
341            StatementKind::Intrinsic(NonDivergingIntrinsic::Assume(op)) => {
342                self.operand(op, env)?;
343            }
344            StatementKind::SetDiscriminant(_, _)
345            | StatementKind::FakeRead(_)
346            | StatementKind::AscribeUserType(_, _)
347            | StatementKind::PlaceMention(_)
348            | StatementKind::Nop => {}
349        }
350        Ok(())
351    }
352
353    fn operand(&mut self, op: &Operand, env: &mut Env) -> QueryResult {
354        match op {
355            Operand::Copy(place) | Operand::Move(place) => {
356                M::projection(self, env, place)?;
357            }
358            Operand::Constant(_) => {}
359        }
360        Ok(())
361    }
362
363    fn terminator(
364        &mut self,
365        terminator: &Terminator,
366        mut env: Env,
367    ) -> QueryResult<Vec<(Env, BasicBlock)>> {
368        let mut successors = vec![];
369        match &terminator.kind {
370            TerminatorKind::Return => {
371                M::ret(self, &env);
372            }
373            TerminatorKind::Call { args, destination, target, .. } => {
374                for arg in args {
375                    self.operand(arg, &mut env)?;
376                }
377                M::projection(self, &mut env, destination)?;
378                if let Some(target) = target {
379                    successors.push((env, *target));
380                }
381            }
382            TerminatorKind::SwitchInt { discr, targets } => {
383                let is_match = match discr {
384                    Operand::Copy(place) | Operand::Move(place) => {
385                        M::projection(self, &mut env, place)?;
386                        self.discriminants.remove(place)
387                    }
388                    Operand::Constant(_) => None,
389                };
390                if let Some(place) = is_match {
391                    let discr_ty = place.ty(self.genv, &self.body.local_decls)?.ty;
392                    let (adt, _) = discr_ty.expect_adt();
393
394                    let mut remaining: FxHashMap<u128, VariantIdx> = adt
395                        .discriminants()
396                        .map(|(idx, discr)| (discr, idx))
397                        .collect();
398                    for (bits, target) in targets.iter() {
399                        let variant_idx = remaining
400                            .remove(&bits)
401                            .expect("value doesn't correspond to any variant");
402
403                        // We do not insert unfolds in match arms because they are explicit
404                        // unfold points.
405                        let mut env = env.clone();
406                        env.downcast(self.genv, &place, variant_idx)?;
407                        successors.push((env, target));
408                    }
409                    if remaining.len() == 1 {
410                        let (_, variant_idx) = remaining
411                            .into_iter()
412                            .next()
413                            .unwrap_or_else(|| tracked_span_bug!());
414                        env.downcast(self.genv, &place, variant_idx)?;
415                    }
416                    successors.push((env, targets.otherwise()));
417                } else {
418                    let n = targets.all_targets().len();
419                    for (env, target) in iter::zip(repeat_n(env, n), targets.all_targets()) {
420                        successors.push((env, *target));
421                    }
422                }
423            }
424            TerminatorKind::Goto { target } => {
425                successors.push((env, *target));
426            }
427            TerminatorKind::Yield { resume, resume_arg, .. } => {
428                M::projection(self, &mut env, resume_arg)?;
429                successors.push((env, *resume));
430            }
431            TerminatorKind::Drop { place, target, .. } => {
432                M::projection(self, &mut env, place)?;
433                successors.push((env, *target));
434            }
435            TerminatorKind::Assert { cond, target, .. } => {
436                self.operand(cond, &mut env)?;
437                successors.push((env, *target));
438            }
439            TerminatorKind::FalseEdge { real_target, .. } => {
440                successors.push((env, *real_target));
441            }
442            TerminatorKind::FalseUnwind { real_target, .. } => {
443                successors.push((env, *real_target));
444            }
445            TerminatorKind::Unreachable
446            | TerminatorKind::UnwindResume
447            | TerminatorKind::CoroutineDrop => {}
448        }
449        Ok(successors)
450    }
451
452    fn goto(&mut self, target: BasicBlock, env: Env) -> QueryResult {
453        if self.body.is_join_point(target) {
454            if M::goto_join_point(self, target, env)? {
455                self.queue.insert(target);
456            }
457            Ok(())
458        } else {
459            self.basic_block(target, env)
460        }
461    }
462}
463
464impl<'a, 'genv, 'tcx, M> FoldUnfoldAnalysis<'a, 'genv, 'tcx, M> {
465    pub(crate) fn new(
466        genv: GlobalEnv<'genv, 'tcx>,
467        body: &'a Body<'tcx>,
468        bb_envs: &'a mut FxHashMap<BasicBlock, Env>,
469        mode: M,
470    ) -> Self {
471        Self {
472            genv,
473            body,
474            bb_envs,
475            discriminants: Default::default(),
476            point: Point::FunEntry,
477            visited: DenseBitSet::new_empty(body.basic_blocks.len()),
478            queue: WorkQueue::empty(body.basic_blocks.len(), &body.dominator_order_rank),
479            mode,
480        }
481    }
482}
483
484impl PlaceNode {
485    fn deref(&mut self) -> (&mut PlaceNode, Modified) {
486        match self {
487            PlaceNode::Deref(_, node) => (node, false),
488            PlaceNode::Ty(ty) => {
489                *self = PlaceNode::Deref(ty.clone(), Box::new(PlaceNode::Ty(ty.deref())));
490                let PlaceNode::Deref(_, node) = self else { unreachable!() };
491                (node, true)
492            }
493            _ => tracked_span_bug!("deref of non-deref place: `{:?}`", self),
494        }
495    }
496
497    fn downcast(
498        &mut self,
499        genv: GlobalEnv,
500        idx: VariantIdx,
501    ) -> QueryResult<(&mut PlaceNode, Modified)> {
502        match self {
503            PlaceNode::Downcast(.., idx2, _) => {
504                debug_assert_eq!(idx, *idx2);
505                Ok((self, false))
506            }
507            PlaceNode::Ty(ty) => {
508                if let TyKind::Adt(adt_def, args) = ty.kind() {
509                    let fields = downcast(genv, adt_def, args, idx)?;
510                    *self = PlaceNode::Downcast(adt_def.clone(), args.clone(), idx, fields);
511                    Ok((self, true))
512                } else {
513                    tracked_span_bug!("invalid downcast `{self:?}`");
514                }
515            }
516            _ => tracked_span_bug!("invalid downcast `{self:?}`"),
517        }
518    }
519
520    fn field(&mut self, genv: GlobalEnv, f: FieldIdx) -> QueryResult<(&mut PlaceNode, Modified)> {
521        let (fields, unfolded) = self.fields(genv)?;
522        Ok((&mut fields[f.as_usize()], unfolded))
523    }
524
525    fn fields(&mut self, genv: GlobalEnv) -> QueryResult<(&mut Vec<PlaceNode>, bool)> {
526        match self {
527            PlaceNode::Ty(ty) => {
528                let fields = match ty.kind() {
529                    TyKind::Adt(adt_def, args) => {
530                        let fields = downcast_struct(genv, adt_def, args)?;
531                        *self = PlaceNode::Downcast(
532                            adt_def.clone(),
533                            args.clone(),
534                            FIRST_VARIANT,
535                            fields,
536                        );
537                        let PlaceNode::Downcast(.., fields) = self else { unreachable!() };
538                        fields
539                    }
540                    TyKind::Closure(def_id, args) => {
541                        let fields = args
542                            .as_closure()
543                            .upvar_tys()
544                            .iter()
545                            .cloned()
546                            .map(PlaceNode::Ty)
547                            .collect_vec();
548                        *self = PlaceNode::Closure(*def_id, args.clone(), fields);
549                        let PlaceNode::Closure(.., fields) = self else { unreachable!() };
550                        fields
551                    }
552                    TyKind::Tuple(fields) => {
553                        let node_fields = fields.iter().cloned().map(PlaceNode::Ty).collect();
554                        *self = PlaceNode::Tuple(fields.clone(), node_fields);
555                        let PlaceNode::Tuple(.., fields) = self else { unreachable!() };
556                        fields
557                    }
558                    TyKind::Coroutine(def_id, args) => {
559                        let fields = args
560                            .as_coroutine()
561                            .upvar_tys()
562                            .cloned()
563                            .map(PlaceNode::Ty)
564                            .collect_vec();
565                        *self = PlaceNode::Generator(*def_id, args.clone(), fields);
566                        let PlaceNode::Generator(.., fields) = self else { unreachable!() };
567                        fields
568                    }
569                    _ => tracked_span_bug!("implicit downcast of non-struct: `{ty:?}`"),
570                };
571                Ok((fields, true))
572            }
573            PlaceNode::Downcast(.., fields)
574            | PlaceNode::Tuple(.., fields)
575            | PlaceNode::Closure(.., fields)
576            | PlaceNode::Generator(.., fields) => Ok((fields, false)),
577            PlaceNode::Deref(..) => {
578                tracked_span_bug!("projection field of non-adt non-tuple place: `{self:?}`")
579            }
580        }
581    }
582
583    fn ensure_folded(&mut self) -> Modified {
584        match self {
585            PlaceNode::Deref(ty, _) => {
586                *self = PlaceNode::Ty(ty.clone());
587                true
588            }
589            PlaceNode::Downcast(adt, args, ..) => {
590                *self = PlaceNode::Ty(Ty::mk_adt(adt.clone(), args.clone()));
591                true
592            }
593            PlaceNode::Closure(did, args, _) => {
594                *self = PlaceNode::Ty(Ty::mk_closure(*did, args.clone()));
595                true
596            }
597            PlaceNode::Generator(did, args, _) => {
598                *self = PlaceNode::Ty(Ty::mk_coroutine(*did, args.clone()));
599                true
600            }
601            PlaceNode::Tuple(fields, ..) => {
602                *self = PlaceNode::Ty(Ty::mk_tuple(fields.clone()));
603                true
604            }
605            PlaceNode::Ty(_) => false,
606        }
607    }
608
609    fn join(
610        &mut self,
611        genv: GlobalEnv,
612        other: &mut PlaceNode,
613        in_mut_ref: bool,
614    ) -> QueryResult<(bool, bool)> {
615        let mut modified1 = false;
616        let mut modified2 = false;
617
618        let (fields1, fields2) = match (&mut *self, &mut *other) {
619            (PlaceNode::Deref(ty1, node1), PlaceNode::Deref(ty2, node2)) => {
620                debug_assert_eq!(ty1, ty2);
621                return node1.join(genv, node2, in_mut_ref || ty1.is_mut_ref());
622            }
623            (PlaceNode::Tuple(_, fields1), PlaceNode::Tuple(_, fields2)) => (fields1, fields2),
624            (PlaceNode::Closure(.., fields1), PlaceNode::Closure(.., fields2)) => {
625                (fields1, fields2)
626            }
627            (PlaceNode::Generator(.., fields1), PlaceNode::Generator(.., fields2)) => {
628                (fields1, fields2)
629            }
630            (
631                PlaceNode::Downcast(adt1, args1, variant1, fields1),
632                PlaceNode::Downcast(adt2, args2, variant2, fields2),
633            ) => {
634                debug_assert_eq!(adt1, adt2);
635                if variant1 == variant2 {
636                    (fields1, fields2)
637                } else {
638                    *self = PlaceNode::Ty(Ty::mk_adt(adt1.clone(), args1.clone()));
639                    *other = PlaceNode::Ty(Ty::mk_adt(adt2.clone(), args2.clone()));
640                    return Ok((true, true));
641                }
642            }
643            (PlaceNode::Ty(_), PlaceNode::Ty(_)) => return Ok((false, false)),
644            (PlaceNode::Ty(_), _) => {
645                let (m1, m2) = other.join(genv, self, in_mut_ref)?;
646                return Ok((m2, m1));
647            }
648            (PlaceNode::Deref(ty, _), _) => {
649                *self = PlaceNode::Ty(ty.clone());
650                return Ok((true, false));
651            }
652            (PlaceNode::Tuple(_, fields1), _) => {
653                let (fields2, m) = other.fields(genv)?;
654                modified2 |= m;
655                (fields1, fields2)
656            }
657            (PlaceNode::Closure(.., fields1), _) | (PlaceNode::Generator(.., fields1), _) => {
658                let (fields2, m) = other.fields(genv)?;
659                modified2 |= m;
660                (fields1, fields2)
661            }
662
663            (PlaceNode::Downcast(adt, args, .., fields1), _) => {
664                if adt.is_struct() && !in_mut_ref {
665                    let (fields2, m) = other.fields(genv)?;
666                    modified2 |= m;
667                    (fields1, fields2)
668                } else {
669                    *self = PlaceNode::Ty(Ty::mk_adt(adt.clone(), args.clone()));
670                    return Ok((true, false));
671                }
672            }
673        };
674        for (node1, node2) in iter::zip(fields1, fields2) {
675            let (m1, m2) = node1.join(genv, node2, in_mut_ref)?;
676            modified1 |= m1;
677            modified2 |= m2;
678        }
679        Ok((modified1, modified2))
680    }
681
682    /// Collect necessary fold/unfold operations such that `self` is unfolded at the same level than `target`
683    fn collect_fold_unfolds(
684        &self,
685        target: &PlaceNode,
686        place: &mut Place,
687        stmts: &mut StatementsAt,
688    ) {
689        let (fields1, fields2) = match (self, target) {
690            (PlaceNode::Deref(_, node1), PlaceNode::Deref(_, node2)) => {
691                place.projection.push(PlaceElem::Deref);
692                node1.collect_fold_unfolds(node2, place, stmts);
693                place.projection.pop();
694                return;
695            }
696            (PlaceNode::Tuple(_, fields1), PlaceNode::Tuple(_, fields2)) => (fields1, fields2),
697            (PlaceNode::Closure(.., fields1), PlaceNode::Closure(.., fields2))
698            | (PlaceNode::Generator(.., fields1), PlaceNode::Generator(.., fields2)) => {
699                (fields1, fields2)
700            }
701            (
702                PlaceNode::Downcast(adt1, .., idx1, fields1),
703                PlaceNode::Downcast(adt2, .., idx2, fields2),
704            ) => {
705                tracked_span_dbg_assert_eq!(adt1.did(), adt2.did());
706                tracked_span_dbg_assert_eq!(idx1, idx2);
707                (fields1, fields2)
708            }
709            (PlaceNode::Ty(_), PlaceNode::Ty(_)) => return,
710            (PlaceNode::Ty(_), _) => {
711                target.collect_unfolds(place, stmts);
712                return;
713            }
714            (_, PlaceNode::Ty(_)) => {
715                stmts.insert(GhostStatement::Fold(place.clone()));
716                return;
717            }
718            _ => tracked_span_bug!("{self:?} {target:?}"),
719        };
720        for (i, (node1, node2)) in iter::zip(fields1, fields2).enumerate() {
721            place.projection.push(PlaceElem::Field(FieldIdx::new(i)));
722            node1.collect_fold_unfolds(node2, place, stmts);
723            place.projection.pop();
724        }
725    }
726
727    fn collect_unfolds(&self, place: &mut Place, stmts: &mut StatementsAt) {
728        match self {
729            PlaceNode::Ty(_) => {}
730            PlaceNode::Deref(_, node) => {
731                if node.is_ty() {
732                    stmts.insert(GhostStatement::Unfold(place.clone()));
733                } else {
734                    place.projection.push(PlaceElem::Deref);
735                    node.collect_unfolds(place, stmts);
736                    place.projection.pop();
737                }
738            }
739            PlaceNode::Downcast(.., fields)
740            | PlaceNode::Closure(.., fields)
741            | PlaceNode::Generator(.., fields)
742            | PlaceNode::Tuple(.., fields) => {
743                let all_leaves = fields.iter().all(PlaceNode::is_ty);
744                if all_leaves {
745                    stmts.insert(GhostStatement::Unfold(place.clone()));
746                } else {
747                    if let Some(idx) = self.enum_variant() {
748                        place.projection.push(PlaceElem::Downcast(None, idx));
749                    }
750                    for (i, node) in fields.iter().enumerate() {
751                        place.projection.push(PlaceElem::Field(FieldIdx::new(i)));
752                        node.collect_unfolds(place, stmts);
753                        place.projection.pop();
754                    }
755                    if self.enum_variant().is_some() {
756                        place.projection.pop();
757                    }
758                }
759            }
760        }
761    }
762
763    fn collect_folds_at_ret(&self, place: &mut Place, stmts: &mut StatementsAt) {
764        let fields = match self {
765            PlaceNode::Deref(ty, deref_ty) => {
766                place.projection.push(PlaceElem::Deref);
767                if ty.is_mut_ref() {
768                    stmts.insert(GhostStatement::Fold(place.clone()));
769                } else if ty.is_box() {
770                    deref_ty.collect_folds_at_ret(place, stmts);
771                }
772                place.projection.pop();
773                return;
774            }
775            PlaceNode::Downcast(adt, _, idx, fields) => {
776                if adt.is_enum() {
777                    place.projection.push(PlaceElem::Downcast(None, *idx));
778                }
779                fields
780            }
781            PlaceNode::Closure(_, _, fields)
782            | PlaceNode::Generator(_, _, fields)
783            | PlaceNode::Tuple(_, fields) => fields,
784            PlaceNode::Ty(_) => return,
785        };
786        for (i, node) in fields.iter().enumerate() {
787            place.projection.push(PlaceElem::Field(FieldIdx::new(i)));
788            node.collect_folds_at_ret(place, stmts);
789            place.projection.pop();
790        }
791        if let PlaceNode::Downcast(adt, ..) = self
792            && adt.is_enum()
793        {
794            place.projection.pop();
795        }
796    }
797
798    fn enum_variant(&self) -> Option<VariantIdx> {
799        if let PlaceNode::Downcast(adt, _, idx, _) = self
800            && adt.is_enum()
801        {
802            Some(*idx)
803        } else {
804            None
805        }
806    }
807
808    /// Returns `true` if the place node is [`Ty`].
809    ///
810    /// [`Ty`]: PlaceNode::Ty
811    #[must_use]
812    fn is_ty(&self) -> bool {
813        matches!(self, Self::Ty(..))
814    }
815}
816
817fn downcast(
818    genv: GlobalEnv,
819    adt_def: &AdtDef,
820    args: &GenericArgs,
821    variant: VariantIdx,
822) -> QueryResult<Vec<PlaceNode>> {
823    adt_def
824        .variant(variant)
825        .fields
826        .iter()
827        .map(|field| {
828            let ty = genv.lower_type_of(field.did)?.subst(args);
829            QueryResult::Ok(PlaceNode::Ty(ty))
830        })
831        .try_collect()
832}
833
834fn downcast_struct(
835    genv: GlobalEnv,
836    adt_def: &AdtDef,
837    args: &GenericArgs,
838) -> QueryResult<Vec<PlaceNode>> {
839    adt_def
840        .non_enum_variant()
841        .fields
842        .iter()
843        .map(|field| {
844            let ty = genv.lower_type_of(field.did)?.subst(args);
845            QueryResult::Ok(PlaceNode::Ty(ty))
846        })
847        .try_collect()
848}
849
850impl fmt::Debug for Env {
851    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
852        write!(
853            f,
854            "{}",
855            self.map
856                .iter_enumerated()
857                .format_with(", ", |(local, node), f| f(&format_args!("{local:?}: {node:?}")))
858        )
859    }
860}
861
862impl fmt::Debug for PlaceNode {
863    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
864        match self {
865            PlaceNode::Deref(_, node) => write!(f, "*({node:?})"),
866            PlaceNode::Downcast(adt, args, variant, fields) => {
867                write!(f, "{}", def_id_to_string(adt.did()))?;
868                if !args.is_empty() {
869                    write!(f, "<{:?}>", args.iter().format(", "),)?;
870                }
871                write!(f, "::{}", adt.variant(*variant).name)?;
872                if !fields.is_empty() {
873                    write!(f, "({:?})", fields.iter().format(", "),)?;
874                }
875                Ok(())
876            }
877            PlaceNode::Closure(did, args, fields) => {
878                write!(f, "Closure {}", def_id_to_string(*did))?;
879                if !args.is_empty() {
880                    write!(f, "<{:?}>", args.iter().format(", "),)?;
881                }
882                if !fields.is_empty() {
883                    write!(f, "({:?})", fields.iter().format(", "),)?;
884                }
885                Ok(())
886            }
887            PlaceNode::Generator(did, args, fields) => {
888                write!(f, "Generator {}", def_id_to_string(*did))?;
889                if !args.is_empty() {
890                    write!(f, "<{:?}>", args.iter().format(", "),)?;
891                }
892                if !fields.is_empty() {
893                    write!(f, "({:?})", fields.iter().format(", "),)?;
894                }
895                Ok(())
896            }
897            PlaceNode::Tuple(_, fields) => write!(f, "({:?})", fields.iter().format(", ")),
898            PlaceNode::Ty(ty) => write!(f, "•{ty:?}"),
899        }
900    }
901}