1mod fold_unfold;
4mod points_to;
5
6use std::{fmt, iter};
7
8use flux_common::bug;
9use flux_middle::{global_env::GlobalEnv, queries::QueryResult};
10use flux_rustc_bridge::{
11 lowering,
12 mir::{BasicBlock, Body, BodyRoot, Place},
13};
14use rustc_data_structures::unord::UnordMap;
15use rustc_hash::{FxHashMap, FxHashSet};
16use rustc_hir::{def::DefKind, def_id::LocalDefId};
17use rustc_middle::{
18 mir::{Location, Promoted, START_BLOCK},
19 ty::TyCtxt,
20};
21
22type LocationMap = FxHashMap<Location, Vec<GhostStatement>>;
23type EdgeMap = FxHashMap<BasicBlock, FxHashMap<BasicBlock, Vec<GhostStatement>>>;
24
25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
27pub enum CheckerId {
28 DefId(LocalDefId),
30 Promoted(LocalDefId, Promoted),
32}
33
34impl CheckerId {
35 pub fn root_id(&self) -> LocalDefId {
36 match self {
37 CheckerId::DefId(def_id) => *def_id,
38 CheckerId::Promoted(def_id, _) => *def_id,
39 }
40 }
41
42 pub fn is_promoted(&self) -> bool {
43 matches!(self, CheckerId::Promoted(_, _))
44 }
45}
46
47pub(crate) fn compute_ghost_statements(
48 genv: GlobalEnv,
49 def_id: LocalDefId,
50) -> QueryResult<UnordMap<CheckerId, GhostStatements>> {
51 let mut data = UnordMap::default();
52 for def_id in all_nested_bodies(genv.tcx(), def_id) {
53 let key = CheckerId::DefId(def_id);
54 data.insert(key, GhostStatements::new(genv, key)?);
55 for promoted in genv.mir(def_id)?.promoted.indices() {
56 let key = CheckerId::Promoted(def_id, promoted);
57 data.insert(key, GhostStatements::new(genv, key)?);
58 }
59 }
60 Ok(data)
61}
62
63pub(crate) enum GhostStatement {
64 Fold(Place),
65 Unfold(Place),
66 Unblock(Place),
67 PtrToRef(Place),
68}
69
70impl fmt::Debug for GhostStatement {
71 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72 match self {
73 GhostStatement::Fold(place) => write!(f, "fold({place:?})"),
74 GhostStatement::Unfold(place) => write!(f, "unfold({place:?})"),
75 GhostStatement::Unblock(place) => write!(f, "unblock({place:?})"),
76 GhostStatement::PtrToRef(place) => write!(f, "ptr_to_ref({place:?})"),
77 }
78 }
79}
80
81pub(crate) struct GhostStatements {
82 at_start: Vec<GhostStatement>,
83 at_location: LocationMap,
84 at_edge: EdgeMap,
85}
86
87impl GhostStatements {
88 fn new(genv: GlobalEnv, checker_id: CheckerId) -> QueryResult<Self> {
89 let def_id = checker_id.root_id();
90 let body_root = genv.mir(def_id)?;
91 let body = match checker_id {
92 CheckerId::DefId(_) => &body_root.body,
93 CheckerId::Promoted(_, promoted) => &body_root.promoted[promoted],
94 };
95
96 bug::track_span(body.span(), || {
97 let mut stmts = Self {
98 at_start: Default::default(),
99 at_location: LocationMap::default(),
100 at_edge: EdgeMap::default(),
101 };
102
103 let fn_sig = if genv.def_kind(def_id) == DefKind::Closure || checker_id.is_promoted() {
105 None
106 } else {
107 Some(genv.fn_sig(def_id)?)
108 };
109
110 fold_unfold::add_ghost_statements(&mut stmts, genv, body, fn_sig.as_ref())?;
111 points_to::add_ghost_statements(&mut stmts, genv, &body.rustc_body, fn_sig.as_ref())?;
112 if !checker_id.is_promoted() {
115 stmts.add_unblocks(genv.tcx(), &body_root);
116 }
117 stmts.dump_ghost_mir(genv.tcx(), body);
118
119 Ok(stmts)
120 })
121 }
122
123 fn add_unblocks<'tcx>(&mut self, tcx: TyCtxt<'tcx>, body_root: &BodyRoot<'tcx>) {
124 for (location, borrows) in body_root.calculate_borrows_out_of_scope_at_location() {
125 let stmts = borrows.into_iter().map(|bidx| {
126 let borrow = body_root.borrow_data(bidx);
127 let place = lowering::lower_place(tcx, &borrow.borrowed_place()).unwrap();
128 GhostStatement::Unblock(place)
129 });
130 self.at_location.entry(location).or_default().extend(stmts);
131 }
132 }
133
134 fn insert_at(&mut self, point: Point, stmt: GhostStatement) {
135 self.extend_at(point, [stmt]);
136 }
137
138 fn extend_at(&mut self, point: Point, stmts: impl IntoIterator<Item = GhostStatement>) {
139 match point {
140 Point::FunEntry => {
141 self.at_start.extend(stmts);
142 }
143 Point::BeforeLocation(location) => {
144 self.at_location.entry(location).or_default().extend(stmts);
145 }
146 Point::Edge(from, to) => {
147 self.at_edge
148 .entry(from)
149 .or_default()
150 .entry(to)
151 .or_default()
152 .extend(stmts);
153 }
154 }
155 }
156
157 fn at(&mut self, point: Point) -> StatementsAt<'_> {
158 StatementsAt { stmts: self, point }
159 }
160
161 pub(crate) fn statements_at(&self, point: Point) -> impl Iterator<Item = &GhostStatement> {
162 match point {
163 Point::FunEntry => Some(&self.at_start).into_iter().flatten(),
164 Point::BeforeLocation(location) => {
165 self.at_location.get(&location).into_iter().flatten()
166 }
167 Point::Edge(from, to) => {
168 self.at_edge
169 .get(&from)
170 .and_then(|m| m.get(&to))
171 .into_iter()
172 .flatten()
173 }
174 }
175 }
176
177 pub(crate) fn dump_ghost_mir<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) {
178 use rustc_middle::mir::{PassWhere, pretty::MirDumper};
179 if let Some(dumper) = MirDumper::new(tcx, "ghost", &body.rustc_body) {
180 dumper
181 .set_extra_data(&|pass, w| {
182 match pass {
183 PassWhere::BeforeBlock(bb) if bb == START_BLOCK => {
184 for stmt in &self.at_start {
185 writeln!(w, " {stmt:?};")?;
186 }
187 }
188 PassWhere::BeforeLocation(location) => {
189 for stmt in self.statements_at(Point::BeforeLocation(location)) {
190 writeln!(w, " {stmt:?};")?;
191 }
192 }
193 PassWhere::AfterTerminator(bb) => {
194 if let Some(map) = self.at_edge.get(&bb) {
195 writeln!(w)?;
196 for (target, stmts) in map {
197 write!(w, " -> {target:?} {{")?;
198 for stmt in stmts {
199 write!(w, "\n {stmt:?};")?;
200 }
201 write!(w, "\n }}")?;
202 }
203 writeln!(w)?;
204 }
205 }
206 _ => {}
207 }
208 Ok(())
209 })
210 .dump_mir(&body.rustc_body);
211 }
212 }
213}
214
215#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
217pub(crate) enum Point {
218 FunEntry,
222 BeforeLocation(Location),
224 Edge(BasicBlock, BasicBlock),
226}
227
228struct StatementsAt<'a> {
229 stmts: &'a mut GhostStatements,
230 point: Point,
231}
232
233impl StatementsAt<'_> {
234 fn insert(&mut self, stmt: GhostStatement) {
235 self.stmts.insert_at(self.point, stmt);
236 }
237}
238
239fn all_nested_bodies(tcx: TyCtxt, def_id: LocalDefId) -> impl Iterator<Item = LocalDefId> {
240 use rustc_hir as hir;
241 struct ClosureFinder<'tcx> {
242 tcx: TyCtxt<'tcx>,
243 closures: FxHashSet<LocalDefId>,
244 }
245
246 impl<'tcx> rustc_hir::intravisit::Visitor<'tcx> for ClosureFinder<'tcx> {
247 type NestedFilter = rustc_middle::hir::nested_filter::OnlyBodies;
248
249 fn maybe_tcx(&mut self) -> Self::MaybeTyCtxt {
250 self.tcx
251 }
252
253 fn visit_expr(&mut self, ex: &'tcx hir::Expr<'tcx>) {
254 if let hir::ExprKind::Closure(closure) = ex.kind {
255 self.closures.insert(closure.def_id);
256 }
257
258 hir::intravisit::walk_expr(self, ex);
259 }
260 }
261 let body = tcx.hir_body_owned_by(def_id).value;
262 let mut finder = ClosureFinder { tcx, closures: FxHashSet::default() };
263 hir::intravisit::Visitor::visit_expr(&mut finder, body);
264 finder.closures.into_iter().chain(iter::once(def_id))
265}