flux_refineck/
ghost_statements.rs

1//! Ghost statements are statements that are not part of the original mir, but are added from information
2//! extracted from the compiler or some additional analysis.
3mod fold_unfold;
4mod points_to;
5
6use std::{fmt, iter};
7
8use flux_common::bug;
9use flux_middle::{global_env::GlobalEnv, queries::QueryResult};
10use flux_rustc_bridge::{
11    lowering,
12    mir::{BasicBlock, Body, BodyRoot, Place},
13};
14use rustc_data_structures::unord::UnordMap;
15use rustc_hash::{FxHashMap, FxHashSet};
16use rustc_hir::{def::DefKind, def_id::LocalDefId};
17use rustc_middle::{
18    mir::{Location, Promoted, START_BLOCK},
19    ty::TyCtxt,
20};
21
22type LocationMap = FxHashMap<Location, Vec<GhostStatement>>;
23type EdgeMap = FxHashMap<BasicBlock, FxHashMap<BasicBlock, Vec<GhostStatement>>>;
24
25/// A type to indicate _who_ the ghost statements are for: either a regular `DefId` (including closures)  a promoted body.
26#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
27pub enum CheckerId {
28    /// A regular function or closure
29    DefId(LocalDefId),
30    /// A promoted body (within a function or closure)
31    Promoted(LocalDefId, Promoted),
32}
33
34impl CheckerId {
35    pub fn root_id(&self) -> LocalDefId {
36        match self {
37            CheckerId::DefId(def_id) => *def_id,
38            CheckerId::Promoted(def_id, _) => *def_id,
39        }
40    }
41
42    pub fn is_promoted(&self) -> bool {
43        matches!(self, CheckerId::Promoted(_, _))
44    }
45}
46
47pub(crate) fn compute_ghost_statements(
48    genv: GlobalEnv,
49    def_id: LocalDefId,
50) -> QueryResult<UnordMap<CheckerId, GhostStatements>> {
51    let mut data = UnordMap::default();
52    for def_id in all_nested_bodies(genv.tcx(), def_id) {
53        let key = CheckerId::DefId(def_id);
54        data.insert(key, GhostStatements::new(genv, key)?);
55        for promoted in genv.mir(def_id)?.promoted.indices() {
56            let key = CheckerId::Promoted(def_id, promoted);
57            data.insert(key, GhostStatements::new(genv, key)?);
58        }
59    }
60    Ok(data)
61}
62
63pub(crate) enum GhostStatement {
64    Fold(Place),
65    Unfold(Place),
66    Unblock(Place),
67    PtrToRef(Place),
68}
69
70impl fmt::Debug for GhostStatement {
71    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72        match self {
73            GhostStatement::Fold(place) => write!(f, "fold({place:?})"),
74            GhostStatement::Unfold(place) => write!(f, "unfold({place:?})"),
75            GhostStatement::Unblock(place) => write!(f, "unblock({place:?})"),
76            GhostStatement::PtrToRef(place) => write!(f, "ptr_to_ref({place:?})"),
77        }
78    }
79}
80
81pub(crate) struct GhostStatements {
82    at_start: Vec<GhostStatement>,
83    at_location: LocationMap,
84    at_edge: EdgeMap,
85}
86
87impl GhostStatements {
88    fn new(genv: GlobalEnv, checker_id: CheckerId) -> QueryResult<Self> {
89        let def_id = checker_id.root_id();
90        let body_root = genv.mir(def_id)?;
91        let body = match checker_id {
92            CheckerId::DefId(_) => &body_root.body,
93            CheckerId::Promoted(_, promoted) => &body_root.promoted[promoted],
94        };
95
96        bug::track_span(body.span(), || {
97            let mut stmts = Self {
98                at_start: Default::default(),
99                at_location: LocationMap::default(),
100                at_edge: EdgeMap::default(),
101            };
102
103            // We have fn_sig for function items, but not for closures or generators or promoteds.
104            let fn_sig = if genv.def_kind(def_id) == DefKind::Closure || checker_id.is_promoted() {
105                None
106            } else {
107                Some(genv.fn_sig(def_id)?)
108            };
109
110            fold_unfold::add_ghost_statements(&mut stmts, genv, body, fn_sig.as_ref())?;
111            points_to::add_ghost_statements(&mut stmts, genv, &body.rustc_body, fn_sig.as_ref())?;
112            // We only add unblock statements for the main body because borrows in promoted constants
113            // have to be live in the main body so they never go out of scope in the promoted body.
114            if !checker_id.is_promoted() {
115                stmts.add_unblocks(genv.tcx(), &body_root);
116            }
117            stmts.dump_ghost_mir(genv.tcx(), body);
118
119            Ok(stmts)
120        })
121    }
122
123    fn add_unblocks<'tcx>(&mut self, tcx: TyCtxt<'tcx>, body_root: &BodyRoot<'tcx>) {
124        for (location, borrows) in body_root.calculate_borrows_out_of_scope_at_location() {
125            let stmts = borrows.into_iter().map(|bidx| {
126                let borrow = body_root.borrow_data(bidx);
127                let place = lowering::lower_place(tcx, &borrow.borrowed_place()).unwrap();
128                GhostStatement::Unblock(place)
129            });
130            self.at_location.entry(location).or_default().extend(stmts);
131        }
132    }
133
134    fn insert_at(&mut self, point: Point, stmt: GhostStatement) {
135        self.extend_at(point, [stmt]);
136    }
137
138    fn extend_at(&mut self, point: Point, stmts: impl IntoIterator<Item = GhostStatement>) {
139        match point {
140            Point::FunEntry => {
141                self.at_start.extend(stmts);
142            }
143            Point::BeforeLocation(location) => {
144                self.at_location.entry(location).or_default().extend(stmts);
145            }
146            Point::Edge(from, to) => {
147                self.at_edge
148                    .entry(from)
149                    .or_default()
150                    .entry(to)
151                    .or_default()
152                    .extend(stmts);
153            }
154        }
155    }
156
157    fn at(&mut self, point: Point) -> StatementsAt<'_> {
158        StatementsAt { stmts: self, point }
159    }
160
161    pub(crate) fn statements_at(&self, point: Point) -> impl Iterator<Item = &GhostStatement> {
162        match point {
163            Point::FunEntry => Some(&self.at_start).into_iter().flatten(),
164            Point::BeforeLocation(location) => {
165                self.at_location.get(&location).into_iter().flatten()
166            }
167            Point::Edge(from, to) => {
168                self.at_edge
169                    .get(&from)
170                    .and_then(|m| m.get(&to))
171                    .into_iter()
172                    .flatten()
173            }
174        }
175    }
176
177    pub(crate) fn dump_ghost_mir<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) {
178        use rustc_middle::mir::{PassWhere, pretty::MirDumper};
179        if let Some(dumper) = MirDumper::new(tcx, "ghost", &body.rustc_body) {
180            dumper
181                .set_extra_data(&|pass, w| {
182                    match pass {
183                        PassWhere::BeforeBlock(bb) if bb == START_BLOCK => {
184                            for stmt in &self.at_start {
185                                writeln!(w, "    {stmt:?};")?;
186                            }
187                        }
188                        PassWhere::BeforeLocation(location) => {
189                            for stmt in self.statements_at(Point::BeforeLocation(location)) {
190                                writeln!(w, "        {stmt:?};")?;
191                            }
192                        }
193                        PassWhere::AfterTerminator(bb) => {
194                            if let Some(map) = self.at_edge.get(&bb) {
195                                writeln!(w)?;
196                                for (target, stmts) in map {
197                                    write!(w, "        -> {target:?} {{")?;
198                                    for stmt in stmts {
199                                        write!(w, "\n            {stmt:?};")?;
200                                    }
201                                    write!(w, "\n        }}")?;
202                                }
203                                writeln!(w)?;
204                            }
205                        }
206                        _ => {}
207                    }
208                    Ok(())
209                })
210                .dump_mir(&body.rustc_body);
211        }
212    }
213}
214
215/// A point in the control flow graph where ghost statements can be inserted.
216#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
217pub(crate) enum Point {
218    /// The entry of the function before the first basic block. This is not the same as the first
219    /// location in the first basic block because, for some functions, the first basic block can have
220    /// incoming edges, and we want to execute ghost statements only once.
221    FunEntry,
222    /// The point before a location in a basic block.
223    BeforeLocation(Location),
224    /// An edge between two basic blocks.
225    Edge(BasicBlock, BasicBlock),
226}
227
228struct StatementsAt<'a> {
229    stmts: &'a mut GhostStatements,
230    point: Point,
231}
232
233impl StatementsAt<'_> {
234    fn insert(&mut self, stmt: GhostStatement) {
235        self.stmts.insert_at(self.point, stmt);
236    }
237}
238
239fn all_nested_bodies(tcx: TyCtxt, def_id: LocalDefId) -> impl Iterator<Item = LocalDefId> {
240    use rustc_hir as hir;
241    struct ClosureFinder<'tcx> {
242        tcx: TyCtxt<'tcx>,
243        closures: FxHashSet<LocalDefId>,
244    }
245
246    impl<'tcx> rustc_hir::intravisit::Visitor<'tcx> for ClosureFinder<'tcx> {
247        type NestedFilter = rustc_middle::hir::nested_filter::OnlyBodies;
248
249        fn maybe_tcx(&mut self) -> Self::MaybeTyCtxt {
250            self.tcx
251        }
252
253        fn visit_expr(&mut self, ex: &'tcx hir::Expr<'tcx>) {
254            if let hir::ExprKind::Closure(closure) = ex.kind {
255                self.closures.insert(closure.def_id);
256            }
257
258            hir::intravisit::walk_expr(self, ex);
259        }
260    }
261    let body = tcx.hir_body_owned_by(def_id).value;
262    let mut finder = ClosureFinder { tcx, closures: FxHashSet::default() };
263    hir::intravisit::Visitor::visit_expr(&mut finder, body);
264    finder.closures.into_iter().chain(iter::once(def_id))
265}