1mod fold_unfold;
4mod points_to;
5
6use std::{fmt, iter};
7
8use flux_common::bug;
9use flux_middle::{global_env::GlobalEnv, queries::QueryResult};
10use flux_rustc_bridge::{
11 lowering,
12 mir::{BasicBlock, Body, BodyRoot, Place},
13};
14use rustc_data_structures::unord::UnordMap;
15use rustc_hash::{FxHashMap, FxHashSet};
16use rustc_hir::{def::DefKind, def_id::LocalDefId};
17use rustc_middle::{
18 mir::{BorrowKind, Location, Promoted, START_BLOCK},
19 ty::TyCtxt,
20};
21
22type LocationMap = UnordMap<Location, Vec<GhostStatement>>;
23type EdgeMap = UnordMap<BasicBlock, FxHashMap<BasicBlock, Vec<GhostStatement>>>;
24
25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
27pub enum CheckerId {
28 DefId(LocalDefId),
30 Promoted(LocalDefId, Promoted),
32}
33
34impl CheckerId {
35 pub fn root_id(&self) -> LocalDefId {
36 match self {
37 CheckerId::DefId(def_id) => *def_id,
38 CheckerId::Promoted(def_id, _) => *def_id,
39 }
40 }
41
42 pub fn is_promoted(&self) -> bool {
43 matches!(self, CheckerId::Promoted(_, _))
44 }
45}
46
47pub(crate) fn compute_ghost_statements(
48 genv: GlobalEnv,
49 def_id: LocalDefId,
50) -> QueryResult<UnordMap<CheckerId, GhostStatements>> {
51 let mut data = UnordMap::default();
52 for def_id in all_nested_bodies(genv.tcx(), def_id) {
53 let key = CheckerId::DefId(def_id);
54 data.insert(key, GhostStatements::new(genv, key)?);
55 for promoted in genv.mir(def_id)?.promoted.indices() {
56 let key = CheckerId::Promoted(def_id, promoted);
57 data.insert(key, GhostStatements::new(genv, key)?);
58 }
59 }
60 Ok(data)
61}
62
63pub(crate) enum GhostStatement {
64 Fold(Place),
65 Unfold(Place),
66 Unblock(Place),
67 PtrToRef(Place),
68}
69
70impl fmt::Debug for GhostStatement {
71 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72 match self {
73 GhostStatement::Fold(place) => write!(f, "fold({place:?})"),
74 GhostStatement::Unfold(place) => write!(f, "unfold({place:?})"),
75 GhostStatement::Unblock(place) => write!(f, "unblock({place:?})"),
76 GhostStatement::PtrToRef(place) => write!(f, "ptr_to_ref({place:?})"),
77 }
78 }
79}
80
81pub(crate) struct GhostStatements {
82 at_start: Vec<GhostStatement>,
83 at_location: LocationMap,
84 at_edge: EdgeMap,
85}
86
87impl GhostStatements {
88 fn new(genv: GlobalEnv, checker_id: CheckerId) -> QueryResult<Self> {
89 let def_id = checker_id.root_id();
90 let body_root = genv.mir(def_id)?;
91 let body = match checker_id {
92 CheckerId::DefId(_) => &body_root.body,
93 CheckerId::Promoted(_, promoted) => &body_root.promoted[promoted],
94 };
95
96 bug::track_span(body.span(), || {
97 let mut stmts = Self {
98 at_start: Default::default(),
99 at_location: LocationMap::default(),
100 at_edge: EdgeMap::default(),
101 };
102
103 let fn_sig =
105 if matches!(genv.def_kind(def_id), DefKind::Closure | DefKind::Static { .. })
106 || checker_id.is_promoted()
107 {
108 None
109 } else {
110 Some(genv.fn_sig(def_id)?)
111 };
112
113 fold_unfold::add_ghost_statements(&mut stmts, genv, body, fn_sig.as_ref())?;
114 points_to::add_ghost_statements(&mut stmts, genv, body.rustc_body(), fn_sig.as_ref())?;
115 if !checker_id.is_promoted() {
118 stmts.add_unblocks(genv.tcx(), &body_root);
119 }
120 stmts.dump_ghost_mir(genv.tcx(), body);
121
122 Ok(stmts)
123 })
124 }
125
126 fn add_unblocks<'tcx>(&mut self, tcx: TyCtxt<'tcx>, body_root: &BodyRoot<'tcx>) {
127 for (location, borrows) in body_root.calculate_borrows_out_of_scope_at_location() {
128 let stmts = borrows
129 .into_iter()
130 .filter(|bidx| {
131 matches!(body_root.borrow_data(*bidx).kind(), BorrowKind::Mut { .. })
138 })
139 .map(|bidx| {
140 let borrow = body_root.borrow_data(bidx);
141 let place = lowering::lower_place(tcx, &borrow.borrowed_place()).unwrap();
142 GhostStatement::Unblock(place)
143 });
144 self.at_location.entry(location).or_default().extend(stmts);
145 }
146 }
147
148 fn insert_at(&mut self, point: Point, stmt: GhostStatement) {
149 self.extend_at(point, [stmt]);
150 }
151
152 fn extend_at(&mut self, point: Point, stmts: impl IntoIterator<Item = GhostStatement>) {
153 match point {
154 Point::FunEntry => {
155 self.at_start.extend(stmts);
156 }
157 Point::BeforeLocation(location) => {
158 self.at_location.entry(location).or_default().extend(stmts);
159 }
160 Point::Edge(from, to) => {
161 self.at_edge
162 .entry(from)
163 .or_default()
164 .entry(to)
165 .or_default()
166 .extend(stmts);
167 }
168 }
169 }
170
171 fn at(&mut self, point: Point) -> StatementsAt<'_> {
172 StatementsAt { stmts: self, point }
173 }
174
175 pub(crate) fn statements_at(&self, point: Point) -> impl Iterator<Item = &GhostStatement> {
176 match point {
177 Point::FunEntry => Some(&self.at_start).into_iter().flatten(),
178 Point::BeforeLocation(location) => {
179 self.at_location.get(&location).into_iter().flatten()
180 }
181 Point::Edge(from, to) => {
182 self.at_edge
183 .get(&from)
184 .and_then(|m| m.get(&to))
185 .into_iter()
186 .flatten()
187 }
188 }
189 }
190
191 pub(crate) fn dump_ghost_mir<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) {
192 use rustc_middle::mir::{PassWhere, pretty::MirDumper};
193 if let Some(dumper) = MirDumper::new(tcx, "ghost", body.rustc_body()) {
194 dumper
195 .set_extra_data(&|pass, w| {
196 match pass {
197 PassWhere::BeforeBlock(bb) if bb == START_BLOCK => {
198 for stmt in &self.at_start {
199 writeln!(w, " {stmt:?};")?;
200 }
201 }
202 PassWhere::BeforeLocation(location) => {
203 for stmt in self.statements_at(Point::BeforeLocation(location)) {
204 writeln!(w, " {stmt:?};")?;
205 }
206 }
207 PassWhere::AfterTerminator(bb) => {
208 if let Some(map) = self.at_edge.get(&bb) {
209 writeln!(w)?;
210 for (target, stmts) in map {
211 write!(w, " -> {target:?} {{")?;
212 for stmt in stmts {
213 write!(w, "\n {stmt:?};")?;
214 }
215 write!(w, "\n }}")?;
216 }
217 writeln!(w)?;
218 }
219 }
220 _ => {}
221 }
222 Ok(())
223 })
224 .dump_mir(body.rustc_body());
225 }
226 }
227}
228
229#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
231pub(crate) enum Point {
232 FunEntry,
236 BeforeLocation(Location),
238 Edge(BasicBlock, BasicBlock),
240}
241
242struct StatementsAt<'a> {
243 stmts: &'a mut GhostStatements,
244 point: Point,
245}
246
247impl StatementsAt<'_> {
248 fn insert(&mut self, stmt: GhostStatement) {
249 self.stmts.insert_at(self.point, stmt);
250 }
251}
252
253fn all_nested_bodies(tcx: TyCtxt, def_id: LocalDefId) -> impl Iterator<Item = LocalDefId> {
254 use rustc_hir as hir;
255 struct ClosureFinder<'tcx> {
256 tcx: TyCtxt<'tcx>,
257 closures: FxHashSet<LocalDefId>,
258 }
259
260 impl<'tcx> rustc_hir::intravisit::Visitor<'tcx> for ClosureFinder<'tcx> {
261 type NestedFilter = rustc_middle::hir::nested_filter::OnlyBodies;
262
263 fn maybe_tcx(&mut self) -> Self::MaybeTyCtxt {
264 self.tcx
265 }
266
267 fn visit_expr(&mut self, ex: &'tcx hir::Expr<'tcx>) {
268 if let hir::ExprKind::Closure(closure) = ex.kind {
269 self.closures.insert(closure.def_id);
270 }
271
272 hir::intravisit::walk_expr(self, ex);
273 }
274 }
275 let body = tcx.hir_body_owned_by(def_id).value;
276 let mut finder = ClosureFinder { tcx, closures: FxHashSet::default() };
277 hir::intravisit::Visitor::visit_expr(&mut finder, body);
278 finder.closures.into_iter().chain(iter::once(def_id))
279}