1mod fold_unfold;
4mod points_to;
5
6use std::{fmt, iter};
7
8use flux_common::bug;
9use flux_middle::{global_env::GlobalEnv, queries::QueryResult};
10use flux_rustc_bridge::{
11 lowering,
12 mir::{BasicBlock, Body, BodyRoot, Place},
13};
14use rustc_data_structures::unord::UnordMap;
15use rustc_hash::{FxHashMap, FxHashSet};
16use rustc_hir::{def::DefKind, def_id::LocalDefId};
17use rustc_middle::{
18 mir::{Location, Promoted, START_BLOCK},
19 ty::TyCtxt,
20};
21
22type LocationMap = FxHashMap<Location, Vec<GhostStatement>>;
23type EdgeMap = FxHashMap<BasicBlock, FxHashMap<BasicBlock, Vec<GhostStatement>>>;
24
25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
27pub enum CheckerId {
28 DefId(LocalDefId),
30 Promoted(LocalDefId, Promoted),
32}
33
34impl CheckerId {
35 pub fn root_id(&self) -> LocalDefId {
36 match self {
37 CheckerId::DefId(def_id) => *def_id,
38 CheckerId::Promoted(def_id, _) => *def_id,
39 }
40 }
41
42 pub fn is_promoted(&self) -> bool {
43 matches!(self, CheckerId::Promoted(_, _))
44 }
45}
46
47pub(crate) fn compute_ghost_statements(
48 genv: GlobalEnv,
49 def_id: LocalDefId,
50) -> QueryResult<UnordMap<CheckerId, GhostStatements>> {
51 let mut data = UnordMap::default();
52 for def_id in all_nested_bodies(genv.tcx(), def_id) {
53 let key = CheckerId::DefId(def_id);
54 data.insert(key, GhostStatements::new(genv, key)?);
55 for promoted in genv.mir(def_id)?.promoted.indices() {
56 let key = CheckerId::Promoted(def_id, promoted);
57 data.insert(key, GhostStatements::new(genv, key)?);
58 }
59 }
60 Ok(data)
61}
62
63pub(crate) enum GhostStatement {
64 Fold(Place),
65 Unfold(Place),
66 Unblock(Place),
67 PtrToRef(Place),
68}
69
70impl fmt::Debug for GhostStatement {
71 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72 match self {
73 GhostStatement::Fold(place) => write!(f, "fold({place:?})"),
74 GhostStatement::Unfold(place) => write!(f, "unfold({place:?})"),
75 GhostStatement::Unblock(place) => write!(f, "unblock({place:?})"),
76 GhostStatement::PtrToRef(place) => write!(f, "ptr_to_ref({place:?})"),
77 }
78 }
79}
80
81pub(crate) struct GhostStatements {
82 at_start: Vec<GhostStatement>,
83 at_location: LocationMap,
84 at_edge: EdgeMap,
85}
86
87impl GhostStatements {
88 fn new(genv: GlobalEnv, checker_id: CheckerId) -> QueryResult<Self> {
89 let def_id = checker_id.root_id();
90 let body_root = genv.mir(def_id)?;
91 let body = match checker_id {
92 CheckerId::DefId(_) => &body_root.body,
93 CheckerId::Promoted(_, promoted) => &body_root.promoted[promoted],
94 };
95
96 bug::track_span(body.span(), || {
97 let mut stmts = Self {
98 at_start: Default::default(),
99 at_location: LocationMap::default(),
100 at_edge: EdgeMap::default(),
101 };
102
103 let fn_sig =
105 if matches!(genv.def_kind(def_id), DefKind::Closure | DefKind::Static { .. })
106 || checker_id.is_promoted()
107 {
108 None
109 } else {
110 Some(genv.fn_sig(def_id)?)
111 };
112
113 fold_unfold::add_ghost_statements(&mut stmts, genv, body, fn_sig.as_ref())?;
114 points_to::add_ghost_statements(&mut stmts, genv, &body.rustc_body, fn_sig.as_ref())?;
115 if !checker_id.is_promoted() {
118 stmts.add_unblocks(genv.tcx(), &body_root);
119 }
120 stmts.dump_ghost_mir(genv.tcx(), body);
121
122 Ok(stmts)
123 })
124 }
125
126 fn add_unblocks<'tcx>(&mut self, tcx: TyCtxt<'tcx>, body_root: &BodyRoot<'tcx>) {
127 for (location, borrows) in body_root.calculate_borrows_out_of_scope_at_location() {
128 let stmts = borrows.into_iter().map(|bidx| {
129 let borrow = body_root.borrow_data(bidx);
130 let place = lowering::lower_place(tcx, &borrow.borrowed_place()).unwrap();
131 GhostStatement::Unblock(place)
132 });
133 self.at_location.entry(location).or_default().extend(stmts);
134 }
135 }
136
137 fn insert_at(&mut self, point: Point, stmt: GhostStatement) {
138 self.extend_at(point, [stmt]);
139 }
140
141 fn extend_at(&mut self, point: Point, stmts: impl IntoIterator<Item = GhostStatement>) {
142 match point {
143 Point::FunEntry => {
144 self.at_start.extend(stmts);
145 }
146 Point::BeforeLocation(location) => {
147 self.at_location.entry(location).or_default().extend(stmts);
148 }
149 Point::Edge(from, to) => {
150 self.at_edge
151 .entry(from)
152 .or_default()
153 .entry(to)
154 .or_default()
155 .extend(stmts);
156 }
157 }
158 }
159
160 fn at(&mut self, point: Point) -> StatementsAt<'_> {
161 StatementsAt { stmts: self, point }
162 }
163
164 pub(crate) fn statements_at(&self, point: Point) -> impl Iterator<Item = &GhostStatement> {
165 match point {
166 Point::FunEntry => Some(&self.at_start).into_iter().flatten(),
167 Point::BeforeLocation(location) => {
168 self.at_location.get(&location).into_iter().flatten()
169 }
170 Point::Edge(from, to) => {
171 self.at_edge
172 .get(&from)
173 .and_then(|m| m.get(&to))
174 .into_iter()
175 .flatten()
176 }
177 }
178 }
179
180 pub(crate) fn dump_ghost_mir<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) {
181 use rustc_middle::mir::{PassWhere, pretty::MirDumper};
182 if let Some(dumper) = MirDumper::new(tcx, "ghost", &body.rustc_body) {
183 dumper
184 .set_extra_data(&|pass, w| {
185 match pass {
186 PassWhere::BeforeBlock(bb) if bb == START_BLOCK => {
187 for stmt in &self.at_start {
188 writeln!(w, " {stmt:?};")?;
189 }
190 }
191 PassWhere::BeforeLocation(location) => {
192 for stmt in self.statements_at(Point::BeforeLocation(location)) {
193 writeln!(w, " {stmt:?};")?;
194 }
195 }
196 PassWhere::AfterTerminator(bb) => {
197 if let Some(map) = self.at_edge.get(&bb) {
198 writeln!(w)?;
199 for (target, stmts) in map {
200 write!(w, " -> {target:?} {{")?;
201 for stmt in stmts {
202 write!(w, "\n {stmt:?};")?;
203 }
204 write!(w, "\n }}")?;
205 }
206 writeln!(w)?;
207 }
208 }
209 _ => {}
210 }
211 Ok(())
212 })
213 .dump_mir(&body.rustc_body);
214 }
215 }
216}
217
218#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
220pub(crate) enum Point {
221 FunEntry,
225 BeforeLocation(Location),
227 Edge(BasicBlock, BasicBlock),
229}
230
231struct StatementsAt<'a> {
232 stmts: &'a mut GhostStatements,
233 point: Point,
234}
235
236impl StatementsAt<'_> {
237 fn insert(&mut self, stmt: GhostStatement) {
238 self.stmts.insert_at(self.point, stmt);
239 }
240}
241
242fn all_nested_bodies(tcx: TyCtxt, def_id: LocalDefId) -> impl Iterator<Item = LocalDefId> {
243 use rustc_hir as hir;
244 struct ClosureFinder<'tcx> {
245 tcx: TyCtxt<'tcx>,
246 closures: FxHashSet<LocalDefId>,
247 }
248
249 impl<'tcx> rustc_hir::intravisit::Visitor<'tcx> for ClosureFinder<'tcx> {
250 type NestedFilter = rustc_middle::hir::nested_filter::OnlyBodies;
251
252 fn maybe_tcx(&mut self) -> Self::MaybeTyCtxt {
253 self.tcx
254 }
255
256 fn visit_expr(&mut self, ex: &'tcx hir::Expr<'tcx>) {
257 if let hir::ExprKind::Closure(closure) = ex.kind {
258 self.closures.insert(closure.def_id);
259 }
260
261 hir::intravisit::walk_expr(self, ex);
262 }
263 }
264 let body = tcx.hir_body_owned_by(def_id).value;
265 let mut finder = ClosureFinder { tcx, closures: FxHashSet::default() };
266 hir::intravisit::Visitor::visit_expr(&mut finder, body);
267 finder.closures.into_iter().chain(iter::once(def_id))
268}