flux_driver/collector/
extern_specs.rs

1use std::iter;
2
3use flux_middle::ExternSpecMappingErr;
4use flux_rustc_bridge::lowering;
5use rustc_errors::Diagnostic;
6use rustc_hir as hir;
7use rustc_hir::{
8    BodyId, OwnerId,
9    def_id::{DefId, LocalDefId},
10};
11use rustc_middle::ty::{self, TyCtxt};
12use rustc_span::{ErrorGuaranteed, Span, symbol::kw};
13
14use super::{FluxAttrs, SpecCollector};
15
16type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
17
18pub(super) struct ExternSpecCollector<'a, 'sess, 'tcx> {
19    inner: &'a mut SpecCollector<'sess, 'tcx>,
20    /// The block corresponding to the `const _: () = { ... }` annotated with `flux::extern_spec`
21    block: &'tcx hir::Block<'tcx>,
22}
23
24struct ExternImplItem {
25    impl_id: DefId,
26    item_id: DefId,
27}
28
29impl<'a, 'sess, 'tcx> ExternSpecCollector<'a, 'sess, 'tcx> {
30    pub(super) fn collect(inner: &'a mut SpecCollector<'sess, 'tcx>, body_id: BodyId) -> Result {
31        Self::new(inner, body_id)?.run()
32    }
33
34    fn new(inner: &'a mut SpecCollector<'sess, 'tcx>, body_id: BodyId) -> Result<Self> {
35        let body = inner.tcx.hir_body(body_id);
36        if let hir::ExprKind::Block(block, _) = body.value.kind {
37            Ok(Self { inner, block })
38        } else {
39            Err(inner
40                .errors
41                .emit(errors::MalformedExternSpec::new(body.value.span)))
42        }
43    }
44
45    fn run(mut self) -> Result {
46        let item = self.item_at(0)?;
47
48        let attrs = self
49            .inner
50            .parse_attrs_and_report_dups(item.owner_id.def_id)?;
51
52        match &item.kind {
53            hir::ItemKind::Fn { .. } => self.collect_extern_fn(item, attrs),
54            hir::ItemKind::Enum(enum_def, _) => {
55                self.collect_extern_enum(item.owner_id, enum_def, attrs)
56            }
57            hir::ItemKind::Struct(variant, _) => {
58                self.collect_extern_struct(item.owner_id, variant, attrs)
59            }
60            hir::ItemKind::Trait(_, _, _, bounds, items) => {
61                self.collect_extern_trait(item.owner_id, bounds, items, attrs)
62            }
63            hir::ItemKind::Impl(impl_) => self.collect_extern_impl(item.owner_id, impl_, attrs),
64            _ => Err(self.malformed()),
65        }
66    }
67
68    fn collect_extern_fn(&mut self, item: &hir::Item, attrs: FluxAttrs) -> Result {
69        self.inner.collect_fn_spec(item.owner_id, attrs)?;
70
71        let extern_id = self.extract_extern_id_from_fn(item)?;
72        self.insert_extern_id(item.owner_id.def_id, extern_id)?;
73        self.check_generics(item.owner_id, extern_id)?;
74
75        Ok(())
76    }
77
78    fn collect_extern_struct(
79        &mut self,
80        struct_id: OwnerId,
81        variant: &hir::VariantData,
82        attrs: FluxAttrs,
83    ) -> Result {
84        let dummy_struct = self.item_at(1)?;
85        self.inner.specs.insert_dummy(dummy_struct.owner_id);
86
87        let extern_id = self.extract_extern_id_from_struct(dummy_struct).unwrap();
88        self.insert_extern_id(struct_id.def_id, extern_id)?;
89        self.check_generics(struct_id, extern_id)?;
90
91        self.inner.collect_struct_def(struct_id, attrs, variant)?;
92
93        Ok(())
94    }
95
96    fn collect_extern_enum(
97        &mut self,
98        enum_id: OwnerId,
99        enum_def: &hir::EnumDef,
100        attrs: FluxAttrs,
101    ) -> Result {
102        let dummy_struct = self.item_at(1)?;
103        self.inner.specs.insert_dummy(dummy_struct.owner_id);
104
105        let extern_id = self.extract_extern_id_from_struct(dummy_struct).unwrap();
106        self.insert_extern_id(enum_id.def_id, extern_id)?;
107        self.check_generics(enum_id, extern_id)?;
108
109        self.inner.collect_enum_def(enum_id, attrs, enum_def)?;
110
111        // Add stuff about Ctor
112        // Get the AdtDef for the enum
113        let extern_enum_def = self.tcx().adt_def(extern_id);
114
115        // Collect all constructor DefIds from variants
116        let extern_variants = extern_enum_def.variants();
117        let enum_variants = enum_def.variants;
118        let extern_len = extern_variants.len();
119        let enum_len = enum_variants.len();
120        if extern_len != enum_len {
121            let reason = format!("expected {extern_len:?} variants but only have {enum_len:?}");
122            return Err(self.invalid_enum_extern_spec(reason));
123        }
124        for (extern_variant, variant) in extern_enum_def.variants().iter().zip(enum_def.variants) {
125            if let Some(extern_ctor) = extern_variant.ctor_def_id()
126                && let Some(ctor) = variant.data.ctor_def_id()
127                && self.tcx().def_kind(extern_ctor) == self.tcx().def_kind(ctor)
128            {
129                self.insert_extern_id(ctor, extern_ctor)?;
130            } else {
131                let reason = format!(
132                    "extern variant {extern_variant:?} incompatible with specified {variant:?}"
133                );
134                return Err(self.invalid_enum_extern_spec(reason));
135            }
136        }
137        Ok(())
138    }
139
140    fn collect_extern_impl(
141        &mut self,
142        impl_id: OwnerId,
143        impl_: &hir::Impl,
144        attrs: FluxAttrs,
145    ) -> Result {
146        self.inner.collect_impl(impl_id, attrs)?;
147
148        let dummy_item = self.item_at(1)?;
149        self.inner.specs.insert_dummy(dummy_item.owner_id);
150
151        // If this is a trait impl compute the impl_id from the trait_ref
152        let mut impl_of_trait = None;
153        if let hir::ItemKind::Impl(dummy_impl) = dummy_item.kind {
154            impl_of_trait =
155                Some(self.extract_extern_id_from_impl(dummy_item.owner_id, dummy_impl)?);
156
157            self.inner.specs.insert_dummy(self.item_at(2)?.owner_id);
158        }
159
160        let mut extern_impl_id = impl_of_trait;
161        for item in impl_.items {
162            let item_id = item.id.owner_id.def_id;
163            let extern_item = if let hir::AssocItemKind::Fn { .. } = item.kind {
164                let attrs = self.inner.parse_attrs_and_report_dups(item_id)?;
165                self.collect_extern_impl_fn(impl_of_trait, item, attrs)?
166            } else {
167                continue;
168            };
169
170            if *extern_impl_id.get_or_insert(extern_item.impl_id) != extern_item.impl_id {
171                return Err(self.invalid_impl_block());
172            }
173        }
174
175        if let Some(extern_impl_id) = extern_impl_id {
176            self.check_generics(impl_id, extern_impl_id)?;
177            self.insert_extern_id(impl_id.def_id, extern_impl_id)?;
178        }
179
180        Ok(())
181    }
182
183    fn collect_extern_impl_fn(
184        &mut self,
185        impl_of_trait: Option<DefId>,
186        item: &hir::ImplItemRef,
187        attrs: FluxAttrs,
188    ) -> Result<ExternImplItem> {
189        let item_id = item.id.owner_id;
190        self.inner.collect_fn_spec(item_id, attrs)?;
191
192        let extern_impl_item = self.extract_extern_id_from_impl_fn(impl_of_trait, item)?;
193        self.insert_extern_id(item_id.def_id, extern_impl_item.item_id)?;
194        self.check_generics(item_id, extern_impl_item.item_id)?;
195
196        Ok(extern_impl_item)
197    }
198
199    fn collect_extern_trait(
200        &mut self,
201        trait_id: OwnerId,
202        bounds: &hir::GenericBounds,
203        items: &[hir::TraitItemRef],
204        attrs: FluxAttrs,
205    ) -> Result {
206        self.inner.collect_trait(trait_id, attrs)?;
207
208        let extern_trait_id = self.extract_extern_id_from_trait(bounds)?;
209        self.insert_extern_id(trait_id.def_id, extern_trait_id)?;
210        self.check_generics(trait_id, extern_trait_id)?;
211
212        for item in items {
213            let item_id = item.id.owner_id.def_id;
214            if let hir::AssocItemKind::Fn { .. } = item.kind {
215                let attrs = self.inner.parse_attrs_and_report_dups(item_id)?;
216                self.collect_extern_trait_fn(extern_trait_id, item, attrs)?;
217            } else {
218                continue;
219            }
220        }
221
222        Ok(())
223    }
224
225    fn collect_extern_trait_fn(
226        &mut self,
227        extern_trait_id: DefId,
228        item: &hir::TraitItemRef,
229        attrs: FluxAttrs,
230    ) -> Result {
231        let item_id = item.id.owner_id;
232        self.inner.collect_fn_spec(item_id, attrs)?;
233
234        let extern_fn_id = self.extract_extern_id_from_trait_fn(extern_trait_id, item)?;
235        self.insert_extern_id(item.id.owner_id.def_id, extern_fn_id)?;
236        self.check_generics(item_id, extern_fn_id)?;
237
238        Ok(())
239    }
240
241    fn extract_extern_id_from_struct(&self, item: &hir::Item) -> Result<DefId> {
242        if let hir::ItemKind::Struct(data, ..) = item.kind
243            && let Some(extern_field) = data.fields().last()
244            && let ty = self.tcx().type_of(extern_field.def_id)
245            && let Some(adt_def) = ty.skip_binder().ty_adt_def()
246        {
247            Ok(adt_def.did())
248        } else {
249            Err(self.malformed())
250        }
251    }
252
253    fn extract_extern_id_from_fn(&self, item: &hir::Item) -> Result<DefId> {
254        if let hir::ItemKind::Fn { body, .. } = item.kind {
255            self.extract_callee_from_body(body)
256        } else {
257            Err(self.malformed())
258        }
259    }
260
261    fn extract_extern_id_from_impl_fn(
262        &self,
263        impl_of_trait: Option<DefId>,
264        item: &hir::ImplItemRef,
265    ) -> Result<ExternImplItem> {
266        if let hir::ImplItemKind::Fn(_, body_id) = self.tcx().hir_impl_item(item.id).kind {
267            let callee_id = self.extract_callee_from_body(body_id)?;
268            if let Some(extern_impl_id) = impl_of_trait {
269                let map = self.tcx().impl_item_implementor_ids(extern_impl_id);
270                if let Some(extern_item_id) = map.get(&callee_id) {
271                    Ok(ExternImplItem { impl_id: extern_impl_id, item_id: *extern_item_id })
272                } else {
273                    Err(self.item_not_in_trait_impl(item.id.owner_id, callee_id, extern_impl_id))
274                }
275            } else {
276                let opt_extern_impl_id = self.tcx().impl_of_method(callee_id);
277                if let Some(extern_impl_id) = opt_extern_impl_id {
278                    debug_assert!(self.tcx().trait_id_of_impl(extern_impl_id).is_none());
279                    Ok(ExternImplItem { impl_id: extern_impl_id, item_id: callee_id })
280                } else {
281                    Err(self.invalid_item_in_inherent_impl(item.id.owner_id, callee_id))
282                }
283            }
284        } else {
285            Err(self.malformed())
286        }
287    }
288
289    fn extract_extern_id_from_trait(&self, bounds: &hir::GenericBounds) -> Result<DefId> {
290        if let Some(bound) = bounds.first()
291            && let Some(trait_ref) = bound.trait_ref()
292            && let Some(trait_id) = trait_ref.trait_def_id()
293        {
294            Ok(trait_id)
295        } else {
296            Err(self.malformed())
297        }
298    }
299
300    fn extract_extern_id_from_trait_fn(
301        &self,
302        trait_id: DefId,
303        item: &hir::TraitItemRef,
304    ) -> Result<DefId> {
305        if let hir::TraitItemKind::Fn(_, trait_fn) = self.tcx().hir_trait_item(item.id).kind
306            && let hir::TraitFn::Provided(body_id) = trait_fn
307        {
308            let callee_id = self.extract_callee_from_body(body_id)?;
309            if let Some(callee_trait_id) = self.tcx().trait_of_item(callee_id)
310                && trait_id == callee_trait_id
311            {
312                Ok(callee_id)
313            } else {
314                // I can't figure out how to trigger this via code generated with the extern spec
315                // macro that also type checks but leaving it here as a precaution.
316                Err(self.item_not_in_trait(item.id.owner_id, callee_id, trait_id))
317            }
318        } else {
319            Err(self.malformed())
320        }
321    }
322
323    fn extract_extern_id_from_impl(&self, impl_id: OwnerId, impl_: &hir::Impl) -> Result<DefId> {
324        if let Some(item) = impl_.items.first()
325            && let hir::AssocItemKind::Fn { .. } = item.kind
326            && let Some((clause, _)) = self
327                .tcx()
328                .predicates_of(item.id.owner_id.def_id)
329                .predicates
330                .first()
331            && let Some(poly_trait_pred) = clause.as_trait_clause()
332            && let Some(trait_pred) = poly_trait_pred.no_bound_vars()
333        {
334            let trait_ref = trait_pred.trait_ref;
335            lowering::resolve_trait_ref_impl_id(self.tcx(), impl_id.to_def_id(), trait_ref)
336                .map(|(impl_id, _)| impl_id)
337                .ok_or_else(|| self.cannot_resolve_trait_impl())
338        } else {
339            Err(self.malformed())
340        }
341    }
342
343    fn extract_callee_from_body(&self, body_id: hir::BodyId) -> Result<DefId> {
344        let owner = self.tcx().hir_body_owner_def_id(body_id);
345        let typeck = self.tcx().typeck(owner);
346        if let hir::ExprKind::Block(b, _) = self.tcx().hir_body(body_id).value.kind
347            && let Some(e) = b.expr
348            && let hir::ExprKind::Call(callee, _) = e.kind
349            && let rustc_middle::ty::FnDef(callee_id, _) = typeck.node_type(callee.hir_id).kind()
350        {
351            Ok(*callee_id)
352        } else {
353            Err(self.malformed())
354        }
355    }
356
357    /// Returns the item inside the const block at position `i` starting from the end.
358    #[track_caller]
359    fn item_at(&self, i: usize) -> Result<&'tcx hir::Item<'tcx>> {
360        let stmts = self.block.stmts;
361        let index = stmts
362            .len()
363            .checked_sub(i + 1)
364            .ok_or_else(|| self.malformed())?;
365        let st = stmts.get(index).ok_or_else(|| self.malformed())?;
366        if let hir::StmtKind::Item(item_id) = st.kind {
367            Ok(self.tcx().hir_item(item_id))
368        } else {
369            Err(self.malformed())
370        }
371    }
372
373    fn insert_extern_id(&mut self, local_id: LocalDefId, extern_id: DefId) -> Result {
374        self.inner
375            .specs
376            .insert_extern_spec_id_mapping(local_id, extern_id)
377            .map_err(|err| {
378                match err {
379                    ExternSpecMappingErr::IsLocal(extern_id_local) => {
380                        self.emit(errors::ExternSpecForLocalDef {
381                            span: ident_or_def_span(self.tcx(), local_id),
382                            local_def_span: ident_or_def_span(self.tcx(), extern_id_local),
383                            name: self.tcx().def_path_str(extern_id),
384                        })
385                    }
386                    ExternSpecMappingErr::Dup(previous_extern_spec) => {
387                        self.emit(errors::DupExternSpec {
388                            span: ident_or_def_span(self.tcx(), local_id),
389                            previous_span: ident_or_def_span(self.tcx(), previous_extern_spec),
390                            name: self.tcx().def_path_str(extern_id),
391                        })
392                    }
393                }
394            })
395    }
396
397    fn check_generics(&mut self, local_id: OwnerId, extern_id: DefId) -> Result {
398        let tcx = self.tcx();
399        let local_params = &tcx.generics_of(local_id).own_params;
400        let extern_params = &tcx.generics_of(extern_id).own_params;
401
402        let mismatch = 'mismatch: {
403            if local_params.len() != extern_params.len() {
404                break 'mismatch true;
405            }
406            for (local_param, extern_param) in iter::zip(local_params, extern_params) {
407                if !cmp_generic_param_def(local_param, extern_param) {
408                    break 'mismatch true;
409                }
410                // We skip the self parameter because its id is the same as the trait's id, which
411                // has already been inserted.
412                if local_param.name != kw::SelfUpper {
413                    #[expect(clippy::disallowed_methods)]
414                    self.insert_extern_id(local_param.def_id.expect_local(), extern_param.def_id)?;
415                }
416            }
417            false
418        };
419        if mismatch {
420            let local_hir_generics = tcx.hir_get_generics(local_id.def_id).unwrap();
421            let span = local_hir_generics.span;
422            Err(self.emit(errors::MismatchedGenerics {
423                span,
424                extern_def: tcx.def_span(extern_id),
425                def_descr: tcx.def_descr(extern_id),
426            }))
427        } else {
428            Ok(())
429        }
430    }
431
432    #[track_caller]
433    fn malformed(&self) -> ErrorGuaranteed {
434        self.emit(errors::MalformedExternSpec::new(self.block.span))
435    }
436
437    #[track_caller]
438    fn invalid_enum_extern_spec(&self, reason: String) -> ErrorGuaranteed {
439        self.emit(errors::InvalidEnumExternSpec::new(self.block.span, reason))
440    }
441
442    #[track_caller]
443    fn item_not_in_trait_impl(
444        &self,
445        local_id: OwnerId,
446        extern_id: DefId,
447        extern_impl_id: DefId,
448    ) -> ErrorGuaranteed {
449        let tcx = self.tcx();
450        self.emit(errors::ItemNotInTraitImpl {
451            span: ident_or_def_span(tcx, local_id),
452            name: tcx.def_path_str(extern_id),
453            extern_impl_span: tcx.def_span(extern_impl_id),
454        })
455    }
456
457    fn invalid_item_in_inherent_impl(
458        &self,
459        local_id: OwnerId,
460        extern_id: DefId,
461    ) -> ErrorGuaranteed {
462        let tcx = self.tcx();
463        self.emit(errors::InvalidItemInInherentImpl {
464            span: ident_or_def_span(tcx, local_id),
465            name: tcx.def_path_str(extern_id),
466            extern_item_span: tcx.def_span(extern_id),
467        })
468    }
469
470    #[track_caller]
471    fn invalid_impl_block(&self) -> ErrorGuaranteed {
472        self.emit(errors::InvalidImplBlock { span: self.block.span })
473    }
474
475    #[track_caller]
476    fn cannot_resolve_trait_impl(&self) -> ErrorGuaranteed {
477        self.emit(errors::CannotResolveTraitImpl { span: self.block.span })
478    }
479
480    #[track_caller]
481    fn item_not_in_trait(
482        &self,
483        local_id: OwnerId,
484        extern_id: DefId,
485        extern_trait_id: DefId,
486    ) -> ErrorGuaranteed {
487        let tcx = self.tcx();
488        self.emit(errors::ItemNotInTrait {
489            span: ident_or_def_span(tcx, local_id),
490            name: tcx.def_path_str(extern_id),
491            extern_trait_span: tcx.def_span(extern_trait_id),
492        })
493    }
494
495    fn emit<'b>(&'b self, err: impl Diagnostic<'b>) -> ErrorGuaranteed {
496        self.inner.errors.emit(err)
497    }
498
499    fn tcx(&self) -> TyCtxt<'tcx> {
500        self.inner.tcx
501    }
502}
503
504fn cmp_generic_param_def(a: &ty::GenericParamDef, b: &ty::GenericParamDef) -> bool {
505    if a.name != b.name {
506        return false;
507    }
508    if a.index != b.index {
509        return false;
510    }
511    matches!(
512        (&a.kind, &b.kind),
513        (ty::GenericParamDefKind::Lifetime, ty::GenericParamDefKind::Lifetime)
514            | (ty::GenericParamDefKind::Type { .. }, ty::GenericParamDefKind::Type { .. })
515            | (ty::GenericParamDefKind::Const { .. }, ty::GenericParamDefKind::Const { .. })
516    )
517}
518
519fn ident_or_def_span(tcx: TyCtxt, def_id: impl Into<DefId>) -> Span {
520    let def_id = def_id.into();
521    tcx.def_ident_span(def_id)
522        .unwrap_or_else(|| tcx.def_span(def_id))
523}
524
525mod errors {
526    use flux_errors::E0999;
527    use flux_macros::Diagnostic;
528    use rustc_span::Span;
529
530    #[derive(Diagnostic)]
531    #[diag(driver_malformed_extern_spec, code = E0999)]
532    pub(super) struct MalformedExternSpec {
533        #[primary_span]
534        span: Span,
535    }
536
537    impl MalformedExternSpec {
538        pub(super) fn new(span: Span) -> Self {
539            Self { span }
540        }
541    }
542
543    #[derive(Diagnostic)]
544    #[diag(driver_invalid_enum_extern_spec, code = E0999)]
545    pub(super) struct InvalidEnumExternSpec {
546        #[primary_span]
547        span: Span,
548        reason: String,
549    }
550
551    impl InvalidEnumExternSpec {
552        pub(super) fn new(span: Span, reason: String) -> Self {
553            Self { span, reason }
554        }
555    }
556
557    #[derive(Diagnostic)]
558    #[diag(driver_cannot_resolve_trait_impl, code = E0999)]
559    #[note]
560    pub(super) struct CannotResolveTraitImpl {
561        #[primary_span]
562        pub span: Span,
563    }
564
565    #[derive(Diagnostic)]
566    #[diag(driver_invalid_impl_block, code = E0999)]
567    pub(super) struct InvalidImplBlock {
568        #[primary_span]
569        #[label]
570        pub span: Span,
571    }
572
573    #[derive(Diagnostic)]
574    #[diag(driver_item_not_in_trait_impl, code = E0999)]
575    pub(super) struct ItemNotInTraitImpl {
576        #[primary_span]
577        #[label]
578        pub span: Span,
579        pub name: String,
580        #[note]
581        pub extern_impl_span: Span,
582    }
583
584    #[derive(Diagnostic)]
585    #[diag(driver_invalid_item_in_inherent_impl, code = E0999)]
586    pub(super) struct InvalidItemInInherentImpl {
587        #[primary_span]
588        #[label]
589        pub span: Span,
590        pub name: String,
591        #[note]
592        pub extern_item_span: Span,
593    }
594
595    #[derive(Diagnostic)]
596    #[diag(driver_item_not_in_trait, code = E0999)]
597    pub(super) struct ItemNotInTrait {
598        #[primary_span]
599        #[label]
600        pub span: Span,
601        pub name: String,
602        #[note]
603        pub extern_trait_span: Span,
604    }
605
606    #[derive(Diagnostic)]
607    #[diag(driver_extern_spec_for_local_def, code = E0999)]
608    pub(super) struct ExternSpecForLocalDef {
609        #[primary_span]
610        pub span: Span,
611        #[note]
612        pub local_def_span: Span,
613        pub name: String,
614    }
615
616    #[derive(Diagnostic)]
617    #[diag(driver_dup_extern_spec, code = E0999)]
618    pub(super) struct DupExternSpec {
619        #[primary_span]
620        #[label]
621        pub span: Span,
622        #[note]
623        pub previous_span: Span,
624        pub name: String,
625    }
626
627    #[derive(Diagnostic)]
628    #[diag(driver_mismatched_generics, code = E0999)]
629    #[note]
630    pub(super) struct MismatchedGenerics {
631        #[primary_span]
632        #[label]
633        pub span: Span,
634        #[label(driver_extern_def_label)]
635        pub extern_def: Span,
636        pub def_descr: &'static str,
637    }
638}