flux_refineck/
ghost_statements.rs

1//! Ghost statements are statements that are not part of the original mir, but are added from information
2//! extracted from the compiler or some additional analysis.
3mod fold_unfold;
4mod points_to;
5
6use std::{fmt, iter};
7
8use flux_common::bug;
9use flux_middle::{global_env::GlobalEnv, queries::QueryResult};
10use flux_rustc_bridge::{
11    lowering,
12    mir::{BasicBlock, Body, BodyRoot, Place},
13};
14use rustc_data_structures::unord::UnordMap;
15use rustc_hash::{FxHashMap, FxHashSet};
16use rustc_hir::{def::DefKind, def_id::LocalDefId};
17use rustc_middle::{
18    mir::{BorrowKind, Location, Promoted, START_BLOCK},
19    ty::TyCtxt,
20};
21
22type LocationMap = UnordMap<Location, Vec<GhostStatement>>;
23type EdgeMap = UnordMap<BasicBlock, FxHashMap<BasicBlock, Vec<GhostStatement>>>;
24
25/// A type to indicate _who_ the ghost statements are for: either a regular `DefId` (including closures)  a promoted body.
26#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
27pub enum CheckerId {
28    /// A regular function or closure
29    DefId(LocalDefId),
30    /// A promoted body (within a function or closure)
31    Promoted(LocalDefId, Promoted),
32}
33
34impl CheckerId {
35    pub fn root_id(&self) -> LocalDefId {
36        match self {
37            CheckerId::DefId(def_id) => *def_id,
38            CheckerId::Promoted(def_id, _) => *def_id,
39        }
40    }
41
42    pub fn is_promoted(&self) -> bool {
43        matches!(self, CheckerId::Promoted(_, _))
44    }
45}
46
47pub(crate) fn compute_ghost_statements(
48    genv: GlobalEnv,
49    def_id: LocalDefId,
50) -> QueryResult<UnordMap<CheckerId, GhostStatements>> {
51    let mut data = UnordMap::default();
52    for def_id in all_nested_bodies(genv.tcx(), def_id) {
53        let key = CheckerId::DefId(def_id);
54        data.insert(key, GhostStatements::new(genv, key)?);
55        for promoted in genv.mir(def_id)?.promoted.indices() {
56            let key = CheckerId::Promoted(def_id, promoted);
57            data.insert(key, GhostStatements::new(genv, key)?);
58        }
59    }
60    Ok(data)
61}
62
63pub(crate) enum GhostStatement {
64    Fold(Place),
65    Unfold(Place),
66    Unblock(Place),
67    PtrToRef(Place),
68}
69
70impl fmt::Debug for GhostStatement {
71    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72        match self {
73            GhostStatement::Fold(place) => write!(f, "fold({place:?})"),
74            GhostStatement::Unfold(place) => write!(f, "unfold({place:?})"),
75            GhostStatement::Unblock(place) => write!(f, "unblock({place:?})"),
76            GhostStatement::PtrToRef(place) => write!(f, "ptr_to_ref({place:?})"),
77        }
78    }
79}
80
81pub(crate) struct GhostStatements {
82    at_start: Vec<GhostStatement>,
83    at_location: LocationMap,
84    at_edge: EdgeMap,
85}
86
87impl GhostStatements {
88    fn new(genv: GlobalEnv, checker_id: CheckerId) -> QueryResult<Self> {
89        let def_id = checker_id.root_id();
90        let body_root = genv.mir(def_id)?;
91        let body = match checker_id {
92            CheckerId::DefId(_) => &body_root.body,
93            CheckerId::Promoted(_, promoted) => &body_root.promoted[promoted],
94        };
95
96        bug::track_span(body.span(), || {
97            let mut stmts = Self {
98                at_start: Default::default(),
99                at_location: LocationMap::default(),
100                at_edge: EdgeMap::default(),
101            };
102
103            // We have fn_sig for function items, but not for closures, generators, statics, or promoteds.
104            let fn_sig =
105                if matches!(genv.def_kind(def_id), DefKind::Closure | DefKind::Static { .. })
106                    || checker_id.is_promoted()
107                {
108                    None
109                } else {
110                    Some(genv.fn_sig(def_id)?)
111                };
112
113            fold_unfold::add_ghost_statements(&mut stmts, genv, body, fn_sig.as_ref())?;
114            points_to::add_ghost_statements(&mut stmts, genv, body.rustc_body(), fn_sig.as_ref())?;
115            // We only add unblock statements for the main body because borrows in promoted constants
116            // have to be live in the main body so they never go out of scope in the promoted body.
117            if !checker_id.is_promoted() {
118                stmts.add_unblocks(genv.tcx(), &body_root);
119            }
120            stmts.dump_ghost_mir(genv.tcx(), body);
121
122            Ok(stmts)
123        })
124    }
125
126    fn add_unblocks<'tcx>(&mut self, tcx: TyCtxt<'tcx>, body_root: &BodyRoot<'tcx>) {
127        for (location, borrows) in body_root.calculate_borrows_out_of_scope_at_location() {
128            let stmts = borrows
129                .into_iter()
130                .filter(|bidx| {
131                    // Only mutable borrows of owned places ever block the place (via the
132                    // `ptr(mut, ℓ)` → `&mut` conversion in `TypeEnv::ptr_to_ref`). Shared and
133                    // fake borrows never block, so emitting an `Unblock` for them is at best a
134                    // no-op and at worst an ICE: the borrowed place may have been folded back to
135                    // a struct (e.g. `&s` live alongside `&s.b`), which `PlacesTree::unblock`
136                    // cannot traverse.
137                    matches!(body_root.borrow_data(*bidx).kind(), BorrowKind::Mut { .. })
138                })
139                .map(|bidx| {
140                    let borrow = body_root.borrow_data(bidx);
141                    let place = lowering::lower_place(tcx, &borrow.borrowed_place()).unwrap();
142                    GhostStatement::Unblock(place)
143                });
144            self.at_location.entry(location).or_default().extend(stmts);
145        }
146    }
147
148    fn insert_at(&mut self, point: Point, stmt: GhostStatement) {
149        self.extend_at(point, [stmt]);
150    }
151
152    fn extend_at(&mut self, point: Point, stmts: impl IntoIterator<Item = GhostStatement>) {
153        match point {
154            Point::FunEntry => {
155                self.at_start.extend(stmts);
156            }
157            Point::BeforeLocation(location) => {
158                self.at_location.entry(location).or_default().extend(stmts);
159            }
160            Point::Edge(from, to) => {
161                self.at_edge
162                    .entry(from)
163                    .or_default()
164                    .entry(to)
165                    .or_default()
166                    .extend(stmts);
167            }
168        }
169    }
170
171    fn at(&mut self, point: Point) -> StatementsAt<'_> {
172        StatementsAt { stmts: self, point }
173    }
174
175    pub(crate) fn statements_at(&self, point: Point) -> impl Iterator<Item = &GhostStatement> {
176        match point {
177            Point::FunEntry => Some(&self.at_start).into_iter().flatten(),
178            Point::BeforeLocation(location) => {
179                self.at_location.get(&location).into_iter().flatten()
180            }
181            Point::Edge(from, to) => {
182                self.at_edge
183                    .get(&from)
184                    .and_then(|m| m.get(&to))
185                    .into_iter()
186                    .flatten()
187            }
188        }
189    }
190
191    pub(crate) fn dump_ghost_mir<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) {
192        use rustc_middle::mir::{PassWhere, pretty::MirDumper};
193        if let Some(dumper) = MirDumper::new(tcx, "ghost", body.rustc_body()) {
194            dumper
195                .set_extra_data(&|pass, w| {
196                    match pass {
197                        PassWhere::BeforeBlock(bb) if bb == START_BLOCK => {
198                            for stmt in &self.at_start {
199                                writeln!(w, "    {stmt:?};")?;
200                            }
201                        }
202                        PassWhere::BeforeLocation(location) => {
203                            for stmt in self.statements_at(Point::BeforeLocation(location)) {
204                                writeln!(w, "        {stmt:?};")?;
205                            }
206                        }
207                        PassWhere::AfterTerminator(bb) => {
208                            if let Some(map) = self.at_edge.get(&bb) {
209                                writeln!(w)?;
210                                for (target, stmts) in map {
211                                    write!(w, "        -> {target:?} {{")?;
212                                    for stmt in stmts {
213                                        write!(w, "\n            {stmt:?};")?;
214                                    }
215                                    write!(w, "\n        }}")?;
216                                }
217                                writeln!(w)?;
218                            }
219                        }
220                        _ => {}
221                    }
222                    Ok(())
223                })
224                .dump_mir(body.rustc_body());
225        }
226    }
227}
228
229/// A point in the control flow graph where ghost statements can be inserted.
230#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
231pub(crate) enum Point {
232    /// The entry of the function before the first basic block. This is not the same as the first
233    /// location in the first basic block because, for some functions, the first basic block can have
234    /// incoming edges, and we want to execute ghost statements only once.
235    FunEntry,
236    /// The point before a location in a basic block.
237    BeforeLocation(Location),
238    /// An edge between two basic blocks.
239    Edge(BasicBlock, BasicBlock),
240}
241
242struct StatementsAt<'a> {
243    stmts: &'a mut GhostStatements,
244    point: Point,
245}
246
247impl StatementsAt<'_> {
248    fn insert(&mut self, stmt: GhostStatement) {
249        self.stmts.insert_at(self.point, stmt);
250    }
251}
252
253fn all_nested_bodies(tcx: TyCtxt, def_id: LocalDefId) -> impl Iterator<Item = LocalDefId> {
254    use rustc_hir as hir;
255    struct ClosureFinder<'tcx> {
256        tcx: TyCtxt<'tcx>,
257        closures: FxHashSet<LocalDefId>,
258    }
259
260    impl<'tcx> rustc_hir::intravisit::Visitor<'tcx> for ClosureFinder<'tcx> {
261        type NestedFilter = rustc_middle::hir::nested_filter::OnlyBodies;
262
263        fn maybe_tcx(&mut self) -> Self::MaybeTyCtxt {
264            self.tcx
265        }
266
267        fn visit_expr(&mut self, ex: &'tcx hir::Expr<'tcx>) {
268            if let hir::ExprKind::Closure(closure) = ex.kind {
269                self.closures.insert(closure.def_id);
270            }
271
272            hir::intravisit::walk_expr(self, ex);
273        }
274    }
275    let body = tcx.hir_body_owned_by(def_id).value;
276    let mut finder = ClosureFinder { tcx, closures: FxHashSet::default() };
277    hir::intravisit::Visitor::visit_expr(&mut finder, body);
278    finder.closures.into_iter().chain(iter::once(def_id))
279}