1use std::slice;
2
3pub use flux_arc_interner::{List, impl_slice_internable};
4use flux_common::tracked_span_bug;
5use flux_macros::{TypeFoldable, TypeVisitable};
6use flux_rustc_bridge::{
7 ToRustc,
8 ty::{BoundRegion, Region},
9};
10use itertools::Itertools;
11use rustc_data_structures::unord::UnordMap;
12use rustc_macros::{Decodable, Encodable, TyDecodable, TyEncodable};
13use rustc_middle::ty::{BoundRegionKind, TyCtxt};
14use rustc_span::Symbol;
15
16use super::{
17 Expr, GenericArg, InferMode, RefineParam, Sort,
18 fold::TypeFoldable,
19 subst::{self, BoundVarReplacer, FnMutDelegate},
20};
21
22#[derive(Clone, Debug, TyEncodable, TyDecodable)]
23pub struct EarlyBinder<T>(pub T);
24
25impl<T> EarlyBinder<T> {
26 pub fn as_ref(&self) -> EarlyBinder<&T> {
27 EarlyBinder(&self.0)
28 }
29
30 pub fn as_deref(&self) -> EarlyBinder<&T::Target>
31 where
32 T: std::ops::Deref,
33 {
34 EarlyBinder(self.0.deref())
35 }
36
37 pub fn map<U>(self, f: impl FnOnce(T) -> U) -> EarlyBinder<U> {
38 EarlyBinder(f(self.0))
39 }
40
41 pub fn try_map<U, E>(self, f: impl FnOnce(T) -> Result<U, E>) -> Result<EarlyBinder<U>, E> {
42 Ok(EarlyBinder(f(self.0)?))
43 }
44
45 pub fn skip_binder(self) -> T {
46 self.0
47 }
48
49 pub fn skip_binder_ref(&self) -> &T {
50 &self.0
51 }
52
53 pub fn instantiate_identity(self) -> T {
54 self.0
55 }
56}
57
58impl<I: IntoIterator> EarlyBinder<I> {
59 pub fn iter_identity(self) -> impl Iterator<Item = I::Item> {
60 self.0.into_iter()
61 }
62}
63
64impl<T: TypeFoldable> EarlyBinder<T> {
65 pub fn instantiate(self, tcx: TyCtxt, args: &[GenericArg], refine_args: &[Expr]) -> T {
66 self.as_ref().instantiate_ref(tcx, args, refine_args)
67 }
68}
69
70impl<T: TypeFoldable> EarlyBinder<&T> {
71 pub fn instantiate_ref(self, tcx: TyCtxt, args: &[GenericArg], refine_args: &[Expr]) -> T {
72 self.0
73 .try_fold_with(&mut subst::GenericsSubstFolder::new(
74 subst::GenericArgsDelegate(args, tcx),
75 refine_args,
76 ))
77 .into_ok()
78 }
79}
80
81impl EarlyBinder<RefineParam> {
82 pub fn name(&self) -> Symbol {
83 self.skip_binder_ref().name
84 }
85}
86
87#[derive(Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
88pub struct Binder<T> {
89 vars: List<BoundVariableKind>,
90 value: T,
91}
92
93impl<T> Binder<T> {
94 pub fn bind_with_vars(value: T, vars: BoundVariableKinds) -> Binder<T> {
95 Binder { vars, value }
96 }
97
98 pub fn dummy(value: T) -> Binder<T> {
99 Binder::bind_with_vars(value, List::empty())
100 }
101
102 pub fn bind_with_sorts(value: T, sorts: &[Sort]) -> Binder<T> {
103 Binder::bind_with_vars(value, sorts.iter().cloned().map_into().collect())
104 }
105
106 pub fn bind_with_sort(value: T, sort: Sort) -> Binder<T> {
107 Binder::bind_with_sorts(value, &[sort])
108 }
109
110 pub fn vars(&self) -> &List<BoundVariableKind> {
111 &self.vars
112 }
113
114 pub fn as_ref(&self) -> Binder<&T> {
115 Binder { vars: self.vars.clone(), value: &self.value }
116 }
117
118 pub fn skip_binder(self) -> T {
119 self.value
120 }
121
122 pub fn skip_binder_ref(&self) -> &T {
123 self.as_ref().skip_binder()
124 }
125
126 pub fn rebind<U>(&self, value: U) -> Binder<U> {
127 Binder { vars: self.vars.clone(), value }
128 }
129
130 pub fn map<U>(self, f: impl FnOnce(T) -> U) -> Binder<U> {
131 Binder { vars: self.vars, value: f(self.value) }
132 }
133
134 pub fn map_ref<U>(&self, f: impl FnOnce(&T) -> U) -> Binder<U> {
135 Binder { vars: self.vars.clone(), value: f(&self.value) }
136 }
137
138 pub fn try_map<U, E>(self, f: impl FnOnce(T) -> Result<U, E>) -> Result<Binder<U>, E> {
139 Ok(Binder { vars: self.vars, value: f(self.value)? })
140 }
141
142 #[track_caller]
143 pub fn sort(&self) -> Sort {
144 match &self.vars[..] {
145 [BoundVariableKind::Refine(sort, ..)] => sort.clone(),
146 _ => tracked_span_bug!("expected single-sorted binder"),
147 }
148 }
149}
150
151impl<T> Binder<T>
152where
153 T: TypeFoldable,
154{
155 pub fn replace_bound_vars(
156 &self,
157 mut replace_region: impl FnMut(BoundRegion) -> Region,
158 mut replace_expr: impl FnMut(&Sort, InferMode) -> Expr,
159 ) -> T {
160 let mut exprs = UnordMap::default();
161 let mut regions = UnordMap::default();
162 let delegate = FnMutDelegate::new(
163 |breft| {
164 exprs
165 .entry(breft.var)
166 .or_insert_with(|| {
167 let (sort, mode, _) = self.vars[breft.var.as_usize()].expect_refine();
168 replace_expr(sort, mode)
169 })
170 .clone()
171 },
172 |br| *regions.entry(br.var).or_insert_with(|| replace_region(br)),
173 );
174
175 self.value.fold_with(&mut BoundVarReplacer::new(delegate))
176 }
177
178 pub fn replace_bound_refts(&self, exprs: &[Expr]) -> T {
179 let delegate = FnMutDelegate::new(
180 |breft| exprs[breft.var.as_usize()].clone(),
181 |br| tracked_span_bug!("unexpected escaping region {br:?}"),
182 );
183 self.value.fold_with(&mut BoundVarReplacer::new(delegate))
184 }
185
186 pub fn replace_bound_reft(&self, expr: &Expr) -> T {
187 debug_assert!(matches!(&self.vars[..], [BoundVariableKind::Refine(..)]));
188 self.replace_bound_refts(slice::from_ref(expr))
189 }
190
191 pub fn replace_bound_refts_with(
192 &self,
193 mut f: impl FnMut(&Sort, InferMode, BoundReftKind) -> Expr,
194 ) -> T {
195 let exprs = self
196 .vars
197 .iter()
198 .map(|param| {
199 let (sort, mode, kind) = param.expect_refine();
200 f(sort, mode, kind)
201 })
202 .collect_vec();
203 self.replace_bound_refts(&exprs)
204 }
205}
206
207impl<'tcx, V> ToRustc<'tcx> for Binder<V>
208where
209 V: ToRustc<'tcx, T: rustc_middle::ty::TypeVisitable<TyCtxt<'tcx>>>,
210{
211 type T = rustc_middle::ty::Binder<'tcx, V::T>;
212
213 fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
214 let vars = BoundVariableKind::to_rustc(&self.vars, tcx);
215 let value = self.value.to_rustc(tcx);
216 rustc_middle::ty::Binder::bind_with_vars(value, vars)
217 }
218}
219
220#[derive(
221 Clone, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
222)]
223pub enum BoundVariableKind {
224 Region(BoundRegionKind),
225 Refine(Sort, InferMode, BoundReftKind),
226}
227
228impl BoundVariableKind {
229 fn expect_refine(&self) -> (&Sort, InferMode, BoundReftKind) {
230 if let BoundVariableKind::Refine(sort, mode, kind) = self {
231 (sort, *mode, *kind)
232 } else {
233 tracked_span_bug!("expected `BoundVariableKind::Refine`")
234 }
235 }
236
237 pub fn expect_sort(&self) -> &Sort {
238 self.expect_refine().0
239 }
240
241 #[must_use]
245 pub fn is_refine(&self) -> bool {
246 matches!(self, Self::Refine(..))
247 }
248
249 fn to_rustc<'tcx>(
252 vars: &[Self],
253 tcx: TyCtxt<'tcx>,
254 ) -> &'tcx rustc_middle::ty::List<rustc_middle::ty::BoundVariableKind> {
255 tcx.mk_bound_variable_kinds_from_iter(vars.iter().flat_map(|kind| {
256 match kind {
257 BoundVariableKind::Region(brk) => {
258 Some(rustc_middle::ty::BoundVariableKind::Region(*brk))
259 }
260 BoundVariableKind::Refine(..) => None,
261 }
262 }))
263 }
264}
265
266impl From<Sort> for BoundVariableKind {
267 fn from(sort: Sort) -> Self {
268 Self::Refine(sort, InferMode::EVar, BoundReftKind::Annon)
269 }
270}
271
272pub type BoundVariableKinds = List<BoundVariableKind>;
273
274#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Encodable, Decodable)]
275pub enum BoundReftKind {
276 Annon,
277 Named(Symbol),
278}
279
280impl_slice_internable!(BoundVariableKind);