1mod fold_unfold;
4mod points_to;
5
6use std::{fmt, iter};
7
8use flux_common::bug;
9use flux_middle::{global_env::GlobalEnv, queries::QueryResult};
10use flux_rustc_bridge::{
11 lowering,
12 mir::{BasicBlock, Body, Place},
13};
14use rustc_data_structures::unord::UnordMap;
15use rustc_hash::{FxHashMap, FxHashSet};
16use rustc_hir::{def::DefKind, def_id::LocalDefId};
17use rustc_middle::{
18 mir::{Location, START_BLOCK},
19 ty::TyCtxt,
20};
21
22type LocationMap = FxHashMap<Location, Vec<GhostStatement>>;
23type EdgeMap = FxHashMap<BasicBlock, FxHashMap<BasicBlock, Vec<GhostStatement>>>;
24
25pub(crate) fn compute_ghost_statements(
26 genv: GlobalEnv,
27 def_id: LocalDefId,
28) -> QueryResult<UnordMap<LocalDefId, GhostStatements>> {
29 let mut data = UnordMap::default();
30 for def_id in all_nested_bodies(genv.tcx(), def_id) {
31 data.insert(def_id, GhostStatements::new(genv, def_id)?);
32 }
33 Ok(data)
34}
35
36pub(crate) enum GhostStatement {
37 Fold(Place),
38 Unfold(Place),
39 Unblock(Place),
40 PtrToRef(Place),
41}
42
43impl fmt::Debug for GhostStatement {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 match self {
46 GhostStatement::Fold(place) => write!(f, "fold({place:?})"),
47 GhostStatement::Unfold(place) => write!(f, "unfold({place:?})"),
48 GhostStatement::Unblock(place) => write!(f, "unblock({place:?})"),
49 GhostStatement::PtrToRef(place) => write!(f, "ptr_to_ref({place:?})"),
50 }
51 }
52}
53
54pub(crate) struct GhostStatements {
55 at_start: Vec<GhostStatement>,
56 at_location: LocationMap,
57 at_edge: EdgeMap,
58}
59
60impl GhostStatements {
61 fn new(genv: GlobalEnv, def_id: LocalDefId) -> QueryResult<Self> {
62 let body = genv.mir(def_id)?;
63
64 bug::track_span(body.span(), || {
65 let mut stmts = Self {
66 at_start: Default::default(),
67 at_location: LocationMap::default(),
68 at_edge: EdgeMap::default(),
69 };
70
71 let fn_sig = if genv.def_kind(def_id) == DefKind::Closure {
73 None
74 } else {
75 Some(genv.fn_sig(def_id)?)
76 };
77
78 fold_unfold::add_ghost_statements(&mut stmts, genv, &body, fn_sig.as_ref())?;
79 points_to::add_ghost_statements(&mut stmts, genv, body.rustc_body(), fn_sig.as_ref())?;
80 stmts.add_unblocks(genv.tcx(), &body);
81
82 stmts.dump_ghost_mir(genv.tcx(), &body);
83
84 Ok(stmts)
85 })
86 }
87
88 fn add_unblocks<'tcx>(&mut self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) {
89 for (location, borrows) in body.calculate_borrows_out_of_scope_at_location() {
90 let stmts = borrows.into_iter().map(|bidx| {
91 let borrow = body.borrow_data(bidx);
92 let place = lowering::lower_place(tcx, &borrow.borrowed_place()).unwrap();
93 GhostStatement::Unblock(place)
94 });
95 self.at_location.entry(location).or_default().extend(stmts);
96 }
97 }
98
99 fn insert_at(&mut self, point: Point, stmt: GhostStatement) {
100 self.extend_at(point, [stmt]);
101 }
102
103 fn extend_at(&mut self, point: Point, stmts: impl IntoIterator<Item = GhostStatement>) {
104 match point {
105 Point::FunEntry => {
106 self.at_start.extend(stmts);
107 }
108 Point::BeforeLocation(location) => {
109 self.at_location.entry(location).or_default().extend(stmts);
110 }
111 Point::Edge(from, to) => {
112 self.at_edge
113 .entry(from)
114 .or_default()
115 .entry(to)
116 .or_default()
117 .extend(stmts);
118 }
119 }
120 }
121
122 fn at(&mut self, point: Point) -> StatementsAt<'_> {
123 StatementsAt { stmts: self, point }
124 }
125
126 pub(crate) fn statements_at(&self, point: Point) -> impl Iterator<Item = &GhostStatement> {
127 match point {
128 Point::FunEntry => Some(&self.at_start).into_iter().flatten(),
129 Point::BeforeLocation(location) => {
130 self.at_location.get(&location).into_iter().flatten()
131 }
132 Point::Edge(from, to) => {
133 self.at_edge
134 .get(&from)
135 .and_then(|m| m.get(&to))
136 .into_iter()
137 .flatten()
138 }
139 }
140 }
141
142 pub(crate) fn dump_ghost_mir<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) {
143 use rustc_middle::mir::{PassWhere, pretty::MirDumper};
144 if let Some(dumper) = MirDumper::new(tcx, "ghost", body.inner()) {
145 dumper
146 .set_extra_data(&|pass, w| {
147 match pass {
148 PassWhere::BeforeBlock(bb) if bb == START_BLOCK => {
149 for stmt in &self.at_start {
150 writeln!(w, " {stmt:?};")?;
151 }
152 }
153 PassWhere::BeforeLocation(location) => {
154 for stmt in self.statements_at(Point::BeforeLocation(location)) {
155 writeln!(w, " {stmt:?};")?;
156 }
157 }
158 PassWhere::AfterTerminator(bb) => {
159 if let Some(map) = self.at_edge.get(&bb) {
160 writeln!(w)?;
161 for (target, stmts) in map {
162 write!(w, " -> {target:?} {{")?;
163 for stmt in stmts {
164 write!(w, "\n {stmt:?};")?;
165 }
166 write!(w, "\n }}")?;
167 }
168 writeln!(w)?;
169 }
170 }
171 _ => {}
172 }
173 Ok(())
174 })
175 .dump_mir(body.inner());
176 }
177 }
178}
179
180#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
182pub(crate) enum Point {
183 FunEntry,
187 BeforeLocation(Location),
189 Edge(BasicBlock, BasicBlock),
191}
192
193struct StatementsAt<'a> {
194 stmts: &'a mut GhostStatements,
195 point: Point,
196}
197
198impl StatementsAt<'_> {
199 fn insert(&mut self, stmt: GhostStatement) {
200 self.stmts.insert_at(self.point, stmt);
201 }
202}
203
204fn all_nested_bodies(tcx: TyCtxt, def_id: LocalDefId) -> impl Iterator<Item = LocalDefId> {
205 use rustc_hir as hir;
206 struct ClosureFinder<'tcx> {
207 tcx: TyCtxt<'tcx>,
208 closures: FxHashSet<LocalDefId>,
209 }
210
211 impl<'tcx> rustc_hir::intravisit::Visitor<'tcx> for ClosureFinder<'tcx> {
212 type NestedFilter = rustc_middle::hir::nested_filter::OnlyBodies;
213
214 fn maybe_tcx(&mut self) -> Self::MaybeTyCtxt {
215 self.tcx
216 }
217
218 fn visit_expr(&mut self, ex: &'tcx hir::Expr<'tcx>) {
219 if let hir::ExprKind::Closure(closure) = ex.kind {
220 self.closures.insert(closure.def_id);
221 }
222
223 hir::intravisit::walk_expr(self, ex);
224 }
225 }
226 let body = tcx.hir_body_owned_by(def_id).value;
227 let mut finder = ClosureFinder { tcx, closures: FxHashSet::default() };
228 hir::intravisit::Visitor::visit_expr(&mut finder, body);
229 finder.closures.into_iter().chain(iter::once(def_id))
230}