flux_fhir_analysis/conv/
struct_compat.rs

1//! Check whether two refinemnt types/signatures are structurally compatible.
2//!
3//! Used to check if a user spec is compatible with the underlying rust type. The code also
4//! infer types annotated with `_` in the surface syntax.
5
6use std::{fmt, iter};
7
8use flux_common::bug;
9use flux_errors::Errors;
10use flux_middle::{
11    def_id::MaybeExternId,
12    fhir,
13    global_env::GlobalEnv,
14    queries::QueryResult,
15    rty::{
16        self,
17        fold::{TypeFoldable, TypeFolder, TypeSuperFoldable},
18        refining::{Refine as _, Refiner},
19    },
20};
21use flux_rustc_bridge::ty::{self, FieldIdx, VariantIdx};
22use rustc_ast::Mutability;
23use rustc_data_structures::unord::UnordMap;
24use rustc_type_ir::{DebruijnIndex, INNERMOST, InferConst};
25
26pub(crate) fn type_alias(
27    genv: GlobalEnv,
28    alias: &fhir::TyAlias,
29    alias_ty: &rty::TyCtor,
30    def_id: MaybeExternId,
31) -> QueryResult<rty::TyCtor> {
32    let rust_ty = genv.lower_type_of(def_id.resolved_id())?.skip_binder();
33    let expected = rust_ty.refine(&Refiner::default_for_item(genv, def_id.resolved_id())?)?;
34    let mut zipper = Zipper::new(genv, def_id);
35
36    if zipper
37        .enter_a_binder(alias_ty, |zipper, ty| zipper.zip_ty(ty, &expected))
38        .is_err()
39    {
40        zipper
41            .errors
42            .emit(errors::IncompatibleRefinement::type_alias(genv, def_id, alias));
43    }
44
45    zipper.errors.to_result()?;
46
47    Ok(zipper.holes.replace_holes(alias_ty))
48}
49
50pub(crate) fn fn_sig(
51    genv: GlobalEnv,
52    decl: &fhir::FnDecl,
53    fn_sig: &rty::PolyFnSig,
54    def_id: MaybeExternId,
55) -> QueryResult<rty::PolyFnSig> {
56    let rust_fn_sig = genv.lower_fn_sig(def_id.resolved_id())?.skip_binder();
57
58    let expected = Refiner::default_for_item(genv, def_id.resolved_id())?.refine(&rust_fn_sig)?;
59
60    let mut zipper = Zipper::new(genv, def_id);
61    if let Err(err) = zipper.zip_poly_fn_sig(fn_sig, &expected) {
62        zipper.emit_fn_sig_err(err, decl);
63    }
64
65    zipper.errors.to_result()?;
66
67    Ok(zipper.holes.replace_holes(fn_sig))
68}
69
70pub(crate) fn variants(
71    genv: GlobalEnv,
72    variants: &[rty::PolyVariant],
73    adt_def_id: MaybeExternId,
74) -> QueryResult<Vec<rty::PolyVariant>> {
75    let refiner = Refiner::default_for_item(genv, adt_def_id.resolved_id())?;
76    let mut zipper = Zipper::new(genv, adt_def_id);
77    // TODO check same number of variants
78    for (i, variant) in variants.iter().enumerate() {
79        let variant_idx = VariantIdx::from_usize(i);
80        let expected = refiner.refine_variant_def(adt_def_id.resolved_id(), variant_idx)?;
81        zipper.zip_variant(variant, &expected, variant_idx);
82    }
83
84    zipper.errors.to_result()?;
85
86    Ok(variants
87        .iter()
88        .map(|v| zipper.holes.replace_holes(v))
89        .collect())
90}
91
92struct Zipper<'genv, 'tcx> {
93    genv: GlobalEnv<'genv, 'tcx>,
94    owner_id: MaybeExternId,
95    locs: UnordMap<rty::Loc, rty::Ty>,
96    holes: Holes,
97    /// Number of binders we've entered in `a`
98    a_binders: u32,
99    /// Each element in the vector correspond to a binder in `b`. For some binders we map it to
100    /// a corresponding binder in `a`. We assume that expressions filling holes will only contain
101    /// variables pointing to some of these mapped binders.
102    b_binder_to_a_binder: Vec<Option<u32>>,
103    errors: Errors<'genv>,
104}
105
106#[derive(Default)]
107struct Holes {
108    sorts: UnordMap<rty::SortVid, rty::Sort>,
109    subset_tys: UnordMap<rty::TyVid, rty::SubsetTy>,
110    types: UnordMap<rty::TyVid, rty::Ty>,
111    regions: UnordMap<rty::RegionVid, rty::Region>,
112    consts: UnordMap<rty::ConstVid, rty::Const>,
113}
114
115impl TypeFolder for &Holes {
116    fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
117        if let rty::Sort::Infer(rty::SortInfer::SortVar(vid)) = sort {
118            self.sorts
119                .get(vid)
120                .cloned()
121                .unwrap_or_else(|| bug!("unfilled sort hole {vid:?}"))
122        } else {
123            sort.super_fold_with(self)
124        }
125    }
126
127    fn fold_ty(&mut self, ty: &rty::Ty) -> rty::Ty {
128        if let rty::TyKind::Infer(vid) = ty.kind() {
129            self.types
130                .get(vid)
131                .cloned()
132                .unwrap_or_else(|| bug!("unfilled type hole {vid:?}"))
133        } else {
134            ty.super_fold_with(self)
135        }
136    }
137
138    fn fold_subset_ty(&mut self, constr: &rty::SubsetTy) -> rty::SubsetTy {
139        if let rty::BaseTy::Infer(vid) = &constr.bty {
140            self.subset_tys
141                .get(vid)
142                .cloned()
143                .unwrap_or_else(|| bug!("unfilled type hole {vid:?}"))
144        } else {
145            constr.super_fold_with(self)
146        }
147    }
148
149    fn fold_region(&mut self, r: &rty::Region) -> rty::Region {
150        if let rty::Region::ReVar(vid) = r {
151            self.regions
152                .get(vid)
153                .copied()
154                .unwrap_or_else(|| bug!("unfilled region hole {vid:?}"))
155        } else {
156            *r
157        }
158    }
159
160    fn fold_const(&mut self, ct: &rty::Const) -> rty::Const {
161        if let rty::ConstKind::Infer(InferConst::Var(cid)) = ct.kind {
162            self.consts
163                .get(&cid)
164                .cloned()
165                .unwrap_or_else(|| bug!("unfilled const hole {cid:?}"))
166        } else {
167            ct.super_fold_with(self)
168        }
169    }
170}
171
172impl Holes {
173    fn replace_holes<T: TypeFoldable>(&self, t: &T) -> T {
174        let mut this = self;
175        t.fold_with(&mut this)
176    }
177}
178
179impl<'genv, 'tcx> Zipper<'genv, 'tcx> {
180    fn new(genv: GlobalEnv<'genv, 'tcx>, owner_id: MaybeExternId) -> Self {
181        Self {
182            genv,
183            owner_id,
184            locs: UnordMap::default(),
185            holes: Default::default(),
186            a_binders: 0,
187            b_binder_to_a_binder: vec![],
188            errors: Errors::new(genv.sess()),
189        }
190    }
191
192    fn zip_poly_fn_sig(&mut self, a: &rty::PolyFnSig, b: &rty::PolyFnSig) -> Result<(), FnSigErr> {
193        self.enter_binders(a, b, |this, a, b| this.zip_fn_sig(a, b))
194    }
195
196    fn zip_variant(&mut self, a: &rty::PolyVariant, b: &rty::PolyVariant, variant_idx: VariantIdx) {
197        self.enter_binders(a, b, |this, a, b| {
198            // The args are always `GenericArgs::identity_for_item` inside the `EarlyBinder`
199            debug_assert_eq!(a.args, b.args);
200
201            if a.fields.len() != b.fields.len() {
202                this.errors.emit(errors::FieldCountMismatch::new(
203                    this.genv,
204                    a.fields.len(),
205                    this.owner_id,
206                    variant_idx,
207                ));
208                return;
209            }
210            for (i, (ty_a, ty_b)) in iter::zip(&a.fields, &b.fields).enumerate() {
211                let field_idx = FieldIdx::from_usize(i);
212                if this.zip_ty(ty_a, ty_b).is_err() {
213                    this.errors.emit(errors::IncompatibleRefinement::field(
214                        this.genv,
215                        this.owner_id,
216                        variant_idx,
217                        field_idx,
218                    ));
219                }
220            }
221        });
222    }
223
224    fn zip_fn_sig(&mut self, a: &rty::FnSig, b: &rty::FnSig) -> Result<(), FnSigErr> {
225        if a.inputs().len() != b.inputs().len() {
226            Err(FnSigErr::ArgCountMismatch)?;
227        }
228        for (i, (ty_a, ty_b)) in iter::zip(a.inputs(), b.inputs()).enumerate() {
229            self.zip_ty(ty_a, ty_b).map_err(|_| FnSigErr::FnInput(i))?;
230        }
231        self.enter_binders(&a.output, &b.output, |this, output_a, output_b| {
232            this.zip_output(output_a, output_b)
233        })
234    }
235
236    fn zip_output(&mut self, a: &rty::FnOutput, b: &rty::FnOutput) -> Result<(), FnSigErr> {
237        self.zip_ty(&a.ret, &b.ret).map_err(FnSigErr::FnOutput)?;
238
239        for (i, ensures) in a.ensures.iter().enumerate() {
240            if let rty::Ensures::Type(path, ty_a) = ensures {
241                let loc = path.to_loc().unwrap();
242                let ty_b = self.locs.get(&loc).unwrap().shift_in_escaping(1);
243                self.zip_ty(ty_a, &ty_b)
244                    .map_err(|_| FnSigErr::Ensures { i, expected: ty_b })?;
245            }
246        }
247        Ok(())
248    }
249
250    fn zip_ty(&mut self, a: &rty::Ty, b: &rty::Ty) -> Result<(), Mismatch> {
251        match (a.kind(), b.kind()) {
252            (rty::TyKind::Infer(vid), _) => {
253                assert_ne!(vid.as_u32(), 0);
254                let b = self.adjust_bvars(b);
255                self.holes.types.insert(*vid, b);
256                Ok(())
257            }
258            (rty::TyKind::Exists(ctor_a), _) => {
259                self.enter_a_binder(ctor_a, |this, ty_a| this.zip_ty(ty_a, b))
260            }
261            (_, rty::TyKind::Exists(ctor_b)) => {
262                self.enter_b_binder(ctor_b, |this, ty_b| this.zip_ty(a, ty_b))
263            }
264            (rty::TyKind::Constr(_, ty_a), _) => self.zip_ty(ty_a, b),
265            (_, rty::TyKind::Constr(_, ty_b)) => self.zip_ty(a, ty_b),
266            (rty::TyKind::Indexed(bty_a, _), rty::TyKind::Indexed(bty_b, _)) => {
267                self.zip_bty(bty_a, bty_b)
268            }
269            (rty::TyKind::StrgRef(re_a, path, ty_a), rty::Ref!(re_b, ty_b, Mutability::Mut)) => {
270                let loc = path.to_loc().unwrap();
271                self.locs.insert(loc, ty_b.clone());
272
273                self.zip_region(re_a, re_b);
274                self.zip_ty(ty_a, ty_b)
275            }
276            (rty::TyKind::Param(pty_a), rty::TyKind::Param(pty_b)) => {
277                assert_eq_or_incompatible(pty_a, pty_b)
278            }
279            (
280                rty::TyKind::Ptr(_, _)
281                | rty::TyKind::Discr(..)
282                | rty::TyKind::Downcast(_, _, _, _, _)
283                | rty::TyKind::Blocked(_)
284                | rty::TyKind::Uninit,
285                _,
286            ) => {
287                bug!("unexpected type {a:?}");
288            }
289            _ => Err(Mismatch::new(a, b)),
290        }
291    }
292
293    fn zip_bty(&mut self, a: &rty::BaseTy, b: &rty::BaseTy) -> Result<(), Mismatch> {
294        match (a, b) {
295            (rty::BaseTy::Int(ity_a), rty::BaseTy::Int(ity_b)) => {
296                assert_eq_or_incompatible(ity_a, ity_b)
297            }
298            (rty::BaseTy::Uint(uity_a), rty::BaseTy::Uint(uity_b)) => {
299                assert_eq_or_incompatible(uity_a, uity_b)
300            }
301            (rty::BaseTy::Bool, rty::BaseTy::Bool) => Ok(()),
302            (rty::BaseTy::Str, rty::BaseTy::Str) => Ok(()),
303            (rty::BaseTy::Char, rty::BaseTy::Char) => Ok(()),
304            (rty::BaseTy::Float(fty_a), rty::BaseTy::Float(fty_b)) => {
305                assert_eq_or_incompatible(fty_a, fty_b)
306            }
307            (rty::BaseTy::Slice(ty_a), rty::BaseTy::Slice(ty_b)) => self.zip_ty(ty_a, ty_b),
308            (rty::BaseTy::Adt(adt_def_a, args_a), rty::BaseTy::Adt(adt_def_b, args_b)) => {
309                assert_eq_or_incompatible(adt_def_a.did(), adt_def_b.did())?;
310                assert_eq_or_incompatible(args_a.len(), args_b.len())?;
311                for (arg_a, arg_b) in iter::zip(args_a, args_b) {
312                    self.zip_generic_arg(arg_a, arg_b)?;
313                }
314                Ok(())
315            }
316            (rty::BaseTy::RawPtr(ty_a, mutbl_a), rty::BaseTy::RawPtr(ty_b, mutbl_b)) => {
317                assert_eq_or_incompatible(mutbl_a, mutbl_b)?;
318                self.zip_ty(ty_a, ty_b)
319            }
320            (rty::BaseTy::Ref(re_a, ty_a, mutbl_a), rty::BaseTy::Ref(re_b, ty_b, mutbl_b)) => {
321                assert_eq_or_incompatible(mutbl_a, mutbl_b)?;
322                self.zip_region(re_a, re_b);
323                self.zip_ty(ty_a, ty_b)
324            }
325            (rty::BaseTy::FnPtr(poly_sig_a), rty::BaseTy::FnPtr(poly_sig_b)) => {
326                self.zip_poly_fn_sig(poly_sig_a, poly_sig_b)
327                    .map_err(|_| Mismatch::new(poly_sig_a, poly_sig_b))
328            }
329            (rty::BaseTy::Tuple(tys_a), rty::BaseTy::Tuple(tys_b)) => {
330                assert_eq_or_incompatible(tys_a.len(), tys_b.len())?;
331                for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
332                    self.zip_ty(ty_a, ty_b)?;
333                }
334                Ok(())
335            }
336            (rty::BaseTy::Alias(kind_a, aty_a), rty::BaseTy::Alias(kind_b, aty_b)) => {
337                assert_eq_or_incompatible(kind_a, kind_b)?;
338                assert_eq_or_incompatible(aty_a.def_id, aty_b.def_id)?;
339                assert_eq_or_incompatible(aty_a.args.len(), aty_b.args.len())?;
340                for (arg_a, arg_b) in iter::zip(&aty_a.args, &aty_b.args) {
341                    self.zip_generic_arg(arg_a, arg_b)?;
342                }
343                Ok(())
344            }
345            (rty::BaseTy::Array(ty_a, len_a), rty::BaseTy::Array(ty_b, len_b)) => {
346                self.zip_const(len_a, len_b)?;
347                self.zip_ty(ty_a, ty_b)
348            }
349            (rty::BaseTy::Never, rty::BaseTy::Never) => Ok(()),
350            (rty::BaseTy::Param(pty_a), rty::BaseTy::Param(pty_b)) => {
351                assert_eq_or_incompatible(pty_a, pty_b)
352            }
353            (rty::BaseTy::Dynamic(preds_a, re_a), rty::BaseTy::Dynamic(preds_b, re_b)) => {
354                assert_eq_or_incompatible(preds_a.len(), preds_b.len())?;
355                for (pred_a, pred_b) in iter::zip(preds_a, preds_b) {
356                    self.zip_poly_existential_pred(pred_a, pred_b)?;
357                }
358                self.zip_region(re_a, re_b);
359                Ok(())
360            }
361            (rty::BaseTy::Foreign(def_id_a), rty::BaseTy::Foreign(def_id_b)) => {
362                assert_eq_or_incompatible(def_id_a, def_id_b)
363            }
364            (rty::BaseTy::Closure(..) | rty::BaseTy::Coroutine(..), _) => {
365                bug!("unexpected type `{a:?}`");
366            }
367            _ => Err(Mismatch::new(a, b)),
368        }
369    }
370
371    fn zip_generic_arg(
372        &mut self,
373        a: &rty::GenericArg,
374        b: &rty::GenericArg,
375    ) -> Result<(), Mismatch> {
376        match (a, b) {
377            (rty::GenericArg::Ty(ty_a), rty::GenericArg::Ty(ty_b)) => self.zip_ty(ty_a, ty_b),
378            (rty::GenericArg::Base(ctor_a), rty::GenericArg::Base(ctor_b)) => {
379                self.zip_sorts(&ctor_a.sort(), &ctor_b.sort());
380                self.enter_binders(ctor_a, ctor_b, |this, sty_a, sty_b| {
381                    this.zip_subset_ty(sty_a, sty_b)
382                })
383            }
384            (rty::GenericArg::Lifetime(re_a), rty::GenericArg::Lifetime(re_b)) => {
385                self.zip_region(re_a, re_b);
386                Ok(())
387            }
388            (rty::GenericArg::Const(ct_a), rty::GenericArg::Const(ct_b)) => {
389                self.zip_const(ct_a, ct_b)
390            }
391            _ => Err(Mismatch::new(a, b)),
392        }
393    }
394
395    fn zip_sorts(&mut self, a: &rty::Sort, b: &rty::Sort) {
396        if let rty::Sort::Infer(rty::SortInfer::SortVar(vid)) = a {
397            assert_ne!(vid.as_u32(), 0);
398            self.holes.sorts.insert(*vid, b.clone());
399        }
400    }
401
402    fn zip_subset_ty(&mut self, a: &rty::SubsetTy, b: &rty::SubsetTy) -> Result<(), Mismatch> {
403        if let rty::BaseTy::Infer(vid) = a.bty {
404            assert_ne!(vid.as_u32(), 0);
405            let b = self.adjust_bvars(b);
406            self.holes.subset_tys.insert(vid, b);
407            Ok(())
408        } else {
409            self.zip_bty(&a.bty, &b.bty)
410        }
411    }
412
413    fn zip_const(&mut self, a: &rty::Const, b: &ty::Const) -> Result<(), Mismatch> {
414        match (&a.kind, &b.kind) {
415            (rty::ConstKind::Infer(ty::InferConst::Var(cid)), _) => {
416                self.holes.consts.insert(*cid, b.clone());
417                Ok(())
418            }
419            (rty::ConstKind::Param(param_const_a), ty::ConstKind::Param(param_const_b)) => {
420                assert_eq_or_incompatible(param_const_a, param_const_b)
421            }
422            (rty::ConstKind::Value(ty_a, val_a), ty::ConstKind::Value(ty_b, val_b)) => {
423                assert_eq_or_incompatible(ty_a, ty_b)?;
424                assert_eq_or_incompatible(val_a, val_b)
425            }
426            (rty::ConstKind::Unevaluated(c1), ty::ConstKind::Unevaluated(c2)) => {
427                assert_eq_or_incompatible(c1, c2)
428            }
429            _ => Err(Mismatch::new(a, b)),
430        }
431    }
432
433    fn zip_region(&mut self, a: &rty::Region, b: &ty::Region) {
434        if let rty::Region::ReVar(vid) = a {
435            let re = self.adjust_bvars(b);
436            self.holes.regions.insert(*vid, re);
437        }
438    }
439
440    fn zip_poly_existential_pred(
441        &mut self,
442        a: &rty::Binder<rty::ExistentialPredicate>,
443        b: &rty::Binder<rty::ExistentialPredicate>,
444    ) -> Result<(), Mismatch> {
445        self.enter_binders(a, b, |this, a, b| {
446            match (a, b) {
447                (
448                    rty::ExistentialPredicate::Trait(trait_ref_a),
449                    rty::ExistentialPredicate::Trait(trait_ref_b),
450                ) => {
451                    assert_eq_or_incompatible(trait_ref_a.def_id, trait_ref_b.def_id)?;
452                    assert_eq_or_incompatible(trait_ref_a.args.len(), trait_ref_b.args.len())?;
453                    for (arg_a, arg_b) in iter::zip(&trait_ref_a.args, &trait_ref_b.args) {
454                        this.zip_generic_arg(arg_a, arg_b)?;
455                    }
456                    Ok(())
457                }
458                (
459                    rty::ExistentialPredicate::Projection(projection_a),
460                    rty::ExistentialPredicate::Projection(projection_b),
461                ) => {
462                    assert_eq_or_incompatible(projection_a.def_id, projection_b.def_id)?;
463                    assert_eq_or_incompatible(projection_a.args.len(), projection_b.args.len())?;
464                    for (arg_a, arg_b) in iter::zip(&projection_a.args, &projection_b.args) {
465                        this.zip_generic_arg(arg_a, arg_b)?;
466                    }
467                    this.enter_binders(&projection_a.term, &projection_b.term, |this, a, b| {
468                        this.zip_bty(&a.bty, &b.bty)
469                    })
470                }
471                (
472                    rty::ExistentialPredicate::AutoTrait(def_id_a),
473                    rty::ExistentialPredicate::AutoTrait(def_id_b),
474                ) => assert_eq_or_incompatible(def_id_a, def_id_b),
475                _ => Err(Mismatch::new(a, b)),
476            }
477        })
478    }
479
480    /// Enter a binder in both `a` and `b` creating a mapping between the two.
481    fn enter_binders<T, R>(
482        &mut self,
483        a: &rty::Binder<T>,
484        b: &rty::Binder<T>,
485        f: impl FnOnce(&mut Self, &T, &T) -> R,
486    ) -> R {
487        self.b_binder_to_a_binder.push(Some(self.a_binders));
488        self.a_binders += 1;
489        let r = f(self, a.skip_binder_ref(), b.skip_binder_ref());
490        self.a_binders -= 1;
491        self.b_binder_to_a_binder.pop();
492        r
493    }
494
495    /// Enter a binder in `a` without a corresponding mapping in `b`
496    fn enter_a_binder<T, R>(
497        &mut self,
498        t: &rty::Binder<T>,
499        f: impl FnOnce(&mut Self, &T) -> R,
500    ) -> R {
501        self.a_binders += 1;
502        let r = f(self, t.skip_binder_ref());
503        self.a_binders -= 1;
504        r
505    }
506
507    /// Enter a binder in `b` without a corresponding mapping in `a`
508    fn enter_b_binder<T, R>(
509        &mut self,
510        t: &rty::Binder<T>,
511        f: impl FnOnce(&mut Self, &T) -> R,
512    ) -> R {
513        self.b_binder_to_a_binder.push(None);
514        let r = f(self, t.skip_binder_ref());
515        self.b_binder_to_a_binder.pop();
516        r
517    }
518
519    fn adjust_bvars<T: TypeFoldable + Clone + std::fmt::Debug>(&self, t: &T) -> T {
520        struct Adjuster<'a, 'genv, 'tcx> {
521            current_index: DebruijnIndex,
522            zipper: &'a Zipper<'genv, 'tcx>,
523        }
524
525        impl Adjuster<'_, '_, '_> {
526            fn adjust(&self, debruijn: DebruijnIndex) -> DebruijnIndex {
527                let b_binders = self.zipper.b_binder_to_a_binder.len();
528                let mapped_binder = self.zipper.b_binder_to_a_binder
529                    [b_binders - debruijn.as_usize() - 1]
530                    .unwrap_or_else(|| {
531                        bug!("bound var without corresponding binder: `{debruijn:?}`")
532                    });
533                DebruijnIndex::from_u32(self.zipper.a_binders - mapped_binder - 1)
534                    .shifted_in(self.current_index.as_u32())
535            }
536        }
537
538        impl TypeFolder for Adjuster<'_, '_, '_> {
539            fn fold_binder<T>(&mut self, t: &rty::Binder<T>) -> rty::Binder<T>
540            where
541                T: TypeFoldable,
542            {
543                self.current_index.shift_in(1);
544                let r = t.super_fold_with(self);
545                self.current_index.shift_out(1);
546                r
547            }
548
549            fn fold_region(&mut self, re: &rty::Region) -> rty::Region {
550                if let rty::ReBound(debruijn, br) = *re
551                    && debruijn >= self.current_index
552                {
553                    rty::ReBound(self.adjust(debruijn), br)
554                } else {
555                    *re
556                }
557            }
558
559            fn fold_expr(&mut self, expr: &rty::Expr) -> rty::Expr {
560                if let rty::ExprKind::Var(rty::Var::Bound(debruijn, breft)) = expr.kind()
561                    && *debruijn >= self.current_index
562                {
563                    rty::Expr::bvar(self.adjust(*debruijn), breft.var, breft.kind)
564                } else {
565                    expr.super_fold_with(self)
566                }
567            }
568        }
569        t.fold_with(&mut Adjuster { current_index: INNERMOST, zipper: self })
570    }
571
572    fn emit_fn_sig_err(&mut self, err: FnSigErr, decl: &fhir::FnDecl) {
573        match err {
574            FnSigErr::ArgCountMismatch => {
575                self.errors.emit(errors::IncompatibleParamCount::new(
576                    self.genv,
577                    decl,
578                    self.owner_id,
579                ));
580            }
581            FnSigErr::FnInput(i) => {
582                self.errors.emit(errors::IncompatibleRefinement::fn_input(
583                    self.genv,
584                    self.owner_id,
585                    decl,
586                    i,
587                ));
588            }
589            FnSigErr::FnOutput(_) => {
590                self.errors.emit(errors::IncompatibleRefinement::fn_output(
591                    self.genv,
592                    self.owner_id,
593                    decl,
594                ));
595            }
596            FnSigErr::Ensures { i, expected } => {
597                self.errors.emit(errors::IncompatibleRefinement::ensures(
598                    self.genv,
599                    self.owner_id,
600                    decl,
601                    &expected,
602                    i,
603                ));
604            }
605        }
606    }
607}
608
609fn assert_eq_or_incompatible<T: Eq + fmt::Debug>(a: T, b: T) -> Result<(), Mismatch> {
610    if a != b {
611        return Err(Mismatch::new(a, b));
612    }
613    Ok(())
614}
615
616#[expect(dead_code, reason = "we use the the String for debugging")]
617struct Mismatch(String);
618
619impl Mismatch {
620    fn new<T: fmt::Debug>(a: T, b: T) -> Self {
621        Self(format!("{a:?} != {b:?}"))
622    }
623}
624
625enum FnSigErr {
626    ArgCountMismatch,
627    FnInput(usize),
628    #[expect(dead_code, reason = "we use the struct for debugging")]
629    FnOutput(Mismatch),
630    Ensures {
631        i: usize,
632        expected: rty::Ty,
633    },
634}
635
636mod errors {
637    use flux_common::span_bug;
638    use flux_errors::E0999;
639    use flux_macros::Diagnostic;
640    use flux_middle::{def_id::MaybeExternId, fhir, global_env::GlobalEnv, rty};
641    use flux_rustc_bridge::{
642        ToRustc,
643        ty::{FieldIdx, VariantIdx},
644    };
645    use rustc_span::{DUMMY_SP, Span};
646
647    #[derive(Diagnostic)]
648    #[diag(fhir_analysis_incompatible_refinement, code = E0999)]
649    #[note]
650    pub(super) struct IncompatibleRefinement<'tcx> {
651        #[primary_span]
652        #[label]
653        span: Span,
654        #[label(fhir_analysis_expected_label)]
655        expected_span: Option<Span>,
656        expected_ty: rustc_middle::ty::Ty<'tcx>,
657        def_descr: &'static str,
658    }
659
660    impl<'tcx> IncompatibleRefinement<'tcx> {
661        pub(super) fn type_alias(
662            genv: GlobalEnv<'_, 'tcx>,
663            def_id: MaybeExternId,
664            type_alias: &fhir::TyAlias,
665        ) -> Self {
666            let tcx = genv.tcx();
667            Self {
668                span: type_alias.ty.span,
669                def_descr: tcx.def_descr(def_id.resolved_id()),
670                expected_span: Some(tcx.def_span(def_id)),
671                expected_ty: tcx.type_of(def_id).skip_binder(),
672            }
673        }
674
675        pub(super) fn fn_input(
676            genv: GlobalEnv<'_, 'tcx>,
677            fn_id: MaybeExternId,
678            decl: &fhir::FnDecl,
679            pos: usize,
680        ) -> Self {
681            let expected_span = match fn_id {
682                MaybeExternId::Local(local_id) => {
683                    genv.tcx()
684                        .hir_node_by_def_id(local_id)
685                        .fn_decl()
686                        .and_then(|fn_decl| fn_decl.inputs.get(pos))
687                        .map(|input| input.span)
688                }
689                MaybeExternId::Extern(_, extern_id) => Some(genv.tcx().def_span(extern_id)),
690            };
691
692            let expected_ty = genv
693                .tcx()
694                .fn_sig(fn_id.resolved_id())
695                .skip_binder()
696                .inputs()
697                .map_bound(|inputs| inputs[pos])
698                .skip_binder();
699
700            Self {
701                span: decl.inputs[pos].span,
702                def_descr: genv.tcx().def_descr(fn_id.resolved_id()),
703                expected_span,
704                expected_ty,
705            }
706        }
707
708        pub(super) fn fn_output(
709            genv: GlobalEnv<'_, 'tcx>,
710            fn_id: MaybeExternId,
711            decl: &fhir::FnDecl,
712        ) -> Self {
713            let expected_span = match fn_id {
714                MaybeExternId::Local(local_id) => {
715                    genv.tcx()
716                        .hir_node_by_def_id(local_id)
717                        .fn_decl()
718                        .map(|fn_decl| fn_decl.output.span())
719                }
720                MaybeExternId::Extern(_, extern_id) => Some(genv.tcx().def_span(extern_id)),
721            };
722
723            let expected_ty = genv
724                .tcx()
725                .fn_sig(fn_id.resolved_id())
726                .skip_binder()
727                .output()
728                .skip_binder();
729            let spec_span = decl.output.ret.span;
730            Self {
731                span: spec_span,
732                def_descr: genv.tcx().def_descr(fn_id.resolved_id()),
733                expected_span,
734                expected_ty,
735            }
736        }
737
738        pub(super) fn ensures(
739            genv: GlobalEnv<'_, 'tcx>,
740            fn_id: MaybeExternId,
741            decl: &fhir::FnDecl,
742            expected: &rty::Ty,
743            i: usize,
744        ) -> Self {
745            let fhir::Ensures::Type(_, ty) = &decl.output.ensures[i] else {
746                span_bug!(decl.span, "expected `fhir::Ensures::Type`");
747            };
748            let tcx = genv.tcx();
749            Self {
750                span: ty.span,
751                def_descr: tcx.def_descr(fn_id.resolved_id()),
752                expected_span: None,
753                expected_ty: expected.to_rustc(tcx),
754            }
755        }
756
757        pub(super) fn field(
758            genv: GlobalEnv<'_, 'tcx>,
759            adt_id: MaybeExternId,
760            variant_idx: VariantIdx,
761            field_idx: FieldIdx,
762        ) -> Self {
763            let tcx = genv.tcx();
764            let adt_def = tcx.adt_def(adt_id);
765            let field_def = &adt_def.variant(variant_idx).fields[field_idx];
766
767            let item = genv.fhir_expect_item(adt_id.local_id()).unwrap();
768            let span = match &item.kind {
769                fhir::ItemKind::Enum(enum_def) => {
770                    enum_def.variants[variant_idx.as_usize()].fields[field_idx.as_usize()]
771                        .ty
772                        .span
773                }
774                fhir::ItemKind::Struct(struct_def)
775                    if let fhir::StructKind::Transparent { fields } = &struct_def.kind =>
776                {
777                    fields[field_idx.as_usize()].ty.span
778                }
779                _ => DUMMY_SP,
780            };
781
782            Self {
783                span,
784                def_descr: tcx.def_descr(field_def.did),
785                expected_span: Some(tcx.def_span(field_def.did)),
786                expected_ty: tcx.type_of(field_def.did).skip_binder(),
787            }
788        }
789    }
790
791    #[derive(Diagnostic)]
792    #[diag(fhir_analysis_incompatible_param_count, code = E0999)]
793    pub(super) struct IncompatibleParamCount {
794        #[primary_span]
795        #[label]
796        span: Span,
797        found: usize,
798        #[label(fhir_analysis_expected_label)]
799        expected_span: Span,
800        expected: usize,
801        def_descr: &'static str,
802    }
803
804    impl IncompatibleParamCount {
805        pub(super) fn new(genv: GlobalEnv, decl: &fhir::FnDecl, def_id: MaybeExternId) -> Self {
806            let def_descr = genv.tcx().def_descr(def_id.resolved_id());
807
808            let span = if !decl.inputs.is_empty() {
809                decl.inputs[decl.inputs.len() - 1]
810                    .span
811                    .with_lo(decl.inputs[0].span.lo())
812            } else {
813                decl.span
814            };
815
816            let expected_span = if let Some(local_id) = def_id.as_local()
817                && let expected_decl = genv.tcx().hir_node_by_def_id(local_id).fn_decl().unwrap()
818                && !expected_decl.inputs.is_empty()
819            {
820                expected_decl.inputs[expected_decl.inputs.len() - 1]
821                    .span
822                    .with_lo(expected_decl.inputs[0].span.lo())
823            } else {
824                genv.tcx().def_span(def_id)
825            };
826
827            let expected = genv
828                .tcx()
829                .fn_sig(def_id)
830                .skip_binder()
831                .skip_binder()
832                .inputs()
833                .len();
834
835            Self { span, found: decl.inputs.len(), expected_span, expected, def_descr }
836        }
837    }
838
839    #[derive(Diagnostic)]
840    #[diag(fhir_analysis_field_count_mismatch, code = E0999)]
841    pub(super) struct FieldCountMismatch {
842        #[primary_span]
843        #[label]
844        span: Span,
845        fields: usize,
846        #[label(fhir_analysis_expected_label)]
847        expected_span: Span,
848        expected_fields: usize,
849    }
850
851    impl FieldCountMismatch {
852        pub(super) fn new(
853            genv: GlobalEnv,
854            found: usize,
855            adt_def_id: MaybeExternId,
856            variant_idx: VariantIdx,
857        ) -> Self {
858            let adt_def = genv.tcx().adt_def(adt_def_id);
859            let expected_variant = adt_def.variant(variant_idx);
860
861            // Get the span of the variant if this is an enum. Structs cannot have produce a field
862            // count mismatch.
863            let span = if let Ok(fhir::Node::Item(item)) = genv.fhir_node(adt_def_id.local_id())
864                && let fhir::ItemKind::Enum(enum_def) = &item.kind
865                && let Some(variant) = enum_def.variants.get(variant_idx.as_usize())
866            {
867                variant.span
868            } else {
869                DUMMY_SP
870            };
871
872            Self {
873                span,
874                fields: found,
875                expected_span: genv.tcx().def_span(expected_variant.def_id),
876                expected_fields: expected_variant.fields.len(),
877            }
878        }
879    }
880}