1use std::iter;
2
3use flux_middle::ExternSpecMappingErr;
4use flux_rustc_bridge::lowering;
5use rustc_errors::Diagnostic;
6use rustc_hir as hir;
7use rustc_hir::{
8 BodyId, OwnerId,
9 def_id::{DefId, LocalDefId},
10};
11use rustc_middle::ty::{self, TyCtxt};
12use rustc_span::{ErrorGuaranteed, Span, symbol::kw};
13
14use super::{FluxAttrs, SpecCollector};
15
16type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
17
18pub(super) struct ExternSpecCollector<'a, 'sess, 'tcx> {
19 inner: &'a mut SpecCollector<'sess, 'tcx>,
20 block: &'tcx hir::Block<'tcx>,
22}
23
24struct ExternImplItem {
25 impl_id: DefId,
26 item_id: DefId,
27}
28
29impl<'a, 'sess, 'tcx> ExternSpecCollector<'a, 'sess, 'tcx> {
30 pub(super) fn collect(inner: &'a mut SpecCollector<'sess, 'tcx>, body_id: BodyId) -> Result {
31 Self::new(inner, body_id)?.run()
32 }
33
34 fn new(inner: &'a mut SpecCollector<'sess, 'tcx>, body_id: BodyId) -> Result<Self> {
35 let body = inner.tcx.hir_body(body_id);
36 if let hir::ExprKind::Block(block, _) = body.value.kind {
37 Ok(Self { inner, block })
38 } else {
39 Err(inner
40 .errors
41 .emit(errors::MalformedExternSpec::new(body.value.span)))
42 }
43 }
44
45 fn run(mut self) -> Result {
46 let item = self.item_at(0)?;
47
48 let attrs = self
49 .inner
50 .parse_attrs_and_report_dups(item.owner_id.def_id)?;
51
52 match &item.kind {
53 hir::ItemKind::Fn { .. } => self.collect_extern_fn(item, attrs),
54 hir::ItemKind::Enum(_, _, enum_def) => {
55 self.collect_extern_enum(item.owner_id, enum_def, attrs)
56 }
57 hir::ItemKind::Struct(_, _, variant) => {
58 self.collect_extern_struct(item.owner_id, variant, attrs)
59 }
60 hir::ItemKind::Trait(_, _, _, _, _, bounds, items) => {
61 self.collect_extern_trait(item.owner_id, bounds, items, attrs)
62 }
63 hir::ItemKind::Impl(impl_) => self.collect_extern_impl(item.owner_id, impl_, attrs),
64 _ => Err(self.malformed()),
65 }
66 }
67
68 fn collect_extern_fn(&mut self, item: &hir::Item, attrs: FluxAttrs) -> Result {
69 self.inner.collect_fn_spec(item.owner_id, attrs)?;
70
71 let extern_id = self.extract_extern_id_from_fn(item)?;
72 self.insert_extern_id(item.owner_id.def_id, extern_id)?;
73 self.check_generics(item.owner_id, extern_id)?;
74
75 Ok(())
76 }
77
78 fn collect_extern_struct(
79 &mut self,
80 struct_id: OwnerId,
81 variant: &hir::VariantData,
82 attrs: FluxAttrs,
83 ) -> Result {
84 let dummy_struct = self.item_at(1)?;
85 self.inner.specs.insert_dummy(dummy_struct.owner_id.def_id);
86
87 let extern_id = self.extract_extern_id_from_struct(dummy_struct).unwrap();
88 self.insert_extern_id(struct_id.def_id, extern_id)?;
89 self.check_generics(struct_id, extern_id)?;
90
91 if let Some(ctor_id) = variant.ctor_def_id() {
92 self.inner.specs.insert_dummy(ctor_id);
93 }
94
95 self.inner.collect_struct_def(struct_id, attrs, variant)?;
96
97 Ok(())
98 }
99
100 fn collect_extern_enum(
101 &mut self,
102 enum_id: OwnerId,
103 enum_def: &hir::EnumDef,
104 attrs: FluxAttrs,
105 ) -> Result {
106 let dummy_struct = self.item_at(1)?;
107 self.inner.specs.insert_dummy(dummy_struct.owner_id.def_id);
108
109 let extern_id = self.extract_extern_id_from_struct(dummy_struct).unwrap();
110 self.insert_extern_id(enum_id.def_id, extern_id)?;
111 self.check_generics(enum_id, extern_id)?;
112
113 self.inner.collect_enum_def(enum_id, attrs, enum_def)?;
114
115 let extern_enum_def = self.tcx().adt_def(extern_id);
118
119 let extern_variants = extern_enum_def.variants();
121 let enum_variants = enum_def.variants;
122 let extern_len = extern_variants.len();
123 let enum_len = enum_variants.len();
124 if extern_len != enum_len {
125 let reason = format!("expected {extern_len:?} variants but only have {enum_len:?}");
126 return Err(self.invalid_enum_extern_spec(reason));
127 }
128 for (extern_variant, variant) in extern_enum_def.variants().iter().zip(enum_def.variants) {
129 if let Some(extern_ctor) = extern_variant.ctor_def_id()
130 && let Some(ctor) = variant.data.ctor_def_id()
131 && self.tcx().def_kind(extern_ctor) == self.tcx().def_kind(ctor)
132 {
133 self.insert_extern_id(ctor, extern_ctor)?;
134 } else {
135 let reason = format!(
136 "extern variant `{}` incompatible with specified `{}`",
137 extern_variant.ident(self.tcx()),
138 rustc_hir_pretty::id_to_string(&self.tcx(), variant.hir_id)
139 );
140 return Err(self.invalid_enum_extern_spec(reason));
141 }
142 }
143 Ok(())
144 }
145
146 fn collect_extern_impl(
147 &mut self,
148 impl_id: OwnerId,
149 impl_: &hir::Impl,
150 attrs: FluxAttrs,
151 ) -> Result {
152 self.inner.collect_impl(impl_id, attrs)?;
153
154 let dummy_item = self.item_at(1)?;
155 self.inner.specs.insert_dummy(dummy_item.owner_id.def_id);
156
157 let mut impl_of_trait = None;
159 if let hir::ItemKind::Impl(dummy_impl) = &dummy_item.kind {
160 impl_of_trait =
161 Some(self.extract_extern_id_from_impl(dummy_item.owner_id, dummy_impl)?);
162
163 self.inner
164 .specs
165 .insert_dummy(self.item_at(2)?.owner_id.def_id);
166 }
167
168 let mut extern_impl_id = impl_of_trait;
169 for item_id in impl_.items {
170 let item = self.tcx().hir_impl_item(*item_id);
171 let extern_item = if let hir::ImplItemKind::Fn { .. } = item.kind {
172 let attrs = self
173 .inner
174 .parse_attrs_and_report_dups(item_id.owner_id.def_id)?;
175 self.collect_extern_impl_fn(impl_of_trait, item, attrs)?
176 } else {
177 continue;
178 };
179
180 if *extern_impl_id.get_or_insert(extern_item.impl_id) != extern_item.impl_id {
181 return Err(self.invalid_impl_block());
182 }
183 }
184
185 if let Some(extern_impl_id) = extern_impl_id {
186 self.check_generics(impl_id, extern_impl_id)?;
187 self.insert_extern_id(impl_id.def_id, extern_impl_id)?;
188 }
189
190 Ok(())
191 }
192
193 fn collect_extern_impl_fn(
194 &mut self,
195 impl_of_trait: Option<DefId>,
196 item: &hir::ImplItem,
197 attrs: FluxAttrs,
198 ) -> Result<ExternImplItem> {
199 self.inner.collect_fn_spec(item.owner_id, attrs)?;
200
201 let extern_impl_item = self.extract_extern_id_from_impl_fn(impl_of_trait, item)?;
202 self.insert_extern_id(item.owner_id.def_id, extern_impl_item.item_id)?;
203 self.check_generics(item.owner_id, extern_impl_item.item_id)?;
204
205 Ok(extern_impl_item)
206 }
207
208 fn collect_extern_trait(
209 &mut self,
210 trait_id: OwnerId,
211 bounds: &hir::GenericBounds,
212 items: &[hir::TraitItemId],
213 attrs: FluxAttrs,
214 ) -> Result {
215 self.inner.collect_trait(trait_id, attrs)?;
216
217 let extern_trait_id = self.extract_extern_id_from_trait(bounds)?;
218 self.insert_extern_id(trait_id.def_id, extern_trait_id)?;
219 self.check_generics(trait_id, extern_trait_id)?;
220
221 for item_id in items {
222 let item = self.tcx().hir_trait_item(*item_id);
223 if let hir::TraitItemKind::Fn { .. } = item.kind {
224 let attrs = self
225 .inner
226 .parse_attrs_and_report_dups(item.owner_id.def_id)?;
227 self.collect_extern_trait_fn(extern_trait_id, item, attrs)?;
228 } else {
229 continue;
230 }
231 }
232
233 Ok(())
234 }
235
236 fn collect_extern_trait_fn(
237 &mut self,
238 extern_trait_id: DefId,
239 item: &hir::TraitItem,
240 attrs: FluxAttrs,
241 ) -> Result {
242 let item_id = item.owner_id;
243 self.inner.collect_fn_spec(item_id, attrs)?;
244
245 let extern_fn_id = self.extract_extern_id_from_trait_fn(extern_trait_id, item)?;
246 self.insert_extern_id(item.owner_id.def_id, extern_fn_id)?;
247 self.check_generics(item_id, extern_fn_id)?;
248
249 Ok(())
250 }
251
252 fn extract_extern_id_from_struct(&self, item: &hir::Item) -> Result<DefId> {
253 if let hir::ItemKind::Struct(_, _, data) = item.kind
254 && let Some(extern_field) = data.fields().last()
255 && let ty = self.tcx().type_of(extern_field.def_id)
256 && let Some(adt_def) = ty.skip_binder().ty_adt_def()
257 {
258 Ok(adt_def.did())
259 } else {
260 Err(self.malformed())
261 }
262 }
263
264 fn extract_extern_id_from_fn(&self, item: &hir::Item) -> Result<DefId> {
265 if let hir::ItemKind::Fn { body, .. } = item.kind {
266 self.extract_callee_from_body(body)
267 } else {
268 Err(self.malformed())
269 }
270 }
271
272 fn extract_extern_id_from_impl_fn(
273 &self,
274 impl_of_trait: Option<DefId>,
275 item: &hir::ImplItem,
276 ) -> Result<ExternImplItem> {
277 if let hir::ImplItemKind::Fn(_, body_id) = item.kind {
278 let callee_id = self.extract_callee_from_body(body_id)?;
279 if let Some(extern_impl_id) = impl_of_trait {
280 let map = self.tcx().impl_item_implementor_ids(extern_impl_id);
281 if let Some(extern_item_id) = map.get(&callee_id) {
282 Ok(ExternImplItem { impl_id: extern_impl_id, item_id: *extern_item_id })
283 } else {
284 Err(self.item_not_in_trait_impl(item.owner_id, callee_id, extern_impl_id))
285 }
286 } else {
287 let opt_extern_impl_id = self.tcx().impl_of_assoc(callee_id);
288 if let Some(extern_impl_id) = opt_extern_impl_id {
289 debug_assert!(self.tcx().trait_id_of_impl(extern_impl_id).is_none());
290 Ok(ExternImplItem { impl_id: extern_impl_id, item_id: callee_id })
291 } else {
292 Err(self.invalid_item_in_inherent_impl(item.owner_id, callee_id))
293 }
294 }
295 } else {
296 Err(self.malformed())
297 }
298 }
299
300 fn extract_extern_id_from_trait(&self, bounds: &hir::GenericBounds) -> Result<DefId> {
301 if let Some(bound) = bounds.first()
302 && let Some(trait_ref) = bound.trait_ref()
303 && let Some(trait_id) = trait_ref.trait_def_id()
304 {
305 Ok(trait_id)
306 } else {
307 Err(self.malformed())
308 }
309 }
310
311 fn extract_extern_id_from_trait_fn(
312 &self,
313 trait_id: DefId,
314 item: &hir::TraitItem,
315 ) -> Result<DefId> {
316 if let hir::TraitItemKind::Fn(_, trait_fn) = item.kind
317 && let hir::TraitFn::Provided(body_id) = trait_fn
318 {
319 let callee_id = self.extract_callee_from_body(body_id)?;
320 if let Some(callee_trait_id) = self.tcx().trait_of_assoc(callee_id)
321 && trait_id == callee_trait_id
322 {
323 Ok(callee_id)
324 } else {
325 Err(self.item_not_in_trait(item.owner_id, callee_id, trait_id))
328 }
329 } else {
330 Err(self.malformed())
331 }
332 }
333
334 fn extract_extern_id_from_impl(&self, impl_id: OwnerId, impl_: &hir::Impl) -> Result<DefId> {
335 if let Some(item_id) = impl_.items.first()
336 && let hir::ImplItemKind::Fn { .. } = self.tcx().hir_impl_item(*item_id).kind
337 && let Some((clause, _)) = self
338 .tcx()
339 .predicates_of(item_id.owner_id.def_id)
340 .predicates
341 .first()
342 && let Some(poly_trait_pred) = clause.as_trait_clause()
343 && let Some(trait_pred) = poly_trait_pred.no_bound_vars()
344 {
345 let trait_ref = trait_pred.trait_ref;
346 lowering::resolve_trait_ref_impl_id(self.tcx(), impl_id.to_def_id(), trait_ref)
347 .map(|(impl_id, _)| impl_id)
348 .ok_or_else(|| self.cannot_resolve_trait_impl())
349 } else {
350 Err(self.malformed())
351 }
352 }
353
354 fn extract_callee_from_body(&self, body_id: hir::BodyId) -> Result<DefId> {
355 let owner = self.tcx().hir_body_owner_def_id(body_id);
356 let typeck = self.tcx().typeck(owner);
357 if let hir::ExprKind::Block(b, _) = self.tcx().hir_body(body_id).value.kind
358 && let Some(e) = b.expr
359 && let hir::ExprKind::Call(callee, _) = e.kind
360 && let rustc_middle::ty::FnDef(callee_id, _) = typeck.node_type(callee.hir_id).kind()
361 {
362 Ok(*callee_id)
363 } else {
364 Err(self.malformed())
365 }
366 }
367
368 #[track_caller]
370 fn item_at(&self, i: usize) -> Result<&'tcx hir::Item<'tcx>> {
371 let stmts = self.block.stmts;
372 let index = stmts
373 .len()
374 .checked_sub(i + 1)
375 .ok_or_else(|| self.malformed())?;
376 let st = stmts.get(index).ok_or_else(|| self.malformed())?;
377 if let hir::StmtKind::Item(item_id) = st.kind {
378 Ok(self.tcx().hir_item(item_id))
379 } else {
380 Err(self.malformed())
381 }
382 }
383
384 fn insert_extern_id(&mut self, local_id: LocalDefId, extern_id: DefId) -> Result {
385 self.inner
386 .specs
387 .insert_extern_spec_id_mapping(local_id, extern_id)
388 .map_err(|err| {
389 match err {
390 ExternSpecMappingErr::IsLocal(extern_id_local) => {
391 self.emit(errors::ExternSpecForLocalDef {
392 span: ident_or_def_span(self.tcx(), local_id),
393 local_def_span: ident_or_def_span(self.tcx(), extern_id_local),
394 name: self.tcx().def_path_str(extern_id),
395 })
396 }
397 ExternSpecMappingErr::Dup(previous_extern_spec) => {
398 self.emit(errors::DupExternSpec {
399 span: ident_or_def_span(self.tcx(), local_id),
400 previous_span: ident_or_def_span(self.tcx(), previous_extern_spec),
401 name: self.tcx().def_path_str(extern_id),
402 })
403 }
404 }
405 })
406 }
407
408 fn check_generics(&mut self, local_id: OwnerId, extern_id: DefId) -> Result {
409 let tcx = self.tcx();
410 let local_params = &tcx.generics_of(local_id).own_params;
411 let extern_params = &tcx.generics_of(extern_id).own_params;
412
413 let mismatch = 'mismatch: {
414 if local_params.len() != extern_params.len() {
415 break 'mismatch true;
416 }
417 for (local_param, extern_param) in iter::zip(local_params, extern_params) {
418 if !cmp_generic_param_def(local_param, extern_param) {
419 break 'mismatch true;
420 }
421 if local_param.name != kw::SelfUpper {
424 #[expect(clippy::disallowed_methods)]
425 self.insert_extern_id(local_param.def_id.expect_local(), extern_param.def_id)?;
426 }
427 }
428 false
429 };
430 if mismatch {
431 let local_hir_generics = tcx.hir_get_generics(local_id.def_id).unwrap();
432 let span = local_hir_generics.span;
433 Err(self.emit(errors::MismatchedGenerics {
434 span,
435 extern_def: tcx.def_span(extern_id),
436 def_descr: tcx.def_descr(extern_id),
437 }))
438 } else {
439 Ok(())
440 }
441 }
442
443 #[track_caller]
444 fn malformed(&self) -> ErrorGuaranteed {
445 self.emit(errors::MalformedExternSpec::new(self.block.span))
446 }
447
448 #[track_caller]
449 fn invalid_enum_extern_spec(&self, reason: String) -> ErrorGuaranteed {
450 self.emit(errors::InvalidEnumExternSpec::new(self.block.span, reason))
451 }
452
453 #[track_caller]
454 fn item_not_in_trait_impl(
455 &self,
456 local_id: OwnerId,
457 extern_id: DefId,
458 extern_impl_id: DefId,
459 ) -> ErrorGuaranteed {
460 let tcx = self.tcx();
461 self.emit(errors::ItemNotInTraitImpl {
462 span: ident_or_def_span(tcx, local_id),
463 name: tcx.def_path_str(extern_id),
464 extern_impl_span: tcx.def_span(extern_impl_id),
465 })
466 }
467
468 fn invalid_item_in_inherent_impl(
469 &self,
470 local_id: OwnerId,
471 extern_id: DefId,
472 ) -> ErrorGuaranteed {
473 let tcx = self.tcx();
474 self.emit(errors::InvalidItemInInherentImpl {
475 span: ident_or_def_span(tcx, local_id),
476 name: tcx.def_path_str(extern_id),
477 extern_item_span: tcx.def_span(extern_id),
478 })
479 }
480
481 #[track_caller]
482 fn invalid_impl_block(&self) -> ErrorGuaranteed {
483 self.emit(errors::InvalidImplBlock { span: self.block.span })
484 }
485
486 #[track_caller]
487 fn cannot_resolve_trait_impl(&self) -> ErrorGuaranteed {
488 self.emit(errors::CannotResolveTraitImpl { span: self.block.span })
489 }
490
491 #[track_caller]
492 fn item_not_in_trait(
493 &self,
494 local_id: OwnerId,
495 extern_id: DefId,
496 extern_trait_id: DefId,
497 ) -> ErrorGuaranteed {
498 let tcx = self.tcx();
499 self.emit(errors::ItemNotInTrait {
500 span: ident_or_def_span(tcx, local_id),
501 name: tcx.def_path_str(extern_id),
502 extern_trait_span: tcx.def_span(extern_trait_id),
503 })
504 }
505
506 #[track_caller]
507 fn emit<'b>(&'b self, err: impl Diagnostic<'b>) -> ErrorGuaranteed {
508 self.inner.errors.emit(err)
509 }
510
511 fn tcx(&self) -> TyCtxt<'tcx> {
512 self.inner.tcx
513 }
514}
515
516fn cmp_generic_param_def(a: &ty::GenericParamDef, b: &ty::GenericParamDef) -> bool {
517 if a.name != b.name {
518 return false;
519 }
520 if a.index != b.index {
521 return false;
522 }
523 matches!(
524 (&a.kind, &b.kind),
525 (ty::GenericParamDefKind::Lifetime, ty::GenericParamDefKind::Lifetime)
526 | (ty::GenericParamDefKind::Type { .. }, ty::GenericParamDefKind::Type { .. })
527 | (ty::GenericParamDefKind::Const { .. }, ty::GenericParamDefKind::Const { .. })
528 )
529}
530
531fn ident_or_def_span(tcx: TyCtxt, def_id: impl Into<DefId>) -> Span {
532 let def_id = def_id.into();
533 tcx.def_ident_span(def_id)
534 .unwrap_or_else(|| tcx.def_span(def_id))
535}
536
537mod errors {
538 use flux_errors::E0999;
539 use flux_macros::Diagnostic;
540 use rustc_span::Span;
541
542 #[derive(Diagnostic)]
543 #[diag(driver_malformed_extern_spec, code = E0999)]
544 pub(super) struct MalformedExternSpec {
545 #[primary_span]
546 span: Span,
547 }
548
549 impl MalformedExternSpec {
550 pub(super) fn new(span: Span) -> Self {
551 Self { span }
552 }
553 }
554
555 #[derive(Diagnostic)]
556 #[diag(driver_invalid_enum_extern_spec, code = E0999)]
557 pub(super) struct InvalidEnumExternSpec {
558 #[primary_span]
559 span: Span,
560 reason: String,
561 }
562
563 impl InvalidEnumExternSpec {
564 pub(super) fn new(span: Span, reason: String) -> Self {
565 Self { span, reason }
566 }
567 }
568
569 #[derive(Diagnostic)]
570 #[diag(driver_cannot_resolve_trait_impl, code = E0999)]
571 #[note]
572 pub(super) struct CannotResolveTraitImpl {
573 #[primary_span]
574 pub span: Span,
575 }
576
577 #[derive(Diagnostic)]
578 #[diag(driver_invalid_impl_block, code = E0999)]
579 pub(super) struct InvalidImplBlock {
580 #[primary_span]
581 #[label]
582 pub span: Span,
583 }
584
585 #[derive(Diagnostic)]
586 #[diag(driver_item_not_in_trait_impl, code = E0999)]
587 pub(super) struct ItemNotInTraitImpl {
588 #[primary_span]
589 #[label]
590 pub span: Span,
591 pub name: String,
592 #[note]
593 pub extern_impl_span: Span,
594 }
595
596 #[derive(Diagnostic)]
597 #[diag(driver_invalid_item_in_inherent_impl, code = E0999)]
598 pub(super) struct InvalidItemInInherentImpl {
599 #[primary_span]
600 #[label]
601 pub span: Span,
602 pub name: String,
603 #[note]
604 pub extern_item_span: Span,
605 }
606
607 #[derive(Diagnostic)]
608 #[diag(driver_item_not_in_trait, code = E0999)]
609 pub(super) struct ItemNotInTrait {
610 #[primary_span]
611 #[label]
612 pub span: Span,
613 pub name: String,
614 #[note]
615 pub extern_trait_span: Span,
616 }
617
618 #[derive(Diagnostic)]
619 #[diag(driver_extern_spec_for_local_def, code = E0999)]
620 pub(super) struct ExternSpecForLocalDef {
621 #[primary_span]
622 pub span: Span,
623 #[note]
624 pub local_def_span: Span,
625 pub name: String,
626 }
627
628 #[derive(Diagnostic)]
629 #[diag(driver_dup_extern_spec, code = E0999)]
630 pub(super) struct DupExternSpec {
631 #[primary_span]
632 #[label]
633 pub span: Span,
634 #[note]
635 pub previous_span: Span,
636 pub name: String,
637 }
638
639 #[derive(Diagnostic)]
640 #[diag(driver_mismatched_generics, code = E0999)]
641 #[note]
642 pub(super) struct MismatchedGenerics {
643 #[primary_span]
644 #[label]
645 pub span: Span,
646 #[label(driver_extern_def_label)]
647 pub extern_def: Span,
648 pub def_descr: &'static str,
649 }
650}