flux_middle/rty/
refining.rs

1//! *Refining* is the process of generating a refined version of a rust type.
2//!
3//! Concretely, this module provides functions to go from types in [`flux_rustc_bridge::ty`] to
4//! types in [`rty`].
5
6use flux_arc_interner::{List, SliceInternable};
7use flux_common::bug;
8use flux_rustc_bridge::{ty, ty::GenericArgsExt as _};
9use itertools::Itertools;
10use rustc_abi::VariantIdx;
11use rustc_hir::{def::DefKind, def_id::DefId};
12use rustc_middle::ty::ParamTy;
13
14use super::{RefineArgsExt, fold::TypeFoldable};
15use crate::{
16    global_env::GlobalEnv,
17    queries::{QueryErr, QueryResult},
18    query_bug,
19    rty::{self, Expr},
20};
21
22pub fn refine_generics(genv: GlobalEnv, def_id: DefId, generics: &ty::Generics) -> rty::Generics {
23    let is_box = if let DefKind::Struct = genv.def_kind(def_id) {
24        genv.tcx().adt_def(def_id).is_box()
25    } else {
26        false
27    };
28    let params = generics
29        .params
30        .iter()
31        .map(|param| {
32            rty::GenericParamDef {
33                kind: refine_generic_param_def_kind(is_box, param.kind),
34                index: param.index,
35                name: param.name,
36                def_id: param.def_id,
37            }
38        })
39        .collect();
40
41    rty::Generics {
42        own_params: params,
43        parent: generics.parent(),
44        parent_count: generics.parent_count(),
45        has_self: generics.orig.has_self,
46    }
47}
48
49pub fn refine_generic_param_def_kind(
50    is_box: bool,
51    kind: ty::GenericParamDefKind,
52) -> rty::GenericParamDefKind {
53    match kind {
54        ty::GenericParamDefKind::Lifetime => rty::GenericParamDefKind::Lifetime,
55        ty::GenericParamDefKind::Type { has_default } => {
56            if is_box {
57                rty::GenericParamDefKind::Type { has_default }
58            } else {
59                rty::GenericParamDefKind::Base { has_default }
60            }
61        }
62        ty::GenericParamDefKind::Const { has_default, .. } => {
63            rty::GenericParamDefKind::Const { has_default }
64        }
65    }
66}
67
68pub struct Refiner<'genv, 'tcx> {
69    genv: GlobalEnv<'genv, 'tcx>,
70    def_id: DefId,
71    generics: rty::Generics,
72    refine: fn(rty::BaseTy) -> rty::SubsetTyCtor,
73}
74
75impl<'genv, 'tcx> Refiner<'genv, 'tcx> {
76    pub fn new_for_item(
77        genv: GlobalEnv<'genv, 'tcx>,
78        def_id: DefId,
79        refine: fn(rty::BaseTy) -> rty::SubsetTyCtor,
80    ) -> QueryResult<Self> {
81        let generics = genv.generics_of(def_id)?;
82        Ok(Self { genv, def_id, generics, refine })
83    }
84
85    pub fn default_for_item(genv: GlobalEnv<'genv, 'tcx>, def_id: DefId) -> QueryResult<Self> {
86        Self::new_for_item(genv, def_id, refine_default)
87    }
88
89    pub fn with_holes(genv: GlobalEnv<'genv, 'tcx>, def_id: DefId) -> QueryResult<Self> {
90        Self::new_for_item(genv, def_id, |bty| {
91            let sort = bty.sort();
92            let constr = rty::SubsetTy::new(
93                bty.shift_in_escaping(1),
94                rty::Expr::nu(),
95                rty::Expr::hole(rty::HoleKind::Pred),
96            );
97            rty::Binder::bind_with_sort(constr, sort)
98        })
99    }
100
101    pub fn refine<T: Refine + ?Sized>(&self, t: &T) -> QueryResult<T::Output> {
102        t.refine(self)
103    }
104
105    fn refine_existential_predicate_generic_args(
106        &self,
107        def_id: DefId,
108        args: &ty::GenericArgs,
109    ) -> QueryResult<rty::GenericArgs> {
110        let generics = self.generics_of(def_id)?;
111        args.iter()
112            .enumerate()
113            .map(|(idx, arg)| {
114                // We need to skip the generic for Self
115                let param = generics.param_at(idx + 1, self.genv)?;
116                self.refine_generic_arg(&param, arg)
117            })
118            .try_collect()
119    }
120
121    pub fn refine_variant_def(
122        &self,
123        adt_def_id: DefId,
124        variant_idx: VariantIdx,
125    ) -> QueryResult<rty::PolyVariant> {
126        let adt_def = self.adt_def(adt_def_id)?;
127        let variant_def = adt_def.variant(variant_idx);
128        let fields = variant_def
129            .fields
130            .iter()
131            .map(|fld| {
132                let ty = self.genv.lower_type_of(fld.did)?.instantiate_identity();
133                ty.refine(self)
134            })
135            .try_collect()?;
136
137        let idx = if adt_def.sort_def().is_struct() {
138            rty::Expr::unit_struct(adt_def_id)
139        } else {
140            rty::Expr::ctor_enum(adt_def_id, variant_idx)
141        };
142        let value = rty::VariantSig::new(
143            adt_def,
144            rty::GenericArg::identity_for_item(self.genv, adt_def_id)?,
145            fields,
146            idx,
147            List::empty(),
148        );
149
150        Ok(rty::Binder::bind_with_vars(value, List::empty()))
151    }
152
153    pub fn refine_generic_args(
154        &self,
155        def_id: DefId,
156        args: &ty::GenericArgs,
157    ) -> QueryResult<rty::GenericArgs> {
158        let generics = self.generics_of(def_id)?;
159        args.iter()
160            .enumerate()
161            .map(|(idx, arg)| {
162                let param = generics.param_at(idx, self.genv)?;
163                self.refine_generic_arg(&param, arg)
164            })
165            .collect()
166    }
167
168    pub fn refine_generic_arg(
169        &self,
170        param: &rty::GenericParamDef,
171        arg: &ty::GenericArg,
172    ) -> QueryResult<rty::GenericArg> {
173        match (&param.kind, arg) {
174            (rty::GenericParamDefKind::Type { .. }, ty::GenericArg::Ty(ty)) => {
175                Ok(rty::GenericArg::Ty(ty.refine(self)?))
176            }
177            (rty::GenericParamDefKind::Base { .. }, ty::GenericArg::Ty(ty)) => {
178                let rty::TyOrBase::Base(contr) = self.refine_ty_or_base(ty)? else {
179                    return Err(QueryErr::InvalidGenericArg { def_id: param.def_id });
180                };
181                Ok(rty::GenericArg::Base(contr))
182            }
183            (rty::GenericParamDefKind::Lifetime, ty::GenericArg::Lifetime(re)) => {
184                Ok(rty::GenericArg::Lifetime(*re))
185            }
186            (rty::GenericParamDefKind::Const { .. }, ty::GenericArg::Const(ct)) => {
187                Ok(rty::GenericArg::Const(ct.clone()))
188            }
189            _ => bug!("mismatched generic arg `{arg:?}` `{param:?}`"),
190        }
191    }
192
193    fn refine_alias_ty(
194        &self,
195        alias_kind: ty::AliasKind,
196        alias_ty: &ty::AliasTy,
197    ) -> QueryResult<rty::AliasTy> {
198        let def_id = alias_ty.def_id;
199        let args = self.refine_generic_args(def_id, &alias_ty.args)?;
200
201        let refine_args = if let ty::AliasKind::Opaque = alias_kind {
202            rty::RefineArgs::for_item(self.genv, def_id, |param, _| {
203                let param = param.instantiate(self.genv.tcx(), &args, &[]);
204                Ok(rty::Expr::hole(rty::HoleKind::Expr(param.sort)))
205            })?
206        } else {
207            List::empty()
208        };
209
210        Ok(rty::AliasTy::new(def_id, args, refine_args))
211    }
212
213    pub fn refine_ty_or_base(&self, ty: &ty::Ty) -> QueryResult<rty::TyOrBase> {
214        let bty = match ty.kind() {
215            ty::TyKind::Closure(did, args) => {
216                let no_panic = self.genv.no_panic(did);
217                let closure_args = args.as_closure();
218                let upvar_tys = closure_args
219                    .upvar_tys()
220                    .iter()
221                    .map(|ty| ty.refine(self))
222                    .try_collect()?;
223                rty::BaseTy::Closure(*did, upvar_tys, args.clone(), no_panic)
224            }
225            ty::TyKind::Coroutine(did, args) => {
226                let coroutine_args = args.as_coroutine();
227                let resume_ty = coroutine_args.resume_ty().refine(self)?;
228                let upvar_tys = coroutine_args
229                    .upvar_tys()
230                    .map(|ty| ty.refine(self))
231                    .try_collect()?;
232                rty::BaseTy::Coroutine(*did, resume_ty, upvar_tys, args.clone())
233            }
234            ty::TyKind::CoroutineWitness(..) => {
235                bug!("implement when we know what this is");
236            }
237            ty::TyKind::Never => rty::BaseTy::Never,
238            ty::TyKind::Ref(r, ty, mutbl) => rty::BaseTy::Ref(*r, ty.refine(self)?, *mutbl),
239            ty::TyKind::Float(float_ty) => rty::BaseTy::Float(*float_ty),
240            ty::TyKind::Tuple(tys) => {
241                let tys = tys.iter().map(|ty| ty.refine(self)).try_collect()?;
242                rty::BaseTy::Tuple(tys)
243            }
244            ty::TyKind::Array(ty, len) => rty::BaseTy::Array(ty.refine(self)?, len.clone()),
245            ty::TyKind::Param(param_ty) => {
246                match self.param(*param_ty)?.kind {
247                    rty::GenericParamDefKind::Type { .. } => {
248                        return Ok(rty::TyOrBase::Ty(rty::Ty::param(*param_ty)));
249                    }
250                    rty::GenericParamDefKind::Base { .. } => rty::BaseTy::Param(*param_ty),
251                    rty::GenericParamDefKind::Lifetime | rty::GenericParamDefKind::Const { .. } => {
252                        bug!()
253                    }
254                }
255            }
256            ty::TyKind::Adt(adt_def, args) => {
257                let adt_def = self.genv.adt_def(adt_def.did())?;
258                let args = self.refine_generic_args(adt_def.did(), args)?;
259                rty::BaseTy::adt(adt_def, args)
260            }
261            ty::TyKind::FnDef(def_id, args) => {
262                let args = self.refine_generic_args(*def_id, args)?;
263                rty::BaseTy::fn_def(*def_id, args)
264            }
265            ty::TyKind::Alias(kind, alias_ty) => {
266                let alias_ty = self.as_default().refine_alias_ty(*kind, alias_ty)?;
267                rty::BaseTy::Alias(*kind, alias_ty)
268            }
269            ty::TyKind::Bool => rty::BaseTy::Bool,
270            ty::TyKind::Int(int_ty) => rty::BaseTy::Int(*int_ty),
271            ty::TyKind::Uint(uint_ty) => rty::BaseTy::Uint(*uint_ty),
272            ty::TyKind::Foreign(def_id) => rty::BaseTy::Foreign(*def_id),
273            ty::TyKind::Str => rty::BaseTy::Str,
274            ty::TyKind::Slice(ty) => rty::BaseTy::Slice(ty.refine(self)?),
275            ty::TyKind::Char => rty::BaseTy::Char,
276            ty::TyKind::FnPtr(poly_fn_sig) => {
277                rty::BaseTy::FnPtr(poly_fn_sig.refine(&self.as_default())?)
278            }
279            ty::TyKind::RawPtr(ty, mu) => rty::BaseTy::RawPtr(ty.refine(&self.as_default())?, *mu),
280            ty::TyKind::Dynamic(exi_preds, r) => {
281                let exi_preds = exi_preds
282                    .iter()
283                    .map(|pred| pred.refine(self))
284                    .try_collect()?;
285                rty::BaseTy::Dynamic(exi_preds, *r)
286            }
287            ty::TyKind::Pat => rty::BaseTy::Pat,
288        };
289        Ok(rty::TyOrBase::Base((self.refine)(bty)))
290    }
291
292    fn as_default(&self) -> Self {
293        Refiner { refine: refine_default, generics: self.generics.clone(), ..*self }
294    }
295
296    fn adt_def(&self, def_id: DefId) -> QueryResult<rty::AdtDef> {
297        self.genv.adt_def(def_id)
298    }
299
300    fn generics_of(&self, def_id: DefId) -> QueryResult<rty::Generics> {
301        self.genv.generics_of(def_id)
302    }
303
304    fn param(&self, param_ty: ParamTy) -> QueryResult<rty::GenericParamDef> {
305        self.generics.param_at(param_ty.index as usize, self.genv)
306    }
307}
308
309pub trait Refine {
310    type Output;
311
312    fn refine(&self, refiner: &Refiner) -> QueryResult<Self::Output>;
313}
314
315impl Refine for ty::Ty {
316    type Output = rty::Ty;
317
318    fn refine(&self, refiner: &Refiner) -> QueryResult<rty::Ty> {
319        Ok(refiner.refine_ty_or_base(self)?.into_ty())
320    }
321}
322
323impl<T: Refine> Refine for ty::Binder<T> {
324    type Output = rty::Binder<T::Output>;
325
326    fn refine(&self, refiner: &Refiner) -> QueryResult<rty::Binder<T::Output>> {
327        let vars = refine_bound_variables(self.vars());
328        let inner = self.skip_binder_ref().refine(refiner)?;
329        Ok(rty::Binder::bind_with_vars(inner, vars))
330    }
331}
332
333impl Refine for ty::FnSig {
334    type Output = rty::FnSig;
335
336    // TODO(hof2)
337    fn refine(&self, refiner: &Refiner) -> QueryResult<rty::FnSig> {
338        let inputs = self
339            .inputs()
340            .iter()
341            .map(|ty| ty.refine(refiner))
342            .try_collect()?;
343        let ret = self.output().refine(refiner)?.shift_in_escaping(1);
344        let output = rty::Binder::bind_with_vars(rty::FnOutput::new(ret, vec![]), List::empty());
345        // TODO(hof2) make a hoister to hoist all the stuff out of the inputs,
346        // the hoister will have a list of all the variables it hoisted and the
347        // single hole for the "requires"; then we "fill" the hole with a KVAR
348        // and generate a PolyFnSig with the hoisted variables
349        // see `into_bb_env` in `type_env.rs` for an example.
350        Ok(rty::FnSig::new(self.safety, self.abi, List::empty(), inputs, output, Expr::ff(), true))
351    }
352}
353
354impl Refine for ty::Clause {
355    type Output = rty::Clause;
356
357    fn refine(&self, refiner: &Refiner) -> QueryResult<rty::Clause> {
358        Ok(rty::Clause { kind: self.kind.refine(refiner)? })
359    }
360}
361
362impl Refine for ty::TraitRef {
363    type Output = rty::TraitRef;
364
365    fn refine(&self, refiner: &Refiner) -> QueryResult<rty::TraitRef> {
366        Ok(rty::TraitRef {
367            def_id: self.def_id,
368            args: refiner.refine_generic_args(self.def_id, &self.args)?,
369        })
370    }
371}
372
373impl Refine for ty::ClauseKind {
374    type Output = rty::ClauseKind;
375
376    fn refine(&self, refiner: &Refiner) -> QueryResult<rty::ClauseKind> {
377        let kind = match self {
378            ty::ClauseKind::Trait(trait_pred) => {
379                let pred = rty::TraitPredicate { trait_ref: trait_pred.trait_ref.refine(refiner)? };
380                rty::ClauseKind::Trait(pred)
381            }
382            ty::ClauseKind::Projection(proj_pred) => {
383                let rty::TyOrBase::Base(term) = refiner.refine_ty_or_base(&proj_pred.term)? else {
384                    return Err(query_bug!(
385                        refiner.def_id,
386                        "sorry, we can't handle non-base associated types"
387                    ));
388                };
389                let pred = rty::ProjectionPredicate {
390                    projection_ty: refiner
391                        .refine_alias_ty(ty::AliasKind::Projection, &proj_pred.projection_ty)?,
392                    term,
393                };
394                rty::ClauseKind::Projection(pred)
395            }
396            ty::ClauseKind::RegionOutlives(pred) => {
397                let pred = rty::OutlivesPredicate(pred.0, pred.1);
398                rty::ClauseKind::RegionOutlives(pred)
399            }
400            ty::ClauseKind::TypeOutlives(pred) => {
401                let pred = rty::OutlivesPredicate(pred.0.refine(refiner)?, pred.1);
402                rty::ClauseKind::TypeOutlives(pred)
403            }
404            ty::ClauseKind::ConstArgHasType(const_, ty) => {
405                rty::ClauseKind::ConstArgHasType(const_.clone(), ty.refine(&refiner.as_default())?)
406            }
407        };
408        Ok(kind)
409    }
410}
411
412impl Refine for ty::ExistentialPredicate {
413    type Output = rty::ExistentialPredicate;
414
415    fn refine(&self, refiner: &Refiner) -> QueryResult<Self::Output> {
416        let pred = match self {
417            ty::ExistentialPredicate::Trait(trait_ref) => {
418                rty::ExistentialPredicate::Trait(rty::ExistentialTraitRef {
419                    def_id: trait_ref.def_id,
420                    args: refiner.refine_existential_predicate_generic_args(
421                        trait_ref.def_id,
422                        &trait_ref.args,
423                    )?,
424                })
425            }
426            ty::ExistentialPredicate::Projection(projection) => {
427                let rty::TyOrBase::Base(term) = refiner.refine_ty_or_base(&projection.term)? else {
428                    return Err(query_bug!(
429                        refiner.def_id,
430                        "sorry, we can't handle non-base associated types"
431                    ));
432                };
433                rty::ExistentialPredicate::Projection(rty::ExistentialProjection {
434                    def_id: projection.def_id,
435                    args: refiner.refine_existential_predicate_generic_args(
436                        projection.def_id,
437                        &projection.args,
438                    )?,
439                    term,
440                })
441            }
442            ty::ExistentialPredicate::AutoTrait(def_id) => {
443                rty::ExistentialPredicate::AutoTrait(*def_id)
444            }
445        };
446        Ok(pred)
447    }
448}
449
450impl Refine for ty::GenericPredicates {
451    type Output = rty::GenericPredicates;
452
453    fn refine(&self, refiner: &Refiner) -> QueryResult<rty::GenericPredicates> {
454        Ok(rty::GenericPredicates {
455            parent: self.parent,
456            predicates: refiner.refine(&self.predicates)?,
457        })
458    }
459}
460
461impl<T> Refine for List<T>
462where
463    T: SliceInternable,
464    T: Refine<Output: SliceInternable>,
465{
466    type Output = rty::List<T::Output>;
467
468    fn refine(&self, refiner: &Refiner) -> QueryResult<rty::List<T::Output>> {
469        refiner.refine(&self[..])
470    }
471}
472
473impl<T> Refine for [T]
474where
475    T: Refine<Output: SliceInternable>,
476{
477    type Output = rty::List<T::Output>;
478
479    fn refine(&self, refiner: &Refiner) -> QueryResult<rty::List<T::Output>> {
480        self.iter().map(|t| refiner.refine(t)).try_collect()
481    }
482}
483
484fn refine_default(bty: rty::BaseTy) -> rty::SubsetTyCtor {
485    let sort = bty.sort();
486    let constr = rty::SubsetTy::trivial(bty.shift_in_escaping(1), rty::Expr::nu());
487    rty::Binder::bind_with_sort(constr, sort)
488}
489
490pub fn refine_bound_variables(vars: &[ty::BoundVariableKind]) -> List<rty::BoundVariableKind> {
491    vars.iter()
492        .map(|kind| {
493            match kind {
494                ty::BoundVariableKind::Region(kind) => rty::BoundVariableKind::Region(*kind),
495            }
496        })
497        .collect()
498}