flux_middle/rty/
region_matching.rs

1use std::{collections::hash_map, iter};
2
3use flux_common::bug;
4use flux_rustc_bridge::ty;
5use rustc_data_structures::unord::UnordMap;
6
7use super::fold::TypeFoldable;
8use crate::{rty, rty::fold::TypeFolder};
9
10/// See `flux_refineck::type_env::TypeEnv::assign`
11pub fn ty_match_regions(a: &rty::Ty, b: &ty::Ty) -> rty::Ty {
12    let a = replace_regions_with_unique_vars(a);
13    let mut subst = RegionSubst::default();
14    subst.ty_infer_from_ty(&a, b);
15    subst.apply(&a)
16}
17
18pub fn rty_match_regions(a: &rty::Ty, b: &rty::Ty) -> rty::Ty {
19    let a = replace_regions_with_unique_vars(a);
20    let mut subst = RegionSubst::default();
21    subst.rty_infer_from_ty(&a, b);
22    subst.apply(&a)
23}
24
25/// Replace all non-bound regions with a [`rty::ReVar`] assigning each a unique [`rty::RegionVid`].
26/// This is used to have a unique identifier for each position such that we can infer a region
27/// substitution.
28fn replace_regions_with_unique_vars(ty: &rty::Ty) -> rty::Ty {
29    struct Replacer {
30        next_rvid: u32,
31    }
32    impl TypeFolder for Replacer {
33        fn fold_region(&mut self, re: &rty::Region) -> rty::Region {
34            if let rty::ReBound(..) = re {
35                *re
36            } else {
37                let rvid = self.next_rvid;
38                self.next_rvid += 1;
39                rty::ReVar(rty::RegionVid::from_u32(rvid))
40            }
41        }
42    }
43
44    ty.fold_with(&mut Replacer { next_rvid: 0 })
45}
46
47#[derive(Default, Debug)]
48struct RegionSubst {
49    map: UnordMap<rty::RegionVid, rty::Region>,
50}
51
52impl RegionSubst {
53    fn apply<T: TypeFoldable>(&self, t: &T) -> T {
54        struct Folder<'a>(&'a RegionSubst);
55        impl TypeFolder for Folder<'_> {
56            fn fold_region(&mut self, re: &rty::Region) -> rty::Region {
57                // FIXME the map should always contain a region
58                if let rty::ReVar(rvid) = re
59                    && let Some(region) = self.0.map.get(rvid)
60                {
61                    *region
62                } else {
63                    *re
64                }
65            }
66        }
67        t.fold_with(&mut Folder(self))
68    }
69
70    fn infer_from_region(&mut self, a: rty::Region, b: rty::Region) {
71        let rty::ReVar(var) = a else { return };
72        match self.map.entry(var) {
73            hash_map::Entry::Occupied(entry) => {
74                if entry.get() != &b {
75                    bug!("ambiguous region substitution: {:?} -> [{:?}, {:?}]", a, entry.get(), b);
76                }
77            }
78            hash_map::Entry::Vacant(entry) => {
79                entry.insert(b);
80            }
81        }
82    }
83}
84
85impl RegionSubst {
86    fn ty_infer_from_fn_sig(&mut self, a: &rty::FnSig, b: &ty::FnSig) {
87        debug_assert_eq!(a.inputs().len(), b.inputs().len());
88        for (ty_a, ty_b) in iter::zip(a.inputs(), b.inputs()) {
89            self.ty_infer_from_ty(ty_a, ty_b);
90        }
91        self.ty_infer_from_ty(&a.output().skip_binder_ref().ret, b.output());
92    }
93
94    fn ty_infer_from_ty(&mut self, a: &rty::Ty, b: &ty::Ty) {
95        match (a.kind(), b.kind()) {
96            (rty::TyKind::Exists(ty_a), _) => {
97                self.ty_infer_from_ty(ty_a.as_ref().skip_binder(), b);
98            }
99            (rty::TyKind::Constr(_, ty_a), _) => {
100                self.ty_infer_from_ty(ty_a, b);
101            }
102            (rty::TyKind::Indexed(bty_a, _), _) => {
103                self.ty_infer_from_bty(bty_a, b);
104            }
105            (rty::TyKind::Ptr(rty::PtrKind::Mut(re_a), _), ty::TyKind::Ref(re_b, _, mutbl)) => {
106                debug_assert!(mutbl.is_mut());
107                self.infer_from_region(*re_a, *re_b);
108            }
109            (rty::TyKind::StrgRef(re_a, ..), ty::TyKind::Ref(re_b, _, mutbl)) => {
110                debug_assert!(mutbl.is_mut());
111                self.infer_from_region(*re_a, *re_b);
112            }
113            _ => {}
114        }
115    }
116
117    fn ty_infer_from_bty(&mut self, a: &rty::BaseTy, ty: &ty::Ty) {
118        match (a, ty.kind()) {
119            (rty::BaseTy::Adt(_, args_a), ty::TyKind::Adt(_, args_b)) => {
120                self.ty_infer_from_generic_args(args_a, args_b);
121            }
122            (rty::BaseTy::Array(ty_a, _), ty::TyKind::Array(ty_b, _)) => {
123                self.ty_infer_from_ty(ty_a, ty_b);
124            }
125            (rty::BaseTy::Ref(re_a, ty_a, mutbl_a), ty::TyKind::Ref(re_b, ty_b, mutbl_b)) => {
126                debug_assert_eq!(mutbl_a, mutbl_b);
127                self.infer_from_region(*re_a, *re_b);
128                self.ty_infer_from_ty(ty_a, ty_b);
129            }
130            (rty::BaseTy::Tuple(fields_a), ty::TyKind::Tuple(fields_b)) => {
131                debug_assert_eq!(fields_a.len(), fields_b.len());
132                for (ty_a, ty_b) in iter::zip(fields_a, fields_b) {
133                    self.ty_infer_from_ty(ty_a, ty_b);
134                }
135            }
136            (rty::BaseTy::Slice(ty_a), ty::TyKind::Slice(ty_b)) => {
137                self.ty_infer_from_ty(ty_a, ty_b);
138            }
139            (rty::BaseTy::FnPtr(poly_sig_a), ty::TyKind::FnPtr(poly_sig_b)) => {
140                self.ty_infer_from_fn_sig(
141                    poly_sig_a.skip_binder_ref(),
142                    poly_sig_b.skip_binder_ref(),
143                );
144            }
145            (rty::BaseTy::RawPtr(ty_a, mutbl_a), ty::TyKind::RawPtr(ty_b, mutbl_b)) => {
146                debug_assert_eq!(mutbl_a, mutbl_b);
147                self.ty_infer_from_ty(ty_a, ty_b);
148            }
149            (rty::BaseTy::Dynamic(preds_a, re_a), ty::TyKind::Dynamic(preds_b, re_b)) => {
150                debug_assert_eq!(preds_a.len(), preds_b.len());
151                self.infer_from_region(*re_a, *re_b);
152                for (pred_a, pred_b) in iter::zip(preds_a, preds_b) {
153                    self.ty_infer_from_existential_pred(pred_a, pred_b);
154                }
155            }
156            _ => {}
157        }
158    }
159
160    fn ty_infer_from_existential_pred(
161        &mut self,
162        a: &rty::PolyExistentialPredicate,
163        b: &ty::PolyExistentialPredicate,
164    ) {
165        match (a.as_ref().skip_binder(), b.as_ref().skip_binder()) {
166            (
167                rty::ExistentialPredicate::Trait(trait_ref_a),
168                ty::ExistentialPredicate::Trait(trait_ref_b),
169            ) => {
170                debug_assert_eq!(trait_ref_a.def_id, trait_ref_b.def_id);
171                self.ty_infer_from_generic_args(&trait_ref_a.args, &trait_ref_b.args);
172            }
173            (
174                rty::ExistentialPredicate::Projection(proj_a),
175                ty::ExistentialPredicate::Projection(proj_b),
176            ) => {
177                debug_assert_eq!(proj_a.def_id, proj_b.def_id);
178                self.ty_infer_from_generic_args(&proj_a.args, &proj_b.args);
179                self.ty_infer_from_bty(proj_a.term.as_bty_skipping_binder(), &proj_b.term);
180            }
181            _ => {}
182        }
183    }
184
185    fn ty_infer_from_generic_args(&mut self, a: &rty::GenericArgs, b: &ty::GenericArgs) {
186        debug_assert_eq!(a.len(), b.len());
187        for (arg_a, arg_b) in iter::zip(a, b) {
188            self.ty_infer_from_generic_arg(arg_a, arg_b);
189        }
190    }
191
192    fn ty_infer_from_generic_arg(&mut self, a: &rty::GenericArg, b: &ty::GenericArg) {
193        match (a, b) {
194            (rty::GenericArg::Base(ctor_a), ty::GenericArg::Ty(ty_b)) => {
195                self.ty_infer_from_bty(ctor_a.as_bty_skipping_binder(), ty_b);
196            }
197            (rty::GenericArg::Ty(ty_a), ty::GenericArg::Ty(ty_b)) => {
198                self.ty_infer_from_ty(ty_a, ty_b);
199            }
200            (rty::GenericArg::Lifetime(re_a), ty::GenericArg::Lifetime(re_b)) => {
201                self.infer_from_region(*re_a, *re_b);
202            }
203            _ => {}
204        }
205    }
206}
207
208impl RegionSubst {
209    fn rty_infer_from_fn_sig(&mut self, a: &rty::FnSig, b: &rty::FnSig) {
210        debug_assert_eq!(a.inputs().len(), b.inputs().len());
211        for (ty_a, ty_b) in iter::zip(a.inputs(), b.inputs()) {
212            self.rty_infer_from_ty(ty_a, ty_b);
213        }
214        self.rty_infer_from_ty(
215            &a.output().skip_binder_ref().ret,
216            &b.output().skip_binder_ref().ret,
217        );
218    }
219
220    fn rty_infer_from_ty(&mut self, a: &rty::Ty, b: &rty::Ty) {
221        match (a.kind(), b.kind()) {
222            (rty::TyKind::Exists(ctor_a), _) => {
223                self.rty_infer_from_ty(ctor_a.skip_binder_ref(), b);
224            }
225            (_, rty::TyKind::Exists(ctor_b)) => {
226                self.rty_infer_from_ty(a, ctor_b.skip_binder_ref());
227            }
228            (rty::TyKind::Constr(_, ty_a), _) => self.rty_infer_from_ty(ty_a, b),
229            (_, rty::TyKind::Constr(_, ty_b)) => self.rty_infer_from_ty(a, ty_b),
230            (rty::TyKind::Indexed(bty_a, _), rty::TyKind::Indexed(bty_b, _)) => {
231                self.rty_infer_from_bty(bty_a, bty_b);
232            }
233            (rty::TyKind::StrgRef(re_a, _, ty_a), rty::TyKind::StrgRef(re_b, _, ty_b)) => {
234                self.infer_from_region(*re_a, *re_b);
235                self.rty_infer_from_ty(ty_a, ty_b);
236            }
237            _ => {}
238        }
239    }
240
241    fn rty_infer_from_bty(&mut self, a: &rty::BaseTy, b: &rty::BaseTy) {
242        match (a, b) {
243            (rty::BaseTy::Slice(ty_a), rty::BaseTy::Slice(ty_b)) => {
244                self.rty_infer_from_ty(ty_a, ty_b);
245            }
246            (rty::BaseTy::Adt(adt_def_a, args_a), rty::BaseTy::Adt(adt_def_b, args_b)) => {
247                debug_assert_eq!(adt_def_a.did(), adt_def_b.did());
248                for (arg_a, arg_b) in iter::zip(args_a, args_b) {
249                    self.rty_infer_from_generic_arg(arg_a, arg_b);
250                }
251            }
252            (rty::BaseTy::RawPtr(ty_a, mutbl_a), rty::BaseTy::RawPtr(ty_b, mutbl_b)) => {
253                debug_assert_eq!(mutbl_a, mutbl_b);
254                self.rty_infer_from_ty(ty_a, ty_b);
255            }
256            (rty::BaseTy::Ref(re_a, ty_a, mutbl_a), rty::BaseTy::Ref(re_b, ty_b, mutbl_b)) => {
257                debug_assert_eq!(mutbl_a, mutbl_b);
258                self.infer_from_region(*re_a, *re_b);
259                self.rty_infer_from_ty(ty_a, ty_b);
260            }
261            (rty::BaseTy::FnPtr(poly_sig_a), rty::BaseTy::FnPtr(poly_sig_b)) => {
262                self.rty_infer_from_fn_sig(
263                    poly_sig_a.skip_binder_ref(),
264                    poly_sig_b.skip_binder_ref(),
265                );
266            }
267            (rty::BaseTy::Tuple(tys_a), rty::BaseTy::Tuple(tys_b)) => {
268                for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
269                    self.rty_infer_from_ty(ty_a, ty_b);
270                }
271            }
272            (rty::BaseTy::Alias(_, aty_a), rty::BaseTy::Alias(_, aty_b)) => {
273                for (arg_a, arg_b) in iter::zip(&aty_a.args, &aty_b.args) {
274                    self.rty_infer_from_generic_arg(arg_a, arg_b);
275                }
276            }
277            (rty::BaseTy::Array(ty_a, _), rty::BaseTy::Array(ty_b, _)) => {
278                self.rty_infer_from_ty(ty_a, ty_b);
279            }
280            (rty::BaseTy::Dynamic(preds_a, re_a), rty::BaseTy::Dynamic(preds_b, re_b)) => {
281                for (pred_a, pred_b) in iter::zip(preds_a, preds_b) {
282                    self.rty_infer_from_existential_pred(pred_a, pred_b);
283                }
284                self.infer_from_region(*re_a, *re_b);
285            }
286            _ => {}
287        }
288    }
289
290    fn rty_infer_from_generic_arg(&mut self, a: &rty::GenericArg, b: &rty::GenericArg) {
291        match (a, b) {
292            (rty::GenericArg::Ty(ty_a), rty::GenericArg::Ty(ty_b)) => {
293                self.rty_infer_from_ty(ty_a, ty_b);
294            }
295            (rty::GenericArg::Base(ctor_a), rty::GenericArg::Base(ctor_b)) => {
296                self.rty_infer_from_bty(
297                    ctor_a.as_bty_skipping_binder(),
298                    ctor_b.as_bty_skipping_binder(),
299                );
300            }
301            (rty::GenericArg::Lifetime(re_a), rty::GenericArg::Lifetime(re_b)) => {
302                self.infer_from_region(*re_a, *re_b);
303            }
304            _ => {}
305        }
306    }
307
308    fn rty_infer_from_existential_pred(
309        &mut self,
310        a: &rty::Binder<rty::ExistentialPredicate>,
311        b: &rty::Binder<rty::ExistentialPredicate>,
312    ) {
313        match (a.skip_binder_ref(), b.skip_binder_ref()) {
314            (
315                rty::ExistentialPredicate::Trait(trait_ref_a),
316                rty::ExistentialPredicate::Trait(trait_ref_b),
317            ) => {
318                debug_assert_eq!(trait_ref_a.def_id, trait_ref_b.def_id);
319                for (arg_a, arg_b) in iter::zip(&trait_ref_a.args, &trait_ref_b.args) {
320                    self.rty_infer_from_generic_arg(arg_a, arg_b);
321                }
322            }
323            (
324                rty::ExistentialPredicate::Projection(proj_a),
325                rty::ExistentialPredicate::Projection(proj_b),
326            ) => {
327                debug_assert_eq!(proj_a.def_id, proj_b.def_id);
328                for (arg_a, arg_b) in iter::zip(&proj_a.args, &proj_b.args) {
329                    self.rty_infer_from_generic_arg(arg_a, arg_b);
330                }
331                self.rty_infer_from_bty(
332                    proj_a.term.as_bty_skipping_binder(),
333                    proj_b.term.as_bty_skipping_binder(),
334                );
335            }
336            _ => {}
337        }
338    }
339}