flux_refineck/
ghost_statements.rs

1//! Ghost statements are statements that are not part of the original mir, but are added from information
2//! extracted from the compiler or some additional analysis.
3mod fold_unfold;
4mod points_to;
5
6use std::{fmt, iter};
7
8use flux_common::bug;
9use flux_middle::{global_env::GlobalEnv, queries::QueryResult};
10use flux_rustc_bridge::{
11    lowering,
12    mir::{BasicBlock, Body, Place},
13};
14use rustc_data_structures::unord::UnordMap;
15use rustc_hash::{FxHashMap, FxHashSet};
16use rustc_hir::{def::DefKind, def_id::LocalDefId};
17use rustc_middle::{
18    mir::{Location, START_BLOCK},
19    ty::TyCtxt,
20};
21
22type LocationMap = FxHashMap<Location, Vec<GhostStatement>>;
23type EdgeMap = FxHashMap<BasicBlock, FxHashMap<BasicBlock, Vec<GhostStatement>>>;
24
25pub(crate) fn compute_ghost_statements(
26    genv: GlobalEnv,
27    def_id: LocalDefId,
28) -> QueryResult<UnordMap<LocalDefId, GhostStatements>> {
29    let mut data = UnordMap::default();
30    for def_id in all_nested_bodies(genv.tcx(), def_id) {
31        data.insert(def_id, GhostStatements::new(genv, def_id)?);
32    }
33    Ok(data)
34}
35
36pub(crate) enum GhostStatement {
37    Fold(Place),
38    Unfold(Place),
39    Unblock(Place),
40    PtrToRef(Place),
41}
42
43impl fmt::Debug for GhostStatement {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        match self {
46            GhostStatement::Fold(place) => write!(f, "fold({place:?})"),
47            GhostStatement::Unfold(place) => write!(f, "unfold({place:?})"),
48            GhostStatement::Unblock(place) => write!(f, "unblock({place:?})"),
49            GhostStatement::PtrToRef(place) => write!(f, "ptr_to_ref({place:?})"),
50        }
51    }
52}
53
54pub(crate) struct GhostStatements {
55    at_start: Vec<GhostStatement>,
56    at_location: LocationMap,
57    at_edge: EdgeMap,
58}
59
60impl GhostStatements {
61    fn new(genv: GlobalEnv, def_id: LocalDefId) -> QueryResult<Self> {
62        let body = genv.mir(def_id)?;
63
64        bug::track_span(body.span(), || {
65            let mut stmts = Self {
66                at_start: Default::default(),
67                at_location: LocationMap::default(),
68                at_edge: EdgeMap::default(),
69            };
70
71            // We have fn_sig for function items, but not for closures or generators.
72            let fn_sig = if genv.def_kind(def_id) == DefKind::Closure {
73                None
74            } else {
75                Some(genv.fn_sig(def_id)?)
76            };
77
78            fold_unfold::add_ghost_statements(&mut stmts, genv, &body, fn_sig.as_ref())?;
79            points_to::add_ghost_statements(&mut stmts, genv, body.rustc_body(), fn_sig.as_ref())?;
80            stmts.add_unblocks(genv.tcx(), &body);
81
82            stmts.dump_ghost_mir(genv.tcx(), &body);
83
84            Ok(stmts)
85        })
86    }
87
88    fn add_unblocks<'tcx>(&mut self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) {
89        for (location, borrows) in body.calculate_borrows_out_of_scope_at_location() {
90            let stmts = borrows.into_iter().map(|bidx| {
91                let borrow = body.borrow_data(bidx);
92                let place = lowering::lower_place(tcx, &borrow.borrowed_place()).unwrap();
93                GhostStatement::Unblock(place)
94            });
95            self.at_location.entry(location).or_default().extend(stmts);
96        }
97    }
98
99    fn insert_at(&mut self, point: Point, stmt: GhostStatement) {
100        self.extend_at(point, [stmt]);
101    }
102
103    fn extend_at(&mut self, point: Point, stmts: impl IntoIterator<Item = GhostStatement>) {
104        match point {
105            Point::FunEntry => {
106                self.at_start.extend(stmts);
107            }
108            Point::BeforeLocation(location) => {
109                self.at_location.entry(location).or_default().extend(stmts);
110            }
111            Point::Edge(from, to) => {
112                self.at_edge
113                    .entry(from)
114                    .or_default()
115                    .entry(to)
116                    .or_default()
117                    .extend(stmts);
118            }
119        }
120    }
121
122    fn at(&mut self, point: Point) -> StatementsAt<'_> {
123        StatementsAt { stmts: self, point }
124    }
125
126    pub(crate) fn statements_at(&self, point: Point) -> impl Iterator<Item = &GhostStatement> {
127        match point {
128            Point::FunEntry => Some(&self.at_start).into_iter().flatten(),
129            Point::BeforeLocation(location) => {
130                self.at_location.get(&location).into_iter().flatten()
131            }
132            Point::Edge(from, to) => {
133                self.at_edge
134                    .get(&from)
135                    .and_then(|m| m.get(&to))
136                    .into_iter()
137                    .flatten()
138            }
139        }
140    }
141
142    pub(crate) fn dump_ghost_mir<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) {
143        use rustc_middle::mir::{PassWhere, pretty::MirDumper};
144        if let Some(dumper) = MirDumper::new(tcx, "ghost", body.inner()) {
145            dumper
146                .set_extra_data(&|pass, w| {
147                    match pass {
148                        PassWhere::BeforeBlock(bb) if bb == START_BLOCK => {
149                            for stmt in &self.at_start {
150                                writeln!(w, "    {stmt:?};")?;
151                            }
152                        }
153                        PassWhere::BeforeLocation(location) => {
154                            for stmt in self.statements_at(Point::BeforeLocation(location)) {
155                                writeln!(w, "        {stmt:?};")?;
156                            }
157                        }
158                        PassWhere::AfterTerminator(bb) => {
159                            if let Some(map) = self.at_edge.get(&bb) {
160                                writeln!(w)?;
161                                for (target, stmts) in map {
162                                    write!(w, "        -> {target:?} {{")?;
163                                    for stmt in stmts {
164                                        write!(w, "\n            {stmt:?};")?;
165                                    }
166                                    write!(w, "\n        }}")?;
167                                }
168                                writeln!(w)?;
169                            }
170                        }
171                        _ => {}
172                    }
173                    Ok(())
174                })
175                .dump_mir(body.inner());
176        }
177    }
178}
179
180/// A point in the control flow graph where ghost statements can be inserted.
181#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
182pub(crate) enum Point {
183    /// The entry of the function before the first basic block. This is not the same as the first
184    /// location in the first basic block because, for some functions, the first basic block can have
185    /// incoming edges, and we want to execute ghost statements only once.
186    FunEntry,
187    /// The point before a location in a basic block.
188    BeforeLocation(Location),
189    /// An edge between two basic blocks.
190    Edge(BasicBlock, BasicBlock),
191}
192
193struct StatementsAt<'a> {
194    stmts: &'a mut GhostStatements,
195    point: Point,
196}
197
198impl StatementsAt<'_> {
199    fn insert(&mut self, stmt: GhostStatement) {
200        self.stmts.insert_at(self.point, stmt);
201    }
202}
203
204fn all_nested_bodies(tcx: TyCtxt, def_id: LocalDefId) -> impl Iterator<Item = LocalDefId> {
205    use rustc_hir as hir;
206    struct ClosureFinder<'tcx> {
207        tcx: TyCtxt<'tcx>,
208        closures: FxHashSet<LocalDefId>,
209    }
210
211    impl<'tcx> rustc_hir::intravisit::Visitor<'tcx> for ClosureFinder<'tcx> {
212        type NestedFilter = rustc_middle::hir::nested_filter::OnlyBodies;
213
214        fn maybe_tcx(&mut self) -> Self::MaybeTyCtxt {
215            self.tcx
216        }
217
218        fn visit_expr(&mut self, ex: &'tcx hir::Expr<'tcx>) {
219            if let hir::ExprKind::Closure(closure) = ex.kind {
220                self.closures.insert(closure.def_id);
221            }
222
223            hir::intravisit::walk_expr(self, ex);
224        }
225    }
226    let body = tcx.hir_body_owned_by(def_id).value;
227    let mut finder = ClosureFinder { tcx, closures: FxHashSet::default() };
228    hir::intravisit::Visitor::visit_expr(&mut finder, body);
229    finder.closures.into_iter().chain(iter::once(def_id))
230}