1use flux_arc_interner::{List, SliceInternable};
7use flux_common::bug;
8use flux_rustc_bridge::{ty, ty::GenericArgsExt as _};
9use itertools::Itertools;
10use rustc_abi::VariantIdx;
11use rustc_hir::{def::DefKind, def_id::DefId};
12use rustc_middle::ty::ParamTy;
13
14use super::{RefineArgsExt, fold::TypeFoldable};
15use crate::{
16 global_env::GlobalEnv,
17 queries::{QueryErr, QueryResult},
18 query_bug,
19 rty::{self, Expr},
20};
21
22pub fn refine_generics(genv: GlobalEnv, def_id: DefId, generics: &ty::Generics) -> rty::Generics {
23 let is_box = if let DefKind::Struct = genv.def_kind(def_id) {
24 genv.tcx().adt_def(def_id).is_box()
25 } else {
26 false
27 };
28 let params = generics
29 .params
30 .iter()
31 .map(|param| {
32 rty::GenericParamDef {
33 kind: refine_generic_param_def_kind(is_box, param.kind),
34 index: param.index,
35 name: param.name,
36 def_id: param.def_id,
37 }
38 })
39 .collect();
40
41 rty::Generics {
42 own_params: params,
43 parent: generics.parent(),
44 parent_count: generics.parent_count(),
45 has_self: generics.orig.has_self,
46 }
47}
48
49pub fn refine_generic_param_def_kind(
50 is_box: bool,
51 kind: ty::GenericParamDefKind,
52) -> rty::GenericParamDefKind {
53 match kind {
54 ty::GenericParamDefKind::Lifetime => rty::GenericParamDefKind::Lifetime,
55 ty::GenericParamDefKind::Type { has_default } => {
56 if is_box {
57 rty::GenericParamDefKind::Type { has_default }
58 } else {
59 rty::GenericParamDefKind::Base { has_default }
60 }
61 }
62 ty::GenericParamDefKind::Const { has_default, .. } => {
63 rty::GenericParamDefKind::Const { has_default }
64 }
65 }
66}
67
68pub struct Refiner<'genv, 'tcx> {
69 genv: GlobalEnv<'genv, 'tcx>,
70 def_id: DefId,
71 generics: rty::Generics,
72 refine: fn(rty::BaseTy) -> rty::SubsetTyCtor,
73}
74
75impl<'genv, 'tcx> Refiner<'genv, 'tcx> {
76 pub fn new_for_item(
77 genv: GlobalEnv<'genv, 'tcx>,
78 def_id: DefId,
79 refine: fn(rty::BaseTy) -> rty::SubsetTyCtor,
80 ) -> QueryResult<Self> {
81 let generics = genv.generics_of(def_id)?;
82 Ok(Self { genv, def_id, generics, refine })
83 }
84
85 pub fn default_for_item(genv: GlobalEnv<'genv, 'tcx>, def_id: DefId) -> QueryResult<Self> {
86 Self::new_for_item(genv, def_id, refine_default)
87 }
88
89 pub fn with_holes(genv: GlobalEnv<'genv, 'tcx>, def_id: DefId) -> QueryResult<Self> {
90 Self::new_for_item(genv, def_id, |bty| {
91 let sort = bty.sort();
92 let constr = rty::SubsetTy::new(
93 bty.shift_in_escaping(1),
94 rty::Expr::nu(),
95 rty::Expr::hole(rty::HoleKind::Pred),
96 );
97 rty::Binder::bind_with_sort(constr, sort)
98 })
99 }
100
101 pub fn refine<T: Refine + ?Sized>(&self, t: &T) -> QueryResult<T::Output> {
102 t.refine(self)
103 }
104
105 fn refine_existential_predicate_generic_args(
106 &self,
107 def_id: DefId,
108 args: &ty::GenericArgs,
109 ) -> QueryResult<rty::GenericArgs> {
110 let generics = self.generics_of(def_id)?;
111 args.iter()
112 .enumerate()
113 .map(|(idx, arg)| {
114 let param = generics.param_at(idx + 1, self.genv)?;
116 self.refine_generic_arg(¶m, arg)
117 })
118 .try_collect()
119 }
120
121 pub fn refine_variant_def(
122 &self,
123 adt_def_id: DefId,
124 variant_idx: VariantIdx,
125 ) -> QueryResult<rty::PolyVariant> {
126 let adt_def = self.adt_def(adt_def_id)?;
127 let variant_def = adt_def.variant(variant_idx);
128 let fields = variant_def
129 .fields
130 .iter()
131 .map(|fld| {
132 let ty = self.genv.lower_type_of(fld.did)?.instantiate_identity();
133 ty.refine(self)
134 })
135 .try_collect()?;
136
137 let idx = if adt_def.sort_def().is_struct() {
138 rty::Expr::unit_struct(adt_def_id)
139 } else {
140 rty::Expr::ctor_enum(adt_def_id, variant_idx)
141 };
142 let value = rty::VariantSig::new(
143 adt_def,
144 rty::GenericArg::identity_for_item(self.genv, adt_def_id)?,
145 fields,
146 idx,
147 List::empty(),
148 );
149
150 Ok(rty::Binder::bind_with_vars(value, List::empty()))
151 }
152
153 pub fn refine_generic_args(
154 &self,
155 def_id: DefId,
156 args: &ty::GenericArgs,
157 ) -> QueryResult<rty::GenericArgs> {
158 let generics = self.generics_of(def_id)?;
159 args.iter()
160 .enumerate()
161 .map(|(idx, arg)| {
162 let param = generics.param_at(idx, self.genv)?;
163 self.refine_generic_arg(¶m, arg)
164 })
165 .collect()
166 }
167
168 pub fn refine_generic_arg(
169 &self,
170 param: &rty::GenericParamDef,
171 arg: &ty::GenericArg,
172 ) -> QueryResult<rty::GenericArg> {
173 match (¶m.kind, arg) {
174 (rty::GenericParamDefKind::Type { .. }, ty::GenericArg::Ty(ty)) => {
175 Ok(rty::GenericArg::Ty(ty.refine(self)?))
176 }
177 (rty::GenericParamDefKind::Base { .. }, ty::GenericArg::Ty(ty)) => {
178 let rty::TyOrBase::Base(contr) = self.refine_ty_or_base(ty)? else {
179 return Err(QueryErr::InvalidGenericArg { def_id: param.def_id });
180 };
181 Ok(rty::GenericArg::Base(contr))
182 }
183 (rty::GenericParamDefKind::Lifetime, ty::GenericArg::Lifetime(re)) => {
184 Ok(rty::GenericArg::Lifetime(*re))
185 }
186 (rty::GenericParamDefKind::Const { .. }, ty::GenericArg::Const(ct)) => {
187 Ok(rty::GenericArg::Const(ct.clone()))
188 }
189 _ => bug!("mismatched generic arg `{arg:?}` `{param:?}`"),
190 }
191 }
192
193 fn refine_alias_ty(
194 &self,
195 alias_kind: ty::AliasKind,
196 alias_ty: &ty::AliasTy,
197 ) -> QueryResult<rty::AliasTy> {
198 let def_id = alias_ty.def_id;
199 let args = self.refine_generic_args(def_id, &alias_ty.args)?;
200
201 let refine_args = if let ty::AliasKind::Opaque = alias_kind {
202 rty::RefineArgs::for_item(self.genv, def_id, |param, _| {
203 let param = param.instantiate(self.genv.tcx(), &args, &[]);
204 Ok(rty::Expr::hole(rty::HoleKind::Expr(param.sort)))
205 })?
206 } else {
207 List::empty()
208 };
209
210 Ok(rty::AliasTy::new(def_id, args, refine_args))
211 }
212
213 pub fn refine_ty_or_base(&self, ty: &ty::Ty) -> QueryResult<rty::TyOrBase> {
214 let bty = match ty.kind() {
215 ty::TyKind::Closure(did, args) => {
216 let no_panic = self.genv.no_panic(did);
217 let closure_args = args.as_closure();
218 let upvar_tys = closure_args
219 .upvar_tys()
220 .iter()
221 .map(|ty| ty.refine(self))
222 .try_collect()?;
223 rty::BaseTy::Closure(*did, upvar_tys, args.clone(), no_panic)
224 }
225 ty::TyKind::Coroutine(did, args) => {
226 let coroutine_args = args.as_coroutine();
227 let resume_ty = coroutine_args.resume_ty().refine(self)?;
228 let upvar_tys = coroutine_args
229 .upvar_tys()
230 .map(|ty| ty.refine(self))
231 .try_collect()?;
232 rty::BaseTy::Coroutine(*did, resume_ty, upvar_tys, args.clone())
233 }
234 ty::TyKind::CoroutineWitness(..) => {
235 bug!("implement when we know what this is");
236 }
237 ty::TyKind::Never => rty::BaseTy::Never,
238 ty::TyKind::Ref(r, ty, mutbl) => rty::BaseTy::Ref(*r, ty.refine(self)?, *mutbl),
239 ty::TyKind::Float(float_ty) => rty::BaseTy::Float(*float_ty),
240 ty::TyKind::Tuple(tys) => {
241 let tys = tys.iter().map(|ty| ty.refine(self)).try_collect()?;
242 rty::BaseTy::Tuple(tys)
243 }
244 ty::TyKind::Array(ty, len) => rty::BaseTy::Array(ty.refine(self)?, len.clone()),
245 ty::TyKind::Param(param_ty) => {
246 match self.param(*param_ty)?.kind {
247 rty::GenericParamDefKind::Type { .. } => {
248 return Ok(rty::TyOrBase::Ty(rty::Ty::param(*param_ty)));
249 }
250 rty::GenericParamDefKind::Base { .. } => rty::BaseTy::Param(*param_ty),
251 rty::GenericParamDefKind::Lifetime | rty::GenericParamDefKind::Const { .. } => {
252 bug!()
253 }
254 }
255 }
256 ty::TyKind::Adt(adt_def, args) => {
257 let adt_def = self.genv.adt_def(adt_def.did())?;
258 let args = self.refine_generic_args(adt_def.did(), args)?;
259 rty::BaseTy::adt(adt_def, args)
260 }
261 ty::TyKind::FnDef(def_id, args) => {
262 let args = self.refine_generic_args(*def_id, args)?;
263 rty::BaseTy::fn_def(*def_id, args)
264 }
265 ty::TyKind::Alias(kind, alias_ty) => {
266 let alias_ty = self.as_default().refine_alias_ty(*kind, alias_ty)?;
267 rty::BaseTy::Alias(*kind, alias_ty)
268 }
269 ty::TyKind::Bool => rty::BaseTy::Bool,
270 ty::TyKind::Int(int_ty) => rty::BaseTy::Int(*int_ty),
271 ty::TyKind::Uint(uint_ty) => rty::BaseTy::Uint(*uint_ty),
272 ty::TyKind::Foreign(def_id) => rty::BaseTy::Foreign(*def_id),
273 ty::TyKind::Str => rty::BaseTy::Str,
274 ty::TyKind::Slice(ty) => rty::BaseTy::Slice(ty.refine(self)?),
275 ty::TyKind::Char => rty::BaseTy::Char,
276 ty::TyKind::FnPtr(poly_fn_sig) => {
277 rty::BaseTy::FnPtr(poly_fn_sig.refine(&self.as_default())?)
278 }
279 ty::TyKind::RawPtr(ty, mu) => rty::BaseTy::RawPtr(ty.refine(&self.as_default())?, *mu),
280 ty::TyKind::Dynamic(exi_preds, r) => {
281 let exi_preds = exi_preds
282 .iter()
283 .map(|pred| pred.refine(self))
284 .try_collect()?;
285 rty::BaseTy::Dynamic(exi_preds, *r)
286 }
287 ty::TyKind::Pat => rty::BaseTy::Pat,
288 };
289 Ok(rty::TyOrBase::Base((self.refine)(bty)))
290 }
291
292 fn as_default(&self) -> Self {
293 Refiner { refine: refine_default, generics: self.generics.clone(), ..*self }
294 }
295
296 fn adt_def(&self, def_id: DefId) -> QueryResult<rty::AdtDef> {
297 self.genv.adt_def(def_id)
298 }
299
300 fn generics_of(&self, def_id: DefId) -> QueryResult<rty::Generics> {
301 self.genv.generics_of(def_id)
302 }
303
304 fn param(&self, param_ty: ParamTy) -> QueryResult<rty::GenericParamDef> {
305 self.generics.param_at(param_ty.index as usize, self.genv)
306 }
307}
308
309pub trait Refine {
310 type Output;
311
312 fn refine(&self, refiner: &Refiner) -> QueryResult<Self::Output>;
313}
314
315impl Refine for ty::Ty {
316 type Output = rty::Ty;
317
318 fn refine(&self, refiner: &Refiner) -> QueryResult<rty::Ty> {
319 Ok(refiner.refine_ty_or_base(self)?.into_ty())
320 }
321}
322
323impl<T: Refine> Refine for ty::Binder<T> {
324 type Output = rty::Binder<T::Output>;
325
326 fn refine(&self, refiner: &Refiner) -> QueryResult<rty::Binder<T::Output>> {
327 let vars = refine_bound_variables(self.vars());
328 let inner = self.skip_binder_ref().refine(refiner)?;
329 Ok(rty::Binder::bind_with_vars(inner, vars))
330 }
331}
332
333impl Refine for ty::FnSig {
334 type Output = rty::FnSig;
335
336 fn refine(&self, refiner: &Refiner) -> QueryResult<rty::FnSig> {
338 let inputs = self
339 .inputs()
340 .iter()
341 .map(|ty| ty.refine(refiner))
342 .try_collect()?;
343 let ret = self.output().refine(refiner)?.shift_in_escaping(1);
344 let output = rty::Binder::bind_with_vars(rty::FnOutput::new(ret, vec![]), List::empty());
345 Ok(rty::FnSig::new(self.safety, self.abi, List::empty(), inputs, output, Expr::ff(), true))
351 }
352}
353
354impl Refine for ty::Clause {
355 type Output = rty::Clause;
356
357 fn refine(&self, refiner: &Refiner) -> QueryResult<rty::Clause> {
358 Ok(rty::Clause { kind: self.kind.refine(refiner)? })
359 }
360}
361
362impl Refine for ty::TraitRef {
363 type Output = rty::TraitRef;
364
365 fn refine(&self, refiner: &Refiner) -> QueryResult<rty::TraitRef> {
366 Ok(rty::TraitRef {
367 def_id: self.def_id,
368 args: refiner.refine_generic_args(self.def_id, &self.args)?,
369 })
370 }
371}
372
373impl Refine for ty::ClauseKind {
374 type Output = rty::ClauseKind;
375
376 fn refine(&self, refiner: &Refiner) -> QueryResult<rty::ClauseKind> {
377 let kind = match self {
378 ty::ClauseKind::Trait(trait_pred) => {
379 let pred = rty::TraitPredicate { trait_ref: trait_pred.trait_ref.refine(refiner)? };
380 rty::ClauseKind::Trait(pred)
381 }
382 ty::ClauseKind::Projection(proj_pred) => {
383 let rty::TyOrBase::Base(term) = refiner.refine_ty_or_base(&proj_pred.term)? else {
384 return Err(query_bug!(
385 refiner.def_id,
386 "sorry, we can't handle non-base associated types"
387 ));
388 };
389 let pred = rty::ProjectionPredicate {
390 projection_ty: refiner
391 .refine_alias_ty(ty::AliasKind::Projection, &proj_pred.projection_ty)?,
392 term,
393 };
394 rty::ClauseKind::Projection(pred)
395 }
396 ty::ClauseKind::RegionOutlives(pred) => {
397 let pred = rty::OutlivesPredicate(pred.0, pred.1);
398 rty::ClauseKind::RegionOutlives(pred)
399 }
400 ty::ClauseKind::TypeOutlives(pred) => {
401 let pred = rty::OutlivesPredicate(pred.0.refine(refiner)?, pred.1);
402 rty::ClauseKind::TypeOutlives(pred)
403 }
404 ty::ClauseKind::ConstArgHasType(const_, ty) => {
405 rty::ClauseKind::ConstArgHasType(const_.clone(), ty.refine(&refiner.as_default())?)
406 }
407 };
408 Ok(kind)
409 }
410}
411
412impl Refine for ty::ExistentialPredicate {
413 type Output = rty::ExistentialPredicate;
414
415 fn refine(&self, refiner: &Refiner) -> QueryResult<Self::Output> {
416 let pred = match self {
417 ty::ExistentialPredicate::Trait(trait_ref) => {
418 rty::ExistentialPredicate::Trait(rty::ExistentialTraitRef {
419 def_id: trait_ref.def_id,
420 args: refiner.refine_existential_predicate_generic_args(
421 trait_ref.def_id,
422 &trait_ref.args,
423 )?,
424 })
425 }
426 ty::ExistentialPredicate::Projection(projection) => {
427 let rty::TyOrBase::Base(term) = refiner.refine_ty_or_base(&projection.term)? else {
428 return Err(query_bug!(
429 refiner.def_id,
430 "sorry, we can't handle non-base associated types"
431 ));
432 };
433 rty::ExistentialPredicate::Projection(rty::ExistentialProjection {
434 def_id: projection.def_id,
435 args: refiner.refine_existential_predicate_generic_args(
436 projection.def_id,
437 &projection.args,
438 )?,
439 term,
440 })
441 }
442 ty::ExistentialPredicate::AutoTrait(def_id) => {
443 rty::ExistentialPredicate::AutoTrait(*def_id)
444 }
445 };
446 Ok(pred)
447 }
448}
449
450impl Refine for ty::GenericPredicates {
451 type Output = rty::GenericPredicates;
452
453 fn refine(&self, refiner: &Refiner) -> QueryResult<rty::GenericPredicates> {
454 Ok(rty::GenericPredicates {
455 parent: self.parent,
456 predicates: refiner.refine(&self.predicates)?,
457 })
458 }
459}
460
461impl<T> Refine for List<T>
462where
463 T: SliceInternable,
464 T: Refine<Output: SliceInternable>,
465{
466 type Output = rty::List<T::Output>;
467
468 fn refine(&self, refiner: &Refiner) -> QueryResult<rty::List<T::Output>> {
469 refiner.refine(&self[..])
470 }
471}
472
473impl<T> Refine for [T]
474where
475 T: Refine<Output: SliceInternable>,
476{
477 type Output = rty::List<T::Output>;
478
479 fn refine(&self, refiner: &Refiner) -> QueryResult<rty::List<T::Output>> {
480 self.iter().map(|t| refiner.refine(t)).try_collect()
481 }
482}
483
484fn refine_default(bty: rty::BaseTy) -> rty::SubsetTyCtor {
485 let sort = bty.sort();
486 let constr = rty::SubsetTy::trivial(bty.shift_in_escaping(1), rty::Expr::nu());
487 rty::Binder::bind_with_sort(constr, sort)
488}
489
490pub fn refine_bound_variables(vars: &[ty::BoundVariableKind]) -> List<rty::BoundVariableKind> {
491 vars.iter()
492 .map(|kind| {
493 match kind {
494 ty::BoundVariableKind::Region(kind) => rty::BoundVariableKind::Region(*kind),
495 }
496 })
497 .collect()
498}