flux_middle/rty/
refining.rs

1//! *Refining* is the process of generating a refined version of a rust type.
2//!
3//! Concretely, this module provides functions to go from types in [`flux_rustc_bridge::ty`] to
4//! types in [`rty`].
5
6use flux_arc_interner::{List, SliceInternable};
7use flux_common::bug;
8use flux_rustc_bridge::{ty, ty::GenericArgsExt as _};
9use itertools::Itertools;
10use rustc_abi::VariantIdx;
11use rustc_hir::{def::DefKind, def_id::DefId};
12use rustc_middle::ty::ParamTy;
13
14use super::{RefineArgsExt, fold::TypeFoldable};
15use crate::{
16    global_env::GlobalEnv,
17    queries::{QueryErr, QueryResult},
18    query_bug, rty,
19};
20
21pub fn refine_generics(genv: GlobalEnv, def_id: DefId, generics: &ty::Generics) -> rty::Generics {
22    let is_box = if let DefKind::Struct = genv.def_kind(def_id) {
23        genv.tcx().adt_def(def_id).is_box()
24    } else {
25        false
26    };
27    let params = generics
28        .params
29        .iter()
30        .map(|param| {
31            rty::GenericParamDef {
32                kind: refine_generic_param_def_kind(is_box, param.kind),
33                index: param.index,
34                name: param.name,
35                def_id: param.def_id,
36            }
37        })
38        .collect();
39
40    rty::Generics {
41        own_params: params,
42        parent: generics.parent(),
43        parent_count: generics.parent_count(),
44        has_self: generics.orig.has_self,
45    }
46}
47
48pub fn refine_generic_param_def_kind(
49    is_box: bool,
50    kind: ty::GenericParamDefKind,
51) -> rty::GenericParamDefKind {
52    match kind {
53        ty::GenericParamDefKind::Lifetime => rty::GenericParamDefKind::Lifetime,
54        ty::GenericParamDefKind::Type { has_default } => {
55            if is_box {
56                rty::GenericParamDefKind::Type { has_default }
57            } else {
58                rty::GenericParamDefKind::Base { has_default }
59            }
60        }
61        ty::GenericParamDefKind::Const { has_default, .. } => {
62            rty::GenericParamDefKind::Const { has_default }
63        }
64    }
65}
66
67pub struct Refiner<'genv, 'tcx> {
68    genv: GlobalEnv<'genv, 'tcx>,
69    def_id: DefId,
70    generics: rty::Generics,
71    refine: fn(rty::BaseTy) -> rty::SubsetTyCtor,
72}
73
74impl<'genv, 'tcx> Refiner<'genv, 'tcx> {
75    pub fn new_for_item(
76        genv: GlobalEnv<'genv, 'tcx>,
77        def_id: DefId,
78        refine: fn(rty::BaseTy) -> rty::SubsetTyCtor,
79    ) -> QueryResult<Self> {
80        let generics = genv.generics_of(def_id)?;
81        Ok(Self { genv, def_id, generics, refine })
82    }
83
84    pub fn default_for_item(genv: GlobalEnv<'genv, 'tcx>, def_id: DefId) -> QueryResult<Self> {
85        Self::new_for_item(genv, def_id, refine_default)
86    }
87
88    pub fn with_holes(genv: GlobalEnv<'genv, 'tcx>, def_id: DefId) -> QueryResult<Self> {
89        Self::new_for_item(genv, def_id, |bty| {
90            let sort = bty.sort();
91            let constr = rty::SubsetTy::new(
92                bty.shift_in_escaping(1),
93                rty::Expr::nu(),
94                rty::Expr::hole(rty::HoleKind::Pred),
95            );
96            rty::Binder::bind_with_sort(constr, sort)
97        })
98    }
99
100    pub fn refine<T: Refine + ?Sized>(&self, t: &T) -> QueryResult<T::Output> {
101        t.refine(self)
102    }
103
104    fn refine_existential_predicate_generic_args(
105        &self,
106        def_id: DefId,
107        args: &ty::GenericArgs,
108    ) -> QueryResult<rty::GenericArgs> {
109        let generics = self.generics_of(def_id)?;
110        args.iter()
111            .enumerate()
112            .map(|(idx, arg)| {
113                // We need to skip the generic for Self
114                let param = generics.param_at(idx + 1, self.genv)?;
115                self.refine_generic_arg(&param, arg)
116            })
117            .try_collect()
118    }
119
120    pub fn refine_variant_def(
121        &self,
122        adt_def_id: DefId,
123        variant_idx: VariantIdx,
124    ) -> QueryResult<rty::PolyVariant> {
125        let adt_def = self.adt_def(adt_def_id)?;
126        let variant_def = adt_def.variant(variant_idx);
127        let fields = variant_def
128            .fields
129            .iter()
130            .map(|fld| {
131                let ty = self.genv.lower_type_of(fld.did)?.instantiate_identity();
132                ty.refine(self)
133            })
134            .try_collect()?;
135
136        let idx = if adt_def.sort_def().is_struct() {
137            rty::Expr::unit_struct(adt_def_id)
138        } else {
139            rty::Expr::ctor_enum(adt_def_id, variant_idx)
140        };
141        let value = rty::VariantSig::new(
142            adt_def,
143            rty::GenericArg::identity_for_item(self.genv, adt_def_id)?,
144            fields,
145            idx,
146            List::empty(),
147        );
148
149        Ok(rty::Binder::bind_with_vars(value, List::empty()))
150    }
151
152    fn refine_generic_args(
153        &self,
154        def_id: DefId,
155        args: &ty::GenericArgs,
156    ) -> QueryResult<rty::GenericArgs> {
157        let generics = self.generics_of(def_id)?;
158        args.iter()
159            .enumerate()
160            .map(|(idx, arg)| {
161                let param = generics.param_at(idx, self.genv)?;
162                self.refine_generic_arg(&param, arg)
163            })
164            .collect()
165    }
166
167    pub fn refine_generic_arg(
168        &self,
169        param: &rty::GenericParamDef,
170        arg: &ty::GenericArg,
171    ) -> QueryResult<rty::GenericArg> {
172        match (&param.kind, arg) {
173            (rty::GenericParamDefKind::Type { .. }, ty::GenericArg::Ty(ty)) => {
174                Ok(rty::GenericArg::Ty(ty.refine(self)?))
175            }
176            (rty::GenericParamDefKind::Base { .. }, ty::GenericArg::Ty(ty)) => {
177                let rty::TyOrBase::Base(contr) = self.refine_ty_or_base(ty)? else {
178                    return Err(QueryErr::InvalidGenericArg { def_id: param.def_id });
179                };
180                Ok(rty::GenericArg::Base(contr))
181            }
182            (rty::GenericParamDefKind::Lifetime, ty::GenericArg::Lifetime(re)) => {
183                Ok(rty::GenericArg::Lifetime(*re))
184            }
185            (rty::GenericParamDefKind::Const { .. }, ty::GenericArg::Const(ct)) => {
186                Ok(rty::GenericArg::Const(ct.clone()))
187            }
188            _ => bug!("mismatched generic arg `{arg:?}` `{param:?}`"),
189        }
190    }
191
192    fn refine_alias_ty(
193        &self,
194        alias_kind: ty::AliasKind,
195        alias_ty: &ty::AliasTy,
196    ) -> QueryResult<rty::AliasTy> {
197        let def_id = alias_ty.def_id;
198        let args = self.refine_generic_args(def_id, &alias_ty.args)?;
199
200        let refine_args = if let ty::AliasKind::Opaque = alias_kind {
201            rty::RefineArgs::for_item(self.genv, def_id, |param, _| {
202                let param = param.instantiate(self.genv.tcx(), &args, &[]);
203                Ok(rty::Expr::hole(rty::HoleKind::Expr(param.sort)))
204            })?
205        } else {
206            List::empty()
207        };
208
209        Ok(rty::AliasTy::new(def_id, args, refine_args))
210    }
211
212    pub fn refine_ty_or_base(&self, ty: &ty::Ty) -> QueryResult<rty::TyOrBase> {
213        let bty = match ty.kind() {
214            ty::TyKind::Closure(did, args) => {
215                let closure_args = args.as_closure();
216                let upvar_tys = closure_args
217                    .upvar_tys()
218                    .iter()
219                    .map(|ty| ty.refine(self))
220                    .try_collect()?;
221                rty::BaseTy::Closure(*did, upvar_tys, args.clone())
222            }
223            ty::TyKind::Coroutine(did, args) => {
224                let args = args.as_coroutine();
225                let resume_ty = args.resume_ty().refine(self)?;
226                let upvar_tys = args.upvar_tys().map(|ty| ty.refine(self)).try_collect()?;
227                rty::BaseTy::Coroutine(*did, resume_ty, upvar_tys)
228            }
229            ty::TyKind::CoroutineWitness(..) => {
230                bug!("implement when we know what this is");
231            }
232            ty::TyKind::Never => rty::BaseTy::Never,
233            ty::TyKind::Ref(r, ty, mutbl) => rty::BaseTy::Ref(*r, ty.refine(self)?, *mutbl),
234            ty::TyKind::Float(float_ty) => rty::BaseTy::Float(*float_ty),
235            ty::TyKind::Tuple(tys) => {
236                let tys = tys.iter().map(|ty| ty.refine(self)).try_collect()?;
237                rty::BaseTy::Tuple(tys)
238            }
239            ty::TyKind::Array(ty, len) => rty::BaseTy::Array(ty.refine(self)?, len.clone()),
240            ty::TyKind::Param(param_ty) => {
241                match self.param(*param_ty)?.kind {
242                    rty::GenericParamDefKind::Type { .. } => {
243                        return Ok(rty::TyOrBase::Ty(rty::Ty::param(*param_ty)));
244                    }
245                    rty::GenericParamDefKind::Base { .. } => rty::BaseTy::Param(*param_ty),
246                    rty::GenericParamDefKind::Lifetime | rty::GenericParamDefKind::Const { .. } => {
247                        bug!()
248                    }
249                }
250            }
251            ty::TyKind::Adt(adt_def, args) => {
252                let adt_def = self.genv.adt_def(adt_def.did())?;
253                let args = self.refine_generic_args(adt_def.did(), args)?;
254                rty::BaseTy::adt(adt_def, args)
255            }
256            ty::TyKind::FnDef(def_id, args) => {
257                let args = self.refine_generic_args(*def_id, args)?;
258                rty::BaseTy::fn_def(*def_id, args)
259            }
260            ty::TyKind::Alias(kind, alias_ty) => {
261                let alias_ty = self.as_default().refine_alias_ty(*kind, alias_ty)?;
262                rty::BaseTy::Alias(*kind, alias_ty)
263            }
264            ty::TyKind::Bool => rty::BaseTy::Bool,
265            ty::TyKind::Int(int_ty) => rty::BaseTy::Int(*int_ty),
266            ty::TyKind::Uint(uint_ty) => rty::BaseTy::Uint(*uint_ty),
267            ty::TyKind::Foreign(def_id) => rty::BaseTy::Foreign(*def_id),
268            ty::TyKind::Str => rty::BaseTy::Str,
269            ty::TyKind::Slice(ty) => rty::BaseTy::Slice(ty.refine(self)?),
270            ty::TyKind::Char => rty::BaseTy::Char,
271            ty::TyKind::FnPtr(poly_fn_sig) => {
272                rty::BaseTy::FnPtr(poly_fn_sig.refine(&self.as_default())?)
273            }
274            ty::TyKind::RawPtr(ty, mu) => rty::BaseTy::RawPtr(ty.refine(&self.as_default())?, *mu),
275            ty::TyKind::Dynamic(exi_preds, r) => {
276                let exi_preds = exi_preds
277                    .iter()
278                    .map(|pred| pred.refine(self))
279                    .try_collect()?;
280                rty::BaseTy::Dynamic(exi_preds, *r)
281            }
282        };
283        Ok(rty::TyOrBase::Base((self.refine)(bty)))
284    }
285
286    fn as_default(&self) -> Self {
287        Refiner { refine: refine_default, generics: self.generics.clone(), ..*self }
288    }
289
290    fn adt_def(&self, def_id: DefId) -> QueryResult<rty::AdtDef> {
291        self.genv.adt_def(def_id)
292    }
293
294    fn generics_of(&self, def_id: DefId) -> QueryResult<rty::Generics> {
295        self.genv.generics_of(def_id)
296    }
297
298    fn param(&self, param_ty: ParamTy) -> QueryResult<rty::GenericParamDef> {
299        self.generics.param_at(param_ty.index as usize, self.genv)
300    }
301}
302
303pub trait Refine {
304    type Output;
305
306    fn refine(&self, refiner: &Refiner) -> QueryResult<Self::Output>;
307}
308
309impl Refine for ty::Ty {
310    type Output = rty::Ty;
311
312    fn refine(&self, refiner: &Refiner) -> QueryResult<rty::Ty> {
313        Ok(refiner.refine_ty_or_base(self)?.into_ty())
314    }
315}
316
317impl<T: Refine> Refine for ty::Binder<T> {
318    type Output = rty::Binder<T::Output>;
319
320    fn refine(&self, refiner: &Refiner) -> QueryResult<rty::Binder<T::Output>> {
321        let vars = refine_bound_variables(self.vars());
322        let inner = self.skip_binder_ref().refine(refiner)?;
323        Ok(rty::Binder::bind_with_vars(inner, vars))
324    }
325}
326
327impl Refine for ty::FnSig {
328    type Output = rty::FnSig;
329
330    // TODO(hof2)
331    fn refine(&self, refiner: &Refiner) -> QueryResult<rty::FnSig> {
332        let inputs = self
333            .inputs()
334            .iter()
335            .map(|ty| ty.refine(refiner))
336            .try_collect()?;
337        let ret = self.output().refine(refiner)?.shift_in_escaping(1);
338        let output = rty::Binder::bind_with_vars(rty::FnOutput::new(ret, vec![]), List::empty());
339        // TODO(hof2) make a hoister to hoist all the stuff out of the inputs,
340        // the hoister will have a list of all the variables it hoisted and the
341        // single hole for the "requires"; then we "fill" the hole with a KVAR
342        // and generate a PolyFnSig with the hoisted variables
343        // see `into_bb_env` in `type_env.rs` for an example.
344        Ok(rty::FnSig::new(self.safety, self.abi, List::empty(), inputs, output))
345    }
346}
347
348impl Refine for ty::Clause {
349    type Output = rty::Clause;
350
351    fn refine(&self, refiner: &Refiner) -> QueryResult<rty::Clause> {
352        Ok(rty::Clause { kind: self.kind.refine(refiner)? })
353    }
354}
355
356impl Refine for ty::TraitRef {
357    type Output = rty::TraitRef;
358
359    fn refine(&self, refiner: &Refiner) -> QueryResult<rty::TraitRef> {
360        Ok(rty::TraitRef {
361            def_id: self.def_id,
362            args: refiner.refine_generic_args(self.def_id, &self.args)?,
363        })
364    }
365}
366
367impl Refine for ty::ClauseKind {
368    type Output = rty::ClauseKind;
369
370    fn refine(&self, refiner: &Refiner) -> QueryResult<rty::ClauseKind> {
371        let kind = match self {
372            ty::ClauseKind::Trait(trait_pred) => {
373                let pred = rty::TraitPredicate { trait_ref: trait_pred.trait_ref.refine(refiner)? };
374                rty::ClauseKind::Trait(pred)
375            }
376            ty::ClauseKind::Projection(proj_pred) => {
377                let rty::TyOrBase::Base(term) = refiner.refine_ty_or_base(&proj_pred.term)? else {
378                    return Err(query_bug!(
379                        refiner.def_id,
380                        "sorry, we can't handle non-base associated types"
381                    ));
382                };
383                let pred = rty::ProjectionPredicate {
384                    projection_ty: refiner
385                        .refine_alias_ty(ty::AliasKind::Projection, &proj_pred.projection_ty)?,
386                    term,
387                };
388                rty::ClauseKind::Projection(pred)
389            }
390            ty::ClauseKind::RegionOutlives(pred) => {
391                let pred = rty::OutlivesPredicate(pred.0, pred.1);
392                rty::ClauseKind::RegionOutlives(pred)
393            }
394            ty::ClauseKind::TypeOutlives(pred) => {
395                let pred = rty::OutlivesPredicate(pred.0.refine(refiner)?, pred.1);
396                rty::ClauseKind::TypeOutlives(pred)
397            }
398            ty::ClauseKind::ConstArgHasType(const_, ty) => {
399                rty::ClauseKind::ConstArgHasType(const_.clone(), ty.refine(&refiner.as_default())?)
400            }
401        };
402        Ok(kind)
403    }
404}
405
406impl Refine for ty::ExistentialPredicate {
407    type Output = rty::ExistentialPredicate;
408
409    fn refine(&self, refiner: &Refiner) -> QueryResult<Self::Output> {
410        let pred = match self {
411            ty::ExistentialPredicate::Trait(trait_ref) => {
412                rty::ExistentialPredicate::Trait(rty::ExistentialTraitRef {
413                    def_id: trait_ref.def_id,
414                    args: refiner.refine_existential_predicate_generic_args(
415                        trait_ref.def_id,
416                        &trait_ref.args,
417                    )?,
418                })
419            }
420            ty::ExistentialPredicate::Projection(projection) => {
421                let rty::TyOrBase::Base(term) = refiner.refine_ty_or_base(&projection.term)? else {
422                    return Err(query_bug!(
423                        refiner.def_id,
424                        "sorry, we can't handle non-base associated types"
425                    ));
426                };
427                rty::ExistentialPredicate::Projection(rty::ExistentialProjection {
428                    def_id: projection.def_id,
429                    args: refiner.refine_existential_predicate_generic_args(
430                        projection.def_id,
431                        &projection.args,
432                    )?,
433                    term,
434                })
435            }
436            ty::ExistentialPredicate::AutoTrait(def_id) => {
437                rty::ExistentialPredicate::AutoTrait(*def_id)
438            }
439        };
440        Ok(pred)
441    }
442}
443
444impl Refine for ty::GenericPredicates {
445    type Output = rty::GenericPredicates;
446
447    fn refine(&self, refiner: &Refiner) -> QueryResult<rty::GenericPredicates> {
448        Ok(rty::GenericPredicates {
449            parent: self.parent,
450            predicates: refiner.refine(&self.predicates)?,
451        })
452    }
453}
454
455impl<T> Refine for List<T>
456where
457    T: SliceInternable,
458    T: Refine<Output: SliceInternable>,
459{
460    type Output = rty::List<T::Output>;
461
462    fn refine(&self, refiner: &Refiner) -> QueryResult<rty::List<T::Output>> {
463        refiner.refine(&self[..])
464    }
465}
466
467impl<T> Refine for [T]
468where
469    T: Refine<Output: SliceInternable>,
470{
471    type Output = rty::List<T::Output>;
472
473    fn refine(&self, refiner: &Refiner) -> QueryResult<rty::List<T::Output>> {
474        self.iter().map(|t| refiner.refine(t)).try_collect()
475    }
476}
477
478fn refine_default(bty: rty::BaseTy) -> rty::SubsetTyCtor {
479    let sort = bty.sort();
480    let constr = rty::SubsetTy::trivial(bty.shift_in_escaping(1), rty::Expr::nu());
481    rty::Binder::bind_with_sort(constr, sort)
482}
483
484pub fn refine_bound_variables(vars: &[ty::BoundVariableKind]) -> List<rty::BoundVariableKind> {
485    vars.iter()
486        .map(|kind| {
487            match kind {
488                ty::BoundVariableKind::Region(kind) => rty::BoundVariableKind::Region(*kind),
489            }
490        })
491        .collect()
492}