flux_driver/collector/
detached_specs.rs

1use std::collections::{HashMap, hash_map::Entry};
2
3use flux_common::dbg::{self, SpanTrace};
4use flux_syntax::surface::{self, DetachedItem, ExprPath, NodeId};
5use itertools::Itertools;
6use rustc_errors::ErrorGuaranteed;
7use rustc_hir::{
8    OwnerId,
9    def::{DefKind, Res},
10    def_id::LocalDefId,
11};
12use rustc_middle::ty::{AssocItem, AssocKind, Ty, TyCtxt};
13use rustc_span::{Symbol, def_id::DefId};
14
15use crate::collector::{FluxAttrs, SpecCollector, errors};
16type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
17
18#[derive(PartialEq, Eq, Debug, Hash, Clone, Copy)]
19enum LookupRes {
20    DefId(DefId),
21    Name(Symbol),
22}
23
24impl LookupRes {
25    fn from_name<T: std::fmt::Debug>(thing: &T) -> Self {
26        let str = format!("{thing:?}");
27        LookupRes::Name(Symbol::intern(&str))
28    }
29
30    fn new(ty: &Ty) -> Self {
31        match ty.kind() {
32            rustc_middle::ty::TyKind::Adt(adt_def, _) => LookupRes::DefId(adt_def.did()),
33            _ => Self::from_name(ty),
34        }
35    }
36}
37
38#[derive(PartialEq, Eq, Debug, Hash)]
39struct TraitImplKey {
40    trait_: LookupRes,
41    self_ty: LookupRes,
42}
43
44fn path_to_symbol(path: &surface::ExprPath) -> Symbol {
45    let path_string = format!(
46        "{}",
47        path.segments
48            .iter()
49            .format_with("::", |s, f| f(&s.ident.name))
50    );
51    Symbol::intern(&path_string)
52}
53
54fn item_def_kind(kind: &surface::DetachedItemKind) -> Vec<DefKind> {
55    match kind {
56        surface::DetachedItemKind::FnSig(_) => vec![DefKind::Fn],
57        surface::DetachedItemKind::Mod(_) => vec![DefKind::Mod],
58        surface::DetachedItemKind::Struct(_) => vec![DefKind::Struct],
59        surface::DetachedItemKind::Enum(_) => vec![DefKind::Enum],
60        surface::DetachedItemKind::InherentImpl(_) | surface::DetachedItemKind::TraitImpl(_) => {
61            vec![DefKind::Struct, DefKind::Enum]
62        }
63        surface::DetachedItemKind::Trait(_) => vec![DefKind::Trait],
64    }
65}
66
67#[derive(Debug)]
68struct ScopeResolver {
69    items: HashMap<(Symbol, DefKind), LookupRes>,
70}
71
72impl ScopeResolver {
73    fn new(tcx: TyCtxt, def_id: LocalDefId, impl_resolver: &TraitImplResolver) -> Self {
74        let mut items = HashMap::default();
75        for child in tcx.module_children_local(def_id) {
76            let ident = child.ident;
77            if let Res::Def(exp_kind, def_id) = child.res {
78                items.insert((ident.name, exp_kind), LookupRes::DefId(def_id));
79            }
80        }
81        for pty in rustc_hir::PrimTy::ALL {
82            let name = pty.name();
83            items.insert((name, DefKind::Struct), LookupRes::Name(name)); // HACK: use DefKind::Struct for primitive...
84        }
85        for trait_impl_key in impl_resolver.items.keys() {
86            if let LookupRes::DefId(trait_id) = trait_impl_key.trait_ {
87                let name = Symbol::intern(&tcx.def_path_str(trait_id));
88                items.insert((name, DefKind::Trait), trait_impl_key.trait_);
89            }
90        }
91        Self { items }
92    }
93
94    fn lookup(&self, path: &ExprPath, item_kind: &surface::DetachedItemKind) -> Option<LookupRes> {
95        let symbol = path_to_symbol(path);
96        for kind in item_def_kind(item_kind) {
97            let key = (symbol, kind);
98            if let Some(res) = self.items.get(&key) {
99                return Some(*res);
100            }
101        }
102        None
103    }
104}
105
106#[derive(Debug)]
107struct TraitImplResolver {
108    items: HashMap<TraitImplKey, LocalDefId>,
109}
110
111impl TraitImplResolver {
112    fn new(tcx: TyCtxt) -> Self {
113        let mut items = HashMap::default();
114        for (trait_id, impl_ids) in tcx.all_local_trait_impls(()) {
115            let trait_ = LookupRes::DefId(*trait_id);
116            for impl_id in impl_ids {
117                let poly_trait_ref = tcx.impl_trait_ref(*impl_id);
118                let self_ty = poly_trait_ref.instantiate_identity().self_ty();
119                let self_ty = LookupRes::new(&self_ty);
120                let key = TraitImplKey { trait_, self_ty };
121                items.insert(key, *impl_id);
122            }
123        }
124        Self { items }
125    }
126
127    fn resolve(&self, trait_: LookupRes, self_ty: LookupRes) -> Option<LocalDefId> {
128        let key = TraitImplKey { trait_, self_ty };
129        self.items.get(&key).copied()
130    }
131}
132
133pub(super) struct DetachedSpecsCollector<'a, 'sess, 'tcx> {
134    inner: &'a mut SpecCollector<'sess, 'tcx>,
135    id_resolver: HashMap<NodeId, LookupRes>,
136    impl_resolver: TraitImplResolver,
137}
138
139impl<'a, 'sess, 'tcx> DetachedSpecsCollector<'a, 'sess, 'tcx> {
140    pub(super) fn collect(
141        inner: &'a mut SpecCollector<'sess, 'tcx>,
142        attrs: &mut FluxAttrs,
143        module_id: LocalDefId,
144    ) -> Result {
145        if let Some(detached_specs) = attrs.detached_specs() {
146            let trait_impl_resolver = TraitImplResolver::new(inner.tcx);
147            let mut collector =
148                Self { inner, id_resolver: HashMap::default(), impl_resolver: trait_impl_resolver };
149            collector.run(detached_specs, module_id)?;
150        };
151        Ok(())
152    }
153
154    fn run(&mut self, detached_specs: surface::DetachedSpecs, def_id: LocalDefId) -> Result {
155        self.resolve(&detached_specs, def_id)?;
156        for item in detached_specs.items {
157            self.attach(item)?;
158        }
159        Ok(())
160    }
161
162    fn resolve_path_kind(
163        &mut self,
164        resolver: &ScopeResolver,
165        path: &ExprPath,
166        kind: &surface::DetachedItemKind,
167    ) -> Result {
168        let Some(res) = resolver.lookup(path, kind) else {
169            return Err(self
170                .inner
171                .errors
172                .emit(errors::UnresolvedSpecification::new(path, "name")));
173        };
174        self.id_resolver.insert(path.node_id, res);
175        Ok(())
176    }
177
178    fn resolve(&mut self, detached_specs: &surface::DetachedSpecs, def_id: LocalDefId) -> Result {
179        let resolver = ScopeResolver::new(self.inner.tcx, def_id, &self.impl_resolver);
180        for item in &detached_specs.items {
181            self.resolve_path_kind(&resolver, &item.path, &item.kind)?;
182            if let surface::DetachedItemKind::TraitImpl(trait_impl) = &item.kind {
183                let kind = surface::DetachedItemKind::Trait(surface::DetachedTrait::default());
184                self.resolve_path_kind(&resolver, &trait_impl.trait_, &kind)?;
185            }
186        }
187        Ok(())
188    }
189
190    #[allow(
191        clippy::disallowed_methods,
192        reason = "this is pre-extern specs so it's fine: https://flux-rs.zulipchat.com/#narrow/channel/486369-verify-std/topic/detached-specs/near/529548357"
193    )]
194    fn unwrap_def_id(&self, def_id: &DefId) -> Result<Option<LocalDefId>> {
195        Ok(def_id.as_local())
196    }
197
198    fn lookup(&mut self, item: &surface::DetachedItem) -> Result<LocalDefId> {
199        let path_def_id = self.id_resolver.get(&item.path.node_id);
200
201        if let surface::DetachedItemKind::TraitImpl(trait_impl) = &item.kind
202            && let Some(trait_) = self.id_resolver.get(&trait_impl.trait_.node_id)
203            && let Some(self_ty) = path_def_id
204            && let Some(impl_id) = self.impl_resolver.resolve(*trait_, *self_ty)
205        {
206            return Ok(impl_id);
207        }
208        if let Some(LookupRes::DefId(def_id)) = self.id_resolver.get(&item.path.node_id)
209            && let Some(local_def_id) = self.unwrap_def_id(def_id)?
210        {
211            return Ok(local_def_id);
212        }
213        Err(self
214            .inner
215            .errors
216            .emit(errors::UnresolvedSpecification::new(&item.path, "item")))
217    }
218
219    fn attach(&mut self, item: surface::DetachedItem) -> Result {
220        let def_id = self.lookup(&item)?;
221        let owner_id = self.inner.tcx.local_def_id_to_hir_id(def_id).owner;
222        let span = item.span();
223        let dst_span = self.inner.tcx.def_span(def_id);
224        dbg::hyperlink!(self.inner.tcx, span, dst_span);
225        match item.kind {
226            surface::DetachedItemKind::FnSig(fn_sig) => {
227                self.inner.insert_item(
228                    owner_id,
229                    surface::Item {
230                        attrs: item.attrs,
231                        kind: surface::ItemKind::Fn(Some(fn_sig)),
232                        node_id: item.node_id,
233                    },
234                )?;
235            }
236            surface::DetachedItemKind::Struct(struct_def) => {
237                self.inner.insert_item(
238                    owner_id,
239                    surface::Item {
240                        attrs: item.attrs,
241                        kind: surface::ItemKind::Struct(struct_def),
242                        node_id: item.node_id,
243                    },
244                )?;
245            }
246            surface::DetachedItemKind::Enum(enum_def) => {
247                self.inner.insert_item(
248                    owner_id,
249                    surface::Item {
250                        attrs: item.attrs,
251                        kind: surface::ItemKind::Enum(enum_def),
252                        node_id: item.node_id,
253                    },
254                )?;
255            }
256            surface::DetachedItemKind::Mod(detached_specs) => {
257                self.run(detached_specs, owner_id.def_id)?;
258            }
259            surface::DetachedItemKind::Trait(trait_def) => {
260                self.collect_trait(owner_id, item.node_id, item.attrs, trait_def)?;
261            }
262            surface::DetachedItemKind::InherentImpl(inherent_impl) => {
263                let tcx = self.inner.tcx;
264                let assoc_items = tcx
265                    .inherent_impls(def_id)
266                    .iter()
267                    .flat_map(|impl_id| tcx.associated_items(impl_id).in_definition_order());
268                self.collect_assoc_methods(
269                    inherent_impl.items,
270                    assoc_items,
271                    |this, owner_id, item| {
272                        this.inner.insert_impl_item(
273                            owner_id,
274                            surface::ImplItemFn {
275                                attrs: item.attrs,
276                                sig: Some(item.kind),
277                                node_id: item.node_id,
278                            },
279                        )
280                    },
281                )?;
282            }
283            surface::DetachedItemKind::TraitImpl(trait_impl) => {
284                self.collect_trait_impl(owner_id, item.node_id, item.attrs, trait_impl)?;
285            }
286        };
287        Ok(())
288    }
289
290    fn collect_trait(
291        &mut self,
292        owner_id: OwnerId,
293        node_id: NodeId,
294        attrs: Vec<surface::Attr>,
295        trait_def: surface::DetachedTrait,
296    ) -> Result {
297        // 1. Collect the associated-refinements
298        self.inner.insert_item(
299            owner_id,
300            surface::Item {
301                attrs,
302                kind: surface::ItemKind::Trait(surface::Trait {
303                    generics: None,
304                    assoc_refinements: trait_def.refts,
305                }),
306                node_id,
307            },
308        )?;
309
310        // 2. Collect the method specifications
311        let tcx = self.inner.tcx;
312        let assoc_items = tcx.associated_items(owner_id.def_id).in_definition_order();
313        self.collect_assoc_methods(trait_def.items, assoc_items, |this, owner_id, item| {
314            this.inner.insert_trait_item(
315                owner_id,
316                surface::TraitItemFn {
317                    attrs: item.attrs,
318                    sig: Some(item.kind),
319                    node_id: item.node_id,
320                },
321            )
322        })
323    }
324
325    fn collect_trait_impl(
326        &mut self,
327        owner_id: OwnerId,
328        node_id: NodeId,
329        attrs: Vec<surface::Attr>,
330        trait_impl: surface::DetachedTraitImpl,
331    ) -> Result {
332        // 1. Collect the associated-refinements
333        self.inner.insert_item(
334            owner_id,
335            surface::Item {
336                attrs,
337                kind: surface::ItemKind::Impl(surface::Impl {
338                    generics: None,
339                    assoc_refinements: trait_impl.refts,
340                }),
341                node_id,
342            },
343        )?;
344
345        // 2. Collect the method specifications
346        let tcx = self.inner.tcx;
347        let assoc_items = tcx.associated_items(owner_id.def_id).in_definition_order();
348        self.collect_assoc_methods(trait_impl.items, assoc_items, |this, owner_id, item| {
349            this.inner.insert_impl_item(
350                owner_id,
351                surface::ImplItemFn {
352                    attrs: item.attrs,
353                    sig: Some(item.kind),
354                    node_id: item.node_id,
355                },
356            )
357        })
358    }
359
360    fn collect_assoc_methods(
361        &mut self,
362        methods: Vec<DetachedItem<surface::FnSig>>,
363        assoc_items: impl Iterator<Item = &'tcx AssocItem>,
364        mut insert_item: impl FnMut(&mut Self, OwnerId, DetachedItem<surface::FnSig>) -> Result,
365    ) -> Result {
366        let mut table: HashMap<Symbol, DetachedItem<(surface::FnSig, Option<DefId>)>> =
367            HashMap::default();
368        // 1. make a table of the impl-items
369        for item in methods {
370            let name = path_to_symbol(&item.path);
371            let span = item.path.span;
372            if let Entry::Occupied(_) = table.entry(name) {
373                return Err(self
374                    .inner
375                    .errors
376                    .emit(errors::MultipleSpecifications { name, span }));
377            } else {
378                table.insert(name, item.map_kind(|spec| (spec, None)));
379            }
380        }
381        // 2. walk over all the assoc-items to resolve names
382        for item in assoc_items {
383            if let AssocKind::Fn { name, .. } = item.kind
384                && let Some(val) = table.get_mut(&name)
385                && val.kind.1.is_none()
386            {
387                val.kind.1 = Some(item.def_id);
388            }
389        }
390        // 3. Attach the `fn_sig` to the resolved `DefId`
391        for (_name, item) in table {
392            let Some(def_id) = item.kind.1 else {
393                return Err(self
394                    .inner
395                    .errors
396                    .emit(errors::UnresolvedSpecification::new(&item.path, "identifier")));
397            };
398            if let Some(def_id) = self.unwrap_def_id(&def_id)? {
399                dbg::hyperlink!(self.inner.tcx, item.path.span, self.inner.tcx.def_span(def_id));
400                let owner_id = self.inner.tcx.local_def_id_to_hir_id(def_id).owner;
401                insert_item(self, owner_id, item.map_kind(|k| k.0))?;
402            }
403        }
404        Ok(())
405    }
406}