flux_middle/rty/
binder.rs

1use std::slice;
2
3pub use flux_arc_interner::{List, impl_slice_internable};
4use flux_common::tracked_span_bug;
5use flux_macros::{TypeFoldable, TypeVisitable};
6use flux_rustc_bridge::{
7    ToRustc,
8    ty::{BoundRegion, Region},
9};
10use itertools::Itertools;
11use rustc_data_structures::unord::UnordMap;
12use rustc_macros::{Decodable, Encodable, TyDecodable, TyEncodable};
13use rustc_middle::ty::{BoundRegionKind, TyCtxt};
14use rustc_span::Symbol;
15
16use super::{
17    Expr, GenericArg, InferMode, RefineParam, Sort,
18    fold::TypeFoldable,
19    subst::{self, BoundVarReplacer, FnMutDelegate},
20};
21
22#[derive(Clone, Debug, TyEncodable, TyDecodable)]
23pub struct EarlyBinder<T>(pub T);
24
25impl<T> EarlyBinder<T> {
26    pub fn as_ref(&self) -> EarlyBinder<&T> {
27        EarlyBinder(&self.0)
28    }
29
30    pub fn as_deref(&self) -> EarlyBinder<&T::Target>
31    where
32        T: std::ops::Deref,
33    {
34        EarlyBinder(self.0.deref())
35    }
36
37    pub fn map<U>(self, f: impl FnOnce(T) -> U) -> EarlyBinder<U> {
38        EarlyBinder(f(self.0))
39    }
40
41    pub fn try_map<U, E>(self, f: impl FnOnce(T) -> Result<U, E>) -> Result<EarlyBinder<U>, E> {
42        Ok(EarlyBinder(f(self.0)?))
43    }
44
45    pub fn skip_binder(self) -> T {
46        self.0
47    }
48
49    pub fn skip_binder_ref(&self) -> &T {
50        &self.0
51    }
52
53    pub fn instantiate_identity(self) -> T {
54        self.0
55    }
56}
57
58impl<I: IntoIterator> EarlyBinder<I> {
59    pub fn iter_identity(self) -> impl Iterator<Item = I::Item> {
60        self.0.into_iter()
61    }
62}
63
64impl<T: TypeFoldable> EarlyBinder<T> {
65    pub fn instantiate(self, tcx: TyCtxt, args: &[GenericArg], refine_args: &[Expr]) -> T {
66        self.as_ref().instantiate_ref(tcx, args, refine_args)
67    }
68}
69
70impl<T: TypeFoldable> EarlyBinder<&T> {
71    pub fn instantiate_ref(self, tcx: TyCtxt, args: &[GenericArg], refine_args: &[Expr]) -> T {
72        self.0
73            .try_fold_with(&mut subst::GenericsSubstFolder::new(
74                subst::GenericArgsDelegate(args, tcx),
75                refine_args,
76            ))
77            .into_ok()
78    }
79}
80
81impl EarlyBinder<RefineParam> {
82    pub fn name(&self) -> Symbol {
83        self.skip_binder_ref().name
84    }
85}
86
87#[derive(Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
88pub struct Binder<T> {
89    vars: List<BoundVariableKind>,
90    value: T,
91}
92
93impl<T> Binder<T> {
94    pub fn bind_with_vars(value: T, vars: BoundVariableKinds) -> Binder<T> {
95        Binder { vars, value }
96    }
97
98    pub fn dummy(value: T) -> Binder<T> {
99        Binder::bind_with_vars(value, List::empty())
100    }
101
102    pub fn bind_with_sorts(value: T, sorts: &[Sort]) -> Binder<T> {
103        Binder::bind_with_vars(value, sorts.iter().cloned().map_into().collect())
104    }
105
106    pub fn bind_with_sort(value: T, sort: Sort) -> Binder<T> {
107        Binder::bind_with_sorts(value, &[sort])
108    }
109
110    pub fn vars(&self) -> &List<BoundVariableKind> {
111        &self.vars
112    }
113
114    pub fn as_ref(&self) -> Binder<&T> {
115        Binder { vars: self.vars.clone(), value: &self.value }
116    }
117
118    pub fn skip_binder(self) -> T {
119        self.value
120    }
121
122    pub fn skip_binder_ref(&self) -> &T {
123        self.as_ref().skip_binder()
124    }
125
126    pub fn rebind<U>(&self, value: U) -> Binder<U> {
127        Binder { vars: self.vars.clone(), value }
128    }
129
130    pub fn map<U>(self, f: impl FnOnce(T) -> U) -> Binder<U> {
131        Binder { vars: self.vars, value: f(self.value) }
132    }
133
134    pub fn map_ref<U>(&self, f: impl FnOnce(&T) -> U) -> Binder<U> {
135        Binder { vars: self.vars.clone(), value: f(&self.value) }
136    }
137
138    pub fn try_map<U, E>(self, f: impl FnOnce(T) -> Result<U, E>) -> Result<Binder<U>, E> {
139        Ok(Binder { vars: self.vars, value: f(self.value)? })
140    }
141
142    #[track_caller]
143    pub fn sort(&self) -> Sort {
144        match &self.vars[..] {
145            [BoundVariableKind::Refine(sort, ..)] => sort.clone(),
146            _ => tracked_span_bug!("expected single-sorted binder"),
147        }
148    }
149}
150
151impl<T> Binder<T>
152where
153    T: TypeFoldable,
154{
155    pub fn replace_bound_vars(
156        &self,
157        mut replace_region: impl FnMut(BoundRegion) -> Region,
158        mut replace_expr: impl FnMut(&Sort, InferMode) -> Expr,
159    ) -> T {
160        let mut exprs = UnordMap::default();
161        let mut regions = UnordMap::default();
162        let delegate = FnMutDelegate::new(
163            |breft| {
164                exprs
165                    .entry(breft.var)
166                    .or_insert_with(|| {
167                        let (sort, mode, _) = self.vars[breft.var.as_usize()].expect_refine();
168                        replace_expr(sort, mode)
169                    })
170                    .clone()
171            },
172            |br| *regions.entry(br.var).or_insert_with(|| replace_region(br)),
173        );
174
175        self.value.fold_with(&mut BoundVarReplacer::new(delegate))
176    }
177
178    pub fn replace_bound_refts(&self, exprs: &[Expr]) -> T {
179        let delegate = FnMutDelegate::new(
180            |breft| exprs[breft.var.as_usize()].clone(),
181            |br| tracked_span_bug!("unexpected escaping region {br:?}"),
182        );
183        self.value.fold_with(&mut BoundVarReplacer::new(delegate))
184    }
185
186    pub fn replace_bound_reft(&self, expr: &Expr) -> T {
187        debug_assert!(matches!(&self.vars[..], [BoundVariableKind::Refine(..)]));
188        self.replace_bound_refts(slice::from_ref(expr))
189    }
190
191    pub fn replace_bound_refts_with(
192        &self,
193        mut f: impl FnMut(&Sort, InferMode, BoundReftKind) -> Expr,
194    ) -> T {
195        let exprs = self
196            .vars
197            .iter()
198            .map(|param| {
199                let (sort, mode, kind) = param.expect_refine();
200                f(sort, mode, kind)
201            })
202            .collect_vec();
203        self.replace_bound_refts(&exprs)
204    }
205}
206
207impl<'tcx, V> ToRustc<'tcx> for Binder<V>
208where
209    V: ToRustc<'tcx, T: rustc_middle::ty::TypeVisitable<TyCtxt<'tcx>>>,
210{
211    type T = rustc_middle::ty::Binder<'tcx, V::T>;
212
213    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
214        let vars = BoundVariableKind::to_rustc(&self.vars, tcx);
215        let value = self.value.to_rustc(tcx);
216        rustc_middle::ty::Binder::bind_with_vars(value, vars)
217    }
218}
219
220#[derive(
221    Clone, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
222)]
223pub enum BoundVariableKind {
224    Region(BoundRegionKind),
225    Refine(Sort, InferMode, BoundReftKind),
226}
227
228impl BoundVariableKind {
229    fn expect_refine(&self) -> (&Sort, InferMode, BoundReftKind) {
230        if let BoundVariableKind::Refine(sort, mode, kind) = self {
231            (sort, *mode, *kind)
232        } else {
233            tracked_span_bug!("expected `BoundVariableKind::Refine`")
234        }
235    }
236
237    pub fn expect_sort(&self) -> &Sort {
238        self.expect_refine().0
239    }
240
241    /// Returns `true` if the bound variable kind is [`Refine`].
242    ///
243    /// [`Refine`]: BoundVariableKind::Refine
244    #[must_use]
245    pub fn is_refine(&self) -> bool {
246        matches!(self, Self::Refine(..))
247    }
248
249    // We can't implement [`ToRustc`] on [`List<BoundVariableKind>`] because of coherence so we add
250    // it here
251    fn to_rustc<'tcx>(
252        vars: &[Self],
253        tcx: TyCtxt<'tcx>,
254    ) -> &'tcx rustc_middle::ty::List<rustc_middle::ty::BoundVariableKind> {
255        tcx.mk_bound_variable_kinds_from_iter(vars.iter().flat_map(|kind| {
256            match kind {
257                BoundVariableKind::Region(brk) => {
258                    Some(rustc_middle::ty::BoundVariableKind::Region(*brk))
259                }
260                BoundVariableKind::Refine(..) => None,
261            }
262        }))
263    }
264}
265
266impl From<Sort> for BoundVariableKind {
267    fn from(sort: Sort) -> Self {
268        Self::Refine(sort, InferMode::EVar, BoundReftKind::Annon)
269    }
270}
271
272pub type BoundVariableKinds = List<BoundVariableKind>;
273
274#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Encodable, Decodable)]
275pub enum BoundReftKind {
276    Annon,
277    Named(Symbol),
278}
279
280impl_slice_internable!(BoundVariableKind);