flux_refineck/ghost_statements/
fold_unfold.rs

1use std::{collections::hash_map::Entry, fmt, iter};
2
3use flux_common::{tracked_span_assert_eq, tracked_span_bug, tracked_span_dbg_assert_eq};
4use flux_middle::{
5    PlaceExt as _, def_id_to_string, global_env::GlobalEnv, queries::QueryResult, query_bug, rty,
6};
7use flux_rustc_bridge::{
8    mir::{
9        BasicBlock, Body, BorrowKind, FIRST_VARIANT, FieldIdx, Local, Location,
10        NonDivergingIntrinsic, Operand, Place, PlaceElem, PlaceRef, Rvalue, Statement,
11        StatementKind, Terminator, TerminatorKind, UnOp, VariantIdx,
12    },
13    ty::{AdtDef, GenericArgs, GenericArgsExt as _, List, Mutability, Ty, TyKind},
14};
15use itertools::{Itertools, repeat_n};
16use rustc_data_structures::unord::UnordMap;
17use rustc_hash::FxHashMap;
18use rustc_hir::def_id::DefId;
19use rustc_index::{Idx, IndexVec, bit_set::DenseBitSet};
20use rustc_middle::mir::{FakeReadCause, START_BLOCK};
21
22use super::{GhostStatements, StatementsAt};
23use crate::{
24    ghost_statements::{GhostStatement, Point},
25    queue::WorkQueue,
26};
27
28pub(crate) fn add_ghost_statements<'tcx>(
29    stmts: &mut GhostStatements,
30    genv: GlobalEnv<'_, 'tcx>,
31    body: &Body<'tcx>,
32    fn_sig: Option<&rty::EarlyBinder<rty::PolyFnSig>>,
33) -> QueryResult {
34    let mut bb_envs = FxHashMap::default();
35    FoldUnfoldAnalysis::new(genv, body, &mut bb_envs, Infer).run(fn_sig)?;
36
37    FoldUnfoldAnalysis::new(genv, body, &mut bb_envs, Elaboration { stmts }).run(fn_sig)
38}
39
40#[derive(Clone)]
41struct Env {
42    map: IndexVec<Local, PlaceNode>,
43}
44
45impl Env {
46    fn new(body: &Body) -> Self {
47        Self {
48            map: body
49                .local_decls
50                .iter()
51                .map(|decl| PlaceNode::Ty(decl.ty.clone()))
52                .collect(),
53        }
54    }
55
56    fn projection<'a>(&mut self, genv: GlobalEnv, place: &'a Place) -> QueryResult<ProjResult<'a>> {
57        let (node, place, modified) = self.ensure_unfolded(genv, place)?;
58        if modified {
59            Ok(ProjResult::Unfold(place))
60        } else if node.ensure_folded() {
61            Ok(ProjResult::Fold(place))
62        } else {
63            Ok(ProjResult::None)
64        }
65    }
66
67    fn downcast(&mut self, genv: GlobalEnv, place: &Place, variant_idx: VariantIdx) -> QueryResult {
68        let (node, ..) = self.ensure_unfolded(genv, place)?;
69        node.downcast(genv, variant_idx)?;
70        Ok(())
71    }
72
73    fn ensure_unfolded<'a>(
74        &mut self,
75        genv: GlobalEnv,
76        place: &'a Place,
77    ) -> QueryResult<(&mut PlaceNode, PlaceRef<'a>, Modified)> {
78        let mut node = &mut self.map[place.local];
79        let mut modified = false;
80        let mut i = 0;
81        while i < place.projection.len() {
82            let elem = place.projection[i];
83            let (n, m) = match elem {
84                PlaceElem::Deref => node.deref(),
85                PlaceElem::Field(f) => node.field(genv, f)?,
86                PlaceElem::Downcast(_, idx) => node.downcast(genv, idx)?,
87                PlaceElem::Index(_) | PlaceElem::ConstantIndex { .. } => break,
88            };
89            node = n;
90            modified |= m;
91            i += 1;
92        }
93        Ok((node, place.as_ref().truncate(i), modified))
94    }
95
96    fn join(&mut self, genv: GlobalEnv, mut other: Env) -> QueryResult<Modified> {
97        let mut modified = false;
98        for (local, node) in self.map.iter_enumerated_mut() {
99            let (m, _) = node.join(genv, &mut other.map[local], false)?;
100            modified |= m;
101        }
102        Ok(modified)
103    }
104
105    fn collect_fold_unfolds_at_goto(&self, target: &Env, stmts: &mut StatementsAt) {
106        for (local, node) in self.map.iter_enumerated() {
107            node.collect_fold_unfolds(&target.map[local], &mut Place::new(local, vec![]), stmts);
108        }
109    }
110
111    fn collect_folds_at_ret(&self, body: &Body, stmts: &mut StatementsAt) {
112        for local in body.args_iter() {
113            self.map[local].collect_folds_at_ret(&mut Place::new(local, vec![]), stmts);
114        }
115    }
116}
117
118type Modified = bool;
119
120struct FoldUnfoldAnalysis<'a, 'genv, 'tcx, M> {
121    genv: GlobalEnv<'genv, 'tcx>,
122    body: &'a Body<'tcx>,
123    bb_envs: &'a mut FxHashMap<BasicBlock, Env>,
124    visited: DenseBitSet<BasicBlock>,
125    queue: WorkQueue<'a>,
126    discriminants: UnordMap<Place, Place>,
127    point: Point,
128    mode: M,
129}
130
131trait Mode: Sized {
132    const _NAME: &'static str;
133
134    fn projection(
135        analysis: &mut FoldUnfoldAnalysis<Self>,
136        env: &mut Env,
137        place: &Place,
138    ) -> QueryResult;
139
140    fn goto_join_point(
141        analysis: &mut FoldUnfoldAnalysis<Self>,
142        target: BasicBlock,
143        env: Env,
144    ) -> QueryResult<bool>;
145
146    fn ret(analysis: &mut FoldUnfoldAnalysis<Self>, env: &Env);
147}
148
149struct Infer;
150
151struct Elaboration<'a> {
152    stmts: &'a mut GhostStatements,
153}
154
155impl Elaboration<'_> {
156    fn insert_at(&mut self, point: Point, stmt: GhostStatement) {
157        self.stmts.insert_at(point, stmt);
158    }
159}
160
161#[derive(Debug)]
162enum ProjResult<'a> {
163    None,
164    Fold(PlaceRef<'a>),
165    Unfold(PlaceRef<'a>),
166}
167
168impl Mode for Infer {
169    const _NAME: &'static str = "infer";
170
171    fn projection(
172        analysis: &mut FoldUnfoldAnalysis<Self>,
173        env: &mut Env,
174        place: &Place,
175    ) -> QueryResult {
176        env.projection(analysis.genv, place)?;
177        Ok(())
178    }
179
180    fn goto_join_point(
181        analysis: &mut FoldUnfoldAnalysis<Self>,
182        target: BasicBlock,
183        env: Env,
184    ) -> QueryResult<bool> {
185        let modified = match analysis.bb_envs.entry(target) {
186            Entry::Occupied(mut entry) => entry.get_mut().join(analysis.genv, env)?,
187            Entry::Vacant(entry) => {
188                entry.insert(env);
189                true
190            }
191        };
192        Ok(modified)
193    }
194
195    fn ret(_: &mut FoldUnfoldAnalysis<Self>, _: &Env) {}
196}
197
198impl Mode for Elaboration<'_> {
199    const _NAME: &'static str = "elaboration";
200
201    fn projection(
202        analysis: &mut FoldUnfoldAnalysis<Self>,
203        env: &mut Env,
204        place: &Place,
205    ) -> QueryResult {
206        match env.projection(analysis.genv, place)? {
207            ProjResult::None => {}
208            ProjResult::Fold(place_ref) => {
209                tracked_span_assert_eq!(place_ref, place.as_ref());
210                let place = place.clone();
211                analysis
212                    .mode
213                    .insert_at(analysis.point, GhostStatement::Fold(place));
214            }
215            ProjResult::Unfold(place_ref) => {
216                tracked_span_assert_eq!(place_ref, place.as_ref());
217                match place_ref.last_projection() {
218                    Some((base, PlaceElem::Deref | PlaceElem::Field(..))) => {
219                        analysis
220                            .mode
221                            .insert_at(analysis.point, GhostStatement::Unfold(base.to_place()));
222                    }
223                    _ => Err(query_bug!("invalid projection for unfolding {place_ref:?}"))?,
224                }
225            }
226        }
227        Ok(())
228    }
229
230    fn goto_join_point(
231        analysis: &mut FoldUnfoldAnalysis<Self>,
232        target: BasicBlock,
233        env: Env,
234    ) -> QueryResult<bool> {
235        env.collect_fold_unfolds_at_goto(
236            &analysis.bb_envs[&target],
237            &mut analysis.mode.stmts.at(analysis.point),
238        );
239        Ok(!analysis.visited.contains(target))
240    }
241
242    fn ret(analysis: &mut FoldUnfoldAnalysis<Self>, env: &Env) {
243        env.collect_folds_at_ret(analysis.body, &mut analysis.mode.stmts.at(analysis.point));
244    }
245}
246
247#[derive(Clone)]
248enum PlaceNode {
249    Deref(Ty, Box<PlaceNode>),
250    Downcast(AdtDef, GenericArgs, VariantIdx, Vec<PlaceNode>),
251    Closure(DefId, GenericArgs, Vec<PlaceNode>),
252    Generator(DefId, GenericArgs, Vec<PlaceNode>),
253    Tuple(List<Ty>, Vec<PlaceNode>),
254    Ty(Ty),
255}
256
257impl<M: Mode> FoldUnfoldAnalysis<'_, '_, '_, M> {
258    fn run(mut self, fn_sig: Option<&rty::EarlyBinder<rty::PolyFnSig>>) -> QueryResult {
259        let mut env = Env::new(self.body);
260
261        if let Some(fn_sig) = fn_sig {
262            let fn_sig = fn_sig.as_ref().skip_binder().as_ref().skip_binder();
263            for (local, ty) in iter::zip(self.body.args_iter(), fn_sig.inputs()) {
264                if let rty::TyKind::StrgRef(..) | rty::Ref!(.., Mutability::Mut) = ty.kind() {
265                    M::projection(&mut self, &mut env, &Place::new(local, vec![PlaceElem::Deref]))?;
266                }
267            }
268        }
269        self.goto(START_BLOCK, env)?;
270        while let Some(bb) = self.queue.pop() {
271            self.basic_block(bb, self.bb_envs[&bb].clone())?;
272        }
273        Ok(())
274    }
275
276    fn basic_block(&mut self, bb: BasicBlock, mut env: Env) -> QueryResult {
277        self.visited.insert(bb);
278        let data = &self.body.basic_blocks[bb];
279        for (statement_index, stmt) in data.statements.iter().enumerate() {
280            self.point = Point::BeforeLocation(Location { block: bb, statement_index });
281            self.statement(stmt, &mut env)?;
282        }
283        if let Some(terminator) = &data.terminator {
284            self.point = Point::BeforeLocation(self.body.terminator_loc(bb));
285            let successors = self.terminator(terminator, env)?;
286            for (env, target) in successors {
287                self.point = Point::Edge(bb, target);
288                self.goto(target, env)?;
289            }
290        }
291        Ok(())
292    }
293
294    fn statement(&mut self, stmt: &Statement, env: &mut Env) -> QueryResult {
295        match &stmt.kind {
296            StatementKind::FakeRead(box (FakeReadCause::ForIndex, place)) => {
297                M::projection(self, env, place)?;
298            }
299            StatementKind::Assign(place, rvalue) => {
300                match rvalue {
301                    Rvalue::UnaryOp(UnOp::PtrMetadata, Operand::Copy(place))
302                    | Rvalue::UnaryOp(UnOp::PtrMetadata, Operand::Move(place)) => {
303                        let deref_place = place.deref();
304                        M::projection(self, env, &deref_place)?;
305                    }
306                    Rvalue::Use(op)
307                    | Rvalue::Cast(_, op, _)
308                    | Rvalue::UnaryOp(_, op)
309                    | Rvalue::ShallowInitBox(op, _) => {
310                        self.operand(op, env)?;
311                    }
312                    Rvalue::Ref(.., bk, place) => {
313                        // Fake borrows should not cause the place to fold
314                        if !matches!(bk, BorrowKind::Fake(_)) {
315                            M::projection(self, env, place)?;
316                        }
317                    }
318                    Rvalue::RawPtr(_, place) => {
319                        M::projection(self, env, place)?;
320                    }
321                    Rvalue::BinaryOp(_, op1, op2) => {
322                        self.operand(op1, env)?;
323                        self.operand(op2, env)?;
324                    }
325                    Rvalue::Aggregate(_, args) => {
326                        for arg in args {
327                            self.operand(arg, env)?;
328                        }
329                    }
330
331                    Rvalue::Discriminant(discr) => {
332                        M::projection(self, env, discr)?;
333                        self.discriminants.insert(place.clone(), discr.clone());
334                    }
335                    Rvalue::Repeat(op, _) => {
336                        self.operand(op, env)?;
337                    }
338                    Rvalue::NullaryOp(_, _) => {}
339                }
340                M::projection(self, env, place)?;
341            }
342            StatementKind::Intrinsic(NonDivergingIntrinsic::Assume(op)) => {
343                self.operand(op, env)?;
344            }
345            StatementKind::SetDiscriminant(_, _)
346            | StatementKind::FakeRead(_)
347            | StatementKind::AscribeUserType(_, _)
348            | StatementKind::PlaceMention(_)
349            | StatementKind::Nop => {}
350        }
351        Ok(())
352    }
353
354    fn operand(&mut self, op: &Operand, env: &mut Env) -> QueryResult {
355        match op {
356            Operand::Copy(place) | Operand::Move(place) => {
357                M::projection(self, env, place)?;
358            }
359            Operand::Constant(_) => {}
360        }
361        Ok(())
362    }
363
364    fn terminator(
365        &mut self,
366        terminator: &Terminator,
367        mut env: Env,
368    ) -> QueryResult<Vec<(Env, BasicBlock)>> {
369        let mut successors = vec![];
370        match &terminator.kind {
371            TerminatorKind::Return => {
372                M::ret(self, &env);
373            }
374            TerminatorKind::Call { args, destination, target, .. } => {
375                for arg in args {
376                    self.operand(arg, &mut env)?;
377                }
378                M::projection(self, &mut env, destination)?;
379                if let Some(target) = target {
380                    successors.push((env, *target));
381                }
382            }
383            TerminatorKind::SwitchInt { discr, targets } => {
384                let is_match = match discr {
385                    Operand::Copy(place) | Operand::Move(place) => {
386                        M::projection(self, &mut env, place)?;
387                        self.discriminants.remove(place)
388                    }
389                    Operand::Constant(_) => None,
390                };
391                if let Some(place) = is_match {
392                    let discr_ty = place.ty(self.genv, &self.body.local_decls)?.ty;
393                    let (adt, _) = discr_ty.expect_adt();
394
395                    let mut remaining: FxHashMap<u128, VariantIdx> = adt
396                        .discriminants()
397                        .map(|(idx, discr)| (discr, idx))
398                        .collect();
399                    for (bits, target) in targets.iter() {
400                        let variant_idx = remaining
401                            .remove(&bits)
402                            .expect("value doesn't correspond to any variant");
403
404                        // We do not insert unfolds in match arms because they are explicit
405                        // unfold points.
406                        let mut env = env.clone();
407                        env.downcast(self.genv, &place, variant_idx)?;
408                        successors.push((env, target));
409                    }
410                    if remaining.len() == 1 {
411                        let (_, variant_idx) = remaining
412                            .into_iter()
413                            .next()
414                            .unwrap_or_else(|| tracked_span_bug!());
415                        env.downcast(self.genv, &place, variant_idx)?;
416                    }
417                    successors.push((env, targets.otherwise()));
418                } else {
419                    let n = targets.all_targets().len();
420                    for (env, target) in iter::zip(repeat_n(env, n), targets.all_targets()) {
421                        successors.push((env, *target));
422                    }
423                }
424            }
425            TerminatorKind::Goto { target } => {
426                successors.push((env, *target));
427            }
428            TerminatorKind::Yield { resume, resume_arg, .. } => {
429                M::projection(self, &mut env, resume_arg)?;
430                successors.push((env, *resume));
431            }
432            TerminatorKind::Drop { place, target, .. } => {
433                M::projection(self, &mut env, place)?;
434                successors.push((env, *target));
435            }
436            TerminatorKind::Assert { cond, target, .. } => {
437                self.operand(cond, &mut env)?;
438                successors.push((env, *target));
439            }
440            TerminatorKind::FalseEdge { real_target, .. } => {
441                successors.push((env, *real_target));
442            }
443            TerminatorKind::FalseUnwind { real_target, .. } => {
444                successors.push((env, *real_target));
445            }
446            TerminatorKind::Unreachable
447            | TerminatorKind::UnwindResume
448            | TerminatorKind::CoroutineDrop => {}
449        }
450        Ok(successors)
451    }
452
453    fn goto(&mut self, target: BasicBlock, env: Env) -> QueryResult {
454        if self.body.is_join_point(target) {
455            if M::goto_join_point(self, target, env)? {
456                self.queue.insert(target);
457            }
458            Ok(())
459        } else {
460            self.basic_block(target, env)
461        }
462    }
463}
464
465impl<'a, 'genv, 'tcx, M> FoldUnfoldAnalysis<'a, 'genv, 'tcx, M> {
466    pub(crate) fn new(
467        genv: GlobalEnv<'genv, 'tcx>,
468        body: &'a Body<'tcx>,
469        bb_envs: &'a mut FxHashMap<BasicBlock, Env>,
470        mode: M,
471    ) -> Self {
472        Self {
473            genv,
474            body,
475            bb_envs,
476            discriminants: Default::default(),
477            point: Point::FunEntry,
478            visited: DenseBitSet::new_empty(body.basic_blocks.len()),
479            queue: WorkQueue::empty(body.basic_blocks.len(), &body.dominator_order_rank),
480            mode,
481        }
482    }
483}
484
485impl PlaceNode {
486    fn deref(&mut self) -> (&mut PlaceNode, Modified) {
487        match self {
488            PlaceNode::Deref(_, node) => (node, false),
489            PlaceNode::Ty(ty) => {
490                *self = PlaceNode::Deref(ty.clone(), Box::new(PlaceNode::Ty(ty.deref())));
491                let PlaceNode::Deref(_, node) = self else { unreachable!() };
492                (node, true)
493            }
494            _ => tracked_span_bug!("deref of non-deref place: `{:?}`", self),
495        }
496    }
497
498    fn downcast(
499        &mut self,
500        genv: GlobalEnv,
501        idx: VariantIdx,
502    ) -> QueryResult<(&mut PlaceNode, Modified)> {
503        match self {
504            PlaceNode::Downcast(.., idx2, _) => {
505                debug_assert_eq!(idx, *idx2);
506                Ok((self, false))
507            }
508            PlaceNode::Ty(ty) => {
509                if let TyKind::Adt(adt_def, args) = ty.kind() {
510                    let fields = downcast(genv, adt_def, args, idx)?;
511                    *self = PlaceNode::Downcast(adt_def.clone(), args.clone(), idx, fields);
512                    Ok((self, true))
513                } else {
514                    tracked_span_bug!("invalid downcast `{self:?}`");
515                }
516            }
517            _ => tracked_span_bug!("invalid downcast `{self:?}`"),
518        }
519    }
520
521    fn field(&mut self, genv: GlobalEnv, f: FieldIdx) -> QueryResult<(&mut PlaceNode, Modified)> {
522        let (fields, unfolded) = self.fields(genv)?;
523        Ok((&mut fields[f.as_usize()], unfolded))
524    }
525
526    fn fields(&mut self, genv: GlobalEnv) -> QueryResult<(&mut Vec<PlaceNode>, bool)> {
527        match self {
528            PlaceNode::Ty(ty) => {
529                let fields = match ty.kind() {
530                    TyKind::Adt(adt_def, args) => {
531                        let fields = downcast_struct(genv, adt_def, args)?;
532                        *self = PlaceNode::Downcast(
533                            adt_def.clone(),
534                            args.clone(),
535                            FIRST_VARIANT,
536                            fields,
537                        );
538                        let PlaceNode::Downcast(.., fields) = self else { unreachable!() };
539                        fields
540                    }
541                    TyKind::Closure(def_id, args) => {
542                        let fields = args
543                            .as_closure()
544                            .upvar_tys()
545                            .iter()
546                            .cloned()
547                            .map(PlaceNode::Ty)
548                            .collect_vec();
549                        *self = PlaceNode::Closure(*def_id, args.clone(), fields);
550                        let PlaceNode::Closure(.., fields) = self else { unreachable!() };
551                        fields
552                    }
553                    TyKind::Tuple(fields) => {
554                        let node_fields = fields.iter().cloned().map(PlaceNode::Ty).collect();
555                        *self = PlaceNode::Tuple(fields.clone(), node_fields);
556                        let PlaceNode::Tuple(.., fields) = self else { unreachable!() };
557                        fields
558                    }
559                    TyKind::Coroutine(def_id, args) => {
560                        let fields = args
561                            .as_coroutine()
562                            .upvar_tys()
563                            .cloned()
564                            .map(PlaceNode::Ty)
565                            .collect_vec();
566                        *self = PlaceNode::Generator(*def_id, args.clone(), fields);
567                        let PlaceNode::Generator(.., fields) = self else { unreachable!() };
568                        fields
569                    }
570                    _ => tracked_span_bug!("implicit downcast of non-struct: `{ty:?}`"),
571                };
572                Ok((fields, true))
573            }
574            PlaceNode::Downcast(.., fields)
575            | PlaceNode::Tuple(.., fields)
576            | PlaceNode::Closure(.., fields)
577            | PlaceNode::Generator(.., fields) => Ok((fields, false)),
578            PlaceNode::Deref(..) => {
579                tracked_span_bug!("projection field of non-adt non-tuple place: `{self:?}`")
580            }
581        }
582    }
583
584    fn ensure_folded(&mut self) -> Modified {
585        match self {
586            PlaceNode::Deref(ty, _) => {
587                *self = PlaceNode::Ty(ty.clone());
588                true
589            }
590            PlaceNode::Downcast(adt, args, ..) => {
591                *self = PlaceNode::Ty(Ty::mk_adt(adt.clone(), args.clone()));
592                true
593            }
594            PlaceNode::Closure(did, args, _) => {
595                *self = PlaceNode::Ty(Ty::mk_closure(*did, args.clone()));
596                true
597            }
598            PlaceNode::Generator(did, args, _) => {
599                *self = PlaceNode::Ty(Ty::mk_coroutine(*did, args.clone()));
600                true
601            }
602            PlaceNode::Tuple(fields, ..) => {
603                *self = PlaceNode::Ty(Ty::mk_tuple(fields.clone()));
604                true
605            }
606            PlaceNode::Ty(_) => false,
607        }
608    }
609
610    fn join(
611        &mut self,
612        genv: GlobalEnv,
613        other: &mut PlaceNode,
614        in_mut_ref: bool,
615    ) -> QueryResult<(bool, bool)> {
616        let mut modified1 = false;
617        let mut modified2 = false;
618
619        let (fields1, fields2) = match (&mut *self, &mut *other) {
620            (PlaceNode::Deref(ty1, node1), PlaceNode::Deref(ty2, node2)) => {
621                debug_assert_eq!(ty1, ty2);
622                return node1.join(genv, node2, in_mut_ref || ty1.is_mut_ref());
623            }
624            (PlaceNode::Tuple(_, fields1), PlaceNode::Tuple(_, fields2)) => (fields1, fields2),
625            (PlaceNode::Closure(.., fields1), PlaceNode::Closure(.., fields2)) => {
626                (fields1, fields2)
627            }
628            (PlaceNode::Generator(.., fields1), PlaceNode::Generator(.., fields2)) => {
629                (fields1, fields2)
630            }
631            (
632                PlaceNode::Downcast(adt1, args1, variant1, fields1),
633                PlaceNode::Downcast(adt2, args2, variant2, fields2),
634            ) => {
635                debug_assert_eq!(adt1, adt2);
636                if variant1 == variant2 {
637                    (fields1, fields2)
638                } else {
639                    *self = PlaceNode::Ty(Ty::mk_adt(adt1.clone(), args1.clone()));
640                    *other = PlaceNode::Ty(Ty::mk_adt(adt2.clone(), args2.clone()));
641                    return Ok((true, true));
642                }
643            }
644            (PlaceNode::Ty(_), PlaceNode::Ty(_)) => return Ok((false, false)),
645            (PlaceNode::Ty(_), _) => {
646                let (m1, m2) = other.join(genv, self, in_mut_ref)?;
647                return Ok((m2, m1));
648            }
649            (PlaceNode::Deref(ty, _), _) => {
650                *self = PlaceNode::Ty(ty.clone());
651                return Ok((true, false));
652            }
653            (PlaceNode::Tuple(_, fields1), _) => {
654                let (fields2, m) = other.fields(genv)?;
655                modified2 |= m;
656                (fields1, fields2)
657            }
658            (PlaceNode::Closure(.., fields1), _) | (PlaceNode::Generator(.., fields1), _) => {
659                let (fields2, m) = other.fields(genv)?;
660                modified2 |= m;
661                (fields1, fields2)
662            }
663
664            (PlaceNode::Downcast(adt, args, .., fields1), _) => {
665                if adt.is_struct() && !in_mut_ref {
666                    let (fields2, m) = other.fields(genv)?;
667                    modified2 |= m;
668                    (fields1, fields2)
669                } else {
670                    *self = PlaceNode::Ty(Ty::mk_adt(adt.clone(), args.clone()));
671                    return Ok((true, false));
672                }
673            }
674        };
675        for (node1, node2) in iter::zip(fields1, fields2) {
676            let (m1, m2) = node1.join(genv, node2, in_mut_ref)?;
677            modified1 |= m1;
678            modified2 |= m2;
679        }
680        Ok((modified1, modified2))
681    }
682
683    /// Collect necessary fold/unfold operations such that `self` is unfolded at the same level than `target`
684    fn collect_fold_unfolds(
685        &self,
686        target: &PlaceNode,
687        place: &mut Place,
688        stmts: &mut StatementsAt,
689    ) {
690        let (fields1, fields2) = match (self, target) {
691            (PlaceNode::Deref(_, node1), PlaceNode::Deref(_, node2)) => {
692                place.projection.push(PlaceElem::Deref);
693                node1.collect_fold_unfolds(node2, place, stmts);
694                place.projection.pop();
695                return;
696            }
697            (PlaceNode::Tuple(_, fields1), PlaceNode::Tuple(_, fields2)) => (fields1, fields2),
698            (PlaceNode::Closure(.., fields1), PlaceNode::Closure(.., fields2))
699            | (PlaceNode::Generator(.., fields1), PlaceNode::Generator(.., fields2)) => {
700                (fields1, fields2)
701            }
702            (
703                PlaceNode::Downcast(adt1, .., idx1, fields1),
704                PlaceNode::Downcast(adt2, .., idx2, fields2),
705            ) => {
706                tracked_span_dbg_assert_eq!(adt1.did(), adt2.did());
707                tracked_span_dbg_assert_eq!(idx1, idx2);
708                (fields1, fields2)
709            }
710            (PlaceNode::Ty(_), PlaceNode::Ty(_)) => return,
711            (PlaceNode::Ty(_), _) => {
712                target.collect_unfolds(place, stmts);
713                return;
714            }
715            (_, PlaceNode::Ty(_)) => {
716                stmts.insert(GhostStatement::Fold(place.clone()));
717                return;
718            }
719            _ => tracked_span_bug!("{self:?} {target:?}"),
720        };
721        for (i, (node1, node2)) in iter::zip(fields1, fields2).enumerate() {
722            place.projection.push(PlaceElem::Field(FieldIdx::new(i)));
723            node1.collect_fold_unfolds(node2, place, stmts);
724            place.projection.pop();
725        }
726    }
727
728    fn collect_unfolds(&self, place: &mut Place, stmts: &mut StatementsAt) {
729        match self {
730            PlaceNode::Ty(_) => {}
731            PlaceNode::Deref(_, node) => {
732                if node.is_ty() {
733                    stmts.insert(GhostStatement::Unfold(place.clone()));
734                } else {
735                    place.projection.push(PlaceElem::Deref);
736                    node.collect_unfolds(place, stmts);
737                    place.projection.pop();
738                }
739            }
740            PlaceNode::Downcast(.., fields)
741            | PlaceNode::Closure(.., fields)
742            | PlaceNode::Generator(.., fields)
743            | PlaceNode::Tuple(.., fields) => {
744                let all_leaves = fields.iter().all(PlaceNode::is_ty);
745                if all_leaves {
746                    stmts.insert(GhostStatement::Unfold(place.clone()));
747                } else {
748                    if let Some(idx) = self.enum_variant() {
749                        place.projection.push(PlaceElem::Downcast(None, idx));
750                    }
751                    for (i, node) in fields.iter().enumerate() {
752                        place.projection.push(PlaceElem::Field(FieldIdx::new(i)));
753                        node.collect_unfolds(place, stmts);
754                        place.projection.pop();
755                    }
756                    if self.enum_variant().is_some() {
757                        place.projection.pop();
758                    }
759                }
760            }
761        }
762    }
763
764    fn collect_folds_at_ret(&self, place: &mut Place, stmts: &mut StatementsAt) {
765        let fields = match self {
766            PlaceNode::Deref(ty, deref_ty) => {
767                place.projection.push(PlaceElem::Deref);
768                if ty.is_mut_ref() {
769                    stmts.insert(GhostStatement::Fold(place.clone()));
770                } else if ty.is_box() {
771                    deref_ty.collect_folds_at_ret(place, stmts);
772                }
773                place.projection.pop();
774                return;
775            }
776            PlaceNode::Downcast(adt, _, idx, fields) => {
777                if adt.is_enum() {
778                    place.projection.push(PlaceElem::Downcast(None, *idx));
779                }
780                fields
781            }
782            PlaceNode::Closure(_, _, fields)
783            | PlaceNode::Generator(_, _, fields)
784            | PlaceNode::Tuple(_, fields) => fields,
785            PlaceNode::Ty(_) => return,
786        };
787        for (i, node) in fields.iter().enumerate() {
788            place.projection.push(PlaceElem::Field(FieldIdx::new(i)));
789            node.collect_folds_at_ret(place, stmts);
790            place.projection.pop();
791        }
792        if let PlaceNode::Downcast(adt, ..) = self
793            && adt.is_enum()
794        {
795            place.projection.pop();
796        }
797    }
798
799    fn enum_variant(&self) -> Option<VariantIdx> {
800        if let PlaceNode::Downcast(adt, _, idx, _) = self
801            && adt.is_enum()
802        {
803            Some(*idx)
804        } else {
805            None
806        }
807    }
808
809    /// Returns `true` if the place node is [`Ty`].
810    ///
811    /// [`Ty`]: PlaceNode::Ty
812    #[must_use]
813    fn is_ty(&self) -> bool {
814        matches!(self, Self::Ty(..))
815    }
816}
817
818fn downcast(
819    genv: GlobalEnv,
820    adt_def: &AdtDef,
821    args: &GenericArgs,
822    variant: VariantIdx,
823) -> QueryResult<Vec<PlaceNode>> {
824    adt_def
825        .variant(variant)
826        .fields
827        .iter()
828        .map(|field| {
829            let ty = genv.lower_type_of(field.did)?.subst(args);
830            QueryResult::Ok(PlaceNode::Ty(ty))
831        })
832        .try_collect()
833}
834
835fn downcast_struct(
836    genv: GlobalEnv,
837    adt_def: &AdtDef,
838    args: &GenericArgs,
839) -> QueryResult<Vec<PlaceNode>> {
840    adt_def
841        .non_enum_variant()
842        .fields
843        .iter()
844        .map(|field| {
845            let ty = genv.lower_type_of(field.did)?.subst(args);
846            QueryResult::Ok(PlaceNode::Ty(ty))
847        })
848        .try_collect()
849}
850
851impl fmt::Debug for Env {
852    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
853        write!(
854            f,
855            "{}",
856            self.map
857                .iter_enumerated()
858                .format_with(", ", |(local, node), f| f(&format_args!("{local:?}: {node:?}")))
859        )
860    }
861}
862
863impl fmt::Debug for PlaceNode {
864    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
865        match self {
866            PlaceNode::Deref(_, node) => write!(f, "*({node:?})"),
867            PlaceNode::Downcast(adt, args, variant, fields) => {
868                write!(f, "{}", def_id_to_string(adt.did()))?;
869                if !args.is_empty() {
870                    write!(f, "<{:?}>", args.iter().format(", "),)?;
871                }
872                write!(f, "::{}", adt.variant(*variant).name)?;
873                if !fields.is_empty() {
874                    write!(f, "({:?})", fields.iter().format(", "),)?;
875                }
876                Ok(())
877            }
878            PlaceNode::Closure(did, args, fields) => {
879                write!(f, "Closure {}", def_id_to_string(*did))?;
880                if !args.is_empty() {
881                    write!(f, "<{:?}>", args.iter().format(", "),)?;
882                }
883                if !fields.is_empty() {
884                    write!(f, "({:?})", fields.iter().format(", "),)?;
885                }
886                Ok(())
887            }
888            PlaceNode::Generator(did, args, fields) => {
889                write!(f, "Generator {}", def_id_to_string(*did))?;
890                if !args.is_empty() {
891                    write!(f, "<{:?}>", args.iter().format(", "),)?;
892                }
893                if !fields.is_empty() {
894                    write!(f, "({:?})", fields.iter().format(", "),)?;
895                }
896                Ok(())
897            }
898            PlaceNode::Tuple(_, fields) => write!(f, "({:?})", fields.iter().format(", ")),
899            PlaceNode::Ty(ty) => write!(f, "•{ty:?}"),
900        }
901    }
902}