flux_middle/rty/
mod.rs

1//! Defines how flux represents refinement types internally. Definitions in this module are used
2//! during refinement type checking. A couple of important differences between definitions in this
3//! module and in [`crate::fhir`] are:
4//!
5//! * Types in this module use debruijn indices to represent local binders.
6//! * Data structures are interned so they can be cheaply cloned.
7mod binder;
8pub mod canonicalize;
9mod expr;
10pub mod fold;
11pub mod normalize;
12mod pretty;
13pub mod refining;
14pub mod region_matching;
15pub mod subst;
16use std::{borrow::Cow, cmp::Ordering, fmt, hash::Hash, sync::LazyLock};
17
18pub use binder::{Binder, BoundReftKind, BoundVariableKind, BoundVariableKinds, EarlyBinder};
19use bitflags::bitflags;
20pub use expr::{
21    AggregateKind, AliasReft, BinOp, BoundReft, Constant, Ctor, ESpan, EVid, EarlyReftParam, Expr,
22    ExprKind, FieldProj, HoleKind, InternalFuncKind, KVar, KVid, Lambda, Loc, Name, NameProvenance,
23    Path, PrettyMap, PrettyVar, Real, SpecFuncKind, UnOp, Var,
24};
25pub use flux_arc_interner::List;
26use flux_arc_interner::{Interned, impl_internable, impl_slice_internable};
27use flux_common::{bug, tracked_span_assert_eq, tracked_span_bug};
28use flux_config::OverflowMode;
29use flux_macros::{TypeFoldable, TypeVisitable};
30pub use flux_rustc_bridge::ty::{
31    AliasKind, BoundRegion, BoundRegionKind, BoundVar, Const, ConstKind, ConstVid, DebruijnIndex,
32    EarlyParamRegion, LateParamRegion, LateParamRegionKind,
33    Region::{self, *},
34    RegionVid,
35};
36use flux_rustc_bridge::{
37    ToRustc,
38    mir::{Place, RawPtrKind},
39    ty::{self, GenericArgsExt as _, VariantDef},
40};
41use itertools::Itertools;
42pub use normalize::{FuncInfo, NormalizedDefns, local_deps};
43use refining::Refiner;
44use rustc_abi;
45pub use rustc_abi::{FIRST_VARIANT, VariantIdx};
46use rustc_data_structures::{fx::FxIndexMap, snapshot_map::SnapshotMap, unord::UnordMap};
47use rustc_hir::{LangItem, Safety, def_id::DefId};
48use rustc_index::{IndexSlice, IndexVec, newtype_index};
49use rustc_macros::{Decodable, Encodable, TyDecodable, TyEncodable, extension};
50pub use rustc_middle::{
51    mir::Mutability,
52    ty::{AdtFlags, ClosureKind, FloatTy, IntTy, ParamConst, ParamTy, ScalarInt, UintTy},
53};
54use rustc_middle::{
55    query::IntoQueryParam,
56    ty::{TyCtxt, fast_reject::SimplifiedType},
57};
58use rustc_span::{DUMMY_SP, Span, Symbol, sym, symbol::kw};
59use rustc_type_ir::Upcast as _;
60pub use rustc_type_ir::{INNERMOST, TyVid};
61
62use self::fold::TypeFoldable;
63pub use crate::fhir::InferMode;
64use crate::{
65    LocalDefId,
66    def_id::{FluxDefId, FluxLocalDefId},
67    fhir::{self, FhirId, FluxOwnerId},
68    global_env::GlobalEnv,
69    pretty::{Pretty, PrettyCx},
70    queries::{QueryErr, QueryResult},
71    rty::subst::SortSubst,
72};
73
74/// The definition of the data sort automatically generated for a struct or enum.
75#[derive(Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
76pub struct AdtSortDef(Interned<AdtSortDefData>);
77
78#[derive(Debug, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
79pub struct AdtSortVariant {
80    /// The list of field names as declared in the `#[flux::refined_by(...)]` annotation
81    field_names: Vec<Symbol>,
82    /// The sort of each of the fields. Note that these can contain [sort variables]. Methods used
83    /// to access these sorts guarantee they are properly instantiated.
84    ///
85    /// [sort variables]: Sort::Var
86    sorts: List<Sort>,
87}
88
89impl AdtSortVariant {
90    pub fn new(fields: Vec<(Symbol, Sort)>) -> Self {
91        let (field_names, sorts) = fields.into_iter().unzip();
92        AdtSortVariant { field_names, sorts: List::from_vec(sorts) }
93    }
94
95    pub fn fields(&self) -> usize {
96        self.sorts.len()
97    }
98
99    pub fn field_names(&self) -> &Vec<Symbol> {
100        &self.field_names
101    }
102
103    pub fn sort_by_field_name(&self, args: &[Sort]) -> FxIndexMap<Symbol, Sort> {
104        std::iter::zip(&self.field_names, &self.sorts.fold_with(&mut SortSubst::new(args)))
105            .map(|(name, sort)| (*name, sort.clone()))
106            .collect()
107    }
108
109    pub fn field_by_name(
110        &self,
111        def_id: DefId,
112        args: &[Sort],
113        name: Symbol,
114    ) -> Option<(FieldProj, Sort)> {
115        let idx = self.field_names.iter().position(|it| name == *it)?;
116        let proj = FieldProj::Adt { def_id, field: idx as u32 };
117        let sort = self.sorts[idx].fold_with(&mut SortSubst::new(args));
118        Some((proj, sort))
119    }
120
121    pub fn field_sorts(&self, args: &[Sort]) -> List<Sort> {
122        self.sorts.fold_with(&mut SortSubst::new(args))
123    }
124
125    pub fn field_sorts_instantiate_identity(&self) -> List<Sort> {
126        self.sorts.clone()
127    }
128
129    pub fn projections(&self, def_id: DefId) -> impl Iterator<Item = FieldProj> {
130        (0..self.fields()).map(move |i| FieldProj::Adt { def_id, field: i as u32 })
131    }
132}
133
134#[derive(Debug, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
135struct AdtSortDefData {
136    /// [`DefId`] of the struct or enum this data sort is associated to.
137    def_id: DefId,
138    /// The list of the type parameters used in the `#[flux::refined_by(..)]` annotation.
139    ///
140    /// See [`fhir::RefinedBy::sort_params`] for more details. This is a version of that but using
141    /// [`ParamTy`] instead of [`DefId`].
142    ///
143    /// The length of this list corresponds to the number of sort variables bound by this definition.
144    params: Vec<ParamTy>,
145    /// A vec of variants of the ADT;
146    /// - a `struct` sort -- used for types with a `refined_by` has a single variant;
147    /// - a `reflected` sort -- used for `reflected` enums have multiple variants
148    variants: IndexVec<VariantIdx, AdtSortVariant>,
149    is_reflected: bool,
150    is_struct: bool,
151}
152
153impl AdtSortDef {
154    pub fn new(
155        def_id: DefId,
156        params: Vec<ParamTy>,
157        variants: IndexVec<VariantIdx, AdtSortVariant>,
158        is_reflected: bool,
159        is_struct: bool,
160    ) -> Self {
161        Self(Interned::new(AdtSortDefData { def_id, params, variants, is_reflected, is_struct }))
162    }
163
164    pub fn did(&self) -> DefId {
165        self.0.def_id
166    }
167
168    pub fn variant(&self, idx: VariantIdx) -> &AdtSortVariant {
169        &self.0.variants[idx]
170    }
171
172    pub fn variants(&self) -> &IndexSlice<VariantIdx, AdtSortVariant> {
173        &self.0.variants
174    }
175
176    pub fn opt_struct_variant(&self) -> Option<&AdtSortVariant> {
177        if self.is_struct() { Some(self.struct_variant()) } else { None }
178    }
179
180    #[track_caller]
181    pub fn struct_variant(&self) -> &AdtSortVariant {
182        tracked_span_assert_eq!(self.0.is_struct, true);
183        &self.0.variants[FIRST_VARIANT]
184    }
185
186    pub fn is_reflected(&self) -> bool {
187        self.0.is_reflected
188    }
189
190    pub fn is_struct(&self) -> bool {
191        self.0.is_struct
192    }
193
194    pub fn to_sort(&self, args: &[GenericArg]) -> Sort {
195        let sorts = self
196            .filter_generic_args(args)
197            .map(|arg| arg.expect_base().sort())
198            .collect();
199
200        Sort::App(SortCtor::Adt(self.clone()), sorts)
201    }
202
203    /// Given a list of generic args, returns an iterator of the generic arguments that should be
204    /// mapped to sorts for instantiation.
205    pub fn filter_generic_args<'a, A>(&'a self, args: &'a [A]) -> impl Iterator<Item = &'a A> + 'a {
206        self.0.params.iter().map(|p| &args[p.index as usize])
207    }
208
209    pub fn identity_args(&self) -> List<Sort> {
210        (0..self.0.params.len())
211            .map(|i| Sort::Var(ParamSort::from(i)))
212            .collect()
213    }
214
215    /// Gives the number of sort variables bound by this definition.
216    pub fn param_count(&self) -> usize {
217        self.0.params.len()
218    }
219}
220
221#[derive(Debug, Clone, Default, Encodable, Decodable)]
222pub struct Generics {
223    pub parent: Option<DefId>,
224    pub parent_count: usize,
225    pub own_params: List<GenericParamDef>,
226    pub has_self: bool,
227}
228
229impl Generics {
230    pub fn count(&self) -> usize {
231        self.parent_count + self.own_params.len()
232    }
233
234    pub fn own_default_count(&self) -> usize {
235        self.own_params
236            .iter()
237            .filter(|param| {
238                match param.kind {
239                    GenericParamDefKind::Type { has_default }
240                    | GenericParamDefKind::Const { has_default }
241                    | GenericParamDefKind::Base { has_default } => has_default,
242                    GenericParamDefKind::Lifetime => false,
243                }
244            })
245            .count()
246    }
247
248    pub fn param_at(&self, param_index: usize, genv: GlobalEnv) -> QueryResult<GenericParamDef> {
249        if let Some(index) = param_index.checked_sub(self.parent_count) {
250            Ok(self.own_params[index].clone())
251        } else {
252            let parent = self.parent.expect("parent_count > 0 but no parent?");
253            genv.generics_of(parent)?.param_at(param_index, genv)
254        }
255    }
256
257    pub fn const_params(&self, genv: GlobalEnv) -> QueryResult<Vec<(ParamConst, Sort)>> {
258        let mut res = vec![];
259        for i in 0..self.count() {
260            let param = self.param_at(i, genv)?;
261            if let GenericParamDefKind::Const { .. } = param.kind
262                && let Some(sort) = genv.sort_of_def_id(param.def_id)?
263            {
264                let param_const = ParamConst { name: param.name, index: param.index };
265                res.push((param_const, sort));
266            }
267        }
268        Ok(res)
269    }
270}
271
272#[derive(Debug, Clone, TyEncodable, TyDecodable)]
273pub struct RefinementGenerics {
274    pub parent: Option<DefId>,
275    pub parent_count: usize,
276    pub own_params: List<RefineParam>,
277}
278
279#[derive(
280    PartialEq, Eq, Debug, Clone, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
281)]
282pub struct RefineParam {
283    pub sort: Sort,
284    pub name: Symbol,
285    pub mode: InferMode,
286}
287
288#[derive(Debug, Clone, PartialEq, Eq, Hash, Encodable, Decodable)]
289pub struct GenericParamDef {
290    pub kind: GenericParamDefKind,
291    pub def_id: DefId,
292    pub index: u32,
293    pub name: Symbol,
294}
295
296#[derive(Copy, Debug, Clone, PartialEq, Eq, Hash, Encodable, Decodable)]
297pub enum GenericParamDefKind {
298    Type { has_default: bool },
299    Base { has_default: bool },
300    Lifetime,
301    Const { has_default: bool },
302}
303
304pub const SELF_PARAM_TY: ParamTy = ParamTy { index: 0, name: kw::SelfUpper };
305
306#[derive(Debug, Clone, TyEncodable, TyDecodable)]
307pub struct GenericPredicates {
308    pub parent: Option<DefId>,
309    pub predicates: List<Clause>,
310}
311
312#[derive(
313    Debug, PartialEq, Eq, Hash, Clone, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
314)]
315pub struct Clause {
316    kind: Binder<ClauseKind>,
317}
318
319impl Clause {
320    pub fn new(vars: impl Into<List<BoundVariableKind>>, kind: ClauseKind) -> Self {
321        Clause { kind: Binder::bind_with_vars(kind, vars.into()) }
322    }
323
324    pub fn kind(&self) -> Binder<ClauseKind> {
325        self.kind.clone()
326    }
327
328    fn as_trait_clause(&self) -> Option<Binder<TraitPredicate>> {
329        let clause = self.kind();
330        if let ClauseKind::Trait(trait_clause) = clause.skip_binder_ref() {
331            Some(clause.rebind(trait_clause.clone()))
332        } else {
333            None
334        }
335    }
336
337    pub fn as_projection_clause(&self) -> Option<Binder<ProjectionPredicate>> {
338        let clause = self.kind();
339        if let ClauseKind::Projection(proj_clause) = clause.skip_binder_ref() {
340            Some(clause.rebind(proj_clause.clone()))
341        } else {
342            None
343        }
344    }
345
346    // FIXME(nilehmann) we should deal with the binder in all the places this is used instead of
347    // blindly skipping it here
348    pub fn kind_skipping_binder(&self) -> ClauseKind {
349        self.kind.clone().skip_binder()
350    }
351
352    /// Group `Fn` trait clauses with their corresponding `FnOnce::Output` projection
353    /// predicate. This assumes there's exactly one corresponding projection predicate and will
354    /// crash otherwise.
355    pub fn split_off_fn_trait_clauses(
356        genv: GlobalEnv,
357        clauses: &Clauses,
358    ) -> (Vec<Clause>, Vec<Binder<FnTraitPredicate>>) {
359        let mut fn_trait_clauses = vec![];
360        let mut fn_trait_output_clauses = vec![];
361        let mut rest = vec![];
362        for clause in clauses {
363            if let Some(trait_clause) = clause.as_trait_clause()
364                && let Some(kind) = genv.tcx().fn_trait_kind_from_def_id(trait_clause.def_id())
365            {
366                fn_trait_clauses.push((kind, trait_clause));
367            } else if let Some(proj_clause) = clause.as_projection_clause()
368                && genv.is_fn_output(proj_clause.projection_def_id())
369            {
370                fn_trait_output_clauses.push(proj_clause);
371            } else {
372                rest.push(clause.clone());
373            }
374        }
375        let fn_trait_clauses = fn_trait_clauses
376            .into_iter()
377            .map(|(kind, fn_trait_clause)| {
378                let mut candidates = vec![];
379                for fn_trait_output_clause in &fn_trait_output_clauses {
380                    if fn_trait_output_clause.self_ty() == fn_trait_clause.self_ty() {
381                        candidates.push(fn_trait_output_clause.clone());
382                    }
383                }
384                tracked_span_assert_eq!(candidates.len(), 1);
385                let proj_pred = candidates.pop().unwrap().skip_binder();
386                fn_trait_clause.map(|fn_trait_clause| {
387                    FnTraitPredicate {
388                        kind,
389                        self_ty: fn_trait_clause.self_ty().to_ty(),
390                        tupled_args: fn_trait_clause.trait_ref.args[1].expect_base().to_ty(),
391                        output: proj_pred.term.to_ty(),
392                    }
393                })
394            })
395            .collect_vec();
396        (rest, fn_trait_clauses)
397    }
398}
399
400impl<'tcx> ToRustc<'tcx> for Clause {
401    type T = rustc_middle::ty::Clause<'tcx>;
402
403    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
404        self.kind.to_rustc(tcx).upcast(tcx)
405    }
406}
407
408impl From<Binder<ClauseKind>> for Clause {
409    fn from(kind: Binder<ClauseKind>) -> Self {
410        Clause { kind }
411    }
412}
413
414pub type Clauses = List<Clause>;
415
416#[derive(
417    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
418)]
419pub enum ClauseKind {
420    Trait(TraitPredicate),
421    Projection(ProjectionPredicate),
422    RegionOutlives(RegionOutlivesPredicate),
423    TypeOutlives(TypeOutlivesPredicate),
424    ConstArgHasType(Const, Ty),
425    UnstableFeature(Symbol),
426}
427
428impl<'tcx> ToRustc<'tcx> for ClauseKind {
429    type T = rustc_middle::ty::ClauseKind<'tcx>;
430
431    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
432        match self {
433            ClauseKind::Trait(trait_predicate) => {
434                rustc_middle::ty::ClauseKind::Trait(trait_predicate.to_rustc(tcx))
435            }
436            ClauseKind::Projection(projection_predicate) => {
437                rustc_middle::ty::ClauseKind::Projection(projection_predicate.to_rustc(tcx))
438            }
439            ClauseKind::RegionOutlives(outlives_predicate) => {
440                rustc_middle::ty::ClauseKind::RegionOutlives(outlives_predicate.to_rustc(tcx))
441            }
442            ClauseKind::TypeOutlives(outlives_predicate) => {
443                rustc_middle::ty::ClauseKind::TypeOutlives(outlives_predicate.to_rustc(tcx))
444            }
445            ClauseKind::ConstArgHasType(constant, ty) => {
446                rustc_middle::ty::ClauseKind::ConstArgHasType(
447                    constant.to_rustc(tcx),
448                    ty.to_rustc(tcx),
449                )
450            }
451            ClauseKind::UnstableFeature(sym) => rustc_middle::ty::ClauseKind::UnstableFeature(*sym),
452        }
453    }
454}
455
456#[derive(Eq, PartialEq, Hash, Clone, Debug, TyEncodable, TyDecodable)]
457pub struct OutlivesPredicate<T>(pub T, pub Region);
458
459pub type TypeOutlivesPredicate = OutlivesPredicate<Ty>;
460pub type RegionOutlivesPredicate = OutlivesPredicate<Region>;
461
462impl<'tcx, V: ToRustc<'tcx>> ToRustc<'tcx> for OutlivesPredicate<V> {
463    type T = rustc_middle::ty::OutlivesPredicate<'tcx, V::T>;
464
465    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
466        rustc_middle::ty::OutlivesPredicate(self.0.to_rustc(tcx), self.1.to_rustc(tcx))
467    }
468}
469
470#[derive(
471    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
472)]
473pub struct TraitPredicate {
474    pub trait_ref: TraitRef,
475}
476
477impl TraitPredicate {
478    fn self_ty(&self) -> SubsetTyCtor {
479        self.trait_ref.self_ty()
480    }
481}
482
483impl<'tcx> ToRustc<'tcx> for TraitPredicate {
484    type T = rustc_middle::ty::TraitPredicate<'tcx>;
485
486    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
487        rustc_middle::ty::TraitPredicate {
488            polarity: rustc_middle::ty::PredicatePolarity::Positive,
489            trait_ref: self.trait_ref.to_rustc(tcx),
490        }
491    }
492}
493
494pub type PolyTraitPredicate = Binder<TraitPredicate>;
495
496impl PolyTraitPredicate {
497    fn def_id(&self) -> DefId {
498        self.skip_binder_ref().trait_ref.def_id
499    }
500
501    fn self_ty(&self) -> Binder<SubsetTyCtor> {
502        self.clone().map(|predicate| predicate.self_ty())
503    }
504}
505
506#[derive(
507    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
508)]
509pub struct TraitRef {
510    pub def_id: DefId,
511    pub args: GenericArgs,
512}
513
514impl TraitRef {
515    pub fn self_ty(&self) -> SubsetTyCtor {
516        self.args[0].expect_base().clone()
517    }
518}
519
520impl<'tcx> ToRustc<'tcx> for TraitRef {
521    type T = rustc_middle::ty::TraitRef<'tcx>;
522
523    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
524        rustc_middle::ty::TraitRef::new(tcx, self.def_id, self.args.to_rustc(tcx))
525    }
526}
527
528pub type PolyTraitRef = Binder<TraitRef>;
529
530impl PolyTraitRef {
531    pub fn def_id(&self) -> DefId {
532        self.as_ref().skip_binder().def_id
533    }
534}
535
536#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
537pub enum ExistentialPredicate {
538    Trait(ExistentialTraitRef),
539    Projection(ExistentialProjection),
540    AutoTrait(DefId),
541}
542
543pub type PolyExistentialPredicate = Binder<ExistentialPredicate>;
544
545impl ExistentialPredicate {
546    /// See [`rustc_middle::ty::ExistentialPredicateStableCmpExt`]
547    pub fn stable_cmp(&self, tcx: TyCtxt, other: &Self) -> Ordering {
548        match (self, other) {
549            (ExistentialPredicate::Trait(_), ExistentialPredicate::Trait(_)) => Ordering::Equal,
550            (ExistentialPredicate::Projection(a), ExistentialPredicate::Projection(b)) => {
551                tcx.def_path_hash(a.def_id)
552                    .cmp(&tcx.def_path_hash(b.def_id))
553            }
554            (ExistentialPredicate::AutoTrait(a), ExistentialPredicate::AutoTrait(b)) => {
555                tcx.def_path_hash(*a).cmp(&tcx.def_path_hash(*b))
556            }
557            (ExistentialPredicate::Trait(_), _) => Ordering::Less,
558            (ExistentialPredicate::Projection(_), ExistentialPredicate::Trait(_)) => {
559                Ordering::Greater
560            }
561            (ExistentialPredicate::Projection(_), _) => Ordering::Less,
562            (ExistentialPredicate::AutoTrait(_), _) => Ordering::Greater,
563        }
564    }
565}
566
567impl<'tcx> ToRustc<'tcx> for ExistentialPredicate {
568    type T = rustc_middle::ty::ExistentialPredicate<'tcx>;
569
570    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
571        match self {
572            ExistentialPredicate::Trait(trait_ref) => {
573                let trait_ref = rustc_middle::ty::ExistentialTraitRef::new_from_args(
574                    tcx,
575                    trait_ref.def_id,
576                    trait_ref.args.to_rustc(tcx),
577                );
578                rustc_middle::ty::ExistentialPredicate::Trait(trait_ref)
579            }
580            ExistentialPredicate::Projection(projection) => {
581                rustc_middle::ty::ExistentialPredicate::Projection(
582                    rustc_middle::ty::ExistentialProjection::new_from_args(
583                        tcx,
584                        projection.def_id,
585                        projection.args.to_rustc(tcx),
586                        projection.term.skip_binder_ref().to_rustc(tcx).into(),
587                    ),
588                )
589            }
590            ExistentialPredicate::AutoTrait(def_id) => {
591                rustc_middle::ty::ExistentialPredicate::AutoTrait(*def_id)
592            }
593        }
594    }
595}
596
597#[derive(
598    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
599)]
600pub struct ExistentialTraitRef {
601    pub def_id: DefId,
602    pub args: GenericArgs,
603}
604
605pub type PolyExistentialTraitRef = Binder<ExistentialTraitRef>;
606
607impl PolyExistentialTraitRef {
608    pub fn def_id(&self) -> DefId {
609        self.as_ref().skip_binder().def_id
610    }
611}
612
613#[derive(
614    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
615)]
616pub struct ExistentialProjection {
617    pub def_id: DefId,
618    pub args: GenericArgs,
619    pub term: SubsetTyCtor,
620}
621
622#[derive(
623    PartialEq, Eq, Hash, Debug, Clone, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
624)]
625pub struct ProjectionPredicate {
626    pub projection_ty: AliasTy,
627    pub term: SubsetTyCtor,
628}
629
630impl Pretty for ProjectionPredicate {
631    fn fmt(&self, _cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
632        write!(
633            f,
634            "ProjectionPredicate << projection_ty = {:?}, term = {:?} >>",
635            self.projection_ty, self.term
636        )
637    }
638}
639
640impl ProjectionPredicate {
641    pub fn self_ty(&self) -> SubsetTyCtor {
642        self.projection_ty.self_ty().clone()
643    }
644}
645
646impl<'tcx> ToRustc<'tcx> for ProjectionPredicate {
647    type T = rustc_middle::ty::ProjectionPredicate<'tcx>;
648
649    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
650        rustc_middle::ty::ProjectionPredicate {
651            projection_term: rustc_middle::ty::AliasTerm::new_from_args(
652                tcx,
653                self.projection_ty.def_id,
654                self.projection_ty.args.to_rustc(tcx),
655            ),
656            term: self.term.as_bty_skipping_binder().to_rustc(tcx).into(),
657        }
658    }
659}
660
661pub type PolyProjectionPredicate = Binder<ProjectionPredicate>;
662
663impl PolyProjectionPredicate {
664    pub fn projection_def_id(&self) -> DefId {
665        self.skip_binder_ref().projection_ty.def_id
666    }
667
668    pub fn self_ty(&self) -> Binder<SubsetTyCtor> {
669        self.clone().map(|predicate| predicate.self_ty())
670    }
671}
672
673#[derive(
674    Clone, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
675)]
676pub struct FnTraitPredicate {
677    pub self_ty: Ty,
678    pub tupled_args: Ty,
679    pub output: Ty,
680    pub kind: ClosureKind,
681}
682
683impl Pretty for FnTraitPredicate {
684    fn fmt(&self, _cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
685        write!(
686            f,
687            "self = {:?}, args = {:?}, output = {:?}, kind = {}",
688            self.self_ty, self.tupled_args, self.output, self.kind
689        )
690    }
691}
692
693impl FnTraitPredicate {
694    pub fn fndef_sig(&self) -> FnSig {
695        let inputs = self.tupled_args.expect_tuple().iter().cloned().collect();
696        let ret = self.output.clone().shift_in_escaping(1);
697        let output = Binder::bind_with_vars(FnOutput::new(ret, vec![]), List::empty());
698        FnSig::new(
699            Safety::Safe,
700            rustc_abi::ExternAbi::Rust,
701            List::empty(),
702            inputs,
703            output,
704            Expr::ff(),
705            false,
706        )
707    }
708}
709
710pub fn to_closure_sig(
711    tcx: TyCtxt,
712    closure_id: LocalDefId,
713    tys: &[Ty],
714    args: &flux_rustc_bridge::ty::GenericArgs,
715    poly_sig: &PolyFnSig,
716    no_panic: bool,
717) -> PolyFnSig {
718    let closure_args = args.as_closure();
719    let kind_ty = closure_args.kind_ty().to_rustc(tcx);
720    let Some(kind) = kind_ty.to_opt_closure_kind() else {
721        bug!("to_closure_sig: expected closure kind, found {kind_ty:?}");
722    };
723
724    let mut vars = poly_sig.vars().clone().to_vec();
725    let fn_sig = poly_sig.clone().skip_binder();
726    let closure_ty = Ty::closure(closure_id.into(), tys, args, no_panic);
727    let env_ty = match kind {
728        ClosureKind::Fn => {
729            vars.push(BoundVariableKind::Region(BoundRegionKind::ClosureEnv));
730            let br = BoundRegion {
731                var: BoundVar::from_usize(vars.len() - 1),
732                kind: BoundRegionKind::ClosureEnv,
733            };
734            Ty::mk_ref(ReBound(INNERMOST, br), closure_ty, Mutability::Not)
735        }
736        ClosureKind::FnMut => {
737            vars.push(BoundVariableKind::Region(BoundRegionKind::ClosureEnv));
738            let br = BoundRegion {
739                var: BoundVar::from_usize(vars.len() - 1),
740                kind: BoundRegionKind::ClosureEnv,
741            };
742            Ty::mk_ref(ReBound(INNERMOST, br), closure_ty, Mutability::Mut)
743        }
744        ClosureKind::FnOnce => closure_ty,
745    };
746
747    let inputs = std::iter::once(env_ty)
748        .chain(fn_sig.inputs().iter().cloned())
749        .collect::<Vec<_>>();
750    let output = fn_sig.output().clone();
751
752    let fn_sig = crate::rty::FnSig::new(
753        fn_sig.safety,
754        fn_sig.abi,
755        fn_sig.requires.clone(),
756        inputs.into(),
757        output,
758        if no_panic { crate::rty::Expr::tt() } else { crate::rty::Expr::ff() },
759        false,
760    );
761
762    PolyFnSig::bind_with_vars(fn_sig, List::from(vars))
763}
764
765#[derive(Clone, PartialEq, Eq, Hash, Debug)]
766pub struct CoroutineObligPredicate {
767    pub def_id: DefId,
768    pub resume_ty: Ty,
769    pub upvar_tys: List<Ty>,
770    pub output: Ty,
771    pub args: flux_rustc_bridge::ty::GenericArgs,
772}
773
774#[derive(Copy, Clone, Encodable, Decodable, Hash, PartialEq, Eq)]
775pub struct AssocReft {
776    pub def_id: FluxDefId,
777    // NOTE: Field is used to denote final associated generic refinements on Traits
778    pub final_: bool,
779    pub span: Span,
780}
781
782impl AssocReft {
783    pub fn new(def_id: FluxDefId, final_: bool, span: Span) -> Self {
784        Self { def_id, final_, span }
785    }
786
787    pub fn name(&self) -> Symbol {
788        self.def_id.name()
789    }
790
791    pub fn def_id(&self) -> FluxDefId {
792        self.def_id
793    }
794}
795
796#[derive(Clone, Encodable, Decodable)]
797pub struct AssocRefinements {
798    pub items: List<AssocReft>,
799}
800
801impl Default for AssocRefinements {
802    fn default() -> Self {
803        Self { items: List::empty() }
804    }
805}
806
807impl AssocRefinements {
808    pub fn get(&self, assoc_id: FluxDefId) -> AssocReft {
809        *self
810            .items
811            .into_iter()
812            .find(|it| it.def_id == assoc_id)
813            .unwrap_or_else(|| bug!("caller should guarantee existence of associated refinement"))
814    }
815
816    pub fn find(&self, name: Symbol) -> Option<AssocReft> {
817        Some(*self.items.into_iter().find(|it| it.name() == name)?)
818    }
819}
820
821#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
822pub enum SortCtor {
823    Set,
824    Map,
825    Adt(AdtSortDef),
826    User(FluxDefId),
827}
828
829newtype_index! {
830    /// [`ParamSort`] is used for polymorphic sorts (`Set`, `Map`, etc.) and [bit-vector size parameters].
831    /// They should occur "bound" under a [`PolyFuncSort`] or an [`AdtSortDef`]. We assume there's a
832    /// single binder and a [`ParamSort`] represents a variable as an index into the list of variables
833    /// bound by that binder, i.e., the representation doesnt't support higher-ranked sorts.
834    ///
835    /// [bit-vector size parameters]: BvSize::Param
836    #[debug_format = "?{}s"]
837    #[encodable]
838    pub struct ParamSort {}
839}
840
841newtype_index! {
842    /// A *sort* *v*variable *id*
843    #[debug_format = "?{}s"]
844    #[encodable]
845    pub struct SortVid {}
846}
847
848impl ena::unify::UnifyKey for SortVid {
849    type Value = SortVarVal;
850
851    #[inline]
852    fn index(&self) -> u32 {
853        self.as_u32()
854    }
855
856    #[inline]
857    fn from_index(u: u32) -> Self {
858        SortVid::from_u32(u)
859    }
860
861    fn tag() -> &'static str {
862        "SortVid"
863    }
864}
865
866bitflags! {
867    /// A *sort constraint* is a set of operations a sort must support.
868    ///
869    /// During sort checking, we accumulate the operations required for each sort variable. As
870    /// unification progresses, these constraints become more specific, i.e, a sort must support
871    /// more operations to satisfy them.
872    #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
873    pub struct SortCstr: u16 {
874        /// An empty constraint (any sort satisfies it)
875        const BOT     = 0b0000000000;
876        /// `*`
877        const MUL     = 0b0000000001;
878        /// `/`
879        const DIV     = 0b0000000010;
880        /// `%`
881        const MOD     = 0b0000000100;
882        /// `+`
883        const ADD     = 0b0000001000;
884        /// `-`
885        const SUB     = 0b0000010000;
886        /// `|`
887        const BIT_OR  = 0b0000100000;
888        /// `&`
889        const BIT_AND = 0b0001000000;
890        /// `>>`
891        const BIT_SHL = 0b0010000000;
892        /// `<<`
893        const BIT_SHR = 0b0100000000;
894        /// `^`
895        const BIT_XOR = 0b1000000000;
896
897        /// The set of operations supported by all _numeric_ sorts.
898        const NUMERIC = Self::ADD.bits() | Self::SUB.bits() | Self::MUL.bits() | Self::DIV.bits();
899        /// The set of operations supported by integers.
900        const INT = Self::DIV.bits()
901            | Self::MUL.bits()
902            | Self::MOD.bits()
903            | Self::ADD.bits()
904            | Self::SUB.bits();
905        /// The set of operations supported by reals.
906        const REAL = Self::ADD.bits() | Self::SUB.bits() | Self::MUL.bits() | Self::DIV.bits();
907        /// The set of operations supported by bit vectors.
908        const BITVEC =  Self::DIV.bits()
909            | Self::MUL.bits()
910            | Self::MOD.bits()
911            | Self::ADD.bits()
912            | Self::SUB.bits()
913            | Self::BIT_OR.bits()
914            | Self::BIT_AND.bits()
915            | Self::BIT_SHL.bits()
916            | Self::BIT_SHR.bits()
917            | Self::BIT_XOR.bits();
918        /// The set of operations supported by sets.
919        const SET = Self::SUB.bits() | Self::BIT_OR.bits() | Self::BIT_AND.bits();
920    }
921}
922
923impl SortCstr {
924    /// Returns a constraint that only requires the specified binary operation.
925    pub fn from_bin_op(op: fhir::BinOp) -> Self {
926        match op {
927            fhir::BinOp::Add => Self::ADD,
928            fhir::BinOp::Sub => Self::SUB,
929            fhir::BinOp::Mul => Self::MUL,
930            fhir::BinOp::Div => Self::DIV,
931            fhir::BinOp::Mod => Self::MOD,
932            fhir::BinOp::BitAnd => Self::BIT_AND,
933            fhir::BinOp::BitOr => Self::BIT_OR,
934            fhir::BinOp::BitXor => Self::BIT_XOR,
935            fhir::BinOp::BitShl => Self::BIT_SHL,
936            fhir::BinOp::BitShr => Self::BIT_SHR,
937            _ => bug!("{op:?} not supported as a constraint"),
938        }
939    }
940
941    /// Returns whether a sort satisfies this constraint
942    fn satisfy(self, sort: &Sort) -> bool {
943        match sort {
944            Sort::Int => SortCstr::INT.contains(self),
945            Sort::Real => SortCstr::REAL.contains(self),
946            Sort::BitVec(_) => SortCstr::BITVEC.contains(self),
947            Sort::App(SortCtor::Set, _) => SortCstr::SET.contains(self),
948            _ => self == SortCstr::BOT,
949        }
950    }
951}
952
953/// Unification value for sort variables used during sort checking.
954#[derive(Debug, Clone, PartialEq, Eq)]
955pub enum SortVarVal {
956    /// The variable is not yet solved but the solution must satisfy some constraint.
957    Unsolved(SortCstr),
958    /// The variable has been solved to a sort.
959    Solved(Sort),
960}
961
962impl Default for SortVarVal {
963    fn default() -> Self {
964        SortVarVal::Unsolved(SortCstr::BOT)
965    }
966}
967
968impl SortVarVal {
969    pub fn solved_or(&self, sort: &Sort) -> Sort {
970        match self {
971            SortVarVal::Unsolved(_) => sort.clone(),
972            SortVarVal::Solved(sort) => sort.clone(),
973        }
974    }
975
976    pub fn map_solved(&self, f: impl FnOnce(&Sort) -> Sort) -> SortVarVal {
977        match self {
978            SortVarVal::Unsolved(cstr) => SortVarVal::Unsolved(*cstr),
979            SortVarVal::Solved(sort) => SortVarVal::Solved(f(sort)),
980        }
981    }
982}
983
984impl ena::unify::UnifyValue for SortVarVal {
985    type Error = ();
986
987    fn unify_values(value1: &Self, value2: &Self) -> Result<Self, Self::Error> {
988        match (value1, value2) {
989            (SortVarVal::Solved(s1), SortVarVal::Solved(s2)) if s1 == s2 => {
990                Ok(SortVarVal::Solved(s1.clone()))
991            }
992            (SortVarVal::Unsolved(a), SortVarVal::Unsolved(b)) => Ok(SortVarVal::Unsolved(*a | *b)),
993            (SortVarVal::Unsolved(v), SortVarVal::Solved(sort))
994            | (SortVarVal::Solved(sort), SortVarVal::Unsolved(v))
995                if v.satisfy(sort) =>
996            {
997                Ok(SortVarVal::Solved(sort.clone()))
998            }
999            _ => Err(()),
1000        }
1001    }
1002}
1003
1004newtype_index! {
1005    /// A *b*it *v*ector *size* *v*variable *id*
1006    #[debug_format = "?{}size"]
1007    #[encodable]
1008    pub struct BvSizeVid {}
1009}
1010
1011impl ena::unify::UnifyKey for BvSizeVid {
1012    type Value = Option<BvSize>;
1013
1014    #[inline]
1015    fn index(&self) -> u32 {
1016        self.as_u32()
1017    }
1018
1019    #[inline]
1020    fn from_index(u: u32) -> Self {
1021        BvSizeVid::from_u32(u)
1022    }
1023
1024    fn tag() -> &'static str {
1025        "BvSizeVid"
1026    }
1027}
1028
1029impl ena::unify::EqUnifyValue for BvSize {}
1030
1031#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1032pub enum Sort {
1033    Int,
1034    Bool,
1035    Real,
1036    BitVec(BvSize),
1037    Str,
1038    Char,
1039    Loc,
1040    Param(ParamTy),
1041    Tuple(List<Sort>),
1042    Alias(AliasKind, AliasTy),
1043    Func(PolyFuncSort),
1044    App(SortCtor, List<Sort>),
1045    Var(ParamSort),
1046    Infer(SortVid),
1047    RawPtr,
1048    Err,
1049}
1050
1051pub enum CastKind {
1052    /// Identity cast, which is erasable (e.g. int -> int, char -> int)
1053    Identity,
1054    /// From bool to int
1055    BoolToInt,
1056    /// Casts to unit index, (e.g. int -> float)
1057    IntoUnit,
1058    /// Uninterpreted casts, only allowed with explicit flag
1059    Uninterpreted,
1060}
1061
1062impl Sort {
1063    pub fn tuple(sorts: impl Into<List<Sort>>) -> Self {
1064        Sort::Tuple(sorts.into())
1065    }
1066
1067    pub fn app(ctor: SortCtor, sorts: List<Sort>) -> Self {
1068        Sort::App(ctor, sorts)
1069    }
1070
1071    pub fn unit() -> Self {
1072        Self::tuple(vec![])
1073    }
1074
1075    #[track_caller]
1076    pub fn expect_func(&self) -> &PolyFuncSort {
1077        if let Sort::Func(sort) = self { sort } else { bug!("expected `Sort::Func`") }
1078    }
1079
1080    pub fn is_loc(&self) -> bool {
1081        matches!(self, Sort::Loc)
1082    }
1083
1084    pub fn is_unit(&self) -> bool {
1085        matches!(self, Sort::Tuple(sorts) if sorts.is_empty())
1086    }
1087
1088    pub fn is_unit_adt(&self) -> Option<DefId> {
1089        if let Sort::App(SortCtor::Adt(sort_def), _) = self
1090            && let Some(variant) = sort_def.opt_struct_variant()
1091            && variant.fields() == 0
1092        {
1093            Some(sort_def.did())
1094        } else {
1095            None
1096        }
1097    }
1098
1099    /// Whether the sort is a function with return sort bool
1100    pub fn is_pred(&self) -> bool {
1101        matches!(self, Sort::Func(fsort) if fsort.skip_binders().output().is_bool())
1102    }
1103
1104    /// Returns `true` if the sort is [`Bool`].
1105    ///
1106    /// [`Bool`]: Sort::Bool
1107    #[must_use]
1108    pub fn is_bool(&self) -> bool {
1109        matches!(self, Self::Bool)
1110    }
1111
1112    pub fn cast_kind(self: &Sort, to: &Sort) -> CastKind {
1113        if self == to
1114            || (matches!(self, Sort::Char | Sort::Int) && matches!(to, Sort::Char | Sort::Int))
1115        {
1116            CastKind::Identity
1117        } else if matches!(self, Sort::Bool) && matches!(to, Sort::Int) {
1118            CastKind::BoolToInt
1119        } else if to.is_unit() {
1120            CastKind::IntoUnit
1121        } else {
1122            CastKind::Uninterpreted
1123        }
1124    }
1125
1126    pub fn walk(&self, mut f: impl FnMut(&Sort, &[FieldProj])) {
1127        fn go(sort: &Sort, f: &mut impl FnMut(&Sort, &[FieldProj]), proj: &mut Vec<FieldProj>) {
1128            match sort {
1129                Sort::Tuple(flds) => {
1130                    for (i, sort) in flds.iter().enumerate() {
1131                        proj.push(FieldProj::Tuple { arity: flds.len(), field: i as u32 });
1132                        go(sort, f, proj);
1133                        proj.pop();
1134                    }
1135                }
1136                Sort::App(SortCtor::Adt(sort_def), args) if sort_def.is_struct() => {
1137                    let field_sorts = sort_def.struct_variant().field_sorts(args);
1138                    for (i, sort) in field_sorts.iter().enumerate() {
1139                        proj.push(FieldProj::Adt { def_id: sort_def.did(), field: i as u32 });
1140                        go(sort, f, proj);
1141                        proj.pop();
1142                    }
1143                }
1144                _ => {
1145                    f(sort, proj);
1146                }
1147            }
1148        }
1149        go(self, &mut f, &mut vec![]);
1150    }
1151}
1152
1153/// The size of a [bit-vector]
1154///
1155/// [bit-vector]: Sort::BitVec
1156#[derive(Clone, Copy, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1157pub enum BvSize {
1158    /// A fixed size
1159    Fixed(u32),
1160    /// A size that has been parameterized, e.g., bound under a [`PolyFuncSort`]
1161    Param(ParamSort),
1162    /// A size that needs to be inferred. Used during sort checking to instantiate bit-vector
1163    /// sizes at call-sites.
1164    Infer(BvSizeVid),
1165}
1166
1167impl rustc_errors::IntoDiagArg for Sort {
1168    fn into_diag_arg(self, _path: &mut Option<std::path::PathBuf>) -> rustc_errors::DiagArgValue {
1169        rustc_errors::DiagArgValue::Str(Cow::Owned(format!("{self:?}")))
1170    }
1171}
1172
1173impl rustc_errors::IntoDiagArg for FuncSort {
1174    fn into_diag_arg(self, _path: &mut Option<std::path::PathBuf>) -> rustc_errors::DiagArgValue {
1175        rustc_errors::DiagArgValue::Str(Cow::Owned(format!("{self:?}")))
1176    }
1177}
1178
1179#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
1180pub struct FuncSort {
1181    pub inputs_and_output: List<Sort>,
1182}
1183
1184impl FuncSort {
1185    pub fn new(mut inputs: Vec<Sort>, output: Sort) -> Self {
1186        inputs.push(output);
1187        FuncSort { inputs_and_output: List::from_vec(inputs) }
1188    }
1189
1190    pub fn inputs(&self) -> &[Sort] {
1191        &self.inputs_and_output[0..self.inputs_and_output.len() - 1]
1192    }
1193
1194    pub fn output(&self) -> &Sort {
1195        &self.inputs_and_output[self.inputs_and_output.len() - 1]
1196    }
1197
1198    pub fn to_poly(&self) -> PolyFuncSort {
1199        PolyFuncSort::new(List::empty(), self.clone())
1200    }
1201}
1202
1203/// See [`PolyFuncSort`]
1204#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable)]
1205pub enum SortParamKind {
1206    Sort,
1207    BvSize,
1208}
1209
1210/// A polymorphic function sort parametric over [sorts] or [bit-vector sizes].
1211///
1212/// Parameterizing over bit-vector sizes is a bit of a stretch, because smtlib doesn't support full
1213/// parametric reasoning over them. As long as we used functions parameterized over a size monomorphically
1214/// we should be fine. Right now, we can guarantee this, because size parameters are not exposed in
1215/// the surface syntax and they are only used for predefined (interpreted) theory functions.
1216///
1217/// [sorts]: Sort
1218/// [bit-vector sizes]: BvSize::Param
1219#[derive(Clone, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable)]
1220pub struct PolyFuncSort {
1221    /// The list of parameters including sorts and bit vector sizes
1222    params: List<SortParamKind>,
1223    fsort: FuncSort,
1224}
1225
1226impl PolyFuncSort {
1227    pub fn new(params: List<SortParamKind>, fsort: FuncSort) -> Self {
1228        PolyFuncSort { params, fsort }
1229    }
1230
1231    pub fn skip_binders(&self) -> FuncSort {
1232        self.fsort.clone()
1233    }
1234
1235    pub fn instantiate_identity(&self) -> FuncSort {
1236        self.fsort.clone()
1237    }
1238
1239    pub fn expect_mono(&self) -> FuncSort {
1240        assert!(self.params.is_empty());
1241        self.fsort.clone()
1242    }
1243
1244    pub fn params(&self) -> impl ExactSizeIterator<Item = SortParamKind> + '_ {
1245        self.params.iter().copied()
1246    }
1247
1248    pub fn instantiate(&self, args: &[SortArg]) -> FuncSort {
1249        self.fsort.fold_with(&mut SortSubst::new(args))
1250    }
1251}
1252
1253/// An argument for a generic parameter in a [`Sort`] which can be either a generic sort or a
1254/// generic bit-vector size.
1255#[derive(
1256    Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
1257)]
1258pub enum SortArg {
1259    Sort(Sort),
1260    BvSize(BvSize),
1261}
1262
1263#[derive(Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1264pub enum ConstantInfo {
1265    /// An uninterpreted constant
1266    Uninterpreted,
1267    /// A non-integral constant whose value is specified by the user
1268    Interpreted(Expr, Sort),
1269}
1270
1271#[derive(Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1272pub enum StaticInfo {
1273    Unknown,
1274    /// A static item whose type was specified by the user
1275    Known(Ty),
1276}
1277
1278#[derive(Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1279pub struct AdtDef(Interned<AdtDefData>);
1280
1281#[derive(Debug, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1282pub struct AdtDefData {
1283    invariants: Vec<Invariant>,
1284    sort_def: AdtSortDef,
1285    opaque: bool,
1286    rustc: ty::AdtDef,
1287}
1288
1289/// Option-like enum to explicitly mark that we don't have information about an ADT because it was
1290/// annotated with `#[flux::opaque]`. Note that only structs can be marked as opaque.
1291#[derive(Clone, Debug, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
1292pub enum Opaqueness<T> {
1293    Opaque,
1294    Transparent(T),
1295}
1296
1297impl<T> Opaqueness<T> {
1298    pub fn map<S>(self, f: impl FnOnce(T) -> S) -> Opaqueness<S> {
1299        match self {
1300            Opaqueness::Opaque => Opaqueness::Opaque,
1301            Opaqueness::Transparent(value) => Opaqueness::Transparent(f(value)),
1302        }
1303    }
1304
1305    pub fn as_ref(&self) -> Opaqueness<&T> {
1306        match self {
1307            Opaqueness::Opaque => Opaqueness::Opaque,
1308            Opaqueness::Transparent(value) => Opaqueness::Transparent(value),
1309        }
1310    }
1311
1312    pub fn as_deref(&self) -> Opaqueness<&T::Target>
1313    where
1314        T: std::ops::Deref,
1315    {
1316        match self {
1317            Opaqueness::Opaque => Opaqueness::Opaque,
1318            Opaqueness::Transparent(value) => Opaqueness::Transparent(value.deref()),
1319        }
1320    }
1321
1322    pub fn ok_or_else<E>(self, err: impl FnOnce() -> E) -> Result<T, E> {
1323        match self {
1324            Opaqueness::Transparent(v) => Ok(v),
1325            Opaqueness::Opaque => Err(err()),
1326        }
1327    }
1328
1329    #[track_caller]
1330    pub fn expect(self, msg: &str) -> T {
1331        match self {
1332            Opaqueness::Transparent(val) => val,
1333            Opaqueness::Opaque => bug!("{}", msg),
1334        }
1335    }
1336
1337    pub fn ok_or_query_err(self, struct_id: DefId) -> Result<T, QueryErr> {
1338        self.ok_or_else(|| QueryErr::OpaqueStruct { struct_id })
1339    }
1340}
1341
1342impl<T, E> Opaqueness<Result<T, E>> {
1343    pub fn transpose(self) -> Result<Opaqueness<T>, E> {
1344        match self {
1345            Opaqueness::Transparent(Ok(x)) => Ok(Opaqueness::Transparent(x)),
1346            Opaqueness::Transparent(Err(e)) => Err(e),
1347            Opaqueness::Opaque => Ok(Opaqueness::Opaque),
1348        }
1349    }
1350}
1351
1352pub static INT_TYS: [IntTy; 6] =
1353    [IntTy::Isize, IntTy::I8, IntTy::I16, IntTy::I32, IntTy::I64, IntTy::I128];
1354pub static UINT_TYS: [UintTy; 6] =
1355    [UintTy::Usize, UintTy::U8, UintTy::U16, UintTy::U32, UintTy::U64, UintTy::U128];
1356
1357#[derive(
1358    Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable, TypeFoldable, TypeVisitable,
1359)]
1360pub struct Invariant {
1361    // This predicate may have sort variables, but we don't explicitly mark it like in `PolyFuncSort`.
1362    // See comment on `apply` for details.
1363    pred: Binder<Expr>,
1364}
1365
1366impl Invariant {
1367    pub fn new(pred: Binder<Expr>) -> Self {
1368        Self { pred }
1369    }
1370
1371    pub fn apply(&self, idx: &Expr) -> Expr {
1372        // The predicate may have sort variables but we don't explicitly instantiate them. This
1373        // works because within an expression, sort variables can only appear inside the sort
1374        // annotation for a lambda and invariants cannot have lambdas. It remains to instantiate
1375        // variables in the sort of the binder itself, but since we are removing it, we can avoid
1376        // the explicit instantiation. Ultimately, this works because the expression we generate in
1377        // fixpoint doesn't need sort annotations (sorts are re-inferred).
1378        self.pred.replace_bound_reft(idx)
1379    }
1380}
1381
1382pub type PolyVariants = List<Binder<VariantSig>>;
1383pub type PolyVariant = Binder<VariantSig>;
1384
1385#[derive(Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1386pub struct VariantSig {
1387    pub adt_def: AdtDef,
1388    pub args: GenericArgs,
1389    pub fields: List<Ty>,
1390    pub idx: Expr,
1391    pub requires: List<Expr>,
1392}
1393
1394pub type PolyFnSig = Binder<FnSig>;
1395
1396#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
1397pub struct FnSig {
1398    pub safety: Safety,
1399    pub abi: rustc_abi::ExternAbi,
1400    pub requires: List<Expr>,
1401    pub inputs: List<Ty>,
1402    pub output: Binder<FnOutput>,
1403    pub no_panic: Expr,
1404    /// was this auto-lifted (or from a spec)
1405    pub lifted: bool,
1406}
1407
1408#[derive(
1409    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
1410)]
1411pub struct FnOutput {
1412    pub ret: Ty,
1413    pub ensures: List<Ensures>,
1414}
1415
1416#[derive(Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1417pub enum Ensures {
1418    Type(Path, Ty),
1419    Pred(Expr),
1420}
1421
1422#[derive(Debug, TypeVisitable, TypeFoldable)]
1423pub struct Qualifier {
1424    pub def_id: FluxLocalDefId,
1425    pub body: Binder<Expr>,
1426    pub kind: QualifierKind,
1427}
1428
1429#[derive(Debug, TypeFoldable, TypeVisitable, Copy, Clone)]
1430pub enum QualifierKind {
1431    Global,
1432    Local,
1433    Hint,
1434}
1435
1436/// A `PrimOpProp` is a single property for a primitive operation which
1437/// can be conjoined to get the definition of the [`PrimRel`] for that
1438/// primitive operation.
1439#[derive(Debug, TypeVisitable, TypeFoldable)]
1440pub struct PrimOpProp {
1441    pub def_id: FluxLocalDefId,
1442    pub op: BinOp,
1443    pub body: Binder<Expr>,
1444}
1445
1446#[derive(Debug, TypeVisitable, TypeFoldable)]
1447pub struct PrimRel {
1448    pub body: Binder<Expr>,
1449}
1450
1451pub type TyCtor = Binder<Ty>;
1452
1453impl TyCtor {
1454    pub fn to_ty(&self) -> Ty {
1455        match &self.vars()[..] {
1456            [] => {
1457                return self.skip_binder_ref().shift_out_escaping(1);
1458            }
1459            [BoundVariableKind::Refine(sort, ..)] => {
1460                if sort.is_unit() {
1461                    return self.replace_bound_reft(&Expr::unit());
1462                }
1463                if let Some(def_id) = sort.is_unit_adt() {
1464                    return self.replace_bound_reft(&Expr::unit_struct(def_id));
1465                }
1466            }
1467            _ => {}
1468        }
1469        Ty::exists(self.clone())
1470    }
1471}
1472
1473#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1474pub struct Ty(Interned<TyKind>);
1475
1476impl Ty {
1477    pub fn kind(&self) -> &TyKind {
1478        &self.0
1479    }
1480
1481    /// Dummy type used for the `Self` of a `TraitRef` created when converting a trait object, and
1482    /// which gets removed in `ExistentialTraitRef`. This type must not appear anywhere in other
1483    /// converted types and must be a valid `rustc` type (i.e., we must be able to call `to_rustc`
1484    /// on it). `TyKind::Infer(TyVid(0))` does the job, with the caveat that we must skip 0 when
1485    /// generating `TyKind::Infer` for "type holes".
1486    pub fn trait_object_dummy_self() -> Ty {
1487        Ty::infer(TyVid::from_u32(0))
1488    }
1489
1490    pub fn dynamic(preds: impl Into<List<Binder<ExistentialPredicate>>>, region: Region) -> Ty {
1491        BaseTy::Dynamic(preds.into(), region).to_ty()
1492    }
1493
1494    pub fn strg_ref(re: Region, path: Path, ty: Ty) -> Ty {
1495        TyKind::StrgRef(re, path, ty).intern()
1496    }
1497
1498    pub fn ptr(pk: impl Into<PtrKind>, path: impl Into<Path>) -> Ty {
1499        TyKind::Ptr(pk.into(), path.into()).intern()
1500    }
1501
1502    pub fn constr(p: impl Into<Expr>, ty: Ty) -> Ty {
1503        TyKind::Constr(p.into(), ty).intern()
1504    }
1505
1506    pub fn uninit() -> Ty {
1507        TyKind::Uninit.intern()
1508    }
1509
1510    pub fn indexed(bty: BaseTy, idx: impl Into<Expr>) -> Ty {
1511        TyKind::Indexed(bty, idx.into()).intern()
1512    }
1513
1514    pub fn exists(ty: Binder<Ty>) -> Ty {
1515        TyKind::Exists(ty).intern()
1516    }
1517
1518    pub fn exists_with_constr(bty: BaseTy, pred: Expr) -> Ty {
1519        let sort = bty.sort();
1520        let ty = Ty::indexed(bty, Expr::nu());
1521        Ty::exists(Binder::bind_with_sort(Ty::constr(pred, ty), sort))
1522    }
1523
1524    pub fn discr(adt_def: AdtDef, place: Place) -> Ty {
1525        TyKind::Discr(adt_def, place).intern()
1526    }
1527
1528    pub fn unit() -> Ty {
1529        Ty::tuple(vec![])
1530    }
1531
1532    pub fn bool() -> Ty {
1533        BaseTy::Bool.to_ty()
1534    }
1535
1536    pub fn int(int_ty: IntTy) -> Ty {
1537        BaseTy::Int(int_ty).to_ty()
1538    }
1539
1540    pub fn uint(uint_ty: UintTy) -> Ty {
1541        BaseTy::Uint(uint_ty).to_ty()
1542    }
1543
1544    pub fn param(param_ty: ParamTy) -> Ty {
1545        TyKind::Param(param_ty).intern()
1546    }
1547
1548    pub fn downcast(
1549        adt: AdtDef,
1550        args: GenericArgs,
1551        ty: Ty,
1552        variant: VariantIdx,
1553        fields: List<Ty>,
1554    ) -> Ty {
1555        TyKind::Downcast(adt, args, ty, variant, fields).intern()
1556    }
1557
1558    pub fn blocked(ty: Ty) -> Ty {
1559        TyKind::Blocked(ty).intern()
1560    }
1561
1562    pub fn str() -> Ty {
1563        BaseTy::Str.to_ty()
1564    }
1565
1566    pub fn char() -> Ty {
1567        BaseTy::Char.to_ty()
1568    }
1569
1570    pub fn float(float_ty: FloatTy) -> Ty {
1571        BaseTy::Float(float_ty).to_ty()
1572    }
1573
1574    pub fn mk_ref(region: Region, ty: Ty, mutbl: Mutability) -> Ty {
1575        BaseTy::Ref(region, ty, mutbl).to_ty()
1576    }
1577
1578    pub fn mk_slice(ty: Ty) -> Ty {
1579        BaseTy::Slice(ty).to_ty()
1580    }
1581
1582    pub fn mk_box(genv: GlobalEnv, deref_ty: Ty, alloc_ty: GenericArg) -> QueryResult<Ty> {
1583        let def_id = genv.tcx().require_lang_item(LangItem::OwnedBox, DUMMY_SP);
1584        let adt_def = genv.adt_def(def_id)?;
1585
1586        let args = List::from_arr([GenericArg::Ty(deref_ty), alloc_ty]);
1587
1588        let bty = BaseTy::adt(adt_def, args);
1589        Ok(Ty::indexed(bty, Expr::unit_struct(def_id)))
1590    }
1591
1592    pub fn mk_box_with_default_alloc(genv: GlobalEnv, deref_ty: Ty) -> QueryResult<Ty> {
1593        let def_id = genv.tcx().require_lang_item(LangItem::OwnedBox, DUMMY_SP);
1594
1595        let generics = genv.generics_of(def_id)?;
1596        let alloc_ty = genv
1597            .lower_type_of(generics.own_params[1].def_id)?
1598            .skip_binder();
1599        let alloc_ty = Refiner::default_for_item(genv, def_id)?.refine_generic_arg(
1600            &generics.own_params[1],
1601            &flux_rustc_bridge::ty::GenericArg::Ty(alloc_ty),
1602        )?;
1603
1604        Ty::mk_box(genv, deref_ty, alloc_ty)
1605    }
1606
1607    pub fn tuple(tys: impl Into<List<Ty>>) -> Ty {
1608        BaseTy::Tuple(tys.into()).to_ty()
1609    }
1610
1611    pub fn array(ty: Ty, c: Const) -> Ty {
1612        BaseTy::Array(ty, c).to_ty()
1613    }
1614
1615    pub fn closure(
1616        did: DefId,
1617        tys: impl Into<List<Ty>>,
1618        args: &flux_rustc_bridge::ty::GenericArgs,
1619        no_panic: bool,
1620    ) -> Ty {
1621        BaseTy::Closure(did, tys.into(), args.clone(), no_panic).to_ty()
1622    }
1623
1624    pub fn coroutine(
1625        did: DefId,
1626        resume_ty: Ty,
1627        upvar_tys: List<Ty>,
1628        args: flux_rustc_bridge::ty::GenericArgs,
1629    ) -> Ty {
1630        BaseTy::Coroutine(did, resume_ty, upvar_tys, args.clone()).to_ty()
1631    }
1632
1633    pub fn never() -> Ty {
1634        BaseTy::Never.to_ty()
1635    }
1636
1637    pub fn infer(vid: TyVid) -> Ty {
1638        TyKind::Infer(vid).intern()
1639    }
1640
1641    pub fn unconstr(&self) -> (Ty, Expr) {
1642        fn go(this: &Ty, preds: &mut Vec<Expr>) -> Ty {
1643            if let TyKind::Constr(pred, ty) = this.kind() {
1644                preds.push(pred.clone());
1645                go(ty, preds)
1646            } else {
1647                this.clone()
1648            }
1649        }
1650        let mut preds = vec![];
1651        (go(self, &mut preds), Expr::and_from_iter(preds))
1652    }
1653
1654    pub fn unblocked(&self) -> Ty {
1655        match self.kind() {
1656            TyKind::Blocked(ty) => ty.clone(),
1657            _ => self.clone(),
1658        }
1659    }
1660
1661    /// Whether the type is an `int` or a `uint`
1662    pub fn is_integral(&self) -> bool {
1663        self.as_bty_skipping_existentials()
1664            .map(BaseTy::is_integral)
1665            .unwrap_or_default()
1666    }
1667
1668    /// Whether the type is a `bool`
1669    pub fn is_bool(&self) -> bool {
1670        self.as_bty_skipping_existentials()
1671            .map(BaseTy::is_bool)
1672            .unwrap_or_default()
1673    }
1674
1675    /// Whether the type is a `char`
1676    pub fn is_char(&self) -> bool {
1677        self.as_bty_skipping_existentials()
1678            .map(BaseTy::is_char)
1679            .unwrap_or_default()
1680    }
1681
1682    pub fn is_uninit(&self) -> bool {
1683        matches!(self.kind(), TyKind::Uninit)
1684    }
1685
1686    pub fn is_box(&self) -> bool {
1687        self.as_bty_skipping_existentials()
1688            .map(BaseTy::is_box)
1689            .unwrap_or_default()
1690    }
1691
1692    pub fn is_struct(&self) -> bool {
1693        self.as_bty_skipping_existentials()
1694            .map(BaseTy::is_struct)
1695            .unwrap_or_default()
1696    }
1697
1698    pub fn is_array(&self) -> bool {
1699        self.as_bty_skipping_existentials()
1700            .map(BaseTy::is_array)
1701            .unwrap_or_default()
1702    }
1703
1704    pub fn is_slice(&self) -> bool {
1705        self.as_bty_skipping_existentials()
1706            .map(BaseTy::is_slice)
1707            .unwrap_or_default()
1708    }
1709
1710    pub fn as_bty_skipping_existentials(&self) -> Option<&BaseTy> {
1711        match self.kind() {
1712            TyKind::Indexed(bty, _) => Some(bty),
1713            TyKind::Exists(ty) => Some(ty.skip_binder_ref().as_bty_skipping_existentials()?),
1714            TyKind::Constr(_, ty) => ty.as_bty_skipping_existentials(),
1715            _ => None,
1716        }
1717    }
1718
1719    #[track_caller]
1720    pub fn expect_discr(&self) -> (&AdtDef, &Place) {
1721        if let TyKind::Discr(adt_def, place) = self.kind() {
1722            (adt_def, place)
1723        } else {
1724            tracked_span_bug!("expected discr")
1725        }
1726    }
1727
1728    #[track_caller]
1729    pub fn expect_adt(&self) -> (&AdtDef, &[GenericArg], &Expr) {
1730        if let TyKind::Indexed(BaseTy::Adt(adt_def, args), idx) = self.kind() {
1731            (adt_def, args, idx)
1732        } else {
1733            tracked_span_bug!("expected adt `{self:?}`")
1734        }
1735    }
1736
1737    #[track_caller]
1738    pub fn expect_tuple(&self) -> &[Ty] {
1739        if let TyKind::Indexed(BaseTy::Tuple(tys), _) = self.kind() {
1740            tys
1741        } else {
1742            tracked_span_bug!("expected tuple found `{self:?}` (kind: `{:?}`)", self.kind())
1743        }
1744    }
1745
1746    pub fn simplify_type(&self) -> Option<SimplifiedType> {
1747        self.as_bty_skipping_existentials()
1748            .and_then(BaseTy::simplify_type)
1749    }
1750}
1751
1752impl<'tcx> ToRustc<'tcx> for Ty {
1753    type T = rustc_middle::ty::Ty<'tcx>;
1754
1755    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
1756        match self.kind() {
1757            TyKind::Indexed(bty, _) => bty.to_rustc(tcx),
1758            TyKind::Exists(ty) => ty.skip_binder_ref().to_rustc(tcx),
1759            TyKind::Constr(_, ty) => ty.to_rustc(tcx),
1760            TyKind::Param(pty) => pty.to_ty(tcx),
1761            TyKind::StrgRef(re, _, ty) => {
1762                rustc_middle::ty::Ty::new_ref(
1763                    tcx,
1764                    re.to_rustc(tcx),
1765                    ty.to_rustc(tcx),
1766                    Mutability::Mut,
1767                )
1768            }
1769            TyKind::Infer(vid) => rustc_middle::ty::Ty::new_var(tcx, *vid),
1770            TyKind::Uninit
1771            | TyKind::Ptr(_, _)
1772            | TyKind::Discr(..)
1773            | TyKind::Downcast(..)
1774            | TyKind::Blocked(_) => bug!("TODO: to_rustc for `{self:?}`"),
1775        }
1776    }
1777}
1778
1779#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, Debug)]
1780pub enum TyKind {
1781    Indexed(BaseTy, Expr),
1782    Exists(Binder<Ty>),
1783    Constr(Expr, Ty),
1784    Uninit,
1785    StrgRef(Region, Path, Ty),
1786    Ptr(PtrKind, Path),
1787    /// This is a bit of a hack. We use this type internally to represent the result of
1788    /// [`Rvalue::Discriminant`] in a way that we can recover the necessary control information
1789    /// when checking a [`match`]. The hack is that we assume the dicriminant remains the same from
1790    /// the creation of this type until we use it in a [`match`].
1791    ///
1792    ///
1793    /// [`Rvalue::Discriminant`]: flux_rustc_bridge::mir::Rvalue::Discriminant
1794    /// [`match`]: flux_rustc_bridge::mir::TerminatorKind::SwitchInt
1795    Discr(AdtDef, Place),
1796    Param(ParamTy),
1797    /// These only arise when you "narrow" an ADT down to a particular variant;
1798    /// either EXPLICITLY in a `match-of`, or IMPLICITLY when you access a field
1799    /// of a struct to "UNPACK" the struct.
1800    Downcast(AdtDef, GenericArgs, Ty, VariantIdx, List<Ty>),
1801    Blocked(Ty),
1802    /// A type that needs to be inferred by matching the signature against a rust signature.
1803    /// [`TyKind::Infer`] appear as an intermediate step during `conv` and should not be present in
1804    /// the final signature.
1805    Infer(TyVid),
1806}
1807
1808#[derive(Copy, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1809pub enum PtrKind {
1810    Mut(Region),
1811    Box,
1812}
1813
1814#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1815pub enum BaseTy {
1816    Int(IntTy),
1817    Uint(UintTy),
1818    Bool,
1819    Str,
1820    Char,
1821    Slice(Ty),
1822    Adt(AdtDef, GenericArgs),
1823    Float(FloatTy),
1824    RawPtr(Ty, Mutability),
1825    RawPtrMetadata(Ty),
1826    Ref(Region, Ty, Mutability),
1827    FnPtr(PolyFnSig),
1828    FnDef(DefId, GenericArgs),
1829    Tuple(List<Ty>),
1830    Alias(AliasKind, AliasTy),
1831    Array(Ty, Const),
1832    Never,
1833    Closure(DefId, /* upvar_tys */ List<Ty>, flux_rustc_bridge::ty::GenericArgs, bool),
1834    Coroutine(
1835        DefId,
1836        /*resume_ty: */ Ty,
1837        /* upvar_tys: */ List<Ty>,
1838        flux_rustc_bridge::ty::GenericArgs,
1839    ),
1840    Dynamic(List<Binder<ExistentialPredicate>>, Region),
1841    Param(ParamTy),
1842    Infer(TyVid),
1843    Foreign(DefId),
1844    Pat,
1845}
1846
1847impl BaseTy {
1848    pub fn opaque(alias_ty: AliasTy) -> BaseTy {
1849        BaseTy::Alias(AliasKind::Opaque, alias_ty)
1850    }
1851
1852    pub fn projection(alias_ty: AliasTy) -> BaseTy {
1853        BaseTy::Alias(AliasKind::Projection, alias_ty)
1854    }
1855
1856    pub fn adt(adt_def: AdtDef, args: GenericArgs) -> BaseTy {
1857        BaseTy::Adt(adt_def, args)
1858    }
1859
1860    pub fn fn_def(def_id: DefId, args: impl Into<GenericArgs>) -> BaseTy {
1861        BaseTy::FnDef(def_id, args.into())
1862    }
1863
1864    pub fn from_primitive_str(s: &str) -> Option<BaseTy> {
1865        match s {
1866            "i8" => Some(BaseTy::Int(IntTy::I8)),
1867            "i16" => Some(BaseTy::Int(IntTy::I16)),
1868            "i32" => Some(BaseTy::Int(IntTy::I32)),
1869            "i64" => Some(BaseTy::Int(IntTy::I64)),
1870            "i128" => Some(BaseTy::Int(IntTy::I128)),
1871            "u8" => Some(BaseTy::Uint(UintTy::U8)),
1872            "u16" => Some(BaseTy::Uint(UintTy::U16)),
1873            "u32" => Some(BaseTy::Uint(UintTy::U32)),
1874            "u64" => Some(BaseTy::Uint(UintTy::U64)),
1875            "u128" => Some(BaseTy::Uint(UintTy::U128)),
1876            "f16" => Some(BaseTy::Float(FloatTy::F16)),
1877            "f32" => Some(BaseTy::Float(FloatTy::F32)),
1878            "f64" => Some(BaseTy::Float(FloatTy::F64)),
1879            "f128" => Some(BaseTy::Float(FloatTy::F128)),
1880            "isize" => Some(BaseTy::Int(IntTy::Isize)),
1881            "usize" => Some(BaseTy::Uint(UintTy::Usize)),
1882            "bool" => Some(BaseTy::Bool),
1883            "char" => Some(BaseTy::Char),
1884            "str" => Some(BaseTy::Str),
1885            _ => None,
1886        }
1887    }
1888
1889    /// If `self` is a primitive, return its [`Symbol`].
1890    pub fn primitive_symbol(&self) -> Option<Symbol> {
1891        match self {
1892            BaseTy::Bool => Some(sym::bool),
1893            BaseTy::Char => Some(sym::char),
1894            BaseTy::Float(f) => {
1895                match f {
1896                    FloatTy::F16 => Some(sym::f16),
1897                    FloatTy::F32 => Some(sym::f32),
1898                    FloatTy::F64 => Some(sym::f64),
1899                    FloatTy::F128 => Some(sym::f128),
1900                }
1901            }
1902            BaseTy::Int(f) => {
1903                match f {
1904                    IntTy::Isize => Some(sym::isize),
1905                    IntTy::I8 => Some(sym::i8),
1906                    IntTy::I16 => Some(sym::i16),
1907                    IntTy::I32 => Some(sym::i32),
1908                    IntTy::I64 => Some(sym::i64),
1909                    IntTy::I128 => Some(sym::i128),
1910                }
1911            }
1912            BaseTy::Uint(f) => {
1913                match f {
1914                    UintTy::Usize => Some(sym::usize),
1915                    UintTy::U8 => Some(sym::u8),
1916                    UintTy::U16 => Some(sym::u16),
1917                    UintTy::U32 => Some(sym::u32),
1918                    UintTy::U64 => Some(sym::u64),
1919                    UintTy::U128 => Some(sym::u128),
1920                }
1921            }
1922            BaseTy::Str => Some(sym::str),
1923            _ => None,
1924        }
1925    }
1926
1927    pub fn is_integral(&self) -> bool {
1928        matches!(self, BaseTy::Int(_) | BaseTy::Uint(_))
1929    }
1930
1931    pub fn is_signed(&self) -> bool {
1932        matches!(self, BaseTy::Int(_))
1933    }
1934
1935    pub fn is_unsigned(&self) -> bool {
1936        matches!(self, BaseTy::Uint(_))
1937    }
1938
1939    pub fn is_float(&self) -> bool {
1940        matches!(self, BaseTy::Float(_))
1941    }
1942
1943    pub fn is_bool(&self) -> bool {
1944        matches!(self, BaseTy::Bool)
1945    }
1946
1947    fn is_struct(&self) -> bool {
1948        matches!(self, BaseTy::Adt(adt_def, _) if adt_def.is_struct())
1949    }
1950
1951    fn is_array(&self) -> bool {
1952        matches!(self, BaseTy::Array(..))
1953    }
1954
1955    fn is_slice(&self) -> bool {
1956        matches!(self, BaseTy::Slice(..))
1957    }
1958
1959    pub fn is_box(&self) -> bool {
1960        matches!(self, BaseTy::Adt(adt_def, _) if adt_def.is_box())
1961    }
1962
1963    pub fn is_char(&self) -> bool {
1964        matches!(self, BaseTy::Char)
1965    }
1966
1967    pub fn is_str(&self) -> bool {
1968        matches!(self, BaseTy::Str)
1969    }
1970
1971    pub fn invariants(
1972        &self,
1973        tcx: TyCtxt,
1974        overflow_mode: OverflowMode,
1975    ) -> impl Iterator<Item = Invariant> {
1976        let (invariants, args) = match self {
1977            BaseTy::Adt(adt_def, args) => (adt_def.invariants().skip_binder(), &args[..]),
1978            BaseTy::Uint(uint_ty) => (uint_invariants(*uint_ty, overflow_mode), &[][..]),
1979            BaseTy::Int(int_ty) => (int_invariants(*int_ty, overflow_mode), &[][..]),
1980            BaseTy::Char => (char_invariants(), &[][..]),
1981            BaseTy::Slice(_) => (slice_invariants(overflow_mode), &[][..]),
1982            _ => (&[][..], &[][..]),
1983        };
1984        invariants
1985            .iter()
1986            .map(move |inv| EarlyBinder(inv).instantiate_ref(tcx, args, &[]))
1987    }
1988
1989    pub fn to_ty(&self) -> Ty {
1990        let sort = self.sort();
1991        if sort.is_unit() {
1992            Ty::indexed(self.clone(), Expr::unit())
1993        } else {
1994            Ty::exists(Binder::bind_with_sort(
1995                Ty::indexed(self.shift_in_escaping(1), Expr::nu()),
1996                sort,
1997            ))
1998        }
1999    }
2000
2001    pub fn to_subset_ty_ctor(&self) -> SubsetTyCtor {
2002        let sort = self.sort();
2003        Binder::bind_with_sort(SubsetTy::trivial(self.clone(), Expr::nu()), sort)
2004    }
2005
2006    #[track_caller]
2007    pub fn expect_adt(&self) -> (&AdtDef, &[GenericArg]) {
2008        if let BaseTy::Adt(adt_def, args) = self {
2009            (adt_def, args)
2010        } else {
2011            tracked_span_bug!("expected adt `{self:?}`")
2012        }
2013    }
2014
2015    /// A type is an *atom* if it is "self-delimiting", i.e., it has a clear boundary
2016    /// when printed. This is used to avoid unnecessary parenthesis when pretty printing.
2017    pub fn is_atom(&self) -> bool {
2018        // (nilehmann) I'm not sure about this list, please adjust if you get any odd behavior
2019        matches!(
2020            self,
2021            BaseTy::Int(_)
2022                | BaseTy::Uint(_)
2023                | BaseTy::Slice(_)
2024                | BaseTy::Bool
2025                | BaseTy::Char
2026                | BaseTy::Str
2027                | BaseTy::Adt(..)
2028                | BaseTy::Tuple(..)
2029                | BaseTy::Param(_)
2030                | BaseTy::Array(..)
2031                | BaseTy::Never
2032                | BaseTy::Closure(..)
2033                | BaseTy::Coroutine(..)
2034                // opaque alias are atoms the way we print them now, but they won't
2035                // be if we print them as `impl Trait`
2036                | BaseTy::Alias(..)
2037        )
2038    }
2039
2040    /// Similar to [`rustc_infer::infer::canonical::ir::fast_reject::simplify_type`].
2041    ///
2042    /// This implementation is currently incomplete, so it should only be used in contexts
2043    /// where completeness is not required. Currently, it's used to find incoherent
2044    /// implementations when resolving associated constants. In this context, incompleteness
2045    /// is acceptable since the worst case outcome is simply failing to resolve a type-relative
2046    /// constant.
2047    fn simplify_type(&self) -> Option<SimplifiedType> {
2048        match self {
2049            BaseTy::Bool => Some(SimplifiedType::Bool),
2050            BaseTy::Char => Some(SimplifiedType::Char),
2051            BaseTy::Int(int_type) => Some(SimplifiedType::Int(*int_type)),
2052            BaseTy::Uint(uint_type) => Some(SimplifiedType::Uint(*uint_type)),
2053            BaseTy::Float(float_type) => Some(SimplifiedType::Float(*float_type)),
2054            BaseTy::Adt(def, _) => Some(SimplifiedType::Adt(def.did())),
2055            BaseTy::Str => Some(SimplifiedType::Str),
2056            BaseTy::Array(..) => Some(SimplifiedType::Array),
2057            BaseTy::Slice(..) => Some(SimplifiedType::Slice),
2058            BaseTy::RawPtr(_, mutbl) => Some(SimplifiedType::Ptr(*mutbl)),
2059            BaseTy::Ref(_, _, mutbl) => Some(SimplifiedType::Ref(*mutbl)),
2060            BaseTy::FnDef(def_id, _) | BaseTy::Closure(def_id, ..) => {
2061                Some(SimplifiedType::Closure(*def_id))
2062            }
2063            BaseTy::Coroutine(def_id, ..) => Some(SimplifiedType::Coroutine(*def_id)),
2064            BaseTy::Never => Some(SimplifiedType::Never),
2065            BaseTy::Tuple(tys) => Some(SimplifiedType::Tuple(tys.len())),
2066            BaseTy::FnPtr(poly_fn_sig) => {
2067                Some(SimplifiedType::Function(poly_fn_sig.skip_binder_ref().inputs().len()))
2068            }
2069            BaseTy::Foreign(def_id) => Some(SimplifiedType::Foreign(*def_id)),
2070            BaseTy::RawPtrMetadata(_)
2071            | BaseTy::Alias(..)
2072            | BaseTy::Param(_)
2073            | BaseTy::Dynamic(..)
2074            | BaseTy::Infer(_) => None,
2075            BaseTy::Pat => todo!(),
2076        }
2077    }
2078}
2079
2080impl<'tcx> ToRustc<'tcx> for BaseTy {
2081    type T = rustc_middle::ty::Ty<'tcx>;
2082
2083    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2084        use rustc_middle::ty;
2085        match self {
2086            BaseTy::Int(i) => ty::Ty::new_int(tcx, *i),
2087            BaseTy::Uint(i) => ty::Ty::new_uint(tcx, *i),
2088            BaseTy::Param(pty) => pty.to_ty(tcx),
2089            BaseTy::Slice(ty) => ty::Ty::new_slice(tcx, ty.to_rustc(tcx)),
2090            BaseTy::Bool => tcx.types.bool,
2091            BaseTy::Char => tcx.types.char,
2092            BaseTy::Str => tcx.types.str_,
2093            BaseTy::Adt(adt_def, args) => {
2094                let did = adt_def.did();
2095                let adt_def = tcx.adt_def(did);
2096                let args = args.to_rustc(tcx);
2097                ty::Ty::new_adt(tcx, adt_def, args)
2098            }
2099            BaseTy::FnDef(def_id, args) => {
2100                let args = args.to_rustc(tcx);
2101                ty::Ty::new_fn_def(tcx, *def_id, args)
2102            }
2103            BaseTy::Float(f) => ty::Ty::new_float(tcx, *f),
2104            BaseTy::RawPtr(ty, mutbl) => ty::Ty::new_ptr(tcx, ty.to_rustc(tcx), *mutbl),
2105            BaseTy::Ref(re, ty, mutbl) => {
2106                ty::Ty::new_ref(tcx, re.to_rustc(tcx), ty.to_rustc(tcx), *mutbl)
2107            }
2108            BaseTy::FnPtr(poly_sig) => ty::Ty::new_fn_ptr(tcx, poly_sig.to_rustc(tcx)),
2109            BaseTy::Tuple(tys) => {
2110                let ts = tys.iter().map(|ty| ty.to_rustc(tcx)).collect_vec();
2111                ty::Ty::new_tup(tcx, &ts)
2112            }
2113            BaseTy::Alias(kind, alias_ty) => {
2114                ty::Ty::new_alias(tcx, kind.to_rustc(tcx), alias_ty.to_rustc(tcx))
2115            }
2116            BaseTy::Array(ty, n) => {
2117                let ty = ty.to_rustc(tcx);
2118                let n = n.to_rustc(tcx);
2119                ty::Ty::new_array_with_const_len(tcx, ty, n)
2120            }
2121            BaseTy::Never => tcx.types.never,
2122            BaseTy::Closure(did, _, args, _) => ty::Ty::new_closure(tcx, *did, args.to_rustc(tcx)),
2123            BaseTy::Dynamic(exi_preds, re) => {
2124                let preds: Vec<_> = exi_preds
2125                    .iter()
2126                    .map(|pred| pred.to_rustc(tcx))
2127                    .collect_vec();
2128                let preds = tcx.mk_poly_existential_predicates(&preds);
2129                ty::Ty::new_dynamic(tcx, preds, re.to_rustc(tcx))
2130            }
2131            BaseTy::Coroutine(did, _, _, args) => {
2132                ty::Ty::new_coroutine(tcx, *did, args.to_rustc(tcx))
2133            }
2134            BaseTy::Infer(ty_vid) => ty::Ty::new_var(tcx, *ty_vid),
2135            BaseTy::Foreign(def_id) => ty::Ty::new_foreign(tcx, *def_id),
2136            BaseTy::RawPtrMetadata(ty) => {
2137                ty::Ty::new_ptr(
2138                    tcx,
2139                    ty.to_rustc(tcx),
2140                    RawPtrKind::FakeForPtrMetadata.to_mutbl_lossy(),
2141                )
2142            }
2143            BaseTy::Pat => todo!(),
2144        }
2145    }
2146}
2147
2148#[derive(
2149    Clone, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
2150)]
2151pub struct AliasTy {
2152    pub def_id: DefId,
2153    pub args: GenericArgs,
2154    /// Holds the refinement-arguments for opaque-types; empty for projections
2155    pub refine_args: RefineArgs,
2156}
2157
2158impl AliasTy {
2159    pub fn new(def_id: DefId, args: GenericArgs, refine_args: RefineArgs) -> Self {
2160        AliasTy { args, refine_args, def_id }
2161    }
2162}
2163
2164/// This methods work only with associated type projections (i.e., no opaque types)
2165impl AliasTy {
2166    pub fn self_ty(&self) -> SubsetTyCtor {
2167        self.args[0].expect_base().clone()
2168    }
2169
2170    pub fn with_self_ty(&self, self_ty: SubsetTyCtor) -> Self {
2171        Self {
2172            def_id: self.def_id,
2173            args: [GenericArg::Base(self_ty)]
2174                .into_iter()
2175                .chain(self.args.iter().skip(1).cloned())
2176                .collect(),
2177            refine_args: self.refine_args.clone(),
2178        }
2179    }
2180}
2181
2182impl<'tcx> ToRustc<'tcx> for AliasTy {
2183    type T = rustc_middle::ty::AliasTy<'tcx>;
2184
2185    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2186        rustc_middle::ty::AliasTy::new(tcx, self.def_id, self.args.to_rustc(tcx))
2187    }
2188}
2189
2190pub type RefineArgs = List<Expr>;
2191
2192#[extension(pub trait RefineArgsExt)]
2193impl RefineArgs {
2194    fn identity_for_item(genv: GlobalEnv, def_id: DefId) -> QueryResult<RefineArgs> {
2195        Self::for_item(genv, def_id, |param, index| {
2196            Ok(Expr::var(Var::EarlyParam(EarlyReftParam {
2197                index: index as u32,
2198                name: param.name(),
2199            })))
2200        })
2201    }
2202
2203    fn for_item<F>(genv: GlobalEnv, def_id: DefId, mut mk: F) -> QueryResult<RefineArgs>
2204    where
2205        F: FnMut(EarlyBinder<RefineParam>, usize) -> QueryResult<Expr>,
2206    {
2207        let reft_generics = genv.refinement_generics_of(def_id)?;
2208        let count = reft_generics.count();
2209        let mut args = Vec::with_capacity(count);
2210        reft_generics.fill_item(genv, &mut args, &mut mk)?;
2211        Ok(List::from_vec(args))
2212    }
2213}
2214
2215/// A type constructor meant to be used as generic a argument of [kind base]. This is just an alias
2216/// to [`Binder<SubsetTy>`], but we expect the binder to have a single bound variable of the sort of
2217/// the underlying [base type].
2218///
2219/// [kind base]: GenericParamDefKind::Base
2220/// [base type]: SubsetTy::bty
2221pub type SubsetTyCtor = Binder<SubsetTy>;
2222
2223impl SubsetTyCtor {
2224    pub fn as_bty_skipping_binder(&self) -> &BaseTy {
2225        &self.as_ref().skip_binder().bty
2226    }
2227
2228    pub fn to_ty(&self) -> Ty {
2229        let sort = self.sort();
2230        if sort.is_unit() {
2231            self.replace_bound_reft(&Expr::unit()).to_ty()
2232        } else if let Some(def_id) = sort.is_unit_adt() {
2233            self.replace_bound_reft(&Expr::unit_struct(def_id)).to_ty()
2234        } else {
2235            Ty::exists(self.as_ref().map(SubsetTy::to_ty))
2236        }
2237    }
2238
2239    pub fn to_ty_ctor(&self) -> TyCtor {
2240        self.as_ref().map(SubsetTy::to_ty)
2241    }
2242}
2243
2244/// A subset type is a simplified version of a type that has the form `{b[e] | p}` where `b` is a
2245/// [`BaseTy`], `e` a refinement index, and `p` a predicate.
2246///
2247/// These are mainly found under a [`Binder`] with a single variable of the base type's sort. This
2248/// can be interpreted as a type constructor or an existential type. For example, under a binder with a
2249/// variable `v` of sort `int`, we can interpret `{i32[v] | v > 0}` as:
2250/// - A lambda `λv:int. {i32[v] | v > 0}` that "constructs" types when applied to ints, or
2251/// - An existential type `∃v:int. {i32[v] | v > 0}`.
2252///
2253/// This second interpretation is the reason we call this a subset type, i.e., the type `∃v. {b[v] | p}`
2254/// corresponds to the subset of values of  type `b` whose index satisfies `p`. These are the types
2255/// written as `B{v: p}` in the surface syntax and correspond to the types supported in other
2256/// refinement type systems like Liquid Haskell (with the difference that we are explicit
2257/// about separating refinements from program values via an index).
2258///
2259/// The main purpose for subset types is to be used as generic arguments of [kind base] when
2260/// interpreted as type constructors. They have two key properties that makes them suitable
2261/// for this:
2262///
2263/// 1. **Syntactic Restriction**: Subset types are syntactically restricted, making it easier to
2264///    relate them structurally (e.g., for subtyping). For instance, given two types `S<λv. T1>` and
2265///    `S<λ. T2>`, if `T1` and `T2` are subset types, we know they match structurally (at least
2266///    shallowly). In particularly, the syntactic restriction rules out complex types like
2267///    `S<λv. (i32[v], i32[v])>` simplifying some operations.
2268///
2269/// 2. **Eager Canonicalization**: Subset types can be eagerly canonicalized via [*strengthening*]
2270///    during substitution. For example, suppose we have a function:
2271///    ```text
2272///    fn foo<T>(x: T[@a], y: { T[@b] | b == a }) { }
2273///    ```
2274///    If we instantiate `T` with `λv. { i32[v] | v > 0}`, after substitution and applying the
2275///    lambda (the indexing syntax `T[a]` corresponds to an application of the lambda), we get:
2276///    ```text
2277///    fn foo(x: {i32[@a] | a > 0}, y: { { i32[@b] | b > 0 } | b == a }) { }
2278///    ```
2279///    Via *strengthening* we can canonicalize this to
2280///    ```text
2281///    fn foo(x: {i32[@a] | a > 0}, y: { i32[@b] | b == a && b > 0 }) { }
2282///    ```
2283///    As a result, we can guarantee the syntactic restriction through substitution.
2284///
2285/// [kind base]: GenericParamDefKind::Base
2286/// [*strengthening*]: https://arxiv.org/pdf/2010.07763.pdf
2287#[derive(PartialEq, Clone, Eq, Hash, TyEncodable, TyDecodable)]
2288pub struct SubsetTy {
2289    /// The base type `b` in the subset type `{b[e] | p}`.
2290    ///
2291    /// **NOTE:** This is mostly going to be under a [`Binder`]. It is not yet clear to me whether
2292    /// this [`BaseTy`] should be able to mention variables in the binder. In general, in a type
2293    /// `∃v. {b[e] | p}`, it's fine to mention `v` inside `b`, but since [`SubsetTy`] is meant to
2294    /// facilitate syntactic manipulation we may want to restrict this.
2295    pub bty: BaseTy,
2296    /// The refinement index `e` in the subset type `{b[e] | p}`. This can be an arbitrary expression,
2297    /// which makes manipulation easier. However, since this is mostly found under a binder, we expect
2298    /// it to be [`Expr::nu()`].
2299    pub idx: Expr,
2300    /// The predicate `p` in the subset type `{b[e] | p}`.
2301    pub pred: Expr,
2302}
2303
2304impl SubsetTy {
2305    pub fn new(bty: BaseTy, idx: impl Into<Expr>, pred: impl Into<Expr>) -> Self {
2306        Self { bty, idx: idx.into(), pred: pred.into() }
2307    }
2308
2309    pub fn trivial(bty: BaseTy, idx: impl Into<Expr>) -> Self {
2310        Self::new(bty, idx, Expr::tt())
2311    }
2312
2313    pub fn strengthen(&self, pred: impl Into<Expr>) -> Self {
2314        let this = self.clone();
2315        let pred = Expr::and(this.pred, pred).simplify(&SnapshotMap::default());
2316        Self { bty: this.bty, idx: this.idx, pred }
2317    }
2318
2319    pub fn to_ty(&self) -> Ty {
2320        let bty = self.bty.clone();
2321        if self.pred.is_trivially_true() {
2322            Ty::indexed(bty, &self.idx)
2323        } else {
2324            Ty::constr(&self.pred, Ty::indexed(bty, &self.idx))
2325        }
2326    }
2327}
2328
2329impl<'tcx> ToRustc<'tcx> for SubsetTy {
2330    type T = rustc_middle::ty::Ty<'tcx>;
2331
2332    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> rustc_middle::ty::Ty<'tcx> {
2333        self.bty.to_rustc(tcx)
2334    }
2335}
2336
2337#[derive(PartialEq, Clone, Eq, Hash, TyEncodable, TyDecodable)]
2338pub enum GenericArg {
2339    Ty(Ty),
2340    Base(SubsetTyCtor),
2341    Lifetime(Region),
2342    Const(Const),
2343}
2344
2345impl GenericArg {
2346    #[track_caller]
2347    pub fn expect_type(&self) -> &Ty {
2348        if let GenericArg::Ty(ty) = self {
2349            ty
2350        } else {
2351            bug!("expected `rty::GenericArg::Ty`, found `{self:?}`")
2352        }
2353    }
2354
2355    #[track_caller]
2356    pub fn expect_base(&self) -> &SubsetTyCtor {
2357        if let GenericArg::Base(ctor) = self {
2358            ctor
2359        } else {
2360            bug!("expected `rty::GenericArg::Base`, found `{self:?}`")
2361        }
2362    }
2363
2364    pub fn from_param_def(param: &GenericParamDef) -> Self {
2365        match param.kind {
2366            GenericParamDefKind::Type { .. } => {
2367                let param_ty = ParamTy { index: param.index, name: param.name };
2368                GenericArg::Ty(Ty::param(param_ty))
2369            }
2370            GenericParamDefKind::Base { .. } => {
2371                // λv. T[v]
2372                let param_ty = ParamTy { index: param.index, name: param.name };
2373                GenericArg::Base(Binder::bind_with_sort(
2374                    SubsetTy::trivial(BaseTy::Param(param_ty), Expr::nu()),
2375                    Sort::Param(param_ty),
2376                ))
2377            }
2378            GenericParamDefKind::Lifetime => {
2379                let region = EarlyParamRegion { index: param.index, name: param.name };
2380                GenericArg::Lifetime(Region::ReEarlyParam(region))
2381            }
2382            GenericParamDefKind::Const { .. } => {
2383                let param_const = ParamConst { index: param.index, name: param.name };
2384                let kind = ConstKind::Param(param_const);
2385                GenericArg::Const(Const { kind })
2386            }
2387        }
2388    }
2389
2390    /// Creates a `GenericArgs` from the definition of generic parameters, by calling a closure to
2391    /// obtain arg. The closures get to observe the `GenericArgs` as they're being built, which can
2392    /// be used to correctly replace defaults of generic parameters.
2393    pub fn for_item<F>(genv: GlobalEnv, def_id: DefId, mut mk_kind: F) -> QueryResult<GenericArgs>
2394    where
2395        F: FnMut(&GenericParamDef, &[GenericArg]) -> GenericArg,
2396    {
2397        let defs = genv.generics_of(def_id)?;
2398        let count = defs.count();
2399        let mut args = Vec::with_capacity(count);
2400        Self::fill_item(genv, &mut args, &defs, &mut mk_kind)?;
2401        Ok(List::from_vec(args))
2402    }
2403
2404    pub fn identity_for_item(genv: GlobalEnv, def_id: DefId) -> QueryResult<GenericArgs> {
2405        Self::for_item(genv, def_id, |param, _| GenericArg::from_param_def(param))
2406    }
2407
2408    fn fill_item<F>(
2409        genv: GlobalEnv,
2410        args: &mut Vec<GenericArg>,
2411        generics: &Generics,
2412        mk_kind: &mut F,
2413    ) -> QueryResult<()>
2414    where
2415        F: FnMut(&GenericParamDef, &[GenericArg]) -> GenericArg,
2416    {
2417        if let Some(def_id) = generics.parent {
2418            let parent_generics = genv.generics_of(def_id)?;
2419            Self::fill_item(genv, args, &parent_generics, mk_kind)?;
2420        }
2421        for param in &generics.own_params {
2422            let kind = mk_kind(param, args);
2423            tracked_span_assert_eq!(param.index as usize, args.len());
2424            args.push(kind);
2425        }
2426        Ok(())
2427    }
2428}
2429
2430impl From<TyOrBase> for GenericArg {
2431    fn from(v: TyOrBase) -> Self {
2432        match v {
2433            TyOrBase::Ty(ty) => GenericArg::Ty(ty),
2434            TyOrBase::Base(ctor) => GenericArg::Base(ctor),
2435        }
2436    }
2437}
2438
2439impl<'tcx> ToRustc<'tcx> for GenericArg {
2440    type T = rustc_middle::ty::GenericArg<'tcx>;
2441
2442    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2443        use rustc_middle::ty;
2444        match self {
2445            GenericArg::Ty(ty) => ty::GenericArg::from(ty.to_rustc(tcx)),
2446            GenericArg::Base(ctor) => ty::GenericArg::from(ctor.skip_binder_ref().to_rustc(tcx)),
2447            GenericArg::Lifetime(re) => ty::GenericArg::from(re.to_rustc(tcx)),
2448            GenericArg::Const(c) => ty::GenericArg::from(c.to_rustc(tcx)),
2449        }
2450    }
2451}
2452
2453pub type GenericArgs = List<GenericArg>;
2454
2455#[extension(pub trait GenericArgsExt)]
2456impl GenericArgs {
2457    #[track_caller]
2458    fn box_args(&self) -> (&Ty, &GenericArg) {
2459        if let [GenericArg::Ty(deref), alloc] = &self[..] {
2460            (deref, alloc)
2461        } else {
2462            bug!("invalid generic arguments for box");
2463        }
2464    }
2465
2466    // We can't implement [`ToRustc`] because of coherence so we add it here
2467    fn to_rustc<'tcx>(&self, tcx: TyCtxt<'tcx>) -> rustc_middle::ty::GenericArgsRef<'tcx> {
2468        tcx.mk_args_from_iter(self.iter().map(|arg| arg.to_rustc(tcx)))
2469    }
2470
2471    fn rebase_onto(
2472        &self,
2473        tcx: &TyCtxt,
2474        source_ancestor: DefId,
2475        target_args: &GenericArgs,
2476    ) -> List<GenericArg> {
2477        let defs = tcx.generics_of(source_ancestor);
2478        target_args
2479            .iter()
2480            .chain(self.iter().skip(defs.count()))
2481            .cloned()
2482            .collect()
2483    }
2484}
2485
2486#[derive(Debug)]
2487pub enum TyOrBase {
2488    Ty(Ty),
2489    Base(SubsetTyCtor),
2490}
2491
2492impl TyOrBase {
2493    pub fn into_ty(self) -> Ty {
2494        match self {
2495            TyOrBase::Ty(ty) => ty,
2496            TyOrBase::Base(ctor) => ctor.to_ty(),
2497        }
2498    }
2499
2500    #[track_caller]
2501    pub fn expect_base(self) -> SubsetTyCtor {
2502        match self {
2503            TyOrBase::Base(ctor) => ctor,
2504            TyOrBase::Ty(_) => tracked_span_bug!("expected `TyOrBase::Base`"),
2505        }
2506    }
2507
2508    pub fn as_base(self) -> Option<SubsetTyCtor> {
2509        match self {
2510            TyOrBase::Base(ctor) => Some(ctor),
2511            TyOrBase::Ty(_) => None,
2512        }
2513    }
2514}
2515
2516#[derive(Debug, Clone, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
2517pub enum TyOrCtor {
2518    Ty(Ty),
2519    Ctor(TyCtor),
2520}
2521
2522impl TyOrCtor {
2523    #[track_caller]
2524    pub fn expect_ctor(self) -> TyCtor {
2525        match self {
2526            TyOrCtor::Ctor(ctor) => ctor,
2527            TyOrCtor::Ty(_) => tracked_span_bug!("expected `TyOrCtor::Ctor`"),
2528        }
2529    }
2530
2531    pub fn expect_subset_ty_ctor(self) -> SubsetTyCtor {
2532        self.expect_ctor().map(|ty| {
2533            if let canonicalize::CanonicalTy::Constr(constr_ty) = ty.shallow_canonicalize()
2534                && let TyKind::Indexed(bty, idx) = constr_ty.ty().kind()
2535                && idx.is_nu()
2536            {
2537                SubsetTy::new(bty.clone(), Expr::nu(), constr_ty.pred())
2538            } else {
2539                tracked_span_bug!()
2540            }
2541        })
2542    }
2543
2544    pub fn to_ty(&self) -> Ty {
2545        match self {
2546            TyOrCtor::Ctor(ctor) => ctor.to_ty(),
2547            TyOrCtor::Ty(ty) => ty.clone(),
2548        }
2549    }
2550}
2551
2552impl From<TyOrBase> for TyOrCtor {
2553    fn from(v: TyOrBase) -> Self {
2554        match v {
2555            TyOrBase::Ty(ty) => TyOrCtor::Ty(ty),
2556            TyOrBase::Base(ctor) => TyOrCtor::Ctor(ctor.to_ty_ctor()),
2557        }
2558    }
2559}
2560
2561impl CoroutineObligPredicate {
2562    pub fn to_poly_fn_sig(&self) -> PolyFnSig {
2563        let vars = vec![];
2564
2565        let resume_ty = &self.resume_ty;
2566        let env_ty = Ty::coroutine(
2567            self.def_id,
2568            resume_ty.clone(),
2569            self.upvar_tys.clone(),
2570            self.args.clone(),
2571        );
2572
2573        let inputs = List::from_arr([env_ty, resume_ty.clone()]);
2574        let output =
2575            Binder::bind_with_vars(FnOutput::new(self.output.clone(), vec![]), List::empty());
2576
2577        PolyFnSig::bind_with_vars(
2578            FnSig::new(
2579                Safety::Safe,
2580                rustc_abi::ExternAbi::RustCall,
2581                List::empty(),
2582                inputs,
2583                output,
2584                Expr::ff(),
2585                false,
2586            ),
2587            List::from(vars),
2588        )
2589    }
2590}
2591
2592impl RefinementGenerics {
2593    pub fn count(&self) -> usize {
2594        self.parent_count + self.own_params.len()
2595    }
2596
2597    pub fn own_count(&self) -> usize {
2598        self.own_params.len()
2599    }
2600}
2601
2602impl EarlyBinder<RefinementGenerics> {
2603    pub fn parent(&self) -> Option<DefId> {
2604        self.skip_binder_ref().parent
2605    }
2606
2607    pub fn parent_count(&self) -> usize {
2608        self.skip_binder_ref().parent_count
2609    }
2610
2611    pub fn count(&self) -> usize {
2612        self.skip_binder_ref().count()
2613    }
2614
2615    pub fn own_count(&self) -> usize {
2616        self.skip_binder_ref().own_count()
2617    }
2618
2619    pub fn own_param_at(&self, index: usize) -> EarlyBinder<RefineParam> {
2620        self.as_ref().map(|this| this.own_params[index].clone())
2621    }
2622
2623    pub fn param_at(
2624        &self,
2625        param_index: usize,
2626        genv: GlobalEnv,
2627    ) -> QueryResult<EarlyBinder<RefineParam>> {
2628        if let Some(index) = param_index.checked_sub(self.parent_count()) {
2629            Ok(self.own_param_at(index))
2630        } else {
2631            let parent = self.parent().expect("parent_count > 0 but no parent?");
2632            genv.refinement_generics_of(parent)?
2633                .param_at(param_index, genv)
2634        }
2635    }
2636
2637    pub fn iter_own_params(&self) -> impl Iterator<Item = EarlyBinder<RefineParam>> + use<'_> {
2638        self.skip_binder_ref()
2639            .own_params
2640            .iter()
2641            .cloned()
2642            .map(EarlyBinder)
2643    }
2644
2645    pub fn fill_item<F, R>(&self, genv: GlobalEnv, vec: &mut Vec<R>, mk: &mut F) -> QueryResult
2646    where
2647        F: FnMut(EarlyBinder<RefineParam>, usize) -> QueryResult<R>,
2648    {
2649        if let Some(def_id) = self.parent() {
2650            genv.refinement_generics_of(def_id)?
2651                .fill_item(genv, vec, mk)?;
2652        }
2653        for param in self.iter_own_params() {
2654            vec.push(mk(param, vec.len())?);
2655        }
2656        Ok(())
2657    }
2658}
2659
2660impl EarlyBinder<GenericPredicates> {
2661    pub fn predicates(&self) -> EarlyBinder<List<Clause>> {
2662        EarlyBinder(self.0.predicates.clone())
2663    }
2664}
2665
2666impl EarlyBinder<FuncSort> {
2667    /// See [`subst::GenericsSubstForSort`]
2668    pub fn instantiate_func_sort<E>(
2669        self,
2670        sort_for_param: impl FnMut(ParamTy) -> Result<Sort, E>,
2671    ) -> Result<FuncSort, E> {
2672        self.0.try_fold_with(&mut subst::GenericsSubstFolder::new(
2673            subst::GenericsSubstForSort { sort_for_param },
2674            &[],
2675        ))
2676    }
2677}
2678
2679impl VariantSig {
2680    pub fn new(
2681        adt_def: AdtDef,
2682        args: GenericArgs,
2683        fields: List<Ty>,
2684        idx: Expr,
2685        requires: List<Expr>,
2686    ) -> Self {
2687        VariantSig { adt_def, args, fields, idx, requires }
2688    }
2689
2690    pub fn fields(&self) -> &[Ty] {
2691        &self.fields
2692    }
2693
2694    pub fn ret(&self) -> Ty {
2695        let bty = BaseTy::Adt(self.adt_def.clone(), self.args.clone());
2696        let idx = self.idx.clone();
2697        Ty::indexed(bty, idx)
2698    }
2699}
2700
2701impl FnSig {
2702    pub fn new(
2703        safety: Safety,
2704        abi: rustc_abi::ExternAbi,
2705        requires: List<Expr>,
2706        inputs: List<Ty>,
2707        output: Binder<FnOutput>,
2708        no_panic: Expr,
2709        lifted: bool,
2710    ) -> Self {
2711        FnSig { safety, abi, requires, inputs, output, no_panic, lifted }
2712    }
2713
2714    pub fn requires(&self) -> &[Expr] {
2715        &self.requires
2716    }
2717
2718    pub fn inputs(&self) -> &[Ty] {
2719        &self.inputs
2720    }
2721
2722    pub fn no_panic(&self) -> Expr {
2723        self.no_panic.clone()
2724    }
2725
2726    pub fn output(&self) -> Binder<FnOutput> {
2727        self.output.clone()
2728    }
2729}
2730
2731impl<'tcx> ToRustc<'tcx> for FnSig {
2732    type T = rustc_middle::ty::FnSig<'tcx>;
2733
2734    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2735        tcx.mk_fn_sig(
2736            self.inputs().iter().map(|ty| ty.to_rustc(tcx)),
2737            self.output().as_ref().skip_binder().to_rustc(tcx),
2738            false,
2739            self.safety,
2740            self.abi,
2741        )
2742    }
2743}
2744
2745impl FnOutput {
2746    pub fn new(ret: Ty, ensures: impl Into<List<Ensures>>) -> Self {
2747        Self { ret, ensures: ensures.into() }
2748    }
2749}
2750
2751impl<'tcx> ToRustc<'tcx> for FnOutput {
2752    type T = rustc_middle::ty::Ty<'tcx>;
2753
2754    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2755        self.ret.to_rustc(tcx)
2756    }
2757}
2758
2759impl AdtDef {
2760    pub fn new(
2761        rustc: ty::AdtDef,
2762        sort_def: AdtSortDef,
2763        invariants: Vec<Invariant>,
2764        opaque: bool,
2765    ) -> Self {
2766        AdtDef(Interned::new(AdtDefData { invariants, sort_def, opaque, rustc }))
2767    }
2768
2769    pub fn did(&self) -> DefId {
2770        self.0.rustc.did()
2771    }
2772
2773    pub fn sort_def(&self) -> &AdtSortDef {
2774        &self.0.sort_def
2775    }
2776
2777    pub fn sort(&self, args: &[GenericArg]) -> Sort {
2778        self.sort_def().to_sort(args)
2779    }
2780
2781    pub fn is_box(&self) -> bool {
2782        self.0.rustc.is_box()
2783    }
2784
2785    pub fn is_enum(&self) -> bool {
2786        self.0.rustc.is_enum()
2787    }
2788
2789    pub fn is_struct(&self) -> bool {
2790        self.0.rustc.is_struct()
2791    }
2792
2793    pub fn is_union(&self) -> bool {
2794        self.0.rustc.is_union()
2795    }
2796
2797    pub fn variants(&self) -> &IndexSlice<VariantIdx, VariantDef> {
2798        self.0.rustc.variants()
2799    }
2800
2801    pub fn variant(&self, idx: VariantIdx) -> &VariantDef {
2802        self.0.rustc.variant(idx)
2803    }
2804
2805    pub fn invariants(&self) -> EarlyBinder<&[Invariant]> {
2806        EarlyBinder(&self.0.invariants)
2807    }
2808
2809    pub fn discriminants(&self) -> impl Iterator<Item = (VariantIdx, u128)> + '_ {
2810        self.0.rustc.discriminants()
2811    }
2812
2813    pub fn is_opaque(&self) -> bool {
2814        self.0.opaque
2815    }
2816}
2817
2818impl EarlyBinder<PolyVariant> {
2819    // The field_idx is `Some(i)` when we have the `i`-th field of a `union`, in which case,
2820    // the `inputs` are _just_ the `i`-th type (and not all the types...)
2821    pub fn to_poly_fn_sig(&self, field_idx: Option<crate::FieldIdx>) -> EarlyBinder<PolyFnSig> {
2822        self.as_ref().map(|poly_variant| {
2823            poly_variant.as_ref().map(|variant| {
2824                let ret = variant.ret().shift_in_escaping(1);
2825                let output = Binder::bind_with_vars(FnOutput::new(ret, vec![]), List::empty());
2826                let inputs = match field_idx {
2827                    None => variant.fields.clone(),
2828                    Some(i) => List::singleton(variant.fields[i.index()].clone()),
2829                };
2830                FnSig::new(
2831                    Safety::Safe,
2832                    rustc_abi::ExternAbi::Rust,
2833                    variant.requires.clone(),
2834                    inputs,
2835                    output,
2836                    Expr::tt(),
2837                    false,
2838                )
2839            })
2840        })
2841    }
2842}
2843
2844impl TyKind {
2845    fn intern(self) -> Ty {
2846        Ty(Interned::new(self))
2847    }
2848}
2849
2850/// returns the same invariants as for `usize` which is the length of a slice
2851fn slice_invariants(overflow_mode: OverflowMode) -> &'static [Invariant] {
2852    static DEFAULT: LazyLock<[Invariant; 1]> = LazyLock::new(|| {
2853        [Invariant { pred: Binder::bind_with_sort(Expr::ge(Expr::nu(), Expr::zero()), Sort::Int) }]
2854    });
2855    static OVERFLOW: LazyLock<[Invariant; 2]> = LazyLock::new(|| {
2856        [
2857            Invariant {
2858                pred: Binder::bind_with_sort(Expr::ge(Expr::nu(), Expr::zero()), Sort::Int),
2859            },
2860            Invariant {
2861                pred: Binder::bind_with_sort(
2862                    Expr::le(Expr::nu(), Expr::uint_max(UintTy::Usize)),
2863                    Sort::Int,
2864                ),
2865            },
2866        ]
2867    });
2868    if matches!(overflow_mode, OverflowMode::Strict | OverflowMode::Lazy) {
2869        &*OVERFLOW
2870    } else {
2871        &*DEFAULT
2872    }
2873}
2874
2875fn uint_invariants(uint_ty: UintTy, overflow_mode: OverflowMode) -> &'static [Invariant] {
2876    static DEFAULT: LazyLock<[Invariant; 1]> = LazyLock::new(|| {
2877        [Invariant { pred: Binder::bind_with_sort(Expr::ge(Expr::nu(), Expr::zero()), Sort::Int) }]
2878    });
2879
2880    static OVERFLOW: LazyLock<UnordMap<UintTy, [Invariant; 2]>> = LazyLock::new(|| {
2881        UINT_TYS
2882            .into_iter()
2883            .map(|uint_ty| {
2884                let invariants = [
2885                    Invariant {
2886                        pred: Binder::bind_with_sort(Expr::ge(Expr::nu(), Expr::zero()), Sort::Int),
2887                    },
2888                    Invariant {
2889                        pred: Binder::bind_with_sort(
2890                            Expr::le(Expr::nu(), Expr::uint_max(uint_ty)),
2891                            Sort::Int,
2892                        ),
2893                    },
2894                ];
2895                (uint_ty, invariants)
2896            })
2897            .collect()
2898    });
2899    if matches!(overflow_mode, OverflowMode::Strict | OverflowMode::Lazy) {
2900        &OVERFLOW[&uint_ty]
2901    } else {
2902        &*DEFAULT
2903    }
2904}
2905
2906fn char_invariants() -> &'static [Invariant] {
2907    static INVARIANTS: LazyLock<[Invariant; 2]> = LazyLock::new(|| {
2908        [
2909            Invariant {
2910                pred: Binder::bind_with_sort(
2911                    Expr::le(
2912                        Expr::cast(Sort::Char, Sort::Int, Expr::nu()),
2913                        Expr::constant((char::MAX as u32).into()),
2914                    ),
2915                    Sort::Int,
2916                ),
2917            },
2918            Invariant {
2919                pred: Binder::bind_with_sort(
2920                    Expr::le(Expr::zero(), Expr::cast(Sort::Char, Sort::Int, Expr::nu())),
2921                    Sort::Int,
2922                ),
2923            },
2924        ]
2925    });
2926    &*INVARIANTS
2927}
2928
2929fn int_invariants(int_ty: IntTy, overflow_mode: OverflowMode) -> &'static [Invariant] {
2930    static DEFAULT: [Invariant; 0] = [];
2931
2932    static OVERFLOW: LazyLock<UnordMap<IntTy, [Invariant; 2]>> = LazyLock::new(|| {
2933        INT_TYS
2934            .into_iter()
2935            .map(|int_ty| {
2936                let invariants = [
2937                    Invariant {
2938                        pred: Binder::bind_with_sort(
2939                            Expr::ge(Expr::nu(), Expr::int_min(int_ty)),
2940                            Sort::Int,
2941                        ),
2942                    },
2943                    Invariant {
2944                        pred: Binder::bind_with_sort(
2945                            Expr::le(Expr::nu(), Expr::int_max(int_ty)),
2946                            Sort::Int,
2947                        ),
2948                    },
2949                ];
2950                (int_ty, invariants)
2951            })
2952            .collect()
2953    });
2954    if matches!(overflow_mode, OverflowMode::Strict | OverflowMode::Lazy) {
2955        &OVERFLOW[&int_ty]
2956    } else {
2957        &DEFAULT
2958    }
2959}
2960
2961impl_internable!(AdtDefData, AdtSortDefData, TyKind);
2962impl_slice_internable!(
2963    Ty,
2964    GenericArg,
2965    Ensures,
2966    InferMode,
2967    Sort,
2968    SortArg,
2969    GenericParamDef,
2970    TraitRef,
2971    Binder<ExistentialPredicate>,
2972    Clause,
2973    PolyVariant,
2974    Invariant,
2975    RefineParam,
2976    FluxDefId,
2977    SortParamKind,
2978    AssocReft
2979);
2980
2981#[macro_export]
2982macro_rules! _Int {
2983    ($int_ty:pat, $idxs:pat) => {
2984        TyKind::Indexed(BaseTy::Int($int_ty), $idxs)
2985    };
2986}
2987pub use crate::_Int as Int;
2988
2989#[macro_export]
2990macro_rules! _Uint {
2991    ($uint_ty:pat, $idxs:pat) => {
2992        TyKind::Indexed(BaseTy::Uint($uint_ty), $idxs)
2993    };
2994}
2995pub use crate::_Uint as Uint;
2996
2997#[macro_export]
2998macro_rules! _Bool {
2999    ($idxs:pat) => {
3000        TyKind::Indexed(BaseTy::Bool, $idxs)
3001    };
3002}
3003pub use crate::_Bool as Bool;
3004
3005#[macro_export]
3006macro_rules! _Char {
3007    ($idxs:pat) => {
3008        TyKind::Indexed(BaseTy::Char, $idxs)
3009    };
3010}
3011pub use crate::_Char as Char;
3012
3013#[macro_export]
3014macro_rules! _Ref {
3015    ($($pats:pat),+ $(,)?) => {
3016        $crate::rty::TyKind::Indexed($crate::rty::BaseTy::Ref($($pats),+), _)
3017    };
3018}
3019pub use crate::_Ref as Ref;
3020
3021pub struct WfckResults {
3022    pub owner: FluxOwnerId,
3023    param_sorts: UnordMap<fhir::ParamId, Sort>,
3024    bin_op_sorts: ItemLocalMap<Sort>,
3025    fn_app_sorts: ItemLocalMap<List<SortArg>>,
3026    coercions: ItemLocalMap<Vec<Coercion>>,
3027    field_projs: ItemLocalMap<FieldProj>,
3028    node_sorts: ItemLocalMap<Sort>,
3029    record_ctors: ItemLocalMap<DefId>,
3030}
3031
3032#[derive(Clone, Copy, Debug)]
3033pub enum Coercion {
3034    Inject(DefId),
3035    Project(DefId),
3036}
3037
3038pub type ItemLocalMap<T> = UnordMap<fhir::ItemLocalId, T>;
3039
3040#[derive(Debug)]
3041pub struct LocalTableInContext<'a, T> {
3042    owner: FluxOwnerId,
3043    data: &'a ItemLocalMap<T>,
3044}
3045
3046pub struct LocalTableInContextMut<'a, T> {
3047    owner: FluxOwnerId,
3048    data: &'a mut ItemLocalMap<T>,
3049}
3050
3051impl WfckResults {
3052    pub fn new(owner: impl Into<FluxOwnerId>) -> Self {
3053        Self {
3054            owner: owner.into(),
3055            param_sorts: UnordMap::default(),
3056            bin_op_sorts: ItemLocalMap::default(),
3057            coercions: ItemLocalMap::default(),
3058            field_projs: ItemLocalMap::default(),
3059            node_sorts: ItemLocalMap::default(),
3060            record_ctors: ItemLocalMap::default(),
3061            fn_app_sorts: ItemLocalMap::default(),
3062        }
3063    }
3064
3065    pub fn param_sorts_mut(&mut self) -> &mut UnordMap<fhir::ParamId, Sort> {
3066        &mut self.param_sorts
3067    }
3068
3069    pub fn param_sorts(&self) -> &UnordMap<fhir::ParamId, Sort> {
3070        &self.param_sorts
3071    }
3072
3073    pub fn bin_op_sorts_mut(&mut self) -> LocalTableInContextMut<'_, Sort> {
3074        LocalTableInContextMut { owner: self.owner, data: &mut self.bin_op_sorts }
3075    }
3076
3077    pub fn fn_app_sorts_mut(&mut self) -> LocalTableInContextMut<'_, List<SortArg>> {
3078        LocalTableInContextMut { owner: self.owner, data: &mut self.fn_app_sorts }
3079    }
3080
3081    pub fn fn_app_sorts(&self) -> LocalTableInContext<'_, List<SortArg>> {
3082        LocalTableInContext { owner: self.owner, data: &self.fn_app_sorts }
3083    }
3084
3085    pub fn bin_op_sorts(&self) -> LocalTableInContext<'_, Sort> {
3086        LocalTableInContext { owner: self.owner, data: &self.bin_op_sorts }
3087    }
3088
3089    pub fn coercions_mut(&mut self) -> LocalTableInContextMut<'_, Vec<Coercion>> {
3090        LocalTableInContextMut { owner: self.owner, data: &mut self.coercions }
3091    }
3092
3093    pub fn coercions(&self) -> LocalTableInContext<'_, Vec<Coercion>> {
3094        LocalTableInContext { owner: self.owner, data: &self.coercions }
3095    }
3096
3097    pub fn field_projs_mut(&mut self) -> LocalTableInContextMut<'_, FieldProj> {
3098        LocalTableInContextMut { owner: self.owner, data: &mut self.field_projs }
3099    }
3100
3101    pub fn field_projs(&self) -> LocalTableInContext<'_, FieldProj> {
3102        LocalTableInContext { owner: self.owner, data: &self.field_projs }
3103    }
3104
3105    pub fn node_sorts_mut(&mut self) -> LocalTableInContextMut<'_, Sort> {
3106        LocalTableInContextMut { owner: self.owner, data: &mut self.node_sorts }
3107    }
3108
3109    pub fn node_sorts(&self) -> LocalTableInContext<'_, Sort> {
3110        LocalTableInContext { owner: self.owner, data: &self.node_sorts }
3111    }
3112
3113    pub fn record_ctors_mut(&mut self) -> LocalTableInContextMut<'_, DefId> {
3114        LocalTableInContextMut { owner: self.owner, data: &mut self.record_ctors }
3115    }
3116
3117    pub fn record_ctors(&self) -> LocalTableInContext<'_, DefId> {
3118        LocalTableInContext { owner: self.owner, data: &self.record_ctors }
3119    }
3120}
3121
3122impl<T> LocalTableInContextMut<'_, T> {
3123    pub fn insert(&mut self, fhir_id: FhirId, value: T) {
3124        tracked_span_assert_eq!(self.owner, fhir_id.owner);
3125        self.data.insert(fhir_id.local_id, value);
3126    }
3127}
3128
3129impl<'a, T> LocalTableInContext<'a, T> {
3130    pub fn get(&self, fhir_id: FhirId) -> Option<&'a T> {
3131        tracked_span_assert_eq!(self.owner, fhir_id.owner);
3132        self.data.get(&fhir_id.local_id)
3133    }
3134}
3135
3136fn can_auto_strong(fn_sig: &PolyFnSig) -> bool {
3137    struct RegionDetector {
3138        has_region: bool,
3139    }
3140
3141    impl fold::TypeFolder for RegionDetector {
3142        fn fold_region(&mut self, re: &Region) -> Region {
3143            self.has_region = true;
3144            *re
3145        }
3146    }
3147    let mut detector = RegionDetector { has_region: false };
3148    fn_sig
3149        .skip_binder_ref()
3150        .output()
3151        .skip_binder_ref()
3152        .ret
3153        .fold_with(&mut detector);
3154
3155    !detector.has_region
3156}
3157/// The [`auto_strong`] function transforms function signatures by automatically converting
3158/// mutable reference parameters into strong references with associated ensures clauses. This
3159/// transformation is applied only when the function signature does not already contain region
3160/// variables in its return type.
3161///
3162/// Specifically, given a source function of type
3163///
3164///    fn (x: &mut InnerTy) -> bool
3165///
3166/// By default the above gives us an `rty::FnSig`
3167///
3168///    forall<>. fn (x: &mut InnerTy) -> bool
3169///
3170/// Which this function then transforms to
3171///
3172///     forall<l0: Loc>. fn (x: &strg<l0:InnerTy>) -> bool ensures l0:InnerTy
3173pub fn auto_strong(
3174    genv: GlobalEnv,
3175    def_id: impl IntoQueryParam<DefId>,
3176    fn_sig: PolyFnSig,
3177) -> PolyFnSig {
3178    // TODO(auto-strong): we only *really* need the first check `can_auto_strong` here.
3179    // The other two skip `auto-strong` as doing it breaks various downstream things
3180    // that should be fixed.
3181    if !can_auto_strong(&fn_sig)
3182        || matches!(genv.def_kind(def_id), rustc_hir::def::DefKind::Closure)
3183        || !fn_sig.skip_binder_ref().lifted
3184    {
3185        return fn_sig;
3186    }
3187    let kind = BoundReftKind::Anon;
3188    let mut vars = fn_sig.vars().to_vec();
3189    let fn_sig = fn_sig.skip_binder();
3190    // new list of (bound_var, inner_ty)
3191    let mut strg_bvars = vec![];
3192    // new list of input types
3193    let mut strg_inputs = vec![];
3194    // 1. Traverse inputs collecting strong locations
3195    for ty in &fn_sig.inputs {
3196        let strg_ty = if let TyKind::Indexed(BaseTy::Ref(re, inner_ty, Mutability::Mut), _) =
3197            ty.kind()
3198            && !inner_ty.is_slice()
3199        // TODO(auto-strong): including `slice` breaks `tock` for some reason we should replicate in our own tests...
3200        {
3201            // if input is &mut InnerTy create a new bound var `loc` for the strong location
3202            let var = {
3203                let idx = vars.len() + strg_bvars.len();
3204                BoundVar::from_usize(idx)
3205            };
3206            strg_bvars.push((var, inner_ty.clone()));
3207            let loc = Loc::Var(Var::Bound(INNERMOST, BoundReft { var, kind }));
3208            // and transform to &strg<loc:InnerTy>
3209            Ty::strg_ref(*re, Path::new(loc, List::empty()), inner_ty.clone())
3210        } else {
3211            // else leave input type unchanged
3212            ty.clone()
3213        };
3214        strg_inputs.push(strg_ty);
3215    }
3216    // 2. Add bound vars for strong locations
3217    for _ in 0..strg_bvars.len() {
3218        vars.push(BoundVariableKind::Refine(Sort::Loc, InferMode::EVar, kind));
3219    }
3220    // 3. Add ensures for strong locations
3221    let output = fn_sig.output.map(|out| {
3222        let mut ens = out.ensures.to_vec();
3223        for (var, inner_ty) in strg_bvars {
3224            let loc = Loc::Var(Var::Bound(INNERMOST.shifted_in(1), BoundReft { var, kind }));
3225            let path = Path::new(loc, List::empty());
3226            ens.push(Ensures::Type(path, inner_ty.shift_in_escaping(1)));
3227        }
3228        FnOutput { ensures: List::from_vec(ens), ..out }
3229    });
3230
3231    // 4. Reconstruct fn sig with new inputs and output and vars
3232    let fn_sig = FnSig { inputs: List::from_vec(strg_inputs), output, ..fn_sig };
3233    Binder::bind_with_vars(fn_sig, vars.into())
3234}