1use std::{collections::hash_map, iter};
2
3use flux_common::bug;
4use flux_rustc_bridge::ty;
5use rustc_data_structures::unord::UnordMap;
6
7use super::fold::TypeFoldable;
8use crate::{rty, rty::fold::TypeFolder};
9
10pub fn ty_match_regions(a: &rty::Ty, b: &ty::Ty) -> rty::Ty {
12 let a = replace_regions_with_unique_vars(a);
13 let mut subst = RegionSubst::default();
14 subst.ty_infer_from_ty(&a, b);
15 subst.apply(&a)
16}
17
18pub fn rty_match_regions(a: &rty::Ty, b: &rty::Ty) -> rty::Ty {
19 let a = replace_regions_with_unique_vars(a);
20 let mut subst = RegionSubst::default();
21 subst.rty_infer_from_ty(&a, b);
22 subst.apply(&a)
23}
24
25fn replace_regions_with_unique_vars(ty: &rty::Ty) -> rty::Ty {
29 struct Replacer {
30 next_rvid: u32,
31 }
32 impl TypeFolder for Replacer {
33 fn fold_region(&mut self, re: &rty::Region) -> rty::Region {
34 if let rty::ReBound(..) = re {
35 *re
36 } else {
37 let rvid = self.next_rvid;
38 self.next_rvid += 1;
39 rty::ReVar(rty::RegionVid::from_u32(rvid))
40 }
41 }
42 }
43
44 ty.fold_with(&mut Replacer { next_rvid: 0 })
45}
46
47#[derive(Default, Debug)]
48struct RegionSubst {
49 map: UnordMap<rty::RegionVid, rty::Region>,
50}
51
52impl RegionSubst {
53 fn apply<T: TypeFoldable>(&self, t: &T) -> T {
54 struct Folder<'a>(&'a RegionSubst);
55 impl TypeFolder for Folder<'_> {
56 fn fold_region(&mut self, re: &rty::Region) -> rty::Region {
57 if let rty::ReVar(rvid) = re
59 && let Some(region) = self.0.map.get(rvid)
60 {
61 *region
62 } else {
63 *re
64 }
65 }
66 }
67 t.fold_with(&mut Folder(self))
68 }
69
70 fn infer_from_region(&mut self, a: rty::Region, b: rty::Region) {
71 let rty::ReVar(var) = a else { return };
72 match self.map.entry(var) {
73 hash_map::Entry::Occupied(entry) => {
74 if entry.get() != &b {
75 bug!("ambiguous region substitution: {:?} -> [{:?}, {:?}]", a, entry.get(), b);
76 }
77 }
78 hash_map::Entry::Vacant(entry) => {
79 entry.insert(b);
80 }
81 }
82 }
83}
84
85impl RegionSubst {
86 fn ty_infer_from_fn_sig(&mut self, a: &rty::FnSig, b: &ty::FnSig) {
87 debug_assert_eq!(a.inputs().len(), b.inputs().len());
88 for (ty_a, ty_b) in iter::zip(a.inputs(), b.inputs()) {
89 self.ty_infer_from_ty(ty_a, ty_b);
90 }
91 self.ty_infer_from_ty(&a.output().skip_binder_ref().ret, b.output());
92 }
93
94 fn ty_infer_from_ty(&mut self, a: &rty::Ty, b: &ty::Ty) {
95 match (a.kind(), b.kind()) {
96 (rty::TyKind::Exists(ty_a), _) => {
97 self.ty_infer_from_ty(ty_a.as_ref().skip_binder(), b);
98 }
99 (rty::TyKind::Constr(_, ty_a), _) => {
100 self.ty_infer_from_ty(ty_a, b);
101 }
102 (rty::TyKind::Indexed(bty_a, _), _) => {
103 self.ty_infer_from_bty(bty_a, b);
104 }
105 (rty::TyKind::Ptr(rty::PtrKind::Mut(re_a), _), ty::TyKind::Ref(re_b, _, mutbl)) => {
106 debug_assert!(mutbl.is_mut());
107 self.infer_from_region(*re_a, *re_b);
108 }
109 (rty::TyKind::StrgRef(re_a, ..), ty::TyKind::Ref(re_b, _, mutbl)) => {
110 debug_assert!(mutbl.is_mut());
111 self.infer_from_region(*re_a, *re_b);
112 }
113 _ => {}
114 }
115 }
116
117 fn ty_infer_from_bty(&mut self, a: &rty::BaseTy, ty: &ty::Ty) {
118 match (a, ty.kind()) {
119 (rty::BaseTy::Adt(_, args_a), ty::TyKind::Adt(_, args_b)) => {
120 self.ty_infer_from_generic_args(args_a, args_b);
121 }
122 (rty::BaseTy::Array(ty_a, _), ty::TyKind::Array(ty_b, _)) => {
123 self.ty_infer_from_ty(ty_a, ty_b);
124 }
125 (rty::BaseTy::Ref(re_a, ty_a, mutbl_a), ty::TyKind::Ref(re_b, ty_b, mutbl_b)) => {
126 debug_assert_eq!(mutbl_a, mutbl_b);
127 self.infer_from_region(*re_a, *re_b);
128 self.ty_infer_from_ty(ty_a, ty_b);
129 }
130 (rty::BaseTy::Tuple(fields_a), ty::TyKind::Tuple(fields_b)) => {
131 debug_assert_eq!(fields_a.len(), fields_b.len());
132 for (ty_a, ty_b) in iter::zip(fields_a, fields_b) {
133 self.ty_infer_from_ty(ty_a, ty_b);
134 }
135 }
136 (rty::BaseTy::Slice(ty_a), ty::TyKind::Slice(ty_b)) => {
137 self.ty_infer_from_ty(ty_a, ty_b);
138 }
139 (rty::BaseTy::FnPtr(poly_sig_a), ty::TyKind::FnPtr(poly_sig_b)) => {
140 self.ty_infer_from_fn_sig(
141 poly_sig_a.skip_binder_ref(),
142 poly_sig_b.skip_binder_ref(),
143 );
144 }
145 (rty::BaseTy::RawPtr(ty_a, mutbl_a), ty::TyKind::RawPtr(ty_b, mutbl_b)) => {
146 debug_assert_eq!(mutbl_a, mutbl_b);
147 self.ty_infer_from_ty(ty_a, ty_b);
148 }
149 (rty::BaseTy::Dynamic(preds_a, re_a), ty::TyKind::Dynamic(preds_b, re_b)) => {
150 debug_assert_eq!(preds_a.len(), preds_b.len());
151 self.infer_from_region(*re_a, *re_b);
152 for (pred_a, pred_b) in iter::zip(preds_a, preds_b) {
153 self.ty_infer_from_existential_pred(pred_a, pred_b);
154 }
155 }
156 _ => {}
157 }
158 }
159
160 fn ty_infer_from_existential_pred(
161 &mut self,
162 a: &rty::PolyExistentialPredicate,
163 b: &ty::PolyExistentialPredicate,
164 ) {
165 match (a.as_ref().skip_binder(), b.as_ref().skip_binder()) {
166 (
167 rty::ExistentialPredicate::Trait(trait_ref_a),
168 ty::ExistentialPredicate::Trait(trait_ref_b),
169 ) => {
170 debug_assert_eq!(trait_ref_a.def_id, trait_ref_b.def_id);
171 self.ty_infer_from_generic_args(&trait_ref_a.args, &trait_ref_b.args);
172 }
173 (
174 rty::ExistentialPredicate::Projection(proj_a),
175 ty::ExistentialPredicate::Projection(proj_b),
176 ) => {
177 debug_assert_eq!(proj_a.def_id, proj_b.def_id);
178 self.ty_infer_from_generic_args(&proj_a.args, &proj_b.args);
179 self.ty_infer_from_bty(proj_a.term.as_bty_skipping_binder(), &proj_b.term);
180 }
181 _ => {}
182 }
183 }
184
185 fn ty_infer_from_generic_args(&mut self, a: &rty::GenericArgs, b: &ty::GenericArgs) {
186 debug_assert_eq!(a.len(), b.len());
187 for (arg_a, arg_b) in iter::zip(a, b) {
188 self.ty_infer_from_generic_arg(arg_a, arg_b);
189 }
190 }
191
192 fn ty_infer_from_generic_arg(&mut self, a: &rty::GenericArg, b: &ty::GenericArg) {
193 match (a, b) {
194 (rty::GenericArg::Base(ctor_a), ty::GenericArg::Ty(ty_b)) => {
195 self.ty_infer_from_bty(ctor_a.as_bty_skipping_binder(), ty_b);
196 }
197 (rty::GenericArg::Ty(ty_a), ty::GenericArg::Ty(ty_b)) => {
198 self.ty_infer_from_ty(ty_a, ty_b);
199 }
200 (rty::GenericArg::Lifetime(re_a), ty::GenericArg::Lifetime(re_b)) => {
201 self.infer_from_region(*re_a, *re_b);
202 }
203 _ => {}
204 }
205 }
206}
207
208impl RegionSubst {
209 fn rty_infer_from_fn_sig(&mut self, a: &rty::FnSig, b: &rty::FnSig) {
210 debug_assert_eq!(a.inputs().len(), b.inputs().len());
211 for (ty_a, ty_b) in iter::zip(a.inputs(), b.inputs()) {
212 self.rty_infer_from_ty(ty_a, ty_b);
213 }
214 self.rty_infer_from_ty(
215 &a.output().skip_binder_ref().ret,
216 &b.output().skip_binder_ref().ret,
217 );
218 }
219
220 fn rty_infer_from_ty(&mut self, a: &rty::Ty, b: &rty::Ty) {
221 match (a.kind(), b.kind()) {
222 (rty::TyKind::Exists(ctor_a), _) => {
223 self.rty_infer_from_ty(ctor_a.skip_binder_ref(), b);
224 }
225 (_, rty::TyKind::Exists(ctor_b)) => {
226 self.rty_infer_from_ty(a, ctor_b.skip_binder_ref());
227 }
228 (rty::TyKind::Constr(_, ty_a), _) => self.rty_infer_from_ty(ty_a, b),
229 (_, rty::TyKind::Constr(_, ty_b)) => self.rty_infer_from_ty(a, ty_b),
230 (rty::TyKind::Indexed(bty_a, _), rty::TyKind::Indexed(bty_b, _)) => {
231 self.rty_infer_from_bty(bty_a, bty_b);
232 }
233 (rty::TyKind::StrgRef(re_a, _, ty_a), rty::TyKind::StrgRef(re_b, _, ty_b)) => {
234 self.infer_from_region(*re_a, *re_b);
235 self.rty_infer_from_ty(ty_a, ty_b);
236 }
237 _ => {}
238 }
239 }
240
241 fn rty_infer_from_bty(&mut self, a: &rty::BaseTy, b: &rty::BaseTy) {
242 match (a, b) {
243 (rty::BaseTy::Slice(ty_a), rty::BaseTy::Slice(ty_b)) => {
244 self.rty_infer_from_ty(ty_a, ty_b);
245 }
246 (rty::BaseTy::Adt(adt_def_a, args_a), rty::BaseTy::Adt(adt_def_b, args_b)) => {
247 debug_assert_eq!(adt_def_a.did(), adt_def_b.did());
248 for (arg_a, arg_b) in iter::zip(args_a, args_b) {
249 self.rty_infer_from_generic_arg(arg_a, arg_b);
250 }
251 }
252 (rty::BaseTy::RawPtr(ty_a, mutbl_a), rty::BaseTy::RawPtr(ty_b, mutbl_b)) => {
253 debug_assert_eq!(mutbl_a, mutbl_b);
254 self.rty_infer_from_ty(ty_a, ty_b);
255 }
256 (rty::BaseTy::Ref(re_a, ty_a, mutbl_a), rty::BaseTy::Ref(re_b, ty_b, mutbl_b)) => {
257 debug_assert_eq!(mutbl_a, mutbl_b);
258 self.infer_from_region(*re_a, *re_b);
259 self.rty_infer_from_ty(ty_a, ty_b);
260 }
261 (rty::BaseTy::FnPtr(poly_sig_a), rty::BaseTy::FnPtr(poly_sig_b)) => {
262 self.rty_infer_from_fn_sig(
263 poly_sig_a.skip_binder_ref(),
264 poly_sig_b.skip_binder_ref(),
265 );
266 }
267 (rty::BaseTy::Tuple(tys_a), rty::BaseTy::Tuple(tys_b)) => {
268 for (ty_a, ty_b) in iter::zip(tys_a, tys_b) {
269 self.rty_infer_from_ty(ty_a, ty_b);
270 }
271 }
272 (rty::BaseTy::Alias(_, aty_a), rty::BaseTy::Alias(_, aty_b)) => {
273 for (arg_a, arg_b) in iter::zip(&aty_a.args, &aty_b.args) {
274 self.rty_infer_from_generic_arg(arg_a, arg_b);
275 }
276 }
277 (rty::BaseTy::Array(ty_a, _), rty::BaseTy::Array(ty_b, _)) => {
278 self.rty_infer_from_ty(ty_a, ty_b);
279 }
280 (rty::BaseTy::Dynamic(preds_a, re_a), rty::BaseTy::Dynamic(preds_b, re_b)) => {
281 for (pred_a, pred_b) in iter::zip(preds_a, preds_b) {
282 self.rty_infer_from_existential_pred(pred_a, pred_b);
283 }
284 self.infer_from_region(*re_a, *re_b);
285 }
286 _ => {}
287 }
288 }
289
290 fn rty_infer_from_generic_arg(&mut self, a: &rty::GenericArg, b: &rty::GenericArg) {
291 match (a, b) {
292 (rty::GenericArg::Ty(ty_a), rty::GenericArg::Ty(ty_b)) => {
293 self.rty_infer_from_ty(ty_a, ty_b);
294 }
295 (rty::GenericArg::Base(ctor_a), rty::GenericArg::Base(ctor_b)) => {
296 self.rty_infer_from_bty(
297 ctor_a.as_bty_skipping_binder(),
298 ctor_b.as_bty_skipping_binder(),
299 );
300 }
301 (rty::GenericArg::Lifetime(re_a), rty::GenericArg::Lifetime(re_b)) => {
302 self.infer_from_region(*re_a, *re_b);
303 }
304 _ => {}
305 }
306 }
307
308 fn rty_infer_from_existential_pred(
309 &mut self,
310 a: &rty::Binder<rty::ExistentialPredicate>,
311 b: &rty::Binder<rty::ExistentialPredicate>,
312 ) {
313 match (a.skip_binder_ref(), b.skip_binder_ref()) {
314 (
315 rty::ExistentialPredicate::Trait(trait_ref_a),
316 rty::ExistentialPredicate::Trait(trait_ref_b),
317 ) => {
318 debug_assert_eq!(trait_ref_a.def_id, trait_ref_b.def_id);
319 for (arg_a, arg_b) in iter::zip(&trait_ref_a.args, &trait_ref_b.args) {
320 self.rty_infer_from_generic_arg(arg_a, arg_b);
321 }
322 }
323 (
324 rty::ExistentialPredicate::Projection(proj_a),
325 rty::ExistentialPredicate::Projection(proj_b),
326 ) => {
327 debug_assert_eq!(proj_a.def_id, proj_b.def_id);
328 for (arg_a, arg_b) in iter::zip(&proj_a.args, &proj_b.args) {
329 self.rty_infer_from_generic_arg(arg_a, arg_b);
330 }
331 self.rty_infer_from_bty(
332 proj_a.term.as_bty_skipping_binder(),
333 proj_b.term.as_bty_skipping_binder(),
334 );
335 }
336 _ => {}
337 }
338 }
339}