flux_driver/collector/
detached_specs.rs

1use std::collections::{HashMap, hash_map::Entry};
2
3use flux_common::dbg::{self, SpanTrace};
4use flux_syntax::surface::{self, DetachedItem, ExprPath, NodeId};
5use itertools::Itertools;
6use rustc_errors::ErrorGuaranteed;
7use rustc_hir::{
8    OwnerId,
9    def::{DefKind, Res},
10    def_id::LocalDefId,
11};
12use rustc_middle::ty::{AssocItem, AssocKind, Ty, TyCtxt};
13use rustc_span::{Symbol, def_id::DefId};
14
15use crate::collector::{FluxAttrs, SpecCollector, errors};
16type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
17
18#[derive(PartialEq, Eq, Debug, Hash, Clone, Copy)]
19enum LookupRes {
20    DefId(DefId),
21    Name(Symbol),
22}
23
24impl LookupRes {
25    fn from_name<T: std::fmt::Debug>(thing: &T) -> Self {
26        let str = format!("{thing:?}");
27        LookupRes::Name(Symbol::intern(&str))
28    }
29
30    fn new(ty: &Ty) -> Self {
31        match ty.kind() {
32            rustc_middle::ty::TyKind::Adt(adt_def, _) => LookupRes::DefId(adt_def.did()),
33            _ => Self::from_name(ty),
34        }
35    }
36}
37
38#[derive(PartialEq, Eq, Debug, Hash)]
39struct TraitImplKey {
40    trait_: LookupRes,
41    self_ty: LookupRes,
42}
43
44fn path_to_symbol(path: &surface::ExprPath) -> Symbol {
45    let path_string = format!(
46        "{}",
47        path.segments
48            .iter()
49            .format_with("::", |s, f| f(&s.ident.name))
50    );
51    Symbol::intern(&path_string)
52}
53
54fn item_def_kind(kind: &surface::DetachedItemKind) -> Vec<DefKind> {
55    match kind {
56        surface::DetachedItemKind::FnSig(_) => vec![DefKind::Fn],
57        surface::DetachedItemKind::Mod(_) => vec![DefKind::Mod],
58        surface::DetachedItemKind::Struct(_) => vec![DefKind::Struct],
59        surface::DetachedItemKind::Enum(_) => vec![DefKind::Enum],
60        surface::DetachedItemKind::InherentImpl(_) | surface::DetachedItemKind::TraitImpl(_) => {
61            vec![DefKind::Struct, DefKind::Enum]
62        }
63        surface::DetachedItemKind::Trait(_) => vec![DefKind::Trait],
64        surface::DetachedItemKind::Static(_) => {
65            vec![DefKind::Static {
66                mutability: rustc_ast::Mutability::Not,
67                nested: false,
68                safety: rustc_hir::Safety::Safe,
69            }]
70        }
71    }
72}
73
74#[derive(Debug)]
75struct ScopeResolver {
76    items: HashMap<(Symbol, DefKind), LookupRes>,
77}
78
79impl ScopeResolver {
80    fn new(tcx: TyCtxt, def_id: LocalDefId, impl_resolver: &TraitImplResolver) -> Self {
81        let mut items = HashMap::default();
82        for child in tcx.module_children_local(def_id) {
83            let ident = child.ident;
84            if let Res::Def(exp_kind, def_id) = child.res {
85                items.insert((ident.name, exp_kind), LookupRes::DefId(def_id));
86            }
87        }
88        for pty in rustc_hir::PrimTy::ALL {
89            let name = pty.name();
90            items.insert((name, DefKind::Struct), LookupRes::Name(name)); // HACK: use DefKind::Struct for primitive...
91        }
92        for trait_impl_key in impl_resolver.items.keys() {
93            if let LookupRes::DefId(trait_id) = trait_impl_key.trait_ {
94                let name = Symbol::intern(&tcx.def_path_str(trait_id));
95                items.insert((name, DefKind::Trait), trait_impl_key.trait_);
96            }
97        }
98        Self { items }
99    }
100
101    fn lookup(&self, path: &ExprPath, item_kind: &surface::DetachedItemKind) -> Option<LookupRes> {
102        let symbol = path_to_symbol(path);
103        for kind in item_def_kind(item_kind) {
104            let key = (symbol, kind);
105            if let Some(res) = self.items.get(&key) {
106                return Some(*res);
107            }
108        }
109        None
110    }
111}
112
113#[derive(Debug)]
114struct TraitImplResolver {
115    items: HashMap<TraitImplKey, LocalDefId>,
116}
117
118impl TraitImplResolver {
119    fn new(tcx: TyCtxt) -> Self {
120        let mut items = HashMap::default();
121        for (trait_id, impl_ids) in tcx.all_local_trait_impls(()) {
122            let trait_ = LookupRes::DefId(*trait_id);
123            for impl_id in impl_ids {
124                let poly_trait_ref = tcx.impl_trait_ref(*impl_id);
125                let self_ty = poly_trait_ref.instantiate_identity().self_ty();
126                let self_ty = LookupRes::new(&self_ty);
127                let key = TraitImplKey { trait_, self_ty };
128                items.insert(key, *impl_id);
129            }
130        }
131        Self { items }
132    }
133
134    fn resolve(&self, trait_: LookupRes, self_ty: LookupRes) -> Option<LocalDefId> {
135        let key = TraitImplKey { trait_, self_ty };
136        self.items.get(&key).copied()
137    }
138}
139
140pub(super) struct DetachedSpecsCollector<'a, 'sess, 'tcx> {
141    inner: &'a mut SpecCollector<'sess, 'tcx>,
142    id_resolver: HashMap<NodeId, LookupRes>,
143    impl_resolver: TraitImplResolver,
144}
145
146impl<'a, 'sess, 'tcx> DetachedSpecsCollector<'a, 'sess, 'tcx> {
147    pub(super) fn collect(
148        inner: &'a mut SpecCollector<'sess, 'tcx>,
149        attrs: &mut FluxAttrs,
150        module_id: LocalDefId,
151    ) -> Result {
152        if let Some(detached_specs) = attrs.detached_specs() {
153            let trait_impl_resolver = TraitImplResolver::new(inner.tcx);
154            let mut collector =
155                Self { inner, id_resolver: HashMap::default(), impl_resolver: trait_impl_resolver };
156            collector.run(detached_specs, module_id)?;
157        };
158        Ok(())
159    }
160
161    fn run(&mut self, detached_specs: surface::DetachedSpecs, def_id: LocalDefId) -> Result {
162        self.resolve(&detached_specs, def_id)?;
163        for item in detached_specs.items {
164            self.attach(item)?;
165        }
166        Ok(())
167    }
168
169    fn resolve_path_kind(
170        &mut self,
171        resolver: &ScopeResolver,
172        path: &ExprPath,
173        kind: &surface::DetachedItemKind,
174    ) -> Result {
175        let Some(res) = resolver.lookup(path, kind) else {
176            return Err(self
177                .inner
178                .errors
179                .emit(errors::UnresolvedSpecification::new(path, "name")));
180        };
181        self.id_resolver.insert(path.node_id, res);
182        Ok(())
183    }
184
185    fn resolve(&mut self, detached_specs: &surface::DetachedSpecs, def_id: LocalDefId) -> Result {
186        let resolver = ScopeResolver::new(self.inner.tcx, def_id, &self.impl_resolver);
187        for item in &detached_specs.items {
188            self.resolve_path_kind(&resolver, &item.path, &item.kind)?;
189            if let surface::DetachedItemKind::TraitImpl(trait_impl) = &item.kind {
190                let kind = surface::DetachedItemKind::Trait(surface::DetachedTrait::default());
191                self.resolve_path_kind(&resolver, &trait_impl.trait_, &kind)?;
192            }
193        }
194        Ok(())
195    }
196
197    #[allow(
198        clippy::disallowed_methods,
199        reason = "this is pre-extern specs so it's fine: https://flux-rs.zulipchat.com/#narrow/channel/486369-verify-std/topic/detached-specs/near/529548357"
200    )]
201    fn unwrap_def_id(&self, def_id: &DefId) -> Result<Option<LocalDefId>> {
202        Ok(def_id.as_local())
203    }
204
205    fn lookup(&mut self, item: &surface::DetachedItem) -> Result<LocalDefId> {
206        let path_def_id = self.id_resolver.get(&item.path.node_id);
207
208        if let surface::DetachedItemKind::TraitImpl(trait_impl) = &item.kind
209            && let Some(trait_) = self.id_resolver.get(&trait_impl.trait_.node_id)
210            && let Some(self_ty) = path_def_id
211            && let Some(impl_id) = self.impl_resolver.resolve(*trait_, *self_ty)
212        {
213            return Ok(impl_id);
214        }
215        if let Some(LookupRes::DefId(def_id)) = self.id_resolver.get(&item.path.node_id)
216            && let Some(local_def_id) = self.unwrap_def_id(def_id)?
217        {
218            return Ok(local_def_id);
219        }
220        Err(self
221            .inner
222            .errors
223            .emit(errors::UnresolvedSpecification::new(&item.path, "item")))
224    }
225
226    fn attach(&mut self, item: surface::DetachedItem) -> Result {
227        let def_id = self.lookup(&item)?;
228        let owner_id = self.inner.tcx.local_def_id_to_hir_id(def_id).owner;
229        let span = item.span();
230        let dst_span = self.inner.tcx.def_span(def_id);
231        dbg::hyperlink!(self.inner.tcx, span, dst_span);
232        match item.kind {
233            surface::DetachedItemKind::FnSig(fn_sig) => {
234                self.inner.insert_item(
235                    owner_id,
236                    surface::Item {
237                        attrs: item.attrs,
238                        kind: surface::ItemKind::Fn(Some(fn_sig)),
239                        node_id: item.node_id,
240                    },
241                )?;
242            }
243            surface::DetachedItemKind::Struct(struct_def) => {
244                self.inner.insert_item(
245                    owner_id,
246                    surface::Item {
247                        attrs: item.attrs,
248                        kind: surface::ItemKind::Struct(struct_def),
249                        node_id: item.node_id,
250                    },
251                )?;
252            }
253            surface::DetachedItemKind::Enum(enum_def) => {
254                self.inner.insert_item(
255                    owner_id,
256                    surface::Item {
257                        attrs: item.attrs,
258                        kind: surface::ItemKind::Enum(enum_def),
259                        node_id: item.node_id,
260                    },
261                )?;
262            }
263            surface::DetachedItemKind::Mod(detached_specs) => {
264                self.run(detached_specs, owner_id.def_id)?;
265            }
266            surface::DetachedItemKind::Trait(trait_def) => {
267                self.collect_trait(owner_id, item.node_id, item.attrs, trait_def)?;
268            }
269            surface::DetachedItemKind::InherentImpl(inherent_impl) => {
270                let tcx = self.inner.tcx;
271                let assoc_items = tcx
272                    .inherent_impls(def_id)
273                    .iter()
274                    .flat_map(|impl_id| tcx.associated_items(impl_id).in_definition_order());
275                self.collect_assoc_methods(
276                    inherent_impl.items,
277                    assoc_items,
278                    |this, owner_id, item| {
279                        this.inner.insert_impl_item(
280                            owner_id,
281                            surface::ImplItemFn {
282                                attrs: item.attrs,
283                                sig: Some(item.kind),
284                                node_id: item.node_id,
285                            },
286                        )
287                    },
288                )?;
289            }
290            surface::DetachedItemKind::TraitImpl(trait_impl) => {
291                self.collect_trait_impl(owner_id, item.node_id, item.attrs, trait_impl)?;
292            }
293            surface::DetachedItemKind::Static(static_info) => {
294                self.inner.insert_item(
295                    owner_id,
296                    surface::Item {
297                        attrs: item.attrs,
298                        kind: surface::ItemKind::Static(static_info),
299                        node_id: item.node_id,
300                    },
301                )?;
302            }
303        };
304        Ok(())
305    }
306
307    fn collect_trait(
308        &mut self,
309        owner_id: OwnerId,
310        node_id: NodeId,
311        attrs: Vec<surface::Attr>,
312        trait_def: surface::DetachedTrait,
313    ) -> Result {
314        // 1. Collect the associated-refinements
315        self.inner.insert_item(
316            owner_id,
317            surface::Item {
318                attrs,
319                kind: surface::ItemKind::Trait(surface::Trait {
320                    generics: None,
321                    assoc_refinements: trait_def.refts,
322                }),
323                node_id,
324            },
325        )?;
326
327        // 2. Collect the method specifications
328        let tcx = self.inner.tcx;
329        let assoc_items = tcx.associated_items(owner_id.def_id).in_definition_order();
330        self.collect_assoc_methods(trait_def.items, assoc_items, |this, owner_id, item| {
331            this.inner.insert_trait_item(
332                owner_id,
333                surface::TraitItemFn {
334                    attrs: item.attrs,
335                    sig: Some(item.kind),
336                    node_id: item.node_id,
337                },
338            )
339        })
340    }
341
342    fn collect_trait_impl(
343        &mut self,
344        owner_id: OwnerId,
345        node_id: NodeId,
346        attrs: Vec<surface::Attr>,
347        trait_impl: surface::DetachedTraitImpl,
348    ) -> Result {
349        // 1. Collect the associated-refinements
350        self.inner.insert_item(
351            owner_id,
352            surface::Item {
353                attrs,
354                kind: surface::ItemKind::Impl(surface::Impl {
355                    generics: None,
356                    assoc_refinements: trait_impl.refts,
357                }),
358                node_id,
359            },
360        )?;
361
362        // 2. Collect the method specifications
363        let tcx = self.inner.tcx;
364        let assoc_items = tcx.associated_items(owner_id.def_id).in_definition_order();
365        self.collect_assoc_methods(trait_impl.items, assoc_items, |this, owner_id, item| {
366            this.inner.insert_impl_item(
367                owner_id,
368                surface::ImplItemFn {
369                    attrs: item.attrs,
370                    sig: Some(item.kind),
371                    node_id: item.node_id,
372                },
373            )
374        })
375    }
376
377    fn collect_assoc_methods(
378        &mut self,
379        methods: Vec<DetachedItem<surface::FnSig>>,
380        assoc_items: impl Iterator<Item = &'tcx AssocItem>,
381        mut insert_item: impl FnMut(&mut Self, OwnerId, DetachedItem<surface::FnSig>) -> Result,
382    ) -> Result {
383        let mut table: HashMap<Symbol, DetachedItem<(surface::FnSig, Option<DefId>)>> =
384            HashMap::default();
385        // 1. make a table of the impl-items
386        for item in methods {
387            let name = path_to_symbol(&item.path);
388            let span = item.path.span;
389            if let Entry::Occupied(_) = table.entry(name) {
390                return Err(self
391                    .inner
392                    .errors
393                    .emit(errors::MultipleSpecifications { name, span }));
394            } else {
395                table.insert(name, item.map_kind(|spec| (spec, None)));
396            }
397        }
398        // 2. walk over all the assoc-items to resolve names
399        for item in assoc_items {
400            if let AssocKind::Fn { name, .. } = item.kind
401                && let Some(val) = table.get_mut(&name)
402                && val.kind.1.is_none()
403            {
404                val.kind.1 = Some(item.def_id);
405            }
406        }
407        // 3. Attach the `fn_sig` to the resolved `DefId`
408        for (_name, item) in table {
409            let Some(def_id) = item.kind.1 else {
410                return Err(self
411                    .inner
412                    .errors
413                    .emit(errors::UnresolvedSpecification::new(&item.path, "identifier")));
414            };
415            if let Some(def_id) = self.unwrap_def_id(&def_id)? {
416                dbg::hyperlink!(self.inner.tcx, item.path.span, self.inner.tcx.def_span(def_id));
417                let owner_id = self.inner.tcx.local_def_id_to_hir_id(def_id).owner;
418                insert_item(self, owner_id, item.map_kind(|k| k.0))?;
419            }
420        }
421        Ok(())
422    }
423}