flux_fhir_analysis/conv/
struct_compat.rs

1//! Check whether two refinemnt types/signatures are structurally compatible.
2//!
3//! Used to check if a user spec is compatible with the underlying rust type. The code also
4//! infer types annotated with `_` in the surface syntax.
5
6use std::{fmt, iter};
7
8use flux_common::bug;
9use flux_errors::Errors;
10use flux_middle::{
11    def_id::MaybeExternId,
12    fhir,
13    global_env::GlobalEnv,
14    queries::QueryResult,
15    rty::{
16        self,
17        fold::{TypeFoldable, TypeFolder, TypeSuperFoldable},
18        refining::{Refine as _, Refiner},
19    },
20};
21use flux_rustc_bridge::ty::{self, FieldIdx, VariantIdx};
22use rustc_ast::Mutability;
23use rustc_data_structures::unord::UnordMap;
24use rustc_type_ir::{DebruijnIndex, INNERMOST, InferConst};
25
26pub(crate) fn type_alias(
27    genv: GlobalEnv,
28    alias: &fhir::TyAlias,
29    alias_ty: &rty::TyCtor,
30    def_id: MaybeExternId,
31) -> QueryResult<rty::TyCtor> {
32    let rust_ty = genv.lower_type_of(def_id.resolved_id())?.skip_binder();
33    let expected = rust_ty.refine(&Refiner::default_for_item(genv, def_id.resolved_id())?)?;
34    let mut zipper = Zipper::new(genv, def_id);
35
36    if zipper
37        .enter_a_binder(alias_ty, |zipper, ty| zipper.zip_ty(ty, &expected))
38        .is_err()
39    {
40        zipper
41            .errors
42            .emit(errors::IncompatibleRefinement::type_alias(genv, def_id, alias));
43    }
44
45    zipper.errors.to_result()?;
46
47    Ok(zipper.holes.replace_holes(alias_ty))
48}
49
50pub(crate) fn fn_sig(
51    genv: GlobalEnv,
52    decl: &fhir::FnDecl,
53    fn_sig: &rty::PolyFnSig,
54    def_id: MaybeExternId,
55) -> QueryResult<rty::PolyFnSig> {
56    let rust_fn_sig = genv.lower_fn_sig(def_id.resolved_id())?.skip_binder();
57
58    let expected = Refiner::default_for_item(genv, def_id.resolved_id())?.refine(&rust_fn_sig)?;
59
60    let mut zipper = Zipper::new(genv, def_id);
61    if let Err(err) = zipper.zip_poly_fn_sig(fn_sig, &expected) {
62        zipper.emit_fn_sig_err(err, decl);
63    }
64
65    zipper.errors.to_result()?;
66
67    Ok(zipper.holes.replace_holes(fn_sig))
68}
69
70pub(crate) fn variants(
71    genv: GlobalEnv,
72    variants: &[rty::PolyVariant],
73    adt_def_id: MaybeExternId,
74) -> QueryResult<Vec<rty::PolyVariant>> {
75    let refiner = Refiner::default_for_item(genv, adt_def_id.resolved_id())?;
76    let mut zipper = Zipper::new(genv, adt_def_id);
77    // TODO check same number of variants
78    for (i, variant) in variants.iter().enumerate() {
79        let variant_idx = VariantIdx::from_usize(i);
80        let expected = refiner.refine_variant_def(adt_def_id.resolved_id(), variant_idx)?;
81        zipper.zip_variant(variant, &expected, variant_idx);
82    }
83
84    zipper.errors.to_result()?;
85
86    Ok(variants
87        .iter()
88        .map(|v| zipper.holes.replace_holes(v))
89        .collect())
90}
91
92struct Zipper<'genv, 'tcx> {
93    genv: GlobalEnv<'genv, 'tcx>,
94    owner_id: MaybeExternId,
95    locs: UnordMap<rty::Loc, rty::Ty>,
96    holes: Holes,
97    /// Number of binders we've entered in `a`
98    a_binders: u32,
99    /// Each element in the vector correspond to a binder in `b`. For some binders we map it to
100    /// a corresponding binder in `a`. We assume that expressions filling holes will only contain
101    /// variables pointing to some of these mapped binders.
102    b_binder_to_a_binder: Vec<Option<u32>>,
103    errors: Errors<'genv>,
104}
105
106#[derive(Default)]
107struct Holes {
108    sorts: UnordMap<rty::SortVid, rty::Sort>,
109    subset_tys: UnordMap<rty::TyVid, rty::SubsetTy>,
110    types: UnordMap<rty::TyVid, rty::Ty>,
111    regions: UnordMap<rty::RegionVid, rty::Region>,
112    consts: UnordMap<rty::ConstVid, rty::Const>,
113}
114
115impl TypeFolder for &Holes {
116    fn fold_sort(&mut self, sort: &rty::Sort) -> rty::Sort {
117        if let rty::Sort::Infer(vid) = sort {
118            self.sorts
119                .get(vid)
120                .cloned()
121                .unwrap_or_else(|| bug!("unfilled sort hole {vid:?}"))
122        } else {
123            sort.super_fold_with(self)
124        }
125    }
126
127    fn fold_ty(&mut self, ty: &rty::Ty) -> rty::Ty {
128        if let rty::TyKind::Infer(vid) = ty.kind() {
129            self.types
130                .get(vid)
131                .cloned()
132                .unwrap_or_else(|| bug!("unfilled type hole {vid:?}"))
133        } else {
134            ty.super_fold_with(self)
135        }
136    }
137
138    fn fold_subset_ty(&mut self, constr: &rty::SubsetTy) -> rty::SubsetTy {
139        if let rty::BaseTy::Infer(vid) = &constr.bty {
140            self.subset_tys
141                .get(vid)
142                .cloned()
143                .unwrap_or_else(|| bug!("unfilled type hole {vid:?}"))
144        } else {
145            constr.super_fold_with(self)
146        }
147    }
148
149    fn fold_region(&mut self, r: &rty::Region) -> rty::Region {
150        if let rty::Region::ReVar(vid) = r {
151            self.regions
152                .get(vid)
153                .copied()
154                .unwrap_or_else(|| bug!("unfilled region hole {vid:?}"))
155        } else {
156            *r
157        }
158    }
159
160    fn fold_const(&mut self, ct: &rty::Const) -> rty::Const {
161        if let rty::ConstKind::Infer(InferConst::Var(cid)) = ct.kind {
162            self.consts
163                .get(&cid)
164                .cloned()
165                .unwrap_or_else(|| bug!("unfilled const hole {cid:?}"))
166        } else {
167            ct.super_fold_with(self)
168        }
169    }
170}
171
172impl Holes {
173    fn replace_holes<T: TypeFoldable>(&self, t: &T) -> T {
174        let mut this = self;
175        t.fold_with(&mut this)
176    }
177}
178
179impl<'genv, 'tcx> Zipper<'genv, 'tcx> {
180    fn new(genv: GlobalEnv<'genv, 'tcx>, owner_id: MaybeExternId) -> Self {
181        Self {
182            genv,
183            owner_id,
184            locs: UnordMap::default(),
185            holes: Default::default(),
186            a_binders: 0,
187            b_binder_to_a_binder: vec![],
188            errors: Errors::new(genv.sess()),
189        }
190    }
191
192    fn is_async_fn(&self) -> bool {
193        self.genv
194            .tcx()
195            .asyncness(self.owner_id.resolved_id())
196            .is_async()
197    }
198
199    fn zip_poly_fn_sig(&mut self, a: &rty::PolyFnSig, b: &rty::PolyFnSig) -> Result<(), FnSigErr> {
200        self.enter_binders(a, b, |this, a, b| this.zip_fn_sig(a, b))
201    }
202
203    fn zip_variant(&mut self, a: &rty::PolyVariant, b: &rty::PolyVariant, variant_idx: VariantIdx) {
204        self.enter_binders(a, b, |this, a, b| {
205            // The args are always `GenericArgs::identity_for_item` inside the `EarlyBinder`
206            debug_assert_eq!(a.args, b.args);
207
208            if a.fields.len() != b.fields.len() {
209                this.errors.emit(errors::FieldCountMismatch::new(
210                    this.genv,
211                    a.fields.len(),
212                    this.owner_id,
213                    variant_idx,
214                ));
215                return;
216            }
217            for (i, (ty_a, ty_b)) in iter::zip(&a.fields, &b.fields).enumerate() {
218                let field_idx = FieldIdx::from_usize(i);
219                if this.zip_ty(ty_a, ty_b).is_err() {
220                    this.errors.emit(errors::IncompatibleRefinement::field(
221                        this.genv,
222                        this.owner_id,
223                        variant_idx,
224                        field_idx,
225                    ));
226                }
227            }
228        });
229    }
230
231    fn zip_fn_sig(&mut self, a: &rty::FnSig, b: &rty::FnSig) -> Result<(), FnSigErr> {
232        if a.inputs().len() != b.inputs().len() {
233            Err(FnSigErr::ArgCountMismatch)?;
234        }
235        for (i, (ty_a, ty_b)) in iter::zip(a.inputs(), b.inputs()).enumerate() {
236            self.zip_ty(ty_a, ty_b).map_err(|_| FnSigErr::FnInput(i))?;
237        }
238        self.enter_binders(&a.output, &b.output, |this, output_a, output_b| {
239            this.zip_output(output_a, output_b)
240        })
241    }
242
243    fn zip_output(&mut self, a: &rty::FnOutput, b: &rty::FnOutput) -> Result<(), FnSigErr> {
244        self.zip_ty(&a.ret, &b.ret).map_err(FnSigErr::FnOutput)?;
245
246        for (i, ensures) in a.ensures.iter().enumerate() {
247            if let rty::Ensures::Type(path, ty_a) = ensures {
248                let loc = path.to_loc().unwrap();
249                let ty_b = self.locs.get(&loc).unwrap().shift_in_escaping(1);
250                self.zip_ty(ty_a, &ty_b)
251                    .map_err(|_| FnSigErr::Ensures { i, expected: ty_b })?;
252            }
253        }
254        Ok(())
255    }
256
257    fn zip_ty(&mut self, a: &rty::Ty, b: &rty::Ty) -> Result<(), Mismatch> {
258        match (a.kind(), b.kind()) {
259            (rty::TyKind::Infer(vid), _) => {
260                assert_ne!(vid.as_u32(), 0);
261                let b = self.adjust_bvars(b);
262                self.holes.types.insert(*vid, b);
263                Ok(())
264            }
265            (rty::TyKind::Exists(ctor_a), _) => {
266                self.enter_a_binder(ctor_a, |this, ty_a| this.zip_ty(ty_a, b))
267            }
268            (_, rty::TyKind::Exists(ctor_b)) => {
269                self.enter_b_binder(ctor_b, |this, ty_b| this.zip_ty(a, ty_b))
270            }
271            (rty::TyKind::Constr(_, ty_a), _) => self.zip_ty(ty_a, b),
272            (_, rty::TyKind::Constr(_, ty_b)) => self.zip_ty(a, ty_b),
273            (rty::TyKind::Indexed(bty_a, _), rty::TyKind::Indexed(bty_b, _)) => {
274                self.zip_bty(bty_a, bty_b)
275            }
276            (rty::TyKind::StrgRef(re_a, path, ty_a), rty::Ref!(re_b, ty_b, Mutability::Mut)) => {
277                let loc = path.to_loc().unwrap();
278                self.locs.insert(loc, ty_b.clone());
279
280                self.zip_region(re_a, re_b);
281                self.zip_ty(ty_a, ty_b)
282            }
283            (rty::TyKind::Param(pty_a), rty::TyKind::Param(pty_b)) => {
284                assert_eq_or_incompatible(pty_a, pty_b)
285            }
286            (
287                rty::TyKind::Ptr(_, _)
288                | rty::TyKind::Discr(..)
289                | rty::TyKind::Downcast(_, _, _, _, _)
290                | rty::TyKind::Blocked(_)
291                | rty::TyKind::Uninit,
292                _,
293            ) => {
294                bug!("unexpected type {a:?}");
295            }
296            _ => Err(Mismatch::new(a, b)),
297        }
298    }
299
300    fn zip_bty(&mut self, a: &rty::BaseTy, b: &rty::BaseTy) -> Result<(), Mismatch> {
301        match (a, b) {
302            (rty::BaseTy::Int(ity_a), rty::BaseTy::Int(ity_b)) => {
303                assert_eq_or_incompatible(ity_a, ity_b)
304            }
305            (rty::BaseTy::Uint(uity_a), rty::BaseTy::Uint(uity_b)) => {
306                assert_eq_or_incompatible(uity_a, uity_b)
307            }
308            (rty::BaseTy::Bool, rty::BaseTy::Bool) => Ok(()),
309            (rty::BaseTy::Str, rty::BaseTy::Str) => Ok(()),
310            (rty::BaseTy::Char, rty::BaseTy::Char) => Ok(()),
311            (rty::BaseTy::Float(fty_a), rty::BaseTy::Float(fty_b)) => {
312                assert_eq_or_incompatible(fty_a, fty_b)
313            }
314            (rty::BaseTy::Slice(ty_a), rty::BaseTy::Slice(ty_b)) => self.zip_ty(ty_a, ty_b),
315            (rty::BaseTy::Adt(adt_def_a, args_a), rty::BaseTy::Adt(adt_def_b, args_b)) => {
316                assert_eq_or_incompatible(adt_def_a.did(), adt_def_b.did())?;
317                assert_eq_or_incompatible(args_a.len(), args_b.len())?;
318                for (arg_a, arg_b) in iter::zip(args_a, args_b) {
319                    self.zip_generic_arg(arg_a, arg_b)?;
320                }
321                Ok(())
322            }
323            (rty::BaseTy::RawPtr(ty_a, mutbl_a), rty::BaseTy::RawPtr(ty_b, mutbl_b)) => {
324                assert_eq_or_incompatible(mutbl_a, mutbl_b)?;
325                self.zip_ty(ty_a, ty_b)
326            }
327            (rty::BaseTy::Ref(re_a, ty_a, mutbl_a), rty::BaseTy::Ref(re_b, ty_b, mutbl_b)) => {
328                assert_eq_or_incompatible(mutbl_a, mutbl_b)?;
329                self.zip_region(re_a, re_b);
330                self.zip_ty(ty_a, ty_b)
331            }
332            (rty::BaseTy::FnPtr(poly_sig_a), rty::BaseTy::FnPtr(poly_sig_b)) => {
333                self.zip_poly_fn_sig(poly_sig_a, poly_sig_b)
334                    .map_err(|_| Mismatch::new(poly_sig_a, poly_sig_b))
335            }
336            (rty::BaseTy::Tuple(tys_a), rty::BaseTy::Tuple(tys_b)) => {
337                assert_eq_or_incompatible(tys_a.len(), tys_b.len())?;
338                for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
339                    self.zip_ty(ty_a, ty_b)?;
340                }
341                Ok(())
342            }
343            (rty::BaseTy::Alias(kind_a, aty_a), rty::BaseTy::Alias(kind_b, aty_b)) => {
344                assert_eq_or_incompatible(kind_a, kind_b)?;
345                assert_eq_or_incompatible(aty_a.def_id, aty_b.def_id)?;
346                assert_eq_or_incompatible(aty_a.args.len(), aty_b.args.len())?;
347                for (arg_a, arg_b) in iter::zip(&aty_a.args, &aty_b.args) {
348                    self.zip_generic_arg(arg_a, arg_b)?;
349                }
350                Ok(())
351            }
352            (rty::BaseTy::Array(ty_a, len_a), rty::BaseTy::Array(ty_b, len_b)) => {
353                self.zip_const(len_a, len_b)?;
354                self.zip_ty(ty_a, ty_b)
355            }
356            (rty::BaseTy::Never, rty::BaseTy::Never) => Ok(()),
357            (rty::BaseTy::Param(pty_a), rty::BaseTy::Param(pty_b)) => {
358                assert_eq_or_incompatible(pty_a, pty_b)
359            }
360            (rty::BaseTy::Dynamic(preds_a, re_a), rty::BaseTy::Dynamic(preds_b, re_b)) => {
361                assert_eq_or_incompatible(preds_a.len(), preds_b.len())?;
362                for (pred_a, pred_b) in iter::zip(preds_a, preds_b) {
363                    self.zip_poly_existential_pred(pred_a, pred_b)?;
364                }
365                self.zip_region(re_a, re_b);
366                Ok(())
367            }
368            (rty::BaseTy::Foreign(def_id_a), rty::BaseTy::Foreign(def_id_b)) => {
369                assert_eq_or_incompatible(def_id_a, def_id_b)
370            }
371            (rty::BaseTy::Closure(..) | rty::BaseTy::Coroutine(..), _) => {
372                bug!("unexpected type `{a:?}`");
373            }
374            _ => Err(Mismatch::new(a, b)),
375        }
376    }
377
378    fn zip_generic_arg(
379        &mut self,
380        a: &rty::GenericArg,
381        b: &rty::GenericArg,
382    ) -> Result<(), Mismatch> {
383        match (a, b) {
384            (rty::GenericArg::Ty(ty_a), rty::GenericArg::Ty(ty_b)) => self.zip_ty(ty_a, ty_b),
385            (rty::GenericArg::Base(ctor_a), rty::GenericArg::Base(ctor_b)) => {
386                self.zip_sorts(&ctor_a.sort(), &ctor_b.sort());
387                self.enter_binders(ctor_a, ctor_b, |this, sty_a, sty_b| {
388                    this.zip_subset_ty(sty_a, sty_b)
389                })
390            }
391            (rty::GenericArg::Lifetime(re_a), rty::GenericArg::Lifetime(re_b)) => {
392                self.zip_region(re_a, re_b);
393                Ok(())
394            }
395            (rty::GenericArg::Const(ct_a), rty::GenericArg::Const(ct_b)) => {
396                self.zip_const(ct_a, ct_b)
397            }
398            _ => Err(Mismatch::new(a, b)),
399        }
400    }
401
402    fn zip_sorts(&mut self, a: &rty::Sort, b: &rty::Sort) {
403        if let rty::Sort::Infer(vid) = a {
404            assert_ne!(vid.as_u32(), 0);
405            self.holes.sorts.insert(*vid, b.clone());
406        }
407    }
408
409    fn zip_subset_ty(&mut self, a: &rty::SubsetTy, b: &rty::SubsetTy) -> Result<(), Mismatch> {
410        if let rty::BaseTy::Infer(vid) = a.bty {
411            assert_ne!(vid.as_u32(), 0);
412            let b = self.adjust_bvars(b);
413            self.holes.subset_tys.insert(vid, b);
414            Ok(())
415        } else {
416            self.zip_bty(&a.bty, &b.bty)
417        }
418    }
419
420    fn zip_const(&mut self, a: &rty::Const, b: &ty::Const) -> Result<(), Mismatch> {
421        match (&a.kind, &b.kind) {
422            (rty::ConstKind::Infer(ty::InferConst::Var(cid)), _) => {
423                self.holes.consts.insert(*cid, b.clone());
424                Ok(())
425            }
426            (rty::ConstKind::Param(param_const_a), ty::ConstKind::Param(param_const_b)) => {
427                assert_eq_or_incompatible(param_const_a, param_const_b)
428            }
429            (rty::ConstKind::Value(ty_a, val_a), ty::ConstKind::Value(ty_b, val_b)) => {
430                assert_eq_or_incompatible(ty_a, ty_b)?;
431                assert_eq_or_incompatible(val_a, val_b)
432            }
433            (rty::ConstKind::Unevaluated(c1), ty::ConstKind::Unevaluated(c2)) => {
434                assert_eq_or_incompatible(c1, c2)
435            }
436            _ => Err(Mismatch::new(a, b)),
437        }
438    }
439
440    fn zip_region(&mut self, a: &rty::Region, b: &ty::Region) {
441        if let rty::Region::ReVar(vid) = a {
442            let re = self.adjust_bvars(b);
443            self.holes.regions.insert(*vid, re);
444        }
445    }
446
447    fn zip_poly_existential_pred(
448        &mut self,
449        a: &rty::Binder<rty::ExistentialPredicate>,
450        b: &rty::Binder<rty::ExistentialPredicate>,
451    ) -> Result<(), Mismatch> {
452        self.enter_binders(a, b, |this, a, b| {
453            match (a, b) {
454                (
455                    rty::ExistentialPredicate::Trait(trait_ref_a),
456                    rty::ExistentialPredicate::Trait(trait_ref_b),
457                ) => {
458                    assert_eq_or_incompatible(trait_ref_a.def_id, trait_ref_b.def_id)?;
459                    assert_eq_or_incompatible(trait_ref_a.args.len(), trait_ref_b.args.len())?;
460                    for (arg_a, arg_b) in iter::zip(&trait_ref_a.args, &trait_ref_b.args) {
461                        this.zip_generic_arg(arg_a, arg_b)?;
462                    }
463                    Ok(())
464                }
465                (
466                    rty::ExistentialPredicate::Projection(projection_a),
467                    rty::ExistentialPredicate::Projection(projection_b),
468                ) => {
469                    assert_eq_or_incompatible(projection_a.def_id, projection_b.def_id)?;
470                    assert_eq_or_incompatible(projection_a.args.len(), projection_b.args.len())?;
471                    for (arg_a, arg_b) in iter::zip(&projection_a.args, &projection_b.args) {
472                        this.zip_generic_arg(arg_a, arg_b)?;
473                    }
474                    this.enter_binders(&projection_a.term, &projection_b.term, |this, a, b| {
475                        this.zip_bty(&a.bty, &b.bty)
476                    })
477                }
478                (
479                    rty::ExistentialPredicate::AutoTrait(def_id_a),
480                    rty::ExistentialPredicate::AutoTrait(def_id_b),
481                ) => assert_eq_or_incompatible(def_id_a, def_id_b),
482                _ => Err(Mismatch::new(a, b)),
483            }
484        })
485    }
486
487    /// Enter a binder in both `a` and `b` creating a mapping between the two.
488    fn enter_binders<T, R>(
489        &mut self,
490        a: &rty::Binder<T>,
491        b: &rty::Binder<T>,
492        f: impl FnOnce(&mut Self, &T, &T) -> R,
493    ) -> R {
494        self.b_binder_to_a_binder.push(Some(self.a_binders));
495        self.a_binders += 1;
496        let r = f(self, a.skip_binder_ref(), b.skip_binder_ref());
497        self.a_binders -= 1;
498        self.b_binder_to_a_binder.pop();
499        r
500    }
501
502    /// Enter a binder in `a` without a corresponding mapping in `b`
503    fn enter_a_binder<T, R>(
504        &mut self,
505        t: &rty::Binder<T>,
506        f: impl FnOnce(&mut Self, &T) -> R,
507    ) -> R {
508        self.a_binders += 1;
509        let r = f(self, t.skip_binder_ref());
510        self.a_binders -= 1;
511        r
512    }
513
514    /// Enter a binder in `b` without a corresponding mapping in `a`
515    fn enter_b_binder<T, R>(
516        &mut self,
517        t: &rty::Binder<T>,
518        f: impl FnOnce(&mut Self, &T) -> R,
519    ) -> R {
520        self.b_binder_to_a_binder.push(None);
521        let r = f(self, t.skip_binder_ref());
522        self.b_binder_to_a_binder.pop();
523        r
524    }
525
526    fn adjust_bvars<T: TypeFoldable + Clone + std::fmt::Debug>(&self, t: &T) -> T {
527        struct Adjuster<'a, 'genv, 'tcx> {
528            current_index: DebruijnIndex,
529            zipper: &'a Zipper<'genv, 'tcx>,
530        }
531
532        impl Adjuster<'_, '_, '_> {
533            fn adjust(&self, debruijn: DebruijnIndex) -> DebruijnIndex {
534                let b_binders = self.zipper.b_binder_to_a_binder.len();
535                let mapped_binder = self.zipper.b_binder_to_a_binder
536                    [b_binders - debruijn.as_usize() - 1]
537                    .unwrap_or_else(|| {
538                        bug!("bound var without corresponding binder: `{debruijn:?}`")
539                    });
540                DebruijnIndex::from_u32(self.zipper.a_binders - mapped_binder - 1)
541                    .shifted_in(self.current_index.as_u32())
542            }
543        }
544
545        impl TypeFolder for Adjuster<'_, '_, '_> {
546            fn enter_binder(&mut self, _: &rty::BoundVariableKinds) {
547                self.current_index.shift_in(1);
548            }
549
550            fn exit_binder(&mut self) {
551                self.current_index.shift_out(1);
552            }
553
554            fn fold_region(&mut self, re: &rty::Region) -> rty::Region {
555                if let rty::ReBound(debruijn, br) = *re
556                    && debruijn >= self.current_index
557                {
558                    rty::ReBound(self.adjust(debruijn), br)
559                } else {
560                    *re
561                }
562            }
563
564            fn fold_expr(&mut self, expr: &rty::Expr) -> rty::Expr {
565                if let rty::ExprKind::Var(rty::Var::Bound(debruijn, breft)) = expr.kind()
566                    && *debruijn >= self.current_index
567                {
568                    rty::Expr::bvar(self.adjust(*debruijn), breft.var, breft.kind)
569                } else {
570                    expr.super_fold_with(self)
571                }
572            }
573        }
574        t.fold_with(&mut Adjuster { current_index: INNERMOST, zipper: self })
575    }
576
577    fn emit_fn_sig_err(&mut self, err: FnSigErr, decl: &fhir::FnDecl) {
578        match err {
579            FnSigErr::ArgCountMismatch => {
580                self.errors.emit(errors::IncompatibleParamCount::new(
581                    self.genv,
582                    decl,
583                    self.owner_id,
584                ));
585            }
586            FnSigErr::FnInput(i) => {
587                self.errors.emit(errors::IncompatibleRefinement::fn_input(
588                    self.genv,
589                    self.owner_id,
590                    decl,
591                    i,
592                ));
593            }
594            FnSigErr::FnOutput(_) => {
595                self.errors.emit(errors::IncompatibleRefinement::fn_output(
596                    self.genv,
597                    self.owner_id,
598                    decl,
599                    self.is_async_fn(),
600                ));
601            }
602            FnSigErr::Ensures { i, expected } => {
603                self.errors.emit(errors::IncompatibleRefinement::ensures(
604                    self.genv,
605                    self.owner_id,
606                    decl,
607                    &expected,
608                    i,
609                ));
610            }
611        }
612    }
613}
614
615fn assert_eq_or_incompatible<T: Eq + fmt::Debug>(a: T, b: T) -> Result<(), Mismatch> {
616    if a != b {
617        return Err(Mismatch::new(a, b));
618    }
619    Ok(())
620}
621
622#[expect(dead_code, reason = "we use the the String for debugging")]
623struct Mismatch(String);
624
625impl Mismatch {
626    fn new<T: fmt::Debug>(a: T, b: T) -> Self {
627        Self(format!("{a:?} != {b:?}"))
628    }
629}
630
631enum FnSigErr {
632    ArgCountMismatch,
633    FnInput(usize),
634    #[expect(dead_code, reason = "we use the struct for debugging")]
635    FnOutput(Mismatch),
636    Ensures {
637        i: usize,
638        expected: rty::Ty,
639    },
640}
641
642mod errors {
643    use flux_common::span_bug;
644    use flux_errors::E0999;
645    use flux_macros::Diagnostic;
646    use flux_middle::{def_id::MaybeExternId, fhir, global_env::GlobalEnv, rty};
647    use flux_rustc_bridge::{
648        ToRustc,
649        ty::{FieldIdx, VariantIdx},
650    };
651    use rustc_span::{DUMMY_SP, Span};
652
653    #[derive(Diagnostic)]
654    #[diag(fhir_analysis_incompatible_refinement, code = E0999)]
655    #[note]
656    pub(super) struct IncompatibleRefinement<'tcx> {
657        #[primary_span]
658        #[label]
659        span: Span,
660        #[label(fhir_analysis_expected_label)]
661        expected_span: Option<Span>,
662        expected_ty: rustc_middle::ty::Ty<'tcx>,
663        def_descr: &'static str,
664        #[help(fhir_analysis_async_hint)]
665        async_hint: Option<()>,
666    }
667
668    impl<'tcx> IncompatibleRefinement<'tcx> {
669        pub(super) fn type_alias(
670            genv: GlobalEnv<'_, 'tcx>,
671            def_id: MaybeExternId,
672            type_alias: &fhir::TyAlias,
673        ) -> Self {
674            let tcx = genv.tcx();
675            Self {
676                span: type_alias.ty.span,
677                def_descr: tcx.def_descr(def_id.resolved_id()),
678                expected_span: Some(tcx.def_span(def_id)),
679                expected_ty: tcx.type_of(def_id).skip_binder(),
680                async_hint: None,
681            }
682        }
683
684        pub(super) fn fn_input(
685            genv: GlobalEnv<'_, 'tcx>,
686            fn_id: MaybeExternId,
687            decl: &fhir::FnDecl,
688            pos: usize,
689        ) -> Self {
690            let expected_span = match fn_id {
691                MaybeExternId::Local(local_id) => {
692                    genv.tcx()
693                        .hir_node_by_def_id(local_id)
694                        .fn_decl()
695                        .and_then(|fn_decl| fn_decl.inputs.get(pos))
696                        .map(|input| input.span)
697                }
698                MaybeExternId::Extern(_, extern_id) => Some(genv.tcx().def_span(extern_id)),
699            };
700
701            let expected_ty = genv
702                .tcx()
703                .fn_sig(fn_id.resolved_id())
704                .skip_binder()
705                .inputs()
706                .map_bound(|inputs| inputs[pos])
707                .skip_binder();
708
709            Self {
710                span: decl.inputs[pos].span,
711                def_descr: genv.tcx().def_descr(fn_id.resolved_id()),
712                expected_span,
713                expected_ty,
714                async_hint: None,
715            }
716        }
717
718        pub(super) fn fn_output(
719            genv: GlobalEnv<'_, 'tcx>,
720            fn_id: MaybeExternId,
721            decl: &fhir::FnDecl,
722            is_async: bool,
723        ) -> Self {
724            let expected_span = match fn_id {
725                MaybeExternId::Local(local_id) => {
726                    genv.tcx()
727                        .hir_node_by_def_id(local_id)
728                        .fn_decl()
729                        .map(|fn_decl| fn_decl.output.span())
730                }
731                MaybeExternId::Extern(_, extern_id) => Some(genv.tcx().def_span(extern_id)),
732            };
733
734            let expected_ty = genv
735                .tcx()
736                .fn_sig(fn_id.resolved_id())
737                .skip_binder()
738                .output()
739                .skip_binder();
740            let spec_span = decl.output.ret.span;
741
742            let async_hint = if is_async { Some(()) } else { None };
743
744            Self {
745                span: spec_span,
746                def_descr: genv.tcx().def_descr(fn_id.resolved_id()),
747                expected_span,
748                expected_ty,
749                async_hint,
750            }
751        }
752
753        pub(super) fn ensures(
754            genv: GlobalEnv<'_, 'tcx>,
755            fn_id: MaybeExternId,
756            decl: &fhir::FnDecl,
757            expected: &rty::Ty,
758            i: usize,
759        ) -> Self {
760            let fhir::Ensures::Type(_, ty) = &decl.output.ensures[i] else {
761                span_bug!(decl.span, "expected `fhir::Ensures::Type`");
762            };
763            let tcx = genv.tcx();
764            Self {
765                span: ty.span,
766                def_descr: tcx.def_descr(fn_id.resolved_id()),
767                expected_span: None,
768                expected_ty: expected.to_rustc(tcx),
769                async_hint: None,
770            }
771        }
772
773        pub(super) fn field(
774            genv: GlobalEnv<'_, 'tcx>,
775            adt_id: MaybeExternId,
776            variant_idx: VariantIdx,
777            field_idx: FieldIdx,
778        ) -> Self {
779            let tcx = genv.tcx();
780            let adt_def = tcx.adt_def(adt_id);
781            let field_def = &adt_def.variant(variant_idx).fields[field_idx];
782
783            let item = genv.fhir_expect_item(adt_id.local_id()).unwrap();
784            let span = match &item.kind {
785                fhir::ItemKind::Enum(enum_def) => {
786                    enum_def.variants[variant_idx.as_usize()].fields[field_idx.as_usize()]
787                        .ty
788                        .span
789                }
790                fhir::ItemKind::Struct(struct_def)
791                    if let fhir::StructKind::Transparent { fields } = &struct_def.kind =>
792                {
793                    fields[field_idx.as_usize()].ty.span
794                }
795                _ => DUMMY_SP,
796            };
797
798            Self {
799                span,
800                def_descr: tcx.def_descr(field_def.did),
801                expected_span: Some(tcx.def_span(field_def.did)),
802                expected_ty: tcx.type_of(field_def.did).skip_binder(),
803                async_hint: None,
804            }
805        }
806    }
807
808    #[derive(Diagnostic)]
809    #[diag(fhir_analysis_incompatible_param_count, code = E0999)]
810    pub(super) struct IncompatibleParamCount {
811        #[primary_span]
812        #[label]
813        span: Span,
814        found: usize,
815        #[label(fhir_analysis_expected_label)]
816        expected_span: Span,
817        expected: usize,
818        def_descr: &'static str,
819    }
820
821    impl IncompatibleParamCount {
822        pub(super) fn new(genv: GlobalEnv, decl: &fhir::FnDecl, def_id: MaybeExternId) -> Self {
823            let def_descr = genv.tcx().def_descr(def_id.resolved_id());
824
825            let span = if !decl.inputs.is_empty() {
826                decl.inputs[decl.inputs.len() - 1]
827                    .span
828                    .with_lo(decl.inputs[0].span.lo())
829            } else {
830                decl.span
831            };
832
833            let expected_span = if let Some(local_id) = def_id.as_local()
834                && let expected_decl = genv.tcx().hir_node_by_def_id(local_id).fn_decl().unwrap()
835                && !expected_decl.inputs.is_empty()
836            {
837                expected_decl.inputs[expected_decl.inputs.len() - 1]
838                    .span
839                    .with_lo(expected_decl.inputs[0].span.lo())
840            } else {
841                genv.tcx().def_span(def_id)
842            };
843
844            let expected = genv
845                .tcx()
846                .fn_sig(def_id)
847                .skip_binder()
848                .skip_binder()
849                .inputs()
850                .len();
851
852            Self { span, found: decl.inputs.len(), expected_span, expected, def_descr }
853        }
854    }
855
856    #[derive(Diagnostic)]
857    #[diag(fhir_analysis_field_count_mismatch, code = E0999)]
858    pub(super) struct FieldCountMismatch {
859        #[primary_span]
860        #[label]
861        span: Span,
862        fields: usize,
863        #[label(fhir_analysis_expected_label)]
864        expected_span: Span,
865        expected_fields: usize,
866    }
867
868    impl FieldCountMismatch {
869        pub(super) fn new(
870            genv: GlobalEnv,
871            found: usize,
872            adt_def_id: MaybeExternId,
873            variant_idx: VariantIdx,
874        ) -> Self {
875            let adt_def = genv.tcx().adt_def(adt_def_id);
876            let expected_variant = adt_def.variant(variant_idx);
877
878            // Get the span of the variant if this is an enum. Structs cannot have produce a field
879            // count mismatch.
880            let span = if let Ok(fhir::Node::Item(item)) = genv.fhir_node(adt_def_id.local_id())
881                && let fhir::ItemKind::Enum(enum_def) = &item.kind
882                && let Some(variant) = enum_def.variants.get(variant_idx.as_usize())
883            {
884                variant.span
885            } else {
886                DUMMY_SP
887            };
888
889            Self {
890                span,
891                fields: found,
892                expected_span: genv.tcx().def_span(expected_variant.def_id),
893                expected_fields: expected_variant.fields.len(),
894            }
895        }
896    }
897}