1use std::{fmt, iter};
7
8use flux_common::bug;
9use flux_errors::Errors;
10use flux_middle::{
11 def_id::MaybeExternId,
12 fhir,
13 global_env::GlobalEnv,
14 queries::QueryResult,
15 rty::{
16 self,
17 fold::{TypeFoldable, TypeFolder, TypeSuperFoldable},
18 refining::{Refine as _, Refiner},
19 },
20};
21use flux_rustc_bridge::ty::{self, FieldIdx, VariantIdx};
22use rustc_ast::Mutability;
23use rustc_data_structures::unord::UnordMap;
24use rustc_type_ir::{DebruijnIndex, INNERMOST, InferConst};
25
26pub(crate) fn type_alias(
27 genv: GlobalEnv,
28 alias: &fhir::TyAlias,
29 alias_ty: &rty::TyCtor,
30 def_id: MaybeExternId,
31) -> QueryResult<rty::TyCtor> {
32 let rust_ty = genv.lower_type_of(def_id.resolved_id())?.skip_binder();
33 let expected = rust_ty.refine(&Refiner::default_for_item(genv, def_id.resolved_id())?)?;
34 let mut zipper = Zipper::new(genv, def_id);
35
36 if zipper
37 .enter_a_binder(alias_ty, |zipper, ty| zipper.zip_ty(ty, &expected))
38 .is_err()
39 {
40 zipper
41 .errors
42 .emit(errors::IncompatibleRefinement::type_alias(genv, def_id, alias));
43 }
44
45 zipper.errors.into_result()?;
46
47 Ok(zipper.holes.replace_holes(alias_ty))
48}
49
50pub(crate) fn fn_sig(
51 genv: GlobalEnv,
52 decl: &fhir::FnDecl,
53 fn_sig: &rty::PolyFnSig,
54 def_id: MaybeExternId,
55) -> QueryResult<rty::PolyFnSig> {
56 let rust_fn_sig = genv.lower_fn_sig(def_id.resolved_id())?.skip_binder();
57
58 let expected = Refiner::default_for_item(genv, def_id.resolved_id())?.refine(&rust_fn_sig)?;
59
60 let mut zipper = Zipper::new(genv, def_id);
61 if let Err(err) = zipper.zip_poly_fn_sig(fn_sig, &expected) {
62 zipper.emit_fn_sig_err(err, decl);
63 }
64
65 zipper.errors.into_result()?;
66
67 Ok(zipper.holes.replace_holes(fn_sig))
68}
69
70pub(crate) fn variants(
71 genv: GlobalEnv,
72 variants: &[rty::PolyVariant],
73 adt_def_id: MaybeExternId,
74) -> QueryResult<Vec<rty::PolyVariant>> {
75 let refiner = Refiner::default_for_item(genv, adt_def_id.resolved_id())?;
76 let mut zipper = Zipper::new(genv, adt_def_id);
77 for (i, variant) in variants.iter().enumerate() {
79 let variant_idx = VariantIdx::from_usize(i);
80 let expected = refiner.refine_variant_def(adt_def_id.resolved_id(), variant_idx)?;
81 zipper.zip_variant(variant, &expected, variant_idx);
82 }
83
84 zipper.errors.into_result()?;
85
86 Ok(variants
87 .iter()
88 .map(|v| zipper.holes.replace_holes(v))
89 .collect())
90}
91
92struct Zipper<'genv, 'tcx> {
93 genv: GlobalEnv<'genv, 'tcx>,
94 owner_id: MaybeExternId,
95 locs: UnordMap<rty::Loc, rty::Ty>,
96 holes: Holes,
97 a_binders: u32,
99 b_binder_to_a_binder: Vec<Option<u32>>,
103 errors: Errors<'genv>,
104}
105
106#[derive(Default)]
107struct Holes {
108 sorts: UnordMap<rty::SortVid, rty::Sort>,
109 subset_tys: UnordMap<rty::TyVid, rty::SubsetTy>,
110 types: UnordMap<rty::TyVid, rty::Ty>,
111 regions: UnordMap<rty::RegionVid, rty::Region>,
112 consts: UnordMap<rty::ConstVid, rty::Const>,
113}
114
115impl TypeFolder for &Holes {
116 fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
117 if let rty::Sort::Infer(rty::SortInfer::SortVar(vid)) = sort {
118 self.sorts
119 .get(vid)
120 .cloned()
121 .unwrap_or_else(|| bug!("unfilled sort hole {vid:?}"))
122 } else {
123 sort.super_fold_with(self)
124 }
125 }
126
127 fn fold_ty(&mut self, ty: &rty::Ty) -> rty::Ty {
128 if let rty::TyKind::Infer(vid) = ty.kind() {
129 self.types
130 .get(vid)
131 .cloned()
132 .unwrap_or_else(|| bug!("unfilled type hole {vid:?}"))
133 } else {
134 ty.super_fold_with(self)
135 }
136 }
137
138 fn fold_subset_ty(&mut self, constr: &rty::SubsetTy) -> rty::SubsetTy {
139 if let rty::BaseTy::Infer(vid) = &constr.bty {
140 self.subset_tys
141 .get(vid)
142 .cloned()
143 .unwrap_or_else(|| bug!("unfilled type hole {vid:?}"))
144 } else {
145 constr.super_fold_with(self)
146 }
147 }
148
149 fn fold_region(&mut self, r: &rty::Region) -> rty::Region {
150 if let rty::Region::ReVar(vid) = r {
151 self.regions
152 .get(vid)
153 .copied()
154 .unwrap_or_else(|| bug!("unfilled region hole {vid:?}"))
155 } else {
156 *r
157 }
158 }
159
160 fn fold_const(&mut self, ct: &rty::Const) -> rty::Const {
161 if let rty::ConstKind::Infer(InferConst::Var(cid)) = ct.kind {
162 self.consts
163 .get(&cid)
164 .cloned()
165 .unwrap_or_else(|| bug!("unfilled const hole {cid:?}"))
166 } else {
167 ct.super_fold_with(self)
168 }
169 }
170}
171
172impl Holes {
173 fn replace_holes<T: TypeFoldable>(&self, t: &T) -> T {
174 let mut this = self;
175 t.fold_with(&mut this)
176 }
177}
178
179impl<'genv, 'tcx> Zipper<'genv, 'tcx> {
180 fn new(genv: GlobalEnv<'genv, 'tcx>, owner_id: MaybeExternId) -> Self {
181 Self {
182 genv,
183 owner_id,
184 locs: UnordMap::default(),
185 holes: Default::default(),
186 a_binders: 0,
187 b_binder_to_a_binder: vec![],
188 errors: Errors::new(genv.sess()),
189 }
190 }
191
192 fn zip_poly_fn_sig(&mut self, a: &rty::PolyFnSig, b: &rty::PolyFnSig) -> Result<(), FnSigErr> {
193 self.enter_binders(a, b, |this, a, b| this.zip_fn_sig(a, b))
194 }
195
196 fn zip_variant(&mut self, a: &rty::PolyVariant, b: &rty::PolyVariant, variant_idx: VariantIdx) {
197 self.enter_binders(a, b, |this, a, b| {
198 debug_assert_eq!(a.args, b.args);
200
201 if a.fields.len() != b.fields.len() {
202 this.errors.emit(errors::FieldCountMismatch::new(
203 this.genv,
204 a.fields.len(),
205 this.owner_id,
206 variant_idx,
207 ));
208 return;
209 }
210 for (i, (ty_a, ty_b)) in iter::zip(&a.fields, &b.fields).enumerate() {
211 let field_idx = FieldIdx::from_usize(i);
212 if this.zip_ty(ty_a, ty_b).is_err() {
213 this.errors.emit(errors::IncompatibleRefinement::field(
214 this.genv,
215 this.owner_id,
216 variant_idx,
217 field_idx,
218 ));
219 }
220 }
221 });
222 }
223
224 fn zip_fn_sig(&mut self, a: &rty::FnSig, b: &rty::FnSig) -> Result<(), FnSigErr> {
225 if a.inputs().len() != b.inputs().len() {
226 Err(FnSigErr::ArgCountMismatch)?;
227 }
228 for (i, (ty_a, ty_b)) in iter::zip(a.inputs(), b.inputs()).enumerate() {
229 self.zip_ty(ty_a, ty_b).map_err(|_| FnSigErr::FnInput(i))?;
230 }
231 self.enter_binders(&a.output, &b.output, |this, output_a, output_b| {
232 this.zip_output(output_a, output_b)
233 })
234 }
235
236 fn zip_output(&mut self, a: &rty::FnOutput, b: &rty::FnOutput) -> Result<(), FnSigErr> {
237 self.zip_ty(&a.ret, &b.ret).map_err(FnSigErr::FnOutput)?;
238
239 for (i, ensures) in a.ensures.iter().enumerate() {
240 if let rty::Ensures::Type(path, ty_a) = ensures {
241 let loc = path.to_loc().unwrap();
242 let ty_b = self.locs.get(&loc).unwrap().clone();
243 self.zip_ty(ty_a, &ty_b)
244 .map_err(|_| FnSigErr::Ensures { i, expected: ty_b })?;
245 }
246 }
247 Ok(())
248 }
249
250 fn zip_ty(&mut self, a: &rty::Ty, b: &rty::Ty) -> Result<(), Mismatch> {
251 match (a.kind(), b.kind()) {
252 (rty::TyKind::Infer(vid), _) => {
253 assert_ne!(vid.as_u32(), 0);
254 let b = self.adjust_bvars(b);
255 self.holes.types.insert(*vid, b);
256 Ok(())
257 }
258 (rty::TyKind::Exists(ctor_a), _) => {
259 self.enter_a_binder(ctor_a, |this, ty_a| this.zip_ty(ty_a, b))
260 }
261 (_, rty::TyKind::Exists(ctor_b)) => {
262 self.enter_b_binder(ctor_b, |this, ty_b| this.zip_ty(a, ty_b))
263 }
264 (rty::TyKind::Constr(_, ty_a), _) => self.zip_ty(ty_a, b),
265 (_, rty::TyKind::Constr(_, ty_b)) => self.zip_ty(a, ty_b),
266 (rty::TyKind::Indexed(bty_a, _), rty::TyKind::Indexed(bty_b, _)) => {
267 self.zip_bty(bty_a, bty_b)
268 }
269 (rty::TyKind::StrgRef(re_a, path, ty_a), rty::Ref!(re_b, ty_b, Mutability::Mut)) => {
270 let loc = path.to_loc().unwrap();
271 self.locs.insert(loc, ty_b.clone());
272
273 self.zip_region(re_a, re_b);
274 self.zip_ty(ty_a, ty_b)
275 }
276 (rty::TyKind::Param(pty_a), rty::TyKind::Param(pty_b)) => {
277 assert_eq_or_incompatible(pty_a, pty_b)
278 }
279 (
280 rty::TyKind::Ptr(_, _)
281 | rty::TyKind::Discr(_, _)
282 | rty::TyKind::Downcast(_, _, _, _, _)
283 | rty::TyKind::Blocked(_)
284 | rty::TyKind::Uninit,
285 _,
286 ) => {
287 bug!("unexpected type {a:?}");
288 }
289 _ => Err(Mismatch::new(a, b)),
290 }
291 }
292
293 fn zip_bty(&mut self, a: &rty::BaseTy, b: &rty::BaseTy) -> Result<(), Mismatch> {
294 match (a, b) {
295 (rty::BaseTy::Int(ity_a), rty::BaseTy::Int(ity_b)) => {
296 assert_eq_or_incompatible(ity_a, ity_b)
297 }
298 (rty::BaseTy::Uint(uity_a), rty::BaseTy::Uint(uity_b)) => {
299 assert_eq_or_incompatible(uity_a, uity_b)
300 }
301 (rty::BaseTy::Bool, rty::BaseTy::Bool) => Ok(()),
302 (rty::BaseTy::Str, rty::BaseTy::Str) => Ok(()),
303 (rty::BaseTy::Char, rty::BaseTy::Char) => Ok(()),
304 (rty::BaseTy::Float(fty_a), rty::BaseTy::Float(fty_b)) => {
305 assert_eq_or_incompatible(fty_a, fty_b)
306 }
307 (rty::BaseTy::Slice(ty_a), rty::BaseTy::Slice(ty_b)) => self.zip_ty(ty_a, ty_b),
308 (rty::BaseTy::Adt(adt_def_a, args_a), rty::BaseTy::Adt(adt_def_b, args_b)) => {
309 assert_eq_or_incompatible(adt_def_a.did(), adt_def_b.did())?;
310 assert_eq_or_incompatible(args_a.len(), args_b.len())?;
311 for (arg_a, arg_b) in iter::zip(args_a, args_b) {
312 self.zip_generic_arg(arg_a, arg_b)?;
313 }
314 Ok(())
315 }
316 (rty::BaseTy::RawPtr(ty_a, mutbl_a), rty::BaseTy::RawPtr(ty_b, mutbl_b)) => {
317 assert_eq_or_incompatible(mutbl_a, mutbl_b)?;
318 self.zip_ty(ty_a, ty_b)
319 }
320 (rty::BaseTy::Ref(re_a, ty_a, mutbl_a), rty::BaseTy::Ref(re_b, ty_b, mutbl_b)) => {
321 assert_eq_or_incompatible(mutbl_a, mutbl_b)?;
322 self.zip_region(re_a, re_b);
323 self.zip_ty(ty_a, ty_b)
324 }
325 (rty::BaseTy::FnPtr(poly_sig_a), rty::BaseTy::FnPtr(poly_sig_b)) => {
326 self.zip_poly_fn_sig(poly_sig_a, poly_sig_b)
327 .map_err(|_| Mismatch::new(poly_sig_a, poly_sig_b))
328 }
329 (rty::BaseTy::Tuple(tys_a), rty::BaseTy::Tuple(tys_b)) => {
330 assert_eq_or_incompatible(tys_a.len(), tys_b.len())?;
331 for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
332 self.zip_ty(ty_a, ty_b)?;
333 }
334 Ok(())
335 }
336 (rty::BaseTy::Alias(kind_a, aty_a), rty::BaseTy::Alias(kind_b, aty_b)) => {
337 assert_eq_or_incompatible(kind_a, kind_b)?;
338 assert_eq_or_incompatible(aty_a.def_id, aty_b.def_id)?;
339 assert_eq_or_incompatible(aty_a.args.len(), aty_b.args.len())?;
340 for (arg_a, arg_b) in iter::zip(&aty_a.args, &aty_b.args) {
341 self.zip_generic_arg(arg_a, arg_b)?;
342 }
343 Ok(())
344 }
345 (rty::BaseTy::Array(ty_a, len_a), rty::BaseTy::Array(ty_b, len_b)) => {
346 self.zip_const(len_a, len_b)?;
347 self.zip_ty(ty_a, ty_b)
348 }
349 (rty::BaseTy::Never, rty::BaseTy::Never) => Ok(()),
350 (rty::BaseTy::Param(pty_a), rty::BaseTy::Param(pty_b)) => {
351 assert_eq_or_incompatible(pty_a, pty_b)
352 }
353 (rty::BaseTy::Dynamic(preds_a, re_a), rty::BaseTy::Dynamic(preds_b, re_b)) => {
354 assert_eq_or_incompatible(preds_a.len(), preds_b.len())?;
355 for (pred_a, pred_b) in iter::zip(preds_a, preds_b) {
356 self.zip_poly_existential_pred(pred_a, pred_b)?;
357 }
358 self.zip_region(re_a, re_b);
359 Ok(())
360 }
361 (rty::BaseTy::Foreign(def_id_a), rty::BaseTy::Foreign(def_id_b)) => {
362 assert_eq_or_incompatible(def_id_a, def_id_b)
363 }
364 (rty::BaseTy::Closure(..) | rty::BaseTy::Coroutine(..), _) => {
365 bug!("unexpected type `{a:?}`");
366 }
367 _ => Err(Mismatch::new(a, b)),
368 }
369 }
370
371 fn zip_generic_arg(
372 &mut self,
373 a: &rty::GenericArg,
374 b: &rty::GenericArg,
375 ) -> Result<(), Mismatch> {
376 match (a, b) {
377 (rty::GenericArg::Ty(ty_a), rty::GenericArg::Ty(ty_b)) => self.zip_ty(ty_a, ty_b),
378 (rty::GenericArg::Base(ctor_a), rty::GenericArg::Base(ctor_b)) => {
379 self.zip_sorts(&ctor_a.sort(), &ctor_b.sort());
380 self.enter_binders(ctor_a, ctor_b, |this, sty_a, sty_b| {
381 this.zip_subset_ty(sty_a, sty_b)
382 })
383 }
384 (rty::GenericArg::Lifetime(re_a), rty::GenericArg::Lifetime(re_b)) => {
385 self.zip_region(re_a, re_b);
386 Ok(())
387 }
388 (rty::GenericArg::Const(ct_a), rty::GenericArg::Const(ct_b)) => {
389 self.zip_const(ct_a, ct_b)
390 }
391 _ => Err(Mismatch::new(a, b)),
392 }
393 }
394
395 fn zip_sorts(&mut self, a: &rty::Sort, b: &rty::Sort) {
396 if let rty::Sort::Infer(rty::SortInfer::SortVar(vid)) = a {
397 assert_ne!(vid.as_u32(), 0);
398 self.holes.sorts.insert(*vid, b.clone());
399 }
400 }
401
402 fn zip_subset_ty(&mut self, a: &rty::SubsetTy, b: &rty::SubsetTy) -> Result<(), Mismatch> {
403 if let rty::BaseTy::Infer(vid) = a.bty {
404 assert_ne!(vid.as_u32(), 0);
405 let b = self.adjust_bvars(b);
406 self.holes.subset_tys.insert(vid, b);
407 Ok(())
408 } else {
409 self.zip_bty(&a.bty, &b.bty)
410 }
411 }
412
413 fn zip_const(&mut self, a: &rty::Const, b: &ty::Const) -> Result<(), Mismatch> {
414 match (&a.kind, &b.kind) {
415 (rty::ConstKind::Infer(ty::InferConst::Var(cid)), _) => {
416 self.holes.consts.insert(*cid, b.clone());
417 Ok(())
418 }
419 (rty::ConstKind::Param(param_const_a), ty::ConstKind::Param(param_const_b)) => {
420 assert_eq_or_incompatible(param_const_a, param_const_b)
421 }
422 (rty::ConstKind::Value(ty_a, val_a), ty::ConstKind::Value(ty_b, val_b)) => {
423 assert_eq_or_incompatible(ty_a, ty_b)?;
424 assert_eq_or_incompatible(val_a, val_b)
425 }
426 (rty::ConstKind::Unevaluated(c1), ty::ConstKind::Unevaluated(c2)) => {
427 assert_eq_or_incompatible(c1, c2)
428 }
429 _ => Err(Mismatch::new(a, b)),
430 }
431 }
432
433 fn zip_region(&mut self, a: &rty::Region, b: &ty::Region) {
434 if let rty::Region::ReVar(vid) = a {
435 let re = self.adjust_bvars(b);
436 self.holes.regions.insert(*vid, re);
437 }
438 }
439
440 fn zip_poly_existential_pred(
441 &mut self,
442 a: &rty::Binder<rty::ExistentialPredicate>,
443 b: &rty::Binder<rty::ExistentialPredicate>,
444 ) -> Result<(), Mismatch> {
445 self.enter_binders(a, b, |this, a, b| {
446 match (a, b) {
447 (
448 rty::ExistentialPredicate::Trait(trait_ref_a),
449 rty::ExistentialPredicate::Trait(trait_ref_b),
450 ) => {
451 assert_eq_or_incompatible(trait_ref_a.def_id, trait_ref_b.def_id)?;
452 assert_eq_or_incompatible(trait_ref_a.args.len(), trait_ref_b.args.len())?;
453 for (arg_a, arg_b) in iter::zip(&trait_ref_a.args, &trait_ref_b.args) {
454 this.zip_generic_arg(arg_a, arg_b)?;
455 }
456 Ok(())
457 }
458 (
459 rty::ExistentialPredicate::Projection(projection_a),
460 rty::ExistentialPredicate::Projection(projection_b),
461 ) => {
462 assert_eq_or_incompatible(projection_a.def_id, projection_b.def_id)?;
463 assert_eq_or_incompatible(projection_a.args.len(), projection_b.args.len())?;
464 for (arg_a, arg_b) in iter::zip(&projection_a.args, &projection_b.args) {
465 this.zip_generic_arg(arg_a, arg_b)?;
466 }
467 this.enter_binders(&projection_a.term, &projection_b.term, |this, a, b| {
468 this.zip_bty(&a.bty, &b.bty)
469 })
470 }
471 (
472 rty::ExistentialPredicate::AutoTrait(def_id_a),
473 rty::ExistentialPredicate::AutoTrait(def_id_b),
474 ) => assert_eq_or_incompatible(def_id_a, def_id_b),
475 _ => Err(Mismatch::new(a, b)),
476 }
477 })
478 }
479
480 fn enter_binders<T, R>(
482 &mut self,
483 a: &rty::Binder<T>,
484 b: &rty::Binder<T>,
485 f: impl FnOnce(&mut Self, &T, &T) -> R,
486 ) -> R {
487 self.b_binder_to_a_binder.push(Some(self.a_binders));
488 self.a_binders += 1;
489 let r = f(self, a.skip_binder_ref(), b.skip_binder_ref());
490 self.a_binders -= 1;
491 self.b_binder_to_a_binder.pop();
492 r
493 }
494
495 fn enter_a_binder<T, R>(
497 &mut self,
498 t: &rty::Binder<T>,
499 f: impl FnOnce(&mut Self, &T) -> R,
500 ) -> R {
501 self.a_binders += 1;
502 let r = f(self, t.skip_binder_ref());
503 self.a_binders -= 1;
504 r
505 }
506
507 fn enter_b_binder<T, R>(
509 &mut self,
510 t: &rty::Binder<T>,
511 f: impl FnOnce(&mut Self, &T) -> R,
512 ) -> R {
513 self.b_binder_to_a_binder.push(None);
514 let r = f(self, t.skip_binder_ref());
515 self.b_binder_to_a_binder.pop();
516 r
517 }
518
519 fn adjust_bvars<T: TypeFoldable + Clone + std::fmt::Debug>(&self, t: &T) -> T {
520 struct Adjuster<'a, 'genv, 'tcx> {
521 current_index: DebruijnIndex,
522 zipper: &'a Zipper<'genv, 'tcx>,
523 }
524
525 impl Adjuster<'_, '_, '_> {
526 fn adjust(&self, debruijn: DebruijnIndex) -> DebruijnIndex {
527 let b_binders = self.zipper.b_binder_to_a_binder.len();
528 let mapped_binder = self.zipper.b_binder_to_a_binder
529 [b_binders - debruijn.as_usize() - 1]
530 .unwrap_or_else(|| {
531 bug!("bound var without corresponding binder: `{debruijn:?}`")
532 });
533 DebruijnIndex::from_u32(self.zipper.a_binders - mapped_binder - 1)
534 .shifted_in(self.current_index.as_u32())
535 }
536 }
537
538 impl TypeFolder for Adjuster<'_, '_, '_> {
539 fn fold_binder<T>(&mut self, t: &rty::Binder<T>) -> rty::Binder<T>
540 where
541 T: TypeFoldable,
542 {
543 self.current_index.shift_in(1);
544 let r = t.super_fold_with(self);
545 self.current_index.shift_out(1);
546 r
547 }
548
549 fn fold_region(&mut self, re: &rty::Region) -> rty::Region {
550 if let rty::ReBound(debruijn, br) = *re
551 && debruijn >= self.current_index
552 {
553 rty::ReBound(self.adjust(debruijn), br)
554 } else {
555 *re
556 }
557 }
558
559 fn fold_expr(&mut self, expr: &rty::Expr) -> rty::Expr {
560 if let rty::ExprKind::Var(rty::Var::Bound(debruijn, breft)) = expr.kind()
561 && *debruijn >= self.current_index
562 {
563 rty::Expr::bvar(self.adjust(*debruijn), breft.var, breft.kind)
564 } else {
565 expr.super_fold_with(self)
566 }
567 }
568 }
569 t.fold_with(&mut Adjuster { current_index: INNERMOST, zipper: self })
570 }
571
572 fn emit_fn_sig_err(&mut self, err: FnSigErr, decl: &fhir::FnDecl) {
573 match err {
574 FnSigErr::ArgCountMismatch => {
575 self.errors.emit(errors::IncompatibleParamCount::new(
576 self.genv,
577 decl,
578 self.owner_id,
579 ));
580 }
581 FnSigErr::FnInput(i) => {
582 self.errors.emit(errors::IncompatibleRefinement::fn_input(
583 self.genv,
584 self.owner_id,
585 decl,
586 i,
587 ));
588 }
589 FnSigErr::FnOutput(_) => {
590 self.errors.emit(errors::IncompatibleRefinement::fn_output(
591 self.genv,
592 self.owner_id,
593 decl,
594 ));
595 }
596 FnSigErr::Ensures { i, expected } => {
597 self.errors.emit(errors::IncompatibleRefinement::ensures(
598 self.genv,
599 self.owner_id,
600 decl,
601 &expected,
602 i,
603 ));
604 }
605 }
606 }
607}
608
609fn assert_eq_or_incompatible<T: Eq + fmt::Debug>(a: T, b: T) -> Result<(), Mismatch> {
610 if a != b {
611 return Err(Mismatch::new(a, b));
612 }
613 Ok(())
614}
615
616#[expect(dead_code, reason = "we use the the String for debugging")]
617struct Mismatch(String);
618
619impl Mismatch {
620 fn new<T: fmt::Debug>(a: T, b: T) -> Self {
621 Self(format!("{a:?} != {b:?}"))
622 }
623}
624
625enum FnSigErr {
626 ArgCountMismatch,
627 FnInput(usize),
628 #[expect(dead_code, reason = "we use the struct for debugging")]
629 FnOutput(Mismatch),
630 Ensures {
631 i: usize,
632 expected: rty::Ty,
633 },
634}
635
636mod errors {
637 use flux_common::span_bug;
638 use flux_errors::E0999;
639 use flux_macros::Diagnostic;
640 use flux_middle::{def_id::MaybeExternId, fhir, global_env::GlobalEnv, rty};
641 use flux_rustc_bridge::{
642 ToRustc,
643 ty::{FieldIdx, VariantIdx},
644 };
645 use rustc_span::{DUMMY_SP, Span};
646
647 #[derive(Diagnostic)]
648 #[diag(fhir_analysis_incompatible_refinement, code = E0999)]
649 #[note]
650 pub(super) struct IncompatibleRefinement<'tcx> {
651 #[primary_span]
652 #[label]
653 span: Span,
654 #[label(fhir_analysis_expected_label)]
655 expected_span: Option<Span>,
656 expected_ty: rustc_middle::ty::Ty<'tcx>,
657 def_descr: &'static str,
658 }
659
660 impl<'tcx> IncompatibleRefinement<'tcx> {
661 pub(super) fn type_alias(
662 genv: GlobalEnv<'_, 'tcx>,
663 def_id: MaybeExternId,
664 type_alias: &fhir::TyAlias,
665 ) -> Self {
666 let tcx = genv.tcx();
667 Self {
668 span: type_alias.ty.span,
669 def_descr: tcx.def_descr(def_id.resolved_id()),
670 expected_span: Some(tcx.def_span(def_id)),
671 expected_ty: tcx.type_of(def_id).skip_binder(),
672 }
673 }
674
675 pub(super) fn fn_input(
676 genv: GlobalEnv<'_, 'tcx>,
677 fn_id: MaybeExternId,
678 decl: &fhir::FnDecl,
679 pos: usize,
680 ) -> Self {
681 let expected_span = match fn_id {
682 MaybeExternId::Local(local_id) => {
683 genv.tcx()
684 .hir_node_by_def_id(local_id)
685 .fn_decl()
686 .and_then(|fn_decl| fn_decl.inputs.get(pos))
687 .map(|input| input.span)
688 }
689 MaybeExternId::Extern(_, extern_id) => Some(genv.tcx().def_span(extern_id)),
690 };
691
692 let expected_ty = genv
693 .tcx()
694 .fn_sig(fn_id.resolved_id())
695 .skip_binder()
696 .inputs()
697 .map_bound(|inputs| inputs[pos])
698 .skip_binder();
699
700 Self {
701 span: decl.inputs[pos].span,
702 def_descr: genv.tcx().def_descr(fn_id.resolved_id()),
703 expected_span,
704 expected_ty,
705 }
706 }
707
708 pub(super) fn fn_output(
709 genv: GlobalEnv<'_, 'tcx>,
710 fn_id: MaybeExternId,
711 decl: &fhir::FnDecl,
712 ) -> Self {
713 let expected_span = match fn_id {
714 MaybeExternId::Local(local_id) => {
715 genv.tcx()
716 .hir_node_by_def_id(local_id)
717 .fn_decl()
718 .map(|fn_decl| fn_decl.output.span())
719 }
720 MaybeExternId::Extern(_, extern_id) => Some(genv.tcx().def_span(extern_id)),
721 };
722
723 let expected_ty = genv
724 .tcx()
725 .fn_sig(fn_id.resolved_id())
726 .skip_binder()
727 .output()
728 .skip_binder();
729 let spec_span = decl.output.ret.span;
730 Self {
731 span: spec_span,
732 def_descr: genv.tcx().def_descr(fn_id.resolved_id()),
733 expected_span,
734 expected_ty,
735 }
736 }
737
738 pub(super) fn ensures(
739 genv: GlobalEnv<'_, 'tcx>,
740 fn_id: MaybeExternId,
741 decl: &fhir::FnDecl,
742 expected: &rty::Ty,
743 i: usize,
744 ) -> Self {
745 let fhir::Ensures::Type(_, ty) = &decl.output.ensures[i] else {
746 span_bug!(decl.span, "expected `fhir::Ensures::Type`");
747 };
748 let tcx = genv.tcx();
749 Self {
750 span: ty.span,
751 def_descr: tcx.def_descr(fn_id.resolved_id()),
752 expected_span: None,
753 expected_ty: expected.to_rustc(tcx),
754 }
755 }
756
757 pub(super) fn field(
758 genv: GlobalEnv<'_, 'tcx>,
759 adt_id: MaybeExternId,
760 variant_idx: VariantIdx,
761 field_idx: FieldIdx,
762 ) -> Self {
763 let tcx = genv.tcx();
764 let adt_def = tcx.adt_def(adt_id);
765 let field_def = &adt_def.variant(variant_idx).fields[field_idx];
766
767 let item = genv.map().expect_item(adt_id.local_id()).unwrap();
768 let span = match &item.kind {
769 fhir::ItemKind::Enum(enum_def) => {
770 enum_def.variants[variant_idx.as_usize()].fields[field_idx.as_usize()]
771 .ty
772 .span
773 }
774 fhir::ItemKind::Struct(struct_def)
775 if let fhir::StructKind::Transparent { fields } = &struct_def.kind =>
776 {
777 fields[field_idx.as_usize()].ty.span
778 }
779 _ => DUMMY_SP,
780 };
781
782 Self {
783 span,
784 def_descr: tcx.def_descr(field_def.did),
785 expected_span: Some(tcx.def_span(field_def.did)),
786 expected_ty: tcx.type_of(field_def.did).skip_binder(),
787 }
788 }
789 }
790
791 #[derive(Diagnostic)]
792 #[diag(fhir_analysis_incompatible_param_count, code = E0999)]
793 pub(super) struct IncompatibleParamCount {
794 #[primary_span]
795 #[label]
796 span: Span,
797 found: usize,
798 #[label(fhir_analysis_expected_label)]
799 expected_span: Span,
800 expected: usize,
801 def_descr: &'static str,
802 }
803
804 impl IncompatibleParamCount {
805 pub(super) fn new(genv: GlobalEnv, decl: &fhir::FnDecl, def_id: MaybeExternId) -> Self {
806 let def_descr = genv.tcx().def_descr(def_id.resolved_id());
807
808 let span = if !decl.inputs.is_empty() {
809 decl.inputs[decl.inputs.len() - 1]
810 .span
811 .with_lo(decl.inputs[0].span.lo())
812 } else {
813 decl.span
814 };
815
816 let expected_span = if let Some(local_id) = def_id.as_local()
817 && let expected_decl = genv.tcx().hir_node_by_def_id(local_id).fn_decl().unwrap()
818 && !expected_decl.inputs.is_empty()
819 {
820 expected_decl.inputs[expected_decl.inputs.len() - 1]
821 .span
822 .with_lo(expected_decl.inputs[0].span.lo())
823 } else {
824 genv.tcx().def_span(def_id)
825 };
826
827 let expected = genv
828 .tcx()
829 .fn_sig(def_id)
830 .skip_binder()
831 .skip_binder()
832 .inputs()
833 .len();
834
835 Self { span, found: decl.inputs.len(), expected_span, expected, def_descr }
836 }
837 }
838
839 #[derive(Diagnostic)]
840 #[diag(fhir_analysis_field_count_mismatch, code = E0999)]
841 pub(super) struct FieldCountMismatch {
842 #[primary_span]
843 #[label]
844 span: Span,
845 fields: usize,
846 #[label(fhir_analysis_expected_label)]
847 expected_span: Span,
848 expected_fields: usize,
849 }
850
851 impl FieldCountMismatch {
852 pub(super) fn new(
853 genv: GlobalEnv,
854 found: usize,
855 adt_def_id: MaybeExternId,
856 variant_idx: VariantIdx,
857 ) -> Self {
858 let adt_def = genv.tcx().adt_def(adt_def_id);
859 let expected_variant = adt_def.variant(variant_idx);
860
861 let span = if let Ok(fhir::Node::Item(item)) = genv.map().node(adt_def_id.local_id())
864 && let fhir::ItemKind::Enum(enum_def) = &item.kind
865 && let Some(variant) = enum_def.variants.get(variant_idx.as_usize())
866 {
867 variant.span
868 } else {
869 DUMMY_SP
870 };
871
872 Self {
873 span,
874 fields: found,
875 expected_span: genv.tcx().def_span(expected_variant.def_id),
876 expected_fields: expected_variant.fields.len(),
877 }
878 }
879 }
880}