1use std::iter;
2
3use flux_middle::ExternSpecMappingErr;
4use flux_rustc_bridge::lowering;
5use rustc_errors::Diagnostic;
6use rustc_hir as hir;
7use rustc_hir::{
8 BodyId, OwnerId,
9 def_id::{DefId, LocalDefId},
10};
11use rustc_middle::ty::{self, TyCtxt};
12use rustc_span::{ErrorGuaranteed, Span, symbol::kw};
13
14use super::{FluxAttrs, SpecCollector};
15
16type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
17
18pub(super) struct ExternSpecCollector<'a, 'sess, 'tcx> {
19 inner: &'a mut SpecCollector<'sess, 'tcx>,
20 block: &'tcx hir::Block<'tcx>,
22}
23
24struct ExternImplItem {
25 impl_id: DefId,
26 item_id: DefId,
27}
28
29impl<'a, 'sess, 'tcx> ExternSpecCollector<'a, 'sess, 'tcx> {
30 pub(super) fn collect(inner: &'a mut SpecCollector<'sess, 'tcx>, body_id: BodyId) -> Result {
31 Self::new(inner, body_id)?.run()
32 }
33
34 fn new(inner: &'a mut SpecCollector<'sess, 'tcx>, body_id: BodyId) -> Result<Self> {
35 let body = inner.tcx.hir_body(body_id);
36 if let hir::ExprKind::Block(block, _) = body.value.kind {
37 Ok(Self { inner, block })
38 } else {
39 Err(inner
40 .errors
41 .emit(errors::MalformedExternSpec::new(body.value.span)))
42 }
43 }
44
45 fn run(mut self) -> Result {
46 let item = self.item_at(0)?;
47
48 let attrs = self
49 .inner
50 .parse_attrs_and_report_dups(item.owner_id.def_id)?;
51
52 match &item.kind {
53 hir::ItemKind::Fn { .. } => self.collect_extern_fn(item, attrs),
54 hir::ItemKind::Enum(enum_def, _) => {
55 self.collect_extern_enum(item.owner_id, enum_def, attrs)
56 }
57 hir::ItemKind::Struct(variant, _) => {
58 self.collect_extern_struct(item.owner_id, variant, attrs)
59 }
60 hir::ItemKind::Trait(_, _, _, bounds, items) => {
61 self.collect_extern_trait(item.owner_id, bounds, items, attrs)
62 }
63 hir::ItemKind::Impl(impl_) => self.collect_extern_impl(item.owner_id, impl_, attrs),
64 _ => Err(self.malformed()),
65 }
66 }
67
68 fn collect_extern_fn(&mut self, item: &hir::Item, attrs: FluxAttrs) -> Result {
69 self.inner.collect_fn_spec(item.owner_id, attrs)?;
70
71 let extern_id = self.extract_extern_id_from_fn(item)?;
72 self.insert_extern_id(item.owner_id.def_id, extern_id)?;
73 self.check_generics(item.owner_id, extern_id)?;
74
75 Ok(())
76 }
77
78 fn collect_extern_struct(
79 &mut self,
80 struct_id: OwnerId,
81 variant: &hir::VariantData,
82 attrs: FluxAttrs,
83 ) -> Result {
84 let dummy_struct = self.item_at(1)?;
85 self.inner.specs.insert_dummy(dummy_struct.owner_id);
86
87 let extern_id = self.extract_extern_id_from_struct(dummy_struct).unwrap();
88 self.insert_extern_id(struct_id.def_id, extern_id)?;
89 self.check_generics(struct_id, extern_id)?;
90
91 self.inner.collect_struct_def(struct_id, attrs, variant)?;
92
93 Ok(())
94 }
95
96 fn collect_extern_enum(
97 &mut self,
98 enum_id: OwnerId,
99 enum_def: &hir::EnumDef,
100 attrs: FluxAttrs,
101 ) -> Result {
102 let dummy_struct = self.item_at(1)?;
103 self.inner.specs.insert_dummy(dummy_struct.owner_id);
104
105 let extern_id = self.extract_extern_id_from_struct(dummy_struct).unwrap();
106 self.insert_extern_id(enum_id.def_id, extern_id)?;
107 self.check_generics(enum_id, extern_id)?;
108
109 self.inner.collect_enum_def(enum_id, attrs, enum_def)?;
110
111 let extern_enum_def = self.tcx().adt_def(extern_id);
114
115 let extern_variants = extern_enum_def.variants();
117 let enum_variants = enum_def.variants;
118 let extern_len = extern_variants.len();
119 let enum_len = enum_variants.len();
120 if extern_len != enum_len {
121 let reason = format!("expected {extern_len:?} variants but only have {enum_len:?}");
122 return Err(self.invalid_enum_extern_spec(reason));
123 }
124 for (extern_variant, variant) in extern_enum_def.variants().iter().zip(enum_def.variants) {
125 if let Some(extern_ctor) = extern_variant.ctor_def_id()
126 && let Some(ctor) = variant.data.ctor_def_id()
127 && self.tcx().def_kind(extern_ctor) == self.tcx().def_kind(ctor)
128 {
129 self.insert_extern_id(ctor, extern_ctor)?;
130 } else {
131 let reason = format!(
132 "extern variant {extern_variant:?} incompatible with specified {variant:?}"
133 );
134 return Err(self.invalid_enum_extern_spec(reason));
135 }
136 }
137 Ok(())
138 }
139
140 fn collect_extern_impl(
141 &mut self,
142 impl_id: OwnerId,
143 impl_: &hir::Impl,
144 attrs: FluxAttrs,
145 ) -> Result {
146 self.inner.collect_impl(impl_id, attrs)?;
147
148 let dummy_item = self.item_at(1)?;
149 self.inner.specs.insert_dummy(dummy_item.owner_id);
150
151 let mut impl_of_trait = None;
153 if let hir::ItemKind::Impl(dummy_impl) = dummy_item.kind {
154 impl_of_trait =
155 Some(self.extract_extern_id_from_impl(dummy_item.owner_id, dummy_impl)?);
156
157 self.inner.specs.insert_dummy(self.item_at(2)?.owner_id);
158 }
159
160 let mut extern_impl_id = impl_of_trait;
161 for item in impl_.items {
162 let item_id = item.id.owner_id.def_id;
163 let extern_item = if let hir::AssocItemKind::Fn { .. } = item.kind {
164 let attrs = self.inner.parse_attrs_and_report_dups(item_id)?;
165 self.collect_extern_impl_fn(impl_of_trait, item, attrs)?
166 } else {
167 continue;
168 };
169
170 if *extern_impl_id.get_or_insert(extern_item.impl_id) != extern_item.impl_id {
171 return Err(self.invalid_impl_block());
172 }
173 }
174
175 if let Some(extern_impl_id) = extern_impl_id {
176 self.check_generics(impl_id, extern_impl_id)?;
177 self.insert_extern_id(impl_id.def_id, extern_impl_id)?;
178 }
179
180 Ok(())
181 }
182
183 fn collect_extern_impl_fn(
184 &mut self,
185 impl_of_trait: Option<DefId>,
186 item: &hir::ImplItemRef,
187 attrs: FluxAttrs,
188 ) -> Result<ExternImplItem> {
189 let item_id = item.id.owner_id;
190 self.inner.collect_fn_spec(item_id, attrs)?;
191
192 let extern_impl_item = self.extract_extern_id_from_impl_fn(impl_of_trait, item)?;
193 self.insert_extern_id(item_id.def_id, extern_impl_item.item_id)?;
194 self.check_generics(item_id, extern_impl_item.item_id)?;
195
196 Ok(extern_impl_item)
197 }
198
199 fn collect_extern_trait(
200 &mut self,
201 trait_id: OwnerId,
202 bounds: &hir::GenericBounds,
203 items: &[hir::TraitItemRef],
204 attrs: FluxAttrs,
205 ) -> Result {
206 self.inner.collect_trait(trait_id, attrs)?;
207
208 let extern_trait_id = self.extract_extern_id_from_trait(bounds)?;
209 self.insert_extern_id(trait_id.def_id, extern_trait_id)?;
210 self.check_generics(trait_id, extern_trait_id)?;
211
212 for item in items {
213 let item_id = item.id.owner_id.def_id;
214 if let hir::AssocItemKind::Fn { .. } = item.kind {
215 let attrs = self.inner.parse_attrs_and_report_dups(item_id)?;
216 self.collect_extern_trait_fn(extern_trait_id, item, attrs)?;
217 } else {
218 continue;
219 }
220 }
221
222 Ok(())
223 }
224
225 fn collect_extern_trait_fn(
226 &mut self,
227 extern_trait_id: DefId,
228 item: &hir::TraitItemRef,
229 attrs: FluxAttrs,
230 ) -> Result {
231 let item_id = item.id.owner_id;
232 self.inner.collect_fn_spec(item_id, attrs)?;
233
234 let extern_fn_id = self.extract_extern_id_from_trait_fn(extern_trait_id, item)?;
235 self.insert_extern_id(item.id.owner_id.def_id, extern_fn_id)?;
236 self.check_generics(item_id, extern_fn_id)?;
237
238 Ok(())
239 }
240
241 fn extract_extern_id_from_struct(&self, item: &hir::Item) -> Result<DefId> {
242 if let hir::ItemKind::Struct(data, ..) = item.kind
243 && let Some(extern_field) = data.fields().last()
244 && let ty = self.tcx().type_of(extern_field.def_id)
245 && let Some(adt_def) = ty.skip_binder().ty_adt_def()
246 {
247 Ok(adt_def.did())
248 } else {
249 Err(self.malformed())
250 }
251 }
252
253 fn extract_extern_id_from_fn(&self, item: &hir::Item) -> Result<DefId> {
254 if let hir::ItemKind::Fn { body, .. } = item.kind {
255 self.extract_callee_from_body(body)
256 } else {
257 Err(self.malformed())
258 }
259 }
260
261 fn extract_extern_id_from_impl_fn(
262 &self,
263 impl_of_trait: Option<DefId>,
264 item: &hir::ImplItemRef,
265 ) -> Result<ExternImplItem> {
266 if let hir::ImplItemKind::Fn(_, body_id) = self.tcx().hir_impl_item(item.id).kind {
267 let callee_id = self.extract_callee_from_body(body_id)?;
268 if let Some(extern_impl_id) = impl_of_trait {
269 let map = self.tcx().impl_item_implementor_ids(extern_impl_id);
270 if let Some(extern_item_id) = map.get(&callee_id) {
271 Ok(ExternImplItem { impl_id: extern_impl_id, item_id: *extern_item_id })
272 } else {
273 Err(self.item_not_in_trait_impl(item.id.owner_id, callee_id, extern_impl_id))
274 }
275 } else {
276 let opt_extern_impl_id = self.tcx().impl_of_method(callee_id);
277 if let Some(extern_impl_id) = opt_extern_impl_id {
278 debug_assert!(self.tcx().trait_id_of_impl(extern_impl_id).is_none());
279 Ok(ExternImplItem { impl_id: extern_impl_id, item_id: callee_id })
280 } else {
281 Err(self.invalid_item_in_inherent_impl(item.id.owner_id, callee_id))
282 }
283 }
284 } else {
285 Err(self.malformed())
286 }
287 }
288
289 fn extract_extern_id_from_trait(&self, bounds: &hir::GenericBounds) -> Result<DefId> {
290 if let Some(bound) = bounds.first()
291 && let Some(trait_ref) = bound.trait_ref()
292 && let Some(trait_id) = trait_ref.trait_def_id()
293 {
294 Ok(trait_id)
295 } else {
296 Err(self.malformed())
297 }
298 }
299
300 fn extract_extern_id_from_trait_fn(
301 &self,
302 trait_id: DefId,
303 item: &hir::TraitItemRef,
304 ) -> Result<DefId> {
305 if let hir::TraitItemKind::Fn(_, trait_fn) = self.tcx().hir_trait_item(item.id).kind
306 && let hir::TraitFn::Provided(body_id) = trait_fn
307 {
308 let callee_id = self.extract_callee_from_body(body_id)?;
309 if let Some(callee_trait_id) = self.tcx().trait_of_item(callee_id)
310 && trait_id == callee_trait_id
311 {
312 Ok(callee_id)
313 } else {
314 Err(self.item_not_in_trait(item.id.owner_id, callee_id, trait_id))
317 }
318 } else {
319 Err(self.malformed())
320 }
321 }
322
323 fn extract_extern_id_from_impl(&self, impl_id: OwnerId, impl_: &hir::Impl) -> Result<DefId> {
324 if let Some(item) = impl_.items.first()
325 && let hir::AssocItemKind::Fn { .. } = item.kind
326 && let Some((clause, _)) = self
327 .tcx()
328 .predicates_of(item.id.owner_id.def_id)
329 .predicates
330 .first()
331 && let Some(poly_trait_pred) = clause.as_trait_clause()
332 && let Some(trait_pred) = poly_trait_pred.no_bound_vars()
333 {
334 let trait_ref = trait_pred.trait_ref;
335 lowering::resolve_trait_ref_impl_id(self.tcx(), impl_id.to_def_id(), trait_ref)
336 .map(|(impl_id, _)| impl_id)
337 .ok_or_else(|| self.cannot_resolve_trait_impl())
338 } else {
339 Err(self.malformed())
340 }
341 }
342
343 fn extract_callee_from_body(&self, body_id: hir::BodyId) -> Result<DefId> {
344 let owner = self.tcx().hir_body_owner_def_id(body_id);
345 let typeck = self.tcx().typeck(owner);
346 if let hir::ExprKind::Block(b, _) = self.tcx().hir_body(body_id).value.kind
347 && let Some(e) = b.expr
348 && let hir::ExprKind::Call(callee, _) = e.kind
349 && let rustc_middle::ty::FnDef(callee_id, _) = typeck.node_type(callee.hir_id).kind()
350 {
351 Ok(*callee_id)
352 } else {
353 Err(self.malformed())
354 }
355 }
356
357 #[track_caller]
359 fn item_at(&self, i: usize) -> Result<&'tcx hir::Item<'tcx>> {
360 let stmts = self.block.stmts;
361 let index = stmts
362 .len()
363 .checked_sub(i + 1)
364 .ok_or_else(|| self.malformed())?;
365 let st = stmts.get(index).ok_or_else(|| self.malformed())?;
366 if let hir::StmtKind::Item(item_id) = st.kind {
367 Ok(self.tcx().hir_item(item_id))
368 } else {
369 Err(self.malformed())
370 }
371 }
372
373 fn insert_extern_id(&mut self, local_id: LocalDefId, extern_id: DefId) -> Result {
374 self.inner
375 .specs
376 .insert_extern_spec_id_mapping(local_id, extern_id)
377 .map_err(|err| {
378 match err {
379 ExternSpecMappingErr::IsLocal(extern_id_local) => {
380 self.emit(errors::ExternSpecForLocalDef {
381 span: ident_or_def_span(self.tcx(), local_id),
382 local_def_span: ident_or_def_span(self.tcx(), extern_id_local),
383 name: self.tcx().def_path_str(extern_id),
384 })
385 }
386 ExternSpecMappingErr::Dup(previous_extern_spec) => {
387 self.emit(errors::DupExternSpec {
388 span: ident_or_def_span(self.tcx(), local_id),
389 previous_span: ident_or_def_span(self.tcx(), previous_extern_spec),
390 name: self.tcx().def_path_str(extern_id),
391 })
392 }
393 }
394 })
395 }
396
397 fn check_generics(&mut self, local_id: OwnerId, extern_id: DefId) -> Result {
398 let tcx = self.tcx();
399 let local_params = &tcx.generics_of(local_id).own_params;
400 let extern_params = &tcx.generics_of(extern_id).own_params;
401
402 let mismatch = 'mismatch: {
403 if local_params.len() != extern_params.len() {
404 break 'mismatch true;
405 }
406 for (local_param, extern_param) in iter::zip(local_params, extern_params) {
407 if !cmp_generic_param_def(local_param, extern_param) {
408 break 'mismatch true;
409 }
410 if local_param.name != kw::SelfUpper {
413 #[expect(clippy::disallowed_methods)]
414 self.insert_extern_id(local_param.def_id.expect_local(), extern_param.def_id)?;
415 }
416 }
417 false
418 };
419 if mismatch {
420 let local_hir_generics = tcx.hir_get_generics(local_id.def_id).unwrap();
421 let span = local_hir_generics.span;
422 Err(self.emit(errors::MismatchedGenerics {
423 span,
424 extern_def: tcx.def_span(extern_id),
425 def_descr: tcx.def_descr(extern_id),
426 }))
427 } else {
428 Ok(())
429 }
430 }
431
432 #[track_caller]
433 fn malformed(&self) -> ErrorGuaranteed {
434 self.emit(errors::MalformedExternSpec::new(self.block.span))
435 }
436
437 #[track_caller]
438 fn invalid_enum_extern_spec(&self, reason: String) -> ErrorGuaranteed {
439 self.emit(errors::InvalidEnumExternSpec::new(self.block.span, reason))
440 }
441
442 #[track_caller]
443 fn item_not_in_trait_impl(
444 &self,
445 local_id: OwnerId,
446 extern_id: DefId,
447 extern_impl_id: DefId,
448 ) -> ErrorGuaranteed {
449 let tcx = self.tcx();
450 self.emit(errors::ItemNotInTraitImpl {
451 span: ident_or_def_span(tcx, local_id),
452 name: tcx.def_path_str(extern_id),
453 extern_impl_span: tcx.def_span(extern_impl_id),
454 })
455 }
456
457 fn invalid_item_in_inherent_impl(
458 &self,
459 local_id: OwnerId,
460 extern_id: DefId,
461 ) -> ErrorGuaranteed {
462 let tcx = self.tcx();
463 self.emit(errors::InvalidItemInInherentImpl {
464 span: ident_or_def_span(tcx, local_id),
465 name: tcx.def_path_str(extern_id),
466 extern_item_span: tcx.def_span(extern_id),
467 })
468 }
469
470 #[track_caller]
471 fn invalid_impl_block(&self) -> ErrorGuaranteed {
472 self.emit(errors::InvalidImplBlock { span: self.block.span })
473 }
474
475 #[track_caller]
476 fn cannot_resolve_trait_impl(&self) -> ErrorGuaranteed {
477 self.emit(errors::CannotResolveTraitImpl { span: self.block.span })
478 }
479
480 #[track_caller]
481 fn item_not_in_trait(
482 &self,
483 local_id: OwnerId,
484 extern_id: DefId,
485 extern_trait_id: DefId,
486 ) -> ErrorGuaranteed {
487 let tcx = self.tcx();
488 self.emit(errors::ItemNotInTrait {
489 span: ident_or_def_span(tcx, local_id),
490 name: tcx.def_path_str(extern_id),
491 extern_trait_span: tcx.def_span(extern_trait_id),
492 })
493 }
494
495 fn emit<'b>(&'b self, err: impl Diagnostic<'b>) -> ErrorGuaranteed {
496 self.inner.errors.emit(err)
497 }
498
499 fn tcx(&self) -> TyCtxt<'tcx> {
500 self.inner.tcx
501 }
502}
503
504fn cmp_generic_param_def(a: &ty::GenericParamDef, b: &ty::GenericParamDef) -> bool {
505 if a.name != b.name {
506 return false;
507 }
508 if a.index != b.index {
509 return false;
510 }
511 matches!(
512 (&a.kind, &b.kind),
513 (ty::GenericParamDefKind::Lifetime, ty::GenericParamDefKind::Lifetime)
514 | (ty::GenericParamDefKind::Type { .. }, ty::GenericParamDefKind::Type { .. })
515 | (ty::GenericParamDefKind::Const { .. }, ty::GenericParamDefKind::Const { .. })
516 )
517}
518
519fn ident_or_def_span(tcx: TyCtxt, def_id: impl Into<DefId>) -> Span {
520 let def_id = def_id.into();
521 tcx.def_ident_span(def_id)
522 .unwrap_or_else(|| tcx.def_span(def_id))
523}
524
525mod errors {
526 use flux_errors::E0999;
527 use flux_macros::Diagnostic;
528 use rustc_span::Span;
529
530 #[derive(Diagnostic)]
531 #[diag(driver_malformed_extern_spec, code = E0999)]
532 pub(super) struct MalformedExternSpec {
533 #[primary_span]
534 span: Span,
535 }
536
537 impl MalformedExternSpec {
538 pub(super) fn new(span: Span) -> Self {
539 Self { span }
540 }
541 }
542
543 #[derive(Diagnostic)]
544 #[diag(driver_invalid_enum_extern_spec, code = E0999)]
545 pub(super) struct InvalidEnumExternSpec {
546 #[primary_span]
547 span: Span,
548 reason: String,
549 }
550
551 impl InvalidEnumExternSpec {
552 pub(super) fn new(span: Span, reason: String) -> Self {
553 Self { span, reason }
554 }
555 }
556
557 #[derive(Diagnostic)]
558 #[diag(driver_cannot_resolve_trait_impl, code = E0999)]
559 #[note]
560 pub(super) struct CannotResolveTraitImpl {
561 #[primary_span]
562 pub span: Span,
563 }
564
565 #[derive(Diagnostic)]
566 #[diag(driver_invalid_impl_block, code = E0999)]
567 pub(super) struct InvalidImplBlock {
568 #[primary_span]
569 #[label]
570 pub span: Span,
571 }
572
573 #[derive(Diagnostic)]
574 #[diag(driver_item_not_in_trait_impl, code = E0999)]
575 pub(super) struct ItemNotInTraitImpl {
576 #[primary_span]
577 #[label]
578 pub span: Span,
579 pub name: String,
580 #[note]
581 pub extern_impl_span: Span,
582 }
583
584 #[derive(Diagnostic)]
585 #[diag(driver_invalid_item_in_inherent_impl, code = E0999)]
586 pub(super) struct InvalidItemInInherentImpl {
587 #[primary_span]
588 #[label]
589 pub span: Span,
590 pub name: String,
591 #[note]
592 pub extern_item_span: Span,
593 }
594
595 #[derive(Diagnostic)]
596 #[diag(driver_item_not_in_trait, code = E0999)]
597 pub(super) struct ItemNotInTrait {
598 #[primary_span]
599 #[label]
600 pub span: Span,
601 pub name: String,
602 #[note]
603 pub extern_trait_span: Span,
604 }
605
606 #[derive(Diagnostic)]
607 #[diag(driver_extern_spec_for_local_def, code = E0999)]
608 pub(super) struct ExternSpecForLocalDef {
609 #[primary_span]
610 pub span: Span,
611 #[note]
612 pub local_def_span: Span,
613 pub name: String,
614 }
615
616 #[derive(Diagnostic)]
617 #[diag(driver_dup_extern_spec, code = E0999)]
618 pub(super) struct DupExternSpec {
619 #[primary_span]
620 #[label]
621 pub span: Span,
622 #[note]
623 pub previous_span: Span,
624 pub name: String,
625 }
626
627 #[derive(Diagnostic)]
628 #[diag(driver_mismatched_generics, code = E0999)]
629 #[note]
630 pub(super) struct MismatchedGenerics {
631 #[primary_span]
632 #[label]
633 pub span: Span,
634 #[label(driver_extern_def_label)]
635 pub extern_def: Span,
636 pub def_descr: &'static str,
637 }
638}