flux_refineck/ghost_statements/
fold_unfold.rs

1use std::{collections::hash_map::Entry, fmt, iter};
2
3use flux_common::{tracked_span_assert_eq, tracked_span_bug, tracked_span_dbg_assert_eq};
4use flux_middle::{
5    PlaceExt as _, def_id_to_string, global_env::GlobalEnv, queries::QueryResult, query_bug, rty,
6};
7use flux_rustc_bridge::{
8    mir::{
9        BasicBlock, Body, BorrowKind, FIRST_VARIANT, FieldIdx, Local, Location,
10        NonDivergingIntrinsic, Operand, Place, PlaceElem, PlaceRef, Rvalue, Statement,
11        StatementKind, Terminator, TerminatorKind, UnOp, VariantIdx,
12    },
13    ty::{AdtDef, GenericArgs, GenericArgsExt as _, List, Mutability, Ty, TyKind},
14};
15use itertools::{Itertools, repeat_n};
16use rustc_data_structures::{fx::FxHashMap, unord::UnordMap};
17use rustc_hir::def_id::DefId;
18use rustc_index::{Idx, IndexVec, bit_set::DenseBitSet};
19use rustc_middle::mir::{FakeReadCause, START_BLOCK};
20
21use super::{GhostStatements, StatementsAt};
22use crate::{
23    ghost_statements::{GhostStatement, Point},
24    queue::WorkQueue,
25};
26
27pub(crate) fn add_ghost_statements<'tcx>(
28    stmts: &mut GhostStatements,
29    genv: GlobalEnv<'_, 'tcx>,
30    body: &Body<'tcx>,
31    fn_sig: Option<&rty::EarlyBinder<rty::PolyFnSig>>,
32) -> QueryResult {
33    let mut bb_envs = UnordMap::default();
34    FoldUnfoldAnalysis::new(genv, body, &mut bb_envs, Infer).run(fn_sig)?;
35
36    FoldUnfoldAnalysis::new(genv, body, &mut bb_envs, Elaboration { stmts }).run(fn_sig)
37}
38
39#[derive(Clone)]
40struct Env {
41    map: IndexVec<Local, PlaceNode>,
42}
43
44impl Env {
45    fn new(body: &Body) -> Self {
46        Self {
47            map: body
48                .local_decls
49                .iter()
50                .map(|decl| PlaceNode::Ty(decl.ty.clone()))
51                .collect(),
52        }
53    }
54
55    fn projection<'a>(&mut self, genv: GlobalEnv, place: &'a Place) -> QueryResult<ProjResult<'a>> {
56        let (node, place, modified) = self.ensure_unfolded(genv, place)?;
57        if modified {
58            Ok(ProjResult::Unfold(place))
59        } else if node.ensure_folded() {
60            Ok(ProjResult::Fold(place))
61        } else {
62            Ok(ProjResult::None)
63        }
64    }
65
66    fn downcast(&mut self, genv: GlobalEnv, place: &Place, variant_idx: VariantIdx) -> QueryResult {
67        let (node, ..) = self.ensure_unfolded(genv, place)?;
68        node.downcast(genv, variant_idx)?;
69        Ok(())
70    }
71
72    fn ensure_unfolded<'a>(
73        &mut self,
74        genv: GlobalEnv,
75        place: &'a Place,
76    ) -> QueryResult<(&mut PlaceNode, PlaceRef<'a>, Modified)> {
77        let mut node = &mut self.map[place.local];
78        let mut modified = false;
79        let mut i = 0;
80        while i < place.projection.len() {
81            let elem = place.projection[i];
82            let (n, m) = match elem {
83                PlaceElem::Deref => node.deref(),
84                PlaceElem::Field(f) => node.field(genv, f)?,
85                PlaceElem::Downcast(_, idx) => node.downcast(genv, idx)?,
86                PlaceElem::Index(_) | PlaceElem::ConstantIndex { .. } => break,
87            };
88            node = n;
89            modified |= m;
90            i += 1;
91        }
92        Ok((node, place.as_ref().truncate(i), modified))
93    }
94
95    fn join(&mut self, genv: GlobalEnv, mut other: Env) -> QueryResult<Modified> {
96        let mut modified = false;
97        for (local, node) in self.map.iter_enumerated_mut() {
98            let (m, _) = node.join(genv, &mut other.map[local], false)?;
99            modified |= m;
100        }
101        Ok(modified)
102    }
103
104    fn collect_fold_unfolds_at_goto(&self, target: &Env, stmts: &mut StatementsAt) {
105        for (local, node) in self.map.iter_enumerated() {
106            node.collect_fold_unfolds(&target.map[local], &mut Place::new(local, vec![]), stmts);
107        }
108    }
109
110    fn collect_folds_at_ret(&self, body: &Body, stmts: &mut StatementsAt) {
111        for local in body.args_iter() {
112            self.map[local].collect_folds_at_ret(&mut Place::new(local, vec![]), stmts);
113        }
114    }
115}
116
117type Modified = bool;
118
119struct FoldUnfoldAnalysis<'a, 'genv, 'tcx, M> {
120    genv: GlobalEnv<'genv, 'tcx>,
121    body: &'a Body<'tcx>,
122    bb_envs: &'a mut UnordMap<BasicBlock, Env>,
123    visited: DenseBitSet<BasicBlock>,
124    queue: WorkQueue<'a>,
125    discriminants: UnordMap<Place, Place>,
126    point: Point,
127    mode: M,
128}
129
130trait Mode: Sized {
131    const _NAME: &'static str;
132
133    fn projection(
134        analysis: &mut FoldUnfoldAnalysis<Self>,
135        env: &mut Env,
136        place: &Place,
137    ) -> QueryResult;
138
139    fn goto_join_point(
140        analysis: &mut FoldUnfoldAnalysis<Self>,
141        target: BasicBlock,
142        env: Env,
143    ) -> QueryResult<bool>;
144
145    fn ret(analysis: &mut FoldUnfoldAnalysis<Self>, env: &Env);
146}
147
148struct Infer;
149
150struct Elaboration<'a> {
151    stmts: &'a mut GhostStatements,
152}
153
154impl Elaboration<'_> {
155    fn insert_at(&mut self, point: Point, stmt: GhostStatement) {
156        self.stmts.insert_at(point, stmt);
157    }
158}
159
160#[derive(Debug)]
161enum ProjResult<'a> {
162    None,
163    Fold(PlaceRef<'a>),
164    Unfold(PlaceRef<'a>),
165}
166
167impl Mode for Infer {
168    const _NAME: &'static str = "infer";
169
170    fn projection(
171        analysis: &mut FoldUnfoldAnalysis<Self>,
172        env: &mut Env,
173        place: &Place,
174    ) -> QueryResult {
175        env.projection(analysis.genv, place)?;
176        Ok(())
177    }
178
179    fn goto_join_point(
180        analysis: &mut FoldUnfoldAnalysis<Self>,
181        target: BasicBlock,
182        env: Env,
183    ) -> QueryResult<bool> {
184        let modified = match analysis.bb_envs.entry(target) {
185            Entry::Occupied(mut entry) => entry.get_mut().join(analysis.genv, env)?,
186            Entry::Vacant(entry) => {
187                entry.insert(env);
188                true
189            }
190        };
191        Ok(modified)
192    }
193
194    fn ret(_: &mut FoldUnfoldAnalysis<Self>, _: &Env) {}
195}
196
197impl Mode for Elaboration<'_> {
198    const _NAME: &'static str = "elaboration";
199
200    fn projection(
201        analysis: &mut FoldUnfoldAnalysis<Self>,
202        env: &mut Env,
203        place: &Place,
204    ) -> QueryResult {
205        match env.projection(analysis.genv, place)? {
206            ProjResult::None => {}
207            ProjResult::Fold(place_ref) => {
208                tracked_span_assert_eq!(place_ref, place.as_ref());
209                let place = place.clone();
210                analysis
211                    .mode
212                    .insert_at(analysis.point, GhostStatement::Fold(place));
213            }
214            ProjResult::Unfold(place_ref) => {
215                tracked_span_assert_eq!(place_ref, place.as_ref());
216                match place_ref.last_projection() {
217                    Some((base, PlaceElem::Deref | PlaceElem::Field(..))) => {
218                        analysis
219                            .mode
220                            .insert_at(analysis.point, GhostStatement::Unfold(base.to_place()));
221                    }
222                    _ => Err(query_bug!("invalid projection for unfolding {place_ref:?}"))?,
223                }
224            }
225        }
226        Ok(())
227    }
228
229    fn goto_join_point(
230        analysis: &mut FoldUnfoldAnalysis<Self>,
231        target: BasicBlock,
232        env: Env,
233    ) -> QueryResult<bool> {
234        env.collect_fold_unfolds_at_goto(
235            &analysis.bb_envs[&target],
236            &mut analysis.mode.stmts.at(analysis.point),
237        );
238        Ok(!analysis.visited.contains(target))
239    }
240
241    fn ret(analysis: &mut FoldUnfoldAnalysis<Self>, env: &Env) {
242        env.collect_folds_at_ret(analysis.body, &mut analysis.mode.stmts.at(analysis.point));
243    }
244}
245
246#[derive(Clone)]
247enum PlaceNode {
248    Deref(Ty, Box<PlaceNode>),
249    Downcast(AdtDef, GenericArgs, VariantIdx, Vec<PlaceNode>),
250    Closure(DefId, GenericArgs, Vec<PlaceNode>),
251    Generator(DefId, GenericArgs, Vec<PlaceNode>),
252    Tuple(List<Ty>, Vec<PlaceNode>),
253    Ty(Ty),
254}
255
256impl<M: Mode> FoldUnfoldAnalysis<'_, '_, '_, M> {
257    fn run(mut self, fn_sig: Option<&rty::EarlyBinder<rty::PolyFnSig>>) -> QueryResult {
258        let mut env = Env::new(self.body);
259
260        if let Some(fn_sig) = fn_sig {
261            let fn_sig = fn_sig.as_ref().skip_binder().as_ref().skip_binder();
262            for (local, ty) in iter::zip(self.body.args_iter(), fn_sig.inputs()) {
263                if let rty::TyKind::StrgRef(..) | rty::Ref!(.., Mutability::Mut) = ty.kind() {
264                    M::projection(&mut self, &mut env, &Place::new(local, vec![PlaceElem::Deref]))?;
265                }
266            }
267        }
268        self.goto(START_BLOCK, env)?;
269        while let Some(bb) = self.queue.pop() {
270            self.basic_block(bb, self.bb_envs[&bb].clone())?;
271        }
272        Ok(())
273    }
274
275    fn basic_block(&mut self, bb: BasicBlock, mut env: Env) -> QueryResult {
276        self.visited.insert(bb);
277        let data = &self.body.basic_blocks[bb];
278        for (statement_index, stmt) in data.statements.iter().enumerate() {
279            self.point = Point::BeforeLocation(Location { block: bb, statement_index });
280            self.statement(stmt, &mut env)?;
281        }
282        if let Some(terminator) = &data.terminator {
283            self.point = Point::BeforeLocation(self.body.terminator_loc(bb));
284            let successors = self.terminator(terminator, env)?;
285            for (env, target) in successors {
286                self.point = Point::Edge(bb, target);
287                self.goto(target, env)?;
288            }
289        }
290        Ok(())
291    }
292
293    fn statement(&mut self, stmt: &Statement, env: &mut Env) -> QueryResult {
294        match &stmt.kind {
295            StatementKind::FakeRead(box (FakeReadCause::ForIndex, place)) => {
296                M::projection(self, env, place)?;
297            }
298            StatementKind::Assign(place, rvalue) => {
299                match rvalue {
300                    Rvalue::UnaryOp(UnOp::PtrMetadata, Operand::Copy(place))
301                    | Rvalue::UnaryOp(UnOp::PtrMetadata, Operand::Move(place)) => {
302                        let deref_place = place.deref();
303                        M::projection(self, env, &deref_place)?;
304                    }
305                    Rvalue::Use(op)
306                    | Rvalue::Cast(_, op, _)
307                    | Rvalue::UnaryOp(_, op)
308                    | Rvalue::ShallowInitBox(op, _) => {
309                        self.operand(op, env)?;
310                    }
311                    Rvalue::Ref(.., bk, place) => {
312                        // Fake borrows should not cause the place to fold
313                        if !matches!(bk, BorrowKind::Fake(_)) {
314                            M::projection(self, env, place)?;
315                        }
316                    }
317                    Rvalue::RawPtr(_, place) => {
318                        M::projection(self, env, place)?;
319                    }
320                    Rvalue::BinaryOp(_, op1, op2) => {
321                        self.operand(op1, env)?;
322                        self.operand(op2, env)?;
323                    }
324                    Rvalue::Aggregate(_, args) => {
325                        for arg in args {
326                            self.operand(arg, env)?;
327                        }
328                    }
329
330                    Rvalue::Discriminant(discr) => {
331                        M::projection(self, env, discr)?;
332                        self.discriminants.insert(place.clone(), discr.clone());
333                    }
334                    Rvalue::Repeat(op, _) => {
335                        self.operand(op, env)?;
336                    }
337                }
338                M::projection(self, env, place)?;
339            }
340            StatementKind::Intrinsic(NonDivergingIntrinsic::Assume(op)) => {
341                self.operand(op, env)?;
342            }
343            StatementKind::SetDiscriminant(_, _)
344            | StatementKind::FakeRead(_)
345            | StatementKind::AscribeUserType(_, _)
346            | StatementKind::PlaceMention(_)
347            | StatementKind::Nop => {}
348        }
349        Ok(())
350    }
351
352    fn operand(&mut self, op: &Operand, env: &mut Env) -> QueryResult {
353        match op {
354            Operand::Copy(place) | Operand::Move(place) => {
355                M::projection(self, env, place)?;
356            }
357            Operand::Constant(_) => {}
358        }
359        Ok(())
360    }
361
362    fn terminator(
363        &mut self,
364        terminator: &Terminator,
365        mut env: Env,
366    ) -> QueryResult<Vec<(Env, BasicBlock)>> {
367        let mut successors = vec![];
368        match &terminator.kind {
369            TerminatorKind::Return => {
370                M::ret(self, &env);
371            }
372            TerminatorKind::Call { args, destination, target, .. } => {
373                for arg in args {
374                    self.operand(arg, &mut env)?;
375                }
376                M::projection(self, &mut env, destination)?;
377                if let Some(target) = target {
378                    successors.push((env, *target));
379                }
380            }
381            TerminatorKind::SwitchInt { discr, targets } => {
382                let is_match = match discr {
383                    Operand::Copy(place) | Operand::Move(place) => {
384                        M::projection(self, &mut env, place)?;
385                        self.discriminants.remove(place)
386                    }
387                    Operand::Constant(_) => None,
388                };
389                if let Some(place) = is_match {
390                    let discr_ty = place.ty(self.genv, &self.body.local_decls)?.ty;
391                    let (adt, _) = discr_ty.expect_adt();
392
393                    let mut remaining: FxHashMap<u128, VariantIdx> = adt
394                        .discriminants()
395                        .map(|(idx, discr)| (discr, idx))
396                        .collect();
397                    for (bits, target) in targets.iter() {
398                        let variant_idx = remaining
399                            .remove(&bits)
400                            .expect("value doesn't correspond to any variant");
401
402                        // We do not insert unfolds in match arms because they are explicit
403                        // unfold points.
404                        let mut env = env.clone();
405                        env.downcast(self.genv, &place, variant_idx)?;
406                        successors.push((env, target));
407                    }
408                    if remaining.len() == 1 {
409                        let (_, variant_idx) = remaining
410                            .into_iter()
411                            .next()
412                            .unwrap_or_else(|| tracked_span_bug!());
413                        env.downcast(self.genv, &place, variant_idx)?;
414                    }
415                    successors.push((env, targets.otherwise()));
416                } else {
417                    let n = targets.all_targets().len();
418                    for (env, target) in iter::zip(repeat_n(env, n), targets.all_targets()) {
419                        successors.push((env, *target));
420                    }
421                }
422            }
423            TerminatorKind::Goto { target } => {
424                successors.push((env, *target));
425            }
426            TerminatorKind::Yield { resume, resume_arg, .. } => {
427                M::projection(self, &mut env, resume_arg)?;
428                successors.push((env, *resume));
429            }
430            TerminatorKind::Drop { place, target, .. } => {
431                M::projection(self, &mut env, place)?;
432                successors.push((env, *target));
433            }
434            TerminatorKind::Assert { cond, target, .. } => {
435                self.operand(cond, &mut env)?;
436                successors.push((env, *target));
437            }
438            TerminatorKind::FalseEdge { real_target, .. } => {
439                successors.push((env, *real_target));
440            }
441            TerminatorKind::FalseUnwind { real_target, .. } => {
442                successors.push((env, *real_target));
443            }
444            TerminatorKind::Unreachable
445            | TerminatorKind::UnwindResume
446            | TerminatorKind::CoroutineDrop => {}
447        }
448        Ok(successors)
449    }
450
451    fn goto(&mut self, target: BasicBlock, env: Env) -> QueryResult {
452        if self.body.is_join_point(target) {
453            if M::goto_join_point(self, target, env)? {
454                self.queue.insert(target);
455            }
456            Ok(())
457        } else {
458            self.basic_block(target, env)
459        }
460    }
461}
462
463impl<'a, 'genv, 'tcx, M> FoldUnfoldAnalysis<'a, 'genv, 'tcx, M> {
464    pub(crate) fn new(
465        genv: GlobalEnv<'genv, 'tcx>,
466        body: &'a Body<'tcx>,
467        bb_envs: &'a mut UnordMap<BasicBlock, Env>,
468        mode: M,
469    ) -> Self {
470        Self {
471            genv,
472            body,
473            bb_envs,
474            discriminants: Default::default(),
475            point: Point::FunEntry,
476            visited: DenseBitSet::new_empty(body.basic_blocks.len()),
477            queue: WorkQueue::empty(body.basic_blocks.len(), &body.dominator_order_rank),
478            mode,
479        }
480    }
481}
482
483impl PlaceNode {
484    fn deref(&mut self) -> (&mut PlaceNode, Modified) {
485        match self {
486            PlaceNode::Deref(_, node) => (node, false),
487            PlaceNode::Ty(ty) => {
488                *self = PlaceNode::Deref(ty.clone(), Box::new(PlaceNode::Ty(ty.deref())));
489                let PlaceNode::Deref(_, node) = self else { unreachable!() };
490                (node, true)
491            }
492            _ => tracked_span_bug!("deref of non-deref place: `{:?}`", self),
493        }
494    }
495
496    fn downcast(
497        &mut self,
498        genv: GlobalEnv,
499        idx: VariantIdx,
500    ) -> QueryResult<(&mut PlaceNode, Modified)> {
501        match self {
502            PlaceNode::Downcast(.., idx2, _) => {
503                debug_assert_eq!(idx, *idx2);
504                Ok((self, false))
505            }
506            PlaceNode::Ty(ty) => {
507                if let TyKind::Adt(adt_def, args) = ty.kind() {
508                    let fields = downcast(genv, adt_def, args, idx)?;
509                    *self = PlaceNode::Downcast(adt_def.clone(), args.clone(), idx, fields);
510                    Ok((self, true))
511                } else {
512                    tracked_span_bug!("invalid downcast `{self:?}`");
513                }
514            }
515            _ => tracked_span_bug!("invalid downcast `{self:?}`"),
516        }
517    }
518
519    fn field(&mut self, genv: GlobalEnv, f: FieldIdx) -> QueryResult<(&mut PlaceNode, Modified)> {
520        let (fields, unfolded) = self.fields(genv)?;
521        Ok((&mut fields[f.as_usize()], unfolded))
522    }
523
524    fn fields(&mut self, genv: GlobalEnv) -> QueryResult<(&mut Vec<PlaceNode>, bool)> {
525        match self {
526            PlaceNode::Ty(ty) => {
527                let fields = match ty.kind() {
528                    TyKind::Adt(adt_def, args) => {
529                        let fields = downcast_struct(genv, adt_def, args)?;
530                        *self = PlaceNode::Downcast(
531                            adt_def.clone(),
532                            args.clone(),
533                            FIRST_VARIANT,
534                            fields,
535                        );
536                        let PlaceNode::Downcast(.., fields) = self else { unreachable!() };
537                        fields
538                    }
539                    TyKind::Closure(def_id, args) => {
540                        let fields = args
541                            .as_closure()
542                            .upvar_tys()
543                            .iter()
544                            .cloned()
545                            .map(PlaceNode::Ty)
546                            .collect_vec();
547                        *self = PlaceNode::Closure(*def_id, args.clone(), fields);
548                        let PlaceNode::Closure(.., fields) = self else { unreachable!() };
549                        fields
550                    }
551                    TyKind::Tuple(fields) => {
552                        let node_fields = fields.iter().cloned().map(PlaceNode::Ty).collect();
553                        *self = PlaceNode::Tuple(fields.clone(), node_fields);
554                        let PlaceNode::Tuple(.., fields) = self else { unreachable!() };
555                        fields
556                    }
557                    TyKind::Coroutine(def_id, args) => {
558                        let fields = args
559                            .as_coroutine()
560                            .upvar_tys()
561                            .cloned()
562                            .map(PlaceNode::Ty)
563                            .collect_vec();
564                        *self = PlaceNode::Generator(*def_id, args.clone(), fields);
565                        let PlaceNode::Generator(.., fields) = self else { unreachable!() };
566                        fields
567                    }
568                    _ => tracked_span_bug!("implicit downcast of non-struct: `{ty:?}`"),
569                };
570                Ok((fields, true))
571            }
572            PlaceNode::Downcast(.., fields)
573            | PlaceNode::Tuple(.., fields)
574            | PlaceNode::Closure(.., fields)
575            | PlaceNode::Generator(.., fields) => Ok((fields, false)),
576            PlaceNode::Deref(..) => {
577                tracked_span_bug!("projection field of non-adt non-tuple place: `{self:?}`")
578            }
579        }
580    }
581
582    fn ensure_folded(&mut self) -> Modified {
583        match self {
584            PlaceNode::Deref(ty, _) => {
585                *self = PlaceNode::Ty(ty.clone());
586                true
587            }
588            PlaceNode::Downcast(adt, args, ..) => {
589                *self = PlaceNode::Ty(Ty::mk_adt(adt.clone(), args.clone()));
590                true
591            }
592            PlaceNode::Closure(did, args, _) => {
593                *self = PlaceNode::Ty(Ty::mk_closure(*did, args.clone()));
594                true
595            }
596            PlaceNode::Generator(did, args, _) => {
597                *self = PlaceNode::Ty(Ty::mk_coroutine(*did, args.clone()));
598                true
599            }
600            PlaceNode::Tuple(fields, ..) => {
601                *self = PlaceNode::Ty(Ty::mk_tuple(fields.clone()));
602                true
603            }
604            PlaceNode::Ty(_) => false,
605        }
606    }
607
608    fn join(
609        &mut self,
610        genv: GlobalEnv,
611        other: &mut PlaceNode,
612        in_mut_ref: bool,
613    ) -> QueryResult<(bool, bool)> {
614        let mut modified1 = false;
615        let mut modified2 = false;
616
617        let (fields1, fields2) = match (&mut *self, &mut *other) {
618            (PlaceNode::Deref(ty1, node1), PlaceNode::Deref(ty2, node2)) => {
619                debug_assert_eq!(ty1, ty2);
620                return node1.join(genv, node2, in_mut_ref || ty1.is_mut_ref());
621            }
622            (PlaceNode::Tuple(_, fields1), PlaceNode::Tuple(_, fields2)) => (fields1, fields2),
623            (PlaceNode::Closure(.., fields1), PlaceNode::Closure(.., fields2)) => {
624                (fields1, fields2)
625            }
626            (PlaceNode::Generator(.., fields1), PlaceNode::Generator(.., fields2)) => {
627                (fields1, fields2)
628            }
629            (
630                PlaceNode::Downcast(adt1, args1, variant1, fields1),
631                PlaceNode::Downcast(adt2, args2, variant2, fields2),
632            ) => {
633                debug_assert_eq!(adt1, adt2);
634                if variant1 == variant2 {
635                    (fields1, fields2)
636                } else {
637                    *self = PlaceNode::Ty(Ty::mk_adt(adt1.clone(), args1.clone()));
638                    *other = PlaceNode::Ty(Ty::mk_adt(adt2.clone(), args2.clone()));
639                    return Ok((true, true));
640                }
641            }
642            (PlaceNode::Ty(_), PlaceNode::Ty(_)) => return Ok((false, false)),
643            (PlaceNode::Ty(_), _) => {
644                let (m1, m2) = other.join(genv, self, in_mut_ref)?;
645                return Ok((m2, m1));
646            }
647            (PlaceNode::Deref(ty, _), _) => {
648                *self = PlaceNode::Ty(ty.clone());
649                return Ok((true, false));
650            }
651            (PlaceNode::Tuple(_, fields1), _) => {
652                let (fields2, m) = other.fields(genv)?;
653                modified2 |= m;
654                (fields1, fields2)
655            }
656            (PlaceNode::Closure(.., fields1), _) | (PlaceNode::Generator(.., fields1), _) => {
657                let (fields2, m) = other.fields(genv)?;
658                modified2 |= m;
659                (fields1, fields2)
660            }
661
662            (PlaceNode::Downcast(adt, args, .., fields1), _) => {
663                if adt.is_struct() && !in_mut_ref {
664                    let (fields2, m) = other.fields(genv)?;
665                    modified2 |= m;
666                    (fields1, fields2)
667                } else {
668                    *self = PlaceNode::Ty(Ty::mk_adt(adt.clone(), args.clone()));
669                    return Ok((true, false));
670                }
671            }
672        };
673        for (node1, node2) in iter::zip(fields1, fields2) {
674            let (m1, m2) = node1.join(genv, node2, in_mut_ref)?;
675            modified1 |= m1;
676            modified2 |= m2;
677        }
678        Ok((modified1, modified2))
679    }
680
681    /// Collect necessary fold/unfold operations such that `self` is unfolded at the same level than `target`
682    fn collect_fold_unfolds(
683        &self,
684        target: &PlaceNode,
685        place: &mut Place,
686        stmts: &mut StatementsAt,
687    ) {
688        let (fields1, fields2) = match (self, target) {
689            (PlaceNode::Deref(_, node1), PlaceNode::Deref(_, node2)) => {
690                place.projection.push(PlaceElem::Deref);
691                node1.collect_fold_unfolds(node2, place, stmts);
692                place.projection.pop();
693                return;
694            }
695            (PlaceNode::Tuple(_, fields1), PlaceNode::Tuple(_, fields2)) => (fields1, fields2),
696            (PlaceNode::Closure(.., fields1), PlaceNode::Closure(.., fields2))
697            | (PlaceNode::Generator(.., fields1), PlaceNode::Generator(.., fields2)) => {
698                (fields1, fields2)
699            }
700            (
701                PlaceNode::Downcast(adt1, .., idx1, fields1),
702                PlaceNode::Downcast(adt2, .., idx2, fields2),
703            ) => {
704                tracked_span_dbg_assert_eq!(adt1.did(), adt2.did());
705                tracked_span_dbg_assert_eq!(idx1, idx2);
706                (fields1, fields2)
707            }
708            (PlaceNode::Ty(_), PlaceNode::Ty(_)) => return,
709            (PlaceNode::Ty(_), _) => {
710                target.collect_unfolds(place, stmts);
711                return;
712            }
713            (_, PlaceNode::Ty(_)) => {
714                stmts.insert(GhostStatement::Fold(place.clone()));
715                return;
716            }
717            _ => tracked_span_bug!("{self:?} {target:?}"),
718        };
719        for (i, (node1, node2)) in iter::zip(fields1, fields2).enumerate() {
720            place.projection.push(PlaceElem::Field(FieldIdx::new(i)));
721            node1.collect_fold_unfolds(node2, place, stmts);
722            place.projection.pop();
723        }
724    }
725
726    fn collect_unfolds(&self, place: &mut Place, stmts: &mut StatementsAt) {
727        match self {
728            PlaceNode::Ty(_) => {}
729            PlaceNode::Deref(_, node) => {
730                if node.is_ty() {
731                    stmts.insert(GhostStatement::Unfold(place.clone()));
732                } else {
733                    place.projection.push(PlaceElem::Deref);
734                    node.collect_unfolds(place, stmts);
735                    place.projection.pop();
736                }
737            }
738            PlaceNode::Downcast(.., fields)
739            | PlaceNode::Closure(.., fields)
740            | PlaceNode::Generator(.., fields)
741            | PlaceNode::Tuple(.., fields) => {
742                let all_leaves = fields.iter().all(PlaceNode::is_ty);
743                if all_leaves {
744                    stmts.insert(GhostStatement::Unfold(place.clone()));
745                } else {
746                    if let Some(idx) = self.enum_variant() {
747                        place.projection.push(PlaceElem::Downcast(None, idx));
748                    }
749                    for (i, node) in fields.iter().enumerate() {
750                        place.projection.push(PlaceElem::Field(FieldIdx::new(i)));
751                        node.collect_unfolds(place, stmts);
752                        place.projection.pop();
753                    }
754                    if self.enum_variant().is_some() {
755                        place.projection.pop();
756                    }
757                }
758            }
759        }
760    }
761
762    fn collect_folds_at_ret(&self, place: &mut Place, stmts: &mut StatementsAt) {
763        let fields = match self {
764            PlaceNode::Deref(ty, deref_ty) => {
765                place.projection.push(PlaceElem::Deref);
766                if ty.is_mut_ref() {
767                    stmts.insert(GhostStatement::Fold(place.clone()));
768                } else if ty.is_box() {
769                    deref_ty.collect_folds_at_ret(place, stmts);
770                }
771                place.projection.pop();
772                return;
773            }
774            PlaceNode::Downcast(adt, _, idx, fields) => {
775                if adt.is_enum() {
776                    place.projection.push(PlaceElem::Downcast(None, *idx));
777                }
778                fields
779            }
780            PlaceNode::Closure(_, _, fields)
781            | PlaceNode::Generator(_, _, fields)
782            | PlaceNode::Tuple(_, fields) => fields,
783            PlaceNode::Ty(_) => return,
784        };
785        for (i, node) in fields.iter().enumerate() {
786            place.projection.push(PlaceElem::Field(FieldIdx::new(i)));
787            node.collect_folds_at_ret(place, stmts);
788            place.projection.pop();
789        }
790        if let PlaceNode::Downcast(adt, ..) = self
791            && adt.is_enum()
792        {
793            place.projection.pop();
794        }
795    }
796
797    fn enum_variant(&self) -> Option<VariantIdx> {
798        if let PlaceNode::Downcast(adt, _, idx, _) = self
799            && adt.is_enum()
800        {
801            Some(*idx)
802        } else {
803            None
804        }
805    }
806
807    /// Returns `true` if the place node is [`Ty`].
808    ///
809    /// [`Ty`]: PlaceNode::Ty
810    #[must_use]
811    fn is_ty(&self) -> bool {
812        matches!(self, Self::Ty(..))
813    }
814}
815
816fn downcast(
817    genv: GlobalEnv,
818    adt_def: &AdtDef,
819    args: &GenericArgs,
820    variant: VariantIdx,
821) -> QueryResult<Vec<PlaceNode>> {
822    adt_def
823        .variant(variant)
824        .fields
825        .iter()
826        .map(|field| {
827            let ty = genv.lower_type_of(field.did)?.subst(args);
828            QueryResult::Ok(PlaceNode::Ty(ty))
829        })
830        .try_collect()
831}
832
833fn downcast_struct(
834    genv: GlobalEnv,
835    adt_def: &AdtDef,
836    args: &GenericArgs,
837) -> QueryResult<Vec<PlaceNode>> {
838    adt_def
839        .non_enum_variant()
840        .fields
841        .iter()
842        .map(|field| {
843            let ty = genv.lower_type_of(field.did)?.subst(args);
844            QueryResult::Ok(PlaceNode::Ty(ty))
845        })
846        .try_collect()
847}
848
849impl fmt::Debug for Env {
850    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
851        write!(
852            f,
853            "{}",
854            self.map
855                .iter_enumerated()
856                .format_with(", ", |(local, node), f| f(&format_args!("{local:?}: {node:?}")))
857        )
858    }
859}
860
861impl fmt::Debug for PlaceNode {
862    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
863        match self {
864            PlaceNode::Deref(_, node) => write!(f, "*({node:?})"),
865            PlaceNode::Downcast(adt, args, variant, fields) => {
866                write!(f, "{}", def_id_to_string(adt.did()))?;
867                if !args.is_empty() {
868                    write!(f, "<{:?}>", args.iter().format(", "),)?;
869                }
870                write!(f, "::{}", adt.variant(*variant).name)?;
871                if !fields.is_empty() {
872                    write!(f, "({:?})", fields.iter().format(", "),)?;
873                }
874                Ok(())
875            }
876            PlaceNode::Closure(did, args, fields) => {
877                write!(f, "Closure {}", def_id_to_string(*did))?;
878                if !args.is_empty() {
879                    write!(f, "<{:?}>", args.iter().format(", "),)?;
880                }
881                if !fields.is_empty() {
882                    write!(f, "({:?})", fields.iter().format(", "),)?;
883                }
884                Ok(())
885            }
886            PlaceNode::Generator(did, args, fields) => {
887                write!(f, "Generator {}", def_id_to_string(*did))?;
888                if !args.is_empty() {
889                    write!(f, "<{:?}>", args.iter().format(", "),)?;
890                }
891                if !fields.is_empty() {
892                    write!(f, "({:?})", fields.iter().format(", "),)?;
893                }
894                Ok(())
895            }
896            PlaceNode::Tuple(_, fields) => write!(f, "({:?})", fields.iter().format(", ")),
897            PlaceNode::Ty(ty) => write!(f, "•{ty:?}"),
898        }
899    }
900}