1use std::iter;
2
3use flux_middle::ExternSpecMappingErr;
4use flux_rustc_bridge::lowering;
5use flux_syntax::surface;
6use rustc_errors::Diagnostic;
7use rustc_hir as hir;
8use rustc_hir::{
9 BodyId, OwnerId,
10 def_id::{DefId, LocalDefId},
11};
12use rustc_middle::ty::{self, TyCtxt};
13use rustc_span::{ErrorGuaranteed, Span, symbol::kw};
14
15use super::{FluxAttrs, SpecCollector};
16
17type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
18
19pub(super) struct ExternSpecCollector<'a, 'sess, 'tcx> {
20 inner: &'a mut SpecCollector<'sess, 'tcx>,
21 block: &'tcx hir::Block<'tcx>,
23}
24
25struct ExternImplItem {
26 impl_id: DefId,
27 item_id: DefId,
28}
29
30impl<'a, 'sess, 'tcx> ExternSpecCollector<'a, 'sess, 'tcx> {
31 pub(super) fn collect(inner: &'a mut SpecCollector<'sess, 'tcx>, body_id: BodyId) -> Result {
32 Self::new(inner, body_id)?.run()
33 }
34
35 fn new(inner: &'a mut SpecCollector<'sess, 'tcx>, body_id: BodyId) -> Result<Self> {
36 let body = inner.tcx.hir_body(body_id);
37 if let hir::ExprKind::Block(block, _) = body.value.kind {
38 Ok(Self { inner, block })
39 } else {
40 Err(inner
41 .errors
42 .emit(errors::MalformedExternSpec::new(body.value.span)))
43 }
44 }
45
46 fn run(mut self) -> Result {
47 let item = self.item_at(0)?;
48
49 let attrs = self
50 .inner
51 .parse_attrs_and_report_dups(item.owner_id.def_id)?;
52
53 match &item.kind {
54 hir::ItemKind::Fn { .. } => self.collect_extern_fn(item, attrs),
55 hir::ItemKind::Enum(_, _, enum_def) => {
56 self.collect_extern_enum(item.owner_id, enum_def, attrs)
57 }
58 hir::ItemKind::Struct(_, _, variant) => {
59 self.collect_extern_struct(item.owner_id, variant, attrs)
60 }
61 hir::ItemKind::Trait(_, _, _, _, _, bounds, items) => {
62 self.collect_extern_trait(item.owner_id, bounds, items, attrs)
63 }
64 hir::ItemKind::Impl(impl_) => self.collect_extern_impl(item.owner_id, impl_, attrs),
65 _ => Err(self.malformed()),
66 }
67 }
68
69 fn collect_extern_fn(&mut self, item: &hir::Item, mut attrs: FluxAttrs) -> Result {
70 if attrs.has_attrs() {
71 let sig = attrs.fn_sig();
72 self.inner.check_fn_sig_name(item.owner_id, sig.as_ref())?;
73 let node_id = self.inner.next_node_id();
74 self.inner.insert_item(
75 item.owner_id,
76 surface::Item {
77 attrs: attrs.into_attr_vec(),
78 kind: surface::ItemKind::Fn(sig),
79 node_id,
80 },
81 )?;
82 }
83
84 let extern_id = self.extract_extern_id_from_fn(item)?;
85 self.insert_extern_id(item.owner_id.def_id, extern_id)?;
86 self.check_generics(item.owner_id, extern_id)?;
87
88 Ok(())
89 }
90
91 fn collect_extern_struct(
92 &mut self,
93 struct_id: OwnerId,
94 variant: &hir::VariantData,
95 attrs: FluxAttrs,
96 ) -> Result {
97 let dummy_struct = self.item_at(1)?;
98 self.inner.specs.insert_dummy(dummy_struct.owner_id.def_id);
99
100 let extern_id = self.extract_extern_id_from_struct(dummy_struct).unwrap();
101 self.insert_extern_id(struct_id.def_id, extern_id)?;
102 self.check_generics(struct_id, extern_id)?;
103
104 if let Some(ctor_id) = variant.ctor_def_id() {
105 self.inner.specs.insert_dummy(ctor_id);
106 }
107
108 self.inner.collect_struct_def(struct_id, attrs, variant)?;
109
110 Ok(())
111 }
112
113 fn collect_extern_enum(
114 &mut self,
115 enum_id: OwnerId,
116 enum_def: &hir::EnumDef,
117 attrs: FluxAttrs,
118 ) -> Result {
119 let dummy_struct = self.item_at(1)?;
120 self.inner.specs.insert_dummy(dummy_struct.owner_id.def_id);
121
122 let extern_id = self.extract_extern_id_from_struct(dummy_struct).unwrap();
123 self.insert_extern_id(enum_id.def_id, extern_id)?;
124 self.check_generics(enum_id, extern_id)?;
125
126 self.inner.collect_enum_def(enum_id, attrs, enum_def)?;
127
128 let extern_enum_def = self.tcx().adt_def(extern_id);
131
132 let extern_variants = extern_enum_def.variants();
134 let enum_variants = enum_def.variants;
135 let extern_len = extern_variants.len();
136 let enum_len = enum_variants.len();
137 if extern_len != enum_len {
138 let reason = format!("expected {extern_len:?} variants but only have {enum_len:?}");
139 return Err(self.invalid_enum_extern_spec(reason));
140 }
141 for (extern_variant, variant) in extern_enum_def.variants().iter().zip(enum_def.variants) {
142 if let Some(extern_ctor) = extern_variant.ctor_def_id()
143 && let Some(ctor) = variant.data.ctor_def_id()
144 && self.tcx().def_kind(extern_ctor) == self.tcx().def_kind(ctor)
145 {
146 self.insert_extern_id(ctor, extern_ctor)?;
147 } else {
148 let reason = format!(
149 "extern variant `{}` incompatible with specified `{}`",
150 extern_variant.ident(self.tcx()),
151 rustc_hir_pretty::id_to_string(&self.tcx(), variant.hir_id)
152 );
153 return Err(self.invalid_enum_extern_spec(reason));
154 }
155 }
156 Ok(())
157 }
158
159 fn collect_extern_impl(
160 &mut self,
161 impl_id: OwnerId,
162 impl_: &hir::Impl,
163 attrs: FluxAttrs,
164 ) -> Result {
165 self.inner.collect_impl(impl_id, attrs)?;
166
167 let dummy_item = self.item_at(1)?;
168 self.inner.specs.insert_dummy(dummy_item.owner_id.def_id);
169
170 let mut impl_of_trait = None;
172 if let hir::ItemKind::Impl(dummy_impl) = &dummy_item.kind {
173 impl_of_trait =
174 Some(self.extract_extern_id_from_impl(dummy_item.owner_id, dummy_impl)?);
175
176 self.inner
177 .specs
178 .insert_dummy(self.item_at(2)?.owner_id.def_id);
179 }
180
181 let mut extern_impl_id = impl_of_trait;
182 for item_id in impl_.items {
183 let item = self.tcx().hir_impl_item(*item_id);
184 let extern_item = if let hir::ImplItemKind::Fn { .. } = item.kind {
185 let attrs = self
186 .inner
187 .parse_attrs_and_report_dups(item_id.owner_id.def_id)?;
188 self.collect_extern_impl_fn(impl_of_trait, item, attrs)?
189 } else {
190 continue;
191 };
192
193 if *extern_impl_id.get_or_insert(extern_item.impl_id) != extern_item.impl_id {
194 return Err(self.invalid_impl_block());
195 }
196 }
197
198 if let Some(extern_impl_id) = extern_impl_id {
199 self.check_generics(impl_id, extern_impl_id)?;
200 self.insert_extern_id(impl_id.def_id, extern_impl_id)?;
201 }
202
203 Ok(())
204 }
205
206 fn collect_extern_impl_fn(
207 &mut self,
208 impl_of_trait: Option<DefId>,
209 item: &hir::ImplItem,
210 mut attrs: FluxAttrs,
211 ) -> Result<ExternImplItem> {
212 if attrs.has_attrs() {
213 let sig = attrs.fn_sig();
214 self.inner.check_fn_sig_name(item.owner_id, sig.as_ref())?;
215 let node_id = self.inner.next_node_id();
216 self.inner.insert_impl_item(
217 item.owner_id,
218 surface::ImplItemFn { attrs: attrs.into_attr_vec(), sig, node_id },
219 )?;
220 }
221
222 let extern_impl_item = self.extract_extern_id_from_impl_fn(impl_of_trait, item)?;
223 self.insert_extern_id(item.owner_id.def_id, extern_impl_item.item_id)?;
224 self.check_generics(item.owner_id, extern_impl_item.item_id)?;
225
226 Ok(extern_impl_item)
227 }
228
229 fn collect_extern_trait(
230 &mut self,
231 trait_id: OwnerId,
232 bounds: &hir::GenericBounds,
233 items: &[hir::TraitItemId],
234 attrs: FluxAttrs,
235 ) -> Result {
236 self.inner.collect_trait(trait_id, attrs)?;
237
238 let extern_trait_id = self.extract_extern_id_from_trait(bounds)?;
239 self.insert_extern_id(trait_id.def_id, extern_trait_id)?;
240 self.check_generics(trait_id, extern_trait_id)?;
241
242 for item_id in items {
243 let item = self.tcx().hir_trait_item(*item_id);
244 if let hir::TraitItemKind::Fn { .. } = item.kind {
245 let attrs = self
246 .inner
247 .parse_attrs_and_report_dups(item.owner_id.def_id)?;
248 self.collect_extern_trait_fn(extern_trait_id, item, attrs)?;
249 } else {
250 continue;
251 }
252 }
253
254 Ok(())
255 }
256
257 fn collect_extern_trait_fn(
258 &mut self,
259 extern_trait_id: DefId,
260 item: &hir::TraitItem,
261 mut attrs: FluxAttrs,
262 ) -> Result {
263 let item_id = item.owner_id;
264 if attrs.has_attrs() {
265 let sig = attrs.fn_sig();
266 self.inner.check_fn_sig_name(item.owner_id, sig.as_ref())?;
267 let node_id = self.inner.next_node_id();
268 self.inner.insert_trait_item(
269 item.owner_id,
270 surface::TraitItemFn { attrs: attrs.into_attr_vec(), sig, node_id },
271 )?;
272 }
273
274 let extern_fn_id = self.extract_extern_id_from_trait_fn(extern_trait_id, item)?;
275 self.insert_extern_id(item.owner_id.def_id, extern_fn_id)?;
276 self.check_generics(item_id, extern_fn_id)?;
277
278 Ok(())
279 }
280
281 fn extract_extern_id_from_struct(&self, item: &hir::Item) -> Result<DefId> {
282 if let hir::ItemKind::Struct(_, _, data) = item.kind
283 && let Some(extern_field) = data.fields().last()
284 && let ty = self.tcx().type_of(extern_field.def_id)
285 && let Some(adt_def) = ty.skip_binder().ty_adt_def()
286 {
287 Ok(adt_def.did())
288 } else {
289 Err(self.malformed())
290 }
291 }
292
293 fn extract_extern_id_from_fn(&self, item: &hir::Item) -> Result<DefId> {
294 if let hir::ItemKind::Fn { body, .. } = item.kind {
295 self.extract_callee_from_body(body)
296 } else {
297 Err(self.malformed())
298 }
299 }
300
301 fn extract_extern_id_from_impl_fn(
302 &self,
303 impl_of_trait: Option<DefId>,
304 item: &hir::ImplItem,
305 ) -> Result<ExternImplItem> {
306 if let hir::ImplItemKind::Fn(_, body_id) = item.kind {
307 let callee_id = self.extract_callee_from_body(body_id)?;
308 if let Some(extern_impl_id) = impl_of_trait {
309 let map = self.tcx().impl_item_implementor_ids(extern_impl_id);
310 if let Some(extern_item_id) = map.get(&callee_id) {
311 Ok(ExternImplItem { impl_id: extern_impl_id, item_id: *extern_item_id })
312 } else {
313 Err(self.item_not_in_trait_impl(item.owner_id, callee_id, extern_impl_id))
314 }
315 } else {
316 let opt_extern_impl_id = self.tcx().impl_of_assoc(callee_id);
317 if let Some(extern_impl_id) = opt_extern_impl_id {
318 debug_assert!(self.tcx().trait_id_of_impl(extern_impl_id).is_none());
319 Ok(ExternImplItem { impl_id: extern_impl_id, item_id: callee_id })
320 } else {
321 Err(self.invalid_item_in_inherent_impl(item.owner_id, callee_id))
322 }
323 }
324 } else {
325 Err(self.malformed())
326 }
327 }
328
329 fn extract_extern_id_from_trait(&self, bounds: &hir::GenericBounds) -> Result<DefId> {
330 if let Some(bound) = bounds.first()
331 && let Some(trait_ref) = bound.trait_ref()
332 && let Some(trait_id) = trait_ref.trait_def_id()
333 {
334 Ok(trait_id)
335 } else {
336 Err(self.malformed())
337 }
338 }
339
340 fn extract_extern_id_from_trait_fn(
341 &self,
342 trait_id: DefId,
343 item: &hir::TraitItem,
344 ) -> Result<DefId> {
345 if let hir::TraitItemKind::Fn(_, trait_fn) = item.kind
346 && let hir::TraitFn::Provided(body_id) = trait_fn
347 {
348 let callee_id = self.extract_callee_from_body(body_id)?;
349 if let Some(callee_trait_id) = self.tcx().trait_of_assoc(callee_id)
350 && trait_id == callee_trait_id
351 {
352 Ok(callee_id)
353 } else {
354 Err(self.item_not_in_trait(item.owner_id, callee_id, trait_id))
357 }
358 } else {
359 Err(self.malformed())
360 }
361 }
362
363 fn extract_extern_id_from_impl(&self, impl_id: OwnerId, impl_: &hir::Impl) -> Result<DefId> {
364 if let Some(item_id) = impl_.items.first()
365 && let hir::ImplItemKind::Fn { .. } = self.tcx().hir_impl_item(*item_id).kind
366 && let Some((clause, _)) = self
367 .tcx()
368 .predicates_of(item_id.owner_id.def_id)
369 .predicates
370 .first()
371 && let Some(poly_trait_pred) = clause.as_trait_clause()
372 && let Some(trait_pred) = poly_trait_pred.no_bound_vars()
373 {
374 let trait_ref = trait_pred.trait_ref;
375 lowering::resolve_trait_ref_impl_id(self.tcx(), impl_id.to_def_id(), trait_ref)
376 .map(|(impl_id, _)| impl_id)
377 .ok_or_else(|| self.cannot_resolve_trait_impl())
378 } else {
379 Err(self.malformed())
380 }
381 }
382
383 fn extract_callee_from_body(&self, body_id: hir::BodyId) -> Result<DefId> {
384 let owner = self.tcx().hir_body_owner_def_id(body_id);
385 let typeck = self.tcx().typeck(owner);
386 if let hir::ExprKind::Block(b, _) = self.tcx().hir_body(body_id).value.kind
387 && let Some(e) = b.expr
388 && let hir::ExprKind::Call(callee, _) = e.kind
389 && let rustc_middle::ty::FnDef(callee_id, _) = typeck.node_type(callee.hir_id).kind()
390 {
391 Ok(*callee_id)
392 } else {
393 Err(self.malformed())
394 }
395 }
396
397 #[track_caller]
399 fn item_at(&self, i: usize) -> Result<&'tcx hir::Item<'tcx>> {
400 let stmts = self.block.stmts;
401 let index = stmts
402 .len()
403 .checked_sub(i + 1)
404 .ok_or_else(|| self.malformed())?;
405 let st = stmts.get(index).ok_or_else(|| self.malformed())?;
406 if let hir::StmtKind::Item(item_id) = st.kind {
407 Ok(self.tcx().hir_item(item_id))
408 } else {
409 Err(self.malformed())
410 }
411 }
412
413 fn insert_extern_id(&mut self, local_id: LocalDefId, extern_id: DefId) -> Result {
414 self.inner
415 .specs
416 .insert_extern_spec_id_mapping(local_id, extern_id)
417 .map_err(|err| {
418 match err {
419 ExternSpecMappingErr::IsLocal(extern_id_local) => {
420 self.emit(errors::ExternSpecForLocalDef {
421 span: ident_or_def_span(self.tcx(), local_id),
422 local_def_span: ident_or_def_span(self.tcx(), extern_id_local),
423 name: self.tcx().def_path_str(extern_id),
424 })
425 }
426 ExternSpecMappingErr::Dup(previous_extern_spec) => {
427 self.emit(errors::DupExternSpec {
428 span: ident_or_def_span(self.tcx(), local_id),
429 previous_span: ident_or_def_span(self.tcx(), previous_extern_spec),
430 name: self.tcx().def_path_str(extern_id),
431 })
432 }
433 }
434 })
435 }
436
437 fn check_generics(&mut self, local_id: OwnerId, extern_id: DefId) -> Result {
438 let tcx = self.tcx();
439 let local_params = &tcx.generics_of(local_id).own_params;
440 let extern_params = &tcx.generics_of(extern_id).own_params;
441
442 let mismatch = 'mismatch: {
443 if local_params.len() != extern_params.len() {
444 break 'mismatch true;
445 }
446 for (local_param, extern_param) in iter::zip(local_params, extern_params) {
447 if !cmp_generic_param_def(local_param, extern_param) {
448 break 'mismatch true;
449 }
450 if local_param.name != kw::SelfUpper {
453 #[expect(clippy::disallowed_methods)]
454 self.insert_extern_id(local_param.def_id.expect_local(), extern_param.def_id)?;
455 }
456 }
457 false
458 };
459 if mismatch {
460 let local_hir_generics = tcx.hir_get_generics(local_id.def_id).unwrap();
461 let span = local_hir_generics.span;
462 Err(self.emit(errors::MismatchedGenerics {
463 span,
464 extern_def: tcx.def_span(extern_id),
465 def_descr: tcx.def_descr(extern_id),
466 }))
467 } else {
468 Ok(())
469 }
470 }
471
472 #[track_caller]
473 fn malformed(&self) -> ErrorGuaranteed {
474 self.emit(errors::MalformedExternSpec::new(self.block.span))
475 }
476
477 #[track_caller]
478 fn invalid_enum_extern_spec(&self, reason: String) -> ErrorGuaranteed {
479 self.emit(errors::InvalidEnumExternSpec::new(self.block.span, reason))
480 }
481
482 #[track_caller]
483 fn item_not_in_trait_impl(
484 &self,
485 local_id: OwnerId,
486 extern_id: DefId,
487 extern_impl_id: DefId,
488 ) -> ErrorGuaranteed {
489 let tcx = self.tcx();
490 self.emit(errors::ItemNotInTraitImpl {
491 span: ident_or_def_span(tcx, local_id),
492 name: tcx.def_path_str(extern_id),
493 extern_impl_span: tcx.def_span(extern_impl_id),
494 })
495 }
496
497 fn invalid_item_in_inherent_impl(
498 &self,
499 local_id: OwnerId,
500 extern_id: DefId,
501 ) -> ErrorGuaranteed {
502 let tcx = self.tcx();
503 self.emit(errors::InvalidItemInInherentImpl {
504 span: ident_or_def_span(tcx, local_id),
505 name: tcx.def_path_str(extern_id),
506 extern_item_span: tcx.def_span(extern_id),
507 })
508 }
509
510 #[track_caller]
511 fn invalid_impl_block(&self) -> ErrorGuaranteed {
512 self.emit(errors::InvalidImplBlock { span: self.block.span })
513 }
514
515 #[track_caller]
516 fn cannot_resolve_trait_impl(&self) -> ErrorGuaranteed {
517 self.emit(errors::CannotResolveTraitImpl { span: self.block.span })
518 }
519
520 #[track_caller]
521 fn item_not_in_trait(
522 &self,
523 local_id: OwnerId,
524 extern_id: DefId,
525 extern_trait_id: DefId,
526 ) -> ErrorGuaranteed {
527 let tcx = self.tcx();
528 self.emit(errors::ItemNotInTrait {
529 span: ident_or_def_span(tcx, local_id),
530 name: tcx.def_path_str(extern_id),
531 extern_trait_span: tcx.def_span(extern_trait_id),
532 })
533 }
534
535 #[track_caller]
536 fn emit<'b>(&'b self, err: impl Diagnostic<'b>) -> ErrorGuaranteed {
537 self.inner.errors.emit(err)
538 }
539
540 fn tcx(&self) -> TyCtxt<'tcx> {
541 self.inner.tcx
542 }
543}
544
545fn cmp_generic_param_def(a: &ty::GenericParamDef, b: &ty::GenericParamDef) -> bool {
546 if a.name != b.name {
547 return false;
548 }
549 if a.index != b.index {
550 return false;
551 }
552 matches!(
553 (&a.kind, &b.kind),
554 (ty::GenericParamDefKind::Lifetime, ty::GenericParamDefKind::Lifetime)
555 | (ty::GenericParamDefKind::Type { .. }, ty::GenericParamDefKind::Type { .. })
556 | (ty::GenericParamDefKind::Const { .. }, ty::GenericParamDefKind::Const { .. })
557 )
558}
559
560fn ident_or_def_span(tcx: TyCtxt, def_id: impl Into<DefId>) -> Span {
561 let def_id = def_id.into();
562 tcx.def_ident_span(def_id)
563 .unwrap_or_else(|| tcx.def_span(def_id))
564}
565
566mod errors {
567 use flux_errors::E0999;
568 use flux_macros::Diagnostic;
569 use rustc_span::Span;
570
571 #[derive(Diagnostic)]
572 #[diag(driver_malformed_extern_spec, code = E0999)]
573 pub(super) struct MalformedExternSpec {
574 #[primary_span]
575 span: Span,
576 }
577
578 impl MalformedExternSpec {
579 pub(super) fn new(span: Span) -> Self {
580 Self { span }
581 }
582 }
583
584 #[derive(Diagnostic)]
585 #[diag(driver_invalid_enum_extern_spec, code = E0999)]
586 pub(super) struct InvalidEnumExternSpec {
587 #[primary_span]
588 span: Span,
589 reason: String,
590 }
591
592 impl InvalidEnumExternSpec {
593 pub(super) fn new(span: Span, reason: String) -> Self {
594 Self { span, reason }
595 }
596 }
597
598 #[derive(Diagnostic)]
599 #[diag(driver_cannot_resolve_trait_impl, code = E0999)]
600 #[note]
601 pub(super) struct CannotResolveTraitImpl {
602 #[primary_span]
603 pub span: Span,
604 }
605
606 #[derive(Diagnostic)]
607 #[diag(driver_invalid_impl_block, code = E0999)]
608 pub(super) struct InvalidImplBlock {
609 #[primary_span]
610 #[label]
611 pub span: Span,
612 }
613
614 #[derive(Diagnostic)]
615 #[diag(driver_item_not_in_trait_impl, code = E0999)]
616 pub(super) struct ItemNotInTraitImpl {
617 #[primary_span]
618 #[label]
619 pub span: Span,
620 pub name: String,
621 #[note]
622 pub extern_impl_span: Span,
623 }
624
625 #[derive(Diagnostic)]
626 #[diag(driver_invalid_item_in_inherent_impl, code = E0999)]
627 pub(super) struct InvalidItemInInherentImpl {
628 #[primary_span]
629 #[label]
630 pub span: Span,
631 pub name: String,
632 #[note]
633 pub extern_item_span: Span,
634 }
635
636 #[derive(Diagnostic)]
637 #[diag(driver_item_not_in_trait, code = E0999)]
638 pub(super) struct ItemNotInTrait {
639 #[primary_span]
640 #[label]
641 pub span: Span,
642 pub name: String,
643 #[note]
644 pub extern_trait_span: Span,
645 }
646
647 #[derive(Diagnostic)]
648 #[diag(driver_extern_spec_for_local_def, code = E0999)]
649 pub(super) struct ExternSpecForLocalDef {
650 #[primary_span]
651 pub span: Span,
652 #[help]
653 pub local_def_span: Span,
654 pub name: String,
655 }
656
657 #[derive(Diagnostic)]
658 #[diag(driver_dup_extern_spec, code = E0999)]
659 pub(super) struct DupExternSpec {
660 #[primary_span]
661 #[label]
662 pub span: Span,
663 #[note]
664 pub previous_span: Span,
665 pub name: String,
666 }
667
668 #[derive(Diagnostic)]
669 #[diag(driver_mismatched_generics, code = E0999)]
670 #[note]
671 pub(super) struct MismatchedGenerics {
672 #[primary_span]
673 #[label]
674 pub span: Span,
675 #[label(driver_extern_def_label)]
676 pub extern_def: Span,
677 pub def_descr: &'static str,
678 }
679}