flux_refineck/ghost_statements/
fold_unfold.rs

1use std::{collections::hash_map::Entry, fmt, iter};
2
3use flux_common::{tracked_span_assert_eq, tracked_span_bug, tracked_span_dbg_assert_eq};
4use flux_middle::{
5    PlaceExt as _, def_id_to_string, global_env::GlobalEnv, queries::QueryResult, query_bug, rty,
6};
7use flux_rustc_bridge::{
8    mir::{
9        BasicBlock, Body, BorrowKind, FIRST_VARIANT, FieldIdx, Local, Location,
10        NonDivergingIntrinsic, Operand, Place, PlaceElem, PlaceRef, Rvalue, Statement,
11        StatementKind, Terminator, TerminatorKind, UnOp, VariantIdx,
12    },
13    ty::{AdtDef, GenericArgs, GenericArgsExt as _, List, Mutability, Ty, TyKind},
14};
15use itertools::{Itertools, repeat_n};
16use rustc_data_structures::unord::UnordMap;
17use rustc_hash::FxHashMap;
18use rustc_hir::def_id::DefId;
19use rustc_index::{Idx, IndexVec, bit_set::DenseBitSet};
20use rustc_middle::mir::{FakeReadCause, START_BLOCK};
21
22use super::{GhostStatements, StatementsAt};
23use crate::{
24    ghost_statements::{GhostStatement, Point},
25    queue::WorkQueue,
26};
27
28pub(crate) fn add_ghost_statements<'tcx>(
29    stmts: &mut GhostStatements,
30    genv: GlobalEnv<'_, 'tcx>,
31    body: &Body<'tcx>,
32    fn_sig: Option<&rty::EarlyBinder<rty::PolyFnSig>>,
33) -> QueryResult {
34    let mut bb_envs = FxHashMap::default();
35    FoldUnfoldAnalysis::new(genv, body, &mut bb_envs, Infer).run(fn_sig)?;
36
37    FoldUnfoldAnalysis::new(genv, body, &mut bb_envs, Elaboration { stmts }).run(fn_sig)
38}
39
40#[derive(Clone)]
41struct Env {
42    map: IndexVec<Local, PlaceNode>,
43}
44
45impl Env {
46    fn new(body: &Body) -> Self {
47        Self {
48            map: body
49                .local_decls
50                .iter()
51                .map(|decl| PlaceNode::Ty(decl.ty.clone()))
52                .collect(),
53        }
54    }
55
56    fn projection<'a>(&mut self, genv: GlobalEnv, place: &'a Place) -> QueryResult<ProjResult<'a>> {
57        let (node, place, modified) = self.ensure_unfolded(genv, place)?;
58        if modified {
59            Ok(ProjResult::Unfold(place))
60        } else if node.ensure_folded() {
61            Ok(ProjResult::Fold(place))
62        } else {
63            Ok(ProjResult::None)
64        }
65    }
66
67    fn downcast(&mut self, genv: GlobalEnv, place: &Place, variant_idx: VariantIdx) -> QueryResult {
68        let (node, ..) = self.ensure_unfolded(genv, place)?;
69        node.downcast(genv, variant_idx)?;
70        Ok(())
71    }
72
73    fn ensure_unfolded<'a>(
74        &mut self,
75        genv: GlobalEnv,
76        place: &'a Place,
77    ) -> QueryResult<(&mut PlaceNode, PlaceRef<'a>, Modified)> {
78        let mut node = &mut self.map[place.local];
79        let mut modified = false;
80        let mut i = 0;
81        while i < place.projection.len() {
82            let elem = place.projection[i];
83            let (n, m) = match elem {
84                PlaceElem::Deref => node.deref(),
85                PlaceElem::Field(f) => node.field(genv, f)?,
86                PlaceElem::Downcast(_, idx) => node.downcast(genv, idx)?,
87                PlaceElem::Index(_) | PlaceElem::ConstantIndex { .. } => break,
88            };
89            node = n;
90            modified |= m;
91            i += 1;
92        }
93        Ok((node, place.as_ref().truncate(i), modified))
94    }
95
96    fn join(&mut self, genv: GlobalEnv, mut other: Env) -> QueryResult<Modified> {
97        let mut modified = false;
98        for (local, node) in self.map.iter_enumerated_mut() {
99            let (m, _) = node.join(genv, &mut other.map[local], false)?;
100            modified |= m;
101        }
102        Ok(modified)
103    }
104
105    fn collect_fold_unfolds_at_goto(&self, target: &Env, stmts: &mut StatementsAt) {
106        for (local, node) in self.map.iter_enumerated() {
107            node.collect_fold_unfolds(&target.map[local], &mut Place::new(local, vec![]), stmts);
108        }
109    }
110
111    fn collect_folds_at_ret(&self, body: &Body, stmts: &mut StatementsAt) {
112        for local in body.args_iter() {
113            self.map[local].collect_folds_at_ret(&mut Place::new(local, vec![]), stmts);
114        }
115    }
116}
117
118type Modified = bool;
119
120struct FoldUnfoldAnalysis<'a, 'genv, 'tcx, M> {
121    genv: GlobalEnv<'genv, 'tcx>,
122    body: &'a Body<'tcx>,
123    bb_envs: &'a mut FxHashMap<BasicBlock, Env>,
124    visited: DenseBitSet<BasicBlock>,
125    queue: WorkQueue<'a>,
126    discriminants: UnordMap<Place, Place>,
127    point: Point,
128    mode: M,
129}
130
131trait Mode: Sized {
132    const NAME: &'static str;
133
134    fn projection(
135        analysis: &mut FoldUnfoldAnalysis<Self>,
136        env: &mut Env,
137        place: &Place,
138    ) -> QueryResult;
139
140    fn goto_join_point(
141        analysis: &mut FoldUnfoldAnalysis<Self>,
142        target: BasicBlock,
143        env: Env,
144    ) -> QueryResult<bool>;
145
146    fn ret(analysis: &mut FoldUnfoldAnalysis<Self>, env: &Env);
147}
148
149struct Infer;
150
151struct Elaboration<'a> {
152    stmts: &'a mut GhostStatements,
153}
154
155impl Elaboration<'_> {
156    fn insert_at(&mut self, point: Point, stmt: GhostStatement) {
157        self.stmts.insert_at(point, stmt);
158    }
159}
160
161#[derive(Debug)]
162enum ProjResult<'a> {
163    None,
164    Fold(PlaceRef<'a>),
165    Unfold(PlaceRef<'a>),
166}
167
168impl Mode for Infer {
169    const NAME: &'static str = "infer";
170
171    fn projection(
172        analysis: &mut FoldUnfoldAnalysis<Self>,
173        env: &mut Env,
174        place: &Place,
175    ) -> QueryResult {
176        env.projection(analysis.genv, place)?;
177        Ok(())
178    }
179
180    fn goto_join_point(
181        analysis: &mut FoldUnfoldAnalysis<Self>,
182        target: BasicBlock,
183        env: Env,
184    ) -> QueryResult<bool> {
185        let modified = match analysis.bb_envs.entry(target) {
186            Entry::Occupied(mut entry) => entry.get_mut().join(analysis.genv, env)?,
187            Entry::Vacant(entry) => {
188                entry.insert(env);
189                true
190            }
191        };
192        Ok(modified)
193    }
194
195    fn ret(_: &mut FoldUnfoldAnalysis<Self>, _: &Env) {}
196}
197
198impl Mode for Elaboration<'_> {
199    const NAME: &'static str = "elaboration";
200
201    fn projection(
202        analysis: &mut FoldUnfoldAnalysis<Self>,
203        env: &mut Env,
204        place: &Place,
205    ) -> QueryResult {
206        match env.projection(analysis.genv, place)? {
207            ProjResult::None => {}
208            ProjResult::Fold(place_ref) => {
209                tracked_span_assert_eq!(place_ref, place.as_ref());
210                let place = place.clone();
211                analysis
212                    .mode
213                    .insert_at(analysis.point, GhostStatement::Fold(place));
214            }
215            ProjResult::Unfold(place_ref) => {
216                tracked_span_assert_eq!(place_ref, place.as_ref());
217                match place_ref.last_projection() {
218                    Some((base, PlaceElem::Deref | PlaceElem::Field(..))) => {
219                        analysis
220                            .mode
221                            .insert_at(analysis.point, GhostStatement::Unfold(base.to_place()));
222                    }
223                    _ => Err(query_bug!("invalid projection for unfolding {place_ref:?}"))?,
224                }
225            }
226        }
227        Ok(())
228    }
229
230    fn goto_join_point(
231        analysis: &mut FoldUnfoldAnalysis<Self>,
232        target: BasicBlock,
233        env: Env,
234    ) -> QueryResult<bool> {
235        env.collect_fold_unfolds_at_goto(
236            &analysis.bb_envs[&target],
237            &mut analysis.mode.stmts.at(analysis.point),
238        );
239        Ok(!analysis.visited.contains(target))
240    }
241
242    fn ret(analysis: &mut FoldUnfoldAnalysis<Self>, env: &Env) {
243        env.collect_folds_at_ret(analysis.body, &mut analysis.mode.stmts.at(analysis.point));
244    }
245}
246
247#[derive(Clone)]
248enum PlaceNode {
249    Deref(Ty, Box<PlaceNode>),
250    Downcast(AdtDef, GenericArgs, VariantIdx, Vec<PlaceNode>),
251    Closure(DefId, GenericArgs, Vec<PlaceNode>),
252    Generator(DefId, GenericArgs, Vec<PlaceNode>),
253    Tuple(List<Ty>, Vec<PlaceNode>),
254    Ty(Ty),
255}
256
257impl<M: Mode> FoldUnfoldAnalysis<'_, '_, '_, M> {
258    fn run(mut self, fn_sig: Option<&rty::EarlyBinder<rty::PolyFnSig>>) -> QueryResult {
259        let mut env = Env::new(self.body);
260
261        if let Some(fn_sig) = fn_sig {
262            let fn_sig = fn_sig.as_ref().skip_binder().as_ref().skip_binder();
263            for (local, ty) in iter::zip(self.body.args_iter(), fn_sig.inputs()) {
264                if let rty::TyKind::StrgRef(..) | rty::Ref!(.., Mutability::Mut) = ty.kind() {
265                    M::projection(&mut self, &mut env, &Place::new(local, vec![PlaceElem::Deref]))?;
266                }
267            }
268        }
269        self.goto(START_BLOCK, env)?;
270        while let Some(bb) = self.queue.pop() {
271            self.basic_block(bb, self.bb_envs[&bb].clone())?;
272        }
273        Ok(())
274    }
275
276    fn basic_block(&mut self, bb: BasicBlock, mut env: Env) -> QueryResult {
277        self.visited.insert(bb);
278        let data = &self.body.basic_blocks[bb];
279        for (statement_index, stmt) in data.statements.iter().enumerate() {
280            self.point = Point::BeforeLocation(Location { block: bb, statement_index });
281            self.statement(stmt, &mut env)?;
282        }
283        if let Some(terminator) = &data.terminator {
284            self.point = Point::BeforeLocation(self.body.terminator_loc(bb));
285            let successors = self.terminator(terminator, env)?;
286            for (env, target) in successors {
287                self.point = Point::Edge(bb, target);
288                self.goto(target, env)?;
289            }
290        }
291        Ok(())
292    }
293
294    fn statement(&mut self, stmt: &Statement, env: &mut Env) -> QueryResult {
295        match &stmt.kind {
296            StatementKind::FakeRead(box (FakeReadCause::ForIndex, place)) => {
297                M::projection(self, env, place)?;
298            }
299            StatementKind::Assign(place, rvalue) => {
300                match rvalue {
301                    Rvalue::Len(place) => {
302                        M::projection(self, env, place)?;
303                    }
304                    Rvalue::UnaryOp(UnOp::PtrMetadata, Operand::Copy(place))
305                    | Rvalue::UnaryOp(UnOp::PtrMetadata, Operand::Move(place)) => {
306                        let deref_place = place.deref();
307                        M::projection(self, env, &deref_place)?;
308                    }
309                    Rvalue::Use(op)
310                    | Rvalue::Cast(_, op, _)
311                    | Rvalue::UnaryOp(_, op)
312                    | Rvalue::ShallowInitBox(op, _) => {
313                        self.operand(op, env)?;
314                    }
315                    Rvalue::Ref(.., bk, place) => {
316                        // Fake borrows should not cause the place to fold
317                        if !matches!(bk, BorrowKind::Fake(_)) {
318                            M::projection(self, env, place)?;
319                        }
320                    }
321                    Rvalue::RawPtr(_, place) => {
322                        M::projection(self, env, place)?;
323                    }
324                    Rvalue::BinaryOp(_, op1, op2) => {
325                        self.operand(op1, env)?;
326                        self.operand(op2, env)?;
327                    }
328                    Rvalue::Aggregate(_, args) => {
329                        for arg in args {
330                            self.operand(arg, env)?;
331                        }
332                    }
333
334                    Rvalue::Discriminant(discr) => {
335                        M::projection(self, env, discr)?;
336                        self.discriminants.insert(place.clone(), discr.clone());
337                    }
338                    Rvalue::Repeat(op, _) => {
339                        self.operand(op, env)?;
340                    }
341                    Rvalue::NullaryOp(_, _) => {}
342                }
343                M::projection(self, env, place)?;
344            }
345            StatementKind::Intrinsic(NonDivergingIntrinsic::Assume(op)) => {
346                self.operand(op, env)?;
347            }
348            StatementKind::SetDiscriminant(_, _)
349            | StatementKind::FakeRead(_)
350            | StatementKind::AscribeUserType(_, _)
351            | StatementKind::PlaceMention(_)
352            | StatementKind::Nop => {}
353        }
354        Ok(())
355    }
356
357    fn operand(&mut self, op: &Operand, env: &mut Env) -> QueryResult {
358        match op {
359            Operand::Copy(place) | Operand::Move(place) => {
360                M::projection(self, env, place)?;
361            }
362            Operand::Constant(_) => {}
363        }
364        Ok(())
365    }
366
367    fn terminator(
368        &mut self,
369        terminator: &Terminator,
370        mut env: Env,
371    ) -> QueryResult<Vec<(Env, BasicBlock)>> {
372        let mut successors = vec![];
373        match &terminator.kind {
374            TerminatorKind::Return => {
375                M::ret(self, &env);
376            }
377            TerminatorKind::Call { args, destination, target, .. } => {
378                for arg in args {
379                    self.operand(arg, &mut env)?;
380                }
381                M::projection(self, &mut env, destination)?;
382                if let Some(target) = target {
383                    successors.push((env, *target));
384                }
385            }
386            TerminatorKind::SwitchInt { discr, targets } => {
387                let is_match = match discr {
388                    Operand::Copy(place) | Operand::Move(place) => {
389                        M::projection(self, &mut env, place)?;
390                        self.discriminants.remove(place)
391                    }
392                    Operand::Constant(_) => None,
393                };
394                if let Some(place) = is_match {
395                    let discr_ty = place.ty(self.genv, &self.body.local_decls)?.ty;
396                    let (adt, _) = discr_ty.expect_adt();
397
398                    let mut remaining: FxHashMap<u128, VariantIdx> = adt
399                        .discriminants()
400                        .map(|(idx, discr)| (discr, idx))
401                        .collect();
402                    for (bits, target) in targets.iter() {
403                        let variant_idx = remaining
404                            .remove(&bits)
405                            .expect("value doesn't correspond to any variant");
406
407                        // We do not insert unfolds in match arms because they are explicit
408                        // unfold points.
409                        let mut env = env.clone();
410                        env.downcast(self.genv, &place, variant_idx)?;
411                        successors.push((env, target));
412                    }
413                    if remaining.len() == 1 {
414                        let (_, variant_idx) = remaining
415                            .into_iter()
416                            .next()
417                            .unwrap_or_else(|| tracked_span_bug!());
418                        env.downcast(self.genv, &place, variant_idx)?;
419                    }
420                    successors.push((env, targets.otherwise()));
421                } else {
422                    let n = targets.all_targets().len();
423                    for (env, target) in iter::zip(repeat_n(env, n), targets.all_targets()) {
424                        successors.push((env, *target));
425                    }
426                }
427            }
428            TerminatorKind::Goto { target } => {
429                successors.push((env, *target));
430            }
431            TerminatorKind::Yield { resume, resume_arg, .. } => {
432                M::projection(self, &mut env, resume_arg)?;
433                successors.push((env, *resume));
434            }
435            TerminatorKind::Drop { place, target, .. } => {
436                M::projection(self, &mut env, place)?;
437                successors.push((env, *target));
438            }
439            TerminatorKind::Assert { cond, target, .. } => {
440                self.operand(cond, &mut env)?;
441                successors.push((env, *target));
442            }
443            TerminatorKind::FalseEdge { real_target, .. } => {
444                successors.push((env, *real_target));
445            }
446            TerminatorKind::FalseUnwind { real_target, .. } => {
447                successors.push((env, *real_target));
448            }
449            TerminatorKind::Unreachable
450            | TerminatorKind::UnwindResume
451            | TerminatorKind::CoroutineDrop => {}
452        }
453        Ok(successors)
454    }
455
456    fn goto(&mut self, target: BasicBlock, env: Env) -> QueryResult {
457        if self.body.is_join_point(target) {
458            if M::goto_join_point(self, target, env)? {
459                self.queue.insert(target);
460            }
461            Ok(())
462        } else {
463            self.basic_block(target, env)
464        }
465    }
466}
467
468impl<'a, 'genv, 'tcx, M> FoldUnfoldAnalysis<'a, 'genv, 'tcx, M> {
469    pub(crate) fn new(
470        genv: GlobalEnv<'genv, 'tcx>,
471        body: &'a Body<'tcx>,
472        bb_envs: &'a mut FxHashMap<BasicBlock, Env>,
473        mode: M,
474    ) -> Self {
475        Self {
476            genv,
477            body,
478            bb_envs,
479            discriminants: Default::default(),
480            point: Point::FunEntry,
481            visited: DenseBitSet::new_empty(body.basic_blocks.len()),
482            queue: WorkQueue::empty(body.basic_blocks.len(), &body.dominator_order_rank),
483            mode,
484        }
485    }
486}
487
488impl PlaceNode {
489    fn deref(&mut self) -> (&mut PlaceNode, Modified) {
490        match self {
491            PlaceNode::Deref(_, node) => (node, false),
492            PlaceNode::Ty(ty) => {
493                *self = PlaceNode::Deref(ty.clone(), Box::new(PlaceNode::Ty(ty.deref())));
494                let PlaceNode::Deref(_, node) = self else { unreachable!() };
495                (node, true)
496            }
497            _ => tracked_span_bug!("deref of non-deref place: `{:?}`", self),
498        }
499    }
500
501    fn downcast(
502        &mut self,
503        genv: GlobalEnv,
504        idx: VariantIdx,
505    ) -> QueryResult<(&mut PlaceNode, Modified)> {
506        match self {
507            PlaceNode::Downcast(.., idx2, _) => {
508                debug_assert_eq!(idx, *idx2);
509                Ok((self, false))
510            }
511            PlaceNode::Ty(ty) => {
512                if let TyKind::Adt(adt_def, args) = ty.kind() {
513                    let fields = downcast(genv, adt_def, args, idx)?;
514                    *self = PlaceNode::Downcast(adt_def.clone(), args.clone(), idx, fields);
515                    Ok((self, true))
516                } else {
517                    tracked_span_bug!("invalid downcast `{self:?}`");
518                }
519            }
520            _ => tracked_span_bug!("invalid downcast `{self:?}`"),
521        }
522    }
523
524    fn field(&mut self, genv: GlobalEnv, f: FieldIdx) -> QueryResult<(&mut PlaceNode, Modified)> {
525        let (fields, unfolded) = self.fields(genv)?;
526        Ok((&mut fields[f.as_usize()], unfolded))
527    }
528
529    fn fields(&mut self, genv: GlobalEnv) -> QueryResult<(&mut Vec<PlaceNode>, bool)> {
530        match self {
531            PlaceNode::Ty(ty) => {
532                let fields = match ty.kind() {
533                    TyKind::Adt(adt_def, args) => {
534                        let fields = downcast_struct(genv, adt_def, args)?;
535                        *self = PlaceNode::Downcast(
536                            adt_def.clone(),
537                            args.clone(),
538                            FIRST_VARIANT,
539                            fields,
540                        );
541                        let PlaceNode::Downcast(.., fields) = self else { unreachable!() };
542                        fields
543                    }
544                    TyKind::Closure(def_id, args) => {
545                        let fields = args
546                            .as_closure()
547                            .upvar_tys()
548                            .iter()
549                            .cloned()
550                            .map(PlaceNode::Ty)
551                            .collect_vec();
552                        *self = PlaceNode::Closure(*def_id, args.clone(), fields);
553                        let PlaceNode::Closure(.., fields) = self else { unreachable!() };
554                        fields
555                    }
556                    TyKind::Tuple(fields) => {
557                        let node_fields = fields.iter().cloned().map(PlaceNode::Ty).collect();
558                        *self = PlaceNode::Tuple(fields.clone(), node_fields);
559                        let PlaceNode::Tuple(.., fields) = self else { unreachable!() };
560                        fields
561                    }
562                    TyKind::Coroutine(def_id, args) => {
563                        let fields = args
564                            .as_coroutine()
565                            .upvar_tys()
566                            .cloned()
567                            .map(PlaceNode::Ty)
568                            .collect_vec();
569                        *self = PlaceNode::Generator(*def_id, args.clone(), fields);
570                        let PlaceNode::Generator(.., fields) = self else { unreachable!() };
571                        fields
572                    }
573                    _ => tracked_span_bug!("implicit downcast of non-struct: `{ty:?}`"),
574                };
575                Ok((fields, true))
576            }
577            PlaceNode::Downcast(.., fields)
578            | PlaceNode::Tuple(.., fields)
579            | PlaceNode::Closure(.., fields)
580            | PlaceNode::Generator(.., fields) => Ok((fields, false)),
581            PlaceNode::Deref(..) => {
582                tracked_span_bug!("projection field of non-adt non-tuple place: `{self:?}`")
583            }
584        }
585    }
586
587    fn ensure_folded(&mut self) -> Modified {
588        match self {
589            PlaceNode::Deref(ty, _) => {
590                *self = PlaceNode::Ty(ty.clone());
591                true
592            }
593            PlaceNode::Downcast(adt, args, ..) => {
594                *self = PlaceNode::Ty(Ty::mk_adt(adt.clone(), args.clone()));
595                true
596            }
597            PlaceNode::Closure(did, args, _) => {
598                *self = PlaceNode::Ty(Ty::mk_closure(*did, args.clone()));
599                true
600            }
601            PlaceNode::Generator(did, args, _) => {
602                *self = PlaceNode::Ty(Ty::mk_coroutine(*did, args.clone()));
603                true
604            }
605            PlaceNode::Tuple(fields, ..) => {
606                *self = PlaceNode::Ty(Ty::mk_tuple(fields.clone()));
607                true
608            }
609            PlaceNode::Ty(_) => false,
610        }
611    }
612
613    fn join(
614        &mut self,
615        genv: GlobalEnv,
616        other: &mut PlaceNode,
617        in_mut_ref: bool,
618    ) -> QueryResult<(bool, bool)> {
619        let mut modified1 = false;
620        let mut modified2 = false;
621
622        let (fields1, fields2) = match (&mut *self, &mut *other) {
623            (PlaceNode::Deref(ty1, node1), PlaceNode::Deref(ty2, node2)) => {
624                debug_assert_eq!(ty1, ty2);
625                return node1.join(genv, node2, in_mut_ref || ty1.is_mut_ref());
626            }
627            (PlaceNode::Tuple(_, fields1), PlaceNode::Tuple(_, fields2)) => (fields1, fields2),
628            (PlaceNode::Closure(.., fields1), PlaceNode::Closure(.., fields2)) => {
629                (fields1, fields2)
630            }
631            (PlaceNode::Generator(.., fields1), PlaceNode::Generator(.., fields2)) => {
632                (fields1, fields2)
633            }
634            (
635                PlaceNode::Downcast(adt1, args1, variant1, fields1),
636                PlaceNode::Downcast(adt2, args2, variant2, fields2),
637            ) => {
638                debug_assert_eq!(adt1, adt2);
639                if variant1 == variant2 {
640                    (fields1, fields2)
641                } else {
642                    *self = PlaceNode::Ty(Ty::mk_adt(adt1.clone(), args1.clone()));
643                    *other = PlaceNode::Ty(Ty::mk_adt(adt2.clone(), args2.clone()));
644                    return Ok((true, true));
645                }
646            }
647            (PlaceNode::Ty(_), PlaceNode::Ty(_)) => return Ok((false, false)),
648            (PlaceNode::Ty(_), _) => {
649                let (m1, m2) = other.join(genv, self, in_mut_ref)?;
650                return Ok((m2, m1));
651            }
652            (PlaceNode::Deref(ty, _), _) => {
653                *self = PlaceNode::Ty(ty.clone());
654                return Ok((true, false));
655            }
656            (PlaceNode::Tuple(_, fields1), _) => {
657                let (fields2, m) = other.fields(genv)?;
658                modified2 |= m;
659                (fields1, fields2)
660            }
661            (PlaceNode::Closure(.., fields1), _) | (PlaceNode::Generator(.., fields1), _) => {
662                let (fields2, m) = other.fields(genv)?;
663                modified2 |= m;
664                (fields1, fields2)
665            }
666
667            (PlaceNode::Downcast(adt, args, .., fields1), _) => {
668                if adt.is_struct() && !in_mut_ref {
669                    let (fields2, m) = other.fields(genv)?;
670                    modified2 |= m;
671                    (fields1, fields2)
672                } else {
673                    *self = PlaceNode::Ty(Ty::mk_adt(adt.clone(), args.clone()));
674                    return Ok((true, false));
675                }
676            }
677        };
678        for (node1, node2) in iter::zip(fields1, fields2) {
679            let (m1, m2) = node1.join(genv, node2, in_mut_ref)?;
680            modified1 |= m1;
681            modified2 |= m2;
682        }
683        Ok((modified1, modified2))
684    }
685
686    /// Collect necessary fold/unfold operations such that `self` is unfolded at the same level than `target`
687    fn collect_fold_unfolds(
688        &self,
689        target: &PlaceNode,
690        place: &mut Place,
691        stmts: &mut StatementsAt,
692    ) {
693        let (fields1, fields2) = match (self, target) {
694            (PlaceNode::Deref(_, node1), PlaceNode::Deref(_, node2)) => {
695                place.projection.push(PlaceElem::Deref);
696                node1.collect_fold_unfolds(node2, place, stmts);
697                place.projection.pop();
698                return;
699            }
700            (PlaceNode::Tuple(_, fields1), PlaceNode::Tuple(_, fields2)) => (fields1, fields2),
701            (PlaceNode::Closure(.., fields1), PlaceNode::Closure(.., fields2))
702            | (PlaceNode::Generator(.., fields1), PlaceNode::Generator(.., fields2)) => {
703                (fields1, fields2)
704            }
705            (
706                PlaceNode::Downcast(adt1, .., idx1, fields1),
707                PlaceNode::Downcast(adt2, .., idx2, fields2),
708            ) => {
709                tracked_span_dbg_assert_eq!(adt1.did(), adt2.did());
710                tracked_span_dbg_assert_eq!(idx1, idx2);
711                (fields1, fields2)
712            }
713            (PlaceNode::Ty(_), PlaceNode::Ty(_)) => return,
714            (PlaceNode::Ty(_), _) => {
715                target.collect_unfolds(place, stmts);
716                return;
717            }
718            (_, PlaceNode::Ty(_)) => {
719                stmts.insert(GhostStatement::Fold(place.clone()));
720                return;
721            }
722            _ => tracked_span_bug!("{self:?} {target:?}"),
723        };
724        for (i, (node1, node2)) in iter::zip(fields1, fields2).enumerate() {
725            place.projection.push(PlaceElem::Field(FieldIdx::new(i)));
726            node1.collect_fold_unfolds(node2, place, stmts);
727            place.projection.pop();
728        }
729    }
730
731    fn collect_unfolds(&self, place: &mut Place, stmts: &mut StatementsAt) {
732        match self {
733            PlaceNode::Ty(_) => {}
734            PlaceNode::Deref(_, node) => {
735                if node.is_ty() {
736                    stmts.insert(GhostStatement::Unfold(place.clone()));
737                } else {
738                    place.projection.push(PlaceElem::Deref);
739                    node.collect_unfolds(place, stmts);
740                    place.projection.pop();
741                }
742            }
743            PlaceNode::Downcast(.., fields)
744            | PlaceNode::Closure(.., fields)
745            | PlaceNode::Generator(.., fields)
746            | PlaceNode::Tuple(.., fields) => {
747                let all_leaves = fields.iter().all(PlaceNode::is_ty);
748                if all_leaves {
749                    stmts.insert(GhostStatement::Unfold(place.clone()));
750                } else {
751                    if let Some(idx) = self.enum_variant() {
752                        place.projection.push(PlaceElem::Downcast(None, idx));
753                    }
754                    for (i, node) in fields.iter().enumerate() {
755                        place.projection.push(PlaceElem::Field(FieldIdx::new(i)));
756                        node.collect_unfolds(place, stmts);
757                        place.projection.pop();
758                    }
759                    if self.enum_variant().is_some() {
760                        place.projection.pop();
761                    }
762                }
763            }
764        }
765    }
766
767    fn collect_folds_at_ret(&self, place: &mut Place, stmts: &mut StatementsAt) {
768        let fields = match self {
769            PlaceNode::Deref(ty, deref_ty) => {
770                place.projection.push(PlaceElem::Deref);
771                if ty.is_mut_ref() {
772                    stmts.insert(GhostStatement::Fold(place.clone()));
773                } else if ty.is_box() {
774                    deref_ty.collect_folds_at_ret(place, stmts);
775                }
776                place.projection.pop();
777                return;
778            }
779            PlaceNode::Downcast(adt, _, idx, fields) => {
780                if adt.is_enum() {
781                    place.projection.push(PlaceElem::Downcast(None, *idx));
782                }
783                fields
784            }
785            PlaceNode::Closure(_, _, fields)
786            | PlaceNode::Generator(_, _, fields)
787            | PlaceNode::Tuple(_, fields) => fields,
788            PlaceNode::Ty(_) => return,
789        };
790        for (i, node) in fields.iter().enumerate() {
791            place.projection.push(PlaceElem::Field(FieldIdx::new(i)));
792            node.collect_folds_at_ret(place, stmts);
793            place.projection.pop();
794        }
795        if let PlaceNode::Downcast(adt, ..) = self
796            && adt.is_enum()
797        {
798            place.projection.pop();
799        }
800    }
801
802    fn enum_variant(&self) -> Option<VariantIdx> {
803        if let PlaceNode::Downcast(adt, _, idx, _) = self
804            && adt.is_enum()
805        {
806            Some(*idx)
807        } else {
808            None
809        }
810    }
811
812    /// Returns `true` if the place node is [`Ty`].
813    ///
814    /// [`Ty`]: PlaceNode::Ty
815    #[must_use]
816    fn is_ty(&self) -> bool {
817        matches!(self, Self::Ty(..))
818    }
819}
820
821fn downcast(
822    genv: GlobalEnv,
823    adt_def: &AdtDef,
824    args: &GenericArgs,
825    variant: VariantIdx,
826) -> QueryResult<Vec<PlaceNode>> {
827    adt_def
828        .variant(variant)
829        .fields
830        .iter()
831        .map(|field| {
832            let ty = genv.lower_type_of(field.did)?.subst(args);
833            QueryResult::Ok(PlaceNode::Ty(ty))
834        })
835        .try_collect()
836}
837
838fn downcast_struct(
839    genv: GlobalEnv,
840    adt_def: &AdtDef,
841    args: &GenericArgs,
842) -> QueryResult<Vec<PlaceNode>> {
843    adt_def
844        .non_enum_variant()
845        .fields
846        .iter()
847        .map(|field| {
848            let ty = genv.lower_type_of(field.did)?.subst(args);
849            QueryResult::Ok(PlaceNode::Ty(ty))
850        })
851        .try_collect()
852}
853
854impl fmt::Debug for Env {
855    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
856        write!(
857            f,
858            "{}",
859            self.map
860                .iter_enumerated()
861                .format_with(", ", |(local, node), f| f(&format_args!("{local:?}: {node:?}")))
862        )
863    }
864}
865
866impl fmt::Debug for PlaceNode {
867    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
868        match self {
869            PlaceNode::Deref(_, node) => write!(f, "*({node:?})"),
870            PlaceNode::Downcast(adt, args, variant, fields) => {
871                write!(f, "{}", def_id_to_string(adt.did()))?;
872                if !args.is_empty() {
873                    write!(f, "<{:?}>", args.iter().format(", "),)?;
874                }
875                write!(f, "::{}", adt.variant(*variant).name)?;
876                if !fields.is_empty() {
877                    write!(f, "({:?})", fields.iter().format(", "),)?;
878                }
879                Ok(())
880            }
881            PlaceNode::Closure(did, args, fields) => {
882                write!(f, "Closure {}", def_id_to_string(*did))?;
883                if !args.is_empty() {
884                    write!(f, "<{:?}>", args.iter().format(", "),)?;
885                }
886                if !fields.is_empty() {
887                    write!(f, "({:?})", fields.iter().format(", "),)?;
888                }
889                Ok(())
890            }
891            PlaceNode::Generator(did, args, fields) => {
892                write!(f, "Generator {}", def_id_to_string(*did))?;
893                if !args.is_empty() {
894                    write!(f, "<{:?}>", args.iter().format(", "),)?;
895                }
896                if !fields.is_empty() {
897                    write!(f, "({:?})", fields.iter().format(", "),)?;
898                }
899                Ok(())
900            }
901            PlaceNode::Tuple(_, fields) => write!(f, "({:?})", fields.iter().format(", ")),
902            PlaceNode::Ty(ty) => write!(f, "•{ty:?}"),
903        }
904    }
905}