1use flux_arc_interner::{List, SliceInternable};
7use flux_common::bug;
8use flux_rustc_bridge::{ty, ty::GenericArgsExt as _};
9use itertools::Itertools;
10use rustc_abi::VariantIdx;
11use rustc_hir::{def::DefKind, def_id::DefId};
12use rustc_middle::ty::ParamTy;
13
14use super::{RefineArgsExt, fold::TypeFoldable};
15use crate::{
16 global_env::GlobalEnv,
17 queries::{QueryErr, QueryResult},
18 query_bug, rty,
19};
20
21pub fn refine_generics(genv: GlobalEnv, def_id: DefId, generics: &ty::Generics) -> rty::Generics {
22 let is_box = if let DefKind::Struct = genv.def_kind(def_id) {
23 genv.tcx().adt_def(def_id).is_box()
24 } else {
25 false
26 };
27 let params = generics
28 .params
29 .iter()
30 .map(|param| {
31 rty::GenericParamDef {
32 kind: refine_generic_param_def_kind(is_box, param.kind),
33 index: param.index,
34 name: param.name,
35 def_id: param.def_id,
36 }
37 })
38 .collect();
39
40 rty::Generics {
41 own_params: params,
42 parent: generics.parent(),
43 parent_count: generics.parent_count(),
44 has_self: generics.orig.has_self,
45 }
46}
47
48pub fn refine_generic_param_def_kind(
49 is_box: bool,
50 kind: ty::GenericParamDefKind,
51) -> rty::GenericParamDefKind {
52 match kind {
53 ty::GenericParamDefKind::Lifetime => rty::GenericParamDefKind::Lifetime,
54 ty::GenericParamDefKind::Type { has_default } => {
55 if is_box {
56 rty::GenericParamDefKind::Type { has_default }
57 } else {
58 rty::GenericParamDefKind::Base { has_default }
59 }
60 }
61 ty::GenericParamDefKind::Const { has_default, .. } => {
62 rty::GenericParamDefKind::Const { has_default }
63 }
64 }
65}
66
67pub struct Refiner<'genv, 'tcx> {
68 genv: GlobalEnv<'genv, 'tcx>,
69 def_id: DefId,
70 generics: rty::Generics,
71 refine: fn(rty::BaseTy) -> rty::SubsetTyCtor,
72}
73
74impl<'genv, 'tcx> Refiner<'genv, 'tcx> {
75 pub fn new_for_item(
76 genv: GlobalEnv<'genv, 'tcx>,
77 def_id: DefId,
78 refine: fn(rty::BaseTy) -> rty::SubsetTyCtor,
79 ) -> QueryResult<Self> {
80 let generics = genv.generics_of(def_id)?;
81 Ok(Self { genv, def_id, generics, refine })
82 }
83
84 pub fn default_for_item(genv: GlobalEnv<'genv, 'tcx>, def_id: DefId) -> QueryResult<Self> {
85 Self::new_for_item(genv, def_id, refine_default)
86 }
87
88 pub fn with_holes(genv: GlobalEnv<'genv, 'tcx>, def_id: DefId) -> QueryResult<Self> {
89 Self::new_for_item(genv, def_id, |bty| {
90 let sort = bty.sort();
91 let constr = rty::SubsetTy::new(
92 bty.shift_in_escaping(1),
93 rty::Expr::nu(),
94 rty::Expr::hole(rty::HoleKind::Pred),
95 );
96 rty::Binder::bind_with_sort(constr, sort)
97 })
98 }
99
100 pub fn refine<T: Refine + ?Sized>(&self, t: &T) -> QueryResult<T::Output> {
101 t.refine(self)
102 }
103
104 fn refine_existential_predicate_generic_args(
105 &self,
106 def_id: DefId,
107 args: &ty::GenericArgs,
108 ) -> QueryResult<rty::GenericArgs> {
109 let generics = self.generics_of(def_id)?;
110 args.iter()
111 .enumerate()
112 .map(|(idx, arg)| {
113 let param = generics.param_at(idx + 1, self.genv)?;
115 self.refine_generic_arg(¶m, arg)
116 })
117 .try_collect()
118 }
119
120 pub fn refine_variant_def(
121 &self,
122 adt_def_id: DefId,
123 variant_idx: VariantIdx,
124 ) -> QueryResult<rty::PolyVariant> {
125 let adt_def = self.adt_def(adt_def_id)?;
126 let variant_def = adt_def.variant(variant_idx);
127 let fields = variant_def
128 .fields
129 .iter()
130 .map(|fld| {
131 let ty = self.genv.lower_type_of(fld.did)?.instantiate_identity();
132 ty.refine(self)
133 })
134 .try_collect()?;
135
136 let idx = if adt_def.sort_def().is_struct() {
137 rty::Expr::unit_struct(adt_def_id)
138 } else {
139 rty::Expr::ctor_enum(adt_def_id, variant_idx)
140 };
141 let value = rty::VariantSig::new(
142 adt_def,
143 rty::GenericArg::identity_for_item(self.genv, adt_def_id)?,
144 fields,
145 idx,
146 List::empty(),
147 );
148
149 Ok(rty::Binder::bind_with_vars(value, List::empty()))
150 }
151
152 fn refine_generic_args(
153 &self,
154 def_id: DefId,
155 args: &ty::GenericArgs,
156 ) -> QueryResult<rty::GenericArgs> {
157 let generics = self.generics_of(def_id)?;
158 args.iter()
159 .enumerate()
160 .map(|(idx, arg)| {
161 let param = generics.param_at(idx, self.genv)?;
162 self.refine_generic_arg(¶m, arg)
163 })
164 .collect()
165 }
166
167 pub fn refine_generic_arg(
168 &self,
169 param: &rty::GenericParamDef,
170 arg: &ty::GenericArg,
171 ) -> QueryResult<rty::GenericArg> {
172 match (¶m.kind, arg) {
173 (rty::GenericParamDefKind::Type { .. }, ty::GenericArg::Ty(ty)) => {
174 Ok(rty::GenericArg::Ty(ty.refine(self)?))
175 }
176 (rty::GenericParamDefKind::Base { .. }, ty::GenericArg::Ty(ty)) => {
177 let rty::TyOrBase::Base(contr) = self.refine_ty_or_base(ty)? else {
178 return Err(QueryErr::InvalidGenericArg { def_id: param.def_id });
179 };
180 Ok(rty::GenericArg::Base(contr))
181 }
182 (rty::GenericParamDefKind::Lifetime, ty::GenericArg::Lifetime(re)) => {
183 Ok(rty::GenericArg::Lifetime(*re))
184 }
185 (rty::GenericParamDefKind::Const { .. }, ty::GenericArg::Const(ct)) => {
186 Ok(rty::GenericArg::Const(ct.clone()))
187 }
188 _ => bug!("mismatched generic arg `{arg:?}` `{param:?}`"),
189 }
190 }
191
192 fn refine_alias_ty(
193 &self,
194 alias_kind: ty::AliasKind,
195 alias_ty: &ty::AliasTy,
196 ) -> QueryResult<rty::AliasTy> {
197 let def_id = alias_ty.def_id;
198 let args = self.refine_generic_args(def_id, &alias_ty.args)?;
199
200 let refine_args = if let ty::AliasKind::Opaque = alias_kind {
201 rty::RefineArgs::for_item(self.genv, def_id, |param, _| {
202 let param = param.instantiate(self.genv.tcx(), &args, &[]);
203 Ok(rty::Expr::hole(rty::HoleKind::Expr(param.sort)))
204 })?
205 } else {
206 List::empty()
207 };
208
209 Ok(rty::AliasTy::new(def_id, args, refine_args))
210 }
211
212 pub fn refine_ty_or_base(&self, ty: &ty::Ty) -> QueryResult<rty::TyOrBase> {
213 let bty = match ty.kind() {
214 ty::TyKind::Closure(did, args) => {
215 let closure_args = args.as_closure();
216 let upvar_tys = closure_args
217 .upvar_tys()
218 .iter()
219 .map(|ty| ty.refine(self))
220 .try_collect()?;
221 rty::BaseTy::Closure(*did, upvar_tys, args.clone())
222 }
223 ty::TyKind::Coroutine(did, args) => {
224 let args = args.as_coroutine();
225 let resume_ty = args.resume_ty().refine(self)?;
226 let upvar_tys = args.upvar_tys().map(|ty| ty.refine(self)).try_collect()?;
227 rty::BaseTy::Coroutine(*did, resume_ty, upvar_tys)
228 }
229 ty::TyKind::CoroutineWitness(..) => {
230 bug!("implement when we know what this is");
231 }
232 ty::TyKind::Never => rty::BaseTy::Never,
233 ty::TyKind::Ref(r, ty, mutbl) => rty::BaseTy::Ref(*r, ty.refine(self)?, *mutbl),
234 ty::TyKind::Float(float_ty) => rty::BaseTy::Float(*float_ty),
235 ty::TyKind::Tuple(tys) => {
236 let tys = tys.iter().map(|ty| ty.refine(self)).try_collect()?;
237 rty::BaseTy::Tuple(tys)
238 }
239 ty::TyKind::Array(ty, len) => rty::BaseTy::Array(ty.refine(self)?, len.clone()),
240 ty::TyKind::Param(param_ty) => {
241 match self.param(*param_ty)?.kind {
242 rty::GenericParamDefKind::Type { .. } => {
243 return Ok(rty::TyOrBase::Ty(rty::Ty::param(*param_ty)));
244 }
245 rty::GenericParamDefKind::Base { .. } => rty::BaseTy::Param(*param_ty),
246 rty::GenericParamDefKind::Lifetime | rty::GenericParamDefKind::Const { .. } => {
247 bug!()
248 }
249 }
250 }
251 ty::TyKind::Adt(adt_def, args) => {
252 let adt_def = self.genv.adt_def(adt_def.did())?;
253 let args = self.refine_generic_args(adt_def.did(), args)?;
254 rty::BaseTy::adt(adt_def, args)
255 }
256 ty::TyKind::FnDef(def_id, args) => {
257 let args = self.refine_generic_args(*def_id, args)?;
258 rty::BaseTy::fn_def(*def_id, args)
259 }
260 ty::TyKind::Alias(kind, alias_ty) => {
261 let alias_ty = self.as_default().refine_alias_ty(*kind, alias_ty)?;
262 rty::BaseTy::Alias(*kind, alias_ty)
263 }
264 ty::TyKind::Bool => rty::BaseTy::Bool,
265 ty::TyKind::Int(int_ty) => rty::BaseTy::Int(*int_ty),
266 ty::TyKind::Uint(uint_ty) => rty::BaseTy::Uint(*uint_ty),
267 ty::TyKind::Foreign(def_id) => rty::BaseTy::Foreign(*def_id),
268 ty::TyKind::Str => rty::BaseTy::Str,
269 ty::TyKind::Slice(ty) => rty::BaseTy::Slice(ty.refine(self)?),
270 ty::TyKind::Char => rty::BaseTy::Char,
271 ty::TyKind::FnPtr(poly_fn_sig) => {
272 rty::BaseTy::FnPtr(poly_fn_sig.refine(&self.as_default())?)
273 }
274 ty::TyKind::RawPtr(ty, mu) => rty::BaseTy::RawPtr(ty.refine(&self.as_default())?, *mu),
275 ty::TyKind::Dynamic(exi_preds, r) => {
276 let exi_preds = exi_preds
277 .iter()
278 .map(|pred| pred.refine(self))
279 .try_collect()?;
280 rty::BaseTy::Dynamic(exi_preds, *r)
281 }
282 };
283 Ok(rty::TyOrBase::Base((self.refine)(bty)))
284 }
285
286 fn as_default(&self) -> Self {
287 Refiner { refine: refine_default, generics: self.generics.clone(), ..*self }
288 }
289
290 fn adt_def(&self, def_id: DefId) -> QueryResult<rty::AdtDef> {
291 self.genv.adt_def(def_id)
292 }
293
294 fn generics_of(&self, def_id: DefId) -> QueryResult<rty::Generics> {
295 self.genv.generics_of(def_id)
296 }
297
298 fn param(&self, param_ty: ParamTy) -> QueryResult<rty::GenericParamDef> {
299 self.generics.param_at(param_ty.index as usize, self.genv)
300 }
301}
302
303pub trait Refine {
304 type Output;
305
306 fn refine(&self, refiner: &Refiner) -> QueryResult<Self::Output>;
307}
308
309impl Refine for ty::Ty {
310 type Output = rty::Ty;
311
312 fn refine(&self, refiner: &Refiner) -> QueryResult<rty::Ty> {
313 Ok(refiner.refine_ty_or_base(self)?.into_ty())
314 }
315}
316
317impl<T: Refine> Refine for ty::Binder<T> {
318 type Output = rty::Binder<T::Output>;
319
320 fn refine(&self, refiner: &Refiner) -> QueryResult<rty::Binder<T::Output>> {
321 let vars = refine_bound_variables(self.vars());
322 let inner = self.skip_binder_ref().refine(refiner)?;
323 Ok(rty::Binder::bind_with_vars(inner, vars))
324 }
325}
326
327impl Refine for ty::FnSig {
328 type Output = rty::FnSig;
329
330 fn refine(&self, refiner: &Refiner) -> QueryResult<rty::FnSig> {
332 let inputs = self
333 .inputs()
334 .iter()
335 .map(|ty| ty.refine(refiner))
336 .try_collect()?;
337 let ret = self.output().refine(refiner)?.shift_in_escaping(1);
338 let output = rty::Binder::bind_with_vars(rty::FnOutput::new(ret, vec![]), List::empty());
339 Ok(rty::FnSig::new(self.safety, self.abi, List::empty(), inputs, output))
345 }
346}
347
348impl Refine for ty::Clause {
349 type Output = rty::Clause;
350
351 fn refine(&self, refiner: &Refiner) -> QueryResult<rty::Clause> {
352 Ok(rty::Clause { kind: self.kind.refine(refiner)? })
353 }
354}
355
356impl Refine for ty::TraitRef {
357 type Output = rty::TraitRef;
358
359 fn refine(&self, refiner: &Refiner) -> QueryResult<rty::TraitRef> {
360 Ok(rty::TraitRef {
361 def_id: self.def_id,
362 args: refiner.refine_generic_args(self.def_id, &self.args)?,
363 })
364 }
365}
366
367impl Refine for ty::ClauseKind {
368 type Output = rty::ClauseKind;
369
370 fn refine(&self, refiner: &Refiner) -> QueryResult<rty::ClauseKind> {
371 let kind = match self {
372 ty::ClauseKind::Trait(trait_pred) => {
373 let pred = rty::TraitPredicate { trait_ref: trait_pred.trait_ref.refine(refiner)? };
374 rty::ClauseKind::Trait(pred)
375 }
376 ty::ClauseKind::Projection(proj_pred) => {
377 let rty::TyOrBase::Base(term) = refiner.refine_ty_or_base(&proj_pred.term)? else {
378 return Err(query_bug!(
379 refiner.def_id,
380 "sorry, we can't handle non-base associated types"
381 ));
382 };
383 let pred = rty::ProjectionPredicate {
384 projection_ty: refiner
385 .refine_alias_ty(ty::AliasKind::Projection, &proj_pred.projection_ty)?,
386 term,
387 };
388 rty::ClauseKind::Projection(pred)
389 }
390 ty::ClauseKind::RegionOutlives(pred) => {
391 let pred = rty::OutlivesPredicate(pred.0, pred.1);
392 rty::ClauseKind::RegionOutlives(pred)
393 }
394 ty::ClauseKind::TypeOutlives(pred) => {
395 let pred = rty::OutlivesPredicate(pred.0.refine(refiner)?, pred.1);
396 rty::ClauseKind::TypeOutlives(pred)
397 }
398 ty::ClauseKind::ConstArgHasType(const_, ty) => {
399 rty::ClauseKind::ConstArgHasType(const_.clone(), ty.refine(&refiner.as_default())?)
400 }
401 };
402 Ok(kind)
403 }
404}
405
406impl Refine for ty::ExistentialPredicate {
407 type Output = rty::ExistentialPredicate;
408
409 fn refine(&self, refiner: &Refiner) -> QueryResult<Self::Output> {
410 let pred = match self {
411 ty::ExistentialPredicate::Trait(trait_ref) => {
412 rty::ExistentialPredicate::Trait(rty::ExistentialTraitRef {
413 def_id: trait_ref.def_id,
414 args: refiner.refine_existential_predicate_generic_args(
415 trait_ref.def_id,
416 &trait_ref.args,
417 )?,
418 })
419 }
420 ty::ExistentialPredicate::Projection(projection) => {
421 let rty::TyOrBase::Base(term) = refiner.refine_ty_or_base(&projection.term)? else {
422 return Err(query_bug!(
423 refiner.def_id,
424 "sorry, we can't handle non-base associated types"
425 ));
426 };
427 rty::ExistentialPredicate::Projection(rty::ExistentialProjection {
428 def_id: projection.def_id,
429 args: refiner.refine_existential_predicate_generic_args(
430 projection.def_id,
431 &projection.args,
432 )?,
433 term,
434 })
435 }
436 ty::ExistentialPredicate::AutoTrait(def_id) => {
437 rty::ExistentialPredicate::AutoTrait(*def_id)
438 }
439 };
440 Ok(pred)
441 }
442}
443
444impl Refine for ty::GenericPredicates {
445 type Output = rty::GenericPredicates;
446
447 fn refine(&self, refiner: &Refiner) -> QueryResult<rty::GenericPredicates> {
448 Ok(rty::GenericPredicates {
449 parent: self.parent,
450 predicates: refiner.refine(&self.predicates)?,
451 })
452 }
453}
454
455impl<T> Refine for List<T>
456where
457 T: SliceInternable,
458 T: Refine<Output: SliceInternable>,
459{
460 type Output = rty::List<T::Output>;
461
462 fn refine(&self, refiner: &Refiner) -> QueryResult<rty::List<T::Output>> {
463 refiner.refine(&self[..])
464 }
465}
466
467impl<T> Refine for [T]
468where
469 T: Refine<Output: SliceInternable>,
470{
471 type Output = rty::List<T::Output>;
472
473 fn refine(&self, refiner: &Refiner) -> QueryResult<rty::List<T::Output>> {
474 self.iter().map(|t| refiner.refine(t)).try_collect()
475 }
476}
477
478fn refine_default(bty: rty::BaseTy) -> rty::SubsetTyCtor {
479 let sort = bty.sort();
480 let constr = rty::SubsetTy::trivial(bty.shift_in_escaping(1), rty::Expr::nu());
481 rty::Binder::bind_with_sort(constr, sort)
482}
483
484pub fn refine_bound_variables(vars: &[ty::BoundVariableKind]) -> List<rty::BoundVariableKind> {
485 vars.iter()
486 .map(|kind| {
487 match kind {
488 ty::BoundVariableKind::Region(kind) => rty::BoundVariableKind::Region(*kind),
489 }
490 })
491 .collect()
492}