flux_driver/collector/
extern_specs.rs

1use std::iter;
2
3use flux_middle::ExternSpecMappingErr;
4use flux_rustc_bridge::lowering;
5use rustc_errors::Diagnostic;
6use rustc_hir as hir;
7use rustc_hir::{
8    BodyId, OwnerId,
9    def_id::{DefId, LocalDefId},
10};
11use rustc_middle::ty::{self, TyCtxt};
12use rustc_span::{ErrorGuaranteed, Span, symbol::kw};
13
14use super::{FluxAttrs, SpecCollector};
15
16type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
17
18pub(super) struct ExternSpecCollector<'a, 'sess, 'tcx> {
19    inner: &'a mut SpecCollector<'sess, 'tcx>,
20    /// The block corresponding to the `const _: () = { ... }` annotated with `flux::extern_spec`
21    block: &'tcx hir::Block<'tcx>,
22}
23
24struct ExternImplItem {
25    impl_id: DefId,
26    item_id: DefId,
27}
28
29impl<'a, 'sess, 'tcx> ExternSpecCollector<'a, 'sess, 'tcx> {
30    pub(super) fn collect(inner: &'a mut SpecCollector<'sess, 'tcx>, body_id: BodyId) -> Result {
31        Self::new(inner, body_id)?.run()
32    }
33
34    fn new(inner: &'a mut SpecCollector<'sess, 'tcx>, body_id: BodyId) -> Result<Self> {
35        let body = inner.tcx.hir_body(body_id);
36        if let hir::ExprKind::Block(block, _) = body.value.kind {
37            Ok(Self { inner, block })
38        } else {
39            Err(inner
40                .errors
41                .emit(errors::MalformedExternSpec::new(body.value.span)))
42        }
43    }
44
45    fn run(mut self) -> Result {
46        let item = self.item_at(0)?;
47
48        let attrs = self
49            .inner
50            .parse_attrs_and_report_dups(item.owner_id.def_id)?;
51
52        match &item.kind {
53            hir::ItemKind::Fn { .. } => self.collect_extern_fn(item, attrs),
54            hir::ItemKind::Enum(_, _, enum_def) => {
55                self.collect_extern_enum(item.owner_id, enum_def, attrs)
56            }
57            hir::ItemKind::Struct(_, _, variant) => {
58                self.collect_extern_struct(item.owner_id, variant, attrs)
59            }
60            hir::ItemKind::Trait(_, _, _, _, _, bounds, items) => {
61                self.collect_extern_trait(item.owner_id, bounds, items, attrs)
62            }
63            hir::ItemKind::Impl(impl_) => self.collect_extern_impl(item.owner_id, impl_, attrs),
64            _ => Err(self.malformed()),
65        }
66    }
67
68    fn collect_extern_fn(&mut self, item: &hir::Item, attrs: FluxAttrs) -> Result {
69        self.inner.collect_fn_spec(item.owner_id, attrs)?;
70
71        let extern_id = self.extract_extern_id_from_fn(item)?;
72        self.insert_extern_id(item.owner_id.def_id, extern_id)?;
73        self.check_generics(item.owner_id, extern_id)?;
74
75        Ok(())
76    }
77
78    fn collect_extern_struct(
79        &mut self,
80        struct_id: OwnerId,
81        variant: &hir::VariantData,
82        attrs: FluxAttrs,
83    ) -> Result {
84        let dummy_struct = self.item_at(1)?;
85        self.inner.specs.insert_dummy(dummy_struct.owner_id.def_id);
86
87        let extern_id = self.extract_extern_id_from_struct(dummy_struct).unwrap();
88        self.insert_extern_id(struct_id.def_id, extern_id)?;
89        self.check_generics(struct_id, extern_id)?;
90
91        if let Some(ctor_id) = variant.ctor_def_id() {
92            self.inner.specs.insert_dummy(ctor_id);
93        }
94
95        self.inner.collect_struct_def(struct_id, attrs, variant)?;
96
97        Ok(())
98    }
99
100    fn collect_extern_enum(
101        &mut self,
102        enum_id: OwnerId,
103        enum_def: &hir::EnumDef,
104        attrs: FluxAttrs,
105    ) -> Result {
106        let dummy_struct = self.item_at(1)?;
107        self.inner.specs.insert_dummy(dummy_struct.owner_id.def_id);
108
109        let extern_id = self.extract_extern_id_from_struct(dummy_struct).unwrap();
110        self.insert_extern_id(enum_id.def_id, extern_id)?;
111        self.check_generics(enum_id, extern_id)?;
112
113        self.inner.collect_enum_def(enum_id, attrs, enum_def)?;
114
115        // Add stuff about Ctor
116        // Get the AdtDef for the enum
117        let extern_enum_def = self.tcx().adt_def(extern_id);
118
119        // Collect all constructor DefIds from variants
120        let extern_variants = extern_enum_def.variants();
121        let enum_variants = enum_def.variants;
122        let extern_len = extern_variants.len();
123        let enum_len = enum_variants.len();
124        if extern_len != enum_len {
125            let reason = format!("expected {extern_len:?} variants but only have {enum_len:?}");
126            return Err(self.invalid_enum_extern_spec(reason));
127        }
128        for (extern_variant, variant) in extern_enum_def.variants().iter().zip(enum_def.variants) {
129            if let Some(extern_ctor) = extern_variant.ctor_def_id()
130                && let Some(ctor) = variant.data.ctor_def_id()
131                && self.tcx().def_kind(extern_ctor) == self.tcx().def_kind(ctor)
132            {
133                self.insert_extern_id(ctor, extern_ctor)?;
134            } else {
135                let reason = format!(
136                    "extern variant `{}` incompatible with specified `{}`",
137                    extern_variant.ident(self.tcx()),
138                    rustc_hir_pretty::id_to_string(&self.tcx(), variant.hir_id)
139                );
140                return Err(self.invalid_enum_extern_spec(reason));
141            }
142        }
143        Ok(())
144    }
145
146    fn collect_extern_impl(
147        &mut self,
148        impl_id: OwnerId,
149        impl_: &hir::Impl,
150        attrs: FluxAttrs,
151    ) -> Result {
152        self.inner.collect_impl(impl_id, attrs)?;
153
154        let dummy_item = self.item_at(1)?;
155        self.inner.specs.insert_dummy(dummy_item.owner_id.def_id);
156
157        // If this is a trait impl compute the impl_id from the trait_ref
158        let mut impl_of_trait = None;
159        if let hir::ItemKind::Impl(dummy_impl) = &dummy_item.kind {
160            impl_of_trait =
161                Some(self.extract_extern_id_from_impl(dummy_item.owner_id, dummy_impl)?);
162
163            self.inner
164                .specs
165                .insert_dummy(self.item_at(2)?.owner_id.def_id);
166        }
167
168        let mut extern_impl_id = impl_of_trait;
169        for item_id in impl_.items {
170            let item = self.tcx().hir_impl_item(*item_id);
171            let extern_item = if let hir::ImplItemKind::Fn { .. } = item.kind {
172                let attrs = self
173                    .inner
174                    .parse_attrs_and_report_dups(item_id.owner_id.def_id)?;
175                self.collect_extern_impl_fn(impl_of_trait, item, attrs)?
176            } else {
177                continue;
178            };
179
180            if *extern_impl_id.get_or_insert(extern_item.impl_id) != extern_item.impl_id {
181                return Err(self.invalid_impl_block());
182            }
183        }
184
185        if let Some(extern_impl_id) = extern_impl_id {
186            self.check_generics(impl_id, extern_impl_id)?;
187            self.insert_extern_id(impl_id.def_id, extern_impl_id)?;
188        }
189
190        Ok(())
191    }
192
193    fn collect_extern_impl_fn(
194        &mut self,
195        impl_of_trait: Option<DefId>,
196        item: &hir::ImplItem,
197        attrs: FluxAttrs,
198    ) -> Result<ExternImplItem> {
199        self.inner.collect_fn_spec(item.owner_id, attrs)?;
200
201        let extern_impl_item = self.extract_extern_id_from_impl_fn(impl_of_trait, item)?;
202        self.insert_extern_id(item.owner_id.def_id, extern_impl_item.item_id)?;
203        self.check_generics(item.owner_id, extern_impl_item.item_id)?;
204
205        Ok(extern_impl_item)
206    }
207
208    fn collect_extern_trait(
209        &mut self,
210        trait_id: OwnerId,
211        bounds: &hir::GenericBounds,
212        items: &[hir::TraitItemId],
213        attrs: FluxAttrs,
214    ) -> Result {
215        self.inner.collect_trait(trait_id, attrs)?;
216
217        let extern_trait_id = self.extract_extern_id_from_trait(bounds)?;
218        self.insert_extern_id(trait_id.def_id, extern_trait_id)?;
219        self.check_generics(trait_id, extern_trait_id)?;
220
221        for item_id in items {
222            let item = self.tcx().hir_trait_item(*item_id);
223            if let hir::TraitItemKind::Fn { .. } = item.kind {
224                let attrs = self
225                    .inner
226                    .parse_attrs_and_report_dups(item.owner_id.def_id)?;
227                self.collect_extern_trait_fn(extern_trait_id, item, attrs)?;
228            } else {
229                continue;
230            }
231        }
232
233        Ok(())
234    }
235
236    fn collect_extern_trait_fn(
237        &mut self,
238        extern_trait_id: DefId,
239        item: &hir::TraitItem,
240        attrs: FluxAttrs,
241    ) -> Result {
242        let item_id = item.owner_id;
243        self.inner.collect_fn_spec(item_id, attrs)?;
244
245        let extern_fn_id = self.extract_extern_id_from_trait_fn(extern_trait_id, item)?;
246        self.insert_extern_id(item.owner_id.def_id, extern_fn_id)?;
247        self.check_generics(item_id, extern_fn_id)?;
248
249        Ok(())
250    }
251
252    fn extract_extern_id_from_struct(&self, item: &hir::Item) -> Result<DefId> {
253        if let hir::ItemKind::Struct(_, _, data) = item.kind
254            && let Some(extern_field) = data.fields().last()
255            && let ty = self.tcx().type_of(extern_field.def_id)
256            && let Some(adt_def) = ty.skip_binder().ty_adt_def()
257        {
258            Ok(adt_def.did())
259        } else {
260            Err(self.malformed())
261        }
262    }
263
264    fn extract_extern_id_from_fn(&self, item: &hir::Item) -> Result<DefId> {
265        if let hir::ItemKind::Fn { body, .. } = item.kind {
266            self.extract_callee_from_body(body)
267        } else {
268            Err(self.malformed())
269        }
270    }
271
272    fn extract_extern_id_from_impl_fn(
273        &self,
274        impl_of_trait: Option<DefId>,
275        item: &hir::ImplItem,
276    ) -> Result<ExternImplItem> {
277        if let hir::ImplItemKind::Fn(_, body_id) = item.kind {
278            let callee_id = self.extract_callee_from_body(body_id)?;
279            if let Some(extern_impl_id) = impl_of_trait {
280                let map = self.tcx().impl_item_implementor_ids(extern_impl_id);
281                if let Some(extern_item_id) = map.get(&callee_id) {
282                    Ok(ExternImplItem { impl_id: extern_impl_id, item_id: *extern_item_id })
283                } else {
284                    Err(self.item_not_in_trait_impl(item.owner_id, callee_id, extern_impl_id))
285                }
286            } else {
287                let opt_extern_impl_id = self.tcx().impl_of_assoc(callee_id);
288                if let Some(extern_impl_id) = opt_extern_impl_id {
289                    debug_assert!(self.tcx().trait_id_of_impl(extern_impl_id).is_none());
290                    Ok(ExternImplItem { impl_id: extern_impl_id, item_id: callee_id })
291                } else {
292                    Err(self.invalid_item_in_inherent_impl(item.owner_id, callee_id))
293                }
294            }
295        } else {
296            Err(self.malformed())
297        }
298    }
299
300    fn extract_extern_id_from_trait(&self, bounds: &hir::GenericBounds) -> Result<DefId> {
301        if let Some(bound) = bounds.first()
302            && let Some(trait_ref) = bound.trait_ref()
303            && let Some(trait_id) = trait_ref.trait_def_id()
304        {
305            Ok(trait_id)
306        } else {
307            Err(self.malformed())
308        }
309    }
310
311    fn extract_extern_id_from_trait_fn(
312        &self,
313        trait_id: DefId,
314        item: &hir::TraitItem,
315    ) -> Result<DefId> {
316        if let hir::TraitItemKind::Fn(_, trait_fn) = item.kind
317            && let hir::TraitFn::Provided(body_id) = trait_fn
318        {
319            let callee_id = self.extract_callee_from_body(body_id)?;
320            if let Some(callee_trait_id) = self.tcx().trait_of_assoc(callee_id)
321                && trait_id == callee_trait_id
322            {
323                Ok(callee_id)
324            } else {
325                // I can't figure out how to trigger this via code generated with the extern spec
326                // macro that also type checks but leaving it here as a precaution.
327                Err(self.item_not_in_trait(item.owner_id, callee_id, trait_id))
328            }
329        } else {
330            Err(self.malformed())
331        }
332    }
333
334    fn extract_extern_id_from_impl(&self, impl_id: OwnerId, impl_: &hir::Impl) -> Result<DefId> {
335        if let Some(item_id) = impl_.items.first()
336            && let hir::ImplItemKind::Fn { .. } = self.tcx().hir_impl_item(*item_id).kind
337            && let Some((clause, _)) = self
338                .tcx()
339                .predicates_of(item_id.owner_id.def_id)
340                .predicates
341                .first()
342            && let Some(poly_trait_pred) = clause.as_trait_clause()
343            && let Some(trait_pred) = poly_trait_pred.no_bound_vars()
344        {
345            let trait_ref = trait_pred.trait_ref;
346            lowering::resolve_trait_ref_impl_id(self.tcx(), impl_id.to_def_id(), trait_ref)
347                .map(|(impl_id, _)| impl_id)
348                .ok_or_else(|| self.cannot_resolve_trait_impl())
349        } else {
350            Err(self.malformed())
351        }
352    }
353
354    fn extract_callee_from_body(&self, body_id: hir::BodyId) -> Result<DefId> {
355        let owner = self.tcx().hir_body_owner_def_id(body_id);
356        let typeck = self.tcx().typeck(owner);
357        if let hir::ExprKind::Block(b, _) = self.tcx().hir_body(body_id).value.kind
358            && let Some(e) = b.expr
359            && let hir::ExprKind::Call(callee, _) = e.kind
360            && let rustc_middle::ty::FnDef(callee_id, _) = typeck.node_type(callee.hir_id).kind()
361        {
362            Ok(*callee_id)
363        } else {
364            Err(self.malformed())
365        }
366    }
367
368    /// Returns the item inside the const block at position `i` starting from the end.
369    #[track_caller]
370    fn item_at(&self, i: usize) -> Result<&'tcx hir::Item<'tcx>> {
371        let stmts = self.block.stmts;
372        let index = stmts
373            .len()
374            .checked_sub(i + 1)
375            .ok_or_else(|| self.malformed())?;
376        let st = stmts.get(index).ok_or_else(|| self.malformed())?;
377        if let hir::StmtKind::Item(item_id) = st.kind {
378            Ok(self.tcx().hir_item(item_id))
379        } else {
380            Err(self.malformed())
381        }
382    }
383
384    fn insert_extern_id(&mut self, local_id: LocalDefId, extern_id: DefId) -> Result {
385        self.inner
386            .specs
387            .insert_extern_spec_id_mapping(local_id, extern_id)
388            .map_err(|err| {
389                match err {
390                    ExternSpecMappingErr::IsLocal(extern_id_local) => {
391                        self.emit(errors::ExternSpecForLocalDef {
392                            span: ident_or_def_span(self.tcx(), local_id),
393                            local_def_span: ident_or_def_span(self.tcx(), extern_id_local),
394                            name: self.tcx().def_path_str(extern_id),
395                        })
396                    }
397                    ExternSpecMappingErr::Dup(previous_extern_spec) => {
398                        self.emit(errors::DupExternSpec {
399                            span: ident_or_def_span(self.tcx(), local_id),
400                            previous_span: ident_or_def_span(self.tcx(), previous_extern_spec),
401                            name: self.tcx().def_path_str(extern_id),
402                        })
403                    }
404                }
405            })
406    }
407
408    fn check_generics(&mut self, local_id: OwnerId, extern_id: DefId) -> Result {
409        let tcx = self.tcx();
410        let local_params = &tcx.generics_of(local_id).own_params;
411        let extern_params = &tcx.generics_of(extern_id).own_params;
412
413        let mismatch = 'mismatch: {
414            if local_params.len() != extern_params.len() {
415                break 'mismatch true;
416            }
417            for (local_param, extern_param) in iter::zip(local_params, extern_params) {
418                if !cmp_generic_param_def(local_param, extern_param) {
419                    break 'mismatch true;
420                }
421                // We skip the self parameter because its id is the same as the trait's id, which
422                // has already been inserted.
423                if local_param.name != kw::SelfUpper {
424                    #[expect(clippy::disallowed_methods)]
425                    self.insert_extern_id(local_param.def_id.expect_local(), extern_param.def_id)?;
426                }
427            }
428            false
429        };
430        if mismatch {
431            let local_hir_generics = tcx.hir_get_generics(local_id.def_id).unwrap();
432            let span = local_hir_generics.span;
433            Err(self.emit(errors::MismatchedGenerics {
434                span,
435                extern_def: tcx.def_span(extern_id),
436                def_descr: tcx.def_descr(extern_id),
437            }))
438        } else {
439            Ok(())
440        }
441    }
442
443    #[track_caller]
444    fn malformed(&self) -> ErrorGuaranteed {
445        self.emit(errors::MalformedExternSpec::new(self.block.span))
446    }
447
448    #[track_caller]
449    fn invalid_enum_extern_spec(&self, reason: String) -> ErrorGuaranteed {
450        self.emit(errors::InvalidEnumExternSpec::new(self.block.span, reason))
451    }
452
453    #[track_caller]
454    fn item_not_in_trait_impl(
455        &self,
456        local_id: OwnerId,
457        extern_id: DefId,
458        extern_impl_id: DefId,
459    ) -> ErrorGuaranteed {
460        let tcx = self.tcx();
461        self.emit(errors::ItemNotInTraitImpl {
462            span: ident_or_def_span(tcx, local_id),
463            name: tcx.def_path_str(extern_id),
464            extern_impl_span: tcx.def_span(extern_impl_id),
465        })
466    }
467
468    fn invalid_item_in_inherent_impl(
469        &self,
470        local_id: OwnerId,
471        extern_id: DefId,
472    ) -> ErrorGuaranteed {
473        let tcx = self.tcx();
474        self.emit(errors::InvalidItemInInherentImpl {
475            span: ident_or_def_span(tcx, local_id),
476            name: tcx.def_path_str(extern_id),
477            extern_item_span: tcx.def_span(extern_id),
478        })
479    }
480
481    #[track_caller]
482    fn invalid_impl_block(&self) -> ErrorGuaranteed {
483        self.emit(errors::InvalidImplBlock { span: self.block.span })
484    }
485
486    #[track_caller]
487    fn cannot_resolve_trait_impl(&self) -> ErrorGuaranteed {
488        self.emit(errors::CannotResolveTraitImpl { span: self.block.span })
489    }
490
491    #[track_caller]
492    fn item_not_in_trait(
493        &self,
494        local_id: OwnerId,
495        extern_id: DefId,
496        extern_trait_id: DefId,
497    ) -> ErrorGuaranteed {
498        let tcx = self.tcx();
499        self.emit(errors::ItemNotInTrait {
500            span: ident_or_def_span(tcx, local_id),
501            name: tcx.def_path_str(extern_id),
502            extern_trait_span: tcx.def_span(extern_trait_id),
503        })
504    }
505
506    #[track_caller]
507    fn emit<'b>(&'b self, err: impl Diagnostic<'b>) -> ErrorGuaranteed {
508        self.inner.errors.emit(err)
509    }
510
511    fn tcx(&self) -> TyCtxt<'tcx> {
512        self.inner.tcx
513    }
514}
515
516fn cmp_generic_param_def(a: &ty::GenericParamDef, b: &ty::GenericParamDef) -> bool {
517    if a.name != b.name {
518        return false;
519    }
520    if a.index != b.index {
521        return false;
522    }
523    matches!(
524        (&a.kind, &b.kind),
525        (ty::GenericParamDefKind::Lifetime, ty::GenericParamDefKind::Lifetime)
526            | (ty::GenericParamDefKind::Type { .. }, ty::GenericParamDefKind::Type { .. })
527            | (ty::GenericParamDefKind::Const { .. }, ty::GenericParamDefKind::Const { .. })
528    )
529}
530
531fn ident_or_def_span(tcx: TyCtxt, def_id: impl Into<DefId>) -> Span {
532    let def_id = def_id.into();
533    tcx.def_ident_span(def_id)
534        .unwrap_or_else(|| tcx.def_span(def_id))
535}
536
537mod errors {
538    use flux_errors::E0999;
539    use flux_macros::Diagnostic;
540    use rustc_span::Span;
541
542    #[derive(Diagnostic)]
543    #[diag(driver_malformed_extern_spec, code = E0999)]
544    pub(super) struct MalformedExternSpec {
545        #[primary_span]
546        span: Span,
547    }
548
549    impl MalformedExternSpec {
550        pub(super) fn new(span: Span) -> Self {
551            Self { span }
552        }
553    }
554
555    #[derive(Diagnostic)]
556    #[diag(driver_invalid_enum_extern_spec, code = E0999)]
557    pub(super) struct InvalidEnumExternSpec {
558        #[primary_span]
559        span: Span,
560        reason: String,
561    }
562
563    impl InvalidEnumExternSpec {
564        pub(super) fn new(span: Span, reason: String) -> Self {
565            Self { span, reason }
566        }
567    }
568
569    #[derive(Diagnostic)]
570    #[diag(driver_cannot_resolve_trait_impl, code = E0999)]
571    #[note]
572    pub(super) struct CannotResolveTraitImpl {
573        #[primary_span]
574        pub span: Span,
575    }
576
577    #[derive(Diagnostic)]
578    #[diag(driver_invalid_impl_block, code = E0999)]
579    pub(super) struct InvalidImplBlock {
580        #[primary_span]
581        #[label]
582        pub span: Span,
583    }
584
585    #[derive(Diagnostic)]
586    #[diag(driver_item_not_in_trait_impl, code = E0999)]
587    pub(super) struct ItemNotInTraitImpl {
588        #[primary_span]
589        #[label]
590        pub span: Span,
591        pub name: String,
592        #[note]
593        pub extern_impl_span: Span,
594    }
595
596    #[derive(Diagnostic)]
597    #[diag(driver_invalid_item_in_inherent_impl, code = E0999)]
598    pub(super) struct InvalidItemInInherentImpl {
599        #[primary_span]
600        #[label]
601        pub span: Span,
602        pub name: String,
603        #[note]
604        pub extern_item_span: Span,
605    }
606
607    #[derive(Diagnostic)]
608    #[diag(driver_item_not_in_trait, code = E0999)]
609    pub(super) struct ItemNotInTrait {
610        #[primary_span]
611        #[label]
612        pub span: Span,
613        pub name: String,
614        #[note]
615        pub extern_trait_span: Span,
616    }
617
618    #[derive(Diagnostic)]
619    #[diag(driver_extern_spec_for_local_def, code = E0999)]
620    pub(super) struct ExternSpecForLocalDef {
621        #[primary_span]
622        pub span: Span,
623        #[note]
624        pub local_def_span: Span,
625        pub name: String,
626    }
627
628    #[derive(Diagnostic)]
629    #[diag(driver_dup_extern_spec, code = E0999)]
630    pub(super) struct DupExternSpec {
631        #[primary_span]
632        #[label]
633        pub span: Span,
634        #[note]
635        pub previous_span: Span,
636        pub name: String,
637    }
638
639    #[derive(Diagnostic)]
640    #[diag(driver_mismatched_generics, code = E0999)]
641    #[note]
642    pub(super) struct MismatchedGenerics {
643        #[primary_span]
644        #[label]
645        pub span: Span,
646        #[label(driver_extern_def_label)]
647        pub extern_def: Span,
648        pub def_descr: &'static str,
649    }
650}