1use std::{fmt, iter};
7
8use flux_common::bug;
9use flux_errors::Errors;
10use flux_middle::{
11 def_id::MaybeExternId,
12 fhir,
13 global_env::GlobalEnv,
14 queries::QueryResult,
15 rty::{
16 self,
17 fold::{TypeFoldable, TypeFolder, TypeSuperFoldable},
18 refining::{Refine as _, Refiner},
19 },
20};
21use flux_rustc_bridge::ty::{self, FieldIdx, VariantIdx};
22use rustc_ast::Mutability;
23use rustc_data_structures::unord::UnordMap;
24use rustc_type_ir::{DebruijnIndex, INNERMOST, InferConst};
25
26pub(crate) fn type_alias(
27 genv: GlobalEnv,
28 alias: &fhir::TyAlias,
29 alias_ty: &rty::TyCtor,
30 def_id: MaybeExternId,
31) -> QueryResult<rty::TyCtor> {
32 let rust_ty = genv.lower_type_of(def_id.resolved_id())?.skip_binder();
33 let expected = rust_ty.refine(&Refiner::default_for_item(genv, def_id.resolved_id())?)?;
34 let mut zipper = Zipper::new(genv, def_id);
35
36 if zipper
37 .enter_a_binder(alias_ty, |zipper, ty| zipper.zip_ty(ty, &expected))
38 .is_err()
39 {
40 zipper
41 .errors
42 .emit(errors::IncompatibleRefinement::type_alias(genv, def_id, alias));
43 }
44
45 zipper.errors.to_result()?;
46
47 Ok(zipper.holes.replace_holes(alias_ty))
48}
49
50pub(crate) fn fn_sig(
51 genv: GlobalEnv,
52 decl: &fhir::FnDecl,
53 fn_sig: &rty::PolyFnSig,
54 def_id: MaybeExternId,
55) -> QueryResult<rty::PolyFnSig> {
56 let rust_fn_sig = genv.lower_fn_sig(def_id.resolved_id())?.skip_binder();
57
58 let expected = Refiner::default_for_item(genv, def_id.resolved_id())?.refine(&rust_fn_sig)?;
59
60 let mut zipper = Zipper::new(genv, def_id);
61 if let Err(err) = zipper.zip_poly_fn_sig(fn_sig, &expected) {
62 zipper.emit_fn_sig_err(err, decl);
63 }
64
65 zipper.errors.to_result()?;
66
67 Ok(zipper.holes.replace_holes(fn_sig))
68}
69
70pub(crate) fn variants(
71 genv: GlobalEnv,
72 variants: &[rty::PolyVariant],
73 adt_def_id: MaybeExternId,
74) -> QueryResult<Vec<rty::PolyVariant>> {
75 let refiner = Refiner::default_for_item(genv, adt_def_id.resolved_id())?;
76 let mut zipper = Zipper::new(genv, adt_def_id);
77 for (i, variant) in variants.iter().enumerate() {
79 let variant_idx = VariantIdx::from_usize(i);
80 let expected = refiner.refine_variant_def(adt_def_id.resolved_id(), variant_idx)?;
81 zipper.zip_variant(variant, &expected, variant_idx);
82 }
83
84 zipper.errors.to_result()?;
85
86 Ok(variants
87 .iter()
88 .map(|v| zipper.holes.replace_holes(v))
89 .collect())
90}
91
92struct Zipper<'genv, 'tcx> {
93 genv: GlobalEnv<'genv, 'tcx>,
94 owner_id: MaybeExternId,
95 locs: UnordMap<rty::Loc, rty::Ty>,
96 holes: Holes,
97 a_binders: u32,
99 b_binder_to_a_binder: Vec<Option<u32>>,
103 errors: Errors<'genv>,
104}
105
106#[derive(Default)]
107struct Holes {
108 sorts: UnordMap<rty::SortVid, rty::Sort>,
109 subset_tys: UnordMap<rty::TyVid, rty::SubsetTy>,
110 types: UnordMap<rty::TyVid, rty::Ty>,
111 regions: UnordMap<rty::RegionVid, rty::Region>,
112 consts: UnordMap<rty::ConstVid, rty::Const>,
113}
114
115impl TypeFolder for &Holes {
116 fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
117 if let rty::Sort::Infer(vid) = sort {
118 self.sorts
119 .get(vid)
120 .cloned()
121 .unwrap_or_else(|| bug!("unfilled sort hole {vid:?}"))
122 } else {
123 sort.super_fold_with(self)
124 }
125 }
126
127 fn fold_ty(&mut self, ty: &rty::Ty) -> rty::Ty {
128 if let rty::TyKind::Infer(vid) = ty.kind() {
129 self.types
130 .get(vid)
131 .cloned()
132 .unwrap_or_else(|| bug!("unfilled type hole {vid:?}"))
133 } else {
134 ty.super_fold_with(self)
135 }
136 }
137
138 fn fold_subset_ty(&mut self, constr: &rty::SubsetTy) -> rty::SubsetTy {
139 if let rty::BaseTy::Infer(vid) = &constr.bty {
140 self.subset_tys
141 .get(vid)
142 .cloned()
143 .unwrap_or_else(|| bug!("unfilled type hole {vid:?}"))
144 } else {
145 constr.super_fold_with(self)
146 }
147 }
148
149 fn fold_region(&mut self, r: &rty::Region) -> rty::Region {
150 if let rty::Region::ReVar(vid) = r {
151 self.regions
152 .get(vid)
153 .copied()
154 .unwrap_or_else(|| bug!("unfilled region hole {vid:?}"))
155 } else {
156 *r
157 }
158 }
159
160 fn fold_const(&mut self, ct: &rty::Const) -> rty::Const {
161 if let rty::ConstKind::Infer(InferConst::Var(cid)) = ct.kind {
162 self.consts
163 .get(&cid)
164 .cloned()
165 .unwrap_or_else(|| bug!("unfilled const hole {cid:?}"))
166 } else {
167 ct.super_fold_with(self)
168 }
169 }
170}
171
172impl Holes {
173 fn replace_holes<T: TypeFoldable>(&self, t: &T) -> T {
174 let mut this = self;
175 t.fold_with(&mut this)
176 }
177}
178
179impl<'genv, 'tcx> Zipper<'genv, 'tcx> {
180 fn new(genv: GlobalEnv<'genv, 'tcx>, owner_id: MaybeExternId) -> Self {
181 Self {
182 genv,
183 owner_id,
184 locs: UnordMap::default(),
185 holes: Default::default(),
186 a_binders: 0,
187 b_binder_to_a_binder: vec![],
188 errors: Errors::new(genv.sess()),
189 }
190 }
191
192 fn zip_poly_fn_sig(&mut self, a: &rty::PolyFnSig, b: &rty::PolyFnSig) -> Result<(), FnSigErr> {
193 self.enter_binders(a, b, |this, a, b| this.zip_fn_sig(a, b))
194 }
195
196 fn zip_variant(&mut self, a: &rty::PolyVariant, b: &rty::PolyVariant, variant_idx: VariantIdx) {
197 self.enter_binders(a, b, |this, a, b| {
198 debug_assert_eq!(a.args, b.args);
200
201 if a.fields.len() != b.fields.len() {
202 this.errors.emit(errors::FieldCountMismatch::new(
203 this.genv,
204 a.fields.len(),
205 this.owner_id,
206 variant_idx,
207 ));
208 return;
209 }
210 for (i, (ty_a, ty_b)) in iter::zip(&a.fields, &b.fields).enumerate() {
211 let field_idx = FieldIdx::from_usize(i);
212 if this.zip_ty(ty_a, ty_b).is_err() {
213 this.errors.emit(errors::IncompatibleRefinement::field(
214 this.genv,
215 this.owner_id,
216 variant_idx,
217 field_idx,
218 ));
219 }
220 }
221 });
222 }
223
224 fn zip_fn_sig(&mut self, a: &rty::FnSig, b: &rty::FnSig) -> Result<(), FnSigErr> {
225 if a.inputs().len() != b.inputs().len() {
226 Err(FnSigErr::ArgCountMismatch)?;
227 }
228 for (i, (ty_a, ty_b)) in iter::zip(a.inputs(), b.inputs()).enumerate() {
229 self.zip_ty(ty_a, ty_b).map_err(|_| FnSigErr::FnInput(i))?;
230 }
231 self.enter_binders(&a.output, &b.output, |this, output_a, output_b| {
232 this.zip_output(output_a, output_b)
233 })
234 }
235
236 fn zip_output(&mut self, a: &rty::FnOutput, b: &rty::FnOutput) -> Result<(), FnSigErr> {
237 self.zip_ty(&a.ret, &b.ret).map_err(FnSigErr::FnOutput)?;
238
239 for (i, ensures) in a.ensures.iter().enumerate() {
240 if let rty::Ensures::Type(path, ty_a) = ensures {
241 let loc = path.to_loc().unwrap();
242 let ty_b = self.locs.get(&loc).unwrap().shift_in_escaping(1);
243 self.zip_ty(ty_a, &ty_b)
244 .map_err(|_| FnSigErr::Ensures { i, expected: ty_b })?;
245 }
246 }
247 Ok(())
248 }
249
250 fn zip_ty(&mut self, a: &rty::Ty, b: &rty::Ty) -> Result<(), Mismatch> {
251 match (a.kind(), b.kind()) {
252 (rty::TyKind::Infer(vid), _) => {
253 assert_ne!(vid.as_u32(), 0);
254 let b = self.adjust_bvars(b);
255 self.holes.types.insert(*vid, b);
256 Ok(())
257 }
258 (rty::TyKind::Exists(ctor_a), _) => {
259 self.enter_a_binder(ctor_a, |this, ty_a| this.zip_ty(ty_a, b))
260 }
261 (_, rty::TyKind::Exists(ctor_b)) => {
262 self.enter_b_binder(ctor_b, |this, ty_b| this.zip_ty(a, ty_b))
263 }
264 (rty::TyKind::Constr(_, ty_a), _) => self.zip_ty(ty_a, b),
265 (_, rty::TyKind::Constr(_, ty_b)) => self.zip_ty(a, ty_b),
266 (rty::TyKind::Indexed(bty_a, _), rty::TyKind::Indexed(bty_b, _)) => {
267 self.zip_bty(bty_a, bty_b)
268 }
269 (rty::TyKind::StrgRef(re_a, path, ty_a), rty::Ref!(re_b, ty_b, Mutability::Mut)) => {
270 let loc = path.to_loc().unwrap();
271 self.locs.insert(loc, ty_b.clone());
272
273 self.zip_region(re_a, re_b);
274 self.zip_ty(ty_a, ty_b)
275 }
276 (rty::TyKind::Param(pty_a), rty::TyKind::Param(pty_b)) => {
277 assert_eq_or_incompatible(pty_a, pty_b)
278 }
279 (
280 rty::TyKind::Ptr(_, _)
281 | rty::TyKind::Discr(..)
282 | rty::TyKind::Downcast(_, _, _, _, _)
283 | rty::TyKind::Blocked(_)
284 | rty::TyKind::Uninit,
285 _,
286 ) => {
287 bug!("unexpected type {a:?}");
288 }
289 _ => Err(Mismatch::new(a, b)),
290 }
291 }
292
293 fn zip_bty(&mut self, a: &rty::BaseTy, b: &rty::BaseTy) -> Result<(), Mismatch> {
294 match (a, b) {
295 (rty::BaseTy::Int(ity_a), rty::BaseTy::Int(ity_b)) => {
296 assert_eq_or_incompatible(ity_a, ity_b)
297 }
298 (rty::BaseTy::Uint(uity_a), rty::BaseTy::Uint(uity_b)) => {
299 assert_eq_or_incompatible(uity_a, uity_b)
300 }
301 (rty::BaseTy::Bool, rty::BaseTy::Bool) => Ok(()),
302 (rty::BaseTy::Str, rty::BaseTy::Str) => Ok(()),
303 (rty::BaseTy::Char, rty::BaseTy::Char) => Ok(()),
304 (rty::BaseTy::Float(fty_a), rty::BaseTy::Float(fty_b)) => {
305 assert_eq_or_incompatible(fty_a, fty_b)
306 }
307 (rty::BaseTy::Slice(ty_a), rty::BaseTy::Slice(ty_b)) => self.zip_ty(ty_a, ty_b),
308 (rty::BaseTy::Adt(adt_def_a, args_a), rty::BaseTy::Adt(adt_def_b, args_b)) => {
309 assert_eq_or_incompatible(adt_def_a.did(), adt_def_b.did())?;
310 assert_eq_or_incompatible(args_a.len(), args_b.len())?;
311 for (arg_a, arg_b) in iter::zip(args_a, args_b) {
312 self.zip_generic_arg(arg_a, arg_b)?;
313 }
314 Ok(())
315 }
316 (rty::BaseTy::RawPtr(ty_a, mutbl_a), rty::BaseTy::RawPtr(ty_b, mutbl_b)) => {
317 assert_eq_or_incompatible(mutbl_a, mutbl_b)?;
318 self.zip_ty(ty_a, ty_b)
319 }
320 (rty::BaseTy::Ref(re_a, ty_a, mutbl_a), rty::BaseTy::Ref(re_b, ty_b, mutbl_b)) => {
321 assert_eq_or_incompatible(mutbl_a, mutbl_b)?;
322 self.zip_region(re_a, re_b);
323 self.zip_ty(ty_a, ty_b)
324 }
325 (rty::BaseTy::FnPtr(poly_sig_a), rty::BaseTy::FnPtr(poly_sig_b)) => {
326 self.zip_poly_fn_sig(poly_sig_a, poly_sig_b)
327 .map_err(|_| Mismatch::new(poly_sig_a, poly_sig_b))
328 }
329 (rty::BaseTy::Tuple(tys_a), rty::BaseTy::Tuple(tys_b)) => {
330 assert_eq_or_incompatible(tys_a.len(), tys_b.len())?;
331 for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
332 self.zip_ty(ty_a, ty_b)?;
333 }
334 Ok(())
335 }
336 (rty::BaseTy::Alias(kind_a, aty_a), rty::BaseTy::Alias(kind_b, aty_b)) => {
337 assert_eq_or_incompatible(kind_a, kind_b)?;
338 assert_eq_or_incompatible(aty_a.def_id, aty_b.def_id)?;
339 assert_eq_or_incompatible(aty_a.args.len(), aty_b.args.len())?;
340 for (arg_a, arg_b) in iter::zip(&aty_a.args, &aty_b.args) {
341 self.zip_generic_arg(arg_a, arg_b)?;
342 }
343 Ok(())
344 }
345 (rty::BaseTy::Array(ty_a, len_a), rty::BaseTy::Array(ty_b, len_b)) => {
346 self.zip_const(len_a, len_b)?;
347 self.zip_ty(ty_a, ty_b)
348 }
349 (rty::BaseTy::Never, rty::BaseTy::Never) => Ok(()),
350 (rty::BaseTy::Param(pty_a), rty::BaseTy::Param(pty_b)) => {
351 assert_eq_or_incompatible(pty_a, pty_b)
352 }
353 (rty::BaseTy::Dynamic(preds_a, re_a), rty::BaseTy::Dynamic(preds_b, re_b)) => {
354 assert_eq_or_incompatible(preds_a.len(), preds_b.len())?;
355 for (pred_a, pred_b) in iter::zip(preds_a, preds_b) {
356 self.zip_poly_existential_pred(pred_a, pred_b)?;
357 }
358 self.zip_region(re_a, re_b);
359 Ok(())
360 }
361 (rty::BaseTy::Foreign(def_id_a), rty::BaseTy::Foreign(def_id_b)) => {
362 assert_eq_or_incompatible(def_id_a, def_id_b)
363 }
364 (rty::BaseTy::Closure(..) | rty::BaseTy::Coroutine(..), _) => {
365 bug!("unexpected type `{a:?}`");
366 }
367 _ => Err(Mismatch::new(a, b)),
368 }
369 }
370
371 fn zip_generic_arg(
372 &mut self,
373 a: &rty::GenericArg,
374 b: &rty::GenericArg,
375 ) -> Result<(), Mismatch> {
376 match (a, b) {
377 (rty::GenericArg::Ty(ty_a), rty::GenericArg::Ty(ty_b)) => self.zip_ty(ty_a, ty_b),
378 (rty::GenericArg::Base(ctor_a), rty::GenericArg::Base(ctor_b)) => {
379 self.zip_sorts(&ctor_a.sort(), &ctor_b.sort());
380 self.enter_binders(ctor_a, ctor_b, |this, sty_a, sty_b| {
381 this.zip_subset_ty(sty_a, sty_b)
382 })
383 }
384 (rty::GenericArg::Lifetime(re_a), rty::GenericArg::Lifetime(re_b)) => {
385 self.zip_region(re_a, re_b);
386 Ok(())
387 }
388 (rty::GenericArg::Const(ct_a), rty::GenericArg::Const(ct_b)) => {
389 self.zip_const(ct_a, ct_b)
390 }
391 _ => Err(Mismatch::new(a, b)),
392 }
393 }
394
395 fn zip_sorts(&mut self, a: &rty::Sort, b: &rty::Sort) {
396 if let rty::Sort::Infer(vid) = a {
397 assert_ne!(vid.as_u32(), 0);
398 self.holes.sorts.insert(*vid, b.clone());
399 }
400 }
401
402 fn zip_subset_ty(&mut self, a: &rty::SubsetTy, b: &rty::SubsetTy) -> Result<(), Mismatch> {
403 if let rty::BaseTy::Infer(vid) = a.bty {
404 assert_ne!(vid.as_u32(), 0);
405 let b = self.adjust_bvars(b);
406 self.holes.subset_tys.insert(vid, b);
407 Ok(())
408 } else {
409 self.zip_bty(&a.bty, &b.bty)
410 }
411 }
412
413 fn zip_const(&mut self, a: &rty::Const, b: &ty::Const) -> Result<(), Mismatch> {
414 match (&a.kind, &b.kind) {
415 (rty::ConstKind::Infer(ty::InferConst::Var(cid)), _) => {
416 self.holes.consts.insert(*cid, b.clone());
417 Ok(())
418 }
419 (rty::ConstKind::Param(param_const_a), ty::ConstKind::Param(param_const_b)) => {
420 assert_eq_or_incompatible(param_const_a, param_const_b)
421 }
422 (rty::ConstKind::Value(ty_a, val_a), ty::ConstKind::Value(ty_b, val_b)) => {
423 assert_eq_or_incompatible(ty_a, ty_b)?;
424 assert_eq_or_incompatible(val_a, val_b)
425 }
426 (rty::ConstKind::Unevaluated(c1), ty::ConstKind::Unevaluated(c2)) => {
427 assert_eq_or_incompatible(c1, c2)
428 }
429 _ => Err(Mismatch::new(a, b)),
430 }
431 }
432
433 fn zip_region(&mut self, a: &rty::Region, b: &ty::Region) {
434 if let rty::Region::ReVar(vid) = a {
435 let re = self.adjust_bvars(b);
436 self.holes.regions.insert(*vid, re);
437 }
438 }
439
440 fn zip_poly_existential_pred(
441 &mut self,
442 a: &rty::Binder<rty::ExistentialPredicate>,
443 b: &rty::Binder<rty::ExistentialPredicate>,
444 ) -> Result<(), Mismatch> {
445 self.enter_binders(a, b, |this, a, b| {
446 match (a, b) {
447 (
448 rty::ExistentialPredicate::Trait(trait_ref_a),
449 rty::ExistentialPredicate::Trait(trait_ref_b),
450 ) => {
451 assert_eq_or_incompatible(trait_ref_a.def_id, trait_ref_b.def_id)?;
452 assert_eq_or_incompatible(trait_ref_a.args.len(), trait_ref_b.args.len())?;
453 for (arg_a, arg_b) in iter::zip(&trait_ref_a.args, &trait_ref_b.args) {
454 this.zip_generic_arg(arg_a, arg_b)?;
455 }
456 Ok(())
457 }
458 (
459 rty::ExistentialPredicate::Projection(projection_a),
460 rty::ExistentialPredicate::Projection(projection_b),
461 ) => {
462 assert_eq_or_incompatible(projection_a.def_id, projection_b.def_id)?;
463 assert_eq_or_incompatible(projection_a.args.len(), projection_b.args.len())?;
464 for (arg_a, arg_b) in iter::zip(&projection_a.args, &projection_b.args) {
465 this.zip_generic_arg(arg_a, arg_b)?;
466 }
467 this.enter_binders(&projection_a.term, &projection_b.term, |this, a, b| {
468 this.zip_bty(&a.bty, &b.bty)
469 })
470 }
471 (
472 rty::ExistentialPredicate::AutoTrait(def_id_a),
473 rty::ExistentialPredicate::AutoTrait(def_id_b),
474 ) => assert_eq_or_incompatible(def_id_a, def_id_b),
475 _ => Err(Mismatch::new(a, b)),
476 }
477 })
478 }
479
480 fn enter_binders<T, R>(
482 &mut self,
483 a: &rty::Binder<T>,
484 b: &rty::Binder<T>,
485 f: impl FnOnce(&mut Self, &T, &T) -> R,
486 ) -> R {
487 self.b_binder_to_a_binder.push(Some(self.a_binders));
488 self.a_binders += 1;
489 let r = f(self, a.skip_binder_ref(), b.skip_binder_ref());
490 self.a_binders -= 1;
491 self.b_binder_to_a_binder.pop();
492 r
493 }
494
495 fn enter_a_binder<T, R>(
497 &mut self,
498 t: &rty::Binder<T>,
499 f: impl FnOnce(&mut Self, &T) -> R,
500 ) -> R {
501 self.a_binders += 1;
502 let r = f(self, t.skip_binder_ref());
503 self.a_binders -= 1;
504 r
505 }
506
507 fn enter_b_binder<T, R>(
509 &mut self,
510 t: &rty::Binder<T>,
511 f: impl FnOnce(&mut Self, &T) -> R,
512 ) -> R {
513 self.b_binder_to_a_binder.push(None);
514 let r = f(self, t.skip_binder_ref());
515 self.b_binder_to_a_binder.pop();
516 r
517 }
518
519 fn adjust_bvars<T: TypeFoldable + Clone + std::fmt::Debug>(&self, t: &T) -> T {
520 struct Adjuster<'a, 'genv, 'tcx> {
521 current_index: DebruijnIndex,
522 zipper: &'a Zipper<'genv, 'tcx>,
523 }
524
525 impl Adjuster<'_, '_, '_> {
526 fn adjust(&self, debruijn: DebruijnIndex) -> DebruijnIndex {
527 let b_binders = self.zipper.b_binder_to_a_binder.len();
528 let mapped_binder = self.zipper.b_binder_to_a_binder
529 [b_binders - debruijn.as_usize() - 1]
530 .unwrap_or_else(|| {
531 bug!("bound var without corresponding binder: `{debruijn:?}`")
532 });
533 DebruijnIndex::from_u32(self.zipper.a_binders - mapped_binder - 1)
534 .shifted_in(self.current_index.as_u32())
535 }
536 }
537
538 impl TypeFolder for Adjuster<'_, '_, '_> {
539 fn enter_binder(&mut self, _: &rty::BoundVariableKinds) {
540 self.current_index.shift_in(1);
541 }
542
543 fn exit_binder(&mut self) {
544 self.current_index.shift_out(1);
545 }
546
547 fn fold_region(&mut self, re: &rty::Region) -> rty::Region {
548 if let rty::ReBound(debruijn, br) = *re
549 && debruijn >= self.current_index
550 {
551 rty::ReBound(self.adjust(debruijn), br)
552 } else {
553 *re
554 }
555 }
556
557 fn fold_expr(&mut self, expr: &rty::Expr) -> rty::Expr {
558 if let rty::ExprKind::Var(rty::Var::Bound(debruijn, breft)) = expr.kind()
559 && *debruijn >= self.current_index
560 {
561 rty::Expr::bvar(self.adjust(*debruijn), breft.var, breft.kind)
562 } else {
563 expr.super_fold_with(self)
564 }
565 }
566 }
567 t.fold_with(&mut Adjuster { current_index: INNERMOST, zipper: self })
568 }
569
570 fn emit_fn_sig_err(&mut self, err: FnSigErr, decl: &fhir::FnDecl) {
571 match err {
572 FnSigErr::ArgCountMismatch => {
573 self.errors.emit(errors::IncompatibleParamCount::new(
574 self.genv,
575 decl,
576 self.owner_id,
577 ));
578 }
579 FnSigErr::FnInput(i) => {
580 self.errors.emit(errors::IncompatibleRefinement::fn_input(
581 self.genv,
582 self.owner_id,
583 decl,
584 i,
585 ));
586 }
587 FnSigErr::FnOutput(_) => {
588 self.errors.emit(errors::IncompatibleRefinement::fn_output(
589 self.genv,
590 self.owner_id,
591 decl,
592 ));
593 }
594 FnSigErr::Ensures { i, expected } => {
595 self.errors.emit(errors::IncompatibleRefinement::ensures(
596 self.genv,
597 self.owner_id,
598 decl,
599 &expected,
600 i,
601 ));
602 }
603 }
604 }
605}
606
607fn assert_eq_or_incompatible<T: Eq + fmt::Debug>(a: T, b: T) -> Result<(), Mismatch> {
608 if a != b {
609 return Err(Mismatch::new(a, b));
610 }
611 Ok(())
612}
613
614#[expect(dead_code, reason = "we use the the String for debugging")]
615struct Mismatch(String);
616
617impl Mismatch {
618 fn new<T: fmt::Debug>(a: T, b: T) -> Self {
619 Self(format!("{a:?} != {b:?}"))
620 }
621}
622
623enum FnSigErr {
624 ArgCountMismatch,
625 FnInput(usize),
626 #[expect(dead_code, reason = "we use the struct for debugging")]
627 FnOutput(Mismatch),
628 Ensures {
629 i: usize,
630 expected: rty::Ty,
631 },
632}
633
634mod errors {
635 use flux_common::span_bug;
636 use flux_errors::E0999;
637 use flux_macros::Diagnostic;
638 use flux_middle::{def_id::MaybeExternId, fhir, global_env::GlobalEnv, rty};
639 use flux_rustc_bridge::{
640 ToRustc,
641 ty::{FieldIdx, VariantIdx},
642 };
643 use rustc_span::{DUMMY_SP, Span};
644
645 #[derive(Diagnostic)]
646 #[diag(fhir_analysis_incompatible_refinement, code = E0999)]
647 #[note]
648 pub(super) struct IncompatibleRefinement<'tcx> {
649 #[primary_span]
650 #[label]
651 span: Span,
652 #[label(fhir_analysis_expected_label)]
653 expected_span: Option<Span>,
654 expected_ty: rustc_middle::ty::Ty<'tcx>,
655 def_descr: &'static str,
656 }
657
658 impl<'tcx> IncompatibleRefinement<'tcx> {
659 pub(super) fn type_alias(
660 genv: GlobalEnv<'_, 'tcx>,
661 def_id: MaybeExternId,
662 type_alias: &fhir::TyAlias,
663 ) -> Self {
664 let tcx = genv.tcx();
665 Self {
666 span: type_alias.ty.span,
667 def_descr: tcx.def_descr(def_id.resolved_id()),
668 expected_span: Some(tcx.def_span(def_id)),
669 expected_ty: tcx.type_of(def_id).skip_binder(),
670 }
671 }
672
673 pub(super) fn fn_input(
674 genv: GlobalEnv<'_, 'tcx>,
675 fn_id: MaybeExternId,
676 decl: &fhir::FnDecl,
677 pos: usize,
678 ) -> Self {
679 let expected_span = match fn_id {
680 MaybeExternId::Local(local_id) => {
681 genv.tcx()
682 .hir_node_by_def_id(local_id)
683 .fn_decl()
684 .and_then(|fn_decl| fn_decl.inputs.get(pos))
685 .map(|input| input.span)
686 }
687 MaybeExternId::Extern(_, extern_id) => Some(genv.tcx().def_span(extern_id)),
688 };
689
690 let expected_ty = genv
691 .tcx()
692 .fn_sig(fn_id.resolved_id())
693 .skip_binder()
694 .inputs()
695 .map_bound(|inputs| inputs[pos])
696 .skip_binder();
697
698 Self {
699 span: decl.inputs[pos].span,
700 def_descr: genv.tcx().def_descr(fn_id.resolved_id()),
701 expected_span,
702 expected_ty,
703 }
704 }
705
706 pub(super) fn fn_output(
707 genv: GlobalEnv<'_, 'tcx>,
708 fn_id: MaybeExternId,
709 decl: &fhir::FnDecl,
710 ) -> Self {
711 let expected_span = match fn_id {
712 MaybeExternId::Local(local_id) => {
713 genv.tcx()
714 .hir_node_by_def_id(local_id)
715 .fn_decl()
716 .map(|fn_decl| fn_decl.output.span())
717 }
718 MaybeExternId::Extern(_, extern_id) => Some(genv.tcx().def_span(extern_id)),
719 };
720
721 let expected_ty = genv
722 .tcx()
723 .fn_sig(fn_id.resolved_id())
724 .skip_binder()
725 .output()
726 .skip_binder();
727 let spec_span = decl.output.ret.span;
728 Self {
729 span: spec_span,
730 def_descr: genv.tcx().def_descr(fn_id.resolved_id()),
731 expected_span,
732 expected_ty,
733 }
734 }
735
736 pub(super) fn ensures(
737 genv: GlobalEnv<'_, 'tcx>,
738 fn_id: MaybeExternId,
739 decl: &fhir::FnDecl,
740 expected: &rty::Ty,
741 i: usize,
742 ) -> Self {
743 let fhir::Ensures::Type(_, ty) = &decl.output.ensures[i] else {
744 span_bug!(decl.span, "expected `fhir::Ensures::Type`");
745 };
746 let tcx = genv.tcx();
747 Self {
748 span: ty.span,
749 def_descr: tcx.def_descr(fn_id.resolved_id()),
750 expected_span: None,
751 expected_ty: expected.to_rustc(tcx),
752 }
753 }
754
755 pub(super) fn field(
756 genv: GlobalEnv<'_, 'tcx>,
757 adt_id: MaybeExternId,
758 variant_idx: VariantIdx,
759 field_idx: FieldIdx,
760 ) -> Self {
761 let tcx = genv.tcx();
762 let adt_def = tcx.adt_def(adt_id);
763 let field_def = &adt_def.variant(variant_idx).fields[field_idx];
764
765 let item = genv.fhir_expect_item(adt_id.local_id()).unwrap();
766 let span = match &item.kind {
767 fhir::ItemKind::Enum(enum_def) => {
768 enum_def.variants[variant_idx.as_usize()].fields[field_idx.as_usize()]
769 .ty
770 .span
771 }
772 fhir::ItemKind::Struct(struct_def)
773 if let fhir::StructKind::Transparent { fields } = &struct_def.kind =>
774 {
775 fields[field_idx.as_usize()].ty.span
776 }
777 _ => DUMMY_SP,
778 };
779
780 Self {
781 span,
782 def_descr: tcx.def_descr(field_def.did),
783 expected_span: Some(tcx.def_span(field_def.did)),
784 expected_ty: tcx.type_of(field_def.did).skip_binder(),
785 }
786 }
787 }
788
789 #[derive(Diagnostic)]
790 #[diag(fhir_analysis_incompatible_param_count, code = E0999)]
791 pub(super) struct IncompatibleParamCount {
792 #[primary_span]
793 #[label]
794 span: Span,
795 found: usize,
796 #[label(fhir_analysis_expected_label)]
797 expected_span: Span,
798 expected: usize,
799 def_descr: &'static str,
800 }
801
802 impl IncompatibleParamCount {
803 pub(super) fn new(genv: GlobalEnv, decl: &fhir::FnDecl, def_id: MaybeExternId) -> Self {
804 let def_descr = genv.tcx().def_descr(def_id.resolved_id());
805
806 let span = if !decl.inputs.is_empty() {
807 decl.inputs[decl.inputs.len() - 1]
808 .span
809 .with_lo(decl.inputs[0].span.lo())
810 } else {
811 decl.span
812 };
813
814 let expected_span = if let Some(local_id) = def_id.as_local()
815 && let expected_decl = genv.tcx().hir_node_by_def_id(local_id).fn_decl().unwrap()
816 && !expected_decl.inputs.is_empty()
817 {
818 expected_decl.inputs[expected_decl.inputs.len() - 1]
819 .span
820 .with_lo(expected_decl.inputs[0].span.lo())
821 } else {
822 genv.tcx().def_span(def_id)
823 };
824
825 let expected = genv
826 .tcx()
827 .fn_sig(def_id)
828 .skip_binder()
829 .skip_binder()
830 .inputs()
831 .len();
832
833 Self { span, found: decl.inputs.len(), expected_span, expected, def_descr }
834 }
835 }
836
837 #[derive(Diagnostic)]
838 #[diag(fhir_analysis_field_count_mismatch, code = E0999)]
839 pub(super) struct FieldCountMismatch {
840 #[primary_span]
841 #[label]
842 span: Span,
843 fields: usize,
844 #[label(fhir_analysis_expected_label)]
845 expected_span: Span,
846 expected_fields: usize,
847 }
848
849 impl FieldCountMismatch {
850 pub(super) fn new(
851 genv: GlobalEnv,
852 found: usize,
853 adt_def_id: MaybeExternId,
854 variant_idx: VariantIdx,
855 ) -> Self {
856 let adt_def = genv.tcx().adt_def(adt_def_id);
857 let expected_variant = adt_def.variant(variant_idx);
858
859 let span = if let Ok(fhir::Node::Item(item)) = genv.fhir_node(adt_def_id.local_id())
862 && let fhir::ItemKind::Enum(enum_def) = &item.kind
863 && let Some(variant) = enum_def.variants.get(variant_idx.as_usize())
864 {
865 variant.span
866 } else {
867 DUMMY_SP
868 };
869
870 Self {
871 span,
872 fields: found,
873 expected_span: genv.tcx().def_span(expected_variant.def_id),
874 expected_fields: expected_variant.fields.len(),
875 }
876 }
877 }
878}