flux_middle/rty/
mod.rs

1//! Defines how flux represents refinement types internally. Definitions in this module are used
2//! during refinement type checking. A couple of important differences between definitions in this
3//! module and in [`crate::fhir`] are:
4//!
5//! * Types in this module use debruijn indices to represent local binders.
6//! * Data structures are interned so they can be cheaply cloned.
7mod binder;
8pub mod canonicalize;
9mod expr;
10pub mod fold;
11pub mod normalize;
12mod pretty;
13pub mod refining;
14pub mod region_matching;
15pub mod subst;
16use std::{borrow::Cow, cmp::Ordering, fmt, hash::Hash, sync::LazyLock};
17
18pub use binder::{Binder, BoundReftKind, BoundVariableKind, BoundVariableKinds, EarlyBinder};
19use bitflags::bitflags;
20pub use expr::{
21    AggregateKind, AliasReft, BinOp, BoundReft, Constant, Ctor, ESpan, EVid, EarlyReftParam, Expr,
22    ExprKind, FieldProj, HoleKind, InternalFuncKind, KVar, KVid, Lambda, Loc, Name, NameProvenance,
23    Path, PrettyMap, PrettyVar, QuantDom, RawPtrField, Real, SpecFuncKind, UnOp, Var,
24};
25pub use flux_arc_interner::List;
26use flux_arc_interner::{Interned, impl_internable, impl_slice_internable};
27use flux_common::{bug, tracked_span_assert_eq, tracked_span_bug};
28use flux_config::OverflowMode;
29use flux_macros::{TypeFoldable, TypeVisitable};
30pub use flux_rustc_bridge::ty::{
31    AliasKind, BoundRegion, BoundRegionKind, BoundVar, Const, ConstKind, ConstVid, DebruijnIndex,
32    EarlyParamRegion, LateParamRegion, LateParamRegionKind,
33    Region::{self, *},
34    RegionVid,
35};
36use flux_rustc_bridge::{
37    ToRustc,
38    mir::{Place, RawPtrKind},
39    ty::{self, GenericArgsExt as _, VariantDef},
40};
41use itertools::Itertools;
42pub use normalize::{FuncInfo, NormalizedDefns, local_deps};
43use refining::Refiner;
44use rustc_abi;
45pub use rustc_abi::{FIRST_VARIANT, VariantIdx};
46use rustc_data_structures::{fx::FxIndexMap, snapshot_map::SnapshotMap, unord::UnordMap};
47use rustc_hir::{LangItem, Safety, def_id::DefId};
48use rustc_index::{IndexSlice, IndexVec, newtype_index};
49use rustc_macros::{Decodable, Encodable, TyDecodable, TyEncodable, extension};
50pub use rustc_middle::{
51    mir::Mutability,
52    ty::{AdtFlags, ClosureKind, FloatTy, IntTy, ParamConst, ParamTy, ScalarInt, UintTy},
53};
54use rustc_middle::{
55    query::IntoQueryParam,
56    ty::{TyCtxt, fast_reject::SimplifiedType},
57};
58use rustc_span::{DUMMY_SP, Span, Symbol, sym, symbol::kw};
59use rustc_type_ir::Upcast as _;
60pub use rustc_type_ir::{INNERMOST, TyVid};
61
62use self::fold::TypeFoldable;
63pub use crate::fhir::InferMode;
64use crate::{
65    LocalDefId,
66    def_id::{FluxDefId, FluxLocalDefId},
67    fhir::{self, FhirId, FluxOwnerId},
68    global_env::GlobalEnv,
69    pretty::{Pretty, PrettyCx},
70    queries::{QueryErr, QueryResult},
71    rty::subst::SortSubst,
72};
73
74/// The definition of the data sort automatically generated for a struct or enum.
75#[derive(Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
76pub struct AdtSortDef(Interned<AdtSortDefData>);
77
78#[derive(Debug, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
79pub struct AdtSortVariant {
80    /// The list of field names as declared in the `#[flux::refined_by(...)]` annotation
81    field_names: Vec<Symbol>,
82    /// The sort of each of the fields. Note that these can contain [sort variables]. Methods used
83    /// to access these sorts guarantee they are properly instantiated.
84    ///
85    /// [sort variables]: Sort::Var
86    sorts: List<Sort>,
87}
88
89impl AdtSortVariant {
90    pub fn new(fields: Vec<(Symbol, Sort)>) -> Self {
91        let (field_names, sorts) = fields.into_iter().unzip();
92        AdtSortVariant { field_names, sorts: List::from_vec(sorts) }
93    }
94
95    pub fn fields(&self) -> usize {
96        self.sorts.len()
97    }
98
99    pub fn field_names(&self) -> &Vec<Symbol> {
100        &self.field_names
101    }
102
103    pub fn sort_by_field_name(&self, args: &[Sort]) -> FxIndexMap<Symbol, Sort> {
104        std::iter::zip(&self.field_names, &self.sorts.fold_with(&mut SortSubst::new(args)))
105            .map(|(name, sort)| (*name, sort.clone()))
106            .collect()
107    }
108
109    pub fn field_by_name(
110        &self,
111        def_id: DefId,
112        args: &[Sort],
113        name: Symbol,
114    ) -> Option<(FieldProj, Sort)> {
115        let idx = self.field_names.iter().position(|it| name == *it)?;
116        let proj = FieldProj::Adt { def_id, field: idx as u32 };
117        let sort = self.sorts[idx].fold_with(&mut SortSubst::new(args));
118        Some((proj, sort))
119    }
120
121    pub fn field_sorts(&self, args: &[Sort]) -> List<Sort> {
122        self.sorts.fold_with(&mut SortSubst::new(args))
123    }
124
125    pub fn field_sorts_instantiate_identity(&self) -> List<Sort> {
126        self.sorts.clone()
127    }
128
129    pub fn projections(&self, def_id: DefId) -> impl Iterator<Item = FieldProj> {
130        (0..self.fields()).map(move |i| FieldProj::Adt { def_id, field: i as u32 })
131    }
132}
133
134#[derive(Debug, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
135struct AdtSortDefData {
136    /// [`DefId`] of the struct or enum this data sort is associated to.
137    def_id: DefId,
138    /// The list of the type parameters used in the `#[flux::refined_by(..)]` annotation.
139    ///
140    /// See [`fhir::RefinedBy::sort_params`] for more details. This is a version of that but using
141    /// [`ParamTy`] instead of [`DefId`].
142    ///
143    /// The length of this list corresponds to the number of sort variables bound by this definition.
144    params: Vec<ParamTy>,
145    /// A vec of variants of the ADT;
146    /// - a `struct` sort -- used for types with a `refined_by` has a single variant;
147    /// - a `reflected` sort -- used for `reflected` enums have multiple variants
148    variants: IndexVec<VariantIdx, AdtSortVariant>,
149    is_reflected: bool,
150    is_struct: bool,
151}
152
153impl AdtSortDef {
154    pub fn new(
155        def_id: DefId,
156        params: Vec<ParamTy>,
157        variants: IndexVec<VariantIdx, AdtSortVariant>,
158        is_reflected: bool,
159        is_struct: bool,
160    ) -> Self {
161        Self(Interned::new(AdtSortDefData { def_id, params, variants, is_reflected, is_struct }))
162    }
163
164    pub fn did(&self) -> DefId {
165        self.0.def_id
166    }
167
168    pub fn variant(&self, idx: VariantIdx) -> &AdtSortVariant {
169        &self.0.variants[idx]
170    }
171
172    pub fn variants(&self) -> &IndexSlice<VariantIdx, AdtSortVariant> {
173        &self.0.variants
174    }
175
176    pub fn opt_struct_variant(&self) -> Option<&AdtSortVariant> {
177        if self.is_struct() { Some(self.struct_variant()) } else { None }
178    }
179
180    #[track_caller]
181    pub fn struct_variant(&self) -> &AdtSortVariant {
182        tracked_span_assert_eq!(self.0.is_struct, true);
183        &self.0.variants[FIRST_VARIANT]
184    }
185
186    pub fn is_reflected(&self) -> bool {
187        self.0.is_reflected
188    }
189
190    pub fn is_struct(&self) -> bool {
191        self.0.is_struct
192    }
193
194    pub fn to_sort(&self, args: &[GenericArg]) -> Sort {
195        let sorts = self
196            .filter_generic_args(args)
197            .map(|arg| arg.expect_base().sort())
198            .collect();
199
200        Sort::App(SortCtor::Adt(self.clone()), sorts)
201    }
202
203    /// Given a list of generic args, returns an iterator of the generic arguments that should be
204    /// mapped to sorts for instantiation.
205    pub fn filter_generic_args<'a, A>(&'a self, args: &'a [A]) -> impl Iterator<Item = &'a A> + 'a {
206        self.0.params.iter().map(|p| &args[p.index as usize])
207    }
208
209    pub fn identity_args(&self) -> List<Sort> {
210        (0..self.0.params.len())
211            .map(|i| Sort::Var(ParamSort::from(i)))
212            .collect()
213    }
214
215    /// Gives the number of sort variables bound by this definition.
216    pub fn param_count(&self) -> usize {
217        self.0.params.len()
218    }
219}
220
221#[derive(Debug, Clone, Default, Encodable, Decodable)]
222pub struct Generics {
223    pub parent: Option<DefId>,
224    pub parent_count: usize,
225    pub own_params: List<GenericParamDef>,
226    pub has_self: bool,
227}
228
229impl Generics {
230    pub fn count(&self) -> usize {
231        self.parent_count + self.own_params.len()
232    }
233
234    pub fn own_default_count(&self) -> usize {
235        self.own_params
236            .iter()
237            .filter(|param| {
238                match param.kind {
239                    GenericParamDefKind::Type { has_default }
240                    | GenericParamDefKind::Const { has_default }
241                    | GenericParamDefKind::Base { has_default } => has_default,
242                    GenericParamDefKind::Lifetime => false,
243                }
244            })
245            .count()
246    }
247
248    pub fn param_at(&self, param_index: usize, genv: GlobalEnv) -> QueryResult<GenericParamDef> {
249        if let Some(index) = param_index.checked_sub(self.parent_count) {
250            Ok(self.own_params[index].clone())
251        } else {
252            let parent = self.parent.expect("parent_count > 0 but no parent?");
253            genv.generics_of(parent)?.param_at(param_index, genv)
254        }
255    }
256
257    pub fn const_params(&self, genv: GlobalEnv) -> QueryResult<Vec<(ParamConst, Sort)>> {
258        let mut res = vec![];
259        for i in 0..self.count() {
260            let param = self.param_at(i, genv)?;
261            if let GenericParamDefKind::Const { .. } = param.kind
262                && let Some(sort) = genv.sort_of_def_id(param.def_id)?
263            {
264                let param_const = ParamConst { name: param.name, index: param.index };
265                res.push((param_const, sort));
266            }
267        }
268        Ok(res)
269    }
270}
271
272#[derive(Debug, Clone, TyEncodable, TyDecodable)]
273pub struct RefinementGenerics {
274    pub parent: Option<DefId>,
275    pub parent_count: usize,
276    pub own_params: List<RefineParam>,
277}
278
279#[derive(
280    PartialEq, Eq, Debug, Clone, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
281)]
282pub struct RefineParam {
283    pub sort: Sort,
284    pub name: Symbol,
285    pub mode: InferMode,
286}
287
288#[derive(Debug, Clone, PartialEq, Eq, Hash, Encodable, Decodable)]
289pub struct GenericParamDef {
290    pub kind: GenericParamDefKind,
291    pub def_id: DefId,
292    pub index: u32,
293    pub name: Symbol,
294}
295
296#[derive(Copy, Debug, Clone, PartialEq, Eq, Hash, Encodable, Decodable)]
297pub enum GenericParamDefKind {
298    Type { has_default: bool },
299    Base { has_default: bool },
300    Lifetime,
301    Const { has_default: bool },
302}
303
304pub const SELF_PARAM_TY: ParamTy = ParamTy { index: 0, name: kw::SelfUpper };
305
306#[derive(Debug, Clone, TyEncodable, TyDecodable)]
307pub struct GenericPredicates {
308    pub parent: Option<DefId>,
309    pub predicates: List<Clause>,
310}
311
312#[derive(
313    Debug, PartialEq, Eq, Hash, Clone, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
314)]
315pub struct Clause {
316    kind: Binder<ClauseKind>,
317}
318
319impl Clause {
320    pub fn new(vars: impl Into<List<BoundVariableKind>>, kind: ClauseKind) -> Self {
321        Clause { kind: Binder::bind_with_vars(kind, vars.into()) }
322    }
323
324    pub fn kind(&self) -> Binder<ClauseKind> {
325        self.kind.clone()
326    }
327
328    fn as_trait_clause(&self) -> Option<Binder<TraitPredicate>> {
329        let clause = self.kind();
330        if let ClauseKind::Trait(trait_clause) = clause.skip_binder_ref() {
331            Some(clause.rebind(trait_clause.clone()))
332        } else {
333            None
334        }
335    }
336
337    pub fn as_projection_clause(&self) -> Option<Binder<ProjectionPredicate>> {
338        let clause = self.kind();
339        if let ClauseKind::Projection(proj_clause) = clause.skip_binder_ref() {
340            Some(clause.rebind(proj_clause.clone()))
341        } else {
342            None
343        }
344    }
345
346    // FIXME(nilehmann) we should deal with the binder in all the places this is used instead of
347    // blindly skipping it here
348    pub fn kind_skipping_binder(&self) -> ClauseKind {
349        self.kind.clone().skip_binder()
350    }
351
352    /// Group `Fn` trait clauses with their corresponding `FnOnce::Output` projection
353    /// predicate. This assumes there's exactly one corresponding projection predicate and will
354    /// crash otherwise.
355    pub fn split_off_fn_trait_clauses(
356        genv: GlobalEnv,
357        clauses: &Clauses,
358    ) -> (Vec<Clause>, Vec<Binder<FnTraitPredicate>>) {
359        let mut fn_trait_clauses = vec![];
360        let mut fn_trait_output_clauses = vec![];
361        let mut rest = vec![];
362        for clause in clauses {
363            if let Some(trait_clause) = clause.as_trait_clause()
364                && let Some(kind) = genv.tcx().fn_trait_kind_from_def_id(trait_clause.def_id())
365            {
366                fn_trait_clauses.push((kind, trait_clause));
367            } else if let Some(proj_clause) = clause.as_projection_clause()
368                && genv.is_fn_output(proj_clause.projection_def_id())
369            {
370                fn_trait_output_clauses.push(proj_clause);
371            } else {
372                rest.push(clause.clone());
373            }
374        }
375        let fn_trait_clauses = fn_trait_clauses
376            .into_iter()
377            .map(|(kind, fn_trait_clause)| {
378                let mut candidates = vec![];
379                for fn_trait_output_clause in &fn_trait_output_clauses {
380                    if fn_trait_output_clause.self_ty() == fn_trait_clause.self_ty() {
381                        candidates.push(fn_trait_output_clause.clone());
382                    }
383                }
384                tracked_span_assert_eq!(candidates.len(), 1);
385                let proj_pred = candidates.pop().unwrap().skip_binder();
386                fn_trait_clause.map(|fn_trait_clause| {
387                    FnTraitPredicate {
388                        kind,
389                        self_ty: fn_trait_clause.self_ty().to_ty(),
390                        tupled_args: fn_trait_clause.trait_ref.args[1].expect_base().to_ty(),
391                        output: proj_pred.term.to_ty(),
392                    }
393                })
394            })
395            .collect_vec();
396        (rest, fn_trait_clauses)
397    }
398}
399
400impl<'tcx> ToRustc<'tcx> for Clause {
401    type T = rustc_middle::ty::Clause<'tcx>;
402
403    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
404        self.kind.to_rustc(tcx).upcast(tcx)
405    }
406}
407
408impl From<Binder<ClauseKind>> for Clause {
409    fn from(kind: Binder<ClauseKind>) -> Self {
410        Clause { kind }
411    }
412}
413
414pub type Clauses = List<Clause>;
415
416#[derive(
417    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
418)]
419pub enum ClauseKind {
420    Trait(TraitPredicate),
421    Projection(ProjectionPredicate),
422    RegionOutlives(RegionOutlivesPredicate),
423    TypeOutlives(TypeOutlivesPredicate),
424    ConstArgHasType(Const, Ty),
425    UnstableFeature(Symbol),
426}
427
428impl<'tcx> ToRustc<'tcx> for ClauseKind {
429    type T = rustc_middle::ty::ClauseKind<'tcx>;
430
431    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
432        match self {
433            ClauseKind::Trait(trait_predicate) => {
434                rustc_middle::ty::ClauseKind::Trait(trait_predicate.to_rustc(tcx))
435            }
436            ClauseKind::Projection(projection_predicate) => {
437                rustc_middle::ty::ClauseKind::Projection(projection_predicate.to_rustc(tcx))
438            }
439            ClauseKind::RegionOutlives(outlives_predicate) => {
440                rustc_middle::ty::ClauseKind::RegionOutlives(outlives_predicate.to_rustc(tcx))
441            }
442            ClauseKind::TypeOutlives(outlives_predicate) => {
443                rustc_middle::ty::ClauseKind::TypeOutlives(outlives_predicate.to_rustc(tcx))
444            }
445            ClauseKind::ConstArgHasType(constant, ty) => {
446                rustc_middle::ty::ClauseKind::ConstArgHasType(
447                    constant.to_rustc(tcx),
448                    ty.to_rustc(tcx),
449                )
450            }
451            ClauseKind::UnstableFeature(sym) => rustc_middle::ty::ClauseKind::UnstableFeature(*sym),
452        }
453    }
454}
455
456#[derive(Eq, PartialEq, Hash, Clone, Debug, TyEncodable, TyDecodable)]
457pub struct OutlivesPredicate<T>(pub T, pub Region);
458
459pub type TypeOutlivesPredicate = OutlivesPredicate<Ty>;
460pub type RegionOutlivesPredicate = OutlivesPredicate<Region>;
461
462impl<'tcx, V: ToRustc<'tcx>> ToRustc<'tcx> for OutlivesPredicate<V> {
463    type T = rustc_middle::ty::OutlivesPredicate<'tcx, V::T>;
464
465    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
466        rustc_middle::ty::OutlivesPredicate(self.0.to_rustc(tcx), self.1.to_rustc(tcx))
467    }
468}
469
470#[derive(
471    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
472)]
473pub struct TraitPredicate {
474    pub trait_ref: TraitRef,
475}
476
477impl TraitPredicate {
478    fn self_ty(&self) -> SubsetTyCtor {
479        self.trait_ref.self_ty()
480    }
481}
482
483impl<'tcx> ToRustc<'tcx> for TraitPredicate {
484    type T = rustc_middle::ty::TraitPredicate<'tcx>;
485
486    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
487        rustc_middle::ty::TraitPredicate {
488            polarity: rustc_middle::ty::PredicatePolarity::Positive,
489            trait_ref: self.trait_ref.to_rustc(tcx),
490        }
491    }
492}
493
494pub type PolyTraitPredicate = Binder<TraitPredicate>;
495
496impl PolyTraitPredicate {
497    fn def_id(&self) -> DefId {
498        self.skip_binder_ref().trait_ref.def_id
499    }
500
501    fn self_ty(&self) -> Binder<SubsetTyCtor> {
502        self.clone().map(|predicate| predicate.self_ty())
503    }
504}
505
506#[derive(
507    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
508)]
509pub struct TraitRef {
510    pub def_id: DefId,
511    pub args: GenericArgs,
512}
513
514impl TraitRef {
515    pub fn self_ty(&self) -> SubsetTyCtor {
516        self.args[0].expect_base().clone()
517    }
518}
519
520impl<'tcx> ToRustc<'tcx> for TraitRef {
521    type T = rustc_middle::ty::TraitRef<'tcx>;
522
523    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
524        rustc_middle::ty::TraitRef::new(tcx, self.def_id, self.args.to_rustc(tcx))
525    }
526}
527
528pub type PolyTraitRef = Binder<TraitRef>;
529
530impl PolyTraitRef {
531    pub fn def_id(&self) -> DefId {
532        self.as_ref().skip_binder().def_id
533    }
534}
535
536#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
537pub enum ExistentialPredicate {
538    Trait(ExistentialTraitRef),
539    Projection(ExistentialProjection),
540    AutoTrait(DefId),
541}
542
543pub type PolyExistentialPredicate = Binder<ExistentialPredicate>;
544
545impl ExistentialPredicate {
546    /// See [`rustc_middle::ty::ExistentialPredicateStableCmpExt`]
547    pub fn stable_cmp(&self, tcx: TyCtxt, other: &Self) -> Ordering {
548        match (self, other) {
549            (ExistentialPredicate::Trait(_), ExistentialPredicate::Trait(_)) => Ordering::Equal,
550            (ExistentialPredicate::Projection(a), ExistentialPredicate::Projection(b)) => {
551                tcx.def_path_hash(a.def_id)
552                    .cmp(&tcx.def_path_hash(b.def_id))
553            }
554            (ExistentialPredicate::AutoTrait(a), ExistentialPredicate::AutoTrait(b)) => {
555                tcx.def_path_hash(*a).cmp(&tcx.def_path_hash(*b))
556            }
557            (ExistentialPredicate::Trait(_), _) => Ordering::Less,
558            (ExistentialPredicate::Projection(_), ExistentialPredicate::Trait(_)) => {
559                Ordering::Greater
560            }
561            (ExistentialPredicate::Projection(_), _) => Ordering::Less,
562            (ExistentialPredicate::AutoTrait(_), _) => Ordering::Greater,
563        }
564    }
565}
566
567impl<'tcx> ToRustc<'tcx> for ExistentialPredicate {
568    type T = rustc_middle::ty::ExistentialPredicate<'tcx>;
569
570    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
571        match self {
572            ExistentialPredicate::Trait(trait_ref) => {
573                let trait_ref = rustc_middle::ty::ExistentialTraitRef::new_from_args(
574                    tcx,
575                    trait_ref.def_id,
576                    trait_ref.args.to_rustc(tcx),
577                );
578                rustc_middle::ty::ExistentialPredicate::Trait(trait_ref)
579            }
580            ExistentialPredicate::Projection(projection) => {
581                rustc_middle::ty::ExistentialPredicate::Projection(
582                    rustc_middle::ty::ExistentialProjection::new_from_args(
583                        tcx,
584                        projection.def_id,
585                        projection.args.to_rustc(tcx),
586                        projection.term.skip_binder_ref().to_rustc(tcx).into(),
587                    ),
588                )
589            }
590            ExistentialPredicate::AutoTrait(def_id) => {
591                rustc_middle::ty::ExistentialPredicate::AutoTrait(*def_id)
592            }
593        }
594    }
595}
596
597#[derive(
598    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
599)]
600pub struct ExistentialTraitRef {
601    pub def_id: DefId,
602    pub args: GenericArgs,
603}
604
605pub type PolyExistentialTraitRef = Binder<ExistentialTraitRef>;
606
607impl PolyExistentialTraitRef {
608    pub fn def_id(&self) -> DefId {
609        self.as_ref().skip_binder().def_id
610    }
611}
612
613#[derive(
614    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
615)]
616pub struct ExistentialProjection {
617    pub def_id: DefId,
618    pub args: GenericArgs,
619    pub term: SubsetTyCtor,
620}
621
622#[derive(
623    PartialEq, Eq, Hash, Debug, Clone, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
624)]
625pub struct ProjectionPredicate {
626    pub projection_ty: AliasTy,
627    pub term: SubsetTyCtor,
628}
629
630impl Pretty for ProjectionPredicate {
631    fn fmt(&self, _cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
632        write!(
633            f,
634            "ProjectionPredicate << projection_ty = {:?}, term = {:?} >>",
635            self.projection_ty, self.term
636        )
637    }
638}
639
640impl ProjectionPredicate {
641    pub fn self_ty(&self) -> SubsetTyCtor {
642        self.projection_ty.self_ty().clone()
643    }
644}
645
646impl<'tcx> ToRustc<'tcx> for ProjectionPredicate {
647    type T = rustc_middle::ty::ProjectionPredicate<'tcx>;
648
649    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
650        rustc_middle::ty::ProjectionPredicate {
651            projection_term: rustc_middle::ty::AliasTerm::new_from_args(
652                tcx,
653                self.projection_ty.def_id,
654                self.projection_ty.args.to_rustc(tcx),
655            ),
656            term: self.term.as_bty_skipping_binder().to_rustc(tcx).into(),
657        }
658    }
659}
660
661pub type PolyProjectionPredicate = Binder<ProjectionPredicate>;
662
663impl PolyProjectionPredicate {
664    pub fn projection_def_id(&self) -> DefId {
665        self.skip_binder_ref().projection_ty.def_id
666    }
667
668    pub fn self_ty(&self) -> Binder<SubsetTyCtor> {
669        self.clone().map(|predicate| predicate.self_ty())
670    }
671}
672
673#[derive(
674    Clone, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
675)]
676pub struct FnTraitPredicate {
677    pub self_ty: Ty,
678    pub tupled_args: Ty,
679    pub output: Ty,
680    pub kind: ClosureKind,
681}
682
683impl Pretty for FnTraitPredicate {
684    fn fmt(&self, _cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
685        write!(
686            f,
687            "self = {:?}, args = {:?}, output = {:?}, kind = {}",
688            self.self_ty, self.tupled_args, self.output, self.kind
689        )
690    }
691}
692
693impl FnTraitPredicate {
694    pub fn fndef_sig(&self) -> FnSig {
695        let inputs = self.tupled_args.expect_tuple().iter().cloned().collect();
696        let ret = self.output.clone().shift_in_escaping(1);
697        let output = Binder::bind_with_vars(FnOutput::new(ret, vec![]), List::empty());
698        FnSig::new(
699            Safety::Safe,
700            rustc_abi::ExternAbi::Rust,
701            List::empty(),
702            inputs,
703            output,
704            Expr::ff(),
705            false,
706        )
707    }
708}
709
710pub fn to_closure_sig(
711    tcx: TyCtxt,
712    closure_id: LocalDefId,
713    tys: &[Ty],
714    args: &flux_rustc_bridge::ty::GenericArgs,
715    poly_sig: &PolyFnSig,
716    no_panic: bool,
717) -> PolyFnSig {
718    let closure_args = args.as_closure();
719    let kind_ty = closure_args.kind_ty().to_rustc(tcx);
720    let Some(kind) = kind_ty.to_opt_closure_kind() else {
721        bug!("to_closure_sig: expected closure kind, found {kind_ty:?}");
722    };
723
724    let mut vars = poly_sig.vars().clone().to_vec();
725    let fn_sig = poly_sig.clone().skip_binder();
726    let closure_ty = Ty::closure(closure_id.into(), tys, args, no_panic);
727    let env_ty = match kind {
728        ClosureKind::Fn => {
729            vars.push(BoundVariableKind::Region(BoundRegionKind::ClosureEnv));
730            let br = BoundRegion {
731                var: BoundVar::from_usize(vars.len() - 1),
732                kind: BoundRegionKind::ClosureEnv,
733            };
734            Ty::mk_ref(ReBound(INNERMOST, br), closure_ty, Mutability::Not)
735        }
736        ClosureKind::FnMut => {
737            vars.push(BoundVariableKind::Region(BoundRegionKind::ClosureEnv));
738            let br = BoundRegion {
739                var: BoundVar::from_usize(vars.len() - 1),
740                kind: BoundRegionKind::ClosureEnv,
741            };
742            Ty::mk_ref(ReBound(INNERMOST, br), closure_ty, Mutability::Mut)
743        }
744        ClosureKind::FnOnce => closure_ty,
745    };
746
747    let inputs = std::iter::once(env_ty)
748        .chain(fn_sig.inputs().iter().cloned())
749        .collect::<Vec<_>>();
750    let output = fn_sig.output().clone();
751
752    let fn_sig = crate::rty::FnSig::new(
753        fn_sig.safety,
754        fn_sig.abi,
755        fn_sig.requires.clone(),
756        inputs.into(),
757        output,
758        if no_panic { crate::rty::Expr::tt() } else { crate::rty::Expr::ff() },
759        false,
760    );
761
762    PolyFnSig::bind_with_vars(fn_sig, List::from(vars))
763}
764
765#[derive(Clone, PartialEq, Eq, Hash, Debug)]
766pub struct CoroutineObligPredicate {
767    pub def_id: DefId,
768    pub resume_ty: Ty,
769    pub upvar_tys: List<Ty>,
770    pub output: Ty,
771    pub args: flux_rustc_bridge::ty::GenericArgs,
772}
773
774#[derive(Copy, Clone, Encodable, Decodable, Hash, PartialEq, Debug, Eq)]
775pub struct AssocReft {
776    pub def_id: FluxDefId,
777    // NOTE: Field is used to denote final associated generic refinements on Traits
778    pub final_: bool,
779    pub span: Span,
780}
781
782impl AssocReft {
783    pub fn new(def_id: FluxDefId, final_: bool, span: Span) -> Self {
784        Self { def_id, final_, span }
785    }
786
787    pub fn name(&self) -> Symbol {
788        self.def_id.name()
789    }
790
791    pub fn def_id(&self) -> FluxDefId {
792        self.def_id
793    }
794}
795
796#[derive(Clone, Encodable, Decodable)]
797pub struct AssocRefinements {
798    pub items: List<AssocReft>,
799}
800
801impl Default for AssocRefinements {
802    fn default() -> Self {
803        Self { items: List::empty() }
804    }
805}
806
807impl AssocRefinements {
808    pub fn get(&self, assoc_id: FluxDefId) -> AssocReft {
809        *self
810            .items
811            .into_iter()
812            .find(|it| it.def_id == assoc_id)
813            .unwrap_or_else(|| {
814                bug!("caller should guarantee existence of associated refinement {assoc_id:?}")
815            })
816    }
817
818    pub fn find(&self, name: Symbol) -> Option<AssocReft> {
819        Some(*self.items.into_iter().find(|it| it.name() == name)?)
820    }
821}
822
823#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
824pub enum SortCtor {
825    Set,
826    Map,
827    Adt(AdtSortDef),
828    User(FluxDefId),
829}
830
831newtype_index! {
832    /// [`ParamSort`] is used for polymorphic sorts (`Set`, `Map`, etc.) and [bit-vector size parameters].
833    /// They should occur "bound" under a [`PolyFuncSort`] or an [`AdtSortDef`]. We assume there's a
834    /// single binder and a [`ParamSort`] represents a variable as an index into the list of variables
835    /// bound by that binder, i.e., the representation doesnt't support higher-ranked sorts.
836    ///
837    /// [bit-vector size parameters]: BvSize::Param
838    #[debug_format = "?{}s"]
839    #[encodable]
840    pub struct ParamSort {}
841}
842
843newtype_index! {
844    /// A *sort* *v*variable *id*
845    #[debug_format = "?{}s"]
846    #[encodable]
847    pub struct SortVid {}
848}
849
850impl ena::unify::UnifyKey for SortVid {
851    type Value = SortVarVal;
852
853    #[inline]
854    fn index(&self) -> u32 {
855        self.as_u32()
856    }
857
858    #[inline]
859    fn from_index(u: u32) -> Self {
860        SortVid::from_u32(u)
861    }
862
863    fn tag() -> &'static str {
864        "SortVid"
865    }
866}
867
868bitflags! {
869    /// A *sort constraint* is a set of operations a sort must support.
870    ///
871    /// During sort checking, we accumulate the operations required for each sort variable. As
872    /// unification progresses, these constraints become more specific, i.e, a sort must support
873    /// more operations to satisfy them.
874    #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
875    pub struct SortCstr: u16 {
876        /// An empty constraint (any sort satisfies it)
877        const BOT     = 0b0000000000;
878        /// `*`
879        const MUL     = 0b0000000001;
880        /// `/`
881        const DIV     = 0b0000000010;
882        /// `%`
883        const MOD     = 0b0000000100;
884        /// `+`
885        const ADD     = 0b0000001000;
886        /// `-`
887        const SUB     = 0b0000010000;
888        /// `|`
889        const BIT_OR  = 0b0000100000;
890        /// `&`
891        const BIT_AND = 0b0001000000;
892        /// `>>`
893        const BIT_SHL = 0b0010000000;
894        /// `<<`
895        const BIT_SHR = 0b0100000000;
896        /// `^`
897        const BIT_XOR = 0b1000000000;
898
899        /// The set of operations supported by all _numeric_ sorts.
900        const NUMERIC = Self::ADD.bits() | Self::SUB.bits() | Self::MUL.bits() | Self::DIV.bits();
901        /// The set of operations supported by integers.
902        const INT = Self::DIV.bits()
903            | Self::MUL.bits()
904            | Self::MOD.bits()
905            | Self::ADD.bits()
906            | Self::SUB.bits();
907        /// The set of operations supported by reals.
908        const REAL = Self::ADD.bits() | Self::SUB.bits() | Self::MUL.bits() | Self::DIV.bits();
909        /// The set of operations supported by bit vectors.
910        const BITVEC =  Self::DIV.bits()
911            | Self::MUL.bits()
912            | Self::MOD.bits()
913            | Self::ADD.bits()
914            | Self::SUB.bits()
915            | Self::BIT_OR.bits()
916            | Self::BIT_AND.bits()
917            | Self::BIT_SHL.bits()
918            | Self::BIT_SHR.bits()
919            | Self::BIT_XOR.bits();
920        /// The set of operations supported by sets.
921        const SET = Self::SUB.bits() | Self::BIT_OR.bits() | Self::BIT_AND.bits();
922    }
923}
924
925impl SortCstr {
926    /// Returns a constraint that only requires the specified binary operation.
927    pub fn from_bin_op(op: fhir::BinOp) -> Self {
928        match op {
929            fhir::BinOp::Add => Self::ADD,
930            fhir::BinOp::Sub => Self::SUB,
931            fhir::BinOp::Mul => Self::MUL,
932            fhir::BinOp::Div => Self::DIV,
933            fhir::BinOp::Mod => Self::MOD,
934            fhir::BinOp::BitAnd => Self::BIT_AND,
935            fhir::BinOp::BitOr => Self::BIT_OR,
936            fhir::BinOp::BitXor => Self::BIT_XOR,
937            fhir::BinOp::BitShl => Self::BIT_SHL,
938            fhir::BinOp::BitShr => Self::BIT_SHR,
939            _ => bug!("{op:?} not supported as a constraint"),
940        }
941    }
942
943    /// Returns whether a sort satisfies this constraint
944    fn satisfy(self, sort: &Sort) -> bool {
945        match sort {
946            Sort::Int => SortCstr::INT.contains(self),
947            Sort::Real => SortCstr::REAL.contains(self),
948            Sort::BitVec(_) => SortCstr::BITVEC.contains(self),
949            Sort::App(SortCtor::Set, _) => SortCstr::SET.contains(self),
950            _ => self == SortCstr::BOT,
951        }
952    }
953}
954
955/// Unification value for sort variables used during sort checking.
956#[derive(Debug, Clone, PartialEq, Eq)]
957pub enum SortVarVal {
958    /// The variable is not yet solved but the solution must satisfy some constraint.
959    Unsolved(SortCstr),
960    /// The variable has been solved to a sort.
961    Solved(Sort),
962}
963
964impl Default for SortVarVal {
965    fn default() -> Self {
966        SortVarVal::Unsolved(SortCstr::BOT)
967    }
968}
969
970impl SortVarVal {
971    pub fn solved_or(&self, sort: &Sort) -> Sort {
972        match self {
973            SortVarVal::Unsolved(_) => sort.clone(),
974            SortVarVal::Solved(sort) => sort.clone(),
975        }
976    }
977
978    pub fn map_solved(&self, f: impl FnOnce(&Sort) -> Sort) -> SortVarVal {
979        match self {
980            SortVarVal::Unsolved(cstr) => SortVarVal::Unsolved(*cstr),
981            SortVarVal::Solved(sort) => SortVarVal::Solved(f(sort)),
982        }
983    }
984}
985
986impl ena::unify::UnifyValue for SortVarVal {
987    type Error = ();
988
989    fn unify_values(value1: &Self, value2: &Self) -> Result<Self, Self::Error> {
990        match (value1, value2) {
991            (SortVarVal::Solved(s1), SortVarVal::Solved(s2)) if s1 == s2 => {
992                Ok(SortVarVal::Solved(s1.clone()))
993            }
994            (SortVarVal::Unsolved(a), SortVarVal::Unsolved(b)) => Ok(SortVarVal::Unsolved(*a | *b)),
995            (SortVarVal::Unsolved(v), SortVarVal::Solved(sort))
996            | (SortVarVal::Solved(sort), SortVarVal::Unsolved(v))
997                if v.satisfy(sort) =>
998            {
999                Ok(SortVarVal::Solved(sort.clone()))
1000            }
1001            _ => Err(()),
1002        }
1003    }
1004}
1005
1006newtype_index! {
1007    /// A *b*it *v*ector *size* *v*variable *id*
1008    #[debug_format = "?{}size"]
1009    #[encodable]
1010    pub struct BvSizeVid {}
1011}
1012
1013impl ena::unify::UnifyKey for BvSizeVid {
1014    type Value = Option<BvSize>;
1015
1016    #[inline]
1017    fn index(&self) -> u32 {
1018        self.as_u32()
1019    }
1020
1021    #[inline]
1022    fn from_index(u: u32) -> Self {
1023        BvSizeVid::from_u32(u)
1024    }
1025
1026    fn tag() -> &'static str {
1027        "BvSizeVid"
1028    }
1029}
1030
1031impl ena::unify::EqUnifyValue for BvSize {}
1032
1033#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1034pub enum Sort {
1035    Int,
1036    Bool,
1037    Real,
1038    BitVec(BvSize),
1039    Str,
1040    Char,
1041    Loc,
1042    Param(ParamTy),
1043    Tuple(List<Sort>),
1044    Alias(AliasKind, AliasTy),
1045    Func(PolyFuncSort),
1046    App(SortCtor, List<Sort>),
1047    Var(ParamSort),
1048    Infer(SortVid),
1049    RawPtr,
1050    Err,
1051}
1052
1053pub enum CastKind {
1054    /// Identity cast, which is erasable (e.g. int -> int, char -> int)
1055    Identity,
1056    /// From bool to int
1057    BoolToInt,
1058    /// Casts to unit index, (e.g. int -> float)
1059    IntoUnit,
1060    /// Uninterpreted casts, only allowed with explicit flag
1061    Uninterpreted,
1062}
1063
1064impl Sort {
1065    pub fn tuple(sorts: impl Into<List<Sort>>) -> Self {
1066        Sort::Tuple(sorts.into())
1067    }
1068
1069    pub fn app(ctor: SortCtor, sorts: List<Sort>) -> Self {
1070        Sort::App(ctor, sorts)
1071    }
1072
1073    pub fn unit() -> Self {
1074        Self::tuple(vec![])
1075    }
1076
1077    pub fn field_sorts(&self) -> Option<List<Sort>> {
1078        match self {
1079            Sort::RawPtr => Some(RawPtrField::iter().map(RawPtrField::sort).collect()),
1080            Sort::App(SortCtor::Adt(sort_def), args) if sort_def.is_struct() => {
1081                Some(sort_def.struct_variant().field_sorts(args))
1082            }
1083            _ => None,
1084        }
1085    }
1086
1087    #[track_caller]
1088    pub fn expect_func(&self) -> &PolyFuncSort {
1089        if let Sort::Func(sort) = self { sort } else { bug!("expected `Sort::Func`") }
1090    }
1091
1092    pub fn is_loc(&self) -> bool {
1093        matches!(self, Sort::Loc)
1094    }
1095
1096    pub fn is_unit(&self) -> bool {
1097        matches!(self, Sort::Tuple(sorts) if sorts.is_empty())
1098    }
1099
1100    pub fn is_unit_adt(&self) -> Option<DefId> {
1101        if let Sort::App(SortCtor::Adt(sort_def), _) = self
1102            && let Some(variant) = sort_def.opt_struct_variant()
1103            && variant.fields() == 0
1104        {
1105            Some(sort_def.did())
1106        } else {
1107            None
1108        }
1109    }
1110
1111    /// Whether the sort is a function with return sort bool
1112    pub fn is_pred(&self) -> bool {
1113        matches!(self, Sort::Func(fsort) if fsort.skip_binders().output().is_bool())
1114    }
1115
1116    /// Returns `true` if the sort is [`Bool`].
1117    ///
1118    /// [`Bool`]: Sort::Bool
1119    #[must_use]
1120    pub fn is_bool(&self) -> bool {
1121        matches!(self, Self::Bool)
1122    }
1123
1124    pub fn cast_kind(self: &Sort, to: &Sort) -> CastKind {
1125        if self == to
1126            || (matches!(self, Sort::Char | Sort::Int) && matches!(to, Sort::Char | Sort::Int))
1127        {
1128            CastKind::Identity
1129        } else if matches!(self, Sort::Bool) && matches!(to, Sort::Int) {
1130            CastKind::BoolToInt
1131        } else if to.is_unit() {
1132            CastKind::IntoUnit
1133        } else {
1134            CastKind::Uninterpreted
1135        }
1136    }
1137
1138    pub fn walk(&self, mut f: impl FnMut(&Sort, &[FieldProj])) {
1139        fn go(sort: &Sort, f: &mut impl FnMut(&Sort, &[FieldProj]), proj: &mut Vec<FieldProj>) {
1140            match sort {
1141                Sort::Tuple(flds) => {
1142                    for (i, sort) in flds.iter().enumerate() {
1143                        proj.push(FieldProj::Tuple { arity: flds.len(), field: i as u32 });
1144                        go(sort, f, proj);
1145                        proj.pop();
1146                    }
1147                }
1148                Sort::App(SortCtor::Adt(sort_def), args) if sort_def.is_struct() => {
1149                    let field_sorts = sort_def.struct_variant().field_sorts(args);
1150                    for (i, sort) in field_sorts.iter().enumerate() {
1151                        proj.push(FieldProj::Adt { def_id: sort_def.did(), field: i as u32 });
1152                        go(sort, f, proj);
1153                        proj.pop();
1154                    }
1155                }
1156                Sort::RawPtr => {
1157                    for field in RawPtrField::iter() {
1158                        let sort = field.sort();
1159                        proj.push(FieldProj::RawPtr { field });
1160                        go(&sort, f, proj);
1161                        proj.pop();
1162                    }
1163                }
1164                _ => {
1165                    f(sort, proj);
1166                }
1167            }
1168        }
1169        go(self, &mut f, &mut vec![]);
1170    }
1171}
1172
1173/// The size of a [bit-vector]
1174///
1175/// [bit-vector]: Sort::BitVec
1176#[derive(Clone, Copy, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1177pub enum BvSize {
1178    /// A fixed size
1179    Fixed(u32),
1180    /// A size that has been parameterized, e.g., bound under a [`PolyFuncSort`]
1181    Param(ParamSort),
1182    /// A size that needs to be inferred. Used during sort checking to instantiate bit-vector
1183    /// sizes at call-sites.
1184    Infer(BvSizeVid),
1185}
1186
1187impl rustc_errors::IntoDiagArg for Sort {
1188    fn into_diag_arg(self, _path: &mut Option<std::path::PathBuf>) -> rustc_errors::DiagArgValue {
1189        rustc_errors::DiagArgValue::Str(Cow::Owned(format!("{self:?}")))
1190    }
1191}
1192
1193impl rustc_errors::IntoDiagArg for FuncSort {
1194    fn into_diag_arg(self, _path: &mut Option<std::path::PathBuf>) -> rustc_errors::DiagArgValue {
1195        rustc_errors::DiagArgValue::Str(Cow::Owned(format!("{self:?}")))
1196    }
1197}
1198
1199#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
1200pub struct FuncSort {
1201    pub inputs_and_output: List<Sort>,
1202}
1203
1204impl FuncSort {
1205    pub fn new(mut inputs: Vec<Sort>, output: Sort) -> Self {
1206        inputs.push(output);
1207        FuncSort { inputs_and_output: List::from_vec(inputs) }
1208    }
1209
1210    pub fn inputs(&self) -> &[Sort] {
1211        &self.inputs_and_output[0..self.inputs_and_output.len() - 1]
1212    }
1213
1214    pub fn output(&self) -> &Sort {
1215        &self.inputs_and_output[self.inputs_and_output.len() - 1]
1216    }
1217
1218    pub fn to_poly(&self) -> PolyFuncSort {
1219        PolyFuncSort::new(List::empty(), self.clone())
1220    }
1221}
1222
1223/// See [`PolyFuncSort`]
1224#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable)]
1225pub enum SortParamKind {
1226    Sort,
1227    BvSize,
1228}
1229
1230/// A polymorphic function sort parametric over [sorts] or [bit-vector sizes].
1231///
1232/// Parameterizing over bit-vector sizes is a bit of a stretch, because smtlib doesn't support full
1233/// parametric reasoning over them. As long as we used functions parameterized over a size monomorphically
1234/// we should be fine. Right now, we can guarantee this, because size parameters are not exposed in
1235/// the surface syntax and they are only used for predefined (interpreted) theory functions.
1236///
1237/// [sorts]: Sort
1238/// [bit-vector sizes]: BvSize::Param
1239#[derive(Clone, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable)]
1240pub struct PolyFuncSort {
1241    /// The list of parameters including sorts and bit vector sizes
1242    params: List<SortParamKind>,
1243    fsort: FuncSort,
1244}
1245
1246impl PolyFuncSort {
1247    pub fn new(params: List<SortParamKind>, fsort: FuncSort) -> Self {
1248        PolyFuncSort { params, fsort }
1249    }
1250
1251    pub fn skip_binders(&self) -> FuncSort {
1252        self.fsort.clone()
1253    }
1254
1255    pub fn instantiate_identity(&self) -> FuncSort {
1256        self.fsort.clone()
1257    }
1258
1259    pub fn expect_mono(&self) -> FuncSort {
1260        assert!(self.params.is_empty());
1261        self.fsort.clone()
1262    }
1263
1264    pub fn params(&self) -> impl ExactSizeIterator<Item = SortParamKind> + '_ {
1265        self.params.iter().copied()
1266    }
1267
1268    pub fn instantiate(&self, args: &[SortArg]) -> FuncSort {
1269        self.fsort.fold_with(&mut SortSubst::new(args))
1270    }
1271}
1272
1273/// An argument for a generic parameter in a [`Sort`] which can be either a generic sort or a
1274/// generic bit-vector size.
1275#[derive(
1276    Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
1277)]
1278pub enum SortArg {
1279    Sort(Sort),
1280    BvSize(BvSize),
1281}
1282
1283#[derive(Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1284pub enum ConstantInfo {
1285    /// An uninterpreted constant
1286    Uninterpreted,
1287    /// A non-integral constant whose value is specified by the user
1288    Interpreted(Expr, Sort),
1289}
1290
1291#[derive(Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1292pub enum StaticInfo {
1293    Unknown,
1294    /// A static item whose type was specified by the user
1295    Known(Ty),
1296}
1297
1298#[derive(Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1299pub struct AdtDef(Interned<AdtDefData>);
1300
1301#[derive(Debug, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1302pub struct AdtDefData {
1303    invariants: Vec<Invariant>,
1304    sort_def: AdtSortDef,
1305    opaque: bool,
1306    rustc: ty::AdtDef,
1307}
1308
1309/// Option-like enum to explicitly mark that we don't have information about an ADT because it was
1310/// annotated with `#[flux::opaque]`. Note that only structs can be marked as opaque.
1311#[derive(Clone, Debug, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
1312pub enum Opaqueness<T> {
1313    Opaque,
1314    Transparent(T),
1315}
1316
1317impl<T> Opaqueness<T> {
1318    pub fn map<S>(self, f: impl FnOnce(T) -> S) -> Opaqueness<S> {
1319        match self {
1320            Opaqueness::Opaque => Opaqueness::Opaque,
1321            Opaqueness::Transparent(value) => Opaqueness::Transparent(f(value)),
1322        }
1323    }
1324
1325    pub fn as_ref(&self) -> Opaqueness<&T> {
1326        match self {
1327            Opaqueness::Opaque => Opaqueness::Opaque,
1328            Opaqueness::Transparent(value) => Opaqueness::Transparent(value),
1329        }
1330    }
1331
1332    pub fn as_deref(&self) -> Opaqueness<&T::Target>
1333    where
1334        T: std::ops::Deref,
1335    {
1336        match self {
1337            Opaqueness::Opaque => Opaqueness::Opaque,
1338            Opaqueness::Transparent(value) => Opaqueness::Transparent(value.deref()),
1339        }
1340    }
1341
1342    pub fn ok_or_else<E>(self, err: impl FnOnce() -> E) -> Result<T, E> {
1343        match self {
1344            Opaqueness::Transparent(v) => Ok(v),
1345            Opaqueness::Opaque => Err(err()),
1346        }
1347    }
1348
1349    #[track_caller]
1350    pub fn expect(self, msg: &str) -> T {
1351        match self {
1352            Opaqueness::Transparent(val) => val,
1353            Opaqueness::Opaque => bug!("{}", msg),
1354        }
1355    }
1356
1357    pub fn ok_or_query_err(self, struct_id: DefId) -> Result<T, QueryErr> {
1358        self.ok_or_else(|| QueryErr::OpaqueStruct { struct_id })
1359    }
1360}
1361
1362impl<T, E> Opaqueness<Result<T, E>> {
1363    pub fn transpose(self) -> Result<Opaqueness<T>, E> {
1364        match self {
1365            Opaqueness::Transparent(Ok(x)) => Ok(Opaqueness::Transparent(x)),
1366            Opaqueness::Transparent(Err(e)) => Err(e),
1367            Opaqueness::Opaque => Ok(Opaqueness::Opaque),
1368        }
1369    }
1370}
1371
1372pub static INT_TYS: [IntTy; 6] =
1373    [IntTy::Isize, IntTy::I8, IntTy::I16, IntTy::I32, IntTy::I64, IntTy::I128];
1374pub static UINT_TYS: [UintTy; 6] =
1375    [UintTy::Usize, UintTy::U8, UintTy::U16, UintTy::U32, UintTy::U64, UintTy::U128];
1376
1377#[derive(
1378    Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable, TypeFoldable, TypeVisitable,
1379)]
1380pub struct Invariant {
1381    // This predicate may have sort variables, but we don't explicitly mark it like in `PolyFuncSort`.
1382    // See comment on `apply` for details.
1383    pred: Binder<Expr>,
1384}
1385
1386impl Invariant {
1387    pub fn new(pred: Binder<Expr>) -> Self {
1388        Self { pred }
1389    }
1390
1391    pub fn apply(&self, idx: &Expr) -> Expr {
1392        // The predicate may have sort variables but we don't explicitly instantiate them. This
1393        // works because within an expression, sort variables can only appear inside the sort
1394        // annotation for a lambda and invariants cannot have lambdas. It remains to instantiate
1395        // variables in the sort of the binder itself, but since we are removing it, we can avoid
1396        // the explicit instantiation. Ultimately, this works because the expression we generate in
1397        // fixpoint doesn't need sort annotations (sorts are re-inferred).
1398        self.pred.replace_bound_reft(idx)
1399    }
1400}
1401
1402pub type PolyVariants = List<Binder<VariantSig>>;
1403pub type PolyVariant = Binder<VariantSig>;
1404
1405#[derive(Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1406pub struct VariantSig {
1407    pub adt_def: AdtDef,
1408    pub args: GenericArgs,
1409    pub fields: List<Ty>,
1410    pub idx: Expr,
1411    pub requires: List<Expr>,
1412}
1413
1414pub type PolyFnSig = Binder<FnSig>;
1415
1416#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
1417pub struct FnSig {
1418    pub safety: Safety,
1419    pub abi: rustc_abi::ExternAbi,
1420    pub requires: List<Expr>,
1421    pub inputs: List<Ty>,
1422    pub output: Binder<FnOutput>,
1423    pub no_panic: Expr,
1424    /// was this auto-lifted (or from a spec)
1425    pub lifted: bool,
1426}
1427
1428#[derive(
1429    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
1430)]
1431pub struct FnOutput {
1432    pub ret: Ty,
1433    pub ensures: List<Ensures>,
1434}
1435
1436#[derive(Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1437pub enum Ensures {
1438    Type(Path, Ty),
1439    Pred(Expr),
1440}
1441
1442#[derive(Debug, TypeVisitable, TypeFoldable)]
1443pub struct Qualifier {
1444    pub def_id: FluxLocalDefId,
1445    pub body: Binder<Expr>,
1446    pub kind: QualifierKind,
1447}
1448
1449#[derive(Debug, TypeFoldable, TypeVisitable, Copy, Clone)]
1450pub enum QualifierKind {
1451    Global,
1452    Local,
1453    Hint,
1454}
1455
1456/// A `PrimOpProp` is a single property for a primitive operation which
1457/// can be conjoined to get the definition of the [`PrimRel`] for that
1458/// primitive operation.
1459#[derive(Debug, TypeVisitable, TypeFoldable)]
1460pub struct PrimOpProp {
1461    pub def_id: FluxLocalDefId,
1462    pub op: BinOp,
1463    pub body: Binder<Expr>,
1464}
1465
1466#[derive(Debug, TypeVisitable, TypeFoldable)]
1467pub struct PrimRel {
1468    pub body: Binder<Expr>,
1469}
1470
1471pub type TyCtor = Binder<Ty>;
1472
1473impl TyCtor {
1474    pub fn to_ty(&self) -> Ty {
1475        match &self.vars()[..] {
1476            [] => {
1477                return self.skip_binder_ref().shift_out_escaping(1);
1478            }
1479            [BoundVariableKind::Refine(sort, ..)] => {
1480                if sort.is_unit() {
1481                    return self.replace_bound_reft(&Expr::unit());
1482                }
1483                if let Some(def_id) = sort.is_unit_adt() {
1484                    return self.replace_bound_reft(&Expr::unit_struct(def_id));
1485                }
1486            }
1487            _ => {}
1488        }
1489        Ty::exists(self.clone())
1490    }
1491}
1492
1493#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1494pub struct Ty(Interned<TyKind>);
1495
1496impl Ty {
1497    pub fn kind(&self) -> &TyKind {
1498        &self.0
1499    }
1500
1501    /// Dummy type used for the `Self` of a `TraitRef` created when converting a trait object, and
1502    /// which gets removed in `ExistentialTraitRef`. This type must not appear anywhere in other
1503    /// converted types and must be a valid `rustc` type (i.e., we must be able to call `to_rustc`
1504    /// on it). `TyKind::Infer(TyVid(0))` does the job, with the caveat that we must skip 0 when
1505    /// generating `TyKind::Infer` for "type holes".
1506    pub fn trait_object_dummy_self() -> Ty {
1507        Ty::infer(TyVid::from_u32(0))
1508    }
1509
1510    pub fn dynamic(preds: impl Into<List<Binder<ExistentialPredicate>>>, region: Region) -> Ty {
1511        BaseTy::Dynamic(preds.into(), region).to_ty()
1512    }
1513
1514    pub fn strg_ref(re: Region, path: Path, ty: Ty) -> Ty {
1515        TyKind::StrgRef(re, path, ty).intern()
1516    }
1517
1518    pub fn ptr(pk: impl Into<PtrKind>, path: impl Into<Path>) -> Ty {
1519        TyKind::Ptr(pk.into(), path.into()).intern()
1520    }
1521
1522    pub fn constr(p: impl Into<Expr>, ty: Ty) -> Ty {
1523        TyKind::Constr(p.into(), ty).intern()
1524    }
1525
1526    pub fn uninit() -> Ty {
1527        TyKind::Uninit.intern()
1528    }
1529
1530    pub fn indexed(bty: BaseTy, idx: impl Into<Expr>) -> Ty {
1531        TyKind::Indexed(bty, idx.into()).intern()
1532    }
1533
1534    pub fn exists(ty: Binder<Ty>) -> Ty {
1535        TyKind::Exists(ty).intern()
1536    }
1537
1538    pub fn exists_with_constr(bty: BaseTy, pred: Expr) -> Ty {
1539        let sort = bty.sort();
1540        let ty = Ty::indexed(bty, Expr::nu());
1541        Ty::exists(Binder::bind_with_sort(Ty::constr(pred, ty), sort))
1542    }
1543
1544    pub fn discr(adt_def: AdtDef, place: Place) -> Ty {
1545        TyKind::Discr(adt_def, place).intern()
1546    }
1547
1548    pub fn unit() -> Ty {
1549        Ty::tuple(vec![])
1550    }
1551
1552    pub fn bool() -> Ty {
1553        BaseTy::Bool.to_ty()
1554    }
1555
1556    pub fn int(int_ty: IntTy) -> Ty {
1557        BaseTy::Int(int_ty).to_ty()
1558    }
1559
1560    pub fn uint(uint_ty: UintTy) -> Ty {
1561        BaseTy::Uint(uint_ty).to_ty()
1562    }
1563
1564    pub fn param(param_ty: ParamTy) -> Ty {
1565        TyKind::Param(param_ty).intern()
1566    }
1567
1568    pub fn downcast(
1569        adt: AdtDef,
1570        args: GenericArgs,
1571        ty: Ty,
1572        variant: VariantIdx,
1573        fields: List<Ty>,
1574    ) -> Ty {
1575        TyKind::Downcast(adt, args, ty, variant, fields).intern()
1576    }
1577
1578    pub fn blocked(ty: Ty) -> Ty {
1579        TyKind::Blocked(ty).intern()
1580    }
1581
1582    pub fn str() -> Ty {
1583        BaseTy::Str.to_ty()
1584    }
1585
1586    pub fn char() -> Ty {
1587        BaseTy::Char.to_ty()
1588    }
1589
1590    pub fn float(float_ty: FloatTy) -> Ty {
1591        BaseTy::Float(float_ty).to_ty()
1592    }
1593
1594    pub fn mk_ref(region: Region, ty: Ty, mutbl: Mutability) -> Ty {
1595        BaseTy::Ref(region, ty, mutbl).to_ty()
1596    }
1597
1598    pub fn mk_slice(ty: Ty) -> Ty {
1599        BaseTy::Slice(ty).to_ty()
1600    }
1601
1602    pub fn mk_box(genv: GlobalEnv, deref_ty: Ty, alloc_ty: GenericArg) -> QueryResult<Ty> {
1603        let def_id = genv.tcx().require_lang_item(LangItem::OwnedBox, DUMMY_SP);
1604        let adt_def = genv.adt_def(def_id)?;
1605
1606        let args = List::from_arr([GenericArg::Ty(deref_ty), alloc_ty]);
1607
1608        let bty = BaseTy::adt(adt_def, args);
1609        Ok(Ty::indexed(bty, Expr::unit_struct(def_id)))
1610    }
1611
1612    pub fn mk_box_with_default_alloc(genv: GlobalEnv, deref_ty: Ty) -> QueryResult<Ty> {
1613        let def_id = genv.tcx().require_lang_item(LangItem::OwnedBox, DUMMY_SP);
1614
1615        let generics = genv.generics_of(def_id)?;
1616        let alloc_ty = genv
1617            .lower_type_of(generics.own_params[1].def_id)?
1618            .skip_binder();
1619        let alloc_ty = Refiner::default_for_item(genv, def_id)?.refine_generic_arg(
1620            &generics.own_params[1],
1621            &flux_rustc_bridge::ty::GenericArg::Ty(alloc_ty),
1622        )?;
1623
1624        Ty::mk_box(genv, deref_ty, alloc_ty)
1625    }
1626
1627    pub fn tuple(tys: impl Into<List<Ty>>) -> Ty {
1628        BaseTy::Tuple(tys.into()).to_ty()
1629    }
1630
1631    pub fn array(ty: Ty, c: Const) -> Ty {
1632        BaseTy::Array(ty, c).to_ty()
1633    }
1634
1635    pub fn closure(
1636        did: DefId,
1637        tys: impl Into<List<Ty>>,
1638        args: &flux_rustc_bridge::ty::GenericArgs,
1639        no_panic: bool,
1640    ) -> Ty {
1641        BaseTy::Closure(did, tys.into(), args.clone(), no_panic).to_ty()
1642    }
1643
1644    pub fn coroutine(
1645        did: DefId,
1646        resume_ty: Ty,
1647        upvar_tys: List<Ty>,
1648        args: flux_rustc_bridge::ty::GenericArgs,
1649    ) -> Ty {
1650        BaseTy::Coroutine(did, resume_ty, upvar_tys, args.clone()).to_ty()
1651    }
1652
1653    pub fn never() -> Ty {
1654        BaseTy::Never.to_ty()
1655    }
1656
1657    pub fn infer(vid: TyVid) -> Ty {
1658        TyKind::Infer(vid).intern()
1659    }
1660
1661    pub fn unconstr(&self) -> (Ty, Expr) {
1662        fn go(this: &Ty, preds: &mut Vec<Expr>) -> Ty {
1663            if let TyKind::Constr(pred, ty) = this.kind() {
1664                preds.push(pred.clone());
1665                go(ty, preds)
1666            } else {
1667                this.clone()
1668            }
1669        }
1670        let mut preds = vec![];
1671        (go(self, &mut preds), Expr::and_from_iter(preds))
1672    }
1673
1674    pub fn unblocked(&self) -> Ty {
1675        match self.kind() {
1676            TyKind::Blocked(ty) => ty.clone(),
1677            _ => self.clone(),
1678        }
1679    }
1680
1681    /// Whether the type is an `int` or a `uint`
1682    pub fn is_integral(&self) -> bool {
1683        self.as_bty_skipping_existentials()
1684            .map(BaseTy::is_integral)
1685            .unwrap_or_default()
1686    }
1687
1688    /// Whether the type is a `bool`
1689    pub fn is_bool(&self) -> bool {
1690        self.as_bty_skipping_existentials()
1691            .map(BaseTy::is_bool)
1692            .unwrap_or_default()
1693    }
1694
1695    /// Whether the type is a `char`
1696    pub fn is_char(&self) -> bool {
1697        self.as_bty_skipping_existentials()
1698            .map(BaseTy::is_char)
1699            .unwrap_or_default()
1700    }
1701
1702    pub fn is_uninit(&self) -> bool {
1703        matches!(self.kind(), TyKind::Uninit)
1704    }
1705
1706    pub fn is_box(&self) -> bool {
1707        self.as_bty_skipping_existentials()
1708            .map(BaseTy::is_box)
1709            .unwrap_or_default()
1710    }
1711
1712    pub fn is_struct(&self) -> bool {
1713        self.as_bty_skipping_existentials()
1714            .map(BaseTy::is_struct)
1715            .unwrap_or_default()
1716    }
1717
1718    pub fn is_array(&self) -> bool {
1719        self.as_bty_skipping_existentials()
1720            .map(BaseTy::is_array)
1721            .unwrap_or_default()
1722    }
1723
1724    pub fn is_slice(&self) -> bool {
1725        self.as_bty_skipping_existentials()
1726            .map(BaseTy::is_slice)
1727            .unwrap_or_default()
1728    }
1729
1730    pub fn as_bty_skipping_existentials(&self) -> Option<&BaseTy> {
1731        match self.kind() {
1732            TyKind::Indexed(bty, _) => Some(bty),
1733            TyKind::Exists(ty) => Some(ty.skip_binder_ref().as_bty_skipping_existentials()?),
1734            TyKind::Constr(_, ty) => ty.as_bty_skipping_existentials(),
1735            _ => None,
1736        }
1737    }
1738
1739    #[track_caller]
1740    pub fn expect_discr(&self) -> (&AdtDef, &Place) {
1741        if let TyKind::Discr(adt_def, place) = self.kind() {
1742            (adt_def, place)
1743        } else {
1744            tracked_span_bug!("expected discr")
1745        }
1746    }
1747
1748    #[track_caller]
1749    pub fn expect_adt(&self) -> (&AdtDef, &[GenericArg], &Expr) {
1750        if let TyKind::Indexed(BaseTy::Adt(adt_def, args), idx) = self.kind() {
1751            (adt_def, args, idx)
1752        } else {
1753            tracked_span_bug!("expected adt `{self:?}`")
1754        }
1755    }
1756
1757    #[track_caller]
1758    pub fn expect_tuple(&self) -> &[Ty] {
1759        if let TyKind::Indexed(BaseTy::Tuple(tys), _) = self.kind() {
1760            tys
1761        } else {
1762            tracked_span_bug!("expected tuple found `{self:?}` (kind: `{:?}`)", self.kind())
1763        }
1764    }
1765
1766    pub fn simplify_type(&self) -> Option<SimplifiedType> {
1767        self.as_bty_skipping_existentials()
1768            .and_then(BaseTy::simplify_type)
1769    }
1770}
1771
1772impl<'tcx> ToRustc<'tcx> for Ty {
1773    type T = rustc_middle::ty::Ty<'tcx>;
1774
1775    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
1776        match self.kind() {
1777            TyKind::Indexed(bty, _) => bty.to_rustc(tcx),
1778            TyKind::Exists(ty) => ty.skip_binder_ref().to_rustc(tcx),
1779            TyKind::Constr(_, ty) => ty.to_rustc(tcx),
1780            TyKind::Param(pty) => pty.to_ty(tcx),
1781            TyKind::StrgRef(re, _, ty) => {
1782                rustc_middle::ty::Ty::new_ref(
1783                    tcx,
1784                    re.to_rustc(tcx),
1785                    ty.to_rustc(tcx),
1786                    Mutability::Mut,
1787                )
1788            }
1789            TyKind::Infer(vid) => rustc_middle::ty::Ty::new_var(tcx, *vid),
1790            TyKind::Uninit
1791            | TyKind::Ptr(_, _)
1792            | TyKind::Discr(..)
1793            | TyKind::Downcast(..)
1794            | TyKind::Blocked(_) => bug!("TODO: to_rustc for `{self:?}`"),
1795        }
1796    }
1797}
1798
1799#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, Debug)]
1800pub enum TyKind {
1801    Indexed(BaseTy, Expr),
1802    Exists(Binder<Ty>),
1803    Constr(Expr, Ty),
1804    Uninit,
1805    StrgRef(Region, Path, Ty),
1806    Ptr(PtrKind, Path),
1807    /// This is a bit of a hack. We use this type internally to represent the result of
1808    /// [`Rvalue::Discriminant`] in a way that we can recover the necessary control information
1809    /// when checking a [`match`]. The hack is that we assume the dicriminant remains the same from
1810    /// the creation of this type until we use it in a [`match`].
1811    ///
1812    ///
1813    /// [`Rvalue::Discriminant`]: flux_rustc_bridge::mir::Rvalue::Discriminant
1814    /// [`match`]: flux_rustc_bridge::mir::TerminatorKind::SwitchInt
1815    Discr(AdtDef, Place),
1816    Param(ParamTy),
1817    /// These only arise when you "narrow" an ADT down to a particular variant;
1818    /// either EXPLICITLY in a `match-of`, or IMPLICITLY when you access a field
1819    /// of a struct to "UNPACK" the struct.
1820    Downcast(AdtDef, GenericArgs, Ty, VariantIdx, List<Ty>),
1821    Blocked(Ty),
1822    /// A type that needs to be inferred by matching the signature against a rust signature.
1823    /// [`TyKind::Infer`] appear as an intermediate step during `conv` and should not be present in
1824    /// the final signature.
1825    Infer(TyVid),
1826}
1827
1828#[derive(Copy, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1829pub enum PtrKind {
1830    Mut(Region),
1831    Box,
1832}
1833
1834#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1835pub enum BaseTy {
1836    Int(IntTy),
1837    Uint(UintTy),
1838    Bool,
1839    Str,
1840    Char,
1841    Slice(Ty),
1842    Adt(AdtDef, GenericArgs),
1843    Float(FloatTy),
1844    RawPtr(Ty, Mutability),
1845    RawPtrMetadata(Ty),
1846    Ref(Region, Ty, Mutability),
1847    FnPtr(PolyFnSig),
1848    FnDef(DefId, GenericArgs),
1849    Tuple(List<Ty>),
1850    Alias(AliasKind, AliasTy),
1851    Array(Ty, Const),
1852    Never,
1853    Closure(DefId, /* upvar_tys */ List<Ty>, flux_rustc_bridge::ty::GenericArgs, bool),
1854    Coroutine(
1855        DefId,
1856        /*resume_ty: */ Ty,
1857        /* upvar_tys: */ List<Ty>,
1858        flux_rustc_bridge::ty::GenericArgs,
1859    ),
1860    Dynamic(List<Binder<ExistentialPredicate>>, Region),
1861    Param(ParamTy),
1862    Infer(TyVid),
1863    Foreign(DefId),
1864    Pat,
1865}
1866
1867impl BaseTy {
1868    pub fn opaque(alias_ty: AliasTy) -> BaseTy {
1869        BaseTy::Alias(AliasKind::Opaque, alias_ty)
1870    }
1871
1872    pub fn projection(alias_ty: AliasTy) -> BaseTy {
1873        BaseTy::Alias(AliasKind::Projection, alias_ty)
1874    }
1875
1876    pub fn adt(adt_def: AdtDef, args: GenericArgs) -> BaseTy {
1877        BaseTy::Adt(adt_def, args)
1878    }
1879
1880    pub fn fn_def(def_id: DefId, args: impl Into<GenericArgs>) -> BaseTy {
1881        BaseTy::FnDef(def_id, args.into())
1882    }
1883
1884    pub fn from_primitive_str(s: &str) -> Option<BaseTy> {
1885        match s {
1886            "i8" => Some(BaseTy::Int(IntTy::I8)),
1887            "i16" => Some(BaseTy::Int(IntTy::I16)),
1888            "i32" => Some(BaseTy::Int(IntTy::I32)),
1889            "i64" => Some(BaseTy::Int(IntTy::I64)),
1890            "i128" => Some(BaseTy::Int(IntTy::I128)),
1891            "u8" => Some(BaseTy::Uint(UintTy::U8)),
1892            "u16" => Some(BaseTy::Uint(UintTy::U16)),
1893            "u32" => Some(BaseTy::Uint(UintTy::U32)),
1894            "u64" => Some(BaseTy::Uint(UintTy::U64)),
1895            "u128" => Some(BaseTy::Uint(UintTy::U128)),
1896            "f16" => Some(BaseTy::Float(FloatTy::F16)),
1897            "f32" => Some(BaseTy::Float(FloatTy::F32)),
1898            "f64" => Some(BaseTy::Float(FloatTy::F64)),
1899            "f128" => Some(BaseTy::Float(FloatTy::F128)),
1900            "isize" => Some(BaseTy::Int(IntTy::Isize)),
1901            "usize" => Some(BaseTy::Uint(UintTy::Usize)),
1902            "bool" => Some(BaseTy::Bool),
1903            "char" => Some(BaseTy::Char),
1904            "str" => Some(BaseTy::Str),
1905            _ => None,
1906        }
1907    }
1908
1909    /// If `self` is a primitive, return its [`Symbol`].
1910    pub fn primitive_symbol(&self) -> Option<Symbol> {
1911        match self {
1912            BaseTy::Bool => Some(sym::bool),
1913            BaseTy::Char => Some(sym::char),
1914            BaseTy::Float(f) => {
1915                match f {
1916                    FloatTy::F16 => Some(sym::f16),
1917                    FloatTy::F32 => Some(sym::f32),
1918                    FloatTy::F64 => Some(sym::f64),
1919                    FloatTy::F128 => Some(sym::f128),
1920                }
1921            }
1922            BaseTy::Int(f) => {
1923                match f {
1924                    IntTy::Isize => Some(sym::isize),
1925                    IntTy::I8 => Some(sym::i8),
1926                    IntTy::I16 => Some(sym::i16),
1927                    IntTy::I32 => Some(sym::i32),
1928                    IntTy::I64 => Some(sym::i64),
1929                    IntTy::I128 => Some(sym::i128),
1930                }
1931            }
1932            BaseTy::Uint(f) => {
1933                match f {
1934                    UintTy::Usize => Some(sym::usize),
1935                    UintTy::U8 => Some(sym::u8),
1936                    UintTy::U16 => Some(sym::u16),
1937                    UintTy::U32 => Some(sym::u32),
1938                    UintTy::U64 => Some(sym::u64),
1939                    UintTy::U128 => Some(sym::u128),
1940                }
1941            }
1942            BaseTy::Str => Some(sym::str),
1943            _ => None,
1944        }
1945    }
1946
1947    pub fn is_integral(&self) -> bool {
1948        matches!(self, BaseTy::Int(_) | BaseTy::Uint(_))
1949    }
1950
1951    pub fn is_signed(&self) -> bool {
1952        matches!(self, BaseTy::Int(_))
1953    }
1954
1955    pub fn is_unsigned(&self) -> bool {
1956        matches!(self, BaseTy::Uint(_))
1957    }
1958
1959    pub fn is_float(&self) -> bool {
1960        matches!(self, BaseTy::Float(_))
1961    }
1962
1963    pub fn is_bool(&self) -> bool {
1964        matches!(self, BaseTy::Bool)
1965    }
1966
1967    fn is_struct(&self) -> bool {
1968        matches!(self, BaseTy::Adt(adt_def, _) if adt_def.is_struct())
1969    }
1970
1971    fn is_array(&self) -> bool {
1972        matches!(self, BaseTy::Array(..))
1973    }
1974
1975    fn is_slice(&self) -> bool {
1976        matches!(self, BaseTy::Slice(..))
1977    }
1978
1979    pub fn is_box(&self) -> bool {
1980        matches!(self, BaseTy::Adt(adt_def, _) if adt_def.is_box())
1981    }
1982
1983    pub fn is_char(&self) -> bool {
1984        matches!(self, BaseTy::Char)
1985    }
1986
1987    pub fn is_str(&self) -> bool {
1988        matches!(self, BaseTy::Str)
1989    }
1990
1991    pub fn invariants(
1992        &self,
1993        tcx: TyCtxt,
1994        overflow_mode: OverflowMode,
1995    ) -> impl Iterator<Item = Invariant> {
1996        let (invariants, args) = match self {
1997            BaseTy::Adt(adt_def, args) => (adt_def.invariants().skip_binder(), &args[..]),
1998            BaseTy::Uint(uint_ty) => (uint_invariants(*uint_ty, overflow_mode), &[][..]),
1999            BaseTy::Int(int_ty) => (int_invariants(*int_ty, overflow_mode), &[][..]),
2000            BaseTy::Char => (char_invariants(), &[][..]),
2001            BaseTy::Slice(_) => (slice_invariants(overflow_mode), &[][..]),
2002            _ => (&[][..], &[][..]),
2003        };
2004        invariants
2005            .iter()
2006            .map(move |inv| EarlyBinder(inv).instantiate_ref(tcx, args, &[]))
2007    }
2008
2009    pub fn to_ty(&self) -> Ty {
2010        let sort = self.sort();
2011        if sort.is_unit() {
2012            Ty::indexed(self.clone(), Expr::unit())
2013        } else {
2014            Ty::exists(Binder::bind_with_sort(
2015                Ty::indexed(self.shift_in_escaping(1), Expr::nu()),
2016                sort,
2017            ))
2018        }
2019    }
2020
2021    pub fn to_subset_ty_ctor(&self) -> SubsetTyCtor {
2022        let sort = self.sort();
2023        Binder::bind_with_sort(SubsetTy::trivial(self.clone(), Expr::nu()), sort)
2024    }
2025
2026    #[track_caller]
2027    pub fn expect_adt(&self) -> (&AdtDef, &[GenericArg]) {
2028        if let BaseTy::Adt(adt_def, args) = self {
2029            (adt_def, args)
2030        } else {
2031            tracked_span_bug!("expected adt `{self:?}`")
2032        }
2033    }
2034
2035    /// A type is an *atom* if it is "self-delimiting", i.e., it has a clear boundary
2036    /// when printed. This is used to avoid unnecessary parenthesis when pretty printing.
2037    pub fn is_atom(&self) -> bool {
2038        // (nilehmann) I'm not sure about this list, please adjust if you get any odd behavior
2039        matches!(
2040            self,
2041            BaseTy::Int(_)
2042                | BaseTy::Uint(_)
2043                | BaseTy::Slice(_)
2044                | BaseTy::Bool
2045                | BaseTy::Char
2046                | BaseTy::Str
2047                | BaseTy::Adt(..)
2048                | BaseTy::Tuple(..)
2049                | BaseTy::Param(_)
2050                | BaseTy::Array(..)
2051                | BaseTy::Never
2052                | BaseTy::Closure(..)
2053                | BaseTy::Coroutine(..)
2054                // opaque alias are atoms the way we print them now, but they won't
2055                // be if we print them as `impl Trait`
2056                | BaseTy::Alias(..)
2057        )
2058    }
2059
2060    /// Similar to [`rustc_infer::infer::canonical::ir::fast_reject::simplify_type`].
2061    ///
2062    /// This implementation is currently incomplete, so it should only be used in contexts
2063    /// where completeness is not required. Currently, it's used to find incoherent
2064    /// implementations when resolving associated constants. In this context, incompleteness
2065    /// is acceptable since the worst case outcome is simply failing to resolve a type-relative
2066    /// constant.
2067    fn simplify_type(&self) -> Option<SimplifiedType> {
2068        match self {
2069            BaseTy::Bool => Some(SimplifiedType::Bool),
2070            BaseTy::Char => Some(SimplifiedType::Char),
2071            BaseTy::Int(int_type) => Some(SimplifiedType::Int(*int_type)),
2072            BaseTy::Uint(uint_type) => Some(SimplifiedType::Uint(*uint_type)),
2073            BaseTy::Float(float_type) => Some(SimplifiedType::Float(*float_type)),
2074            BaseTy::Adt(def, _) => Some(SimplifiedType::Adt(def.did())),
2075            BaseTy::Str => Some(SimplifiedType::Str),
2076            BaseTy::Array(..) => Some(SimplifiedType::Array),
2077            BaseTy::Slice(..) => Some(SimplifiedType::Slice),
2078            BaseTy::RawPtr(_, mutbl) => Some(SimplifiedType::Ptr(*mutbl)),
2079            BaseTy::Ref(_, _, mutbl) => Some(SimplifiedType::Ref(*mutbl)),
2080            BaseTy::FnDef(def_id, _) | BaseTy::Closure(def_id, ..) => {
2081                Some(SimplifiedType::Closure(*def_id))
2082            }
2083            BaseTy::Coroutine(def_id, ..) => Some(SimplifiedType::Coroutine(*def_id)),
2084            BaseTy::Never => Some(SimplifiedType::Never),
2085            BaseTy::Tuple(tys) => Some(SimplifiedType::Tuple(tys.len())),
2086            BaseTy::FnPtr(poly_fn_sig) => {
2087                Some(SimplifiedType::Function(poly_fn_sig.skip_binder_ref().inputs().len()))
2088            }
2089            BaseTy::Foreign(def_id) => Some(SimplifiedType::Foreign(*def_id)),
2090            BaseTy::RawPtrMetadata(_)
2091            | BaseTy::Alias(..)
2092            | BaseTy::Param(_)
2093            | BaseTy::Dynamic(..)
2094            | BaseTy::Infer(_) => None,
2095            BaseTy::Pat => todo!(),
2096        }
2097    }
2098}
2099
2100impl<'tcx> ToRustc<'tcx> for BaseTy {
2101    type T = rustc_middle::ty::Ty<'tcx>;
2102
2103    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2104        use rustc_middle::ty;
2105        match self {
2106            BaseTy::Int(i) => ty::Ty::new_int(tcx, *i),
2107            BaseTy::Uint(i) => ty::Ty::new_uint(tcx, *i),
2108            BaseTy::Param(pty) => pty.to_ty(tcx),
2109            BaseTy::Slice(ty) => ty::Ty::new_slice(tcx, ty.to_rustc(tcx)),
2110            BaseTy::Bool => tcx.types.bool,
2111            BaseTy::Char => tcx.types.char,
2112            BaseTy::Str => tcx.types.str_,
2113            BaseTy::Adt(adt_def, args) => {
2114                let did = adt_def.did();
2115                let adt_def = tcx.adt_def(did);
2116                let args = args.to_rustc(tcx);
2117                ty::Ty::new_adt(tcx, adt_def, args)
2118            }
2119            BaseTy::FnDef(def_id, args) => {
2120                let args = args.to_rustc(tcx);
2121                ty::Ty::new_fn_def(tcx, *def_id, args)
2122            }
2123            BaseTy::Float(f) => ty::Ty::new_float(tcx, *f),
2124            BaseTy::RawPtr(ty, mutbl) => ty::Ty::new_ptr(tcx, ty.to_rustc(tcx), *mutbl),
2125            BaseTy::Ref(re, ty, mutbl) => {
2126                ty::Ty::new_ref(tcx, re.to_rustc(tcx), ty.to_rustc(tcx), *mutbl)
2127            }
2128            BaseTy::FnPtr(poly_sig) => ty::Ty::new_fn_ptr(tcx, poly_sig.to_rustc(tcx)),
2129            BaseTy::Tuple(tys) => {
2130                let ts = tys.iter().map(|ty| ty.to_rustc(tcx)).collect_vec();
2131                ty::Ty::new_tup(tcx, &ts)
2132            }
2133            BaseTy::Alias(kind, alias_ty) => {
2134                ty::Ty::new_alias(tcx, kind.to_rustc(tcx), alias_ty.to_rustc(tcx))
2135            }
2136            BaseTy::Array(ty, n) => {
2137                let ty = ty.to_rustc(tcx);
2138                let n = n.to_rustc(tcx);
2139                ty::Ty::new_array_with_const_len(tcx, ty, n)
2140            }
2141            BaseTy::Never => tcx.types.never,
2142            BaseTy::Closure(did, _, args, _) => ty::Ty::new_closure(tcx, *did, args.to_rustc(tcx)),
2143            BaseTy::Dynamic(exi_preds, re) => {
2144                let preds: Vec<_> = exi_preds
2145                    .iter()
2146                    .map(|pred| pred.to_rustc(tcx))
2147                    .collect_vec();
2148                let preds = tcx.mk_poly_existential_predicates(&preds);
2149                ty::Ty::new_dynamic(tcx, preds, re.to_rustc(tcx))
2150            }
2151            BaseTy::Coroutine(did, _, _, args) => {
2152                ty::Ty::new_coroutine(tcx, *did, args.to_rustc(tcx))
2153            }
2154            BaseTy::Infer(ty_vid) => ty::Ty::new_var(tcx, *ty_vid),
2155            BaseTy::Foreign(def_id) => ty::Ty::new_foreign(tcx, *def_id),
2156            BaseTy::RawPtrMetadata(ty) => {
2157                ty::Ty::new_ptr(
2158                    tcx,
2159                    ty.to_rustc(tcx),
2160                    RawPtrKind::FakeForPtrMetadata.to_mutbl_lossy(),
2161                )
2162            }
2163            BaseTy::Pat => todo!(),
2164        }
2165    }
2166}
2167
2168#[derive(
2169    Clone, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
2170)]
2171pub struct AliasTy {
2172    pub def_id: DefId,
2173    pub args: GenericArgs,
2174    /// Holds the refinement-arguments for opaque-types; empty for projections
2175    pub refine_args: RefineArgs,
2176}
2177
2178impl AliasTy {
2179    pub fn new(def_id: DefId, args: GenericArgs, refine_args: RefineArgs) -> Self {
2180        AliasTy { args, refine_args, def_id }
2181    }
2182}
2183
2184/// This methods work only with associated type projections (i.e., no opaque types)
2185impl AliasTy {
2186    pub fn self_ty(&self) -> SubsetTyCtor {
2187        self.args[0].expect_base().clone()
2188    }
2189
2190    pub fn with_self_ty(&self, self_ty: SubsetTyCtor) -> Self {
2191        Self {
2192            def_id: self.def_id,
2193            args: [GenericArg::Base(self_ty)]
2194                .into_iter()
2195                .chain(self.args.iter().skip(1).cloned())
2196                .collect(),
2197            refine_args: self.refine_args.clone(),
2198        }
2199    }
2200}
2201
2202impl<'tcx> ToRustc<'tcx> for AliasTy {
2203    type T = rustc_middle::ty::AliasTy<'tcx>;
2204
2205    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2206        rustc_middle::ty::AliasTy::new(tcx, self.def_id, self.args.to_rustc(tcx))
2207    }
2208}
2209
2210pub type RefineArgs = List<Expr>;
2211
2212#[extension(pub trait RefineArgsExt)]
2213impl RefineArgs {
2214    fn identity_for_item(genv: GlobalEnv, def_id: DefId) -> QueryResult<RefineArgs> {
2215        Self::for_item(genv, def_id, |param, index| {
2216            Ok(Expr::var(Var::EarlyParam(EarlyReftParam {
2217                index: index as u32,
2218                name: param.name(),
2219            })))
2220        })
2221    }
2222
2223    fn for_item<F>(genv: GlobalEnv, def_id: DefId, mut mk: F) -> QueryResult<RefineArgs>
2224    where
2225        F: FnMut(EarlyBinder<RefineParam>, usize) -> QueryResult<Expr>,
2226    {
2227        let reft_generics = genv.refinement_generics_of(def_id)?;
2228        let count = reft_generics.count();
2229        let mut args = Vec::with_capacity(count);
2230        reft_generics.fill_item(genv, &mut args, &mut mk)?;
2231        Ok(List::from_vec(args))
2232    }
2233}
2234
2235/// A type constructor meant to be used as generic a argument of [kind base]. This is just an alias
2236/// to [`Binder<SubsetTy>`], but we expect the binder to have a single bound variable of the sort of
2237/// the underlying [base type].
2238///
2239/// [kind base]: GenericParamDefKind::Base
2240/// [base type]: SubsetTy::bty
2241pub type SubsetTyCtor = Binder<SubsetTy>;
2242
2243impl SubsetTyCtor {
2244    pub fn as_bty_skipping_binder(&self) -> &BaseTy {
2245        &self.as_ref().skip_binder().bty
2246    }
2247
2248    pub fn to_ty(&self) -> Ty {
2249        let sort = self.sort();
2250        if sort.is_unit() {
2251            self.replace_bound_reft(&Expr::unit()).to_ty()
2252        } else if let Some(def_id) = sort.is_unit_adt() {
2253            self.replace_bound_reft(&Expr::unit_struct(def_id)).to_ty()
2254        } else {
2255            Ty::exists(self.as_ref().map(SubsetTy::to_ty))
2256        }
2257    }
2258
2259    pub fn to_ty_ctor(&self) -> TyCtor {
2260        self.as_ref().map(SubsetTy::to_ty)
2261    }
2262}
2263
2264/// A subset type is a simplified version of a type that has the form `{b[e] | p}` where `b` is a
2265/// [`BaseTy`], `e` a refinement index, and `p` a predicate.
2266///
2267/// These are mainly found under a [`Binder`] with a single variable of the base type's sort. This
2268/// can be interpreted as a type constructor or an existential type. For example, under a binder with a
2269/// variable `v` of sort `int`, we can interpret `{i32[v] | v > 0}` as:
2270/// - A lambda `λv:int. {i32[v] | v > 0}` that "constructs" types when applied to ints, or
2271/// - An existential type `∃v:int. {i32[v] | v > 0}`.
2272///
2273/// This second interpretation is the reason we call this a subset type, i.e., the type `∃v. {b[v] | p}`
2274/// corresponds to the subset of values of  type `b` whose index satisfies `p`. These are the types
2275/// written as `B{v: p}` in the surface syntax and correspond to the types supported in other
2276/// refinement type systems like Liquid Haskell (with the difference that we are explicit
2277/// about separating refinements from program values via an index).
2278///
2279/// The main purpose for subset types is to be used as generic arguments of [kind base] when
2280/// interpreted as type constructors. They have two key properties that makes them suitable
2281/// for this:
2282///
2283/// 1. **Syntactic Restriction**: Subset types are syntactically restricted, making it easier to
2284///    relate them structurally (e.g., for subtyping). For instance, given two types `S<λv. T1>` and
2285///    `S<λ. T2>`, if `T1` and `T2` are subset types, we know they match structurally (at least
2286///    shallowly). In particularly, the syntactic restriction rules out complex types like
2287///    `S<λv. (i32[v], i32[v])>` simplifying some operations.
2288///
2289/// 2. **Eager Canonicalization**: Subset types can be eagerly canonicalized via [*strengthening*]
2290///    during substitution. For example, suppose we have a function:
2291///    ```text
2292///    fn foo<T>(x: T[@a], y: { T[@b] | b == a }) { }
2293///    ```
2294///    If we instantiate `T` with `λv. { i32[v] | v > 0}`, after substitution and applying the
2295///    lambda (the indexing syntax `T[a]` corresponds to an application of the lambda), we get:
2296///    ```text
2297///    fn foo(x: {i32[@a] | a > 0}, y: { { i32[@b] | b > 0 } | b == a }) { }
2298///    ```
2299///    Via *strengthening* we can canonicalize this to
2300///    ```text
2301///    fn foo(x: {i32[@a] | a > 0}, y: { i32[@b] | b == a && b > 0 }) { }
2302///    ```
2303///    As a result, we can guarantee the syntactic restriction through substitution.
2304///
2305/// [kind base]: GenericParamDefKind::Base
2306/// [*strengthening*]: https://arxiv.org/pdf/2010.07763.pdf
2307#[derive(PartialEq, Clone, Eq, Hash, TyEncodable, TyDecodable)]
2308pub struct SubsetTy {
2309    /// The base type `b` in the subset type `{b[e] | p}`.
2310    ///
2311    /// **NOTE:** This is mostly going to be under a [`Binder`]. It is not yet clear to me whether
2312    /// this [`BaseTy`] should be able to mention variables in the binder. In general, in a type
2313    /// `∃v. {b[e] | p}`, it's fine to mention `v` inside `b`, but since [`SubsetTy`] is meant to
2314    /// facilitate syntactic manipulation we may want to restrict this.
2315    pub bty: BaseTy,
2316    /// The refinement index `e` in the subset type `{b[e] | p}`. This can be an arbitrary expression,
2317    /// which makes manipulation easier. However, since this is mostly found under a binder, we expect
2318    /// it to be [`Expr::nu()`].
2319    pub idx: Expr,
2320    /// The predicate `p` in the subset type `{b[e] | p}`.
2321    pub pred: Expr,
2322}
2323
2324impl SubsetTy {
2325    pub fn new(bty: BaseTy, idx: impl Into<Expr>, pred: impl Into<Expr>) -> Self {
2326        Self { bty, idx: idx.into(), pred: pred.into() }
2327    }
2328
2329    pub fn trivial(bty: BaseTy, idx: impl Into<Expr>) -> Self {
2330        Self::new(bty, idx, Expr::tt())
2331    }
2332
2333    pub fn strengthen(&self, pred: impl Into<Expr>) -> Self {
2334        let this = self.clone();
2335        let pred = Expr::and(this.pred, pred).simplify(&SnapshotMap::default());
2336        Self { bty: this.bty, idx: this.idx, pred }
2337    }
2338
2339    pub fn to_ty(&self) -> Ty {
2340        let bty = self.bty.clone();
2341        if self.pred.is_trivially_true() {
2342            Ty::indexed(bty, &self.idx)
2343        } else {
2344            Ty::constr(&self.pred, Ty::indexed(bty, &self.idx))
2345        }
2346    }
2347}
2348
2349impl<'tcx> ToRustc<'tcx> for SubsetTy {
2350    type T = rustc_middle::ty::Ty<'tcx>;
2351
2352    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> rustc_middle::ty::Ty<'tcx> {
2353        self.bty.to_rustc(tcx)
2354    }
2355}
2356
2357#[derive(PartialEq, Clone, Eq, Hash, TyEncodable, TyDecodable)]
2358pub enum GenericArg {
2359    Ty(Ty),
2360    Base(SubsetTyCtor),
2361    Lifetime(Region),
2362    Const(Const),
2363}
2364
2365impl GenericArg {
2366    #[track_caller]
2367    pub fn expect_type(&self) -> &Ty {
2368        if let GenericArg::Ty(ty) = self {
2369            ty
2370        } else {
2371            bug!("expected `rty::GenericArg::Ty`, found `{self:?}`")
2372        }
2373    }
2374
2375    #[track_caller]
2376    pub fn expect_base(&self) -> &SubsetTyCtor {
2377        if let GenericArg::Base(ctor) = self {
2378            ctor
2379        } else {
2380            bug!("expected `rty::GenericArg::Base`, found `{self:?}`")
2381        }
2382    }
2383
2384    pub fn from_param_def(param: &GenericParamDef) -> Self {
2385        match param.kind {
2386            GenericParamDefKind::Type { .. } => {
2387                let param_ty = ParamTy { index: param.index, name: param.name };
2388                GenericArg::Ty(Ty::param(param_ty))
2389            }
2390            GenericParamDefKind::Base { .. } => {
2391                // λv. T[v]
2392                let param_ty = ParamTy { index: param.index, name: param.name };
2393                GenericArg::Base(Binder::bind_with_sort(
2394                    SubsetTy::trivial(BaseTy::Param(param_ty), Expr::nu()),
2395                    Sort::Param(param_ty),
2396                ))
2397            }
2398            GenericParamDefKind::Lifetime => {
2399                let region = EarlyParamRegion { index: param.index, name: param.name };
2400                GenericArg::Lifetime(Region::ReEarlyParam(region))
2401            }
2402            GenericParamDefKind::Const { .. } => {
2403                let param_const = ParamConst { index: param.index, name: param.name };
2404                let kind = ConstKind::Param(param_const);
2405                GenericArg::Const(Const { kind })
2406            }
2407        }
2408    }
2409
2410    /// Creates a `GenericArgs` from the definition of generic parameters, by calling a closure to
2411    /// obtain arg. The closures get to observe the `GenericArgs` as they're being built, which can
2412    /// be used to correctly replace defaults of generic parameters.
2413    pub fn for_item<F>(genv: GlobalEnv, def_id: DefId, mut mk_kind: F) -> QueryResult<GenericArgs>
2414    where
2415        F: FnMut(&GenericParamDef, &[GenericArg]) -> GenericArg,
2416    {
2417        let defs = genv.generics_of(def_id)?;
2418        let count = defs.count();
2419        let mut args = Vec::with_capacity(count);
2420        Self::fill_item(genv, &mut args, &defs, &mut mk_kind)?;
2421        Ok(List::from_vec(args))
2422    }
2423
2424    pub fn identity_for_item(genv: GlobalEnv, def_id: DefId) -> QueryResult<GenericArgs> {
2425        Self::for_item(genv, def_id, |param, _| GenericArg::from_param_def(param))
2426    }
2427
2428    fn fill_item<F>(
2429        genv: GlobalEnv,
2430        args: &mut Vec<GenericArg>,
2431        generics: &Generics,
2432        mk_kind: &mut F,
2433    ) -> QueryResult<()>
2434    where
2435        F: FnMut(&GenericParamDef, &[GenericArg]) -> GenericArg,
2436    {
2437        if let Some(def_id) = generics.parent {
2438            let parent_generics = genv.generics_of(def_id)?;
2439            Self::fill_item(genv, args, &parent_generics, mk_kind)?;
2440        }
2441        for param in &generics.own_params {
2442            let kind = mk_kind(param, args);
2443            tracked_span_assert_eq!(param.index as usize, args.len());
2444            args.push(kind);
2445        }
2446        Ok(())
2447    }
2448}
2449
2450impl From<TyOrBase> for GenericArg {
2451    fn from(v: TyOrBase) -> Self {
2452        match v {
2453            TyOrBase::Ty(ty) => GenericArg::Ty(ty),
2454            TyOrBase::Base(ctor) => GenericArg::Base(ctor),
2455        }
2456    }
2457}
2458
2459impl<'tcx> ToRustc<'tcx> for GenericArg {
2460    type T = rustc_middle::ty::GenericArg<'tcx>;
2461
2462    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2463        use rustc_middle::ty;
2464        match self {
2465            GenericArg::Ty(ty) => ty::GenericArg::from(ty.to_rustc(tcx)),
2466            GenericArg::Base(ctor) => ty::GenericArg::from(ctor.skip_binder_ref().to_rustc(tcx)),
2467            GenericArg::Lifetime(re) => ty::GenericArg::from(re.to_rustc(tcx)),
2468            GenericArg::Const(c) => ty::GenericArg::from(c.to_rustc(tcx)),
2469        }
2470    }
2471}
2472
2473pub type GenericArgs = List<GenericArg>;
2474
2475#[extension(pub trait GenericArgsExt)]
2476impl GenericArgs {
2477    #[track_caller]
2478    fn box_args(&self) -> (&Ty, &GenericArg) {
2479        if let [GenericArg::Ty(deref), alloc] = &self[..] {
2480            (deref, alloc)
2481        } else {
2482            bug!("invalid generic arguments for box");
2483        }
2484    }
2485
2486    // We can't implement [`ToRustc`] because of coherence so we add it here
2487    fn to_rustc<'tcx>(&self, tcx: TyCtxt<'tcx>) -> rustc_middle::ty::GenericArgsRef<'tcx> {
2488        tcx.mk_args_from_iter(self.iter().map(|arg| arg.to_rustc(tcx)))
2489    }
2490
2491    fn rebase_onto(
2492        &self,
2493        tcx: &TyCtxt,
2494        source_ancestor: DefId,
2495        target_args: &GenericArgs,
2496    ) -> List<GenericArg> {
2497        let defs = tcx.generics_of(source_ancestor);
2498        target_args
2499            .iter()
2500            .chain(self.iter().skip(defs.count()))
2501            .cloned()
2502            .collect()
2503    }
2504}
2505
2506#[derive(Debug)]
2507pub enum TyOrBase {
2508    Ty(Ty),
2509    Base(SubsetTyCtor),
2510}
2511
2512impl TyOrBase {
2513    pub fn into_ty(self) -> Ty {
2514        match self {
2515            TyOrBase::Ty(ty) => ty,
2516            TyOrBase::Base(ctor) => ctor.to_ty(),
2517        }
2518    }
2519
2520    #[track_caller]
2521    pub fn expect_base(self) -> SubsetTyCtor {
2522        match self {
2523            TyOrBase::Base(ctor) => ctor,
2524            TyOrBase::Ty(_) => tracked_span_bug!("expected `TyOrBase::Base`"),
2525        }
2526    }
2527
2528    pub fn as_base(self) -> Option<SubsetTyCtor> {
2529        match self {
2530            TyOrBase::Base(ctor) => Some(ctor),
2531            TyOrBase::Ty(_) => None,
2532        }
2533    }
2534}
2535
2536#[derive(Debug, Clone, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
2537pub enum TyOrCtor {
2538    Ty(Ty),
2539    Ctor(TyCtor),
2540}
2541
2542impl TyOrCtor {
2543    #[track_caller]
2544    pub fn expect_ctor(self) -> TyCtor {
2545        match self {
2546            TyOrCtor::Ctor(ctor) => ctor,
2547            TyOrCtor::Ty(_) => tracked_span_bug!("expected `TyOrCtor::Ctor`"),
2548        }
2549    }
2550
2551    pub fn expect_subset_ty_ctor(self) -> SubsetTyCtor {
2552        self.expect_ctor().map(|ty| {
2553            if let canonicalize::CanonicalTy::Constr(constr_ty) = ty.shallow_canonicalize()
2554                && let TyKind::Indexed(bty, idx) = constr_ty.ty().kind()
2555                && idx.is_nu()
2556            {
2557                SubsetTy::new(bty.clone(), Expr::nu(), constr_ty.pred())
2558            } else {
2559                tracked_span_bug!()
2560            }
2561        })
2562    }
2563
2564    pub fn to_ty(&self) -> Ty {
2565        match self {
2566            TyOrCtor::Ctor(ctor) => ctor.to_ty(),
2567            TyOrCtor::Ty(ty) => ty.clone(),
2568        }
2569    }
2570}
2571
2572impl From<TyOrBase> for TyOrCtor {
2573    fn from(v: TyOrBase) -> Self {
2574        match v {
2575            TyOrBase::Ty(ty) => TyOrCtor::Ty(ty),
2576            TyOrBase::Base(ctor) => TyOrCtor::Ctor(ctor.to_ty_ctor()),
2577        }
2578    }
2579}
2580
2581impl CoroutineObligPredicate {
2582    pub fn to_poly_fn_sig(&self) -> PolyFnSig {
2583        let vars = vec![];
2584
2585        let resume_ty = &self.resume_ty;
2586        let env_ty = Ty::coroutine(
2587            self.def_id,
2588            resume_ty.clone(),
2589            self.upvar_tys.clone(),
2590            self.args.clone(),
2591        );
2592
2593        let inputs = List::from_arr([env_ty, resume_ty.clone()]);
2594        let output =
2595            Binder::bind_with_vars(FnOutput::new(self.output.clone(), vec![]), List::empty());
2596
2597        PolyFnSig::bind_with_vars(
2598            FnSig::new(
2599                Safety::Safe,
2600                rustc_abi::ExternAbi::RustCall,
2601                List::empty(),
2602                inputs,
2603                output,
2604                Expr::ff(),
2605                false,
2606            ),
2607            List::from(vars),
2608        )
2609    }
2610}
2611
2612impl RefinementGenerics {
2613    pub fn count(&self) -> usize {
2614        self.parent_count + self.own_params.len()
2615    }
2616
2617    pub fn own_count(&self) -> usize {
2618        self.own_params.len()
2619    }
2620}
2621
2622impl EarlyBinder<RefinementGenerics> {
2623    pub fn parent(&self) -> Option<DefId> {
2624        self.skip_binder_ref().parent
2625    }
2626
2627    pub fn parent_count(&self) -> usize {
2628        self.skip_binder_ref().parent_count
2629    }
2630
2631    pub fn count(&self) -> usize {
2632        self.skip_binder_ref().count()
2633    }
2634
2635    pub fn own_count(&self) -> usize {
2636        self.skip_binder_ref().own_count()
2637    }
2638
2639    pub fn own_param_at(&self, index: usize) -> EarlyBinder<RefineParam> {
2640        self.as_ref().map(|this| this.own_params[index].clone())
2641    }
2642
2643    pub fn param_at(
2644        &self,
2645        param_index: usize,
2646        genv: GlobalEnv,
2647    ) -> QueryResult<EarlyBinder<RefineParam>> {
2648        if let Some(index) = param_index.checked_sub(self.parent_count()) {
2649            Ok(self.own_param_at(index))
2650        } else {
2651            let parent = self.parent().expect("parent_count > 0 but no parent?");
2652            genv.refinement_generics_of(parent)?
2653                .param_at(param_index, genv)
2654        }
2655    }
2656
2657    pub fn iter_own_params(&self) -> impl Iterator<Item = EarlyBinder<RefineParam>> + use<'_> {
2658        self.skip_binder_ref()
2659            .own_params
2660            .iter()
2661            .cloned()
2662            .map(EarlyBinder)
2663    }
2664
2665    pub fn fill_item<F, R>(&self, genv: GlobalEnv, vec: &mut Vec<R>, mk: &mut F) -> QueryResult
2666    where
2667        F: FnMut(EarlyBinder<RefineParam>, usize) -> QueryResult<R>,
2668    {
2669        if let Some(def_id) = self.parent() {
2670            genv.refinement_generics_of(def_id)?
2671                .fill_item(genv, vec, mk)?;
2672        }
2673        for param in self.iter_own_params() {
2674            vec.push(mk(param, vec.len())?);
2675        }
2676        Ok(())
2677    }
2678}
2679
2680impl EarlyBinder<GenericPredicates> {
2681    pub fn predicates(&self) -> EarlyBinder<List<Clause>> {
2682        EarlyBinder(self.0.predicates.clone())
2683    }
2684}
2685
2686impl EarlyBinder<FuncSort> {
2687    /// See [`subst::GenericsSubstForSort`]
2688    pub fn instantiate_func_sort<E>(
2689        self,
2690        sort_for_param: impl FnMut(ParamTy) -> Result<Sort, E>,
2691    ) -> Result<FuncSort, E> {
2692        self.0.try_fold_with(&mut subst::GenericsSubstFolder::new(
2693            subst::GenericsSubstForSort { sort_for_param },
2694            &[],
2695        ))
2696    }
2697}
2698
2699impl VariantSig {
2700    pub fn new(
2701        adt_def: AdtDef,
2702        args: GenericArgs,
2703        fields: List<Ty>,
2704        idx: Expr,
2705        requires: List<Expr>,
2706    ) -> Self {
2707        VariantSig { adt_def, args, fields, idx, requires }
2708    }
2709
2710    pub fn fields(&self) -> &[Ty] {
2711        &self.fields
2712    }
2713
2714    pub fn ret(&self) -> Ty {
2715        let bty = BaseTy::Adt(self.adt_def.clone(), self.args.clone());
2716        let idx = self.idx.clone();
2717        Ty::indexed(bty, idx)
2718    }
2719}
2720
2721impl FnSig {
2722    pub fn new(
2723        safety: Safety,
2724        abi: rustc_abi::ExternAbi,
2725        requires: List<Expr>,
2726        inputs: List<Ty>,
2727        output: Binder<FnOutput>,
2728        no_panic: Expr,
2729        lifted: bool,
2730    ) -> Self {
2731        FnSig { safety, abi, requires, inputs, output, no_panic, lifted }
2732    }
2733
2734    pub fn requires(&self) -> &[Expr] {
2735        &self.requires
2736    }
2737
2738    pub fn inputs(&self) -> &[Ty] {
2739        &self.inputs
2740    }
2741
2742    pub fn no_panic(&self) -> Expr {
2743        self.no_panic.clone()
2744    }
2745
2746    pub fn output(&self) -> Binder<FnOutput> {
2747        self.output.clone()
2748    }
2749}
2750
2751impl<'tcx> ToRustc<'tcx> for FnSig {
2752    type T = rustc_middle::ty::FnSig<'tcx>;
2753
2754    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2755        tcx.mk_fn_sig(
2756            self.inputs().iter().map(|ty| ty.to_rustc(tcx)),
2757            self.output().as_ref().skip_binder().to_rustc(tcx),
2758            false,
2759            self.safety,
2760            self.abi,
2761        )
2762    }
2763}
2764
2765impl FnOutput {
2766    pub fn new(ret: Ty, ensures: impl Into<List<Ensures>>) -> Self {
2767        Self { ret, ensures: ensures.into() }
2768    }
2769}
2770
2771impl<'tcx> ToRustc<'tcx> for FnOutput {
2772    type T = rustc_middle::ty::Ty<'tcx>;
2773
2774    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2775        self.ret.to_rustc(tcx)
2776    }
2777}
2778
2779impl AdtDef {
2780    pub fn new(
2781        rustc: ty::AdtDef,
2782        sort_def: AdtSortDef,
2783        invariants: Vec<Invariant>,
2784        opaque: bool,
2785    ) -> Self {
2786        AdtDef(Interned::new(AdtDefData { invariants, sort_def, opaque, rustc }))
2787    }
2788
2789    pub fn did(&self) -> DefId {
2790        self.0.rustc.did()
2791    }
2792
2793    pub fn sort_def(&self) -> &AdtSortDef {
2794        &self.0.sort_def
2795    }
2796
2797    pub fn sort(&self, args: &[GenericArg]) -> Sort {
2798        self.sort_def().to_sort(args)
2799    }
2800
2801    pub fn is_box(&self) -> bool {
2802        self.0.rustc.is_box()
2803    }
2804
2805    pub fn is_enum(&self) -> bool {
2806        self.0.rustc.is_enum()
2807    }
2808
2809    pub fn is_struct(&self) -> bool {
2810        self.0.rustc.is_struct()
2811    }
2812
2813    pub fn is_union(&self) -> bool {
2814        self.0.rustc.is_union()
2815    }
2816
2817    pub fn variants(&self) -> &IndexSlice<VariantIdx, VariantDef> {
2818        self.0.rustc.variants()
2819    }
2820
2821    pub fn variant(&self, idx: VariantIdx) -> &VariantDef {
2822        self.0.rustc.variant(idx)
2823    }
2824
2825    pub fn invariants(&self) -> EarlyBinder<&[Invariant]> {
2826        EarlyBinder(&self.0.invariants)
2827    }
2828
2829    pub fn discriminants(&self) -> impl Iterator<Item = (VariantIdx, u128)> + '_ {
2830        self.0.rustc.discriminants()
2831    }
2832
2833    pub fn is_opaque(&self) -> bool {
2834        self.0.opaque
2835    }
2836}
2837
2838impl EarlyBinder<PolyVariant> {
2839    // The field_idx is `Some(i)` when we have the `i`-th field of a `union`, in which case,
2840    // the `inputs` are _just_ the `i`-th type (and not all the types...)
2841    pub fn to_poly_fn_sig(&self, field_idx: Option<crate::FieldIdx>) -> EarlyBinder<PolyFnSig> {
2842        self.as_ref().map(|poly_variant| {
2843            poly_variant.as_ref().map(|variant| {
2844                let ret = variant.ret().shift_in_escaping(1);
2845                let output = Binder::bind_with_vars(FnOutput::new(ret, vec![]), List::empty());
2846                let inputs = match field_idx {
2847                    None => variant.fields.clone(),
2848                    Some(i) => List::singleton(variant.fields[i.index()].clone()),
2849                };
2850                FnSig::new(
2851                    Safety::Safe,
2852                    rustc_abi::ExternAbi::Rust,
2853                    variant.requires.clone(),
2854                    inputs,
2855                    output,
2856                    Expr::tt(),
2857                    false,
2858                )
2859            })
2860        })
2861    }
2862}
2863
2864impl TyKind {
2865    fn intern(self) -> Ty {
2866        Ty(Interned::new(self))
2867    }
2868}
2869
2870/// returns the same invariants as for `usize` which is the length of a slice
2871fn slice_invariants(overflow_mode: OverflowMode) -> &'static [Invariant] {
2872    static DEFAULT: LazyLock<[Invariant; 1]> = LazyLock::new(|| {
2873        [Invariant { pred: Binder::bind_with_sort(Expr::ge(Expr::nu(), Expr::zero()), Sort::Int) }]
2874    });
2875    static OVERFLOW: LazyLock<[Invariant; 2]> = LazyLock::new(|| {
2876        [
2877            Invariant {
2878                pred: Binder::bind_with_sort(Expr::ge(Expr::nu(), Expr::zero()), Sort::Int),
2879            },
2880            Invariant {
2881                pred: Binder::bind_with_sort(
2882                    Expr::le(Expr::nu(), Expr::uint_max(UintTy::Usize)),
2883                    Sort::Int,
2884                ),
2885            },
2886        ]
2887    });
2888    if matches!(overflow_mode, OverflowMode::Strict | OverflowMode::Lazy) {
2889        &*OVERFLOW
2890    } else {
2891        &*DEFAULT
2892    }
2893}
2894
2895fn uint_invariants(uint_ty: UintTy, overflow_mode: OverflowMode) -> &'static [Invariant] {
2896    static DEFAULT: LazyLock<[Invariant; 1]> = LazyLock::new(|| {
2897        [Invariant { pred: Binder::bind_with_sort(Expr::ge(Expr::nu(), Expr::zero()), Sort::Int) }]
2898    });
2899
2900    static OVERFLOW: LazyLock<UnordMap<UintTy, [Invariant; 2]>> = LazyLock::new(|| {
2901        UINT_TYS
2902            .into_iter()
2903            .map(|uint_ty| {
2904                let invariants = [
2905                    Invariant {
2906                        pred: Binder::bind_with_sort(Expr::ge(Expr::nu(), Expr::zero()), Sort::Int),
2907                    },
2908                    Invariant {
2909                        pred: Binder::bind_with_sort(
2910                            Expr::le(Expr::nu(), Expr::uint_max(uint_ty)),
2911                            Sort::Int,
2912                        ),
2913                    },
2914                ];
2915                (uint_ty, invariants)
2916            })
2917            .collect()
2918    });
2919    if matches!(overflow_mode, OverflowMode::Strict | OverflowMode::Lazy) {
2920        &OVERFLOW[&uint_ty]
2921    } else {
2922        &*DEFAULT
2923    }
2924}
2925
2926fn char_invariants() -> &'static [Invariant] {
2927    static INVARIANTS: LazyLock<[Invariant; 2]> = LazyLock::new(|| {
2928        [
2929            Invariant {
2930                pred: Binder::bind_with_sort(
2931                    Expr::le(
2932                        Expr::cast(Sort::Char, Sort::Int, Expr::nu()),
2933                        Expr::constant((char::MAX as u32).into()),
2934                    ),
2935                    Sort::Int,
2936                ),
2937            },
2938            Invariant {
2939                pred: Binder::bind_with_sort(
2940                    Expr::le(Expr::zero(), Expr::cast(Sort::Char, Sort::Int, Expr::nu())),
2941                    Sort::Int,
2942                ),
2943            },
2944        ]
2945    });
2946    &*INVARIANTS
2947}
2948
2949fn int_invariants(int_ty: IntTy, overflow_mode: OverflowMode) -> &'static [Invariant] {
2950    static DEFAULT: [Invariant; 0] = [];
2951
2952    static OVERFLOW: LazyLock<UnordMap<IntTy, [Invariant; 2]>> = LazyLock::new(|| {
2953        INT_TYS
2954            .into_iter()
2955            .map(|int_ty| {
2956                let invariants = [
2957                    Invariant {
2958                        pred: Binder::bind_with_sort(
2959                            Expr::ge(Expr::nu(), Expr::int_min(int_ty)),
2960                            Sort::Int,
2961                        ),
2962                    },
2963                    Invariant {
2964                        pred: Binder::bind_with_sort(
2965                            Expr::le(Expr::nu(), Expr::int_max(int_ty)),
2966                            Sort::Int,
2967                        ),
2968                    },
2969                ];
2970                (int_ty, invariants)
2971            })
2972            .collect()
2973    });
2974    if matches!(overflow_mode, OverflowMode::Strict | OverflowMode::Lazy) {
2975        &OVERFLOW[&int_ty]
2976    } else {
2977        &DEFAULT
2978    }
2979}
2980
2981impl_internable!(AdtDefData, AdtSortDefData, TyKind);
2982impl_slice_internable!(
2983    Ty,
2984    GenericArg,
2985    Ensures,
2986    InferMode,
2987    Sort,
2988    SortArg,
2989    GenericParamDef,
2990    TraitRef,
2991    Binder<ExistentialPredicate>,
2992    Clause,
2993    PolyVariant,
2994    Invariant,
2995    RefineParam,
2996    FluxDefId,
2997    SortParamKind,
2998    AssocReft
2999);
3000
3001#[macro_export]
3002macro_rules! _Int {
3003    ($int_ty:pat, $idxs:pat) => {
3004        TyKind::Indexed(BaseTy::Int($int_ty), $idxs)
3005    };
3006}
3007pub use crate::_Int as Int;
3008
3009#[macro_export]
3010macro_rules! _Uint {
3011    ($uint_ty:pat, $idxs:pat) => {
3012        TyKind::Indexed(BaseTy::Uint($uint_ty), $idxs)
3013    };
3014}
3015pub use crate::_Uint as Uint;
3016
3017#[macro_export]
3018macro_rules! _Bool {
3019    ($idxs:pat) => {
3020        TyKind::Indexed(BaseTy::Bool, $idxs)
3021    };
3022}
3023pub use crate::_Bool as Bool;
3024
3025#[macro_export]
3026macro_rules! _Char {
3027    ($idxs:pat) => {
3028        TyKind::Indexed(BaseTy::Char, $idxs)
3029    };
3030}
3031pub use crate::_Char as Char;
3032
3033#[macro_export]
3034macro_rules! _Ref {
3035    ($($pats:pat),+ $(,)?) => {
3036        $crate::rty::TyKind::Indexed($crate::rty::BaseTy::Ref($($pats),+), _)
3037    };
3038}
3039pub use crate::_Ref as Ref;
3040
3041pub struct WfckResults {
3042    pub owner: FluxOwnerId,
3043    param_sorts: UnordMap<fhir::ParamId, Sort>,
3044    bin_op_sorts: ItemLocalMap<Sort>,
3045    fn_app_sorts: ItemLocalMap<List<SortArg>>,
3046    coercions: ItemLocalMap<Vec<Coercion>>,
3047    field_projs: ItemLocalMap<FieldProj>,
3048    node_sorts: ItemLocalMap<Sort>,
3049    record_ctors: ItemLocalMap<RecordCtor>,
3050}
3051
3052#[derive(Clone, Copy, Debug)]
3053pub enum Coercion {
3054    Inject(DefId),
3055    Project(DefId),
3056}
3057
3058#[derive(Clone, Copy, Debug)]
3059pub enum RecordCtor {
3060    Struct(DefId),
3061    RawPtr,
3062}
3063
3064pub type ItemLocalMap<T> = UnordMap<fhir::ItemLocalId, T>;
3065
3066#[derive(Debug)]
3067pub struct LocalTableInContext<'a, T> {
3068    owner: FluxOwnerId,
3069    data: &'a ItemLocalMap<T>,
3070}
3071
3072pub struct LocalTableInContextMut<'a, T> {
3073    owner: FluxOwnerId,
3074    data: &'a mut ItemLocalMap<T>,
3075}
3076
3077impl WfckResults {
3078    pub fn new(owner: impl Into<FluxOwnerId>) -> Self {
3079        Self {
3080            owner: owner.into(),
3081            param_sorts: UnordMap::default(),
3082            bin_op_sorts: ItemLocalMap::default(),
3083            coercions: ItemLocalMap::default(),
3084            field_projs: ItemLocalMap::default(),
3085            node_sorts: ItemLocalMap::default(),
3086            record_ctors: ItemLocalMap::default(),
3087            fn_app_sorts: ItemLocalMap::default(),
3088        }
3089    }
3090
3091    pub fn param_sorts_mut(&mut self) -> &mut UnordMap<fhir::ParamId, Sort> {
3092        &mut self.param_sorts
3093    }
3094
3095    pub fn param_sorts(&self) -> &UnordMap<fhir::ParamId, Sort> {
3096        &self.param_sorts
3097    }
3098
3099    pub fn bin_op_sorts_mut(&mut self) -> LocalTableInContextMut<'_, Sort> {
3100        LocalTableInContextMut { owner: self.owner, data: &mut self.bin_op_sorts }
3101    }
3102
3103    pub fn fn_app_sorts_mut(&mut self) -> LocalTableInContextMut<'_, List<SortArg>> {
3104        LocalTableInContextMut { owner: self.owner, data: &mut self.fn_app_sorts }
3105    }
3106
3107    pub fn fn_app_sorts(&self) -> LocalTableInContext<'_, List<SortArg>> {
3108        LocalTableInContext { owner: self.owner, data: &self.fn_app_sorts }
3109    }
3110
3111    pub fn bin_op_sorts(&self) -> LocalTableInContext<'_, Sort> {
3112        LocalTableInContext { owner: self.owner, data: &self.bin_op_sorts }
3113    }
3114
3115    pub fn coercions_mut(&mut self) -> LocalTableInContextMut<'_, Vec<Coercion>> {
3116        LocalTableInContextMut { owner: self.owner, data: &mut self.coercions }
3117    }
3118
3119    pub fn coercions(&self) -> LocalTableInContext<'_, Vec<Coercion>> {
3120        LocalTableInContext { owner: self.owner, data: &self.coercions }
3121    }
3122
3123    pub fn field_projs_mut(&mut self) -> LocalTableInContextMut<'_, FieldProj> {
3124        LocalTableInContextMut { owner: self.owner, data: &mut self.field_projs }
3125    }
3126
3127    pub fn field_projs(&self) -> LocalTableInContext<'_, FieldProj> {
3128        LocalTableInContext { owner: self.owner, data: &self.field_projs }
3129    }
3130
3131    pub fn node_sorts_mut(&mut self) -> LocalTableInContextMut<'_, Sort> {
3132        LocalTableInContextMut { owner: self.owner, data: &mut self.node_sorts }
3133    }
3134
3135    pub fn node_sorts(&self) -> LocalTableInContext<'_, Sort> {
3136        LocalTableInContext { owner: self.owner, data: &self.node_sorts }
3137    }
3138
3139    pub fn record_ctors_mut(&mut self) -> LocalTableInContextMut<'_, RecordCtor> {
3140        LocalTableInContextMut { owner: self.owner, data: &mut self.record_ctors }
3141    }
3142
3143    pub fn record_ctors(&self) -> LocalTableInContext<'_, RecordCtor> {
3144        LocalTableInContext { owner: self.owner, data: &self.record_ctors }
3145    }
3146}
3147
3148impl<T> LocalTableInContextMut<'_, T> {
3149    pub fn insert(&mut self, fhir_id: FhirId, value: T) {
3150        tracked_span_assert_eq!(self.owner, fhir_id.owner);
3151        self.data.insert(fhir_id.local_id, value);
3152    }
3153}
3154
3155impl<'a, T> LocalTableInContext<'a, T> {
3156    pub fn get(&self, fhir_id: FhirId) -> Option<&'a T> {
3157        tracked_span_assert_eq!(self.owner, fhir_id.owner);
3158        self.data.get(&fhir_id.local_id)
3159    }
3160}
3161
3162fn can_auto_strong(fn_sig: &PolyFnSig) -> bool {
3163    struct RegionDetector {
3164        has_region: bool,
3165    }
3166
3167    impl fold::TypeFolder for RegionDetector {
3168        fn fold_region(&mut self, re: &Region) -> Region {
3169            self.has_region = true;
3170            *re
3171        }
3172    }
3173    let mut detector = RegionDetector { has_region: false };
3174    fn_sig
3175        .skip_binder_ref()
3176        .output()
3177        .skip_binder_ref()
3178        .ret
3179        .fold_with(&mut detector);
3180
3181    !detector.has_region
3182}
3183/// The [`auto_strong`] function transforms function signatures by automatically converting
3184/// mutable reference parameters into strong references with associated ensures clauses. This
3185/// transformation is applied only when the function signature does not already contain region
3186/// variables in its return type.
3187///
3188/// Specifically, given a source function of type
3189///
3190///    fn (x: &mut InnerTy) -> bool
3191///
3192/// By default the above gives us an `rty::FnSig`
3193///
3194///    forall<>. fn (x: &mut InnerTy) -> bool
3195///
3196/// Which this function then transforms to
3197///
3198///     forall<l0: Loc>. fn (x: &strg<l0:InnerTy>) -> bool ensures l0:InnerTy
3199pub fn auto_strong(
3200    genv: GlobalEnv,
3201    def_id: impl IntoQueryParam<DefId>,
3202    fn_sig: PolyFnSig,
3203) -> PolyFnSig {
3204    // TODO(auto-strong): we only *really* need the first check `can_auto_strong` here.
3205    // The other two skip `auto-strong` as doing it breaks various downstream things
3206    // that should be fixed.
3207    if !can_auto_strong(&fn_sig)
3208        || matches!(genv.def_kind(def_id), rustc_hir::def::DefKind::Closure)
3209        || !fn_sig.skip_binder_ref().lifted
3210    {
3211        return fn_sig;
3212    }
3213    let kind = BoundReftKind::Anon;
3214    let mut vars = fn_sig.vars().to_vec();
3215    let fn_sig = fn_sig.skip_binder();
3216    // new list of (bound_var, inner_ty)
3217    let mut strg_bvars = vec![];
3218    // new list of input types
3219    let mut strg_inputs = vec![];
3220    // 1. Traverse inputs collecting strong locations
3221    for ty in &fn_sig.inputs {
3222        let strg_ty = if let TyKind::Indexed(BaseTy::Ref(re, inner_ty, Mutability::Mut), _) =
3223            ty.kind()
3224            && !inner_ty.is_slice()
3225        // TODO(auto-strong): including `slice` breaks `tock` for some reason we should replicate in our own tests...
3226        {
3227            // if input is &mut InnerTy create a new bound var `loc` for the strong location
3228            let var = {
3229                let idx = vars.len() + strg_bvars.len();
3230                BoundVar::from_usize(idx)
3231            };
3232            strg_bvars.push((var, inner_ty.clone()));
3233            let loc = Loc::Var(Var::Bound(INNERMOST, BoundReft { var, kind }));
3234            // and transform to &strg<loc:InnerTy>
3235            Ty::strg_ref(*re, Path::new(loc, List::empty()), inner_ty.clone())
3236        } else {
3237            // else leave input type unchanged
3238            ty.clone()
3239        };
3240        strg_inputs.push(strg_ty);
3241    }
3242    // 2. Add bound vars for strong locations
3243    for _ in 0..strg_bvars.len() {
3244        vars.push(BoundVariableKind::Refine(Sort::Loc, InferMode::EVar, kind));
3245    }
3246    // 3. Add ensures for strong locations
3247    let output = fn_sig.output.map(|out| {
3248        let mut ens = out.ensures.to_vec();
3249        for (var, inner_ty) in strg_bvars {
3250            let loc = Loc::Var(Var::Bound(INNERMOST.shifted_in(1), BoundReft { var, kind }));
3251            let path = Path::new(loc, List::empty());
3252            ens.push(Ensures::Type(path, inner_ty.shift_in_escaping(1)));
3253        }
3254        FnOutput { ensures: List::from_vec(ens), ..out }
3255    });
3256
3257    // 4. Reconstruct fn sig with new inputs and output and vars
3258    let fn_sig = FnSig { inputs: List::from_vec(strg_inputs), output, ..fn_sig };
3259    Binder::bind_with_vars(fn_sig, vars.into())
3260}