flux_middle/rty/
mod.rs

1//! Defines how flux represents refinement types internally. Definitions in this module are used
2//! during refinement type checking. A couple of important differences between definitions in this
3//! module and in [`crate::fhir`] are:
4//!
5//! * Types in this module use debruijn indices to represent local binders.
6//! * Data structures are interned so they can be cheaply cloned.
7mod binder;
8pub mod canonicalize;
9mod expr;
10pub mod fold;
11pub mod normalize;
12mod pretty;
13pub mod refining;
14pub mod region_matching;
15pub mod subst;
16use std::{borrow::Cow, cmp::Ordering, fmt, hash::Hash, sync::LazyLock};
17
18pub use binder::{Binder, BoundReftKind, BoundVariableKind, BoundVariableKinds, EarlyBinder};
19use bitflags::bitflags;
20pub use expr::{
21    AggregateKind, AliasReft, BinOp, BoundReft, Constant, Ctor, ESpan, EVid, EarlyReftParam, Expr,
22    ExprKind, FieldProj, HoleKind, InternalFuncKind, KVar, KVid, Lambda, Loc, Name, NameProvenance,
23    Path, PrettyMap, PrettyVar, Real, SpecFuncKind, UnOp, Var,
24};
25pub use flux_arc_interner::List;
26use flux_arc_interner::{Interned, impl_internable, impl_slice_internable};
27use flux_common::{bug, tracked_span_assert_eq, tracked_span_bug};
28use flux_config::OverflowMode;
29use flux_macros::{TypeFoldable, TypeVisitable};
30pub use flux_rustc_bridge::ty::{
31    AliasKind, BoundRegion, BoundRegionKind, BoundVar, Const, ConstKind, ConstVid, DebruijnIndex,
32    EarlyParamRegion, LateParamRegion, LateParamRegionKind,
33    Region::{self, *},
34    RegionVid,
35};
36use flux_rustc_bridge::{
37    ToRustc,
38    mir::{Place, RawPtrKind},
39    ty::{self, GenericArgsExt as _, VariantDef},
40};
41use itertools::Itertools;
42pub use normalize::{FuncInfo, NormalizedDefns, local_deps};
43use refining::{Refine as _, Refiner};
44use rustc_abi;
45pub use rustc_abi::{FIRST_VARIANT, VariantIdx};
46use rustc_data_structures::{fx::FxIndexMap, snapshot_map::SnapshotMap, unord::UnordMap};
47use rustc_hir::{LangItem, Safety, def_id::DefId};
48use rustc_index::{IndexSlice, IndexVec, newtype_index};
49use rustc_macros::{Decodable, Encodable, TyDecodable, TyEncodable, extension};
50pub use rustc_middle::{
51    mir::Mutability,
52    ty::{AdtFlags, ClosureKind, FloatTy, IntTy, ParamConst, ParamTy, ScalarInt, UintTy},
53};
54use rustc_middle::{
55    query::IntoQueryParam,
56    ty::{TyCtxt, fast_reject::SimplifiedType},
57};
58use rustc_span::{DUMMY_SP, Span, Symbol, sym, symbol::kw};
59use rustc_type_ir::Upcast as _;
60pub use rustc_type_ir::{INNERMOST, TyVid};
61
62use self::fold::TypeFoldable;
63pub use crate::fhir::InferMode;
64use crate::{
65    LocalDefId,
66    def_id::{FluxDefId, FluxLocalDefId},
67    fhir::{self, FhirId, FluxOwnerId},
68    global_env::GlobalEnv,
69    pretty::{Pretty, PrettyCx},
70    queries::{QueryErr, QueryResult},
71    rty::subst::SortSubst,
72};
73
74/// The definition of the data sort automatically generated for a struct or enum.
75#[derive(Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
76pub struct AdtSortDef(Interned<AdtSortDefData>);
77
78#[derive(Debug, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
79pub struct AdtSortVariant {
80    /// The list of field names as declared in the `#[flux::refined_by(...)]` annotation
81    field_names: Vec<Symbol>,
82    /// The sort of each of the fields. Note that these can contain [sort variables]. Methods used
83    /// to access these sorts guarantee they are properly instantiated.
84    ///
85    /// [sort variables]: Sort::Var
86    sorts: List<Sort>,
87}
88
89impl AdtSortVariant {
90    pub fn new(fields: Vec<(Symbol, Sort)>) -> Self {
91        let (field_names, sorts) = fields.into_iter().unzip();
92        AdtSortVariant { field_names, sorts: List::from_vec(sorts) }
93    }
94
95    pub fn fields(&self) -> usize {
96        self.sorts.len()
97    }
98
99    pub fn field_names(&self) -> &Vec<Symbol> {
100        &self.field_names
101    }
102
103    pub fn sort_by_field_name(&self, args: &[Sort]) -> FxIndexMap<Symbol, Sort> {
104        std::iter::zip(&self.field_names, &self.sorts.fold_with(&mut SortSubst::new(args)))
105            .map(|(name, sort)| (*name, sort.clone()))
106            .collect()
107    }
108
109    pub fn field_by_name(
110        &self,
111        def_id: DefId,
112        args: &[Sort],
113        name: Symbol,
114    ) -> Option<(FieldProj, Sort)> {
115        let idx = self.field_names.iter().position(|it| name == *it)?;
116        let proj = FieldProj::Adt { def_id, field: idx as u32 };
117        let sort = self.sorts[idx].fold_with(&mut SortSubst::new(args));
118        Some((proj, sort))
119    }
120
121    pub fn field_sorts(&self, args: &[Sort]) -> List<Sort> {
122        self.sorts.fold_with(&mut SortSubst::new(args))
123    }
124
125    pub fn field_sorts_instantiate_identity(&self) -> List<Sort> {
126        self.sorts.clone()
127    }
128
129    pub fn projections(&self, def_id: DefId) -> impl Iterator<Item = FieldProj> {
130        (0..self.fields()).map(move |i| FieldProj::Adt { def_id, field: i as u32 })
131    }
132}
133
134#[derive(Debug, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
135struct AdtSortDefData {
136    /// [`DefId`] of the struct or enum this data sort is associated to.
137    def_id: DefId,
138    /// The list of the type parameters used in the `#[flux::refined_by(..)]` annotation.
139    ///
140    /// See [`fhir::RefinedBy::sort_params`] for more details. This is a version of that but using
141    /// [`ParamTy`] instead of [`DefId`].
142    ///
143    /// The length of this list corresponds to the number of sort variables bound by this definition.
144    params: Vec<ParamTy>,
145    /// A vec of variants of the ADT;
146    /// - a `struct` sort -- used for types with a `refined_by` has a single variant;
147    /// - a `reflected` sort -- used for `reflected` enums have multiple variants
148    variants: IndexVec<VariantIdx, AdtSortVariant>,
149    is_reflected: bool,
150    is_struct: bool,
151}
152
153impl AdtSortDef {
154    pub fn new(
155        def_id: DefId,
156        params: Vec<ParamTy>,
157        variants: IndexVec<VariantIdx, AdtSortVariant>,
158        is_reflected: bool,
159        is_struct: bool,
160    ) -> Self {
161        Self(Interned::new(AdtSortDefData { def_id, params, variants, is_reflected, is_struct }))
162    }
163
164    pub fn did(&self) -> DefId {
165        self.0.def_id
166    }
167
168    pub fn variant(&self, idx: VariantIdx) -> &AdtSortVariant {
169        &self.0.variants[idx]
170    }
171
172    pub fn variants(&self) -> &IndexSlice<VariantIdx, AdtSortVariant> {
173        &self.0.variants
174    }
175
176    pub fn opt_struct_variant(&self) -> Option<&AdtSortVariant> {
177        if self.is_struct() { Some(self.struct_variant()) } else { None }
178    }
179
180    #[track_caller]
181    pub fn struct_variant(&self) -> &AdtSortVariant {
182        tracked_span_assert_eq!(self.0.is_struct, true);
183        &self.0.variants[FIRST_VARIANT]
184    }
185
186    pub fn is_reflected(&self) -> bool {
187        self.0.is_reflected
188    }
189
190    pub fn is_struct(&self) -> bool {
191        self.0.is_struct
192    }
193
194    pub fn to_sort(&self, args: &[GenericArg]) -> Sort {
195        let sorts = self
196            .filter_generic_args(args)
197            .map(|arg| arg.expect_base().sort())
198            .collect();
199
200        Sort::App(SortCtor::Adt(self.clone()), sorts)
201    }
202
203    /// Given a list of generic args, returns an iterator of the generic arguments that should be
204    /// mapped to sorts for instantiation.
205    pub fn filter_generic_args<'a, A>(&'a self, args: &'a [A]) -> impl Iterator<Item = &'a A> + 'a {
206        self.0.params.iter().map(|p| &args[p.index as usize])
207    }
208
209    pub fn identity_args(&self) -> List<Sort> {
210        (0..self.0.params.len())
211            .map(|i| Sort::Var(ParamSort::from(i)))
212            .collect()
213    }
214
215    /// Gives the number of sort variables bound by this definition.
216    pub fn param_count(&self) -> usize {
217        self.0.params.len()
218    }
219}
220
221#[derive(Debug, Clone, Default, Encodable, Decodable)]
222pub struct Generics {
223    pub parent: Option<DefId>,
224    pub parent_count: usize,
225    pub own_params: List<GenericParamDef>,
226    pub has_self: bool,
227}
228
229impl Generics {
230    pub fn count(&self) -> usize {
231        self.parent_count + self.own_params.len()
232    }
233
234    pub fn own_default_count(&self) -> usize {
235        self.own_params
236            .iter()
237            .filter(|param| {
238                match param.kind {
239                    GenericParamDefKind::Type { has_default }
240                    | GenericParamDefKind::Const { has_default }
241                    | GenericParamDefKind::Base { has_default } => has_default,
242                    GenericParamDefKind::Lifetime => false,
243                }
244            })
245            .count()
246    }
247
248    pub fn param_at(&self, param_index: usize, genv: GlobalEnv) -> QueryResult<GenericParamDef> {
249        if let Some(index) = param_index.checked_sub(self.parent_count) {
250            Ok(self.own_params[index].clone())
251        } else {
252            let parent = self.parent.expect("parent_count > 0 but no parent?");
253            genv.generics_of(parent)?.param_at(param_index, genv)
254        }
255    }
256
257    pub fn const_params(&self, genv: GlobalEnv) -> QueryResult<Vec<(ParamConst, Sort)>> {
258        let mut res = vec![];
259        for i in 0..self.count() {
260            let param = self.param_at(i, genv)?;
261            if let GenericParamDefKind::Const { .. } = param.kind
262                && let Some(sort) = genv.sort_of_def_id(param.def_id)?
263            {
264                let param_const = ParamConst { name: param.name, index: param.index };
265                res.push((param_const, sort));
266            }
267        }
268        Ok(res)
269    }
270}
271
272#[derive(Debug, Clone, TyEncodable, TyDecodable)]
273pub struct RefinementGenerics {
274    pub parent: Option<DefId>,
275    pub parent_count: usize,
276    pub own_params: List<RefineParam>,
277}
278
279#[derive(
280    PartialEq, Eq, Debug, Clone, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
281)]
282pub struct RefineParam {
283    pub sort: Sort,
284    pub name: Symbol,
285    pub mode: InferMode,
286}
287
288#[derive(Debug, Clone, PartialEq, Eq, Hash, Encodable, Decodable)]
289pub struct GenericParamDef {
290    pub kind: GenericParamDefKind,
291    pub def_id: DefId,
292    pub index: u32,
293    pub name: Symbol,
294}
295
296#[derive(Copy, Debug, Clone, PartialEq, Eq, Hash, Encodable, Decodable)]
297pub enum GenericParamDefKind {
298    Type { has_default: bool },
299    Base { has_default: bool },
300    Lifetime,
301    Const { has_default: bool },
302}
303
304pub const SELF_PARAM_TY: ParamTy = ParamTy { index: 0, name: kw::SelfUpper };
305
306#[derive(Debug, Clone, TyEncodable, TyDecodable)]
307pub struct GenericPredicates {
308    pub parent: Option<DefId>,
309    pub predicates: List<Clause>,
310}
311
312#[derive(
313    Debug, PartialEq, Eq, Hash, Clone, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
314)]
315pub struct Clause {
316    kind: Binder<ClauseKind>,
317}
318
319impl Clause {
320    pub fn new(vars: impl Into<List<BoundVariableKind>>, kind: ClauseKind) -> Self {
321        Clause { kind: Binder::bind_with_vars(kind, vars.into()) }
322    }
323
324    pub fn kind(&self) -> Binder<ClauseKind> {
325        self.kind.clone()
326    }
327
328    fn as_trait_clause(&self) -> Option<Binder<TraitPredicate>> {
329        let clause = self.kind();
330        if let ClauseKind::Trait(trait_clause) = clause.skip_binder_ref() {
331            Some(clause.rebind(trait_clause.clone()))
332        } else {
333            None
334        }
335    }
336
337    pub fn as_projection_clause(&self) -> Option<Binder<ProjectionPredicate>> {
338        let clause = self.kind();
339        if let ClauseKind::Projection(proj_clause) = clause.skip_binder_ref() {
340            Some(clause.rebind(proj_clause.clone()))
341        } else {
342            None
343        }
344    }
345
346    // FIXME(nilehmann) we should deal with the binder in all the places this is used instead of
347    // blindly skipping it here
348    pub fn kind_skipping_binder(&self) -> ClauseKind {
349        self.kind.clone().skip_binder()
350    }
351
352    /// Group `Fn` trait clauses with their corresponding `FnOnce::Output` projection
353    /// predicate. This assumes there's exactly one corresponding projection predicate and will
354    /// crash otherwise.
355    pub fn split_off_fn_trait_clauses(
356        genv: GlobalEnv,
357        clauses: &Clauses,
358    ) -> (Vec<Clause>, Vec<Binder<FnTraitPredicate>>) {
359        let mut fn_trait_clauses = vec![];
360        let mut fn_trait_output_clauses = vec![];
361        let mut rest = vec![];
362        for clause in clauses {
363            if let Some(trait_clause) = clause.as_trait_clause()
364                && let Some(kind) = genv.tcx().fn_trait_kind_from_def_id(trait_clause.def_id())
365            {
366                fn_trait_clauses.push((kind, trait_clause));
367            } else if let Some(proj_clause) = clause.as_projection_clause()
368                && genv.is_fn_output(proj_clause.projection_def_id())
369            {
370                fn_trait_output_clauses.push(proj_clause);
371            } else {
372                rest.push(clause.clone());
373            }
374        }
375        let fn_trait_clauses = fn_trait_clauses
376            .into_iter()
377            .map(|(kind, fn_trait_clause)| {
378                let mut candidates = vec![];
379                for fn_trait_output_clause in &fn_trait_output_clauses {
380                    if fn_trait_output_clause.self_ty() == fn_trait_clause.self_ty() {
381                        candidates.push(fn_trait_output_clause.clone());
382                    }
383                }
384                tracked_span_assert_eq!(candidates.len(), 1);
385                let proj_pred = candidates.pop().unwrap().skip_binder();
386                fn_trait_clause.map(|fn_trait_clause| {
387                    FnTraitPredicate {
388                        kind,
389                        self_ty: fn_trait_clause.self_ty().to_ty(),
390                        tupled_args: fn_trait_clause.trait_ref.args[1].expect_base().to_ty(),
391                        output: proj_pred.term.to_ty(),
392                    }
393                })
394            })
395            .collect_vec();
396        (rest, fn_trait_clauses)
397    }
398}
399
400impl<'tcx> ToRustc<'tcx> for Clause {
401    type T = rustc_middle::ty::Clause<'tcx>;
402
403    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
404        self.kind.to_rustc(tcx).upcast(tcx)
405    }
406}
407
408impl From<Binder<ClauseKind>> for Clause {
409    fn from(kind: Binder<ClauseKind>) -> Self {
410        Clause { kind }
411    }
412}
413
414pub type Clauses = List<Clause>;
415
416#[derive(
417    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
418)]
419pub enum ClauseKind {
420    Trait(TraitPredicate),
421    Projection(ProjectionPredicate),
422    RegionOutlives(RegionOutlivesPredicate),
423    TypeOutlives(TypeOutlivesPredicate),
424    ConstArgHasType(Const, Ty),
425}
426
427impl<'tcx> ToRustc<'tcx> for ClauseKind {
428    type T = rustc_middle::ty::ClauseKind<'tcx>;
429
430    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
431        match self {
432            ClauseKind::Trait(trait_predicate) => {
433                rustc_middle::ty::ClauseKind::Trait(trait_predicate.to_rustc(tcx))
434            }
435            ClauseKind::Projection(projection_predicate) => {
436                rustc_middle::ty::ClauseKind::Projection(projection_predicate.to_rustc(tcx))
437            }
438            ClauseKind::RegionOutlives(outlives_predicate) => {
439                rustc_middle::ty::ClauseKind::RegionOutlives(outlives_predicate.to_rustc(tcx))
440            }
441            ClauseKind::TypeOutlives(outlives_predicate) => {
442                rustc_middle::ty::ClauseKind::TypeOutlives(outlives_predicate.to_rustc(tcx))
443            }
444            ClauseKind::ConstArgHasType(constant, ty) => {
445                rustc_middle::ty::ClauseKind::ConstArgHasType(
446                    constant.to_rustc(tcx),
447                    ty.to_rustc(tcx),
448                )
449            }
450        }
451    }
452}
453
454#[derive(Eq, PartialEq, Hash, Clone, Debug, TyEncodable, TyDecodable)]
455pub struct OutlivesPredicate<T>(pub T, pub Region);
456
457pub type TypeOutlivesPredicate = OutlivesPredicate<Ty>;
458pub type RegionOutlivesPredicate = OutlivesPredicate<Region>;
459
460impl<'tcx, V: ToRustc<'tcx>> ToRustc<'tcx> for OutlivesPredicate<V> {
461    type T = rustc_middle::ty::OutlivesPredicate<'tcx, V::T>;
462
463    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
464        rustc_middle::ty::OutlivesPredicate(self.0.to_rustc(tcx), self.1.to_rustc(tcx))
465    }
466}
467
468#[derive(
469    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
470)]
471pub struct TraitPredicate {
472    pub trait_ref: TraitRef,
473}
474
475impl TraitPredicate {
476    fn self_ty(&self) -> SubsetTyCtor {
477        self.trait_ref.self_ty()
478    }
479}
480
481impl<'tcx> ToRustc<'tcx> for TraitPredicate {
482    type T = rustc_middle::ty::TraitPredicate<'tcx>;
483
484    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
485        rustc_middle::ty::TraitPredicate {
486            polarity: rustc_middle::ty::PredicatePolarity::Positive,
487            trait_ref: self.trait_ref.to_rustc(tcx),
488        }
489    }
490}
491
492pub type PolyTraitPredicate = Binder<TraitPredicate>;
493
494impl PolyTraitPredicate {
495    fn def_id(&self) -> DefId {
496        self.skip_binder_ref().trait_ref.def_id
497    }
498
499    fn self_ty(&self) -> Binder<SubsetTyCtor> {
500        self.clone().map(|predicate| predicate.self_ty())
501    }
502}
503
504#[derive(
505    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
506)]
507pub struct TraitRef {
508    pub def_id: DefId,
509    pub args: GenericArgs,
510}
511
512impl TraitRef {
513    pub fn self_ty(&self) -> SubsetTyCtor {
514        self.args[0].expect_base().clone()
515    }
516}
517
518impl<'tcx> ToRustc<'tcx> for TraitRef {
519    type T = rustc_middle::ty::TraitRef<'tcx>;
520
521    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
522        rustc_middle::ty::TraitRef::new(tcx, self.def_id, self.args.to_rustc(tcx))
523    }
524}
525
526pub type PolyTraitRef = Binder<TraitRef>;
527
528impl PolyTraitRef {
529    pub fn def_id(&self) -> DefId {
530        self.as_ref().skip_binder().def_id
531    }
532}
533
534#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
535pub enum ExistentialPredicate {
536    Trait(ExistentialTraitRef),
537    Projection(ExistentialProjection),
538    AutoTrait(DefId),
539}
540
541pub type PolyExistentialPredicate = Binder<ExistentialPredicate>;
542
543impl ExistentialPredicate {
544    /// See [`rustc_middle::ty::ExistentialPredicateStableCmpExt`]
545    pub fn stable_cmp(&self, tcx: TyCtxt, other: &Self) -> Ordering {
546        match (self, other) {
547            (ExistentialPredicate::Trait(_), ExistentialPredicate::Trait(_)) => Ordering::Equal,
548            (ExistentialPredicate::Projection(a), ExistentialPredicate::Projection(b)) => {
549                tcx.def_path_hash(a.def_id)
550                    .cmp(&tcx.def_path_hash(b.def_id))
551            }
552            (ExistentialPredicate::AutoTrait(a), ExistentialPredicate::AutoTrait(b)) => {
553                tcx.def_path_hash(*a).cmp(&tcx.def_path_hash(*b))
554            }
555            (ExistentialPredicate::Trait(_), _) => Ordering::Less,
556            (ExistentialPredicate::Projection(_), ExistentialPredicate::Trait(_)) => {
557                Ordering::Greater
558            }
559            (ExistentialPredicate::Projection(_), _) => Ordering::Less,
560            (ExistentialPredicate::AutoTrait(_), _) => Ordering::Greater,
561        }
562    }
563}
564
565impl<'tcx> ToRustc<'tcx> for ExistentialPredicate {
566    type T = rustc_middle::ty::ExistentialPredicate<'tcx>;
567
568    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
569        match self {
570            ExistentialPredicate::Trait(trait_ref) => {
571                let trait_ref = rustc_middle::ty::ExistentialTraitRef::new_from_args(
572                    tcx,
573                    trait_ref.def_id,
574                    trait_ref.args.to_rustc(tcx),
575                );
576                rustc_middle::ty::ExistentialPredicate::Trait(trait_ref)
577            }
578            ExistentialPredicate::Projection(projection) => {
579                rustc_middle::ty::ExistentialPredicate::Projection(
580                    rustc_middle::ty::ExistentialProjection::new_from_args(
581                        tcx,
582                        projection.def_id,
583                        projection.args.to_rustc(tcx),
584                        projection.term.skip_binder_ref().to_rustc(tcx).into(),
585                    ),
586                )
587            }
588            ExistentialPredicate::AutoTrait(def_id) => {
589                rustc_middle::ty::ExistentialPredicate::AutoTrait(*def_id)
590            }
591        }
592    }
593}
594
595#[derive(
596    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
597)]
598pub struct ExistentialTraitRef {
599    pub def_id: DefId,
600    pub args: GenericArgs,
601}
602
603pub type PolyExistentialTraitRef = Binder<ExistentialTraitRef>;
604
605impl PolyExistentialTraitRef {
606    pub fn def_id(&self) -> DefId {
607        self.as_ref().skip_binder().def_id
608    }
609}
610
611#[derive(
612    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
613)]
614pub struct ExistentialProjection {
615    pub def_id: DefId,
616    pub args: GenericArgs,
617    pub term: SubsetTyCtor,
618}
619
620#[derive(
621    PartialEq, Eq, Hash, Debug, Clone, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
622)]
623pub struct ProjectionPredicate {
624    pub projection_ty: AliasTy,
625    pub term: SubsetTyCtor,
626}
627
628impl Pretty for ProjectionPredicate {
629    fn fmt(&self, _cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
630        write!(
631            f,
632            "ProjectionPredicate << projection_ty = {:?}, term = {:?} >>",
633            self.projection_ty, self.term
634        )
635    }
636}
637
638impl ProjectionPredicate {
639    pub fn self_ty(&self) -> SubsetTyCtor {
640        self.projection_ty.self_ty().clone()
641    }
642}
643
644impl<'tcx> ToRustc<'tcx> for ProjectionPredicate {
645    type T = rustc_middle::ty::ProjectionPredicate<'tcx>;
646
647    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
648        rustc_middle::ty::ProjectionPredicate {
649            projection_term: rustc_middle::ty::AliasTerm::new_from_args(
650                tcx,
651                self.projection_ty.def_id,
652                self.projection_ty.args.to_rustc(tcx),
653            ),
654            term: self.term.as_bty_skipping_binder().to_rustc(tcx).into(),
655        }
656    }
657}
658
659pub type PolyProjectionPredicate = Binder<ProjectionPredicate>;
660
661impl PolyProjectionPredicate {
662    pub fn projection_def_id(&self) -> DefId {
663        self.skip_binder_ref().projection_ty.def_id
664    }
665
666    pub fn self_ty(&self) -> Binder<SubsetTyCtor> {
667        self.clone().map(|predicate| predicate.self_ty())
668    }
669}
670
671#[derive(
672    Clone, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
673)]
674pub struct FnTraitPredicate {
675    pub self_ty: Ty,
676    pub tupled_args: Ty,
677    pub output: Ty,
678    pub kind: ClosureKind,
679}
680
681impl Pretty for FnTraitPredicate {
682    fn fmt(&self, _cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
683        write!(
684            f,
685            "self = {:?}, args = {:?}, output = {:?}, kind = {}",
686            self.self_ty, self.tupled_args, self.output, self.kind
687        )
688    }
689}
690
691impl FnTraitPredicate {
692    pub fn fndef_sig(&self) -> FnSig {
693        let inputs = self.tupled_args.expect_tuple().iter().cloned().collect();
694        let ret = self.output.clone().shift_in_escaping(1);
695        let output = Binder::bind_with_vars(FnOutput::new(ret, vec![]), List::empty());
696        FnSig::new(
697            Safety::Safe,
698            rustc_abi::ExternAbi::Rust,
699            List::empty(),
700            inputs,
701            output,
702            false,
703            false,
704        )
705    }
706}
707
708pub fn to_closure_sig(
709    tcx: TyCtxt,
710    closure_id: LocalDefId,
711    tys: &[Ty],
712    args: &flux_rustc_bridge::ty::GenericArgs,
713    poly_sig: &PolyFnSig,
714    no_panic: bool,
715) -> PolyFnSig {
716    let closure_args = args.as_closure();
717    let kind_ty = closure_args.kind_ty().to_rustc(tcx);
718    let Some(kind) = kind_ty.to_opt_closure_kind() else {
719        bug!("to_closure_sig: expected closure kind, found {kind_ty:?}");
720    };
721
722    let mut vars = poly_sig.vars().clone().to_vec();
723    let fn_sig = poly_sig.clone().skip_binder();
724    let closure_ty = Ty::closure(closure_id.into(), tys, args);
725    let env_ty = match kind {
726        ClosureKind::Fn => {
727            vars.push(BoundVariableKind::Region(BoundRegionKind::ClosureEnv));
728            let br = BoundRegion {
729                var: BoundVar::from_usize(vars.len() - 1),
730                kind: BoundRegionKind::ClosureEnv,
731            };
732            Ty::mk_ref(ReBound(INNERMOST, br), closure_ty, Mutability::Not)
733        }
734        ClosureKind::FnMut => {
735            vars.push(BoundVariableKind::Region(BoundRegionKind::ClosureEnv));
736            let br = BoundRegion {
737                var: BoundVar::from_usize(vars.len() - 1),
738                kind: BoundRegionKind::ClosureEnv,
739            };
740            Ty::mk_ref(ReBound(INNERMOST, br), closure_ty, Mutability::Mut)
741        }
742        ClosureKind::FnOnce => closure_ty,
743    };
744
745    let inputs = std::iter::once(env_ty)
746        .chain(fn_sig.inputs().iter().cloned())
747        .collect::<Vec<_>>();
748    let output = fn_sig.output().clone();
749
750    let fn_sig = crate::rty::FnSig::new(
751        fn_sig.safety,
752        fn_sig.abi,
753        fn_sig.requires.clone(),
754        inputs.into(),
755        output,
756        no_panic,
757        false,
758    );
759
760    PolyFnSig::bind_with_vars(fn_sig, List::from(vars))
761}
762
763#[derive(
764    Clone, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
765)]
766pub struct CoroutineObligPredicate {
767    pub def_id: DefId,
768    pub resume_ty: Ty,
769    pub upvar_tys: List<Ty>,
770    pub output: Ty,
771}
772
773#[derive(Copy, Clone, Encodable, Decodable, Hash, PartialEq, Eq)]
774pub struct AssocReft {
775    pub def_id: FluxDefId,
776    // NOTE: Field is used to denote final associated generic refinements on Traits
777    pub final_: bool,
778    pub span: Span,
779}
780
781impl AssocReft {
782    pub fn new(def_id: FluxDefId, final_: bool, span: Span) -> Self {
783        Self { def_id, final_, span }
784    }
785
786    pub fn name(&self) -> Symbol {
787        self.def_id.name()
788    }
789
790    pub fn def_id(&self) -> FluxDefId {
791        self.def_id
792    }
793}
794
795#[derive(Clone, Encodable, Decodable)]
796pub struct AssocRefinements {
797    pub items: List<AssocReft>,
798}
799
800impl Default for AssocRefinements {
801    fn default() -> Self {
802        Self { items: List::empty() }
803    }
804}
805
806impl AssocRefinements {
807    pub fn get(&self, assoc_id: FluxDefId) -> AssocReft {
808        *self
809            .items
810            .into_iter()
811            .find(|it| it.def_id == assoc_id)
812            .unwrap_or_else(|| bug!("caller should guarantee existence of associated refinement"))
813    }
814
815    pub fn find(&self, name: Symbol) -> Option<AssocReft> {
816        Some(*self.items.into_iter().find(|it| it.name() == name)?)
817    }
818}
819
820#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
821pub enum SortCtor {
822    Set,
823    Map,
824    Adt(AdtSortDef),
825    User(FluxDefId),
826}
827
828newtype_index! {
829    /// [`ParamSort`] is used for polymorphic sorts (`Set`, `Map`, etc.) and [bit-vector size parameters].
830    /// They should occur "bound" under a [`PolyFuncSort`] or an [`AdtSortDef`]. We assume there's a
831    /// single binder and a [`ParamSort`] represents a variable as an index into the list of variables
832    /// bound by that binder, i.e., the representation doesnt't support higher-ranked sorts.
833    ///
834    /// [bit-vector size parameters]: BvSize::Param
835    #[debug_format = "?{}s"]
836    #[encodable]
837    pub struct ParamSort {}
838}
839
840newtype_index! {
841    /// A *sort* *v*variable *id*
842    #[debug_format = "?{}s"]
843    #[encodable]
844    pub struct SortVid {}
845}
846
847impl ena::unify::UnifyKey for SortVid {
848    type Value = SortVarVal;
849
850    #[inline]
851    fn index(&self) -> u32 {
852        self.as_u32()
853    }
854
855    #[inline]
856    fn from_index(u: u32) -> Self {
857        SortVid::from_u32(u)
858    }
859
860    fn tag() -> &'static str {
861        "SortVid"
862    }
863}
864
865bitflags! {
866    /// A *sort constraint* is a set of operations a sort must support.
867    ///
868    /// During sort checking, we accumulate the operations required for each sort variable. As
869    /// unification progresses, these constraints become more specific, i.e, a sort must support
870    /// more operations to satisfy them.
871    #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
872    pub struct SortCstr: u16 {
873        /// An empty constraint (any sort satisfies it)
874        const BOT     = 0b0000000000;
875        /// `*`
876        const MUL     = 0b0000000001;
877        /// `/`
878        const DIV     = 0b0000000010;
879        /// `%`
880        const MOD     = 0b0000000100;
881        /// `+`
882        const ADD     = 0b0000001000;
883        /// `-`
884        const SUB     = 0b0000010000;
885        /// `|`
886        const BIT_OR  = 0b0000100000;
887        /// `&`
888        const BIT_AND = 0b0001000000;
889        /// `>>`
890        const BIT_SHL = 0b0010000000;
891        /// `<<`
892        const BIT_SHR = 0b0100000000;
893        /// `^`
894        const BIT_XOR = 0b1000000000;
895
896        /// The set of operations supported by all _numeric_ sorts.
897        const NUMERIC = Self::ADD.bits() | Self::SUB.bits() | Self::MUL.bits() | Self::DIV.bits();
898        /// The set of operations supported by integers.
899        const INT = Self::DIV.bits()
900            | Self::MUL.bits()
901            | Self::MOD.bits()
902            | Self::ADD.bits()
903            | Self::SUB.bits();
904        /// The set of operations supported by reals.
905        const REAL = Self::ADD.bits() | Self::SUB.bits() | Self::MUL.bits() | Self::DIV.bits();
906        /// The set of operations supported by bit vectors.
907        const BITVEC =  Self::DIV.bits()
908            | Self::MUL.bits()
909            | Self::MOD.bits()
910            | Self::ADD.bits()
911            | Self::SUB.bits()
912            | Self::BIT_OR.bits()
913            | Self::BIT_AND.bits()
914            | Self::BIT_SHL.bits()
915            | Self::BIT_SHR.bits()
916            | Self::BIT_XOR.bits();
917        /// The set of operations supported by sets.
918        const SET = Self::SUB.bits() | Self::BIT_OR.bits() | Self::BIT_AND.bits();
919    }
920}
921
922impl SortCstr {
923    /// Returns a constraint that only requires the specified binary operation.
924    pub fn from_bin_op(op: fhir::BinOp) -> Self {
925        match op {
926            fhir::BinOp::Add => Self::ADD,
927            fhir::BinOp::Sub => Self::SUB,
928            fhir::BinOp::Mul => Self::MUL,
929            fhir::BinOp::Div => Self::DIV,
930            fhir::BinOp::Mod => Self::MOD,
931            fhir::BinOp::BitAnd => Self::BIT_AND,
932            fhir::BinOp::BitOr => Self::BIT_OR,
933            fhir::BinOp::BitXor => Self::BIT_XOR,
934            fhir::BinOp::BitShl => Self::BIT_SHL,
935            fhir::BinOp::BitShr => Self::BIT_SHR,
936            _ => bug!("{op:?} not supported as a constraint"),
937        }
938    }
939
940    /// Returns whether a sort satisfies this constraint
941    fn satisfy(self, sort: &Sort) -> bool {
942        match sort {
943            Sort::Int => SortCstr::INT.contains(self),
944            Sort::Real => SortCstr::REAL.contains(self),
945            Sort::BitVec(_) => SortCstr::BITVEC.contains(self),
946            Sort::App(SortCtor::Set, _) => SortCstr::SET.contains(self),
947            _ => self == SortCstr::BOT,
948        }
949    }
950}
951
952/// Unification value for sort variables used during sort checking.
953#[derive(Debug, Clone, PartialEq, Eq)]
954pub enum SortVarVal {
955    /// The variable is not yet solved but the solution must satisfy some constraint.
956    Unsolved(SortCstr),
957    /// The variable has been solved to a sort.
958    Solved(Sort),
959}
960
961impl Default for SortVarVal {
962    fn default() -> Self {
963        SortVarVal::Unsolved(SortCstr::BOT)
964    }
965}
966
967impl SortVarVal {
968    pub fn solved_or(&self, sort: &Sort) -> Sort {
969        match self {
970            SortVarVal::Unsolved(_) => sort.clone(),
971            SortVarVal::Solved(sort) => sort.clone(),
972        }
973    }
974
975    pub fn map_solved(&self, f: impl FnOnce(&Sort) -> Sort) -> SortVarVal {
976        match self {
977            SortVarVal::Unsolved(cstr) => SortVarVal::Unsolved(*cstr),
978            SortVarVal::Solved(sort) => SortVarVal::Solved(f(sort)),
979        }
980    }
981}
982
983impl ena::unify::UnifyValue for SortVarVal {
984    type Error = ();
985
986    fn unify_values(value1: &Self, value2: &Self) -> Result<Self, Self::Error> {
987        match (value1, value2) {
988            (SortVarVal::Solved(s1), SortVarVal::Solved(s2)) if s1 == s2 => {
989                Ok(SortVarVal::Solved(s1.clone()))
990            }
991            (SortVarVal::Unsolved(a), SortVarVal::Unsolved(b)) => Ok(SortVarVal::Unsolved(*a | *b)),
992            (SortVarVal::Unsolved(v), SortVarVal::Solved(sort))
993            | (SortVarVal::Solved(sort), SortVarVal::Unsolved(v))
994                if v.satisfy(sort) =>
995            {
996                Ok(SortVarVal::Solved(sort.clone()))
997            }
998            _ => Err(()),
999        }
1000    }
1001}
1002
1003newtype_index! {
1004    /// A *b*it *v*ector *size* *v*variable *id*
1005    #[debug_format = "?{}size"]
1006    #[encodable]
1007    pub struct BvSizeVid {}
1008}
1009
1010impl ena::unify::UnifyKey for BvSizeVid {
1011    type Value = Option<BvSize>;
1012
1013    #[inline]
1014    fn index(&self) -> u32 {
1015        self.as_u32()
1016    }
1017
1018    #[inline]
1019    fn from_index(u: u32) -> Self {
1020        BvSizeVid::from_u32(u)
1021    }
1022
1023    fn tag() -> &'static str {
1024        "BvSizeVid"
1025    }
1026}
1027
1028impl ena::unify::EqUnifyValue for BvSize {}
1029
1030#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1031pub enum Sort {
1032    Int,
1033    Bool,
1034    Real,
1035    BitVec(BvSize),
1036    Str,
1037    Char,
1038    Loc,
1039    Param(ParamTy),
1040    Tuple(List<Sort>),
1041    Alias(AliasKind, AliasTy),
1042    Func(PolyFuncSort),
1043    App(SortCtor, List<Sort>),
1044    Var(ParamSort),
1045    Infer(SortVid),
1046    Err,
1047}
1048
1049pub enum CastKind {
1050    /// Identity cast, which is erasable (e.g. int -> int, char -> int)
1051    Identity,
1052    /// From bool to int
1053    BoolToInt,
1054    /// Casts to unit index, (e.g. int -> float)
1055    IntoUnit,
1056    /// Uninterpreted casts, only allowed with explicit flag
1057    Uninterpreted,
1058}
1059
1060impl Sort {
1061    pub fn tuple(sorts: impl Into<List<Sort>>) -> Self {
1062        Sort::Tuple(sorts.into())
1063    }
1064
1065    pub fn app(ctor: SortCtor, sorts: List<Sort>) -> Self {
1066        Sort::App(ctor, sorts)
1067    }
1068
1069    pub fn unit() -> Self {
1070        Self::tuple(vec![])
1071    }
1072
1073    #[track_caller]
1074    pub fn expect_func(&self) -> &PolyFuncSort {
1075        if let Sort::Func(sort) = self { sort } else { bug!("expected `Sort::Func`") }
1076    }
1077
1078    pub fn is_loc(&self) -> bool {
1079        matches!(self, Sort::Loc)
1080    }
1081
1082    pub fn is_unit(&self) -> bool {
1083        matches!(self, Sort::Tuple(sorts) if sorts.is_empty())
1084    }
1085
1086    pub fn is_unit_adt(&self) -> Option<DefId> {
1087        if let Sort::App(SortCtor::Adt(sort_def), _) = self
1088            && let Some(variant) = sort_def.opt_struct_variant()
1089            && variant.fields() == 0
1090        {
1091            Some(sort_def.did())
1092        } else {
1093            None
1094        }
1095    }
1096
1097    /// Whether the sort is a function with return sort bool
1098    pub fn is_pred(&self) -> bool {
1099        matches!(self, Sort::Func(fsort) if fsort.skip_binders().output().is_bool())
1100    }
1101
1102    /// Returns `true` if the sort is [`Bool`].
1103    ///
1104    /// [`Bool`]: Sort::Bool
1105    #[must_use]
1106    pub fn is_bool(&self) -> bool {
1107        matches!(self, Self::Bool)
1108    }
1109
1110    pub fn cast_kind(self: &Sort, to: &Sort) -> CastKind {
1111        if self == to
1112            || (matches!(self, Sort::Char | Sort::Int) && matches!(to, Sort::Char | Sort::Int))
1113        {
1114            CastKind::Identity
1115        } else if matches!(self, Sort::Bool) && matches!(to, Sort::Int) {
1116            CastKind::BoolToInt
1117        } else if to.is_unit() {
1118            CastKind::IntoUnit
1119        } else {
1120            CastKind::Uninterpreted
1121        }
1122    }
1123
1124    pub fn walk(&self, mut f: impl FnMut(&Sort, &[FieldProj])) {
1125        fn go(sort: &Sort, f: &mut impl FnMut(&Sort, &[FieldProj]), proj: &mut Vec<FieldProj>) {
1126            match sort {
1127                Sort::Tuple(flds) => {
1128                    for (i, sort) in flds.iter().enumerate() {
1129                        proj.push(FieldProj::Tuple { arity: flds.len(), field: i as u32 });
1130                        go(sort, f, proj);
1131                        proj.pop();
1132                    }
1133                }
1134                Sort::App(SortCtor::Adt(sort_def), args) if sort_def.is_struct() => {
1135                    let field_sorts = sort_def.struct_variant().field_sorts(args);
1136                    for (i, sort) in field_sorts.iter().enumerate() {
1137                        proj.push(FieldProj::Adt { def_id: sort_def.did(), field: i as u32 });
1138                        go(sort, f, proj);
1139                        proj.pop();
1140                    }
1141                }
1142                _ => {
1143                    f(sort, proj);
1144                }
1145            }
1146        }
1147        go(self, &mut f, &mut vec![]);
1148    }
1149}
1150
1151/// The size of a [bit-vector]
1152///
1153/// [bit-vector]: Sort::BitVec
1154#[derive(Clone, Copy, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1155pub enum BvSize {
1156    /// A fixed size
1157    Fixed(u32),
1158    /// A size that has been parameterized, e.g., bound under a [`PolyFuncSort`]
1159    Param(ParamSort),
1160    /// A size that needs to be inferred. Used during sort checking to instantiate bit-vector
1161    /// sizes at call-sites.
1162    Infer(BvSizeVid),
1163}
1164
1165impl rustc_errors::IntoDiagArg for Sort {
1166    fn into_diag_arg(self, _path: &mut Option<std::path::PathBuf>) -> rustc_errors::DiagArgValue {
1167        rustc_errors::DiagArgValue::Str(Cow::Owned(format!("{self:?}")))
1168    }
1169}
1170
1171impl rustc_errors::IntoDiagArg for FuncSort {
1172    fn into_diag_arg(self, _path: &mut Option<std::path::PathBuf>) -> rustc_errors::DiagArgValue {
1173        rustc_errors::DiagArgValue::Str(Cow::Owned(format!("{self:?}")))
1174    }
1175}
1176
1177#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
1178pub struct FuncSort {
1179    pub inputs_and_output: List<Sort>,
1180}
1181
1182impl FuncSort {
1183    pub fn new(mut inputs: Vec<Sort>, output: Sort) -> Self {
1184        inputs.push(output);
1185        FuncSort { inputs_and_output: List::from_vec(inputs) }
1186    }
1187
1188    pub fn inputs(&self) -> &[Sort] {
1189        &self.inputs_and_output[0..self.inputs_and_output.len() - 1]
1190    }
1191
1192    pub fn output(&self) -> &Sort {
1193        &self.inputs_and_output[self.inputs_and_output.len() - 1]
1194    }
1195
1196    pub fn to_poly(&self) -> PolyFuncSort {
1197        PolyFuncSort::new(List::empty(), self.clone())
1198    }
1199}
1200
1201/// See [`PolyFuncSort`]
1202#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable)]
1203pub enum SortParamKind {
1204    Sort,
1205    BvSize,
1206}
1207
1208/// A polymorphic function sort parametric over [sorts] or [bit-vector sizes].
1209///
1210/// Parameterizing over bit-vector sizes is a bit of a stretch, because smtlib doesn't support full
1211/// parametric reasoning over them. As long as we used functions parameterized over a size monomorphically
1212/// we should be fine. Right now, we can guarantee this, because size parameters are not exposed in
1213/// the surface syntax and they are only used for predefined (interpreted) theory functions.
1214///
1215/// [sorts]: Sort
1216/// [bit-vector sizes]: BvSize::Param
1217#[derive(Clone, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable)]
1218pub struct PolyFuncSort {
1219    /// The list of parameters including sorts and bit vector sizes
1220    params: List<SortParamKind>,
1221    fsort: FuncSort,
1222}
1223
1224impl PolyFuncSort {
1225    pub fn new(params: List<SortParamKind>, fsort: FuncSort) -> Self {
1226        PolyFuncSort { params, fsort }
1227    }
1228
1229    pub fn skip_binders(&self) -> FuncSort {
1230        self.fsort.clone()
1231    }
1232
1233    pub fn instantiate_identity(&self) -> FuncSort {
1234        self.fsort.clone()
1235    }
1236
1237    pub fn expect_mono(&self) -> FuncSort {
1238        assert!(self.params.is_empty());
1239        self.fsort.clone()
1240    }
1241
1242    pub fn params(&self) -> impl ExactSizeIterator<Item = SortParamKind> + '_ {
1243        self.params.iter().copied()
1244    }
1245
1246    pub fn instantiate(&self, args: &[SortArg]) -> FuncSort {
1247        self.fsort.fold_with(&mut SortSubst::new(args))
1248    }
1249}
1250
1251/// An argument for a generic parameter in a [`Sort`] which can be either a generic sort or a
1252/// generic bit-vector size.
1253#[derive(
1254    Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
1255)]
1256pub enum SortArg {
1257    Sort(Sort),
1258    BvSize(BvSize),
1259}
1260
1261#[derive(Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1262pub enum ConstantInfo {
1263    /// An uninterpreted constant
1264    Uninterpreted,
1265    /// A non-integral constant whose value is specified by the user
1266    Interpreted(Expr, Sort),
1267}
1268
1269#[derive(Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1270pub struct AdtDef(Interned<AdtDefData>);
1271
1272#[derive(Debug, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1273pub struct AdtDefData {
1274    invariants: Vec<Invariant>,
1275    sort_def: AdtSortDef,
1276    opaque: bool,
1277    rustc: ty::AdtDef,
1278}
1279
1280/// Option-like enum to explicitly mark that we don't have information about an ADT because it was
1281/// annotated with `#[flux::opaque]`. Note that only structs can be marked as opaque.
1282#[derive(Clone, Debug, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
1283pub enum Opaqueness<T> {
1284    Opaque,
1285    Transparent(T),
1286}
1287
1288impl<T> Opaqueness<T> {
1289    pub fn map<S>(self, f: impl FnOnce(T) -> S) -> Opaqueness<S> {
1290        match self {
1291            Opaqueness::Opaque => Opaqueness::Opaque,
1292            Opaqueness::Transparent(value) => Opaqueness::Transparent(f(value)),
1293        }
1294    }
1295
1296    pub fn as_ref(&self) -> Opaqueness<&T> {
1297        match self {
1298            Opaqueness::Opaque => Opaqueness::Opaque,
1299            Opaqueness::Transparent(value) => Opaqueness::Transparent(value),
1300        }
1301    }
1302
1303    pub fn as_deref(&self) -> Opaqueness<&T::Target>
1304    where
1305        T: std::ops::Deref,
1306    {
1307        match self {
1308            Opaqueness::Opaque => Opaqueness::Opaque,
1309            Opaqueness::Transparent(value) => Opaqueness::Transparent(value.deref()),
1310        }
1311    }
1312
1313    pub fn ok_or_else<E>(self, err: impl FnOnce() -> E) -> Result<T, E> {
1314        match self {
1315            Opaqueness::Transparent(v) => Ok(v),
1316            Opaqueness::Opaque => Err(err()),
1317        }
1318    }
1319
1320    #[track_caller]
1321    pub fn expect(self, msg: &str) -> T {
1322        match self {
1323            Opaqueness::Transparent(val) => val,
1324            Opaqueness::Opaque => bug!("{}", msg),
1325        }
1326    }
1327
1328    pub fn ok_or_query_err(self, struct_id: DefId) -> Result<T, QueryErr> {
1329        self.ok_or_else(|| QueryErr::OpaqueStruct { struct_id })
1330    }
1331}
1332
1333impl<T, E> Opaqueness<Result<T, E>> {
1334    pub fn transpose(self) -> Result<Opaqueness<T>, E> {
1335        match self {
1336            Opaqueness::Transparent(Ok(x)) => Ok(Opaqueness::Transparent(x)),
1337            Opaqueness::Transparent(Err(e)) => Err(e),
1338            Opaqueness::Opaque => Ok(Opaqueness::Opaque),
1339        }
1340    }
1341}
1342
1343pub static INT_TYS: [IntTy; 6] =
1344    [IntTy::Isize, IntTy::I8, IntTy::I16, IntTy::I32, IntTy::I64, IntTy::I128];
1345pub static UINT_TYS: [UintTy; 6] =
1346    [UintTy::Usize, UintTy::U8, UintTy::U16, UintTy::U32, UintTy::U64, UintTy::U128];
1347
1348#[derive(
1349    Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable, TypeFoldable, TypeVisitable,
1350)]
1351pub struct Invariant {
1352    // This predicate may have sort variables, but we don't explicitly mark it like in `PolyFuncSort`.
1353    // See comment on `apply` for details.
1354    pred: Binder<Expr>,
1355}
1356
1357impl Invariant {
1358    pub fn new(pred: Binder<Expr>) -> Self {
1359        Self { pred }
1360    }
1361
1362    pub fn apply(&self, idx: &Expr) -> Expr {
1363        // The predicate may have sort variables but we don't explicitly instantiate them. This
1364        // works because within an expression, sort variables can only appear inside the sort
1365        // annotation for a lambda and invariants cannot have lambdas. It remains to instantiate
1366        // variables in the sort of the binder itself, but since we are removing it, we can avoid
1367        // the explicit instantiation. Ultimately, this works because the expression we generate in
1368        // fixpoint doesn't need sort annotations (sorts are re-inferred).
1369        self.pred.replace_bound_reft(idx)
1370    }
1371}
1372
1373pub type PolyVariants = List<Binder<VariantSig>>;
1374pub type PolyVariant = Binder<VariantSig>;
1375
1376#[derive(Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1377pub struct VariantSig {
1378    pub adt_def: AdtDef,
1379    pub args: GenericArgs,
1380    pub fields: List<Ty>,
1381    pub idx: Expr,
1382    pub requires: List<Expr>,
1383}
1384
1385pub type PolyFnSig = Binder<FnSig>;
1386
1387#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
1388pub struct FnSig {
1389    pub safety: Safety,
1390    pub abi: rustc_abi::ExternAbi,
1391    pub requires: List<Expr>,
1392    pub inputs: List<Ty>,
1393    pub output: Binder<FnOutput>,
1394    pub no_panic: bool,
1395    /// was this auto-lifted (or from a spec)
1396    pub lifted: bool,
1397}
1398
1399#[derive(
1400    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
1401)]
1402pub struct FnOutput {
1403    pub ret: Ty,
1404    pub ensures: List<Ensures>,
1405}
1406
1407#[derive(Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1408pub enum Ensures {
1409    Type(Path, Ty),
1410    Pred(Expr),
1411}
1412
1413#[derive(Debug, TypeVisitable, TypeFoldable)]
1414pub struct Qualifier {
1415    pub def_id: FluxLocalDefId,
1416    pub body: Binder<Expr>,
1417    pub kind: QualifierKind,
1418}
1419
1420#[derive(Debug, TypeFoldable, TypeVisitable, Copy, Clone)]
1421pub enum QualifierKind {
1422    Global,
1423    Local,
1424    Hint,
1425}
1426
1427/// A `PrimOpProp` is a single property for a primitive operation which
1428/// can be conjoined to get the definition of the [`PrimRel`] for that
1429/// primitive operation.
1430#[derive(Debug, TypeVisitable, TypeFoldable)]
1431pub struct PrimOpProp {
1432    pub def_id: FluxLocalDefId,
1433    pub op: BinOp,
1434    pub body: Binder<Expr>,
1435}
1436
1437#[derive(Debug, TypeVisitable, TypeFoldable)]
1438pub struct PrimRel {
1439    pub body: Binder<Expr>,
1440}
1441
1442pub type TyCtor = Binder<Ty>;
1443
1444impl TyCtor {
1445    pub fn to_ty(&self) -> Ty {
1446        match &self.vars()[..] {
1447            [] => {
1448                return self.skip_binder_ref().shift_out_escaping(1);
1449            }
1450            [BoundVariableKind::Refine(sort, ..)] => {
1451                if sort.is_unit() {
1452                    return self.replace_bound_reft(&Expr::unit());
1453                }
1454                if let Some(def_id) = sort.is_unit_adt() {
1455                    return self.replace_bound_reft(&Expr::unit_struct(def_id));
1456                }
1457            }
1458            _ => {}
1459        }
1460        Ty::exists(self.clone())
1461    }
1462}
1463
1464#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1465pub struct Ty(Interned<TyKind>);
1466
1467impl Ty {
1468    pub fn kind(&self) -> &TyKind {
1469        &self.0
1470    }
1471
1472    /// Dummy type used for the `Self` of a `TraitRef` created when converting a trait object, and
1473    /// which gets removed in `ExistentialTraitRef`. This type must not appear anywhere in other
1474    /// converted types and must be a valid `rustc` type (i.e., we must be able to call `to_rustc`
1475    /// on it). `TyKind::Infer(TyVid(0))` does the job, with the caveat that we must skip 0 when
1476    /// generating `TyKind::Infer` for "type holes".
1477    pub fn trait_object_dummy_self() -> Ty {
1478        Ty::infer(TyVid::from_u32(0))
1479    }
1480
1481    pub fn dynamic(preds: impl Into<List<Binder<ExistentialPredicate>>>, region: Region) -> Ty {
1482        BaseTy::Dynamic(preds.into(), region).to_ty()
1483    }
1484
1485    pub fn strg_ref(re: Region, path: Path, ty: Ty) -> Ty {
1486        TyKind::StrgRef(re, path, ty).intern()
1487    }
1488
1489    pub fn ptr(pk: impl Into<PtrKind>, path: impl Into<Path>) -> Ty {
1490        TyKind::Ptr(pk.into(), path.into()).intern()
1491    }
1492
1493    pub fn constr(p: impl Into<Expr>, ty: Ty) -> Ty {
1494        TyKind::Constr(p.into(), ty).intern()
1495    }
1496
1497    pub fn uninit() -> Ty {
1498        TyKind::Uninit.intern()
1499    }
1500
1501    pub fn indexed(bty: BaseTy, idx: impl Into<Expr>) -> Ty {
1502        TyKind::Indexed(bty, idx.into()).intern()
1503    }
1504
1505    pub fn exists(ty: Binder<Ty>) -> Ty {
1506        TyKind::Exists(ty).intern()
1507    }
1508
1509    pub fn exists_with_constr(bty: BaseTy, pred: Expr) -> Ty {
1510        let sort = bty.sort();
1511        let ty = Ty::indexed(bty, Expr::nu());
1512        Ty::exists(Binder::bind_with_sort(Ty::constr(pred, ty), sort))
1513    }
1514
1515    pub fn discr(adt_def: AdtDef, place: Place) -> Ty {
1516        TyKind::Discr(adt_def, place).intern()
1517    }
1518
1519    pub fn unit() -> Ty {
1520        Ty::tuple(vec![])
1521    }
1522
1523    pub fn bool() -> Ty {
1524        BaseTy::Bool.to_ty()
1525    }
1526
1527    pub fn int(int_ty: IntTy) -> Ty {
1528        BaseTy::Int(int_ty).to_ty()
1529    }
1530
1531    pub fn uint(uint_ty: UintTy) -> Ty {
1532        BaseTy::Uint(uint_ty).to_ty()
1533    }
1534
1535    pub fn param(param_ty: ParamTy) -> Ty {
1536        TyKind::Param(param_ty).intern()
1537    }
1538
1539    pub fn downcast(
1540        adt: AdtDef,
1541        args: GenericArgs,
1542        ty: Ty,
1543        variant: VariantIdx,
1544        fields: List<Ty>,
1545    ) -> Ty {
1546        TyKind::Downcast(adt, args, ty, variant, fields).intern()
1547    }
1548
1549    pub fn blocked(ty: Ty) -> Ty {
1550        TyKind::Blocked(ty).intern()
1551    }
1552
1553    pub fn str() -> Ty {
1554        BaseTy::Str.to_ty()
1555    }
1556
1557    pub fn char() -> Ty {
1558        BaseTy::Char.to_ty()
1559    }
1560
1561    pub fn float(float_ty: FloatTy) -> Ty {
1562        BaseTy::Float(float_ty).to_ty()
1563    }
1564
1565    pub fn mk_ref(region: Region, ty: Ty, mutbl: Mutability) -> Ty {
1566        BaseTy::Ref(region, ty, mutbl).to_ty()
1567    }
1568
1569    pub fn mk_slice(ty: Ty) -> Ty {
1570        BaseTy::Slice(ty).to_ty()
1571    }
1572
1573    pub fn mk_box(genv: GlobalEnv, deref_ty: Ty, alloc_ty: Ty) -> QueryResult<Ty> {
1574        let def_id = genv.tcx().require_lang_item(LangItem::OwnedBox, DUMMY_SP);
1575        let adt_def = genv.adt_def(def_id)?;
1576
1577        let args = List::from_arr([GenericArg::Ty(deref_ty), GenericArg::Ty(alloc_ty)]);
1578
1579        let bty = BaseTy::adt(adt_def, args);
1580        Ok(Ty::indexed(bty, Expr::unit_struct(def_id)))
1581    }
1582
1583    pub fn mk_box_with_default_alloc(genv: GlobalEnv, deref_ty: Ty) -> QueryResult<Ty> {
1584        let def_id = genv.tcx().require_lang_item(LangItem::OwnedBox, DUMMY_SP);
1585
1586        let generics = genv.generics_of(def_id)?;
1587        let alloc_ty = genv
1588            .lower_type_of(generics.own_params[1].def_id)?
1589            .skip_binder()
1590            .refine(&Refiner::default_for_item(genv, def_id)?)?;
1591
1592        Ty::mk_box(genv, deref_ty, alloc_ty)
1593    }
1594
1595    pub fn tuple(tys: impl Into<List<Ty>>) -> Ty {
1596        BaseTy::Tuple(tys.into()).to_ty()
1597    }
1598
1599    pub fn array(ty: Ty, c: Const) -> Ty {
1600        BaseTy::Array(ty, c).to_ty()
1601    }
1602
1603    pub fn closure(
1604        did: DefId,
1605        tys: impl Into<List<Ty>>,
1606        args: &flux_rustc_bridge::ty::GenericArgs,
1607    ) -> Ty {
1608        BaseTy::Closure(did, tys.into(), args.clone()).to_ty()
1609    }
1610
1611    pub fn coroutine(did: DefId, resume_ty: Ty, upvar_tys: List<Ty>) -> Ty {
1612        BaseTy::Coroutine(did, resume_ty, upvar_tys).to_ty()
1613    }
1614
1615    pub fn never() -> Ty {
1616        BaseTy::Never.to_ty()
1617    }
1618
1619    pub fn infer(vid: TyVid) -> Ty {
1620        TyKind::Infer(vid).intern()
1621    }
1622
1623    pub fn unconstr(&self) -> (Ty, Expr) {
1624        fn go(this: &Ty, preds: &mut Vec<Expr>) -> Ty {
1625            if let TyKind::Constr(pred, ty) = this.kind() {
1626                preds.push(pred.clone());
1627                go(ty, preds)
1628            } else {
1629                this.clone()
1630            }
1631        }
1632        let mut preds = vec![];
1633        (go(self, &mut preds), Expr::and_from_iter(preds))
1634    }
1635
1636    pub fn unblocked(&self) -> Ty {
1637        match self.kind() {
1638            TyKind::Blocked(ty) => ty.clone(),
1639            _ => self.clone(),
1640        }
1641    }
1642
1643    /// Whether the type is an `int` or a `uint`
1644    pub fn is_integral(&self) -> bool {
1645        self.as_bty_skipping_existentials()
1646            .map(BaseTy::is_integral)
1647            .unwrap_or_default()
1648    }
1649
1650    /// Whether the type is a `bool`
1651    pub fn is_bool(&self) -> bool {
1652        self.as_bty_skipping_existentials()
1653            .map(BaseTy::is_bool)
1654            .unwrap_or_default()
1655    }
1656
1657    /// Whether the type is a `char`
1658    pub fn is_char(&self) -> bool {
1659        self.as_bty_skipping_existentials()
1660            .map(BaseTy::is_char)
1661            .unwrap_or_default()
1662    }
1663
1664    pub fn is_uninit(&self) -> bool {
1665        matches!(self.kind(), TyKind::Uninit)
1666    }
1667
1668    pub fn is_box(&self) -> bool {
1669        self.as_bty_skipping_existentials()
1670            .map(BaseTy::is_box)
1671            .unwrap_or_default()
1672    }
1673
1674    pub fn is_struct(&self) -> bool {
1675        self.as_bty_skipping_existentials()
1676            .map(BaseTy::is_struct)
1677            .unwrap_or_default()
1678    }
1679
1680    pub fn is_array(&self) -> bool {
1681        self.as_bty_skipping_existentials()
1682            .map(BaseTy::is_array)
1683            .unwrap_or_default()
1684    }
1685
1686    pub fn is_slice(&self) -> bool {
1687        self.as_bty_skipping_existentials()
1688            .map(BaseTy::is_slice)
1689            .unwrap_or_default()
1690    }
1691
1692    pub fn as_bty_skipping_existentials(&self) -> Option<&BaseTy> {
1693        match self.kind() {
1694            TyKind::Indexed(bty, _) => Some(bty),
1695            TyKind::Exists(ty) => Some(ty.skip_binder_ref().as_bty_skipping_existentials()?),
1696            TyKind::Constr(_, ty) => ty.as_bty_skipping_existentials(),
1697            _ => None,
1698        }
1699    }
1700
1701    #[track_caller]
1702    pub fn expect_discr(&self) -> (&AdtDef, &Place) {
1703        if let TyKind::Discr(adt_def, place) = self.kind() {
1704            (adt_def, place)
1705        } else {
1706            tracked_span_bug!("expected discr")
1707        }
1708    }
1709
1710    #[track_caller]
1711    pub fn expect_adt(&self) -> (&AdtDef, &[GenericArg], &Expr) {
1712        if let TyKind::Indexed(BaseTy::Adt(adt_def, args), idx) = self.kind() {
1713            (adt_def, args, idx)
1714        } else {
1715            tracked_span_bug!("expected adt `{self:?}`")
1716        }
1717    }
1718
1719    #[track_caller]
1720    pub fn expect_tuple(&self) -> &[Ty] {
1721        if let TyKind::Indexed(BaseTy::Tuple(tys), _) = self.kind() {
1722            tys
1723        } else {
1724            tracked_span_bug!("expected tuple found `{self:?}` (kind: `{:?}`)", self.kind())
1725        }
1726    }
1727
1728    pub fn simplify_type(&self) -> Option<SimplifiedType> {
1729        self.as_bty_skipping_existentials()
1730            .and_then(BaseTy::simplify_type)
1731    }
1732}
1733
1734impl<'tcx> ToRustc<'tcx> for Ty {
1735    type T = rustc_middle::ty::Ty<'tcx>;
1736
1737    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
1738        match self.kind() {
1739            TyKind::Indexed(bty, _) => bty.to_rustc(tcx),
1740            TyKind::Exists(ty) => ty.skip_binder_ref().to_rustc(tcx),
1741            TyKind::Constr(_, ty) => ty.to_rustc(tcx),
1742            TyKind::Param(pty) => pty.to_ty(tcx),
1743            TyKind::StrgRef(re, _, ty) => {
1744                rustc_middle::ty::Ty::new_ref(
1745                    tcx,
1746                    re.to_rustc(tcx),
1747                    ty.to_rustc(tcx),
1748                    Mutability::Mut,
1749                )
1750            }
1751            TyKind::Infer(vid) => rustc_middle::ty::Ty::new_var(tcx, *vid),
1752            TyKind::Uninit
1753            | TyKind::Ptr(_, _)
1754            | TyKind::Discr(..)
1755            | TyKind::Downcast(..)
1756            | TyKind::Blocked(_) => bug!("TODO: to_rustc for `{self:?}`"),
1757        }
1758    }
1759}
1760
1761#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, Debug)]
1762pub enum TyKind {
1763    Indexed(BaseTy, Expr),
1764    Exists(Binder<Ty>),
1765    Constr(Expr, Ty),
1766    Uninit,
1767    StrgRef(Region, Path, Ty),
1768    Ptr(PtrKind, Path),
1769    /// This is a bit of a hack. We use this type internally to represent the result of
1770    /// [`Rvalue::Discriminant`] in a way that we can recover the necessary control information
1771    /// when checking a [`match`]. The hack is that we assume the dicriminant remains the same from
1772    /// the creation of this type until we use it in a [`match`].
1773    ///
1774    ///
1775    /// [`Rvalue::Discriminant`]: flux_rustc_bridge::mir::Rvalue::Discriminant
1776    /// [`match`]: flux_rustc_bridge::mir::TerminatorKind::SwitchInt
1777    Discr(AdtDef, Place),
1778    Param(ParamTy),
1779    Downcast(AdtDef, GenericArgs, Ty, VariantIdx, List<Ty>),
1780    Blocked(Ty),
1781    /// A type that needs to be inferred by matching the signature against a rust signature.
1782    /// [`TyKind::Infer`] appear as an intermediate step during `conv` and should not be present in
1783    /// the final signature.
1784    Infer(TyVid),
1785}
1786
1787#[derive(Copy, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1788pub enum PtrKind {
1789    Mut(Region),
1790    Box,
1791}
1792
1793#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1794pub enum BaseTy {
1795    Int(IntTy),
1796    Uint(UintTy),
1797    Bool,
1798    Str,
1799    Char,
1800    Slice(Ty),
1801    Adt(AdtDef, GenericArgs),
1802    Float(FloatTy),
1803    RawPtr(Ty, Mutability),
1804    RawPtrMetadata(Ty),
1805    Ref(Region, Ty, Mutability),
1806    FnPtr(PolyFnSig),
1807    FnDef(DefId, GenericArgs),
1808    Tuple(List<Ty>),
1809    Alias(AliasKind, AliasTy),
1810    Array(Ty, Const),
1811    Never,
1812    Closure(DefId, /* upvar_tys */ List<Ty>, flux_rustc_bridge::ty::GenericArgs),
1813    Coroutine(DefId, /*resume_ty: */ Ty, /* upvar_tys: */ List<Ty>),
1814    Dynamic(List<Binder<ExistentialPredicate>>, Region),
1815    Param(ParamTy),
1816    Infer(TyVid),
1817    Foreign(DefId),
1818}
1819
1820impl BaseTy {
1821    pub fn opaque(alias_ty: AliasTy) -> BaseTy {
1822        BaseTy::Alias(AliasKind::Opaque, alias_ty)
1823    }
1824
1825    pub fn projection(alias_ty: AliasTy) -> BaseTy {
1826        BaseTy::Alias(AliasKind::Projection, alias_ty)
1827    }
1828
1829    pub fn adt(adt_def: AdtDef, args: GenericArgs) -> BaseTy {
1830        BaseTy::Adt(adt_def, args)
1831    }
1832
1833    pub fn fn_def(def_id: DefId, args: impl Into<GenericArgs>) -> BaseTy {
1834        BaseTy::FnDef(def_id, args.into())
1835    }
1836
1837    pub fn from_primitive_str(s: &str) -> Option<BaseTy> {
1838        match s {
1839            "i8" => Some(BaseTy::Int(IntTy::I8)),
1840            "i16" => Some(BaseTy::Int(IntTy::I16)),
1841            "i32" => Some(BaseTy::Int(IntTy::I32)),
1842            "i64" => Some(BaseTy::Int(IntTy::I64)),
1843            "i128" => Some(BaseTy::Int(IntTy::I128)),
1844            "u8" => Some(BaseTy::Uint(UintTy::U8)),
1845            "u16" => Some(BaseTy::Uint(UintTy::U16)),
1846            "u32" => Some(BaseTy::Uint(UintTy::U32)),
1847            "u64" => Some(BaseTy::Uint(UintTy::U64)),
1848            "u128" => Some(BaseTy::Uint(UintTy::U128)),
1849            "f16" => Some(BaseTy::Float(FloatTy::F16)),
1850            "f32" => Some(BaseTy::Float(FloatTy::F32)),
1851            "f64" => Some(BaseTy::Float(FloatTy::F64)),
1852            "f128" => Some(BaseTy::Float(FloatTy::F128)),
1853            "isize" => Some(BaseTy::Int(IntTy::Isize)),
1854            "usize" => Some(BaseTy::Uint(UintTy::Usize)),
1855            "bool" => Some(BaseTy::Bool),
1856            "char" => Some(BaseTy::Char),
1857            "str" => Some(BaseTy::Str),
1858            _ => None,
1859        }
1860    }
1861
1862    /// If `self` is a primitive, return its [`Symbol`].
1863    pub fn primitive_symbol(&self) -> Option<Symbol> {
1864        match self {
1865            BaseTy::Bool => Some(sym::bool),
1866            BaseTy::Char => Some(sym::char),
1867            BaseTy::Float(f) => {
1868                match f {
1869                    FloatTy::F16 => Some(sym::f16),
1870                    FloatTy::F32 => Some(sym::f32),
1871                    FloatTy::F64 => Some(sym::f64),
1872                    FloatTy::F128 => Some(sym::f128),
1873                }
1874            }
1875            BaseTy::Int(f) => {
1876                match f {
1877                    IntTy::Isize => Some(sym::isize),
1878                    IntTy::I8 => Some(sym::i8),
1879                    IntTy::I16 => Some(sym::i16),
1880                    IntTy::I32 => Some(sym::i32),
1881                    IntTy::I64 => Some(sym::i64),
1882                    IntTy::I128 => Some(sym::i128),
1883                }
1884            }
1885            BaseTy::Uint(f) => {
1886                match f {
1887                    UintTy::Usize => Some(sym::usize),
1888                    UintTy::U8 => Some(sym::u8),
1889                    UintTy::U16 => Some(sym::u16),
1890                    UintTy::U32 => Some(sym::u32),
1891                    UintTy::U64 => Some(sym::u64),
1892                    UintTy::U128 => Some(sym::u128),
1893                }
1894            }
1895            BaseTy::Str => Some(sym::str),
1896            _ => None,
1897        }
1898    }
1899
1900    pub fn is_integral(&self) -> bool {
1901        matches!(self, BaseTy::Int(_) | BaseTy::Uint(_))
1902    }
1903
1904    pub fn is_signed(&self) -> bool {
1905        matches!(self, BaseTy::Int(_))
1906    }
1907
1908    pub fn is_unsigned(&self) -> bool {
1909        matches!(self, BaseTy::Uint(_))
1910    }
1911
1912    pub fn is_float(&self) -> bool {
1913        matches!(self, BaseTy::Float(_))
1914    }
1915
1916    pub fn is_bool(&self) -> bool {
1917        matches!(self, BaseTy::Bool)
1918    }
1919
1920    fn is_struct(&self) -> bool {
1921        matches!(self, BaseTy::Adt(adt_def, _) if adt_def.is_struct())
1922    }
1923
1924    fn is_array(&self) -> bool {
1925        matches!(self, BaseTy::Array(..))
1926    }
1927
1928    fn is_slice(&self) -> bool {
1929        matches!(self, BaseTy::Slice(..))
1930    }
1931
1932    pub fn is_box(&self) -> bool {
1933        matches!(self, BaseTy::Adt(adt_def, _) if adt_def.is_box())
1934    }
1935
1936    pub fn is_char(&self) -> bool {
1937        matches!(self, BaseTy::Char)
1938    }
1939
1940    pub fn is_str(&self) -> bool {
1941        matches!(self, BaseTy::Str)
1942    }
1943
1944    pub fn invariants(
1945        &self,
1946        tcx: TyCtxt,
1947        overflow_mode: OverflowMode,
1948    ) -> impl Iterator<Item = Invariant> {
1949        let (invariants, args) = match self {
1950            BaseTy::Adt(adt_def, args) => (adt_def.invariants().skip_binder(), &args[..]),
1951            BaseTy::Uint(uint_ty) => (uint_invariants(*uint_ty, overflow_mode), &[][..]),
1952            BaseTy::Int(int_ty) => (int_invariants(*int_ty, overflow_mode), &[][..]),
1953            BaseTy::Char => (char_invariants(), &[][..]),
1954            BaseTy::Slice(_) => (slice_invariants(overflow_mode), &[][..]),
1955            _ => (&[][..], &[][..]),
1956        };
1957        invariants
1958            .iter()
1959            .map(move |inv| EarlyBinder(inv).instantiate_ref(tcx, args, &[]))
1960    }
1961
1962    pub fn to_ty(&self) -> Ty {
1963        let sort = self.sort();
1964        if sort.is_unit() {
1965            Ty::indexed(self.clone(), Expr::unit())
1966        } else {
1967            Ty::exists(Binder::bind_with_sort(
1968                Ty::indexed(self.shift_in_escaping(1), Expr::nu()),
1969                sort,
1970            ))
1971        }
1972    }
1973
1974    pub fn to_subset_ty_ctor(&self) -> SubsetTyCtor {
1975        let sort = self.sort();
1976        Binder::bind_with_sort(SubsetTy::trivial(self.clone(), Expr::nu()), sort)
1977    }
1978
1979    #[track_caller]
1980    pub fn expect_adt(&self) -> (&AdtDef, &[GenericArg]) {
1981        if let BaseTy::Adt(adt_def, args) = self {
1982            (adt_def, args)
1983        } else {
1984            tracked_span_bug!("expected adt `{self:?}`")
1985        }
1986    }
1987
1988    /// A type is an *atom* if it is "self-delimiting", i.e., it has a clear boundary
1989    /// when printed. This is used to avoid unnecessary parenthesis when pretty printing.
1990    pub fn is_atom(&self) -> bool {
1991        // (nilehmann) I'm not sure about this list, please adjust if you get any odd behavior
1992        matches!(
1993            self,
1994            BaseTy::Int(_)
1995                | BaseTy::Uint(_)
1996                | BaseTy::Slice(_)
1997                | BaseTy::Bool
1998                | BaseTy::Char
1999                | BaseTy::Str
2000                | BaseTy::Adt(..)
2001                | BaseTy::Tuple(..)
2002                | BaseTy::Param(_)
2003                | BaseTy::Array(..)
2004                | BaseTy::Never
2005                | BaseTy::Closure(..)
2006                | BaseTy::Coroutine(..)
2007                // opaque alias are atoms the way we print them now, but they won't
2008                // be if we print them as `impl Trait`
2009                | BaseTy::Alias(..)
2010        )
2011    }
2012
2013    /// Similar to [`rustc_infer::infer::canonical::ir::fast_reject::simplify_type`].
2014    ///
2015    /// This implementation is currently incomplete, so it should only be used in contexts
2016    /// where completeness is not required. Currently, it's used to find incoherent
2017    /// implementations when resolving associated constants. In this context, incompleteness
2018    /// is acceptable since the worst case outcome is simply failing to resolve a type-relative
2019    /// constant.
2020    fn simplify_type(&self) -> Option<SimplifiedType> {
2021        match self {
2022            BaseTy::Bool => Some(SimplifiedType::Bool),
2023            BaseTy::Char => Some(SimplifiedType::Char),
2024            BaseTy::Int(int_type) => Some(SimplifiedType::Int(*int_type)),
2025            BaseTy::Uint(uint_type) => Some(SimplifiedType::Uint(*uint_type)),
2026            BaseTy::Float(float_type) => Some(SimplifiedType::Float(*float_type)),
2027            BaseTy::Adt(def, _) => Some(SimplifiedType::Adt(def.did())),
2028            BaseTy::Str => Some(SimplifiedType::Str),
2029            BaseTy::Array(..) => Some(SimplifiedType::Array),
2030            BaseTy::Slice(..) => Some(SimplifiedType::Slice),
2031            BaseTy::RawPtr(_, mutbl) => Some(SimplifiedType::Ptr(*mutbl)),
2032            BaseTy::Ref(_, _, mutbl) => Some(SimplifiedType::Ref(*mutbl)),
2033            BaseTy::FnDef(def_id, _) | BaseTy::Closure(def_id, ..) => {
2034                Some(SimplifiedType::Closure(*def_id))
2035            }
2036            BaseTy::Coroutine(def_id, ..) => Some(SimplifiedType::Coroutine(*def_id)),
2037            BaseTy::Never => Some(SimplifiedType::Never),
2038            BaseTy::Tuple(tys) => Some(SimplifiedType::Tuple(tys.len())),
2039            BaseTy::FnPtr(poly_fn_sig) => {
2040                Some(SimplifiedType::Function(poly_fn_sig.skip_binder_ref().inputs().len()))
2041            }
2042            BaseTy::Foreign(def_id) => Some(SimplifiedType::Foreign(*def_id)),
2043            BaseTy::RawPtrMetadata(_)
2044            | BaseTy::Alias(..)
2045            | BaseTy::Param(_)
2046            | BaseTy::Dynamic(..)
2047            | BaseTy::Infer(_) => None,
2048        }
2049    }
2050}
2051
2052impl<'tcx> ToRustc<'tcx> for BaseTy {
2053    type T = rustc_middle::ty::Ty<'tcx>;
2054
2055    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2056        use rustc_middle::ty;
2057        match self {
2058            BaseTy::Int(i) => ty::Ty::new_int(tcx, *i),
2059            BaseTy::Uint(i) => ty::Ty::new_uint(tcx, *i),
2060            BaseTy::Param(pty) => pty.to_ty(tcx),
2061            BaseTy::Slice(ty) => ty::Ty::new_slice(tcx, ty.to_rustc(tcx)),
2062            BaseTy::Bool => tcx.types.bool,
2063            BaseTy::Char => tcx.types.char,
2064            BaseTy::Str => tcx.types.str_,
2065            BaseTy::Adt(adt_def, args) => {
2066                let did = adt_def.did();
2067                let adt_def = tcx.adt_def(did);
2068                let args = args.to_rustc(tcx);
2069                ty::Ty::new_adt(tcx, adt_def, args)
2070            }
2071            BaseTy::FnDef(def_id, args) => {
2072                let args = args.to_rustc(tcx);
2073                ty::Ty::new_fn_def(tcx, *def_id, args)
2074            }
2075            BaseTy::Float(f) => ty::Ty::new_float(tcx, *f),
2076            BaseTy::RawPtr(ty, mutbl) => ty::Ty::new_ptr(tcx, ty.to_rustc(tcx), *mutbl),
2077            BaseTy::Ref(re, ty, mutbl) => {
2078                ty::Ty::new_ref(tcx, re.to_rustc(tcx), ty.to_rustc(tcx), *mutbl)
2079            }
2080            BaseTy::FnPtr(poly_sig) => ty::Ty::new_fn_ptr(tcx, poly_sig.to_rustc(tcx)),
2081            BaseTy::Tuple(tys) => {
2082                let ts = tys.iter().map(|ty| ty.to_rustc(tcx)).collect_vec();
2083                ty::Ty::new_tup(tcx, &ts)
2084            }
2085            BaseTy::Alias(kind, alias_ty) => {
2086                ty::Ty::new_alias(tcx, kind.to_rustc(tcx), alias_ty.to_rustc(tcx))
2087            }
2088            BaseTy::Array(ty, n) => {
2089                let ty = ty.to_rustc(tcx);
2090                let n = n.to_rustc(tcx);
2091                ty::Ty::new_array_with_const_len(tcx, ty, n)
2092            }
2093            BaseTy::Never => tcx.types.never,
2094            BaseTy::Closure(did, _, args) => ty::Ty::new_closure(tcx, *did, args.to_rustc(tcx)),
2095            BaseTy::Dynamic(exi_preds, re) => {
2096                let preds: Vec<_> = exi_preds
2097                    .iter()
2098                    .map(|pred| pred.to_rustc(tcx))
2099                    .collect_vec();
2100                let preds = tcx.mk_poly_existential_predicates(&preds);
2101                ty::Ty::new_dynamic(tcx, preds, re.to_rustc(tcx))
2102            }
2103            BaseTy::Coroutine(def_id, resume_ty, upvars) => {
2104                bug!("TODO: Generator {def_id:?} {resume_ty:?} {upvars:?}")
2105                // let args = args.iter().map(|arg| into_rustc_generic_arg(tcx, arg));
2106                // let args = tcx.mk_args_from_iter(args);
2107                // ty::Ty::new_generator(*tcx, *def_id, args, mov)
2108            }
2109            BaseTy::Infer(ty_vid) => ty::Ty::new_var(tcx, *ty_vid),
2110            BaseTy::Foreign(def_id) => ty::Ty::new_foreign(tcx, *def_id),
2111            BaseTy::RawPtrMetadata(ty) => {
2112                ty::Ty::new_ptr(
2113                    tcx,
2114                    ty.to_rustc(tcx),
2115                    RawPtrKind::FakeForPtrMetadata.to_mutbl_lossy(),
2116                )
2117            }
2118        }
2119    }
2120}
2121
2122#[derive(
2123    Clone, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
2124)]
2125pub struct AliasTy {
2126    pub def_id: DefId,
2127    pub args: GenericArgs,
2128    /// Holds the refinement-arguments for opaque-types; empty for projections
2129    pub refine_args: RefineArgs,
2130}
2131
2132impl AliasTy {
2133    pub fn new(def_id: DefId, args: GenericArgs, refine_args: RefineArgs) -> Self {
2134        AliasTy { args, refine_args, def_id }
2135    }
2136}
2137
2138/// This methods work only with associated type projections (i.e., no opaque types)
2139impl AliasTy {
2140    pub fn self_ty(&self) -> SubsetTyCtor {
2141        self.args[0].expect_base().clone()
2142    }
2143
2144    pub fn with_self_ty(&self, self_ty: SubsetTyCtor) -> Self {
2145        Self {
2146            def_id: self.def_id,
2147            args: [GenericArg::Base(self_ty)]
2148                .into_iter()
2149                .chain(self.args.iter().skip(1).cloned())
2150                .collect(),
2151            refine_args: self.refine_args.clone(),
2152        }
2153    }
2154}
2155
2156impl<'tcx> ToRustc<'tcx> for AliasTy {
2157    type T = rustc_middle::ty::AliasTy<'tcx>;
2158
2159    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2160        rustc_middle::ty::AliasTy::new(tcx, self.def_id, self.args.to_rustc(tcx))
2161    }
2162}
2163
2164pub type RefineArgs = List<Expr>;
2165
2166#[extension(pub trait RefineArgsExt)]
2167impl RefineArgs {
2168    fn identity_for_item(genv: GlobalEnv, def_id: DefId) -> QueryResult<RefineArgs> {
2169        Self::for_item(genv, def_id, |param, index| {
2170            Ok(Expr::var(Var::EarlyParam(EarlyReftParam {
2171                index: index as u32,
2172                name: param.name(),
2173            })))
2174        })
2175    }
2176
2177    fn for_item<F>(genv: GlobalEnv, def_id: DefId, mut mk: F) -> QueryResult<RefineArgs>
2178    where
2179        F: FnMut(EarlyBinder<RefineParam>, usize) -> QueryResult<Expr>,
2180    {
2181        let reft_generics = genv.refinement_generics_of(def_id)?;
2182        let count = reft_generics.count();
2183        let mut args = Vec::with_capacity(count);
2184        reft_generics.fill_item(genv, &mut args, &mut mk)?;
2185        Ok(List::from_vec(args))
2186    }
2187}
2188
2189/// A type constructor meant to be used as generic a argument of [kind base]. This is just an alias
2190/// to [`Binder<SubsetTy>`], but we expect the binder to have a single bound variable of the sort of
2191/// the underlying [base type].
2192///
2193/// [kind base]: GenericParamDefKind::Base
2194/// [base type]: SubsetTy::bty
2195pub type SubsetTyCtor = Binder<SubsetTy>;
2196
2197impl SubsetTyCtor {
2198    pub fn as_bty_skipping_binder(&self) -> &BaseTy {
2199        &self.as_ref().skip_binder().bty
2200    }
2201
2202    pub fn to_ty(&self) -> Ty {
2203        let sort = self.sort();
2204        if sort.is_unit() {
2205            self.replace_bound_reft(&Expr::unit()).to_ty()
2206        } else if let Some(def_id) = sort.is_unit_adt() {
2207            self.replace_bound_reft(&Expr::unit_struct(def_id)).to_ty()
2208        } else {
2209            Ty::exists(self.as_ref().map(SubsetTy::to_ty))
2210        }
2211    }
2212
2213    pub fn to_ty_ctor(&self) -> TyCtor {
2214        self.as_ref().map(SubsetTy::to_ty)
2215    }
2216}
2217
2218/// A subset type is a simplified version of a type that has the form `{b[e] | p}` where `b` is a
2219/// [`BaseTy`], `e` a refinement index, and `p` a predicate.
2220///
2221/// These are mainly found under a [`Binder`] with a single variable of the base type's sort. This
2222/// can be interpreted as a type constructor or an existential type. For example, under a binder with a
2223/// variable `v` of sort `int`, we can interpret `{i32[v] | v > 0}` as:
2224/// - A lambda `λv:int. {i32[v] | v > 0}` that "constructs" types when applied to ints, or
2225/// - An existential type `∃v:int. {i32[v] | v > 0}`.
2226///
2227/// This second interpretation is the reason we call this a subset type, i.e., the type `∃v. {b[v] | p}`
2228/// corresponds to the subset of values of  type `b` whose index satisfies `p`. These are the types
2229/// written as `B{v: p}` in the surface syntax and correspond to the types supported in other
2230/// refinement type systems like Liquid Haskell (with the difference that we are explicit
2231/// about separating refinements from program values via an index).
2232///
2233/// The main purpose for subset types is to be used as generic arguments of [kind base] when
2234/// interpreted as type constructors. They have two key properties that makes them suitable
2235/// for this:
2236///
2237/// 1. **Syntactic Restriction**: Subset types are syntactically restricted, making it easier to
2238///    relate them structurally (e.g., for subtyping). For instance, given two types `S<λv. T1>` and
2239///    `S<λ. T2>`, if `T1` and `T2` are subset types, we know they match structurally (at least
2240///    shallowly). In particularly, the syntactic restriction rules out complex types like
2241///    `S<λv. (i32[v], i32[v])>` simplifying some operations.
2242///
2243/// 2. **Eager Canonicalization**: Subset types can be eagerly canonicalized via [*strengthening*]
2244///    during substitution. For example, suppose we have a function:
2245///    ```text
2246///    fn foo<T>(x: T[@a], y: { T[@b] | b == a }) { }
2247///    ```
2248///    If we instantiate `T` with `λv. { i32[v] | v > 0}`, after substitution and applying the
2249///    lambda (the indexing syntax `T[a]` corresponds to an application of the lambda), we get:
2250///    ```text
2251///    fn foo(x: {i32[@a] | a > 0}, y: { { i32[@b] | b > 0 } | b == a }) { }
2252///    ```
2253///    Via *strengthening* we can canonicalize this to
2254///    ```text
2255///    fn foo(x: {i32[@a] | a > 0}, y: { i32[@b] | b == a && b > 0 }) { }
2256///    ```
2257///    As a result, we can guarantee the syntactic restriction through substitution.
2258///
2259/// [kind base]: GenericParamDefKind::Base
2260/// [*strengthening*]: https://arxiv.org/pdf/2010.07763.pdf
2261#[derive(PartialEq, Clone, Eq, Hash, TyEncodable, TyDecodable)]
2262pub struct SubsetTy {
2263    /// The base type `b` in the subset type `{b[e] | p}`.
2264    ///
2265    /// **NOTE:** This is mostly going to be under a [`Binder`]. It is not yet clear to me whether
2266    /// this [`BaseTy`] should be able to mention variables in the binder. In general, in a type
2267    /// `∃v. {b[e] | p}`, it's fine to mention `v` inside `b`, but since [`SubsetTy`] is meant to
2268    /// facilitate syntactic manipulation we may want to restrict this.
2269    pub bty: BaseTy,
2270    /// The refinement index `e` in the subset type `{b[e] | p}`. This can be an arbitrary expression,
2271    /// which makes manipulation easier. However, since this is mostly found under a binder, we expect
2272    /// it to be [`Expr::nu()`].
2273    pub idx: Expr,
2274    /// The predicate `p` in the subset type `{b[e] | p}`.
2275    pub pred: Expr,
2276}
2277
2278impl SubsetTy {
2279    pub fn new(bty: BaseTy, idx: impl Into<Expr>, pred: impl Into<Expr>) -> Self {
2280        Self { bty, idx: idx.into(), pred: pred.into() }
2281    }
2282
2283    pub fn trivial(bty: BaseTy, idx: impl Into<Expr>) -> Self {
2284        Self::new(bty, idx, Expr::tt())
2285    }
2286
2287    pub fn strengthen(&self, pred: impl Into<Expr>) -> Self {
2288        let this = self.clone();
2289        let pred = Expr::and(this.pred, pred).simplify(&SnapshotMap::default());
2290        Self { bty: this.bty, idx: this.idx, pred }
2291    }
2292
2293    pub fn to_ty(&self) -> Ty {
2294        let bty = self.bty.clone();
2295        if self.pred.is_trivially_true() {
2296            Ty::indexed(bty, &self.idx)
2297        } else {
2298            Ty::constr(&self.pred, Ty::indexed(bty, &self.idx))
2299        }
2300    }
2301}
2302
2303impl<'tcx> ToRustc<'tcx> for SubsetTy {
2304    type T = rustc_middle::ty::Ty<'tcx>;
2305
2306    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> rustc_middle::ty::Ty<'tcx> {
2307        self.bty.to_rustc(tcx)
2308    }
2309}
2310
2311#[derive(PartialEq, Clone, Eq, Hash, TyEncodable, TyDecodable)]
2312pub enum GenericArg {
2313    Ty(Ty),
2314    Base(SubsetTyCtor),
2315    Lifetime(Region),
2316    Const(Const),
2317}
2318
2319impl GenericArg {
2320    #[track_caller]
2321    pub fn expect_type(&self) -> &Ty {
2322        if let GenericArg::Ty(ty) = self {
2323            ty
2324        } else {
2325            bug!("expected `rty::GenericArg::Ty`, found `{self:?}`")
2326        }
2327    }
2328
2329    #[track_caller]
2330    pub fn expect_base(&self) -> &SubsetTyCtor {
2331        if let GenericArg::Base(ctor) = self {
2332            ctor
2333        } else {
2334            bug!("expected `rty::GenericArg::Base`, found `{self:?}`")
2335        }
2336    }
2337
2338    pub fn from_param_def(param: &GenericParamDef) -> Self {
2339        match param.kind {
2340            GenericParamDefKind::Type { .. } => {
2341                let param_ty = ParamTy { index: param.index, name: param.name };
2342                GenericArg::Ty(Ty::param(param_ty))
2343            }
2344            GenericParamDefKind::Base { .. } => {
2345                // λv. T[v]
2346                let param_ty = ParamTy { index: param.index, name: param.name };
2347                GenericArg::Base(Binder::bind_with_sort(
2348                    SubsetTy::trivial(BaseTy::Param(param_ty), Expr::nu()),
2349                    Sort::Param(param_ty),
2350                ))
2351            }
2352            GenericParamDefKind::Lifetime => {
2353                let region = EarlyParamRegion { index: param.index, name: param.name };
2354                GenericArg::Lifetime(Region::ReEarlyParam(region))
2355            }
2356            GenericParamDefKind::Const { .. } => {
2357                let param_const = ParamConst { index: param.index, name: param.name };
2358                let kind = ConstKind::Param(param_const);
2359                GenericArg::Const(Const { kind })
2360            }
2361        }
2362    }
2363
2364    /// Creates a `GenericArgs` from the definition of generic parameters, by calling a closure to
2365    /// obtain arg. The closures get to observe the `GenericArgs` as they're being built, which can
2366    /// be used to correctly replace defaults of generic parameters.
2367    pub fn for_item<F>(genv: GlobalEnv, def_id: DefId, mut mk_kind: F) -> QueryResult<GenericArgs>
2368    where
2369        F: FnMut(&GenericParamDef, &[GenericArg]) -> GenericArg,
2370    {
2371        let defs = genv.generics_of(def_id)?;
2372        let count = defs.count();
2373        let mut args = Vec::with_capacity(count);
2374        Self::fill_item(genv, &mut args, &defs, &mut mk_kind)?;
2375        Ok(List::from_vec(args))
2376    }
2377
2378    pub fn identity_for_item(genv: GlobalEnv, def_id: DefId) -> QueryResult<GenericArgs> {
2379        Self::for_item(genv, def_id, |param, _| GenericArg::from_param_def(param))
2380    }
2381
2382    fn fill_item<F>(
2383        genv: GlobalEnv,
2384        args: &mut Vec<GenericArg>,
2385        generics: &Generics,
2386        mk_kind: &mut F,
2387    ) -> QueryResult<()>
2388    where
2389        F: FnMut(&GenericParamDef, &[GenericArg]) -> GenericArg,
2390    {
2391        if let Some(def_id) = generics.parent {
2392            let parent_generics = genv.generics_of(def_id)?;
2393            Self::fill_item(genv, args, &parent_generics, mk_kind)?;
2394        }
2395        for param in &generics.own_params {
2396            let kind = mk_kind(param, args);
2397            tracked_span_assert_eq!(param.index as usize, args.len());
2398            args.push(kind);
2399        }
2400        Ok(())
2401    }
2402}
2403
2404impl From<TyOrBase> for GenericArg {
2405    fn from(v: TyOrBase) -> Self {
2406        match v {
2407            TyOrBase::Ty(ty) => GenericArg::Ty(ty),
2408            TyOrBase::Base(ctor) => GenericArg::Base(ctor),
2409        }
2410    }
2411}
2412
2413impl<'tcx> ToRustc<'tcx> for GenericArg {
2414    type T = rustc_middle::ty::GenericArg<'tcx>;
2415
2416    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2417        use rustc_middle::ty;
2418        match self {
2419            GenericArg::Ty(ty) => ty::GenericArg::from(ty.to_rustc(tcx)),
2420            GenericArg::Base(ctor) => ty::GenericArg::from(ctor.skip_binder_ref().to_rustc(tcx)),
2421            GenericArg::Lifetime(re) => ty::GenericArg::from(re.to_rustc(tcx)),
2422            GenericArg::Const(c) => ty::GenericArg::from(c.to_rustc(tcx)),
2423        }
2424    }
2425}
2426
2427pub type GenericArgs = List<GenericArg>;
2428
2429#[extension(pub trait GenericArgsExt)]
2430impl GenericArgs {
2431    #[track_caller]
2432    fn box_args(&self) -> (&Ty, &Ty) {
2433        if let [GenericArg::Ty(deref), GenericArg::Ty(alloc)] = &self[..] {
2434            (deref, alloc)
2435        } else {
2436            bug!("invalid generic arguments for box");
2437        }
2438    }
2439
2440    // We can't implement [`ToRustc`] because of coherence so we add it here
2441    fn to_rustc<'tcx>(&self, tcx: TyCtxt<'tcx>) -> rustc_middle::ty::GenericArgsRef<'tcx> {
2442        tcx.mk_args_from_iter(self.iter().map(|arg| arg.to_rustc(tcx)))
2443    }
2444
2445    fn rebase_onto(
2446        &self,
2447        tcx: &TyCtxt,
2448        source_ancestor: DefId,
2449        target_args: &GenericArgs,
2450    ) -> List<GenericArg> {
2451        let defs = tcx.generics_of(source_ancestor);
2452        target_args
2453            .iter()
2454            .chain(self.iter().skip(defs.count()))
2455            .cloned()
2456            .collect()
2457    }
2458}
2459
2460#[derive(Debug)]
2461pub enum TyOrBase {
2462    Ty(Ty),
2463    Base(SubsetTyCtor),
2464}
2465
2466impl TyOrBase {
2467    pub fn into_ty(self) -> Ty {
2468        match self {
2469            TyOrBase::Ty(ty) => ty,
2470            TyOrBase::Base(ctor) => ctor.to_ty(),
2471        }
2472    }
2473
2474    #[track_caller]
2475    pub fn expect_base(self) -> SubsetTyCtor {
2476        match self {
2477            TyOrBase::Base(ctor) => ctor,
2478            TyOrBase::Ty(_) => tracked_span_bug!("expected `TyOrBase::Base`"),
2479        }
2480    }
2481
2482    pub fn as_base(self) -> Option<SubsetTyCtor> {
2483        match self {
2484            TyOrBase::Base(ctor) => Some(ctor),
2485            TyOrBase::Ty(_) => None,
2486        }
2487    }
2488}
2489
2490#[derive(Debug, Clone, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
2491pub enum TyOrCtor {
2492    Ty(Ty),
2493    Ctor(TyCtor),
2494}
2495
2496impl TyOrCtor {
2497    #[track_caller]
2498    pub fn expect_ctor(self) -> TyCtor {
2499        match self {
2500            TyOrCtor::Ctor(ctor) => ctor,
2501            TyOrCtor::Ty(_) => tracked_span_bug!("expected `TyOrCtor::Ctor`"),
2502        }
2503    }
2504
2505    pub fn expect_subset_ty_ctor(self) -> SubsetTyCtor {
2506        self.expect_ctor().map(|ty| {
2507            if let canonicalize::CanonicalTy::Constr(constr_ty) = ty.shallow_canonicalize()
2508                && let TyKind::Indexed(bty, idx) = constr_ty.ty().kind()
2509                && idx.is_nu()
2510            {
2511                SubsetTy::new(bty.clone(), Expr::nu(), constr_ty.pred())
2512            } else {
2513                tracked_span_bug!()
2514            }
2515        })
2516    }
2517
2518    pub fn to_ty(&self) -> Ty {
2519        match self {
2520            TyOrCtor::Ctor(ctor) => ctor.to_ty(),
2521            TyOrCtor::Ty(ty) => ty.clone(),
2522        }
2523    }
2524}
2525
2526impl From<TyOrBase> for TyOrCtor {
2527    fn from(v: TyOrBase) -> Self {
2528        match v {
2529            TyOrBase::Ty(ty) => TyOrCtor::Ty(ty),
2530            TyOrBase::Base(ctor) => TyOrCtor::Ctor(ctor.to_ty_ctor()),
2531        }
2532    }
2533}
2534
2535impl CoroutineObligPredicate {
2536    pub fn to_poly_fn_sig(&self) -> PolyFnSig {
2537        let vars = vec![];
2538
2539        let resume_ty = &self.resume_ty;
2540        let env_ty = Ty::coroutine(self.def_id, resume_ty.clone(), self.upvar_tys.clone());
2541
2542        let inputs = List::from_arr([env_ty, resume_ty.clone()]);
2543        let output =
2544            Binder::bind_with_vars(FnOutput::new(self.output.clone(), vec![]), List::empty());
2545
2546        PolyFnSig::bind_with_vars(
2547            FnSig::new(
2548                Safety::Safe,
2549                rustc_abi::ExternAbi::RustCall,
2550                List::empty(),
2551                inputs,
2552                output,
2553                false,
2554                false,
2555            ),
2556            List::from(vars),
2557        )
2558    }
2559}
2560
2561impl RefinementGenerics {
2562    pub fn count(&self) -> usize {
2563        self.parent_count + self.own_params.len()
2564    }
2565
2566    pub fn own_count(&self) -> usize {
2567        self.own_params.len()
2568    }
2569}
2570
2571impl EarlyBinder<RefinementGenerics> {
2572    pub fn parent(&self) -> Option<DefId> {
2573        self.skip_binder_ref().parent
2574    }
2575
2576    pub fn parent_count(&self) -> usize {
2577        self.skip_binder_ref().parent_count
2578    }
2579
2580    pub fn count(&self) -> usize {
2581        self.skip_binder_ref().count()
2582    }
2583
2584    pub fn own_count(&self) -> usize {
2585        self.skip_binder_ref().own_count()
2586    }
2587
2588    pub fn own_param_at(&self, index: usize) -> EarlyBinder<RefineParam> {
2589        self.as_ref().map(|this| this.own_params[index].clone())
2590    }
2591
2592    pub fn param_at(
2593        &self,
2594        param_index: usize,
2595        genv: GlobalEnv,
2596    ) -> QueryResult<EarlyBinder<RefineParam>> {
2597        if let Some(index) = param_index.checked_sub(self.parent_count()) {
2598            Ok(self.own_param_at(index))
2599        } else {
2600            let parent = self.parent().expect("parent_count > 0 but no parent?");
2601            genv.refinement_generics_of(parent)?
2602                .param_at(param_index, genv)
2603        }
2604    }
2605
2606    pub fn iter_own_params(&self) -> impl Iterator<Item = EarlyBinder<RefineParam>> + use<'_> {
2607        self.skip_binder_ref()
2608            .own_params
2609            .iter()
2610            .cloned()
2611            .map(EarlyBinder)
2612    }
2613
2614    pub fn fill_item<F, R>(&self, genv: GlobalEnv, vec: &mut Vec<R>, mk: &mut F) -> QueryResult
2615    where
2616        F: FnMut(EarlyBinder<RefineParam>, usize) -> QueryResult<R>,
2617    {
2618        if let Some(def_id) = self.parent() {
2619            genv.refinement_generics_of(def_id)?
2620                .fill_item(genv, vec, mk)?;
2621        }
2622        for param in self.iter_own_params() {
2623            vec.push(mk(param, vec.len())?);
2624        }
2625        Ok(())
2626    }
2627}
2628
2629impl EarlyBinder<GenericPredicates> {
2630    pub fn predicates(&self) -> EarlyBinder<List<Clause>> {
2631        EarlyBinder(self.0.predicates.clone())
2632    }
2633}
2634
2635impl EarlyBinder<FuncSort> {
2636    /// See [`subst::GenericsSubstForSort`]
2637    pub fn instantiate_func_sort<E>(
2638        self,
2639        sort_for_param: impl FnMut(ParamTy) -> Result<Sort, E>,
2640    ) -> Result<FuncSort, E> {
2641        self.0.try_fold_with(&mut subst::GenericsSubstFolder::new(
2642            subst::GenericsSubstForSort { sort_for_param },
2643            &[],
2644        ))
2645    }
2646}
2647
2648impl VariantSig {
2649    pub fn new(
2650        adt_def: AdtDef,
2651        args: GenericArgs,
2652        fields: List<Ty>,
2653        idx: Expr,
2654        requires: List<Expr>,
2655    ) -> Self {
2656        VariantSig { adt_def, args, fields, idx, requires }
2657    }
2658
2659    pub fn fields(&self) -> &[Ty] {
2660        &self.fields
2661    }
2662
2663    pub fn ret(&self) -> Ty {
2664        let bty = BaseTy::Adt(self.adt_def.clone(), self.args.clone());
2665        let idx = self.idx.clone();
2666        Ty::indexed(bty, idx)
2667    }
2668}
2669
2670impl FnSig {
2671    pub fn new(
2672        safety: Safety,
2673        abi: rustc_abi::ExternAbi,
2674        requires: List<Expr>,
2675        inputs: List<Ty>,
2676        output: Binder<FnOutput>,
2677        no_panic: bool,
2678        lifted: bool,
2679    ) -> Self {
2680        FnSig { safety, abi, requires, inputs, output, no_panic, lifted }
2681    }
2682
2683    pub fn requires(&self) -> &[Expr] {
2684        &self.requires
2685    }
2686
2687    pub fn inputs(&self) -> &[Ty] {
2688        &self.inputs
2689    }
2690
2691    pub fn no_panic(&self) -> bool {
2692        self.no_panic
2693    }
2694
2695    pub fn output(&self) -> Binder<FnOutput> {
2696        self.output.clone()
2697    }
2698}
2699
2700impl<'tcx> ToRustc<'tcx> for FnSig {
2701    type T = rustc_middle::ty::FnSig<'tcx>;
2702
2703    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2704        tcx.mk_fn_sig(
2705            self.inputs().iter().map(|ty| ty.to_rustc(tcx)),
2706            self.output().as_ref().skip_binder().to_rustc(tcx),
2707            false,
2708            self.safety,
2709            self.abi,
2710        )
2711    }
2712}
2713
2714impl FnOutput {
2715    pub fn new(ret: Ty, ensures: impl Into<List<Ensures>>) -> Self {
2716        Self { ret, ensures: ensures.into() }
2717    }
2718}
2719
2720impl<'tcx> ToRustc<'tcx> for FnOutput {
2721    type T = rustc_middle::ty::Ty<'tcx>;
2722
2723    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2724        self.ret.to_rustc(tcx)
2725    }
2726}
2727
2728impl AdtDef {
2729    pub fn new(
2730        rustc: ty::AdtDef,
2731        sort_def: AdtSortDef,
2732        invariants: Vec<Invariant>,
2733        opaque: bool,
2734    ) -> Self {
2735        AdtDef(Interned::new(AdtDefData { invariants, sort_def, opaque, rustc }))
2736    }
2737
2738    pub fn did(&self) -> DefId {
2739        self.0.rustc.did()
2740    }
2741
2742    pub fn sort_def(&self) -> &AdtSortDef {
2743        &self.0.sort_def
2744    }
2745
2746    pub fn sort(&self, args: &[GenericArg]) -> Sort {
2747        self.sort_def().to_sort(args)
2748    }
2749
2750    pub fn is_box(&self) -> bool {
2751        self.0.rustc.is_box()
2752    }
2753
2754    pub fn is_enum(&self) -> bool {
2755        self.0.rustc.is_enum()
2756    }
2757
2758    pub fn is_struct(&self) -> bool {
2759        self.0.rustc.is_struct()
2760    }
2761
2762    pub fn is_union(&self) -> bool {
2763        self.0.rustc.is_union()
2764    }
2765
2766    pub fn variants(&self) -> &IndexSlice<VariantIdx, VariantDef> {
2767        self.0.rustc.variants()
2768    }
2769
2770    pub fn variant(&self, idx: VariantIdx) -> &VariantDef {
2771        self.0.rustc.variant(idx)
2772    }
2773
2774    pub fn invariants(&self) -> EarlyBinder<&[Invariant]> {
2775        EarlyBinder(&self.0.invariants)
2776    }
2777
2778    pub fn discriminants(&self) -> impl Iterator<Item = (VariantIdx, u128)> + '_ {
2779        self.0.rustc.discriminants()
2780    }
2781
2782    pub fn is_opaque(&self) -> bool {
2783        self.0.opaque
2784    }
2785}
2786
2787impl EarlyBinder<PolyVariant> {
2788    // The field_idx is `Some(i)` when we have the `i`-th field of a `union`, in which case,
2789    // the `inputs` are _just_ the `i`-th type (and not all the types...)
2790    pub fn to_poly_fn_sig(&self, field_idx: Option<crate::FieldIdx>) -> EarlyBinder<PolyFnSig> {
2791        self.as_ref().map(|poly_variant| {
2792            poly_variant.as_ref().map(|variant| {
2793                let ret = variant.ret().shift_in_escaping(1);
2794                let output = Binder::bind_with_vars(FnOutput::new(ret, vec![]), List::empty());
2795                let inputs = match field_idx {
2796                    None => variant.fields.clone(),
2797                    Some(i) => List::singleton(variant.fields[i.index()].clone()),
2798                };
2799                FnSig::new(
2800                    Safety::Safe,
2801                    rustc_abi::ExternAbi::Rust,
2802                    variant.requires.clone(),
2803                    inputs,
2804                    output,
2805                    true,
2806                    false,
2807                )
2808            })
2809        })
2810    }
2811}
2812
2813impl TyKind {
2814    fn intern(self) -> Ty {
2815        Ty(Interned::new(self))
2816    }
2817}
2818
2819/// returns the same invariants as for `usize` which is the length of a slice
2820fn slice_invariants(overflow_mode: OverflowMode) -> &'static [Invariant] {
2821    static DEFAULT: LazyLock<[Invariant; 1]> = LazyLock::new(|| {
2822        [Invariant { pred: Binder::bind_with_sort(Expr::ge(Expr::nu(), Expr::zero()), Sort::Int) }]
2823    });
2824    static OVERFLOW: LazyLock<[Invariant; 2]> = LazyLock::new(|| {
2825        [
2826            Invariant {
2827                pred: Binder::bind_with_sort(Expr::ge(Expr::nu(), Expr::zero()), Sort::Int),
2828            },
2829            Invariant {
2830                pred: Binder::bind_with_sort(
2831                    Expr::le(Expr::nu(), Expr::uint_max(UintTy::Usize)),
2832                    Sort::Int,
2833                ),
2834            },
2835        ]
2836    });
2837    if matches!(overflow_mode, OverflowMode::Strict | OverflowMode::Lazy) {
2838        &*OVERFLOW
2839    } else {
2840        &*DEFAULT
2841    }
2842}
2843
2844fn uint_invariants(uint_ty: UintTy, overflow_mode: OverflowMode) -> &'static [Invariant] {
2845    static DEFAULT: LazyLock<[Invariant; 1]> = LazyLock::new(|| {
2846        [Invariant { pred: Binder::bind_with_sort(Expr::ge(Expr::nu(), Expr::zero()), Sort::Int) }]
2847    });
2848
2849    static OVERFLOW: LazyLock<UnordMap<UintTy, [Invariant; 2]>> = LazyLock::new(|| {
2850        UINT_TYS
2851            .into_iter()
2852            .map(|uint_ty| {
2853                let invariants = [
2854                    Invariant {
2855                        pred: Binder::bind_with_sort(Expr::ge(Expr::nu(), Expr::zero()), Sort::Int),
2856                    },
2857                    Invariant {
2858                        pred: Binder::bind_with_sort(
2859                            Expr::le(Expr::nu(), Expr::uint_max(uint_ty)),
2860                            Sort::Int,
2861                        ),
2862                    },
2863                ];
2864                (uint_ty, invariants)
2865            })
2866            .collect()
2867    });
2868    if matches!(overflow_mode, OverflowMode::Strict | OverflowMode::Lazy) {
2869        &OVERFLOW[&uint_ty]
2870    } else {
2871        &*DEFAULT
2872    }
2873}
2874
2875fn char_invariants() -> &'static [Invariant] {
2876    static INVARIANTS: LazyLock<[Invariant; 2]> = LazyLock::new(|| {
2877        [
2878            Invariant {
2879                pred: Binder::bind_with_sort(
2880                    Expr::le(
2881                        Expr::cast(Sort::Char, Sort::Int, Expr::nu()),
2882                        Expr::constant((char::MAX as u32).into()),
2883                    ),
2884                    Sort::Int,
2885                ),
2886            },
2887            Invariant {
2888                pred: Binder::bind_with_sort(
2889                    Expr::le(Expr::zero(), Expr::cast(Sort::Char, Sort::Int, Expr::nu())),
2890                    Sort::Int,
2891                ),
2892            },
2893        ]
2894    });
2895    &*INVARIANTS
2896}
2897
2898fn int_invariants(int_ty: IntTy, overflow_mode: OverflowMode) -> &'static [Invariant] {
2899    static DEFAULT: [Invariant; 0] = [];
2900
2901    static OVERFLOW: LazyLock<UnordMap<IntTy, [Invariant; 2]>> = LazyLock::new(|| {
2902        INT_TYS
2903            .into_iter()
2904            .map(|int_ty| {
2905                let invariants = [
2906                    Invariant {
2907                        pred: Binder::bind_with_sort(
2908                            Expr::ge(Expr::nu(), Expr::int_min(int_ty)),
2909                            Sort::Int,
2910                        ),
2911                    },
2912                    Invariant {
2913                        pred: Binder::bind_with_sort(
2914                            Expr::le(Expr::nu(), Expr::int_max(int_ty)),
2915                            Sort::Int,
2916                        ),
2917                    },
2918                ];
2919                (int_ty, invariants)
2920            })
2921            .collect()
2922    });
2923    if matches!(overflow_mode, OverflowMode::Strict | OverflowMode::Lazy) {
2924        &OVERFLOW[&int_ty]
2925    } else {
2926        &DEFAULT
2927    }
2928}
2929
2930impl_internable!(AdtDefData, AdtSortDefData, TyKind);
2931impl_slice_internable!(
2932    Ty,
2933    GenericArg,
2934    Ensures,
2935    InferMode,
2936    Sort,
2937    SortArg,
2938    GenericParamDef,
2939    TraitRef,
2940    Binder<ExistentialPredicate>,
2941    Clause,
2942    PolyVariant,
2943    Invariant,
2944    RefineParam,
2945    FluxDefId,
2946    SortParamKind,
2947    AssocReft
2948);
2949
2950#[macro_export]
2951macro_rules! _Int {
2952    ($int_ty:pat, $idxs:pat) => {
2953        TyKind::Indexed(BaseTy::Int($int_ty), $idxs)
2954    };
2955}
2956pub use crate::_Int as Int;
2957
2958#[macro_export]
2959macro_rules! _Uint {
2960    ($uint_ty:pat, $idxs:pat) => {
2961        TyKind::Indexed(BaseTy::Uint($uint_ty), $idxs)
2962    };
2963}
2964pub use crate::_Uint as Uint;
2965
2966#[macro_export]
2967macro_rules! _Bool {
2968    ($idxs:pat) => {
2969        TyKind::Indexed(BaseTy::Bool, $idxs)
2970    };
2971}
2972pub use crate::_Bool as Bool;
2973
2974#[macro_export]
2975macro_rules! _Char {
2976    ($idxs:pat) => {
2977        TyKind::Indexed(BaseTy::Char, $idxs)
2978    };
2979}
2980pub use crate::_Char as Char;
2981
2982#[macro_export]
2983macro_rules! _Ref {
2984    ($($pats:pat),+ $(,)?) => {
2985        $crate::rty::TyKind::Indexed($crate::rty::BaseTy::Ref($($pats),+), _)
2986    };
2987}
2988pub use crate::_Ref as Ref;
2989
2990pub struct WfckResults {
2991    pub owner: FluxOwnerId,
2992    param_sorts: UnordMap<fhir::ParamId, Sort>,
2993    bin_op_sorts: ItemLocalMap<Sort>,
2994    fn_app_sorts: ItemLocalMap<List<SortArg>>,
2995    coercions: ItemLocalMap<Vec<Coercion>>,
2996    field_projs: ItemLocalMap<FieldProj>,
2997    node_sorts: ItemLocalMap<Sort>,
2998    record_ctors: ItemLocalMap<DefId>,
2999}
3000
3001#[derive(Clone, Copy, Debug)]
3002pub enum Coercion {
3003    Inject(DefId),
3004    Project(DefId),
3005}
3006
3007pub type ItemLocalMap<T> = UnordMap<fhir::ItemLocalId, T>;
3008
3009#[derive(Debug)]
3010pub struct LocalTableInContext<'a, T> {
3011    owner: FluxOwnerId,
3012    data: &'a ItemLocalMap<T>,
3013}
3014
3015pub struct LocalTableInContextMut<'a, T> {
3016    owner: FluxOwnerId,
3017    data: &'a mut ItemLocalMap<T>,
3018}
3019
3020impl WfckResults {
3021    pub fn new(owner: impl Into<FluxOwnerId>) -> Self {
3022        Self {
3023            owner: owner.into(),
3024            param_sorts: UnordMap::default(),
3025            bin_op_sorts: ItemLocalMap::default(),
3026            coercions: ItemLocalMap::default(),
3027            field_projs: ItemLocalMap::default(),
3028            node_sorts: ItemLocalMap::default(),
3029            record_ctors: ItemLocalMap::default(),
3030            fn_app_sorts: ItemLocalMap::default(),
3031        }
3032    }
3033
3034    pub fn param_sorts_mut(&mut self) -> &mut UnordMap<fhir::ParamId, Sort> {
3035        &mut self.param_sorts
3036    }
3037
3038    pub fn param_sorts(&self) -> &UnordMap<fhir::ParamId, Sort> {
3039        &self.param_sorts
3040    }
3041
3042    pub fn bin_op_sorts_mut(&mut self) -> LocalTableInContextMut<'_, Sort> {
3043        LocalTableInContextMut { owner: self.owner, data: &mut self.bin_op_sorts }
3044    }
3045
3046    pub fn fn_app_sorts_mut(&mut self) -> LocalTableInContextMut<'_, List<SortArg>> {
3047        LocalTableInContextMut { owner: self.owner, data: &mut self.fn_app_sorts }
3048    }
3049
3050    pub fn fn_app_sorts(&self) -> LocalTableInContext<'_, List<SortArg>> {
3051        LocalTableInContext { owner: self.owner, data: &self.fn_app_sorts }
3052    }
3053
3054    pub fn bin_op_sorts(&self) -> LocalTableInContext<'_, Sort> {
3055        LocalTableInContext { owner: self.owner, data: &self.bin_op_sorts }
3056    }
3057
3058    pub fn coercions_mut(&mut self) -> LocalTableInContextMut<'_, Vec<Coercion>> {
3059        LocalTableInContextMut { owner: self.owner, data: &mut self.coercions }
3060    }
3061
3062    pub fn coercions(&self) -> LocalTableInContext<'_, Vec<Coercion>> {
3063        LocalTableInContext { owner: self.owner, data: &self.coercions }
3064    }
3065
3066    pub fn field_projs_mut(&mut self) -> LocalTableInContextMut<'_, FieldProj> {
3067        LocalTableInContextMut { owner: self.owner, data: &mut self.field_projs }
3068    }
3069
3070    pub fn field_projs(&self) -> LocalTableInContext<'_, FieldProj> {
3071        LocalTableInContext { owner: self.owner, data: &self.field_projs }
3072    }
3073
3074    pub fn node_sorts_mut(&mut self) -> LocalTableInContextMut<'_, Sort> {
3075        LocalTableInContextMut { owner: self.owner, data: &mut self.node_sorts }
3076    }
3077
3078    pub fn node_sorts(&self) -> LocalTableInContext<'_, Sort> {
3079        LocalTableInContext { owner: self.owner, data: &self.node_sorts }
3080    }
3081
3082    pub fn record_ctors_mut(&mut self) -> LocalTableInContextMut<'_, DefId> {
3083        LocalTableInContextMut { owner: self.owner, data: &mut self.record_ctors }
3084    }
3085
3086    pub fn record_ctors(&self) -> LocalTableInContext<'_, DefId> {
3087        LocalTableInContext { owner: self.owner, data: &self.record_ctors }
3088    }
3089}
3090
3091impl<T> LocalTableInContextMut<'_, T> {
3092    pub fn insert(&mut self, fhir_id: FhirId, value: T) {
3093        tracked_span_assert_eq!(self.owner, fhir_id.owner);
3094        self.data.insert(fhir_id.local_id, value);
3095    }
3096}
3097
3098impl<'a, T> LocalTableInContext<'a, T> {
3099    pub fn get(&self, fhir_id: FhirId) -> Option<&'a T> {
3100        tracked_span_assert_eq!(self.owner, fhir_id.owner);
3101        self.data.get(&fhir_id.local_id)
3102    }
3103}
3104
3105fn can_auto_strong(fn_sig: &PolyFnSig) -> bool {
3106    struct RegionDetector {
3107        has_region: bool,
3108    }
3109
3110    impl fold::TypeFolder for RegionDetector {
3111        fn fold_region(&mut self, re: &Region) -> Region {
3112            self.has_region = true;
3113            *re
3114        }
3115    }
3116    let mut detector = RegionDetector { has_region: false };
3117    fn_sig
3118        .skip_binder_ref()
3119        .output()
3120        .skip_binder_ref()
3121        .ret
3122        .fold_with(&mut detector);
3123
3124    !detector.has_region
3125}
3126/// The [`auto_strong`] function transforms function signatures by automatically converting
3127/// mutable reference parameters into strong references with associated ensures clauses. This
3128/// transformation is applied only when the function signature does not already contain region
3129/// variables in its return type.
3130///
3131/// Specifically, given a source function of type
3132///
3133///    fn (x: &mut InnerTy) -> bool
3134///
3135/// By default the above gives us an `rty::FnSig`
3136///
3137///    forall<>. fn (x: &mut InnerTy) -> bool
3138///
3139/// Which this function then transforms to
3140///
3141///     forall<l0: Loc>. fn (x: &strg<l0:InnerTy>) -> bool ensures l0:InnerTy
3142pub fn auto_strong(
3143    genv: GlobalEnv,
3144    def_id: impl IntoQueryParam<DefId>,
3145    fn_sig: PolyFnSig,
3146) -> PolyFnSig {
3147    // TODO(auto-strong): we only *really* need the first check `can_auto_strong` here.
3148    // The other two skip `auto-strong` as doing it breaks various downstream things
3149    // that should be fixed.
3150    if !can_auto_strong(&fn_sig)
3151        || matches!(genv.def_kind(def_id), rustc_hir::def::DefKind::Closure)
3152        || !fn_sig.skip_binder_ref().lifted
3153    {
3154        return fn_sig;
3155    }
3156    let kind = BoundReftKind::Anon;
3157    let mut vars = fn_sig.vars().to_vec();
3158    let fn_sig = fn_sig.skip_binder();
3159    // new list of (bound_var, inner_ty)
3160    let mut strg_bvars = vec![];
3161    // new list of input types
3162    let mut strg_inputs = vec![];
3163    // 1. Traverse inputs collecting strong locations
3164    for ty in &fn_sig.inputs {
3165        let strg_ty = if let TyKind::Indexed(BaseTy::Ref(re, inner_ty, Mutability::Mut), _) =
3166            ty.kind()
3167            && !inner_ty.is_slice()
3168        // TODO(auto-strong): including `slice` breaks `tock` for some reason we should replicate in our own tests...
3169        {
3170            // if input is &mut InnerTy create a new bound var `loc` for the strong location
3171            let var = {
3172                let idx = vars.len() + strg_bvars.len();
3173                BoundVar::from_usize(idx)
3174            };
3175            strg_bvars.push((var, inner_ty.clone()));
3176            let loc = Loc::Var(Var::Bound(INNERMOST, BoundReft { var, kind }));
3177            // and transform to &strg<loc:InnerTy>
3178            Ty::strg_ref(*re, Path::new(loc, List::empty()), inner_ty.clone())
3179        } else {
3180            // else leave input type unchanged
3181            ty.clone()
3182        };
3183        strg_inputs.push(strg_ty);
3184    }
3185    // 2. Add bound vars for strong locations
3186    for _ in 0..strg_bvars.len() {
3187        vars.push(BoundVariableKind::Refine(Sort::Loc, InferMode::EVar, kind));
3188    }
3189    // 3. Add ensures for strong locations
3190    let output = fn_sig.output.map(|out| {
3191        let mut ens = out.ensures.to_vec();
3192        for (var, inner_ty) in strg_bvars {
3193            let loc = Loc::Var(Var::Bound(INNERMOST.shifted_in(1), BoundReft { var, kind }));
3194            let path = Path::new(loc, List::empty());
3195            ens.push(Ensures::Type(path, inner_ty.shift_in_escaping(1)));
3196        }
3197        FnOutput { ensures: List::from_vec(ens), ..out }
3198    });
3199
3200    // 4. Reconstruct fn sig with new inputs and output and vars
3201    let fn_sig = FnSig { inputs: List::from_vec(strg_inputs), output, ..fn_sig };
3202    Binder::bind_with_vars(fn_sig, vars.into())
3203}