flux_middle/rty/
mod.rs

1//! Defines how flux represents refinement types internally. Definitions in this module are used
2//! during refinement type checking. A couple of important differences between definitions in this
3//! module and in [`crate::fhir`] are:
4//!
5//! * Types in this module use debruijn indices to represent local binders.
6//! * Data structures are interned so they can be cheaply cloned.
7mod binder;
8pub mod canonicalize;
9mod expr;
10pub mod fold;
11pub mod normalize;
12mod pretty;
13pub mod refining;
14pub mod region_matching;
15pub mod subst;
16use std::{borrow::Cow, cmp::Ordering, fmt, hash::Hash, sync::LazyLock};
17
18pub use binder::{Binder, BoundReftKind, BoundVariableKind, BoundVariableKinds, EarlyBinder};
19use bitflags::bitflags;
20pub use expr::{
21    AggregateKind, AliasReft, BinOp, BoundReft, Constant, Ctor, ESpan, EVid, EarlyReftParam, Expr,
22    ExprKind, FieldProj, HoleKind, InternalFuncKind, KVar, KVid, Lambda, Loc, Name, NameProvenance,
23    Path, PrettyMap, PrettyVar, Real, SpecFuncKind, UnOp, Var,
24};
25pub use flux_arc_interner::List;
26use flux_arc_interner::{Interned, impl_internable, impl_slice_internable};
27use flux_common::{bug, tracked_span_assert_eq, tracked_span_bug};
28use flux_config::OverflowMode;
29use flux_macros::{TypeFoldable, TypeVisitable};
30pub use flux_rustc_bridge::ty::{
31    AliasKind, BoundRegion, BoundRegionKind, BoundVar, Const, ConstKind, ConstVid, DebruijnIndex,
32    EarlyParamRegion, LateParamRegion, LateParamRegionKind,
33    Region::{self, *},
34    RegionVid,
35};
36use flux_rustc_bridge::{
37    ToRustc,
38    mir::{Place, RawPtrKind},
39    ty::{self, GenericArgsExt as _, VariantDef},
40};
41use itertools::Itertools;
42pub use normalize::{FuncInfo, NormalizedDefns, local_deps};
43use refining::{Refine as _, Refiner};
44use rustc_abi;
45pub use rustc_abi::{FIRST_VARIANT, VariantIdx};
46use rustc_data_structures::{fx::FxIndexMap, snapshot_map::SnapshotMap, unord::UnordMap};
47use rustc_hir::{LangItem, Safety, def_id::DefId};
48use rustc_index::{IndexSlice, IndexVec, newtype_index};
49use rustc_macros::{Decodable, Encodable, TyDecodable, TyEncodable, extension};
50pub use rustc_middle::{
51    mir::Mutability,
52    ty::{AdtFlags, ClosureKind, FloatTy, IntTy, ParamConst, ParamTy, ScalarInt, UintTy},
53};
54use rustc_middle::{
55    query::IntoQueryParam,
56    ty::{TyCtxt, fast_reject::SimplifiedType},
57};
58use rustc_span::{DUMMY_SP, Span, Symbol, sym, symbol::kw};
59use rustc_type_ir::Upcast as _;
60pub use rustc_type_ir::{INNERMOST, TyVid};
61
62use self::fold::TypeFoldable;
63pub use crate::fhir::InferMode;
64use crate::{
65    LocalDefId,
66    def_id::{FluxDefId, FluxLocalDefId},
67    fhir::{self, FhirId, FluxOwnerId},
68    global_env::GlobalEnv,
69    pretty::{Pretty, PrettyCx},
70    queries::{QueryErr, QueryResult},
71    rty::subst::SortSubst,
72};
73
74/// The definition of the data sort automatically generated for a struct or enum.
75#[derive(Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
76pub struct AdtSortDef(Interned<AdtSortDefData>);
77
78#[derive(Debug, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
79pub struct AdtSortVariant {
80    /// The list of field names as declared in the `#[flux::refined_by(...)]` annotation
81    field_names: Vec<Symbol>,
82    /// The sort of each of the fields. Note that these can contain [sort variables]. Methods used
83    /// to access these sorts guarantee they are properly instantiated.
84    ///
85    /// [sort variables]: Sort::Var
86    sorts: List<Sort>,
87}
88
89impl AdtSortVariant {
90    pub fn new(fields: Vec<(Symbol, Sort)>) -> Self {
91        let (field_names, sorts) = fields.into_iter().unzip();
92        AdtSortVariant { field_names, sorts: List::from_vec(sorts) }
93    }
94
95    pub fn fields(&self) -> usize {
96        self.sorts.len()
97    }
98
99    pub fn field_names(&self) -> &Vec<Symbol> {
100        &self.field_names
101    }
102
103    pub fn sort_by_field_name(&self, args: &[Sort]) -> FxIndexMap<Symbol, Sort> {
104        std::iter::zip(&self.field_names, &self.sorts.fold_with(&mut SortSubst::new(args)))
105            .map(|(name, sort)| (*name, sort.clone()))
106            .collect()
107    }
108
109    pub fn field_by_name(
110        &self,
111        def_id: DefId,
112        args: &[Sort],
113        name: Symbol,
114    ) -> Option<(FieldProj, Sort)> {
115        let idx = self.field_names.iter().position(|it| name == *it)?;
116        let proj = FieldProj::Adt { def_id, field: idx as u32 };
117        let sort = self.sorts[idx].fold_with(&mut SortSubst::new(args));
118        Some((proj, sort))
119    }
120
121    pub fn field_sorts(&self, args: &[Sort]) -> List<Sort> {
122        self.sorts.fold_with(&mut SortSubst::new(args))
123    }
124
125    pub fn field_sorts_instantiate_identity(&self) -> List<Sort> {
126        self.sorts.clone()
127    }
128
129    pub fn projections(&self, def_id: DefId) -> impl Iterator<Item = FieldProj> {
130        (0..self.fields()).map(move |i| FieldProj::Adt { def_id, field: i as u32 })
131    }
132}
133
134#[derive(Debug, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
135struct AdtSortDefData {
136    /// [`DefId`] of the struct or enum this data sort is associated to.
137    def_id: DefId,
138    /// The list of the type parameters used in the `#[flux::refined_by(..)]` annotation.
139    ///
140    /// See [`fhir::RefinedBy::sort_params`] for more details. This is a version of that but using
141    /// [`ParamTy`] instead of [`DefId`].
142    ///
143    /// The length of this list corresponds to the number of sort variables bound by this definition.
144    params: Vec<ParamTy>,
145    /// A vec of variants of the ADT;
146    /// - a `struct` sort -- used for types with a `refined_by` has a single variant;
147    /// - a `reflected` sort -- used for `reflected` enums have multiple variants
148    variants: IndexVec<VariantIdx, AdtSortVariant>,
149    is_reflected: bool,
150    is_struct: bool,
151}
152
153impl AdtSortDef {
154    pub fn new(
155        def_id: DefId,
156        params: Vec<ParamTy>,
157        variants: IndexVec<VariantIdx, AdtSortVariant>,
158        is_reflected: bool,
159        is_struct: bool,
160    ) -> Self {
161        Self(Interned::new(AdtSortDefData { def_id, params, variants, is_reflected, is_struct }))
162    }
163
164    pub fn did(&self) -> DefId {
165        self.0.def_id
166    }
167
168    pub fn variant(&self, idx: VariantIdx) -> &AdtSortVariant {
169        &self.0.variants[idx]
170    }
171
172    pub fn variants(&self) -> &IndexSlice<VariantIdx, AdtSortVariant> {
173        &self.0.variants
174    }
175
176    pub fn opt_struct_variant(&self) -> Option<&AdtSortVariant> {
177        if self.is_struct() { Some(self.struct_variant()) } else { None }
178    }
179
180    #[track_caller]
181    pub fn struct_variant(&self) -> &AdtSortVariant {
182        tracked_span_assert_eq!(self.0.is_struct, true);
183        &self.0.variants[FIRST_VARIANT]
184    }
185
186    pub fn is_reflected(&self) -> bool {
187        self.0.is_reflected
188    }
189
190    pub fn is_struct(&self) -> bool {
191        self.0.is_struct
192    }
193
194    pub fn to_sort(&self, args: &[GenericArg]) -> Sort {
195        let sorts = self
196            .filter_generic_args(args)
197            .map(|arg| arg.expect_base().sort())
198            .collect();
199
200        Sort::App(SortCtor::Adt(self.clone()), sorts)
201    }
202
203    /// Given a list of generic args, returns an iterator of the generic arguments that should be
204    /// mapped to sorts for instantiation.
205    pub fn filter_generic_args<'a, A>(&'a self, args: &'a [A]) -> impl Iterator<Item = &'a A> + 'a {
206        self.0.params.iter().map(|p| &args[p.index as usize])
207    }
208
209    pub fn identity_args(&self) -> List<Sort> {
210        (0..self.0.params.len())
211            .map(|i| Sort::Var(ParamSort::from(i)))
212            .collect()
213    }
214
215    /// Gives the number of sort variables bound by this definition.
216    pub fn param_count(&self) -> usize {
217        self.0.params.len()
218    }
219}
220
221#[derive(Debug, Clone, Default, Encodable, Decodable)]
222pub struct Generics {
223    pub parent: Option<DefId>,
224    pub parent_count: usize,
225    pub own_params: List<GenericParamDef>,
226    pub has_self: bool,
227}
228
229impl Generics {
230    pub fn count(&self) -> usize {
231        self.parent_count + self.own_params.len()
232    }
233
234    pub fn own_default_count(&self) -> usize {
235        self.own_params
236            .iter()
237            .filter(|param| {
238                match param.kind {
239                    GenericParamDefKind::Type { has_default }
240                    | GenericParamDefKind::Const { has_default }
241                    | GenericParamDefKind::Base { has_default } => has_default,
242                    GenericParamDefKind::Lifetime => false,
243                }
244            })
245            .count()
246    }
247
248    pub fn param_at(&self, param_index: usize, genv: GlobalEnv) -> QueryResult<GenericParamDef> {
249        if let Some(index) = param_index.checked_sub(self.parent_count) {
250            Ok(self.own_params[index].clone())
251        } else {
252            let parent = self.parent.expect("parent_count > 0 but no parent?");
253            genv.generics_of(parent)?.param_at(param_index, genv)
254        }
255    }
256
257    pub fn const_params(&self, genv: GlobalEnv) -> QueryResult<Vec<(ParamConst, Sort)>> {
258        let mut res = vec![];
259        for i in 0..self.count() {
260            let param = self.param_at(i, genv)?;
261            if let GenericParamDefKind::Const { .. } = param.kind
262                && let Some(sort) = genv.sort_of_def_id(param.def_id)?
263            {
264                let param_const = ParamConst { name: param.name, index: param.index };
265                res.push((param_const, sort));
266            }
267        }
268        Ok(res)
269    }
270}
271
272#[derive(Debug, Clone, TyEncodable, TyDecodable)]
273pub struct RefinementGenerics {
274    pub parent: Option<DefId>,
275    pub parent_count: usize,
276    pub own_params: List<RefineParam>,
277}
278
279#[derive(
280    PartialEq, Eq, Debug, Clone, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
281)]
282pub struct RefineParam {
283    pub sort: Sort,
284    pub name: Symbol,
285    pub mode: InferMode,
286}
287
288#[derive(Debug, Clone, PartialEq, Eq, Hash, Encodable, Decodable)]
289pub struct GenericParamDef {
290    pub kind: GenericParamDefKind,
291    pub def_id: DefId,
292    pub index: u32,
293    pub name: Symbol,
294}
295
296#[derive(Copy, Debug, Clone, PartialEq, Eq, Hash, Encodable, Decodable)]
297pub enum GenericParamDefKind {
298    Type { has_default: bool },
299    Base { has_default: bool },
300    Lifetime,
301    Const { has_default: bool },
302}
303
304pub const SELF_PARAM_TY: ParamTy = ParamTy { index: 0, name: kw::SelfUpper };
305
306#[derive(Debug, Clone, TyEncodable, TyDecodable)]
307pub struct GenericPredicates {
308    pub parent: Option<DefId>,
309    pub predicates: List<Clause>,
310}
311
312#[derive(
313    Debug, PartialEq, Eq, Hash, Clone, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
314)]
315pub struct Clause {
316    kind: Binder<ClauseKind>,
317}
318
319impl Clause {
320    pub fn new(vars: impl Into<List<BoundVariableKind>>, kind: ClauseKind) -> Self {
321        Clause { kind: Binder::bind_with_vars(kind, vars.into()) }
322    }
323
324    pub fn kind(&self) -> Binder<ClauseKind> {
325        self.kind.clone()
326    }
327
328    fn as_trait_clause(&self) -> Option<Binder<TraitPredicate>> {
329        let clause = self.kind();
330        if let ClauseKind::Trait(trait_clause) = clause.skip_binder_ref() {
331            Some(clause.rebind(trait_clause.clone()))
332        } else {
333            None
334        }
335    }
336
337    pub fn as_projection_clause(&self) -> Option<Binder<ProjectionPredicate>> {
338        let clause = self.kind();
339        if let ClauseKind::Projection(proj_clause) = clause.skip_binder_ref() {
340            Some(clause.rebind(proj_clause.clone()))
341        } else {
342            None
343        }
344    }
345
346    // FIXME(nilehmann) we should deal with the binder in all the places this is used instead of
347    // blindly skipping it here
348    pub fn kind_skipping_binder(&self) -> ClauseKind {
349        self.kind.clone().skip_binder()
350    }
351
352    /// Group `Fn` trait clauses with their corresponding `FnOnce::Output` projection
353    /// predicate. This assumes there's exactly one corresponding projection predicate and will
354    /// crash otherwise.
355    pub fn split_off_fn_trait_clauses(
356        genv: GlobalEnv,
357        clauses: &Clauses,
358    ) -> (Vec<Clause>, Vec<Binder<FnTraitPredicate>>) {
359        let mut fn_trait_clauses = vec![];
360        let mut fn_trait_output_clauses = vec![];
361        let mut rest = vec![];
362        for clause in clauses {
363            if let Some(trait_clause) = clause.as_trait_clause()
364                && let Some(kind) = genv.tcx().fn_trait_kind_from_def_id(trait_clause.def_id())
365            {
366                fn_trait_clauses.push((kind, trait_clause));
367            } else if let Some(proj_clause) = clause.as_projection_clause()
368                && genv.is_fn_output(proj_clause.projection_def_id())
369            {
370                fn_trait_output_clauses.push(proj_clause);
371            } else {
372                rest.push(clause.clone());
373            }
374        }
375        let fn_trait_clauses = fn_trait_clauses
376            .into_iter()
377            .map(|(kind, fn_trait_clause)| {
378                let mut candidates = vec![];
379                for fn_trait_output_clause in &fn_trait_output_clauses {
380                    if fn_trait_output_clause.self_ty() == fn_trait_clause.self_ty() {
381                        candidates.push(fn_trait_output_clause.clone());
382                    }
383                }
384                tracked_span_assert_eq!(candidates.len(), 1);
385                let proj_pred = candidates.pop().unwrap().skip_binder();
386                fn_trait_clause.map(|fn_trait_clause| {
387                    FnTraitPredicate {
388                        kind,
389                        self_ty: fn_trait_clause.self_ty().to_ty(),
390                        tupled_args: fn_trait_clause.trait_ref.args[1].expect_base().to_ty(),
391                        output: proj_pred.term.to_ty(),
392                    }
393                })
394            })
395            .collect_vec();
396        (rest, fn_trait_clauses)
397    }
398}
399
400impl<'tcx> ToRustc<'tcx> for Clause {
401    type T = rustc_middle::ty::Clause<'tcx>;
402
403    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
404        self.kind.to_rustc(tcx).upcast(tcx)
405    }
406}
407
408impl From<Binder<ClauseKind>> for Clause {
409    fn from(kind: Binder<ClauseKind>) -> Self {
410        Clause { kind }
411    }
412}
413
414pub type Clauses = List<Clause>;
415
416#[derive(
417    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
418)]
419pub enum ClauseKind {
420    Trait(TraitPredicate),
421    Projection(ProjectionPredicate),
422    RegionOutlives(RegionOutlivesPredicate),
423    TypeOutlives(TypeOutlivesPredicate),
424    ConstArgHasType(Const, Ty),
425}
426
427impl<'tcx> ToRustc<'tcx> for ClauseKind {
428    type T = rustc_middle::ty::ClauseKind<'tcx>;
429
430    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
431        match self {
432            ClauseKind::Trait(trait_predicate) => {
433                rustc_middle::ty::ClauseKind::Trait(trait_predicate.to_rustc(tcx))
434            }
435            ClauseKind::Projection(projection_predicate) => {
436                rustc_middle::ty::ClauseKind::Projection(projection_predicate.to_rustc(tcx))
437            }
438            ClauseKind::RegionOutlives(outlives_predicate) => {
439                rustc_middle::ty::ClauseKind::RegionOutlives(outlives_predicate.to_rustc(tcx))
440            }
441            ClauseKind::TypeOutlives(outlives_predicate) => {
442                rustc_middle::ty::ClauseKind::TypeOutlives(outlives_predicate.to_rustc(tcx))
443            }
444            ClauseKind::ConstArgHasType(constant, ty) => {
445                rustc_middle::ty::ClauseKind::ConstArgHasType(
446                    constant.to_rustc(tcx),
447                    ty.to_rustc(tcx),
448                )
449            }
450        }
451    }
452}
453
454#[derive(Eq, PartialEq, Hash, Clone, Debug, TyEncodable, TyDecodable)]
455pub struct OutlivesPredicate<T>(pub T, pub Region);
456
457pub type TypeOutlivesPredicate = OutlivesPredicate<Ty>;
458pub type RegionOutlivesPredicate = OutlivesPredicate<Region>;
459
460impl<'tcx, V: ToRustc<'tcx>> ToRustc<'tcx> for OutlivesPredicate<V> {
461    type T = rustc_middle::ty::OutlivesPredicate<'tcx, V::T>;
462
463    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
464        rustc_middle::ty::OutlivesPredicate(self.0.to_rustc(tcx), self.1.to_rustc(tcx))
465    }
466}
467
468#[derive(
469    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
470)]
471pub struct TraitPredicate {
472    pub trait_ref: TraitRef,
473}
474
475impl TraitPredicate {
476    fn self_ty(&self) -> SubsetTyCtor {
477        self.trait_ref.self_ty()
478    }
479}
480
481impl<'tcx> ToRustc<'tcx> for TraitPredicate {
482    type T = rustc_middle::ty::TraitPredicate<'tcx>;
483
484    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
485        rustc_middle::ty::TraitPredicate {
486            polarity: rustc_middle::ty::PredicatePolarity::Positive,
487            trait_ref: self.trait_ref.to_rustc(tcx),
488        }
489    }
490}
491
492pub type PolyTraitPredicate = Binder<TraitPredicate>;
493
494impl PolyTraitPredicate {
495    fn def_id(&self) -> DefId {
496        self.skip_binder_ref().trait_ref.def_id
497    }
498
499    fn self_ty(&self) -> Binder<SubsetTyCtor> {
500        self.clone().map(|predicate| predicate.self_ty())
501    }
502}
503
504#[derive(
505    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
506)]
507pub struct TraitRef {
508    pub def_id: DefId,
509    pub args: GenericArgs,
510}
511
512impl TraitRef {
513    pub fn self_ty(&self) -> SubsetTyCtor {
514        self.args[0].expect_base().clone()
515    }
516}
517
518impl<'tcx> ToRustc<'tcx> for TraitRef {
519    type T = rustc_middle::ty::TraitRef<'tcx>;
520
521    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
522        rustc_middle::ty::TraitRef::new(tcx, self.def_id, self.args.to_rustc(tcx))
523    }
524}
525
526pub type PolyTraitRef = Binder<TraitRef>;
527
528impl PolyTraitRef {
529    pub fn def_id(&self) -> DefId {
530        self.as_ref().skip_binder().def_id
531    }
532}
533
534#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
535pub enum ExistentialPredicate {
536    Trait(ExistentialTraitRef),
537    Projection(ExistentialProjection),
538    AutoTrait(DefId),
539}
540
541pub type PolyExistentialPredicate = Binder<ExistentialPredicate>;
542
543impl ExistentialPredicate {
544    /// See [`rustc_middle::ty::ExistentialPredicateStableCmpExt`]
545    pub fn stable_cmp(&self, tcx: TyCtxt, other: &Self) -> Ordering {
546        match (self, other) {
547            (ExistentialPredicate::Trait(_), ExistentialPredicate::Trait(_)) => Ordering::Equal,
548            (ExistentialPredicate::Projection(a), ExistentialPredicate::Projection(b)) => {
549                tcx.def_path_hash(a.def_id)
550                    .cmp(&tcx.def_path_hash(b.def_id))
551            }
552            (ExistentialPredicate::AutoTrait(a), ExistentialPredicate::AutoTrait(b)) => {
553                tcx.def_path_hash(*a).cmp(&tcx.def_path_hash(*b))
554            }
555            (ExistentialPredicate::Trait(_), _) => Ordering::Less,
556            (ExistentialPredicate::Projection(_), ExistentialPredicate::Trait(_)) => {
557                Ordering::Greater
558            }
559            (ExistentialPredicate::Projection(_), _) => Ordering::Less,
560            (ExistentialPredicate::AutoTrait(_), _) => Ordering::Greater,
561        }
562    }
563}
564
565impl<'tcx> ToRustc<'tcx> for ExistentialPredicate {
566    type T = rustc_middle::ty::ExistentialPredicate<'tcx>;
567
568    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
569        match self {
570            ExistentialPredicate::Trait(trait_ref) => {
571                let trait_ref = rustc_middle::ty::ExistentialTraitRef::new_from_args(
572                    tcx,
573                    trait_ref.def_id,
574                    trait_ref.args.to_rustc(tcx),
575                );
576                rustc_middle::ty::ExistentialPredicate::Trait(trait_ref)
577            }
578            ExistentialPredicate::Projection(projection) => {
579                rustc_middle::ty::ExistentialPredicate::Projection(
580                    rustc_middle::ty::ExistentialProjection::new_from_args(
581                        tcx,
582                        projection.def_id,
583                        projection.args.to_rustc(tcx),
584                        projection.term.skip_binder_ref().to_rustc(tcx).into(),
585                    ),
586                )
587            }
588            ExistentialPredicate::AutoTrait(def_id) => {
589                rustc_middle::ty::ExistentialPredicate::AutoTrait(*def_id)
590            }
591        }
592    }
593}
594
595#[derive(
596    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
597)]
598pub struct ExistentialTraitRef {
599    pub def_id: DefId,
600    pub args: GenericArgs,
601}
602
603pub type PolyExistentialTraitRef = Binder<ExistentialTraitRef>;
604
605impl PolyExistentialTraitRef {
606    pub fn def_id(&self) -> DefId {
607        self.as_ref().skip_binder().def_id
608    }
609}
610
611#[derive(
612    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
613)]
614pub struct ExistentialProjection {
615    pub def_id: DefId,
616    pub args: GenericArgs,
617    pub term: SubsetTyCtor,
618}
619
620#[derive(
621    PartialEq, Eq, Hash, Debug, Clone, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
622)]
623pub struct ProjectionPredicate {
624    pub projection_ty: AliasTy,
625    pub term: SubsetTyCtor,
626}
627
628impl Pretty for ProjectionPredicate {
629    fn fmt(&self, _cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
630        write!(
631            f,
632            "ProjectionPredicate << projection_ty = {:?}, term = {:?} >>",
633            self.projection_ty, self.term
634        )
635    }
636}
637
638impl ProjectionPredicate {
639    pub fn self_ty(&self) -> SubsetTyCtor {
640        self.projection_ty.self_ty().clone()
641    }
642}
643
644impl<'tcx> ToRustc<'tcx> for ProjectionPredicate {
645    type T = rustc_middle::ty::ProjectionPredicate<'tcx>;
646
647    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
648        rustc_middle::ty::ProjectionPredicate {
649            projection_term: rustc_middle::ty::AliasTerm::new_from_args(
650                tcx,
651                self.projection_ty.def_id,
652                self.projection_ty.args.to_rustc(tcx),
653            ),
654            term: self.term.as_bty_skipping_binder().to_rustc(tcx).into(),
655        }
656    }
657}
658
659pub type PolyProjectionPredicate = Binder<ProjectionPredicate>;
660
661impl PolyProjectionPredicate {
662    pub fn projection_def_id(&self) -> DefId {
663        self.skip_binder_ref().projection_ty.def_id
664    }
665
666    pub fn self_ty(&self) -> Binder<SubsetTyCtor> {
667        self.clone().map(|predicate| predicate.self_ty())
668    }
669}
670
671#[derive(
672    Clone, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
673)]
674pub struct FnTraitPredicate {
675    pub self_ty: Ty,
676    pub tupled_args: Ty,
677    pub output: Ty,
678    pub kind: ClosureKind,
679}
680
681impl Pretty for FnTraitPredicate {
682    fn fmt(&self, _cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
683        write!(
684            f,
685            "self = {:?}, args = {:?}, output = {:?}, kind = {}",
686            self.self_ty, self.tupled_args, self.output, self.kind
687        )
688    }
689}
690
691impl FnTraitPredicate {
692    pub fn fndef_sig(&self) -> FnSig {
693        let inputs = self.tupled_args.expect_tuple().iter().cloned().collect();
694        let ret = self.output.clone().shift_in_escaping(1);
695        let output = Binder::bind_with_vars(FnOutput::new(ret, vec![]), List::empty());
696        FnSig::new(
697            Safety::Safe,
698            rustc_abi::ExternAbi::Rust,
699            List::empty(),
700            inputs,
701            output,
702            Expr::ff(),
703            false,
704        )
705    }
706}
707
708pub fn to_closure_sig(
709    tcx: TyCtxt,
710    closure_id: LocalDefId,
711    tys: &[Ty],
712    args: &flux_rustc_bridge::ty::GenericArgs,
713    poly_sig: &PolyFnSig,
714    no_panic: bool,
715) -> PolyFnSig {
716    let closure_args = args.as_closure();
717    let kind_ty = closure_args.kind_ty().to_rustc(tcx);
718    let Some(kind) = kind_ty.to_opt_closure_kind() else {
719        bug!("to_closure_sig: expected closure kind, found {kind_ty:?}");
720    };
721
722    let mut vars = poly_sig.vars().clone().to_vec();
723    let fn_sig = poly_sig.clone().skip_binder();
724    let closure_ty = Ty::closure(closure_id.into(), tys, args, no_panic);
725    let env_ty = match kind {
726        ClosureKind::Fn => {
727            vars.push(BoundVariableKind::Region(BoundRegionKind::ClosureEnv));
728            let br = BoundRegion {
729                var: BoundVar::from_usize(vars.len() - 1),
730                kind: BoundRegionKind::ClosureEnv,
731            };
732            Ty::mk_ref(ReBound(INNERMOST, br), closure_ty, Mutability::Not)
733        }
734        ClosureKind::FnMut => {
735            vars.push(BoundVariableKind::Region(BoundRegionKind::ClosureEnv));
736            let br = BoundRegion {
737                var: BoundVar::from_usize(vars.len() - 1),
738                kind: BoundRegionKind::ClosureEnv,
739            };
740            Ty::mk_ref(ReBound(INNERMOST, br), closure_ty, Mutability::Mut)
741        }
742        ClosureKind::FnOnce => closure_ty,
743    };
744
745    let inputs = std::iter::once(env_ty)
746        .chain(fn_sig.inputs().iter().cloned())
747        .collect::<Vec<_>>();
748    let output = fn_sig.output().clone();
749
750    let fn_sig = crate::rty::FnSig::new(
751        fn_sig.safety,
752        fn_sig.abi,
753        fn_sig.requires.clone(),
754        inputs.into(),
755        output,
756        if no_panic { crate::rty::Expr::tt() } else { crate::rty::Expr::ff() },
757        false,
758    );
759
760    PolyFnSig::bind_with_vars(fn_sig, List::from(vars))
761}
762
763#[derive(Clone, PartialEq, Eq, Hash, Debug)]
764pub struct CoroutineObligPredicate {
765    pub def_id: DefId,
766    pub resume_ty: Ty,
767    pub upvar_tys: List<Ty>,
768    pub output: Ty,
769    pub args: flux_rustc_bridge::ty::GenericArgs,
770}
771
772#[derive(Copy, Clone, Encodable, Decodable, Hash, PartialEq, Eq)]
773pub struct AssocReft {
774    pub def_id: FluxDefId,
775    // NOTE: Field is used to denote final associated generic refinements on Traits
776    pub final_: bool,
777    pub span: Span,
778}
779
780impl AssocReft {
781    pub fn new(def_id: FluxDefId, final_: bool, span: Span) -> Self {
782        Self { def_id, final_, span }
783    }
784
785    pub fn name(&self) -> Symbol {
786        self.def_id.name()
787    }
788
789    pub fn def_id(&self) -> FluxDefId {
790        self.def_id
791    }
792}
793
794#[derive(Clone, Encodable, Decodable)]
795pub struct AssocRefinements {
796    pub items: List<AssocReft>,
797}
798
799impl Default for AssocRefinements {
800    fn default() -> Self {
801        Self { items: List::empty() }
802    }
803}
804
805impl AssocRefinements {
806    pub fn get(&self, assoc_id: FluxDefId) -> AssocReft {
807        *self
808            .items
809            .into_iter()
810            .find(|it| it.def_id == assoc_id)
811            .unwrap_or_else(|| bug!("caller should guarantee existence of associated refinement"))
812    }
813
814    pub fn find(&self, name: Symbol) -> Option<AssocReft> {
815        Some(*self.items.into_iter().find(|it| it.name() == name)?)
816    }
817}
818
819#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
820pub enum SortCtor {
821    Set,
822    Map,
823    Adt(AdtSortDef),
824    User(FluxDefId),
825}
826
827newtype_index! {
828    /// [`ParamSort`] is used for polymorphic sorts (`Set`, `Map`, etc.) and [bit-vector size parameters].
829    /// They should occur "bound" under a [`PolyFuncSort`] or an [`AdtSortDef`]. We assume there's a
830    /// single binder and a [`ParamSort`] represents a variable as an index into the list of variables
831    /// bound by that binder, i.e., the representation doesnt't support higher-ranked sorts.
832    ///
833    /// [bit-vector size parameters]: BvSize::Param
834    #[debug_format = "?{}s"]
835    #[encodable]
836    pub struct ParamSort {}
837}
838
839newtype_index! {
840    /// A *sort* *v*variable *id*
841    #[debug_format = "?{}s"]
842    #[encodable]
843    pub struct SortVid {}
844}
845
846impl ena::unify::UnifyKey for SortVid {
847    type Value = SortVarVal;
848
849    #[inline]
850    fn index(&self) -> u32 {
851        self.as_u32()
852    }
853
854    #[inline]
855    fn from_index(u: u32) -> Self {
856        SortVid::from_u32(u)
857    }
858
859    fn tag() -> &'static str {
860        "SortVid"
861    }
862}
863
864bitflags! {
865    /// A *sort constraint* is a set of operations a sort must support.
866    ///
867    /// During sort checking, we accumulate the operations required for each sort variable. As
868    /// unification progresses, these constraints become more specific, i.e, a sort must support
869    /// more operations to satisfy them.
870    #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
871    pub struct SortCstr: u16 {
872        /// An empty constraint (any sort satisfies it)
873        const BOT     = 0b0000000000;
874        /// `*`
875        const MUL     = 0b0000000001;
876        /// `/`
877        const DIV     = 0b0000000010;
878        /// `%`
879        const MOD     = 0b0000000100;
880        /// `+`
881        const ADD     = 0b0000001000;
882        /// `-`
883        const SUB     = 0b0000010000;
884        /// `|`
885        const BIT_OR  = 0b0000100000;
886        /// `&`
887        const BIT_AND = 0b0001000000;
888        /// `>>`
889        const BIT_SHL = 0b0010000000;
890        /// `<<`
891        const BIT_SHR = 0b0100000000;
892        /// `^`
893        const BIT_XOR = 0b1000000000;
894
895        /// The set of operations supported by all _numeric_ sorts.
896        const NUMERIC = Self::ADD.bits() | Self::SUB.bits() | Self::MUL.bits() | Self::DIV.bits();
897        /// The set of operations supported by integers.
898        const INT = Self::DIV.bits()
899            | Self::MUL.bits()
900            | Self::MOD.bits()
901            | Self::ADD.bits()
902            | Self::SUB.bits();
903        /// The set of operations supported by reals.
904        const REAL = Self::ADD.bits() | Self::SUB.bits() | Self::MUL.bits() | Self::DIV.bits();
905        /// The set of operations supported by bit vectors.
906        const BITVEC =  Self::DIV.bits()
907            | Self::MUL.bits()
908            | Self::MOD.bits()
909            | Self::ADD.bits()
910            | Self::SUB.bits()
911            | Self::BIT_OR.bits()
912            | Self::BIT_AND.bits()
913            | Self::BIT_SHL.bits()
914            | Self::BIT_SHR.bits()
915            | Self::BIT_XOR.bits();
916        /// The set of operations supported by sets.
917        const SET = Self::SUB.bits() | Self::BIT_OR.bits() | Self::BIT_AND.bits();
918    }
919}
920
921impl SortCstr {
922    /// Returns a constraint that only requires the specified binary operation.
923    pub fn from_bin_op(op: fhir::BinOp) -> Self {
924        match op {
925            fhir::BinOp::Add => Self::ADD,
926            fhir::BinOp::Sub => Self::SUB,
927            fhir::BinOp::Mul => Self::MUL,
928            fhir::BinOp::Div => Self::DIV,
929            fhir::BinOp::Mod => Self::MOD,
930            fhir::BinOp::BitAnd => Self::BIT_AND,
931            fhir::BinOp::BitOr => Self::BIT_OR,
932            fhir::BinOp::BitXor => Self::BIT_XOR,
933            fhir::BinOp::BitShl => Self::BIT_SHL,
934            fhir::BinOp::BitShr => Self::BIT_SHR,
935            _ => bug!("{op:?} not supported as a constraint"),
936        }
937    }
938
939    /// Returns whether a sort satisfies this constraint
940    fn satisfy(self, sort: &Sort) -> bool {
941        match sort {
942            Sort::Int => SortCstr::INT.contains(self),
943            Sort::Real => SortCstr::REAL.contains(self),
944            Sort::BitVec(_) => SortCstr::BITVEC.contains(self),
945            Sort::App(SortCtor::Set, _) => SortCstr::SET.contains(self),
946            _ => self == SortCstr::BOT,
947        }
948    }
949}
950
951/// Unification value for sort variables used during sort checking.
952#[derive(Debug, Clone, PartialEq, Eq)]
953pub enum SortVarVal {
954    /// The variable is not yet solved but the solution must satisfy some constraint.
955    Unsolved(SortCstr),
956    /// The variable has been solved to a sort.
957    Solved(Sort),
958}
959
960impl Default for SortVarVal {
961    fn default() -> Self {
962        SortVarVal::Unsolved(SortCstr::BOT)
963    }
964}
965
966impl SortVarVal {
967    pub fn solved_or(&self, sort: &Sort) -> Sort {
968        match self {
969            SortVarVal::Unsolved(_) => sort.clone(),
970            SortVarVal::Solved(sort) => sort.clone(),
971        }
972    }
973
974    pub fn map_solved(&self, f: impl FnOnce(&Sort) -> Sort) -> SortVarVal {
975        match self {
976            SortVarVal::Unsolved(cstr) => SortVarVal::Unsolved(*cstr),
977            SortVarVal::Solved(sort) => SortVarVal::Solved(f(sort)),
978        }
979    }
980}
981
982impl ena::unify::UnifyValue for SortVarVal {
983    type Error = ();
984
985    fn unify_values(value1: &Self, value2: &Self) -> Result<Self, Self::Error> {
986        match (value1, value2) {
987            (SortVarVal::Solved(s1), SortVarVal::Solved(s2)) if s1 == s2 => {
988                Ok(SortVarVal::Solved(s1.clone()))
989            }
990            (SortVarVal::Unsolved(a), SortVarVal::Unsolved(b)) => Ok(SortVarVal::Unsolved(*a | *b)),
991            (SortVarVal::Unsolved(v), SortVarVal::Solved(sort))
992            | (SortVarVal::Solved(sort), SortVarVal::Unsolved(v))
993                if v.satisfy(sort) =>
994            {
995                Ok(SortVarVal::Solved(sort.clone()))
996            }
997            _ => Err(()),
998        }
999    }
1000}
1001
1002newtype_index! {
1003    /// A *b*it *v*ector *size* *v*variable *id*
1004    #[debug_format = "?{}size"]
1005    #[encodable]
1006    pub struct BvSizeVid {}
1007}
1008
1009impl ena::unify::UnifyKey for BvSizeVid {
1010    type Value = Option<BvSize>;
1011
1012    #[inline]
1013    fn index(&self) -> u32 {
1014        self.as_u32()
1015    }
1016
1017    #[inline]
1018    fn from_index(u: u32) -> Self {
1019        BvSizeVid::from_u32(u)
1020    }
1021
1022    fn tag() -> &'static str {
1023        "BvSizeVid"
1024    }
1025}
1026
1027impl ena::unify::EqUnifyValue for BvSize {}
1028
1029#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1030pub enum Sort {
1031    Int,
1032    Bool,
1033    Real,
1034    BitVec(BvSize),
1035    Str,
1036    Char,
1037    Loc,
1038    Param(ParamTy),
1039    Tuple(List<Sort>),
1040    Alias(AliasKind, AliasTy),
1041    Func(PolyFuncSort),
1042    App(SortCtor, List<Sort>),
1043    Var(ParamSort),
1044    Infer(SortVid),
1045    Err,
1046}
1047
1048pub enum CastKind {
1049    /// Identity cast, which is erasable (e.g. int -> int, char -> int)
1050    Identity,
1051    /// From bool to int
1052    BoolToInt,
1053    /// Casts to unit index, (e.g. int -> float)
1054    IntoUnit,
1055    /// Uninterpreted casts, only allowed with explicit flag
1056    Uninterpreted,
1057}
1058
1059impl Sort {
1060    pub fn tuple(sorts: impl Into<List<Sort>>) -> Self {
1061        Sort::Tuple(sorts.into())
1062    }
1063
1064    pub fn app(ctor: SortCtor, sorts: List<Sort>) -> Self {
1065        Sort::App(ctor, sorts)
1066    }
1067
1068    pub fn unit() -> Self {
1069        Self::tuple(vec![])
1070    }
1071
1072    #[track_caller]
1073    pub fn expect_func(&self) -> &PolyFuncSort {
1074        if let Sort::Func(sort) = self { sort } else { bug!("expected `Sort::Func`") }
1075    }
1076
1077    pub fn is_loc(&self) -> bool {
1078        matches!(self, Sort::Loc)
1079    }
1080
1081    pub fn is_unit(&self) -> bool {
1082        matches!(self, Sort::Tuple(sorts) if sorts.is_empty())
1083    }
1084
1085    pub fn is_unit_adt(&self) -> Option<DefId> {
1086        if let Sort::App(SortCtor::Adt(sort_def), _) = self
1087            && let Some(variant) = sort_def.opt_struct_variant()
1088            && variant.fields() == 0
1089        {
1090            Some(sort_def.did())
1091        } else {
1092            None
1093        }
1094    }
1095
1096    /// Whether the sort is a function with return sort bool
1097    pub fn is_pred(&self) -> bool {
1098        matches!(self, Sort::Func(fsort) if fsort.skip_binders().output().is_bool())
1099    }
1100
1101    /// Returns `true` if the sort is [`Bool`].
1102    ///
1103    /// [`Bool`]: Sort::Bool
1104    #[must_use]
1105    pub fn is_bool(&self) -> bool {
1106        matches!(self, Self::Bool)
1107    }
1108
1109    pub fn cast_kind(self: &Sort, to: &Sort) -> CastKind {
1110        if self == to
1111            || (matches!(self, Sort::Char | Sort::Int) && matches!(to, Sort::Char | Sort::Int))
1112        {
1113            CastKind::Identity
1114        } else if matches!(self, Sort::Bool) && matches!(to, Sort::Int) {
1115            CastKind::BoolToInt
1116        } else if to.is_unit() {
1117            CastKind::IntoUnit
1118        } else {
1119            CastKind::Uninterpreted
1120        }
1121    }
1122
1123    pub fn walk(&self, mut f: impl FnMut(&Sort, &[FieldProj])) {
1124        fn go(sort: &Sort, f: &mut impl FnMut(&Sort, &[FieldProj]), proj: &mut Vec<FieldProj>) {
1125            match sort {
1126                Sort::Tuple(flds) => {
1127                    for (i, sort) in flds.iter().enumerate() {
1128                        proj.push(FieldProj::Tuple { arity: flds.len(), field: i as u32 });
1129                        go(sort, f, proj);
1130                        proj.pop();
1131                    }
1132                }
1133                Sort::App(SortCtor::Adt(sort_def), args) if sort_def.is_struct() => {
1134                    let field_sorts = sort_def.struct_variant().field_sorts(args);
1135                    for (i, sort) in field_sorts.iter().enumerate() {
1136                        proj.push(FieldProj::Adt { def_id: sort_def.did(), field: i as u32 });
1137                        go(sort, f, proj);
1138                        proj.pop();
1139                    }
1140                }
1141                _ => {
1142                    f(sort, proj);
1143                }
1144            }
1145        }
1146        go(self, &mut f, &mut vec![]);
1147    }
1148}
1149
1150/// The size of a [bit-vector]
1151///
1152/// [bit-vector]: Sort::BitVec
1153#[derive(Clone, Copy, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1154pub enum BvSize {
1155    /// A fixed size
1156    Fixed(u32),
1157    /// A size that has been parameterized, e.g., bound under a [`PolyFuncSort`]
1158    Param(ParamSort),
1159    /// A size that needs to be inferred. Used during sort checking to instantiate bit-vector
1160    /// sizes at call-sites.
1161    Infer(BvSizeVid),
1162}
1163
1164impl rustc_errors::IntoDiagArg for Sort {
1165    fn into_diag_arg(self, _path: &mut Option<std::path::PathBuf>) -> rustc_errors::DiagArgValue {
1166        rustc_errors::DiagArgValue::Str(Cow::Owned(format!("{self:?}")))
1167    }
1168}
1169
1170impl rustc_errors::IntoDiagArg for FuncSort {
1171    fn into_diag_arg(self, _path: &mut Option<std::path::PathBuf>) -> rustc_errors::DiagArgValue {
1172        rustc_errors::DiagArgValue::Str(Cow::Owned(format!("{self:?}")))
1173    }
1174}
1175
1176#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
1177pub struct FuncSort {
1178    pub inputs_and_output: List<Sort>,
1179}
1180
1181impl FuncSort {
1182    pub fn new(mut inputs: Vec<Sort>, output: Sort) -> Self {
1183        inputs.push(output);
1184        FuncSort { inputs_and_output: List::from_vec(inputs) }
1185    }
1186
1187    pub fn inputs(&self) -> &[Sort] {
1188        &self.inputs_and_output[0..self.inputs_and_output.len() - 1]
1189    }
1190
1191    pub fn output(&self) -> &Sort {
1192        &self.inputs_and_output[self.inputs_and_output.len() - 1]
1193    }
1194
1195    pub fn to_poly(&self) -> PolyFuncSort {
1196        PolyFuncSort::new(List::empty(), self.clone())
1197    }
1198}
1199
1200/// See [`PolyFuncSort`]
1201#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable)]
1202pub enum SortParamKind {
1203    Sort,
1204    BvSize,
1205}
1206
1207/// A polymorphic function sort parametric over [sorts] or [bit-vector sizes].
1208///
1209/// Parameterizing over bit-vector sizes is a bit of a stretch, because smtlib doesn't support full
1210/// parametric reasoning over them. As long as we used functions parameterized over a size monomorphically
1211/// we should be fine. Right now, we can guarantee this, because size parameters are not exposed in
1212/// the surface syntax and they are only used for predefined (interpreted) theory functions.
1213///
1214/// [sorts]: Sort
1215/// [bit-vector sizes]: BvSize::Param
1216#[derive(Clone, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable)]
1217pub struct PolyFuncSort {
1218    /// The list of parameters including sorts and bit vector sizes
1219    params: List<SortParamKind>,
1220    fsort: FuncSort,
1221}
1222
1223impl PolyFuncSort {
1224    pub fn new(params: List<SortParamKind>, fsort: FuncSort) -> Self {
1225        PolyFuncSort { params, fsort }
1226    }
1227
1228    pub fn skip_binders(&self) -> FuncSort {
1229        self.fsort.clone()
1230    }
1231
1232    pub fn instantiate_identity(&self) -> FuncSort {
1233        self.fsort.clone()
1234    }
1235
1236    pub fn expect_mono(&self) -> FuncSort {
1237        assert!(self.params.is_empty());
1238        self.fsort.clone()
1239    }
1240
1241    pub fn params(&self) -> impl ExactSizeIterator<Item = SortParamKind> + '_ {
1242        self.params.iter().copied()
1243    }
1244
1245    pub fn instantiate(&self, args: &[SortArg]) -> FuncSort {
1246        self.fsort.fold_with(&mut SortSubst::new(args))
1247    }
1248}
1249
1250/// An argument for a generic parameter in a [`Sort`] which can be either a generic sort or a
1251/// generic bit-vector size.
1252#[derive(
1253    Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
1254)]
1255pub enum SortArg {
1256    Sort(Sort),
1257    BvSize(BvSize),
1258}
1259
1260#[derive(Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1261pub enum ConstantInfo {
1262    /// An uninterpreted constant
1263    Uninterpreted,
1264    /// A non-integral constant whose value is specified by the user
1265    Interpreted(Expr, Sort),
1266}
1267
1268#[derive(Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1269pub struct AdtDef(Interned<AdtDefData>);
1270
1271#[derive(Debug, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1272pub struct AdtDefData {
1273    invariants: Vec<Invariant>,
1274    sort_def: AdtSortDef,
1275    opaque: bool,
1276    rustc: ty::AdtDef,
1277}
1278
1279/// Option-like enum to explicitly mark that we don't have information about an ADT because it was
1280/// annotated with `#[flux::opaque]`. Note that only structs can be marked as opaque.
1281#[derive(Clone, Debug, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
1282pub enum Opaqueness<T> {
1283    Opaque,
1284    Transparent(T),
1285}
1286
1287impl<T> Opaqueness<T> {
1288    pub fn map<S>(self, f: impl FnOnce(T) -> S) -> Opaqueness<S> {
1289        match self {
1290            Opaqueness::Opaque => Opaqueness::Opaque,
1291            Opaqueness::Transparent(value) => Opaqueness::Transparent(f(value)),
1292        }
1293    }
1294
1295    pub fn as_ref(&self) -> Opaqueness<&T> {
1296        match self {
1297            Opaqueness::Opaque => Opaqueness::Opaque,
1298            Opaqueness::Transparent(value) => Opaqueness::Transparent(value),
1299        }
1300    }
1301
1302    pub fn as_deref(&self) -> Opaqueness<&T::Target>
1303    where
1304        T: std::ops::Deref,
1305    {
1306        match self {
1307            Opaqueness::Opaque => Opaqueness::Opaque,
1308            Opaqueness::Transparent(value) => Opaqueness::Transparent(value.deref()),
1309        }
1310    }
1311
1312    pub fn ok_or_else<E>(self, err: impl FnOnce() -> E) -> Result<T, E> {
1313        match self {
1314            Opaqueness::Transparent(v) => Ok(v),
1315            Opaqueness::Opaque => Err(err()),
1316        }
1317    }
1318
1319    #[track_caller]
1320    pub fn expect(self, msg: &str) -> T {
1321        match self {
1322            Opaqueness::Transparent(val) => val,
1323            Opaqueness::Opaque => bug!("{}", msg),
1324        }
1325    }
1326
1327    pub fn ok_or_query_err(self, struct_id: DefId) -> Result<T, QueryErr> {
1328        self.ok_or_else(|| QueryErr::OpaqueStruct { struct_id })
1329    }
1330}
1331
1332impl<T, E> Opaqueness<Result<T, E>> {
1333    pub fn transpose(self) -> Result<Opaqueness<T>, E> {
1334        match self {
1335            Opaqueness::Transparent(Ok(x)) => Ok(Opaqueness::Transparent(x)),
1336            Opaqueness::Transparent(Err(e)) => Err(e),
1337            Opaqueness::Opaque => Ok(Opaqueness::Opaque),
1338        }
1339    }
1340}
1341
1342pub static INT_TYS: [IntTy; 6] =
1343    [IntTy::Isize, IntTy::I8, IntTy::I16, IntTy::I32, IntTy::I64, IntTy::I128];
1344pub static UINT_TYS: [UintTy; 6] =
1345    [UintTy::Usize, UintTy::U8, UintTy::U16, UintTy::U32, UintTy::U64, UintTy::U128];
1346
1347#[derive(
1348    Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable, TypeFoldable, TypeVisitable,
1349)]
1350pub struct Invariant {
1351    // This predicate may have sort variables, but we don't explicitly mark it like in `PolyFuncSort`.
1352    // See comment on `apply` for details.
1353    pred: Binder<Expr>,
1354}
1355
1356impl Invariant {
1357    pub fn new(pred: Binder<Expr>) -> Self {
1358        Self { pred }
1359    }
1360
1361    pub fn apply(&self, idx: &Expr) -> Expr {
1362        // The predicate may have sort variables but we don't explicitly instantiate them. This
1363        // works because within an expression, sort variables can only appear inside the sort
1364        // annotation for a lambda and invariants cannot have lambdas. It remains to instantiate
1365        // variables in the sort of the binder itself, but since we are removing it, we can avoid
1366        // the explicit instantiation. Ultimately, this works because the expression we generate in
1367        // fixpoint doesn't need sort annotations (sorts are re-inferred).
1368        self.pred.replace_bound_reft(idx)
1369    }
1370}
1371
1372pub type PolyVariants = List<Binder<VariantSig>>;
1373pub type PolyVariant = Binder<VariantSig>;
1374
1375#[derive(Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1376pub struct VariantSig {
1377    pub adt_def: AdtDef,
1378    pub args: GenericArgs,
1379    pub fields: List<Ty>,
1380    pub idx: Expr,
1381    pub requires: List<Expr>,
1382}
1383
1384pub type PolyFnSig = Binder<FnSig>;
1385
1386#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
1387pub struct FnSig {
1388    pub safety: Safety,
1389    pub abi: rustc_abi::ExternAbi,
1390    pub requires: List<Expr>,
1391    pub inputs: List<Ty>,
1392    pub output: Binder<FnOutput>,
1393    pub no_panic: Expr,
1394    /// was this auto-lifted (or from a spec)
1395    pub lifted: bool,
1396}
1397
1398#[derive(
1399    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
1400)]
1401pub struct FnOutput {
1402    pub ret: Ty,
1403    pub ensures: List<Ensures>,
1404}
1405
1406#[derive(Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1407pub enum Ensures {
1408    Type(Path, Ty),
1409    Pred(Expr),
1410}
1411
1412#[derive(Debug, TypeVisitable, TypeFoldable)]
1413pub struct Qualifier {
1414    pub def_id: FluxLocalDefId,
1415    pub body: Binder<Expr>,
1416    pub kind: QualifierKind,
1417}
1418
1419#[derive(Debug, TypeFoldable, TypeVisitable, Copy, Clone)]
1420pub enum QualifierKind {
1421    Global,
1422    Local,
1423    Hint,
1424}
1425
1426/// A `PrimOpProp` is a single property for a primitive operation which
1427/// can be conjoined to get the definition of the [`PrimRel`] for that
1428/// primitive operation.
1429#[derive(Debug, TypeVisitable, TypeFoldable)]
1430pub struct PrimOpProp {
1431    pub def_id: FluxLocalDefId,
1432    pub op: BinOp,
1433    pub body: Binder<Expr>,
1434}
1435
1436#[derive(Debug, TypeVisitable, TypeFoldable)]
1437pub struct PrimRel {
1438    pub body: Binder<Expr>,
1439}
1440
1441pub type TyCtor = Binder<Ty>;
1442
1443impl TyCtor {
1444    pub fn to_ty(&self) -> Ty {
1445        match &self.vars()[..] {
1446            [] => {
1447                return self.skip_binder_ref().shift_out_escaping(1);
1448            }
1449            [BoundVariableKind::Refine(sort, ..)] => {
1450                if sort.is_unit() {
1451                    return self.replace_bound_reft(&Expr::unit());
1452                }
1453                if let Some(def_id) = sort.is_unit_adt() {
1454                    return self.replace_bound_reft(&Expr::unit_struct(def_id));
1455                }
1456            }
1457            _ => {}
1458        }
1459        Ty::exists(self.clone())
1460    }
1461}
1462
1463#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1464pub struct Ty(Interned<TyKind>);
1465
1466impl Ty {
1467    pub fn kind(&self) -> &TyKind {
1468        &self.0
1469    }
1470
1471    /// Dummy type used for the `Self` of a `TraitRef` created when converting a trait object, and
1472    /// which gets removed in `ExistentialTraitRef`. This type must not appear anywhere in other
1473    /// converted types and must be a valid `rustc` type (i.e., we must be able to call `to_rustc`
1474    /// on it). `TyKind::Infer(TyVid(0))` does the job, with the caveat that we must skip 0 when
1475    /// generating `TyKind::Infer` for "type holes".
1476    pub fn trait_object_dummy_self() -> Ty {
1477        Ty::infer(TyVid::from_u32(0))
1478    }
1479
1480    pub fn dynamic(preds: impl Into<List<Binder<ExistentialPredicate>>>, region: Region) -> Ty {
1481        BaseTy::Dynamic(preds.into(), region).to_ty()
1482    }
1483
1484    pub fn strg_ref(re: Region, path: Path, ty: Ty) -> Ty {
1485        TyKind::StrgRef(re, path, ty).intern()
1486    }
1487
1488    pub fn ptr(pk: impl Into<PtrKind>, path: impl Into<Path>) -> Ty {
1489        TyKind::Ptr(pk.into(), path.into()).intern()
1490    }
1491
1492    pub fn constr(p: impl Into<Expr>, ty: Ty) -> Ty {
1493        TyKind::Constr(p.into(), ty).intern()
1494    }
1495
1496    pub fn uninit() -> Ty {
1497        TyKind::Uninit.intern()
1498    }
1499
1500    pub fn indexed(bty: BaseTy, idx: impl Into<Expr>) -> Ty {
1501        TyKind::Indexed(bty, idx.into()).intern()
1502    }
1503
1504    pub fn exists(ty: Binder<Ty>) -> Ty {
1505        TyKind::Exists(ty).intern()
1506    }
1507
1508    pub fn exists_with_constr(bty: BaseTy, pred: Expr) -> Ty {
1509        let sort = bty.sort();
1510        let ty = Ty::indexed(bty, Expr::nu());
1511        Ty::exists(Binder::bind_with_sort(Ty::constr(pred, ty), sort))
1512    }
1513
1514    pub fn discr(adt_def: AdtDef, place: Place) -> Ty {
1515        TyKind::Discr(adt_def, place).intern()
1516    }
1517
1518    pub fn unit() -> Ty {
1519        Ty::tuple(vec![])
1520    }
1521
1522    pub fn bool() -> Ty {
1523        BaseTy::Bool.to_ty()
1524    }
1525
1526    pub fn int(int_ty: IntTy) -> Ty {
1527        BaseTy::Int(int_ty).to_ty()
1528    }
1529
1530    pub fn uint(uint_ty: UintTy) -> Ty {
1531        BaseTy::Uint(uint_ty).to_ty()
1532    }
1533
1534    pub fn param(param_ty: ParamTy) -> Ty {
1535        TyKind::Param(param_ty).intern()
1536    }
1537
1538    pub fn downcast(
1539        adt: AdtDef,
1540        args: GenericArgs,
1541        ty: Ty,
1542        variant: VariantIdx,
1543        fields: List<Ty>,
1544    ) -> Ty {
1545        TyKind::Downcast(adt, args, ty, variant, fields).intern()
1546    }
1547
1548    pub fn blocked(ty: Ty) -> Ty {
1549        TyKind::Blocked(ty).intern()
1550    }
1551
1552    pub fn str() -> Ty {
1553        BaseTy::Str.to_ty()
1554    }
1555
1556    pub fn char() -> Ty {
1557        BaseTy::Char.to_ty()
1558    }
1559
1560    pub fn float(float_ty: FloatTy) -> Ty {
1561        BaseTy::Float(float_ty).to_ty()
1562    }
1563
1564    pub fn mk_ref(region: Region, ty: Ty, mutbl: Mutability) -> Ty {
1565        BaseTy::Ref(region, ty, mutbl).to_ty()
1566    }
1567
1568    pub fn mk_slice(ty: Ty) -> Ty {
1569        BaseTy::Slice(ty).to_ty()
1570    }
1571
1572    pub fn mk_box(genv: GlobalEnv, deref_ty: Ty, alloc_ty: Ty) -> QueryResult<Ty> {
1573        let def_id = genv.tcx().require_lang_item(LangItem::OwnedBox, DUMMY_SP);
1574        let adt_def = genv.adt_def(def_id)?;
1575
1576        let args = List::from_arr([GenericArg::Ty(deref_ty), GenericArg::Ty(alloc_ty)]);
1577
1578        let bty = BaseTy::adt(adt_def, args);
1579        Ok(Ty::indexed(bty, Expr::unit_struct(def_id)))
1580    }
1581
1582    pub fn mk_box_with_default_alloc(genv: GlobalEnv, deref_ty: Ty) -> QueryResult<Ty> {
1583        let def_id = genv.tcx().require_lang_item(LangItem::OwnedBox, DUMMY_SP);
1584
1585        let generics = genv.generics_of(def_id)?;
1586        let alloc_ty = genv
1587            .lower_type_of(generics.own_params[1].def_id)?
1588            .skip_binder()
1589            .refine(&Refiner::default_for_item(genv, def_id)?)?;
1590
1591        Ty::mk_box(genv, deref_ty, alloc_ty)
1592    }
1593
1594    pub fn tuple(tys: impl Into<List<Ty>>) -> Ty {
1595        BaseTy::Tuple(tys.into()).to_ty()
1596    }
1597
1598    pub fn array(ty: Ty, c: Const) -> Ty {
1599        BaseTy::Array(ty, c).to_ty()
1600    }
1601
1602    pub fn closure(
1603        did: DefId,
1604        tys: impl Into<List<Ty>>,
1605        args: &flux_rustc_bridge::ty::GenericArgs,
1606        no_panic: bool,
1607    ) -> Ty {
1608        BaseTy::Closure(did, tys.into(), args.clone(), no_panic).to_ty()
1609    }
1610
1611    pub fn coroutine(
1612        did: DefId,
1613        resume_ty: Ty,
1614        upvar_tys: List<Ty>,
1615        args: flux_rustc_bridge::ty::GenericArgs,
1616    ) -> Ty {
1617        BaseTy::Coroutine(did, resume_ty, upvar_tys, args.clone()).to_ty()
1618    }
1619
1620    pub fn never() -> Ty {
1621        BaseTy::Never.to_ty()
1622    }
1623
1624    pub fn infer(vid: TyVid) -> Ty {
1625        TyKind::Infer(vid).intern()
1626    }
1627
1628    pub fn unconstr(&self) -> (Ty, Expr) {
1629        fn go(this: &Ty, preds: &mut Vec<Expr>) -> Ty {
1630            if let TyKind::Constr(pred, ty) = this.kind() {
1631                preds.push(pred.clone());
1632                go(ty, preds)
1633            } else {
1634                this.clone()
1635            }
1636        }
1637        let mut preds = vec![];
1638        (go(self, &mut preds), Expr::and_from_iter(preds))
1639    }
1640
1641    pub fn unblocked(&self) -> Ty {
1642        match self.kind() {
1643            TyKind::Blocked(ty) => ty.clone(),
1644            _ => self.clone(),
1645        }
1646    }
1647
1648    /// Whether the type is an `int` or a `uint`
1649    pub fn is_integral(&self) -> bool {
1650        self.as_bty_skipping_existentials()
1651            .map(BaseTy::is_integral)
1652            .unwrap_or_default()
1653    }
1654
1655    /// Whether the type is a `bool`
1656    pub fn is_bool(&self) -> bool {
1657        self.as_bty_skipping_existentials()
1658            .map(BaseTy::is_bool)
1659            .unwrap_or_default()
1660    }
1661
1662    /// Whether the type is a `char`
1663    pub fn is_char(&self) -> bool {
1664        self.as_bty_skipping_existentials()
1665            .map(BaseTy::is_char)
1666            .unwrap_or_default()
1667    }
1668
1669    pub fn is_uninit(&self) -> bool {
1670        matches!(self.kind(), TyKind::Uninit)
1671    }
1672
1673    pub fn is_box(&self) -> bool {
1674        self.as_bty_skipping_existentials()
1675            .map(BaseTy::is_box)
1676            .unwrap_or_default()
1677    }
1678
1679    pub fn is_struct(&self) -> bool {
1680        self.as_bty_skipping_existentials()
1681            .map(BaseTy::is_struct)
1682            .unwrap_or_default()
1683    }
1684
1685    pub fn is_array(&self) -> bool {
1686        self.as_bty_skipping_existentials()
1687            .map(BaseTy::is_array)
1688            .unwrap_or_default()
1689    }
1690
1691    pub fn is_slice(&self) -> bool {
1692        self.as_bty_skipping_existentials()
1693            .map(BaseTy::is_slice)
1694            .unwrap_or_default()
1695    }
1696
1697    pub fn as_bty_skipping_existentials(&self) -> Option<&BaseTy> {
1698        match self.kind() {
1699            TyKind::Indexed(bty, _) => Some(bty),
1700            TyKind::Exists(ty) => Some(ty.skip_binder_ref().as_bty_skipping_existentials()?),
1701            TyKind::Constr(_, ty) => ty.as_bty_skipping_existentials(),
1702            _ => None,
1703        }
1704    }
1705
1706    #[track_caller]
1707    pub fn expect_discr(&self) -> (&AdtDef, &Place) {
1708        if let TyKind::Discr(adt_def, place) = self.kind() {
1709            (adt_def, place)
1710        } else {
1711            tracked_span_bug!("expected discr")
1712        }
1713    }
1714
1715    #[track_caller]
1716    pub fn expect_adt(&self) -> (&AdtDef, &[GenericArg], &Expr) {
1717        if let TyKind::Indexed(BaseTy::Adt(adt_def, args), idx) = self.kind() {
1718            (adt_def, args, idx)
1719        } else {
1720            tracked_span_bug!("expected adt `{self:?}`")
1721        }
1722    }
1723
1724    #[track_caller]
1725    pub fn expect_tuple(&self) -> &[Ty] {
1726        if let TyKind::Indexed(BaseTy::Tuple(tys), _) = self.kind() {
1727            tys
1728        } else {
1729            tracked_span_bug!("expected tuple found `{self:?}` (kind: `{:?}`)", self.kind())
1730        }
1731    }
1732
1733    pub fn simplify_type(&self) -> Option<SimplifiedType> {
1734        self.as_bty_skipping_existentials()
1735            .and_then(BaseTy::simplify_type)
1736    }
1737}
1738
1739impl<'tcx> ToRustc<'tcx> for Ty {
1740    type T = rustc_middle::ty::Ty<'tcx>;
1741
1742    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
1743        match self.kind() {
1744            TyKind::Indexed(bty, _) => bty.to_rustc(tcx),
1745            TyKind::Exists(ty) => ty.skip_binder_ref().to_rustc(tcx),
1746            TyKind::Constr(_, ty) => ty.to_rustc(tcx),
1747            TyKind::Param(pty) => pty.to_ty(tcx),
1748            TyKind::StrgRef(re, _, ty) => {
1749                rustc_middle::ty::Ty::new_ref(
1750                    tcx,
1751                    re.to_rustc(tcx),
1752                    ty.to_rustc(tcx),
1753                    Mutability::Mut,
1754                )
1755            }
1756            TyKind::Infer(vid) => rustc_middle::ty::Ty::new_var(tcx, *vid),
1757            TyKind::Uninit
1758            | TyKind::Ptr(_, _)
1759            | TyKind::Discr(..)
1760            | TyKind::Downcast(..)
1761            | TyKind::Blocked(_) => bug!("TODO: to_rustc for `{self:?}`"),
1762        }
1763    }
1764}
1765
1766#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, Debug)]
1767pub enum TyKind {
1768    Indexed(BaseTy, Expr),
1769    Exists(Binder<Ty>),
1770    Constr(Expr, Ty),
1771    Uninit,
1772    StrgRef(Region, Path, Ty),
1773    Ptr(PtrKind, Path),
1774    /// This is a bit of a hack. We use this type internally to represent the result of
1775    /// [`Rvalue::Discriminant`] in a way that we can recover the necessary control information
1776    /// when checking a [`match`]. The hack is that we assume the dicriminant remains the same from
1777    /// the creation of this type until we use it in a [`match`].
1778    ///
1779    ///
1780    /// [`Rvalue::Discriminant`]: flux_rustc_bridge::mir::Rvalue::Discriminant
1781    /// [`match`]: flux_rustc_bridge::mir::TerminatorKind::SwitchInt
1782    Discr(AdtDef, Place),
1783    Param(ParamTy),
1784    Downcast(AdtDef, GenericArgs, Ty, VariantIdx, List<Ty>),
1785    Blocked(Ty),
1786    /// A type that needs to be inferred by matching the signature against a rust signature.
1787    /// [`TyKind::Infer`] appear as an intermediate step during `conv` and should not be present in
1788    /// the final signature.
1789    Infer(TyVid),
1790}
1791
1792#[derive(Copy, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1793pub enum PtrKind {
1794    Mut(Region),
1795    Box,
1796}
1797
1798#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1799pub enum BaseTy {
1800    Int(IntTy),
1801    Uint(UintTy),
1802    Bool,
1803    Str,
1804    Char,
1805    Slice(Ty),
1806    Adt(AdtDef, GenericArgs),
1807    Float(FloatTy),
1808    RawPtr(Ty, Mutability),
1809    RawPtrMetadata(Ty),
1810    Ref(Region, Ty, Mutability),
1811    FnPtr(PolyFnSig),
1812    FnDef(DefId, GenericArgs),
1813    Tuple(List<Ty>),
1814    Alias(AliasKind, AliasTy),
1815    Array(Ty, Const),
1816    Never,
1817    Closure(DefId, /* upvar_tys */ List<Ty>, flux_rustc_bridge::ty::GenericArgs, bool),
1818    Coroutine(
1819        DefId,
1820        /*resume_ty: */ Ty,
1821        /* upvar_tys: */ List<Ty>,
1822        flux_rustc_bridge::ty::GenericArgs,
1823    ),
1824    Dynamic(List<Binder<ExistentialPredicate>>, Region),
1825    Param(ParamTy),
1826    Infer(TyVid),
1827    Foreign(DefId),
1828    Pat,
1829}
1830
1831impl BaseTy {
1832    pub fn opaque(alias_ty: AliasTy) -> BaseTy {
1833        BaseTy::Alias(AliasKind::Opaque, alias_ty)
1834    }
1835
1836    pub fn projection(alias_ty: AliasTy) -> BaseTy {
1837        BaseTy::Alias(AliasKind::Projection, alias_ty)
1838    }
1839
1840    pub fn adt(adt_def: AdtDef, args: GenericArgs) -> BaseTy {
1841        BaseTy::Adt(adt_def, args)
1842    }
1843
1844    pub fn fn_def(def_id: DefId, args: impl Into<GenericArgs>) -> BaseTy {
1845        BaseTy::FnDef(def_id, args.into())
1846    }
1847
1848    pub fn from_primitive_str(s: &str) -> Option<BaseTy> {
1849        match s {
1850            "i8" => Some(BaseTy::Int(IntTy::I8)),
1851            "i16" => Some(BaseTy::Int(IntTy::I16)),
1852            "i32" => Some(BaseTy::Int(IntTy::I32)),
1853            "i64" => Some(BaseTy::Int(IntTy::I64)),
1854            "i128" => Some(BaseTy::Int(IntTy::I128)),
1855            "u8" => Some(BaseTy::Uint(UintTy::U8)),
1856            "u16" => Some(BaseTy::Uint(UintTy::U16)),
1857            "u32" => Some(BaseTy::Uint(UintTy::U32)),
1858            "u64" => Some(BaseTy::Uint(UintTy::U64)),
1859            "u128" => Some(BaseTy::Uint(UintTy::U128)),
1860            "f16" => Some(BaseTy::Float(FloatTy::F16)),
1861            "f32" => Some(BaseTy::Float(FloatTy::F32)),
1862            "f64" => Some(BaseTy::Float(FloatTy::F64)),
1863            "f128" => Some(BaseTy::Float(FloatTy::F128)),
1864            "isize" => Some(BaseTy::Int(IntTy::Isize)),
1865            "usize" => Some(BaseTy::Uint(UintTy::Usize)),
1866            "bool" => Some(BaseTy::Bool),
1867            "char" => Some(BaseTy::Char),
1868            "str" => Some(BaseTy::Str),
1869            _ => None,
1870        }
1871    }
1872
1873    /// If `self` is a primitive, return its [`Symbol`].
1874    pub fn primitive_symbol(&self) -> Option<Symbol> {
1875        match self {
1876            BaseTy::Bool => Some(sym::bool),
1877            BaseTy::Char => Some(sym::char),
1878            BaseTy::Float(f) => {
1879                match f {
1880                    FloatTy::F16 => Some(sym::f16),
1881                    FloatTy::F32 => Some(sym::f32),
1882                    FloatTy::F64 => Some(sym::f64),
1883                    FloatTy::F128 => Some(sym::f128),
1884                }
1885            }
1886            BaseTy::Int(f) => {
1887                match f {
1888                    IntTy::Isize => Some(sym::isize),
1889                    IntTy::I8 => Some(sym::i8),
1890                    IntTy::I16 => Some(sym::i16),
1891                    IntTy::I32 => Some(sym::i32),
1892                    IntTy::I64 => Some(sym::i64),
1893                    IntTy::I128 => Some(sym::i128),
1894                }
1895            }
1896            BaseTy::Uint(f) => {
1897                match f {
1898                    UintTy::Usize => Some(sym::usize),
1899                    UintTy::U8 => Some(sym::u8),
1900                    UintTy::U16 => Some(sym::u16),
1901                    UintTy::U32 => Some(sym::u32),
1902                    UintTy::U64 => Some(sym::u64),
1903                    UintTy::U128 => Some(sym::u128),
1904                }
1905            }
1906            BaseTy::Str => Some(sym::str),
1907            _ => None,
1908        }
1909    }
1910
1911    pub fn is_integral(&self) -> bool {
1912        matches!(self, BaseTy::Int(_) | BaseTy::Uint(_))
1913    }
1914
1915    pub fn is_signed(&self) -> bool {
1916        matches!(self, BaseTy::Int(_))
1917    }
1918
1919    pub fn is_unsigned(&self) -> bool {
1920        matches!(self, BaseTy::Uint(_))
1921    }
1922
1923    pub fn is_float(&self) -> bool {
1924        matches!(self, BaseTy::Float(_))
1925    }
1926
1927    pub fn is_bool(&self) -> bool {
1928        matches!(self, BaseTy::Bool)
1929    }
1930
1931    fn is_struct(&self) -> bool {
1932        matches!(self, BaseTy::Adt(adt_def, _) if adt_def.is_struct())
1933    }
1934
1935    fn is_array(&self) -> bool {
1936        matches!(self, BaseTy::Array(..))
1937    }
1938
1939    fn is_slice(&self) -> bool {
1940        matches!(self, BaseTy::Slice(..))
1941    }
1942
1943    pub fn is_box(&self) -> bool {
1944        matches!(self, BaseTy::Adt(adt_def, _) if adt_def.is_box())
1945    }
1946
1947    pub fn is_char(&self) -> bool {
1948        matches!(self, BaseTy::Char)
1949    }
1950
1951    pub fn is_str(&self) -> bool {
1952        matches!(self, BaseTy::Str)
1953    }
1954
1955    pub fn invariants(
1956        &self,
1957        tcx: TyCtxt,
1958        overflow_mode: OverflowMode,
1959    ) -> impl Iterator<Item = Invariant> {
1960        let (invariants, args) = match self {
1961            BaseTy::Adt(adt_def, args) => (adt_def.invariants().skip_binder(), &args[..]),
1962            BaseTy::Uint(uint_ty) => (uint_invariants(*uint_ty, overflow_mode), &[][..]),
1963            BaseTy::Int(int_ty) => (int_invariants(*int_ty, overflow_mode), &[][..]),
1964            BaseTy::Char => (char_invariants(), &[][..]),
1965            BaseTy::Slice(_) => (slice_invariants(overflow_mode), &[][..]),
1966            _ => (&[][..], &[][..]),
1967        };
1968        invariants
1969            .iter()
1970            .map(move |inv| EarlyBinder(inv).instantiate_ref(tcx, args, &[]))
1971    }
1972
1973    pub fn to_ty(&self) -> Ty {
1974        let sort = self.sort();
1975        if sort.is_unit() {
1976            Ty::indexed(self.clone(), Expr::unit())
1977        } else {
1978            Ty::exists(Binder::bind_with_sort(
1979                Ty::indexed(self.shift_in_escaping(1), Expr::nu()),
1980                sort,
1981            ))
1982        }
1983    }
1984
1985    pub fn to_subset_ty_ctor(&self) -> SubsetTyCtor {
1986        let sort = self.sort();
1987        Binder::bind_with_sort(SubsetTy::trivial(self.clone(), Expr::nu()), sort)
1988    }
1989
1990    #[track_caller]
1991    pub fn expect_adt(&self) -> (&AdtDef, &[GenericArg]) {
1992        if let BaseTy::Adt(adt_def, args) = self {
1993            (adt_def, args)
1994        } else {
1995            tracked_span_bug!("expected adt `{self:?}`")
1996        }
1997    }
1998
1999    /// A type is an *atom* if it is "self-delimiting", i.e., it has a clear boundary
2000    /// when printed. This is used to avoid unnecessary parenthesis when pretty printing.
2001    pub fn is_atom(&self) -> bool {
2002        // (nilehmann) I'm not sure about this list, please adjust if you get any odd behavior
2003        matches!(
2004            self,
2005            BaseTy::Int(_)
2006                | BaseTy::Uint(_)
2007                | BaseTy::Slice(_)
2008                | BaseTy::Bool
2009                | BaseTy::Char
2010                | BaseTy::Str
2011                | BaseTy::Adt(..)
2012                | BaseTy::Tuple(..)
2013                | BaseTy::Param(_)
2014                | BaseTy::Array(..)
2015                | BaseTy::Never
2016                | BaseTy::Closure(..)
2017                | BaseTy::Coroutine(..)
2018                // opaque alias are atoms the way we print them now, but they won't
2019                // be if we print them as `impl Trait`
2020                | BaseTy::Alias(..)
2021        )
2022    }
2023
2024    /// Similar to [`rustc_infer::infer::canonical::ir::fast_reject::simplify_type`].
2025    ///
2026    /// This implementation is currently incomplete, so it should only be used in contexts
2027    /// where completeness is not required. Currently, it's used to find incoherent
2028    /// implementations when resolving associated constants. In this context, incompleteness
2029    /// is acceptable since the worst case outcome is simply failing to resolve a type-relative
2030    /// constant.
2031    fn simplify_type(&self) -> Option<SimplifiedType> {
2032        match self {
2033            BaseTy::Bool => Some(SimplifiedType::Bool),
2034            BaseTy::Char => Some(SimplifiedType::Char),
2035            BaseTy::Int(int_type) => Some(SimplifiedType::Int(*int_type)),
2036            BaseTy::Uint(uint_type) => Some(SimplifiedType::Uint(*uint_type)),
2037            BaseTy::Float(float_type) => Some(SimplifiedType::Float(*float_type)),
2038            BaseTy::Adt(def, _) => Some(SimplifiedType::Adt(def.did())),
2039            BaseTy::Str => Some(SimplifiedType::Str),
2040            BaseTy::Array(..) => Some(SimplifiedType::Array),
2041            BaseTy::Slice(..) => Some(SimplifiedType::Slice),
2042            BaseTy::RawPtr(_, mutbl) => Some(SimplifiedType::Ptr(*mutbl)),
2043            BaseTy::Ref(_, _, mutbl) => Some(SimplifiedType::Ref(*mutbl)),
2044            BaseTy::FnDef(def_id, _) | BaseTy::Closure(def_id, ..) => {
2045                Some(SimplifiedType::Closure(*def_id))
2046            }
2047            BaseTy::Coroutine(def_id, ..) => Some(SimplifiedType::Coroutine(*def_id)),
2048            BaseTy::Never => Some(SimplifiedType::Never),
2049            BaseTy::Tuple(tys) => Some(SimplifiedType::Tuple(tys.len())),
2050            BaseTy::FnPtr(poly_fn_sig) => {
2051                Some(SimplifiedType::Function(poly_fn_sig.skip_binder_ref().inputs().len()))
2052            }
2053            BaseTy::Foreign(def_id) => Some(SimplifiedType::Foreign(*def_id)),
2054            BaseTy::RawPtrMetadata(_)
2055            | BaseTy::Alias(..)
2056            | BaseTy::Param(_)
2057            | BaseTy::Dynamic(..)
2058            | BaseTy::Infer(_) => None,
2059            BaseTy::Pat => todo!(),
2060        }
2061    }
2062}
2063
2064impl<'tcx> ToRustc<'tcx> for BaseTy {
2065    type T = rustc_middle::ty::Ty<'tcx>;
2066
2067    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2068        use rustc_middle::ty;
2069        match self {
2070            BaseTy::Int(i) => ty::Ty::new_int(tcx, *i),
2071            BaseTy::Uint(i) => ty::Ty::new_uint(tcx, *i),
2072            BaseTy::Param(pty) => pty.to_ty(tcx),
2073            BaseTy::Slice(ty) => ty::Ty::new_slice(tcx, ty.to_rustc(tcx)),
2074            BaseTy::Bool => tcx.types.bool,
2075            BaseTy::Char => tcx.types.char,
2076            BaseTy::Str => tcx.types.str_,
2077            BaseTy::Adt(adt_def, args) => {
2078                let did = adt_def.did();
2079                let adt_def = tcx.adt_def(did);
2080                let args = args.to_rustc(tcx);
2081                ty::Ty::new_adt(tcx, adt_def, args)
2082            }
2083            BaseTy::FnDef(def_id, args) => {
2084                let args = args.to_rustc(tcx);
2085                ty::Ty::new_fn_def(tcx, *def_id, args)
2086            }
2087            BaseTy::Float(f) => ty::Ty::new_float(tcx, *f),
2088            BaseTy::RawPtr(ty, mutbl) => ty::Ty::new_ptr(tcx, ty.to_rustc(tcx), *mutbl),
2089            BaseTy::Ref(re, ty, mutbl) => {
2090                ty::Ty::new_ref(tcx, re.to_rustc(tcx), ty.to_rustc(tcx), *mutbl)
2091            }
2092            BaseTy::FnPtr(poly_sig) => ty::Ty::new_fn_ptr(tcx, poly_sig.to_rustc(tcx)),
2093            BaseTy::Tuple(tys) => {
2094                let ts = tys.iter().map(|ty| ty.to_rustc(tcx)).collect_vec();
2095                ty::Ty::new_tup(tcx, &ts)
2096            }
2097            BaseTy::Alias(kind, alias_ty) => {
2098                ty::Ty::new_alias(tcx, kind.to_rustc(tcx), alias_ty.to_rustc(tcx))
2099            }
2100            BaseTy::Array(ty, n) => {
2101                let ty = ty.to_rustc(tcx);
2102                let n = n.to_rustc(tcx);
2103                ty::Ty::new_array_with_const_len(tcx, ty, n)
2104            }
2105            BaseTy::Never => tcx.types.never,
2106            BaseTy::Closure(did, _, args, _) => ty::Ty::new_closure(tcx, *did, args.to_rustc(tcx)),
2107            BaseTy::Dynamic(exi_preds, re) => {
2108                let preds: Vec<_> = exi_preds
2109                    .iter()
2110                    .map(|pred| pred.to_rustc(tcx))
2111                    .collect_vec();
2112                let preds = tcx.mk_poly_existential_predicates(&preds);
2113                ty::Ty::new_dynamic(tcx, preds, re.to_rustc(tcx))
2114            }
2115            BaseTy::Coroutine(did, _, _, args) => {
2116                ty::Ty::new_coroutine(tcx, *did, args.to_rustc(tcx))
2117            }
2118            BaseTy::Infer(ty_vid) => ty::Ty::new_var(tcx, *ty_vid),
2119            BaseTy::Foreign(def_id) => ty::Ty::new_foreign(tcx, *def_id),
2120            BaseTy::RawPtrMetadata(ty) => {
2121                ty::Ty::new_ptr(
2122                    tcx,
2123                    ty.to_rustc(tcx),
2124                    RawPtrKind::FakeForPtrMetadata.to_mutbl_lossy(),
2125                )
2126            }
2127            BaseTy::Pat => todo!(),
2128        }
2129    }
2130}
2131
2132#[derive(
2133    Clone, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
2134)]
2135pub struct AliasTy {
2136    pub def_id: DefId,
2137    pub args: GenericArgs,
2138    /// Holds the refinement-arguments for opaque-types; empty for projections
2139    pub refine_args: RefineArgs,
2140}
2141
2142impl AliasTy {
2143    pub fn new(def_id: DefId, args: GenericArgs, refine_args: RefineArgs) -> Self {
2144        AliasTy { args, refine_args, def_id }
2145    }
2146}
2147
2148/// This methods work only with associated type projections (i.e., no opaque types)
2149impl AliasTy {
2150    pub fn self_ty(&self) -> SubsetTyCtor {
2151        self.args[0].expect_base().clone()
2152    }
2153
2154    pub fn with_self_ty(&self, self_ty: SubsetTyCtor) -> Self {
2155        Self {
2156            def_id: self.def_id,
2157            args: [GenericArg::Base(self_ty)]
2158                .into_iter()
2159                .chain(self.args.iter().skip(1).cloned())
2160                .collect(),
2161            refine_args: self.refine_args.clone(),
2162        }
2163    }
2164}
2165
2166impl<'tcx> ToRustc<'tcx> for AliasTy {
2167    type T = rustc_middle::ty::AliasTy<'tcx>;
2168
2169    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2170        rustc_middle::ty::AliasTy::new(tcx, self.def_id, self.args.to_rustc(tcx))
2171    }
2172}
2173
2174pub type RefineArgs = List<Expr>;
2175
2176#[extension(pub trait RefineArgsExt)]
2177impl RefineArgs {
2178    fn identity_for_item(genv: GlobalEnv, def_id: DefId) -> QueryResult<RefineArgs> {
2179        Self::for_item(genv, def_id, |param, index| {
2180            Ok(Expr::var(Var::EarlyParam(EarlyReftParam {
2181                index: index as u32,
2182                name: param.name(),
2183            })))
2184        })
2185    }
2186
2187    fn for_item<F>(genv: GlobalEnv, def_id: DefId, mut mk: F) -> QueryResult<RefineArgs>
2188    where
2189        F: FnMut(EarlyBinder<RefineParam>, usize) -> QueryResult<Expr>,
2190    {
2191        let reft_generics = genv.refinement_generics_of(def_id)?;
2192        let count = reft_generics.count();
2193        let mut args = Vec::with_capacity(count);
2194        reft_generics.fill_item(genv, &mut args, &mut mk)?;
2195        Ok(List::from_vec(args))
2196    }
2197}
2198
2199/// A type constructor meant to be used as generic a argument of [kind base]. This is just an alias
2200/// to [`Binder<SubsetTy>`], but we expect the binder to have a single bound variable of the sort of
2201/// the underlying [base type].
2202///
2203/// [kind base]: GenericParamDefKind::Base
2204/// [base type]: SubsetTy::bty
2205pub type SubsetTyCtor = Binder<SubsetTy>;
2206
2207impl SubsetTyCtor {
2208    pub fn as_bty_skipping_binder(&self) -> &BaseTy {
2209        &self.as_ref().skip_binder().bty
2210    }
2211
2212    pub fn to_ty(&self) -> Ty {
2213        let sort = self.sort();
2214        if sort.is_unit() {
2215            self.replace_bound_reft(&Expr::unit()).to_ty()
2216        } else if let Some(def_id) = sort.is_unit_adt() {
2217            self.replace_bound_reft(&Expr::unit_struct(def_id)).to_ty()
2218        } else {
2219            Ty::exists(self.as_ref().map(SubsetTy::to_ty))
2220        }
2221    }
2222
2223    pub fn to_ty_ctor(&self) -> TyCtor {
2224        self.as_ref().map(SubsetTy::to_ty)
2225    }
2226}
2227
2228/// A subset type is a simplified version of a type that has the form `{b[e] | p}` where `b` is a
2229/// [`BaseTy`], `e` a refinement index, and `p` a predicate.
2230///
2231/// These are mainly found under a [`Binder`] with a single variable of the base type's sort. This
2232/// can be interpreted as a type constructor or an existential type. For example, under a binder with a
2233/// variable `v` of sort `int`, we can interpret `{i32[v] | v > 0}` as:
2234/// - A lambda `λv:int. {i32[v] | v > 0}` that "constructs" types when applied to ints, or
2235/// - An existential type `∃v:int. {i32[v] | v > 0}`.
2236///
2237/// This second interpretation is the reason we call this a subset type, i.e., the type `∃v. {b[v] | p}`
2238/// corresponds to the subset of values of  type `b` whose index satisfies `p`. These are the types
2239/// written as `B{v: p}` in the surface syntax and correspond to the types supported in other
2240/// refinement type systems like Liquid Haskell (with the difference that we are explicit
2241/// about separating refinements from program values via an index).
2242///
2243/// The main purpose for subset types is to be used as generic arguments of [kind base] when
2244/// interpreted as type constructors. They have two key properties that makes them suitable
2245/// for this:
2246///
2247/// 1. **Syntactic Restriction**: Subset types are syntactically restricted, making it easier to
2248///    relate them structurally (e.g., for subtyping). For instance, given two types `S<λv. T1>` and
2249///    `S<λ. T2>`, if `T1` and `T2` are subset types, we know they match structurally (at least
2250///    shallowly). In particularly, the syntactic restriction rules out complex types like
2251///    `S<λv. (i32[v], i32[v])>` simplifying some operations.
2252///
2253/// 2. **Eager Canonicalization**: Subset types can be eagerly canonicalized via [*strengthening*]
2254///    during substitution. For example, suppose we have a function:
2255///    ```text
2256///    fn foo<T>(x: T[@a], y: { T[@b] | b == a }) { }
2257///    ```
2258///    If we instantiate `T` with `λv. { i32[v] | v > 0}`, after substitution and applying the
2259///    lambda (the indexing syntax `T[a]` corresponds to an application of the lambda), we get:
2260///    ```text
2261///    fn foo(x: {i32[@a] | a > 0}, y: { { i32[@b] | b > 0 } | b == a }) { }
2262///    ```
2263///    Via *strengthening* we can canonicalize this to
2264///    ```text
2265///    fn foo(x: {i32[@a] | a > 0}, y: { i32[@b] | b == a && b > 0 }) { }
2266///    ```
2267///    As a result, we can guarantee the syntactic restriction through substitution.
2268///
2269/// [kind base]: GenericParamDefKind::Base
2270/// [*strengthening*]: https://arxiv.org/pdf/2010.07763.pdf
2271#[derive(PartialEq, Clone, Eq, Hash, TyEncodable, TyDecodable)]
2272pub struct SubsetTy {
2273    /// The base type `b` in the subset type `{b[e] | p}`.
2274    ///
2275    /// **NOTE:** This is mostly going to be under a [`Binder`]. It is not yet clear to me whether
2276    /// this [`BaseTy`] should be able to mention variables in the binder. In general, in a type
2277    /// `∃v. {b[e] | p}`, it's fine to mention `v` inside `b`, but since [`SubsetTy`] is meant to
2278    /// facilitate syntactic manipulation we may want to restrict this.
2279    pub bty: BaseTy,
2280    /// The refinement index `e` in the subset type `{b[e] | p}`. This can be an arbitrary expression,
2281    /// which makes manipulation easier. However, since this is mostly found under a binder, we expect
2282    /// it to be [`Expr::nu()`].
2283    pub idx: Expr,
2284    /// The predicate `p` in the subset type `{b[e] | p}`.
2285    pub pred: Expr,
2286}
2287
2288impl SubsetTy {
2289    pub fn new(bty: BaseTy, idx: impl Into<Expr>, pred: impl Into<Expr>) -> Self {
2290        Self { bty, idx: idx.into(), pred: pred.into() }
2291    }
2292
2293    pub fn trivial(bty: BaseTy, idx: impl Into<Expr>) -> Self {
2294        Self::new(bty, idx, Expr::tt())
2295    }
2296
2297    pub fn strengthen(&self, pred: impl Into<Expr>) -> Self {
2298        let this = self.clone();
2299        let pred = Expr::and(this.pred, pred).simplify(&SnapshotMap::default());
2300        Self { bty: this.bty, idx: this.idx, pred }
2301    }
2302
2303    pub fn to_ty(&self) -> Ty {
2304        let bty = self.bty.clone();
2305        if self.pred.is_trivially_true() {
2306            Ty::indexed(bty, &self.idx)
2307        } else {
2308            Ty::constr(&self.pred, Ty::indexed(bty, &self.idx))
2309        }
2310    }
2311}
2312
2313impl<'tcx> ToRustc<'tcx> for SubsetTy {
2314    type T = rustc_middle::ty::Ty<'tcx>;
2315
2316    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> rustc_middle::ty::Ty<'tcx> {
2317        self.bty.to_rustc(tcx)
2318    }
2319}
2320
2321#[derive(PartialEq, Clone, Eq, Hash, TyEncodable, TyDecodable)]
2322pub enum GenericArg {
2323    Ty(Ty),
2324    Base(SubsetTyCtor),
2325    Lifetime(Region),
2326    Const(Const),
2327}
2328
2329impl GenericArg {
2330    #[track_caller]
2331    pub fn expect_type(&self) -> &Ty {
2332        if let GenericArg::Ty(ty) = self {
2333            ty
2334        } else {
2335            bug!("expected `rty::GenericArg::Ty`, found `{self:?}`")
2336        }
2337    }
2338
2339    #[track_caller]
2340    pub fn expect_base(&self) -> &SubsetTyCtor {
2341        if let GenericArg::Base(ctor) = self {
2342            ctor
2343        } else {
2344            bug!("expected `rty::GenericArg::Base`, found `{self:?}`")
2345        }
2346    }
2347
2348    pub fn from_param_def(param: &GenericParamDef) -> Self {
2349        match param.kind {
2350            GenericParamDefKind::Type { .. } => {
2351                let param_ty = ParamTy { index: param.index, name: param.name };
2352                GenericArg::Ty(Ty::param(param_ty))
2353            }
2354            GenericParamDefKind::Base { .. } => {
2355                // λv. T[v]
2356                let param_ty = ParamTy { index: param.index, name: param.name };
2357                GenericArg::Base(Binder::bind_with_sort(
2358                    SubsetTy::trivial(BaseTy::Param(param_ty), Expr::nu()),
2359                    Sort::Param(param_ty),
2360                ))
2361            }
2362            GenericParamDefKind::Lifetime => {
2363                let region = EarlyParamRegion { index: param.index, name: param.name };
2364                GenericArg::Lifetime(Region::ReEarlyParam(region))
2365            }
2366            GenericParamDefKind::Const { .. } => {
2367                let param_const = ParamConst { index: param.index, name: param.name };
2368                let kind = ConstKind::Param(param_const);
2369                GenericArg::Const(Const { kind })
2370            }
2371        }
2372    }
2373
2374    /// Creates a `GenericArgs` from the definition of generic parameters, by calling a closure to
2375    /// obtain arg. The closures get to observe the `GenericArgs` as they're being built, which can
2376    /// be used to correctly replace defaults of generic parameters.
2377    pub fn for_item<F>(genv: GlobalEnv, def_id: DefId, mut mk_kind: F) -> QueryResult<GenericArgs>
2378    where
2379        F: FnMut(&GenericParamDef, &[GenericArg]) -> GenericArg,
2380    {
2381        let defs = genv.generics_of(def_id)?;
2382        let count = defs.count();
2383        let mut args = Vec::with_capacity(count);
2384        Self::fill_item(genv, &mut args, &defs, &mut mk_kind)?;
2385        Ok(List::from_vec(args))
2386    }
2387
2388    pub fn identity_for_item(genv: GlobalEnv, def_id: DefId) -> QueryResult<GenericArgs> {
2389        Self::for_item(genv, def_id, |param, _| GenericArg::from_param_def(param))
2390    }
2391
2392    fn fill_item<F>(
2393        genv: GlobalEnv,
2394        args: &mut Vec<GenericArg>,
2395        generics: &Generics,
2396        mk_kind: &mut F,
2397    ) -> QueryResult<()>
2398    where
2399        F: FnMut(&GenericParamDef, &[GenericArg]) -> GenericArg,
2400    {
2401        if let Some(def_id) = generics.parent {
2402            let parent_generics = genv.generics_of(def_id)?;
2403            Self::fill_item(genv, args, &parent_generics, mk_kind)?;
2404        }
2405        for param in &generics.own_params {
2406            let kind = mk_kind(param, args);
2407            tracked_span_assert_eq!(param.index as usize, args.len());
2408            args.push(kind);
2409        }
2410        Ok(())
2411    }
2412}
2413
2414impl From<TyOrBase> for GenericArg {
2415    fn from(v: TyOrBase) -> Self {
2416        match v {
2417            TyOrBase::Ty(ty) => GenericArg::Ty(ty),
2418            TyOrBase::Base(ctor) => GenericArg::Base(ctor),
2419        }
2420    }
2421}
2422
2423impl<'tcx> ToRustc<'tcx> for GenericArg {
2424    type T = rustc_middle::ty::GenericArg<'tcx>;
2425
2426    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2427        use rustc_middle::ty;
2428        match self {
2429            GenericArg::Ty(ty) => ty::GenericArg::from(ty.to_rustc(tcx)),
2430            GenericArg::Base(ctor) => ty::GenericArg::from(ctor.skip_binder_ref().to_rustc(tcx)),
2431            GenericArg::Lifetime(re) => ty::GenericArg::from(re.to_rustc(tcx)),
2432            GenericArg::Const(c) => ty::GenericArg::from(c.to_rustc(tcx)),
2433        }
2434    }
2435}
2436
2437pub type GenericArgs = List<GenericArg>;
2438
2439#[extension(pub trait GenericArgsExt)]
2440impl GenericArgs {
2441    #[track_caller]
2442    fn box_args(&self) -> (&Ty, &Ty) {
2443        if let [GenericArg::Ty(deref), GenericArg::Ty(alloc)] = &self[..] {
2444            (deref, alloc)
2445        } else {
2446            bug!("invalid generic arguments for box");
2447        }
2448    }
2449
2450    // We can't implement [`ToRustc`] because of coherence so we add it here
2451    fn to_rustc<'tcx>(&self, tcx: TyCtxt<'tcx>) -> rustc_middle::ty::GenericArgsRef<'tcx> {
2452        tcx.mk_args_from_iter(self.iter().map(|arg| arg.to_rustc(tcx)))
2453    }
2454
2455    fn rebase_onto(
2456        &self,
2457        tcx: &TyCtxt,
2458        source_ancestor: DefId,
2459        target_args: &GenericArgs,
2460    ) -> List<GenericArg> {
2461        let defs = tcx.generics_of(source_ancestor);
2462        target_args
2463            .iter()
2464            .chain(self.iter().skip(defs.count()))
2465            .cloned()
2466            .collect()
2467    }
2468}
2469
2470#[derive(Debug)]
2471pub enum TyOrBase {
2472    Ty(Ty),
2473    Base(SubsetTyCtor),
2474}
2475
2476impl TyOrBase {
2477    pub fn into_ty(self) -> Ty {
2478        match self {
2479            TyOrBase::Ty(ty) => ty,
2480            TyOrBase::Base(ctor) => ctor.to_ty(),
2481        }
2482    }
2483
2484    #[track_caller]
2485    pub fn expect_base(self) -> SubsetTyCtor {
2486        match self {
2487            TyOrBase::Base(ctor) => ctor,
2488            TyOrBase::Ty(_) => tracked_span_bug!("expected `TyOrBase::Base`"),
2489        }
2490    }
2491
2492    pub fn as_base(self) -> Option<SubsetTyCtor> {
2493        match self {
2494            TyOrBase::Base(ctor) => Some(ctor),
2495            TyOrBase::Ty(_) => None,
2496        }
2497    }
2498}
2499
2500#[derive(Debug, Clone, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
2501pub enum TyOrCtor {
2502    Ty(Ty),
2503    Ctor(TyCtor),
2504}
2505
2506impl TyOrCtor {
2507    #[track_caller]
2508    pub fn expect_ctor(self) -> TyCtor {
2509        match self {
2510            TyOrCtor::Ctor(ctor) => ctor,
2511            TyOrCtor::Ty(_) => tracked_span_bug!("expected `TyOrCtor::Ctor`"),
2512        }
2513    }
2514
2515    pub fn expect_subset_ty_ctor(self) -> SubsetTyCtor {
2516        self.expect_ctor().map(|ty| {
2517            if let canonicalize::CanonicalTy::Constr(constr_ty) = ty.shallow_canonicalize()
2518                && let TyKind::Indexed(bty, idx) = constr_ty.ty().kind()
2519                && idx.is_nu()
2520            {
2521                SubsetTy::new(bty.clone(), Expr::nu(), constr_ty.pred())
2522            } else {
2523                tracked_span_bug!()
2524            }
2525        })
2526    }
2527
2528    pub fn to_ty(&self) -> Ty {
2529        match self {
2530            TyOrCtor::Ctor(ctor) => ctor.to_ty(),
2531            TyOrCtor::Ty(ty) => ty.clone(),
2532        }
2533    }
2534}
2535
2536impl From<TyOrBase> for TyOrCtor {
2537    fn from(v: TyOrBase) -> Self {
2538        match v {
2539            TyOrBase::Ty(ty) => TyOrCtor::Ty(ty),
2540            TyOrBase::Base(ctor) => TyOrCtor::Ctor(ctor.to_ty_ctor()),
2541        }
2542    }
2543}
2544
2545impl CoroutineObligPredicate {
2546    pub fn to_poly_fn_sig(&self) -> PolyFnSig {
2547        let vars = vec![];
2548
2549        let resume_ty = &self.resume_ty;
2550        let env_ty = Ty::coroutine(
2551            self.def_id,
2552            resume_ty.clone(),
2553            self.upvar_tys.clone(),
2554            self.args.clone(),
2555        );
2556
2557        let inputs = List::from_arr([env_ty, resume_ty.clone()]);
2558        let output =
2559            Binder::bind_with_vars(FnOutput::new(self.output.clone(), vec![]), List::empty());
2560
2561        PolyFnSig::bind_with_vars(
2562            FnSig::new(
2563                Safety::Safe,
2564                rustc_abi::ExternAbi::RustCall,
2565                List::empty(),
2566                inputs,
2567                output,
2568                Expr::ff(),
2569                false,
2570            ),
2571            List::from(vars),
2572        )
2573    }
2574}
2575
2576impl RefinementGenerics {
2577    pub fn count(&self) -> usize {
2578        self.parent_count + self.own_params.len()
2579    }
2580
2581    pub fn own_count(&self) -> usize {
2582        self.own_params.len()
2583    }
2584}
2585
2586impl EarlyBinder<RefinementGenerics> {
2587    pub fn parent(&self) -> Option<DefId> {
2588        self.skip_binder_ref().parent
2589    }
2590
2591    pub fn parent_count(&self) -> usize {
2592        self.skip_binder_ref().parent_count
2593    }
2594
2595    pub fn count(&self) -> usize {
2596        self.skip_binder_ref().count()
2597    }
2598
2599    pub fn own_count(&self) -> usize {
2600        self.skip_binder_ref().own_count()
2601    }
2602
2603    pub fn own_param_at(&self, index: usize) -> EarlyBinder<RefineParam> {
2604        self.as_ref().map(|this| this.own_params[index].clone())
2605    }
2606
2607    pub fn param_at(
2608        &self,
2609        param_index: usize,
2610        genv: GlobalEnv,
2611    ) -> QueryResult<EarlyBinder<RefineParam>> {
2612        if let Some(index) = param_index.checked_sub(self.parent_count()) {
2613            Ok(self.own_param_at(index))
2614        } else {
2615            let parent = self.parent().expect("parent_count > 0 but no parent?");
2616            genv.refinement_generics_of(parent)?
2617                .param_at(param_index, genv)
2618        }
2619    }
2620
2621    pub fn iter_own_params(&self) -> impl Iterator<Item = EarlyBinder<RefineParam>> + use<'_> {
2622        self.skip_binder_ref()
2623            .own_params
2624            .iter()
2625            .cloned()
2626            .map(EarlyBinder)
2627    }
2628
2629    pub fn fill_item<F, R>(&self, genv: GlobalEnv, vec: &mut Vec<R>, mk: &mut F) -> QueryResult
2630    where
2631        F: FnMut(EarlyBinder<RefineParam>, usize) -> QueryResult<R>,
2632    {
2633        if let Some(def_id) = self.parent() {
2634            genv.refinement_generics_of(def_id)?
2635                .fill_item(genv, vec, mk)?;
2636        }
2637        for param in self.iter_own_params() {
2638            vec.push(mk(param, vec.len())?);
2639        }
2640        Ok(())
2641    }
2642}
2643
2644impl EarlyBinder<GenericPredicates> {
2645    pub fn predicates(&self) -> EarlyBinder<List<Clause>> {
2646        EarlyBinder(self.0.predicates.clone())
2647    }
2648}
2649
2650impl EarlyBinder<FuncSort> {
2651    /// See [`subst::GenericsSubstForSort`]
2652    pub fn instantiate_func_sort<E>(
2653        self,
2654        sort_for_param: impl FnMut(ParamTy) -> Result<Sort, E>,
2655    ) -> Result<FuncSort, E> {
2656        self.0.try_fold_with(&mut subst::GenericsSubstFolder::new(
2657            subst::GenericsSubstForSort { sort_for_param },
2658            &[],
2659        ))
2660    }
2661}
2662
2663impl VariantSig {
2664    pub fn new(
2665        adt_def: AdtDef,
2666        args: GenericArgs,
2667        fields: List<Ty>,
2668        idx: Expr,
2669        requires: List<Expr>,
2670    ) -> Self {
2671        VariantSig { adt_def, args, fields, idx, requires }
2672    }
2673
2674    pub fn fields(&self) -> &[Ty] {
2675        &self.fields
2676    }
2677
2678    pub fn ret(&self) -> Ty {
2679        let bty = BaseTy::Adt(self.adt_def.clone(), self.args.clone());
2680        let idx = self.idx.clone();
2681        Ty::indexed(bty, idx)
2682    }
2683}
2684
2685impl FnSig {
2686    pub fn new(
2687        safety: Safety,
2688        abi: rustc_abi::ExternAbi,
2689        requires: List<Expr>,
2690        inputs: List<Ty>,
2691        output: Binder<FnOutput>,
2692        no_panic: Expr,
2693        lifted: bool,
2694    ) -> Self {
2695        FnSig { safety, abi, requires, inputs, output, no_panic, lifted }
2696    }
2697
2698    pub fn requires(&self) -> &[Expr] {
2699        &self.requires
2700    }
2701
2702    pub fn inputs(&self) -> &[Ty] {
2703        &self.inputs
2704    }
2705
2706    pub fn no_panic(&self) -> Expr {
2707        self.no_panic.clone()
2708    }
2709
2710    pub fn output(&self) -> Binder<FnOutput> {
2711        self.output.clone()
2712    }
2713}
2714
2715impl<'tcx> ToRustc<'tcx> for FnSig {
2716    type T = rustc_middle::ty::FnSig<'tcx>;
2717
2718    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2719        tcx.mk_fn_sig(
2720            self.inputs().iter().map(|ty| ty.to_rustc(tcx)),
2721            self.output().as_ref().skip_binder().to_rustc(tcx),
2722            false,
2723            self.safety,
2724            self.abi,
2725        )
2726    }
2727}
2728
2729impl FnOutput {
2730    pub fn new(ret: Ty, ensures: impl Into<List<Ensures>>) -> Self {
2731        Self { ret, ensures: ensures.into() }
2732    }
2733}
2734
2735impl<'tcx> ToRustc<'tcx> for FnOutput {
2736    type T = rustc_middle::ty::Ty<'tcx>;
2737
2738    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2739        self.ret.to_rustc(tcx)
2740    }
2741}
2742
2743impl AdtDef {
2744    pub fn new(
2745        rustc: ty::AdtDef,
2746        sort_def: AdtSortDef,
2747        invariants: Vec<Invariant>,
2748        opaque: bool,
2749    ) -> Self {
2750        AdtDef(Interned::new(AdtDefData { invariants, sort_def, opaque, rustc }))
2751    }
2752
2753    pub fn did(&self) -> DefId {
2754        self.0.rustc.did()
2755    }
2756
2757    pub fn sort_def(&self) -> &AdtSortDef {
2758        &self.0.sort_def
2759    }
2760
2761    pub fn sort(&self, args: &[GenericArg]) -> Sort {
2762        self.sort_def().to_sort(args)
2763    }
2764
2765    pub fn is_box(&self) -> bool {
2766        self.0.rustc.is_box()
2767    }
2768
2769    pub fn is_enum(&self) -> bool {
2770        self.0.rustc.is_enum()
2771    }
2772
2773    pub fn is_struct(&self) -> bool {
2774        self.0.rustc.is_struct()
2775    }
2776
2777    pub fn is_union(&self) -> bool {
2778        self.0.rustc.is_union()
2779    }
2780
2781    pub fn variants(&self) -> &IndexSlice<VariantIdx, VariantDef> {
2782        self.0.rustc.variants()
2783    }
2784
2785    pub fn variant(&self, idx: VariantIdx) -> &VariantDef {
2786        self.0.rustc.variant(idx)
2787    }
2788
2789    pub fn invariants(&self) -> EarlyBinder<&[Invariant]> {
2790        EarlyBinder(&self.0.invariants)
2791    }
2792
2793    pub fn discriminants(&self) -> impl Iterator<Item = (VariantIdx, u128)> + '_ {
2794        self.0.rustc.discriminants()
2795    }
2796
2797    pub fn is_opaque(&self) -> bool {
2798        self.0.opaque
2799    }
2800}
2801
2802impl EarlyBinder<PolyVariant> {
2803    // The field_idx is `Some(i)` when we have the `i`-th field of a `union`, in which case,
2804    // the `inputs` are _just_ the `i`-th type (and not all the types...)
2805    pub fn to_poly_fn_sig(&self, field_idx: Option<crate::FieldIdx>) -> EarlyBinder<PolyFnSig> {
2806        self.as_ref().map(|poly_variant| {
2807            poly_variant.as_ref().map(|variant| {
2808                let ret = variant.ret().shift_in_escaping(1);
2809                let output = Binder::bind_with_vars(FnOutput::new(ret, vec![]), List::empty());
2810                let inputs = match field_idx {
2811                    None => variant.fields.clone(),
2812                    Some(i) => List::singleton(variant.fields[i.index()].clone()),
2813                };
2814                FnSig::new(
2815                    Safety::Safe,
2816                    rustc_abi::ExternAbi::Rust,
2817                    variant.requires.clone(),
2818                    inputs,
2819                    output,
2820                    Expr::tt(),
2821                    false,
2822                )
2823            })
2824        })
2825    }
2826}
2827
2828impl TyKind {
2829    fn intern(self) -> Ty {
2830        Ty(Interned::new(self))
2831    }
2832}
2833
2834/// returns the same invariants as for `usize` which is the length of a slice
2835fn slice_invariants(overflow_mode: OverflowMode) -> &'static [Invariant] {
2836    static DEFAULT: LazyLock<[Invariant; 1]> = LazyLock::new(|| {
2837        [Invariant { pred: Binder::bind_with_sort(Expr::ge(Expr::nu(), Expr::zero()), Sort::Int) }]
2838    });
2839    static OVERFLOW: LazyLock<[Invariant; 2]> = LazyLock::new(|| {
2840        [
2841            Invariant {
2842                pred: Binder::bind_with_sort(Expr::ge(Expr::nu(), Expr::zero()), Sort::Int),
2843            },
2844            Invariant {
2845                pred: Binder::bind_with_sort(
2846                    Expr::le(Expr::nu(), Expr::uint_max(UintTy::Usize)),
2847                    Sort::Int,
2848                ),
2849            },
2850        ]
2851    });
2852    if matches!(overflow_mode, OverflowMode::Strict | OverflowMode::Lazy) {
2853        &*OVERFLOW
2854    } else {
2855        &*DEFAULT
2856    }
2857}
2858
2859fn uint_invariants(uint_ty: UintTy, overflow_mode: OverflowMode) -> &'static [Invariant] {
2860    static DEFAULT: LazyLock<[Invariant; 1]> = LazyLock::new(|| {
2861        [Invariant { pred: Binder::bind_with_sort(Expr::ge(Expr::nu(), Expr::zero()), Sort::Int) }]
2862    });
2863
2864    static OVERFLOW: LazyLock<UnordMap<UintTy, [Invariant; 2]>> = LazyLock::new(|| {
2865        UINT_TYS
2866            .into_iter()
2867            .map(|uint_ty| {
2868                let invariants = [
2869                    Invariant {
2870                        pred: Binder::bind_with_sort(Expr::ge(Expr::nu(), Expr::zero()), Sort::Int),
2871                    },
2872                    Invariant {
2873                        pred: Binder::bind_with_sort(
2874                            Expr::le(Expr::nu(), Expr::uint_max(uint_ty)),
2875                            Sort::Int,
2876                        ),
2877                    },
2878                ];
2879                (uint_ty, invariants)
2880            })
2881            .collect()
2882    });
2883    if matches!(overflow_mode, OverflowMode::Strict | OverflowMode::Lazy) {
2884        &OVERFLOW[&uint_ty]
2885    } else {
2886        &*DEFAULT
2887    }
2888}
2889
2890fn char_invariants() -> &'static [Invariant] {
2891    static INVARIANTS: LazyLock<[Invariant; 2]> = LazyLock::new(|| {
2892        [
2893            Invariant {
2894                pred: Binder::bind_with_sort(
2895                    Expr::le(
2896                        Expr::cast(Sort::Char, Sort::Int, Expr::nu()),
2897                        Expr::constant((char::MAX as u32).into()),
2898                    ),
2899                    Sort::Int,
2900                ),
2901            },
2902            Invariant {
2903                pred: Binder::bind_with_sort(
2904                    Expr::le(Expr::zero(), Expr::cast(Sort::Char, Sort::Int, Expr::nu())),
2905                    Sort::Int,
2906                ),
2907            },
2908        ]
2909    });
2910    &*INVARIANTS
2911}
2912
2913fn int_invariants(int_ty: IntTy, overflow_mode: OverflowMode) -> &'static [Invariant] {
2914    static DEFAULT: [Invariant; 0] = [];
2915
2916    static OVERFLOW: LazyLock<UnordMap<IntTy, [Invariant; 2]>> = LazyLock::new(|| {
2917        INT_TYS
2918            .into_iter()
2919            .map(|int_ty| {
2920                let invariants = [
2921                    Invariant {
2922                        pred: Binder::bind_with_sort(
2923                            Expr::ge(Expr::nu(), Expr::int_min(int_ty)),
2924                            Sort::Int,
2925                        ),
2926                    },
2927                    Invariant {
2928                        pred: Binder::bind_with_sort(
2929                            Expr::le(Expr::nu(), Expr::int_max(int_ty)),
2930                            Sort::Int,
2931                        ),
2932                    },
2933                ];
2934                (int_ty, invariants)
2935            })
2936            .collect()
2937    });
2938    if matches!(overflow_mode, OverflowMode::Strict | OverflowMode::Lazy) {
2939        &OVERFLOW[&int_ty]
2940    } else {
2941        &DEFAULT
2942    }
2943}
2944
2945impl_internable!(AdtDefData, AdtSortDefData, TyKind);
2946impl_slice_internable!(
2947    Ty,
2948    GenericArg,
2949    Ensures,
2950    InferMode,
2951    Sort,
2952    SortArg,
2953    GenericParamDef,
2954    TraitRef,
2955    Binder<ExistentialPredicate>,
2956    Clause,
2957    PolyVariant,
2958    Invariant,
2959    RefineParam,
2960    FluxDefId,
2961    SortParamKind,
2962    AssocReft
2963);
2964
2965#[macro_export]
2966macro_rules! _Int {
2967    ($int_ty:pat, $idxs:pat) => {
2968        TyKind::Indexed(BaseTy::Int($int_ty), $idxs)
2969    };
2970}
2971pub use crate::_Int as Int;
2972
2973#[macro_export]
2974macro_rules! _Uint {
2975    ($uint_ty:pat, $idxs:pat) => {
2976        TyKind::Indexed(BaseTy::Uint($uint_ty), $idxs)
2977    };
2978}
2979pub use crate::_Uint as Uint;
2980
2981#[macro_export]
2982macro_rules! _Bool {
2983    ($idxs:pat) => {
2984        TyKind::Indexed(BaseTy::Bool, $idxs)
2985    };
2986}
2987pub use crate::_Bool as Bool;
2988
2989#[macro_export]
2990macro_rules! _Char {
2991    ($idxs:pat) => {
2992        TyKind::Indexed(BaseTy::Char, $idxs)
2993    };
2994}
2995pub use crate::_Char as Char;
2996
2997#[macro_export]
2998macro_rules! _Ref {
2999    ($($pats:pat),+ $(,)?) => {
3000        $crate::rty::TyKind::Indexed($crate::rty::BaseTy::Ref($($pats),+), _)
3001    };
3002}
3003pub use crate::_Ref as Ref;
3004
3005pub struct WfckResults {
3006    pub owner: FluxOwnerId,
3007    param_sorts: UnordMap<fhir::ParamId, Sort>,
3008    bin_op_sorts: ItemLocalMap<Sort>,
3009    fn_app_sorts: ItemLocalMap<List<SortArg>>,
3010    coercions: ItemLocalMap<Vec<Coercion>>,
3011    field_projs: ItemLocalMap<FieldProj>,
3012    node_sorts: ItemLocalMap<Sort>,
3013    record_ctors: ItemLocalMap<DefId>,
3014}
3015
3016#[derive(Clone, Copy, Debug)]
3017pub enum Coercion {
3018    Inject(DefId),
3019    Project(DefId),
3020}
3021
3022pub type ItemLocalMap<T> = UnordMap<fhir::ItemLocalId, T>;
3023
3024#[derive(Debug)]
3025pub struct LocalTableInContext<'a, T> {
3026    owner: FluxOwnerId,
3027    data: &'a ItemLocalMap<T>,
3028}
3029
3030pub struct LocalTableInContextMut<'a, T> {
3031    owner: FluxOwnerId,
3032    data: &'a mut ItemLocalMap<T>,
3033}
3034
3035impl WfckResults {
3036    pub fn new(owner: impl Into<FluxOwnerId>) -> Self {
3037        Self {
3038            owner: owner.into(),
3039            param_sorts: UnordMap::default(),
3040            bin_op_sorts: ItemLocalMap::default(),
3041            coercions: ItemLocalMap::default(),
3042            field_projs: ItemLocalMap::default(),
3043            node_sorts: ItemLocalMap::default(),
3044            record_ctors: ItemLocalMap::default(),
3045            fn_app_sorts: ItemLocalMap::default(),
3046        }
3047    }
3048
3049    pub fn param_sorts_mut(&mut self) -> &mut UnordMap<fhir::ParamId, Sort> {
3050        &mut self.param_sorts
3051    }
3052
3053    pub fn param_sorts(&self) -> &UnordMap<fhir::ParamId, Sort> {
3054        &self.param_sorts
3055    }
3056
3057    pub fn bin_op_sorts_mut(&mut self) -> LocalTableInContextMut<'_, Sort> {
3058        LocalTableInContextMut { owner: self.owner, data: &mut self.bin_op_sorts }
3059    }
3060
3061    pub fn fn_app_sorts_mut(&mut self) -> LocalTableInContextMut<'_, List<SortArg>> {
3062        LocalTableInContextMut { owner: self.owner, data: &mut self.fn_app_sorts }
3063    }
3064
3065    pub fn fn_app_sorts(&self) -> LocalTableInContext<'_, List<SortArg>> {
3066        LocalTableInContext { owner: self.owner, data: &self.fn_app_sorts }
3067    }
3068
3069    pub fn bin_op_sorts(&self) -> LocalTableInContext<'_, Sort> {
3070        LocalTableInContext { owner: self.owner, data: &self.bin_op_sorts }
3071    }
3072
3073    pub fn coercions_mut(&mut self) -> LocalTableInContextMut<'_, Vec<Coercion>> {
3074        LocalTableInContextMut { owner: self.owner, data: &mut self.coercions }
3075    }
3076
3077    pub fn coercions(&self) -> LocalTableInContext<'_, Vec<Coercion>> {
3078        LocalTableInContext { owner: self.owner, data: &self.coercions }
3079    }
3080
3081    pub fn field_projs_mut(&mut self) -> LocalTableInContextMut<'_, FieldProj> {
3082        LocalTableInContextMut { owner: self.owner, data: &mut self.field_projs }
3083    }
3084
3085    pub fn field_projs(&self) -> LocalTableInContext<'_, FieldProj> {
3086        LocalTableInContext { owner: self.owner, data: &self.field_projs }
3087    }
3088
3089    pub fn node_sorts_mut(&mut self) -> LocalTableInContextMut<'_, Sort> {
3090        LocalTableInContextMut { owner: self.owner, data: &mut self.node_sorts }
3091    }
3092
3093    pub fn node_sorts(&self) -> LocalTableInContext<'_, Sort> {
3094        LocalTableInContext { owner: self.owner, data: &self.node_sorts }
3095    }
3096
3097    pub fn record_ctors_mut(&mut self) -> LocalTableInContextMut<'_, DefId> {
3098        LocalTableInContextMut { owner: self.owner, data: &mut self.record_ctors }
3099    }
3100
3101    pub fn record_ctors(&self) -> LocalTableInContext<'_, DefId> {
3102        LocalTableInContext { owner: self.owner, data: &self.record_ctors }
3103    }
3104}
3105
3106impl<T> LocalTableInContextMut<'_, T> {
3107    pub fn insert(&mut self, fhir_id: FhirId, value: T) {
3108        tracked_span_assert_eq!(self.owner, fhir_id.owner);
3109        self.data.insert(fhir_id.local_id, value);
3110    }
3111}
3112
3113impl<'a, T> LocalTableInContext<'a, T> {
3114    pub fn get(&self, fhir_id: FhirId) -> Option<&'a T> {
3115        tracked_span_assert_eq!(self.owner, fhir_id.owner);
3116        self.data.get(&fhir_id.local_id)
3117    }
3118}
3119
3120fn can_auto_strong(fn_sig: &PolyFnSig) -> bool {
3121    struct RegionDetector {
3122        has_region: bool,
3123    }
3124
3125    impl fold::TypeFolder for RegionDetector {
3126        fn fold_region(&mut self, re: &Region) -> Region {
3127            self.has_region = true;
3128            *re
3129        }
3130    }
3131    let mut detector = RegionDetector { has_region: false };
3132    fn_sig
3133        .skip_binder_ref()
3134        .output()
3135        .skip_binder_ref()
3136        .ret
3137        .fold_with(&mut detector);
3138
3139    !detector.has_region
3140}
3141/// The [`auto_strong`] function transforms function signatures by automatically converting
3142/// mutable reference parameters into strong references with associated ensures clauses. This
3143/// transformation is applied only when the function signature does not already contain region
3144/// variables in its return type.
3145///
3146/// Specifically, given a source function of type
3147///
3148///    fn (x: &mut InnerTy) -> bool
3149///
3150/// By default the above gives us an `rty::FnSig`
3151///
3152///    forall<>. fn (x: &mut InnerTy) -> bool
3153///
3154/// Which this function then transforms to
3155///
3156///     forall<l0: Loc>. fn (x: &strg<l0:InnerTy>) -> bool ensures l0:InnerTy
3157pub fn auto_strong(
3158    genv: GlobalEnv,
3159    def_id: impl IntoQueryParam<DefId>,
3160    fn_sig: PolyFnSig,
3161) -> PolyFnSig {
3162    // TODO(auto-strong): we only *really* need the first check `can_auto_strong` here.
3163    // The other two skip `auto-strong` as doing it breaks various downstream things
3164    // that should be fixed.
3165    if !can_auto_strong(&fn_sig)
3166        || matches!(genv.def_kind(def_id), rustc_hir::def::DefKind::Closure)
3167        || !fn_sig.skip_binder_ref().lifted
3168    {
3169        return fn_sig;
3170    }
3171    let kind = BoundReftKind::Anon;
3172    let mut vars = fn_sig.vars().to_vec();
3173    let fn_sig = fn_sig.skip_binder();
3174    // new list of (bound_var, inner_ty)
3175    let mut strg_bvars = vec![];
3176    // new list of input types
3177    let mut strg_inputs = vec![];
3178    // 1. Traverse inputs collecting strong locations
3179    for ty in &fn_sig.inputs {
3180        let strg_ty = if let TyKind::Indexed(BaseTy::Ref(re, inner_ty, Mutability::Mut), _) =
3181            ty.kind()
3182            && !inner_ty.is_slice()
3183        // TODO(auto-strong): including `slice` breaks `tock` for some reason we should replicate in our own tests...
3184        {
3185            // if input is &mut InnerTy create a new bound var `loc` for the strong location
3186            let var = {
3187                let idx = vars.len() + strg_bvars.len();
3188                BoundVar::from_usize(idx)
3189            };
3190            strg_bvars.push((var, inner_ty.clone()));
3191            let loc = Loc::Var(Var::Bound(INNERMOST, BoundReft { var, kind }));
3192            // and transform to &strg<loc:InnerTy>
3193            Ty::strg_ref(*re, Path::new(loc, List::empty()), inner_ty.clone())
3194        } else {
3195            // else leave input type unchanged
3196            ty.clone()
3197        };
3198        strg_inputs.push(strg_ty);
3199    }
3200    // 2. Add bound vars for strong locations
3201    for _ in 0..strg_bvars.len() {
3202        vars.push(BoundVariableKind::Refine(Sort::Loc, InferMode::EVar, kind));
3203    }
3204    // 3. Add ensures for strong locations
3205    let output = fn_sig.output.map(|out| {
3206        let mut ens = out.ensures.to_vec();
3207        for (var, inner_ty) in strg_bvars {
3208            let loc = Loc::Var(Var::Bound(INNERMOST.shifted_in(1), BoundReft { var, kind }));
3209            let path = Path::new(loc, List::empty());
3210            ens.push(Ensures::Type(path, inner_ty.shift_in_escaping(1)));
3211        }
3212        FnOutput { ensures: List::from_vec(ens), ..out }
3213    });
3214
3215    // 4. Reconstruct fn sig with new inputs and output and vars
3216    let fn_sig = FnSig { inputs: List::from_vec(strg_inputs), output, ..fn_sig };
3217    Binder::bind_with_vars(fn_sig, vars.into())
3218}