1mod fold_unfold;
4mod points_to;
5
6use std::{fmt, io, iter};
7
8use flux_common::{bug, dbg};
9use flux_config as config;
10use flux_middle::{global_env::GlobalEnv, queries::QueryResult};
11use flux_rustc_bridge::{
12 lowering,
13 mir::{BasicBlock, Body, Place},
14};
15use rustc_data_structures::unord::UnordMap;
16use rustc_hash::{FxHashMap, FxHashSet};
17use rustc_hir::{def::DefKind, def_id::LocalDefId};
18use rustc_middle::{
19 mir::{Location, START_BLOCK},
20 ty::TyCtxt,
21};
22
23type LocationMap = FxHashMap<Location, Vec<GhostStatement>>;
24type EdgeMap = FxHashMap<BasicBlock, FxHashMap<BasicBlock, Vec<GhostStatement>>>;
25
26pub(crate) fn compute_ghost_statements(
27 genv: GlobalEnv,
28 def_id: LocalDefId,
29) -> QueryResult<UnordMap<LocalDefId, GhostStatements>> {
30 let mut data = UnordMap::default();
31 for def_id in all_nested_bodies(genv.tcx(), def_id) {
32 data.insert(def_id, GhostStatements::new(genv, def_id)?);
33 }
34 Ok(data)
35}
36
37pub(crate) enum GhostStatement {
38 Fold(Place),
39 Unfold(Place),
40 Unblock(Place),
41 PtrToRef(Place),
42}
43
44impl fmt::Debug for GhostStatement {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 match self {
47 GhostStatement::Fold(place) => write!(f, "fold({place:?})"),
48 GhostStatement::Unfold(place) => write!(f, "unfold({place:?})"),
49 GhostStatement::Unblock(place) => write!(f, "unblock({place:?})"),
50 GhostStatement::PtrToRef(place) => write!(f, "ptr_to_ref({place:?})"),
51 }
52 }
53}
54
55pub(crate) struct GhostStatements {
56 at_start: Vec<GhostStatement>,
57 at_location: LocationMap,
58 at_edge: EdgeMap,
59}
60
61impl GhostStatements {
62 fn new(genv: GlobalEnv, def_id: LocalDefId) -> QueryResult<Self> {
63 let body = genv.mir(def_id)?;
64
65 bug::track_span(body.span(), || {
66 let mut stmts = Self {
67 at_start: Default::default(),
68 at_location: LocationMap::default(),
69 at_edge: EdgeMap::default(),
70 };
71
72 let fn_sig = if genv.def_kind(def_id) == DefKind::Closure {
74 None
75 } else {
76 Some(genv.fn_sig(def_id)?)
77 };
78
79 fold_unfold::add_ghost_statements(&mut stmts, genv, &body, fn_sig.as_ref())?;
80 points_to::add_ghost_statements(&mut stmts, genv, body.rustc_body(), fn_sig.as_ref())?;
81 stmts.add_unblocks(genv.tcx(), &body);
82
83 if config::dump_mir() {
84 let mut writer =
85 dbg::writer_for_item(genv.tcx(), def_id.to_def_id(), "ghost.mir").unwrap();
86 stmts.write_mir(genv.tcx(), &body, &mut writer).unwrap();
87 }
88 Ok(stmts)
89 })
90 }
91
92 fn add_unblocks<'tcx>(&mut self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) {
93 for (location, borrows) in body.calculate_borrows_out_of_scope_at_location() {
94 let stmts = borrows.into_iter().map(|bidx| {
95 let borrow = body.borrow_data(bidx);
96 let place = lowering::lower_place(tcx, &borrow.borrowed_place()).unwrap();
97 GhostStatement::Unblock(place)
98 });
99 self.at_location.entry(location).or_default().extend(stmts);
100 }
101 }
102
103 fn insert_at(&mut self, point: Point, stmt: GhostStatement) {
104 self.extend_at(point, [stmt]);
105 }
106
107 fn extend_at(&mut self, point: Point, stmts: impl IntoIterator<Item = GhostStatement>) {
108 match point {
109 Point::FunEntry => {
110 self.at_start.extend(stmts);
111 }
112 Point::BeforeLocation(location) => {
113 self.at_location.entry(location).or_default().extend(stmts);
114 }
115 Point::Edge(from, to) => {
116 self.at_edge
117 .entry(from)
118 .or_default()
119 .entry(to)
120 .or_default()
121 .extend(stmts);
122 }
123 }
124 }
125
126 fn at(&mut self, point: Point) -> StatementsAt {
127 StatementsAt { stmts: self, point }
128 }
129
130 pub(crate) fn statements_at(&self, point: Point) -> impl Iterator<Item = &GhostStatement> {
131 match point {
132 Point::FunEntry => Some(&self.at_start).into_iter().flatten(),
133 Point::BeforeLocation(location) => {
134 self.at_location.get(&location).into_iter().flatten()
135 }
136 Point::Edge(from, to) => {
137 self.at_edge
138 .get(&from)
139 .and_then(|m| m.get(&to))
140 .into_iter()
141 .flatten()
142 }
143 }
144 }
145
146 pub(crate) fn write_mir<'tcx, W: io::Write>(
147 &self,
148 tcx: TyCtxt<'tcx>,
149 body: &Body<'tcx>,
150 w: &mut W,
151 ) -> io::Result<()> {
152 use rustc_middle::mir::PassWhere;
153 rustc_middle::mir::pretty::write_mir_fn(
154 tcx,
155 body.inner(),
156 &mut |pass, w| {
157 match pass {
158 PassWhere::BeforeBlock(bb) if bb == START_BLOCK => {
159 for stmt in &self.at_start {
160 writeln!(w, " {stmt:?};")?;
161 }
162 }
163 PassWhere::BeforeLocation(location) => {
164 for stmt in self.statements_at(Point::BeforeLocation(location)) {
165 writeln!(w, " {stmt:?};")?;
166 }
167 }
168 PassWhere::AfterTerminator(bb) => {
169 if let Some(map) = self.at_edge.get(&bb) {
170 writeln!(w)?;
171 for (target, stmts) in map {
172 write!(w, " -> {target:?} {{")?;
173 for stmt in stmts {
174 write!(w, "\n {stmt:?};")?;
175 }
176 write!(w, "\n }}")?;
177 }
178 writeln!(w)?;
179 }
180 }
181 _ => {}
182 }
183 Ok(())
184 },
185 w,
186 rustc_middle::mir::pretty::PrettyPrintMirOptions::from_cli(tcx),
187 )
188 }
189}
190
191#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
193pub(crate) enum Point {
194 FunEntry,
198 BeforeLocation(Location),
200 Edge(BasicBlock, BasicBlock),
202}
203
204struct StatementsAt<'a> {
205 stmts: &'a mut GhostStatements,
206 point: Point,
207}
208
209impl StatementsAt<'_> {
210 fn insert(&mut self, stmt: GhostStatement) {
211 self.stmts.insert_at(self.point, stmt);
212 }
213}
214
215fn all_nested_bodies(tcx: TyCtxt, def_id: LocalDefId) -> impl Iterator<Item = LocalDefId> {
216 use rustc_hir as hir;
217 struct ClosureFinder<'tcx> {
218 tcx: TyCtxt<'tcx>,
219 closures: FxHashSet<LocalDefId>,
220 }
221
222 impl<'tcx> rustc_hir::intravisit::Visitor<'tcx> for ClosureFinder<'tcx> {
223 type NestedFilter = rustc_middle::hir::nested_filter::OnlyBodies;
224
225 fn maybe_tcx(&mut self) -> Self::MaybeTyCtxt {
226 self.tcx
227 }
228
229 fn visit_expr(&mut self, ex: &'tcx hir::Expr<'tcx>) {
230 if let hir::ExprKind::Closure(closure) = ex.kind {
231 self.closures.insert(closure.def_id);
232 }
233
234 hir::intravisit::walk_expr(self, ex);
235 }
236 }
237 let body = tcx.hir_body_owned_by(def_id).value;
238 let mut finder = ClosureFinder { tcx, closures: FxHashSet::default() };
239 hir::intravisit::Visitor::visit_expr(&mut finder, body);
240 finder.closures.into_iter().chain(iter::once(def_id))
241}