1use std::{fmt, iter};
7
8use flux_common::bug;
9use flux_errors::Errors;
10use flux_middle::{
11 def_id::MaybeExternId,
12 fhir,
13 global_env::GlobalEnv,
14 queries::QueryResult,
15 rty::{
16 self,
17 fold::{TypeFoldable, TypeFolder, TypeSuperFoldable},
18 refining::{Refine as _, Refiner},
19 },
20};
21use flux_rustc_bridge::ty::{self, FieldIdx, VariantIdx};
22use rustc_ast::Mutability;
23use rustc_data_structures::unord::UnordMap;
24use rustc_type_ir::{DebruijnIndex, INNERMOST, InferConst};
25
26pub(crate) fn type_alias(
27 genv: GlobalEnv,
28 alias: &fhir::TyAlias,
29 alias_ty: &rty::TyCtor,
30 def_id: MaybeExternId,
31) -> QueryResult<rty::TyCtor> {
32 let rust_ty = genv.lower_type_of(def_id.resolved_id())?.skip_binder();
33 let expected = rust_ty.refine(&Refiner::default_for_item(genv, def_id.resolved_id())?)?;
34 let mut zipper = Zipper::new(genv, def_id);
35
36 if zipper
37 .enter_a_binder(alias_ty, |zipper, ty| zipper.zip_ty(ty, &expected))
38 .is_err()
39 {
40 zipper
41 .errors
42 .emit(errors::IncompatibleRefinement::type_alias(genv, def_id, alias));
43 }
44
45 zipper.errors.to_result()?;
46
47 Ok(zipper.holes.replace_holes(alias_ty))
48}
49
50pub(crate) fn fn_sig(
51 genv: GlobalEnv,
52 decl: &fhir::FnDecl,
53 fn_sig: &rty::PolyFnSig,
54 def_id: MaybeExternId,
55) -> QueryResult<rty::PolyFnSig> {
56 let rust_fn_sig = genv.lower_fn_sig(def_id.resolved_id())?.skip_binder();
57
58 let expected = Refiner::default_for_item(genv, def_id.resolved_id())?.refine(&rust_fn_sig)?;
59
60 let mut zipper = Zipper::new(genv, def_id);
61 if let Err(err) = zipper.zip_poly_fn_sig(fn_sig, &expected) {
62 zipper.emit_fn_sig_err(err, decl);
63 }
64
65 zipper.errors.to_result()?;
66
67 Ok(zipper.holes.replace_holes(fn_sig))
68}
69
70pub(crate) fn variants(
71 genv: GlobalEnv,
72 variants: &[rty::PolyVariant],
73 adt_def_id: MaybeExternId,
74) -> QueryResult<Vec<rty::PolyVariant>> {
75 let refiner = Refiner::default_for_item(genv, adt_def_id.resolved_id())?;
76 let mut zipper = Zipper::new(genv, adt_def_id);
77 for (i, variant) in variants.iter().enumerate() {
79 let variant_idx = VariantIdx::from_usize(i);
80 let expected = refiner.refine_variant_def(adt_def_id.resolved_id(), variant_idx)?;
81 zipper.zip_variant(variant, &expected, variant_idx);
82 }
83
84 zipper.errors.to_result()?;
85
86 Ok(variants
87 .iter()
88 .map(|v| zipper.holes.replace_holes(v))
89 .collect())
90}
91
92struct Zipper<'genv, 'tcx> {
93 genv: GlobalEnv<'genv, 'tcx>,
94 owner_id: MaybeExternId,
95 locs: UnordMap<rty::Loc, rty::Ty>,
96 holes: Holes,
97 a_binders: u32,
99 b_binder_to_a_binder: Vec<Option<u32>>,
103 errors: Errors<'genv>,
104}
105
106#[derive(Default)]
107struct Holes {
108 sorts: UnordMap<rty::SortVid, rty::Sort>,
109 subset_tys: UnordMap<rty::TyVid, rty::SubsetTy>,
110 types: UnordMap<rty::TyVid, rty::Ty>,
111 regions: UnordMap<rty::RegionVid, rty::Region>,
112 consts: UnordMap<rty::ConstVid, rty::Const>,
113}
114
115impl TypeFolder for &Holes {
116 fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
117 if let rty::Sort::Infer(vid) = sort {
118 self.sorts
119 .get(vid)
120 .cloned()
121 .unwrap_or_else(|| bug!("unfilled sort hole {vid:?}"))
122 } else {
123 sort.super_fold_with(self)
124 }
125 }
126
127 fn fold_ty(&mut self, ty: &rty::Ty) -> rty::Ty {
128 if let rty::TyKind::Infer(vid) = ty.kind() {
129 self.types
130 .get(vid)
131 .cloned()
132 .unwrap_or_else(|| bug!("unfilled type hole {vid:?}"))
133 } else {
134 ty.super_fold_with(self)
135 }
136 }
137
138 fn fold_subset_ty(&mut self, constr: &rty::SubsetTy) -> rty::SubsetTy {
139 if let rty::BaseTy::Infer(vid) = &constr.bty {
140 self.subset_tys
141 .get(vid)
142 .cloned()
143 .unwrap_or_else(|| bug!("unfilled type hole {vid:?}"))
144 } else {
145 constr.super_fold_with(self)
146 }
147 }
148
149 fn fold_region(&mut self, r: &rty::Region) -> rty::Region {
150 if let rty::Region::ReVar(vid) = r {
151 self.regions
152 .get(vid)
153 .copied()
154 .unwrap_or_else(|| bug!("unfilled region hole {vid:?}"))
155 } else {
156 *r
157 }
158 }
159
160 fn fold_const(&mut self, ct: &rty::Const) -> rty::Const {
161 if let rty::ConstKind::Infer(InferConst::Var(cid)) = ct.kind {
162 self.consts
163 .get(&cid)
164 .cloned()
165 .unwrap_or_else(|| bug!("unfilled const hole {cid:?}"))
166 } else {
167 ct.super_fold_with(self)
168 }
169 }
170}
171
172impl Holes {
173 fn replace_holes<T: TypeFoldable>(&self, t: &T) -> T {
174 let mut this = self;
175 t.fold_with(&mut this)
176 }
177}
178
179impl<'genv, 'tcx> Zipper<'genv, 'tcx> {
180 fn new(genv: GlobalEnv<'genv, 'tcx>, owner_id: MaybeExternId) -> Self {
181 Self {
182 genv,
183 owner_id,
184 locs: UnordMap::default(),
185 holes: Default::default(),
186 a_binders: 0,
187 b_binder_to_a_binder: vec![],
188 errors: Errors::new(genv.sess()),
189 }
190 }
191
192 fn is_async_fn(&self) -> bool {
193 self.genv
194 .tcx()
195 .asyncness(self.owner_id.resolved_id())
196 .is_async()
197 }
198
199 fn zip_poly_fn_sig(&mut self, a: &rty::PolyFnSig, b: &rty::PolyFnSig) -> Result<(), FnSigErr> {
200 self.enter_binders(a, b, |this, a, b| this.zip_fn_sig(a, b))
201 }
202
203 fn zip_variant(&mut self, a: &rty::PolyVariant, b: &rty::PolyVariant, variant_idx: VariantIdx) {
204 self.enter_binders(a, b, |this, a, b| {
205 debug_assert_eq!(a.args, b.args);
207
208 if a.fields.len() != b.fields.len() {
209 this.errors.emit(errors::FieldCountMismatch::new(
210 this.genv,
211 a.fields.len(),
212 this.owner_id,
213 variant_idx,
214 ));
215 return;
216 }
217 for (i, (ty_a, ty_b)) in iter::zip(&a.fields, &b.fields).enumerate() {
218 let field_idx = FieldIdx::from_usize(i);
219 if this.zip_ty(ty_a, ty_b).is_err() {
220 this.errors.emit(errors::IncompatibleRefinement::field(
221 this.genv,
222 this.owner_id,
223 variant_idx,
224 field_idx,
225 ));
226 }
227 }
228 });
229 }
230
231 fn zip_fn_sig(&mut self, a: &rty::FnSig, b: &rty::FnSig) -> Result<(), FnSigErr> {
232 if a.inputs().len() != b.inputs().len() {
233 Err(FnSigErr::ArgCountMismatch)?;
234 }
235 for (i, (ty_a, ty_b)) in iter::zip(a.inputs(), b.inputs()).enumerate() {
236 self.zip_ty(ty_a, ty_b).map_err(|_| FnSigErr::FnInput(i))?;
237 }
238 self.enter_binders(&a.output, &b.output, |this, output_a, output_b| {
239 this.zip_output(output_a, output_b)
240 })
241 }
242
243 fn zip_output(&mut self, a: &rty::FnOutput, b: &rty::FnOutput) -> Result<(), FnSigErr> {
244 self.zip_ty(&a.ret, &b.ret).map_err(FnSigErr::FnOutput)?;
245
246 for (i, ensures) in a.ensures.iter().enumerate() {
247 if let rty::Ensures::Type(path, ty_a) = ensures {
248 let loc = path.to_loc().unwrap();
249 let ty_b = self.locs.get(&loc).unwrap().shift_in_escaping(1);
250 self.zip_ty(ty_a, &ty_b)
251 .map_err(|_| FnSigErr::Ensures { i, expected: ty_b })?;
252 }
253 }
254 Ok(())
255 }
256
257 fn zip_ty(&mut self, a: &rty::Ty, b: &rty::Ty) -> Result<(), Mismatch> {
258 match (a.kind(), b.kind()) {
259 (rty::TyKind::Infer(vid), _) => {
260 assert_ne!(vid.as_u32(), 0);
261 let b = self.adjust_bvars(b);
262 self.holes.types.insert(*vid, b);
263 Ok(())
264 }
265 (rty::TyKind::Exists(ctor_a), _) => {
266 self.enter_a_binder(ctor_a, |this, ty_a| this.zip_ty(ty_a, b))
267 }
268 (_, rty::TyKind::Exists(ctor_b)) => {
269 self.enter_b_binder(ctor_b, |this, ty_b| this.zip_ty(a, ty_b))
270 }
271 (rty::TyKind::Constr(_, ty_a), _) => self.zip_ty(ty_a, b),
272 (_, rty::TyKind::Constr(_, ty_b)) => self.zip_ty(a, ty_b),
273 (rty::TyKind::Indexed(bty_a, _), rty::TyKind::Indexed(bty_b, _)) => {
274 self.zip_bty(bty_a, bty_b)
275 }
276 (rty::TyKind::StrgRef(re_a, path, ty_a), rty::Ref!(re_b, ty_b, Mutability::Mut)) => {
277 let loc = path.to_loc().unwrap();
278 self.locs.insert(loc, ty_b.clone());
279
280 self.zip_region(re_a, re_b);
281 self.zip_ty(ty_a, ty_b)
282 }
283 (rty::TyKind::Param(pty_a), rty::TyKind::Param(pty_b)) => {
284 assert_eq_or_incompatible(pty_a, pty_b)
285 }
286 (
287 rty::TyKind::Ptr(_, _)
288 | rty::TyKind::Discr(..)
289 | rty::TyKind::Downcast(_, _, _, _, _)
290 | rty::TyKind::Blocked(_)
291 | rty::TyKind::Uninit,
292 _,
293 ) => {
294 bug!("unexpected type {a:?}");
295 }
296 _ => Err(Mismatch::new(a, b)),
297 }
298 }
299
300 fn zip_bty(&mut self, a: &rty::BaseTy, b: &rty::BaseTy) -> Result<(), Mismatch> {
301 match (a, b) {
302 (rty::BaseTy::Int(ity_a), rty::BaseTy::Int(ity_b)) => {
303 assert_eq_or_incompatible(ity_a, ity_b)
304 }
305 (rty::BaseTy::Uint(uity_a), rty::BaseTy::Uint(uity_b)) => {
306 assert_eq_or_incompatible(uity_a, uity_b)
307 }
308 (rty::BaseTy::Bool, rty::BaseTy::Bool) => Ok(()),
309 (rty::BaseTy::Str, rty::BaseTy::Str) => Ok(()),
310 (rty::BaseTy::Char, rty::BaseTy::Char) => Ok(()),
311 (rty::BaseTy::Float(fty_a), rty::BaseTy::Float(fty_b)) => {
312 assert_eq_or_incompatible(fty_a, fty_b)
313 }
314 (rty::BaseTy::Slice(ty_a), rty::BaseTy::Slice(ty_b)) => self.zip_ty(ty_a, ty_b),
315 (rty::BaseTy::Adt(adt_def_a, args_a), rty::BaseTy::Adt(adt_def_b, args_b)) => {
316 assert_eq_or_incompatible(adt_def_a.did(), adt_def_b.did())?;
317 assert_eq_or_incompatible(args_a.len(), args_b.len())?;
318 for (arg_a, arg_b) in iter::zip(args_a, args_b) {
319 self.zip_generic_arg(arg_a, arg_b)?;
320 }
321 Ok(())
322 }
323 (rty::BaseTy::RawPtr(ty_a, mutbl_a), rty::BaseTy::RawPtr(ty_b, mutbl_b)) => {
324 assert_eq_or_incompatible(mutbl_a, mutbl_b)?;
325 self.zip_ty(ty_a, ty_b)
326 }
327 (rty::BaseTy::Ref(re_a, ty_a, mutbl_a), rty::BaseTy::Ref(re_b, ty_b, mutbl_b)) => {
328 assert_eq_or_incompatible(mutbl_a, mutbl_b)?;
329 self.zip_region(re_a, re_b);
330 self.zip_ty(ty_a, ty_b)
331 }
332 (rty::BaseTy::FnPtr(poly_sig_a), rty::BaseTy::FnPtr(poly_sig_b)) => {
333 self.zip_poly_fn_sig(poly_sig_a, poly_sig_b)
334 .map_err(|_| Mismatch::new(poly_sig_a, poly_sig_b))
335 }
336 (rty::BaseTy::Tuple(tys_a), rty::BaseTy::Tuple(tys_b)) => {
337 assert_eq_or_incompatible(tys_a.len(), tys_b.len())?;
338 for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
339 self.zip_ty(ty_a, ty_b)?;
340 }
341 Ok(())
342 }
343 (rty::BaseTy::Alias(kind_a, aty_a), rty::BaseTy::Alias(kind_b, aty_b)) => {
344 assert_eq_or_incompatible(kind_a, kind_b)?;
345 assert_eq_or_incompatible(aty_a.def_id, aty_b.def_id)?;
346 assert_eq_or_incompatible(aty_a.args.len(), aty_b.args.len())?;
347 for (arg_a, arg_b) in iter::zip(&aty_a.args, &aty_b.args) {
348 self.zip_generic_arg(arg_a, arg_b)?;
349 }
350 Ok(())
351 }
352 (rty::BaseTy::Array(ty_a, len_a), rty::BaseTy::Array(ty_b, len_b)) => {
353 self.zip_const(len_a, len_b)?;
354 self.zip_ty(ty_a, ty_b)
355 }
356 (rty::BaseTy::Never, rty::BaseTy::Never) => Ok(()),
357 (rty::BaseTy::Param(pty_a), rty::BaseTy::Param(pty_b)) => {
358 assert_eq_or_incompatible(pty_a, pty_b)
359 }
360 (rty::BaseTy::Dynamic(preds_a, re_a), rty::BaseTy::Dynamic(preds_b, re_b)) => {
361 assert_eq_or_incompatible(preds_a.len(), preds_b.len())?;
362 for (pred_a, pred_b) in iter::zip(preds_a, preds_b) {
363 self.zip_poly_existential_pred(pred_a, pred_b)?;
364 }
365 self.zip_region(re_a, re_b);
366 Ok(())
367 }
368 (rty::BaseTy::Foreign(def_id_a), rty::BaseTy::Foreign(def_id_b)) => {
369 assert_eq_or_incompatible(def_id_a, def_id_b)
370 }
371 (rty::BaseTy::Closure(..) | rty::BaseTy::Coroutine(..), _) => {
372 bug!("unexpected type `{a:?}`");
373 }
374 _ => Err(Mismatch::new(a, b)),
375 }
376 }
377
378 fn zip_generic_arg(
379 &mut self,
380 a: &rty::GenericArg,
381 b: &rty::GenericArg,
382 ) -> Result<(), Mismatch> {
383 match (a, b) {
384 (rty::GenericArg::Ty(ty_a), rty::GenericArg::Ty(ty_b)) => self.zip_ty(ty_a, ty_b),
385 (rty::GenericArg::Base(ctor_a), rty::GenericArg::Base(ctor_b)) => {
386 self.zip_sorts(&ctor_a.sort(), &ctor_b.sort());
387 self.enter_binders(ctor_a, ctor_b, |this, sty_a, sty_b| {
388 this.zip_subset_ty(sty_a, sty_b)
389 })
390 }
391 (rty::GenericArg::Lifetime(re_a), rty::GenericArg::Lifetime(re_b)) => {
392 self.zip_region(re_a, re_b);
393 Ok(())
394 }
395 (rty::GenericArg::Const(ct_a), rty::GenericArg::Const(ct_b)) => {
396 self.zip_const(ct_a, ct_b)
397 }
398 _ => Err(Mismatch::new(a, b)),
399 }
400 }
401
402 fn zip_sorts(&mut self, a: &rty::Sort, b: &rty::Sort) {
403 if let rty::Sort::Infer(vid) = a {
404 assert_ne!(vid.as_u32(), 0);
405 self.holes.sorts.insert(*vid, b.clone());
406 }
407 }
408
409 fn zip_subset_ty(&mut self, a: &rty::SubsetTy, b: &rty::SubsetTy) -> Result<(), Mismatch> {
410 if let rty::BaseTy::Infer(vid) = a.bty {
411 assert_ne!(vid.as_u32(), 0);
412 let b = self.adjust_bvars(b);
413 self.holes.subset_tys.insert(vid, b);
414 Ok(())
415 } else {
416 self.zip_bty(&a.bty, &b.bty)
417 }
418 }
419
420 fn zip_const(&mut self, a: &rty::Const, b: &ty::Const) -> Result<(), Mismatch> {
421 match (&a.kind, &b.kind) {
422 (rty::ConstKind::Infer(ty::InferConst::Var(cid)), _) => {
423 self.holes.consts.insert(*cid, b.clone());
424 Ok(())
425 }
426 (rty::ConstKind::Param(param_const_a), ty::ConstKind::Param(param_const_b)) => {
427 assert_eq_or_incompatible(param_const_a, param_const_b)
428 }
429 (rty::ConstKind::Value(ty_a, val_a), ty::ConstKind::Value(ty_b, val_b)) => {
430 assert_eq_or_incompatible(ty_a, ty_b)?;
431 assert_eq_or_incompatible(val_a, val_b)
432 }
433 (rty::ConstKind::Unevaluated(c1), ty::ConstKind::Unevaluated(c2)) => {
434 assert_eq_or_incompatible(c1, c2)
435 }
436 _ => Err(Mismatch::new(a, b)),
437 }
438 }
439
440 fn zip_region(&mut self, a: &rty::Region, b: &ty::Region) {
441 if let rty::Region::ReVar(vid) = a {
442 let re = self.adjust_bvars(b);
443 self.holes.regions.insert(*vid, re);
444 }
445 }
446
447 fn zip_poly_existential_pred(
448 &mut self,
449 a: &rty::Binder<rty::ExistentialPredicate>,
450 b: &rty::Binder<rty::ExistentialPredicate>,
451 ) -> Result<(), Mismatch> {
452 self.enter_binders(a, b, |this, a, b| {
453 match (a, b) {
454 (
455 rty::ExistentialPredicate::Trait(trait_ref_a),
456 rty::ExistentialPredicate::Trait(trait_ref_b),
457 ) => {
458 assert_eq_or_incompatible(trait_ref_a.def_id, trait_ref_b.def_id)?;
459 assert_eq_or_incompatible(trait_ref_a.args.len(), trait_ref_b.args.len())?;
460 for (arg_a, arg_b) in iter::zip(&trait_ref_a.args, &trait_ref_b.args) {
461 this.zip_generic_arg(arg_a, arg_b)?;
462 }
463 Ok(())
464 }
465 (
466 rty::ExistentialPredicate::Projection(projection_a),
467 rty::ExistentialPredicate::Projection(projection_b),
468 ) => {
469 assert_eq_or_incompatible(projection_a.def_id, projection_b.def_id)?;
470 assert_eq_or_incompatible(projection_a.args.len(), projection_b.args.len())?;
471 for (arg_a, arg_b) in iter::zip(&projection_a.args, &projection_b.args) {
472 this.zip_generic_arg(arg_a, arg_b)?;
473 }
474 this.enter_binders(&projection_a.term, &projection_b.term, |this, a, b| {
475 this.zip_bty(&a.bty, &b.bty)
476 })
477 }
478 (
479 rty::ExistentialPredicate::AutoTrait(def_id_a),
480 rty::ExistentialPredicate::AutoTrait(def_id_b),
481 ) => assert_eq_or_incompatible(def_id_a, def_id_b),
482 _ => Err(Mismatch::new(a, b)),
483 }
484 })
485 }
486
487 fn enter_binders<T, R>(
489 &mut self,
490 a: &rty::Binder<T>,
491 b: &rty::Binder<T>,
492 f: impl FnOnce(&mut Self, &T, &T) -> R,
493 ) -> R {
494 self.b_binder_to_a_binder.push(Some(self.a_binders));
495 self.a_binders += 1;
496 let r = f(self, a.skip_binder_ref(), b.skip_binder_ref());
497 self.a_binders -= 1;
498 self.b_binder_to_a_binder.pop();
499 r
500 }
501
502 fn enter_a_binder<T, R>(
504 &mut self,
505 t: &rty::Binder<T>,
506 f: impl FnOnce(&mut Self, &T) -> R,
507 ) -> R {
508 self.a_binders += 1;
509 let r = f(self, t.skip_binder_ref());
510 self.a_binders -= 1;
511 r
512 }
513
514 fn enter_b_binder<T, R>(
516 &mut self,
517 t: &rty::Binder<T>,
518 f: impl FnOnce(&mut Self, &T) -> R,
519 ) -> R {
520 self.b_binder_to_a_binder.push(None);
521 let r = f(self, t.skip_binder_ref());
522 self.b_binder_to_a_binder.pop();
523 r
524 }
525
526 fn adjust_bvars<T: TypeFoldable + Clone + std::fmt::Debug>(&self, t: &T) -> T {
527 struct Adjuster<'a, 'genv, 'tcx> {
528 current_index: DebruijnIndex,
529 zipper: &'a Zipper<'genv, 'tcx>,
530 }
531
532 impl Adjuster<'_, '_, '_> {
533 fn adjust(&self, debruijn: DebruijnIndex) -> DebruijnIndex {
534 let b_binders = self.zipper.b_binder_to_a_binder.len();
535 let mapped_binder = self.zipper.b_binder_to_a_binder
536 [b_binders - debruijn.as_usize() - 1]
537 .unwrap_or_else(|| {
538 bug!("bound var without corresponding binder: `{debruijn:?}`")
539 });
540 DebruijnIndex::from_u32(self.zipper.a_binders - mapped_binder - 1)
541 .shifted_in(self.current_index.as_u32())
542 }
543 }
544
545 impl TypeFolder for Adjuster<'_, '_, '_> {
546 fn enter_binder(&mut self, _: &rty::BoundVariableKinds) {
547 self.current_index.shift_in(1);
548 }
549
550 fn exit_binder(&mut self) {
551 self.current_index.shift_out(1);
552 }
553
554 fn fold_region(&mut self, re: &rty::Region) -> rty::Region {
555 if let rty::ReBound(debruijn, br) = *re
556 && debruijn >= self.current_index
557 {
558 rty::ReBound(self.adjust(debruijn), br)
559 } else {
560 *re
561 }
562 }
563
564 fn fold_expr(&mut self, expr: &rty::Expr) -> rty::Expr {
565 if let rty::ExprKind::Var(rty::Var::Bound(debruijn, breft)) = expr.kind()
566 && *debruijn >= self.current_index
567 {
568 rty::Expr::bvar(self.adjust(*debruijn), breft.var, breft.kind)
569 } else {
570 expr.super_fold_with(self)
571 }
572 }
573 }
574 t.fold_with(&mut Adjuster { current_index: INNERMOST, zipper: self })
575 }
576
577 fn emit_fn_sig_err(&mut self, err: FnSigErr, decl: &fhir::FnDecl) {
578 match err {
579 FnSigErr::ArgCountMismatch => {
580 self.errors.emit(errors::IncompatibleParamCount::new(
581 self.genv,
582 decl,
583 self.owner_id,
584 ));
585 }
586 FnSigErr::FnInput(i) => {
587 self.errors.emit(errors::IncompatibleRefinement::fn_input(
588 self.genv,
589 self.owner_id,
590 decl,
591 i,
592 ));
593 }
594 FnSigErr::FnOutput(_) => {
595 self.errors.emit(errors::IncompatibleRefinement::fn_output(
596 self.genv,
597 self.owner_id,
598 decl,
599 self.is_async_fn(),
600 ));
601 }
602 FnSigErr::Ensures { i, expected } => {
603 self.errors.emit(errors::IncompatibleRefinement::ensures(
604 self.genv,
605 self.owner_id,
606 decl,
607 &expected,
608 i,
609 ));
610 }
611 }
612 }
613}
614
615fn assert_eq_or_incompatible<T: Eq + fmt::Debug>(a: T, b: T) -> Result<(), Mismatch> {
616 if a != b {
617 return Err(Mismatch::new(a, b));
618 }
619 Ok(())
620}
621
622#[expect(dead_code, reason = "we use the the String for debugging")]
623struct Mismatch(String);
624
625impl Mismatch {
626 fn new<T: fmt::Debug>(a: T, b: T) -> Self {
627 Self(format!("{a:?} != {b:?}"))
628 }
629}
630
631enum FnSigErr {
632 ArgCountMismatch,
633 FnInput(usize),
634 #[expect(dead_code, reason = "we use the struct for debugging")]
635 FnOutput(Mismatch),
636 Ensures {
637 i: usize,
638 expected: rty::Ty,
639 },
640}
641
642mod errors {
643 use flux_common::span_bug;
644 use flux_errors::E0999;
645 use flux_macros::Diagnostic;
646 use flux_middle::{def_id::MaybeExternId, fhir, global_env::GlobalEnv, rty};
647 use flux_rustc_bridge::{
648 ToRustc,
649 ty::{FieldIdx, VariantIdx},
650 };
651 use rustc_span::{DUMMY_SP, Span};
652
653 #[derive(Diagnostic)]
654 #[diag(fhir_analysis_incompatible_refinement, code = E0999)]
655 #[note]
656 pub(super) struct IncompatibleRefinement<'tcx> {
657 #[primary_span]
658 #[label]
659 span: Span,
660 #[label(fhir_analysis_expected_label)]
661 expected_span: Option<Span>,
662 expected_ty: rustc_middle::ty::Ty<'tcx>,
663 def_descr: &'static str,
664 #[help(fhir_analysis_async_hint)]
665 async_hint: Option<()>,
666 }
667
668 impl<'tcx> IncompatibleRefinement<'tcx> {
669 pub(super) fn type_alias(
670 genv: GlobalEnv<'_, 'tcx>,
671 def_id: MaybeExternId,
672 type_alias: &fhir::TyAlias,
673 ) -> Self {
674 let tcx = genv.tcx();
675 Self {
676 span: type_alias.ty.span,
677 def_descr: tcx.def_descr(def_id.resolved_id()),
678 expected_span: Some(tcx.def_span(def_id)),
679 expected_ty: tcx.type_of(def_id).skip_binder(),
680 async_hint: None,
681 }
682 }
683
684 pub(super) fn fn_input(
685 genv: GlobalEnv<'_, 'tcx>,
686 fn_id: MaybeExternId,
687 decl: &fhir::FnDecl,
688 pos: usize,
689 ) -> Self {
690 let expected_span = match fn_id {
691 MaybeExternId::Local(local_id) => {
692 genv.tcx()
693 .hir_node_by_def_id(local_id)
694 .fn_decl()
695 .and_then(|fn_decl| fn_decl.inputs.get(pos))
696 .map(|input| input.span)
697 }
698 MaybeExternId::Extern(_, extern_id) => Some(genv.tcx().def_span(extern_id)),
699 };
700
701 let expected_ty = genv
702 .tcx()
703 .fn_sig(fn_id.resolved_id())
704 .skip_binder()
705 .inputs()
706 .map_bound(|inputs| inputs[pos])
707 .skip_binder();
708
709 Self {
710 span: decl.inputs[pos].span,
711 def_descr: genv.tcx().def_descr(fn_id.resolved_id()),
712 expected_span,
713 expected_ty,
714 async_hint: None,
715 }
716 }
717
718 pub(super) fn fn_output(
719 genv: GlobalEnv<'_, 'tcx>,
720 fn_id: MaybeExternId,
721 decl: &fhir::FnDecl,
722 is_async: bool,
723 ) -> Self {
724 let expected_span = match fn_id {
725 MaybeExternId::Local(local_id) => {
726 genv.tcx()
727 .hir_node_by_def_id(local_id)
728 .fn_decl()
729 .map(|fn_decl| fn_decl.output.span())
730 }
731 MaybeExternId::Extern(_, extern_id) => Some(genv.tcx().def_span(extern_id)),
732 };
733
734 let expected_ty = genv
735 .tcx()
736 .fn_sig(fn_id.resolved_id())
737 .skip_binder()
738 .output()
739 .skip_binder();
740 let spec_span = decl.output.ret.span;
741
742 let async_hint = if is_async { Some(()) } else { None };
743
744 Self {
745 span: spec_span,
746 def_descr: genv.tcx().def_descr(fn_id.resolved_id()),
747 expected_span,
748 expected_ty,
749 async_hint,
750 }
751 }
752
753 pub(super) fn ensures(
754 genv: GlobalEnv<'_, 'tcx>,
755 fn_id: MaybeExternId,
756 decl: &fhir::FnDecl,
757 expected: &rty::Ty,
758 i: usize,
759 ) -> Self {
760 let fhir::Ensures::Type(_, ty) = &decl.output.ensures[i] else {
761 span_bug!(decl.span, "expected `fhir::Ensures::Type`");
762 };
763 let tcx = genv.tcx();
764 Self {
765 span: ty.span,
766 def_descr: tcx.def_descr(fn_id.resolved_id()),
767 expected_span: None,
768 expected_ty: expected.to_rustc(tcx),
769 async_hint: None,
770 }
771 }
772
773 pub(super) fn field(
774 genv: GlobalEnv<'_, 'tcx>,
775 adt_id: MaybeExternId,
776 variant_idx: VariantIdx,
777 field_idx: FieldIdx,
778 ) -> Self {
779 let tcx = genv.tcx();
780 let adt_def = tcx.adt_def(adt_id);
781 let field_def = &adt_def.variant(variant_idx).fields[field_idx];
782
783 let item = genv.fhir_expect_item(adt_id.local_id()).unwrap();
784 let span = match &item.kind {
785 fhir::ItemKind::Enum(enum_def) => {
786 enum_def.variants[variant_idx.as_usize()].fields[field_idx.as_usize()]
787 .ty
788 .span
789 }
790 fhir::ItemKind::Struct(struct_def)
791 if let fhir::StructKind::Transparent { fields } = &struct_def.kind =>
792 {
793 fields[field_idx.as_usize()].ty.span
794 }
795 _ => DUMMY_SP,
796 };
797
798 Self {
799 span,
800 def_descr: tcx.def_descr(field_def.did),
801 expected_span: Some(tcx.def_span(field_def.did)),
802 expected_ty: tcx.type_of(field_def.did).skip_binder(),
803 async_hint: None,
804 }
805 }
806 }
807
808 #[derive(Diagnostic)]
809 #[diag(fhir_analysis_incompatible_param_count, code = E0999)]
810 pub(super) struct IncompatibleParamCount {
811 #[primary_span]
812 #[label]
813 span: Span,
814 found: usize,
815 #[label(fhir_analysis_expected_label)]
816 expected_span: Span,
817 expected: usize,
818 def_descr: &'static str,
819 }
820
821 impl IncompatibleParamCount {
822 pub(super) fn new(genv: GlobalEnv, decl: &fhir::FnDecl, def_id: MaybeExternId) -> Self {
823 let def_descr = genv.tcx().def_descr(def_id.resolved_id());
824
825 let span = if !decl.inputs.is_empty() {
826 decl.inputs[decl.inputs.len() - 1]
827 .span
828 .with_lo(decl.inputs[0].span.lo())
829 } else {
830 decl.span
831 };
832
833 let expected_span = if let Some(local_id) = def_id.as_local()
834 && let expected_decl = genv.tcx().hir_node_by_def_id(local_id).fn_decl().unwrap()
835 && !expected_decl.inputs.is_empty()
836 {
837 expected_decl.inputs[expected_decl.inputs.len() - 1]
838 .span
839 .with_lo(expected_decl.inputs[0].span.lo())
840 } else {
841 genv.tcx().def_span(def_id)
842 };
843
844 let expected = genv
845 .tcx()
846 .fn_sig(def_id)
847 .skip_binder()
848 .skip_binder()
849 .inputs()
850 .len();
851
852 Self { span, found: decl.inputs.len(), expected_span, expected, def_descr }
853 }
854 }
855
856 #[derive(Diagnostic)]
857 #[diag(fhir_analysis_field_count_mismatch, code = E0999)]
858 pub(super) struct FieldCountMismatch {
859 #[primary_span]
860 #[label]
861 span: Span,
862 fields: usize,
863 #[label(fhir_analysis_expected_label)]
864 expected_span: Span,
865 expected_fields: usize,
866 }
867
868 impl FieldCountMismatch {
869 pub(super) fn new(
870 genv: GlobalEnv,
871 found: usize,
872 adt_def_id: MaybeExternId,
873 variant_idx: VariantIdx,
874 ) -> Self {
875 let adt_def = genv.tcx().adt_def(adt_def_id);
876 let expected_variant = adt_def.variant(variant_idx);
877
878 let span = if let Ok(fhir::Node::Item(item)) = genv.fhir_node(adt_def_id.local_id())
881 && let fhir::ItemKind::Enum(enum_def) = &item.kind
882 && let Some(variant) = enum_def.variants.get(variant_idx.as_usize())
883 {
884 variant.span
885 } else {
886 DUMMY_SP
887 };
888
889 Self {
890 span,
891 fields: found,
892 expected_span: genv.tcx().def_span(expected_variant.def_id),
893 expected_fields: expected_variant.fields.len(),
894 }
895 }
896 }
897}