flux_refineck/
ghost_statements.rs

1//! Ghost statements are statements that are not part of the original mir, but are added from information
2//! extracted from the compiler or some additional analysis.
3mod fold_unfold;
4mod points_to;
5
6use std::{fmt, io, iter};
7
8use flux_common::{bug, dbg};
9use flux_config as config;
10use flux_middle::{global_env::GlobalEnv, queries::QueryResult};
11use flux_rustc_bridge::{
12    lowering,
13    mir::{BasicBlock, Body, Place},
14};
15use rustc_data_structures::unord::UnordMap;
16use rustc_hash::{FxHashMap, FxHashSet};
17use rustc_hir::{def::DefKind, def_id::LocalDefId};
18use rustc_middle::{
19    mir::{Location, START_BLOCK},
20    ty::TyCtxt,
21};
22
23type LocationMap = FxHashMap<Location, Vec<GhostStatement>>;
24type EdgeMap = FxHashMap<BasicBlock, FxHashMap<BasicBlock, Vec<GhostStatement>>>;
25
26pub(crate) fn compute_ghost_statements(
27    genv: GlobalEnv,
28    def_id: LocalDefId,
29) -> QueryResult<UnordMap<LocalDefId, GhostStatements>> {
30    let mut data = UnordMap::default();
31    for def_id in all_nested_bodies(genv.tcx(), def_id) {
32        data.insert(def_id, GhostStatements::new(genv, def_id)?);
33    }
34    Ok(data)
35}
36
37pub(crate) enum GhostStatement {
38    Fold(Place),
39    Unfold(Place),
40    Unblock(Place),
41    PtrToRef(Place),
42}
43
44impl fmt::Debug for GhostStatement {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        match self {
47            GhostStatement::Fold(place) => write!(f, "fold({place:?})"),
48            GhostStatement::Unfold(place) => write!(f, "unfold({place:?})"),
49            GhostStatement::Unblock(place) => write!(f, "unblock({place:?})"),
50            GhostStatement::PtrToRef(place) => write!(f, "ptr_to_ref({place:?})"),
51        }
52    }
53}
54
55pub(crate) struct GhostStatements {
56    at_start: Vec<GhostStatement>,
57    at_location: LocationMap,
58    at_edge: EdgeMap,
59}
60
61impl GhostStatements {
62    fn new(genv: GlobalEnv, def_id: LocalDefId) -> QueryResult<Self> {
63        let body = genv.mir(def_id)?;
64
65        bug::track_span(body.span(), || {
66            let mut stmts = Self {
67                at_start: Default::default(),
68                at_location: LocationMap::default(),
69                at_edge: EdgeMap::default(),
70            };
71
72            // We have fn_sig for function items, but not for closures or generators.
73            let fn_sig = if genv.def_kind(def_id) == DefKind::Closure {
74                None
75            } else {
76                Some(genv.fn_sig(def_id)?)
77            };
78
79            fold_unfold::add_ghost_statements(&mut stmts, genv, &body, fn_sig.as_ref())?;
80            points_to::add_ghost_statements(&mut stmts, genv, body.rustc_body(), fn_sig.as_ref())?;
81            stmts.add_unblocks(genv.tcx(), &body);
82
83            if config::dump_mir() {
84                let mut writer =
85                    dbg::writer_for_item(genv.tcx(), def_id.to_def_id(), "ghost.mir").unwrap();
86                stmts.write_mir(genv.tcx(), &body, &mut writer).unwrap();
87            }
88            Ok(stmts)
89        })
90    }
91
92    fn add_unblocks<'tcx>(&mut self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) {
93        for (location, borrows) in body.calculate_borrows_out_of_scope_at_location() {
94            let stmts = borrows.into_iter().map(|bidx| {
95                let borrow = body.borrow_data(bidx);
96                let place = lowering::lower_place(tcx, &borrow.borrowed_place()).unwrap();
97                GhostStatement::Unblock(place)
98            });
99            self.at_location.entry(location).or_default().extend(stmts);
100        }
101    }
102
103    fn insert_at(&mut self, point: Point, stmt: GhostStatement) {
104        self.extend_at(point, [stmt]);
105    }
106
107    fn extend_at(&mut self, point: Point, stmts: impl IntoIterator<Item = GhostStatement>) {
108        match point {
109            Point::FunEntry => {
110                self.at_start.extend(stmts);
111            }
112            Point::BeforeLocation(location) => {
113                self.at_location.entry(location).or_default().extend(stmts);
114            }
115            Point::Edge(from, to) => {
116                self.at_edge
117                    .entry(from)
118                    .or_default()
119                    .entry(to)
120                    .or_default()
121                    .extend(stmts);
122            }
123        }
124    }
125
126    fn at(&mut self, point: Point) -> StatementsAt {
127        StatementsAt { stmts: self, point }
128    }
129
130    pub(crate) fn statements_at(&self, point: Point) -> impl Iterator<Item = &GhostStatement> {
131        match point {
132            Point::FunEntry => Some(&self.at_start).into_iter().flatten(),
133            Point::BeforeLocation(location) => {
134                self.at_location.get(&location).into_iter().flatten()
135            }
136            Point::Edge(from, to) => {
137                self.at_edge
138                    .get(&from)
139                    .and_then(|m| m.get(&to))
140                    .into_iter()
141                    .flatten()
142            }
143        }
144    }
145
146    pub(crate) fn write_mir<'tcx, W: io::Write>(
147        &self,
148        tcx: TyCtxt<'tcx>,
149        body: &Body<'tcx>,
150        w: &mut W,
151    ) -> io::Result<()> {
152        use rustc_middle::mir::PassWhere;
153        rustc_middle::mir::pretty::write_mir_fn(
154            tcx,
155            body.inner(),
156            &mut |pass, w| {
157                match pass {
158                    PassWhere::BeforeBlock(bb) if bb == START_BLOCK => {
159                        for stmt in &self.at_start {
160                            writeln!(w, "    {stmt:?};")?;
161                        }
162                    }
163                    PassWhere::BeforeLocation(location) => {
164                        for stmt in self.statements_at(Point::BeforeLocation(location)) {
165                            writeln!(w, "        {stmt:?};")?;
166                        }
167                    }
168                    PassWhere::AfterTerminator(bb) => {
169                        if let Some(map) = self.at_edge.get(&bb) {
170                            writeln!(w)?;
171                            for (target, stmts) in map {
172                                write!(w, "        -> {target:?} {{")?;
173                                for stmt in stmts {
174                                    write!(w, "\n            {stmt:?};")?;
175                                }
176                                write!(w, "\n        }}")?;
177                            }
178                            writeln!(w)?;
179                        }
180                    }
181                    _ => {}
182                }
183                Ok(())
184            },
185            w,
186            rustc_middle::mir::pretty::PrettyPrintMirOptions::from_cli(tcx),
187        )
188    }
189}
190
191/// A point in the control flow graph where ghost statements can be inserted.
192#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
193pub(crate) enum Point {
194    /// The entry of the function before the first basic block. This is not the same as the first
195    /// location in the first basic block because, for some functions, the first basic block can have
196    /// incoming edges, and we want to execute ghost statements only once.
197    FunEntry,
198    /// The point before a location in a basic block.
199    BeforeLocation(Location),
200    /// An edge between two basic blocks.
201    Edge(BasicBlock, BasicBlock),
202}
203
204struct StatementsAt<'a> {
205    stmts: &'a mut GhostStatements,
206    point: Point,
207}
208
209impl StatementsAt<'_> {
210    fn insert(&mut self, stmt: GhostStatement) {
211        self.stmts.insert_at(self.point, stmt);
212    }
213}
214
215fn all_nested_bodies(tcx: TyCtxt, def_id: LocalDefId) -> impl Iterator<Item = LocalDefId> {
216    use rustc_hir as hir;
217    struct ClosureFinder<'tcx> {
218        tcx: TyCtxt<'tcx>,
219        closures: FxHashSet<LocalDefId>,
220    }
221
222    impl<'tcx> rustc_hir::intravisit::Visitor<'tcx> for ClosureFinder<'tcx> {
223        type NestedFilter = rustc_middle::hir::nested_filter::OnlyBodies;
224
225        fn maybe_tcx(&mut self) -> Self::MaybeTyCtxt {
226            self.tcx
227        }
228
229        fn visit_expr(&mut self, ex: &'tcx hir::Expr<'tcx>) {
230            if let hir::ExprKind::Closure(closure) = ex.kind {
231                self.closures.insert(closure.def_id);
232            }
233
234            hir::intravisit::walk_expr(self, ex);
235        }
236    }
237    let body = tcx.hir_body_owned_by(def_id).value;
238    let mut finder = ClosureFinder { tcx, closures: FxHashSet::default() };
239    hir::intravisit::Visitor::visit_expr(&mut finder, body);
240    finder.closures.into_iter().chain(iter::once(def_id))
241}