flux_middle/rty/
mod.rs

1//! Defines how flux represents refinement types internally. Definitions in this module are used
2//! during refinement type checking. A couple of important differences between definitions in this
3//! module and in [`crate::fhir`] are:
4//!
5//! * Types in this module use debruijn indices to represent local binders.
6//! * Data structures are interned so they can be cheaply cloned.
7mod binder;
8pub mod canonicalize;
9mod expr;
10pub mod fold;
11pub mod normalize;
12mod pretty;
13pub mod refining;
14pub mod region_matching;
15pub mod subst;
16use std::{borrow::Cow, cmp::Ordering, fmt, hash::Hash, sync::LazyLock};
17
18pub use binder::{Binder, BoundReftKind, BoundVariableKind, BoundVariableKinds, EarlyBinder};
19use bitflags::bitflags;
20pub use expr::{
21    AggregateKind, AliasReft, BinOp, BoundReft, Constant, Ctor, ESpan, EVid, EarlyReftParam, Expr,
22    ExprKind, FieldProj, HoleKind, InternalFuncKind, KVar, KVid, Lambda, Loc, Name, Path, Real,
23    SpecFuncKind, UnOp, Var,
24};
25pub use flux_arc_interner::List;
26use flux_arc_interner::{Interned, impl_internable, impl_slice_internable};
27use flux_common::{bug, tracked_span_assert_eq, tracked_span_bug};
28use flux_config::OverflowMode;
29use flux_macros::{TypeFoldable, TypeVisitable};
30pub use flux_rustc_bridge::ty::{
31    AliasKind, BoundRegion, BoundRegionKind, BoundVar, Const, ConstKind, ConstVid, DebruijnIndex,
32    EarlyParamRegion, LateParamRegion, LateParamRegionKind,
33    Region::{self, *},
34    RegionVid,
35};
36use flux_rustc_bridge::{
37    ToRustc,
38    mir::{Place, RawPtrKind},
39    ty::{self, GenericArgsExt as _, VariantDef},
40};
41use itertools::Itertools;
42pub use normalize::{NormalizeInfo, NormalizedDefns, local_deps};
43use refining::{Refine as _, Refiner};
44use rustc_abi;
45pub use rustc_abi::{FIRST_VARIANT, VariantIdx};
46use rustc_data_structures::{fx::FxIndexMap, snapshot_map::SnapshotMap, unord::UnordMap};
47use rustc_hir::{LangItem, Safety, def_id::DefId};
48use rustc_index::{IndexSlice, IndexVec, newtype_index};
49use rustc_macros::{Decodable, Encodable, TyDecodable, TyEncodable, extension};
50pub use rustc_middle::{
51    mir::Mutability,
52    ty::{AdtFlags, ClosureKind, FloatTy, IntTy, ParamConst, ParamTy, ScalarInt, UintTy},
53};
54use rustc_middle::{
55    query::IntoQueryParam,
56    ty::{TyCtxt, fast_reject::SimplifiedType},
57};
58use rustc_span::{DUMMY_SP, Span, Symbol, sym, symbol::kw};
59use rustc_type_ir::Upcast as _;
60pub use rustc_type_ir::{INNERMOST, TyVid};
61
62use self::fold::TypeFoldable;
63pub use crate::fhir::InferMode;
64use crate::{
65    LocalDefId,
66    def_id::{FluxDefId, FluxLocalDefId},
67    fhir::{self, FhirId, FluxOwnerId},
68    global_env::GlobalEnv,
69    pretty::{Pretty, PrettyCx},
70    queries::{QueryErr, QueryResult},
71    rty::subst::SortSubst,
72};
73
74/// The definition of the data sort automatically generated for a struct or enum.
75#[derive(Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
76pub struct AdtSortDef(Interned<AdtSortDefData>);
77
78#[derive(Debug, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
79pub struct AdtSortVariant {
80    /// The list of field names as declared in the `#[flux::refined_by(...)]` annotation
81    field_names: Vec<Symbol>,
82    /// The sort of each of the fields. Note that these can contain [sort variables]. Methods used
83    /// to access these sorts guarantee they are properly instantiated.
84    ///
85    /// [sort variables]: Sort::Var
86    sorts: List<Sort>,
87}
88
89impl AdtSortVariant {
90    pub fn new(fields: Vec<(Symbol, Sort)>) -> Self {
91        let (field_names, sorts) = fields.into_iter().unzip();
92        AdtSortVariant { field_names, sorts: List::from_vec(sorts) }
93    }
94
95    pub fn fields(&self) -> usize {
96        self.sorts.len()
97    }
98
99    pub fn field_names(&self) -> &Vec<Symbol> {
100        &self.field_names
101    }
102
103    pub fn sort_by_field_name(&self, args: &[Sort]) -> FxIndexMap<Symbol, Sort> {
104        std::iter::zip(&self.field_names, &self.sorts.fold_with(&mut SortSubst::new(args)))
105            .map(|(name, sort)| (*name, sort.clone()))
106            .collect()
107    }
108
109    pub fn field_by_name(
110        &self,
111        def_id: DefId,
112        args: &[Sort],
113        name: Symbol,
114    ) -> Option<(FieldProj, Sort)> {
115        let idx = self.field_names.iter().position(|it| name == *it)?;
116        let proj = FieldProj::Adt { def_id, field: idx as u32 };
117        let sort = self.sorts[idx].fold_with(&mut SortSubst::new(args));
118        Some((proj, sort))
119    }
120
121    pub fn field_sorts(&self, args: &[Sort]) -> List<Sort> {
122        self.sorts.fold_with(&mut SortSubst::new(args))
123    }
124
125    pub fn field_sorts_instantiate_identity(&self) -> List<Sort> {
126        self.sorts.clone()
127    }
128
129    pub fn projections(&self, def_id: DefId) -> impl Iterator<Item = FieldProj> {
130        (0..self.fields()).map(move |i| FieldProj::Adt { def_id, field: i as u32 })
131    }
132}
133
134#[derive(Debug, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
135struct AdtSortDefData {
136    /// [`DefId`] of the struct or enum this data sort is associated to.
137    def_id: DefId,
138    /// The list of the type parameters used in the `#[flux::refined_by(..)]` annotation.
139    ///
140    /// See [`fhir::RefinedBy::sort_params`] for more details. This is a version of that but using
141    /// [`ParamTy`] instead of [`DefId`].
142    ///
143    /// The length of this list corresponds to the number of sort variables bound by this definition.
144    params: Vec<ParamTy>,
145    /// A vec of variants of the ADT;
146    /// - a `struct` sort -- used for types with a `refined_by` has a single variant;
147    /// - a `reflected` sort -- used for `reflected` enums have multiple variants
148    variants: IndexVec<VariantIdx, AdtSortVariant>,
149    is_reflected: bool,
150    is_struct: bool,
151}
152
153impl AdtSortDef {
154    pub fn new(
155        def_id: DefId,
156        params: Vec<ParamTy>,
157        variants: IndexVec<VariantIdx, AdtSortVariant>,
158        is_reflected: bool,
159        is_struct: bool,
160    ) -> Self {
161        Self(Interned::new(AdtSortDefData { def_id, params, variants, is_reflected, is_struct }))
162    }
163
164    pub fn did(&self) -> DefId {
165        self.0.def_id
166    }
167
168    pub fn variant(&self, idx: VariantIdx) -> &AdtSortVariant {
169        &self.0.variants[idx]
170    }
171
172    pub fn variants(&self) -> &IndexSlice<VariantIdx, AdtSortVariant> {
173        &self.0.variants
174    }
175
176    pub fn opt_struct_variant(&self) -> Option<&AdtSortVariant> {
177        if self.is_struct() { Some(self.struct_variant()) } else { None }
178    }
179
180    #[track_caller]
181    pub fn struct_variant(&self) -> &AdtSortVariant {
182        tracked_span_assert_eq!(self.0.is_struct, true);
183        &self.0.variants[FIRST_VARIANT]
184    }
185
186    pub fn is_reflected(&self) -> bool {
187        self.0.is_reflected
188    }
189
190    pub fn is_struct(&self) -> bool {
191        self.0.is_struct
192    }
193
194    pub fn to_sort(&self, args: &[GenericArg]) -> Sort {
195        let sorts = self
196            .filter_generic_args(args)
197            .map(|arg| arg.expect_base().sort())
198            .collect();
199
200        Sort::App(SortCtor::Adt(self.clone()), sorts)
201    }
202
203    /// Given a list of generic args, returns an iterator of the generic arguments that should be
204    /// mapped to sorts for instantiation.
205    pub fn filter_generic_args<'a, A>(&'a self, args: &'a [A]) -> impl Iterator<Item = &'a A> + 'a {
206        self.0.params.iter().map(|p| &args[p.index as usize])
207    }
208
209    pub fn identity_args(&self) -> List<Sort> {
210        (0..self.0.params.len())
211            .map(|i| Sort::Var(ParamSort::from(i)))
212            .collect()
213    }
214
215    /// Gives the number of sort variables bound by this definition.
216    pub fn param_count(&self) -> usize {
217        self.0.params.len()
218    }
219}
220
221#[derive(Debug, Clone, Default, Encodable, Decodable)]
222pub struct Generics {
223    pub parent: Option<DefId>,
224    pub parent_count: usize,
225    pub own_params: List<GenericParamDef>,
226    pub has_self: bool,
227}
228
229impl Generics {
230    pub fn count(&self) -> usize {
231        self.parent_count + self.own_params.len()
232    }
233
234    pub fn own_default_count(&self) -> usize {
235        self.own_params
236            .iter()
237            .filter(|param| {
238                match param.kind {
239                    GenericParamDefKind::Type { has_default }
240                    | GenericParamDefKind::Const { has_default }
241                    | GenericParamDefKind::Base { has_default } => has_default,
242                    GenericParamDefKind::Lifetime => false,
243                }
244            })
245            .count()
246    }
247
248    pub fn param_at(&self, param_index: usize, genv: GlobalEnv) -> QueryResult<GenericParamDef> {
249        if let Some(index) = param_index.checked_sub(self.parent_count) {
250            Ok(self.own_params[index].clone())
251        } else {
252            let parent = self.parent.expect("parent_count > 0 but no parent?");
253            genv.generics_of(parent)?.param_at(param_index, genv)
254        }
255    }
256
257    pub fn const_params(&self, genv: GlobalEnv) -> QueryResult<Vec<(ParamConst, Sort)>> {
258        let mut res = vec![];
259        for i in 0..self.count() {
260            let param = self.param_at(i, genv)?;
261            if let GenericParamDefKind::Const { .. } = param.kind
262                && let Some(sort) = genv.sort_of_def_id(param.def_id)?
263            {
264                let param_const = ParamConst { name: param.name, index: param.index };
265                res.push((param_const, sort));
266            }
267        }
268        Ok(res)
269    }
270}
271
272#[derive(Debug, Clone, TyEncodable, TyDecodable)]
273pub struct RefinementGenerics {
274    pub parent: Option<DefId>,
275    pub parent_count: usize,
276    pub own_params: List<RefineParam>,
277}
278
279#[derive(
280    PartialEq, Eq, Debug, Clone, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
281)]
282pub struct RefineParam {
283    pub sort: Sort,
284    pub name: Symbol,
285    pub mode: InferMode,
286}
287
288#[derive(Debug, Clone, PartialEq, Eq, Hash, Encodable, Decodable)]
289pub struct GenericParamDef {
290    pub kind: GenericParamDefKind,
291    pub def_id: DefId,
292    pub index: u32,
293    pub name: Symbol,
294}
295
296#[derive(Copy, Debug, Clone, PartialEq, Eq, Hash, Encodable, Decodable)]
297pub enum GenericParamDefKind {
298    Type { has_default: bool },
299    Base { has_default: bool },
300    Lifetime,
301    Const { has_default: bool },
302}
303
304pub const SELF_PARAM_TY: ParamTy = ParamTy { index: 0, name: kw::SelfUpper };
305
306#[derive(Debug, Clone, TyEncodable, TyDecodable)]
307pub struct GenericPredicates {
308    pub parent: Option<DefId>,
309    pub predicates: List<Clause>,
310}
311
312#[derive(
313    Debug, PartialEq, Eq, Hash, Clone, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
314)]
315pub struct Clause {
316    kind: Binder<ClauseKind>,
317}
318
319impl Clause {
320    pub fn new(vars: impl Into<List<BoundVariableKind>>, kind: ClauseKind) -> Self {
321        Clause { kind: Binder::bind_with_vars(kind, vars.into()) }
322    }
323
324    pub fn kind(&self) -> Binder<ClauseKind> {
325        self.kind.clone()
326    }
327
328    fn as_trait_clause(&self) -> Option<Binder<TraitPredicate>> {
329        let clause = self.kind();
330        if let ClauseKind::Trait(trait_clause) = clause.skip_binder_ref() {
331            Some(clause.rebind(trait_clause.clone()))
332        } else {
333            None
334        }
335    }
336
337    pub fn as_projection_clause(&self) -> Option<Binder<ProjectionPredicate>> {
338        let clause = self.kind();
339        if let ClauseKind::Projection(proj_clause) = clause.skip_binder_ref() {
340            Some(clause.rebind(proj_clause.clone()))
341        } else {
342            None
343        }
344    }
345
346    // FIXME(nilehmann) we should deal with the binder in all the places this is used instead of
347    // blindly skipping it here
348    pub fn kind_skipping_binder(&self) -> ClauseKind {
349        self.kind.clone().skip_binder()
350    }
351
352    /// Group `Fn` trait clauses with their corresponding `FnOnce::Output` projection
353    /// predicate. This assumes there's exactly one corresponding projection predicate and will
354    /// crash otherwise.
355    pub fn split_off_fn_trait_clauses(
356        genv: GlobalEnv,
357        clauses: &Clauses,
358    ) -> (Vec<Clause>, Vec<Binder<FnTraitPredicate>>) {
359        let mut fn_trait_clauses = vec![];
360        let mut fn_trait_output_clauses = vec![];
361        let mut rest = vec![];
362        for clause in clauses {
363            if let Some(trait_clause) = clause.as_trait_clause()
364                && let Some(kind) = genv.tcx().fn_trait_kind_from_def_id(trait_clause.def_id())
365            {
366                fn_trait_clauses.push((kind, trait_clause));
367            } else if let Some(proj_clause) = clause.as_projection_clause()
368                && genv.is_fn_output(proj_clause.projection_def_id())
369            {
370                fn_trait_output_clauses.push(proj_clause);
371            } else {
372                rest.push(clause.clone());
373            }
374        }
375        let fn_trait_clauses = fn_trait_clauses
376            .into_iter()
377            .map(|(kind, fn_trait_clause)| {
378                let mut candidates = vec![];
379                for fn_trait_output_clause in &fn_trait_output_clauses {
380                    if fn_trait_output_clause.self_ty() == fn_trait_clause.self_ty() {
381                        candidates.push(fn_trait_output_clause.clone());
382                    }
383                }
384                tracked_span_assert_eq!(candidates.len(), 1);
385                let proj_pred = candidates.pop().unwrap().skip_binder();
386                fn_trait_clause.map(|fn_trait_clause| {
387                    FnTraitPredicate {
388                        kind,
389                        self_ty: fn_trait_clause.self_ty().to_ty(),
390                        tupled_args: fn_trait_clause.trait_ref.args[1].expect_base().to_ty(),
391                        output: proj_pred.term.to_ty(),
392                    }
393                })
394            })
395            .collect_vec();
396        (rest, fn_trait_clauses)
397    }
398}
399
400impl<'tcx> ToRustc<'tcx> for Clause {
401    type T = rustc_middle::ty::Clause<'tcx>;
402
403    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
404        self.kind.to_rustc(tcx).upcast(tcx)
405    }
406}
407
408impl From<Binder<ClauseKind>> for Clause {
409    fn from(kind: Binder<ClauseKind>) -> Self {
410        Clause { kind }
411    }
412}
413
414pub type Clauses = List<Clause>;
415
416#[derive(
417    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
418)]
419pub enum ClauseKind {
420    Trait(TraitPredicate),
421    Projection(ProjectionPredicate),
422    RegionOutlives(RegionOutlivesPredicate),
423    TypeOutlives(TypeOutlivesPredicate),
424    ConstArgHasType(Const, Ty),
425}
426
427impl<'tcx> ToRustc<'tcx> for ClauseKind {
428    type T = rustc_middle::ty::ClauseKind<'tcx>;
429
430    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
431        match self {
432            ClauseKind::Trait(trait_predicate) => {
433                rustc_middle::ty::ClauseKind::Trait(trait_predicate.to_rustc(tcx))
434            }
435            ClauseKind::Projection(projection_predicate) => {
436                rustc_middle::ty::ClauseKind::Projection(projection_predicate.to_rustc(tcx))
437            }
438            ClauseKind::RegionOutlives(outlives_predicate) => {
439                rustc_middle::ty::ClauseKind::RegionOutlives(outlives_predicate.to_rustc(tcx))
440            }
441            ClauseKind::TypeOutlives(outlives_predicate) => {
442                rustc_middle::ty::ClauseKind::TypeOutlives(outlives_predicate.to_rustc(tcx))
443            }
444            ClauseKind::ConstArgHasType(constant, ty) => {
445                rustc_middle::ty::ClauseKind::ConstArgHasType(
446                    constant.to_rustc(tcx),
447                    ty.to_rustc(tcx),
448                )
449            }
450        }
451    }
452}
453
454#[derive(Eq, PartialEq, Hash, Clone, Debug, TyEncodable, TyDecodable)]
455pub struct OutlivesPredicate<T>(pub T, pub Region);
456
457pub type TypeOutlivesPredicate = OutlivesPredicate<Ty>;
458pub type RegionOutlivesPredicate = OutlivesPredicate<Region>;
459
460impl<'tcx, V: ToRustc<'tcx>> ToRustc<'tcx> for OutlivesPredicate<V> {
461    type T = rustc_middle::ty::OutlivesPredicate<'tcx, V::T>;
462
463    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
464        rustc_middle::ty::OutlivesPredicate(self.0.to_rustc(tcx), self.1.to_rustc(tcx))
465    }
466}
467
468#[derive(
469    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
470)]
471pub struct TraitPredicate {
472    pub trait_ref: TraitRef,
473}
474
475impl TraitPredicate {
476    fn self_ty(&self) -> SubsetTyCtor {
477        self.trait_ref.self_ty()
478    }
479}
480
481impl<'tcx> ToRustc<'tcx> for TraitPredicate {
482    type T = rustc_middle::ty::TraitPredicate<'tcx>;
483
484    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
485        rustc_middle::ty::TraitPredicate {
486            polarity: rustc_middle::ty::PredicatePolarity::Positive,
487            trait_ref: self.trait_ref.to_rustc(tcx),
488        }
489    }
490}
491
492pub type PolyTraitPredicate = Binder<TraitPredicate>;
493
494impl PolyTraitPredicate {
495    fn def_id(&self) -> DefId {
496        self.skip_binder_ref().trait_ref.def_id
497    }
498
499    fn self_ty(&self) -> Binder<SubsetTyCtor> {
500        self.clone().map(|predicate| predicate.self_ty())
501    }
502}
503
504#[derive(
505    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
506)]
507pub struct TraitRef {
508    pub def_id: DefId,
509    pub args: GenericArgs,
510}
511
512impl TraitRef {
513    pub fn self_ty(&self) -> SubsetTyCtor {
514        self.args[0].expect_base().clone()
515    }
516}
517
518impl<'tcx> ToRustc<'tcx> for TraitRef {
519    type T = rustc_middle::ty::TraitRef<'tcx>;
520
521    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
522        rustc_middle::ty::TraitRef::new(tcx, self.def_id, self.args.to_rustc(tcx))
523    }
524}
525
526pub type PolyTraitRef = Binder<TraitRef>;
527
528impl PolyTraitRef {
529    pub fn def_id(&self) -> DefId {
530        self.as_ref().skip_binder().def_id
531    }
532}
533
534#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
535pub enum ExistentialPredicate {
536    Trait(ExistentialTraitRef),
537    Projection(ExistentialProjection),
538    AutoTrait(DefId),
539}
540
541pub type PolyExistentialPredicate = Binder<ExistentialPredicate>;
542
543impl ExistentialPredicate {
544    /// See [`rustc_middle::ty::ExistentialPredicateStableCmpExt`]
545    pub fn stable_cmp(&self, tcx: TyCtxt, other: &Self) -> Ordering {
546        match (self, other) {
547            (ExistentialPredicate::Trait(_), ExistentialPredicate::Trait(_)) => Ordering::Equal,
548            (ExistentialPredicate::Projection(a), ExistentialPredicate::Projection(b)) => {
549                tcx.def_path_hash(a.def_id)
550                    .cmp(&tcx.def_path_hash(b.def_id))
551            }
552            (ExistentialPredicate::AutoTrait(a), ExistentialPredicate::AutoTrait(b)) => {
553                tcx.def_path_hash(*a).cmp(&tcx.def_path_hash(*b))
554            }
555            (ExistentialPredicate::Trait(_), _) => Ordering::Less,
556            (ExistentialPredicate::Projection(_), ExistentialPredicate::Trait(_)) => {
557                Ordering::Greater
558            }
559            (ExistentialPredicate::Projection(_), _) => Ordering::Less,
560            (ExistentialPredicate::AutoTrait(_), _) => Ordering::Greater,
561        }
562    }
563}
564
565impl<'tcx> ToRustc<'tcx> for ExistentialPredicate {
566    type T = rustc_middle::ty::ExistentialPredicate<'tcx>;
567
568    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
569        match self {
570            ExistentialPredicate::Trait(trait_ref) => {
571                let trait_ref = rustc_middle::ty::ExistentialTraitRef::new_from_args(
572                    tcx,
573                    trait_ref.def_id,
574                    trait_ref.args.to_rustc(tcx),
575                );
576                rustc_middle::ty::ExistentialPredicate::Trait(trait_ref)
577            }
578            ExistentialPredicate::Projection(projection) => {
579                rustc_middle::ty::ExistentialPredicate::Projection(
580                    rustc_middle::ty::ExistentialProjection::new_from_args(
581                        tcx,
582                        projection.def_id,
583                        projection.args.to_rustc(tcx),
584                        projection.term.skip_binder_ref().to_rustc(tcx).into(),
585                    ),
586                )
587            }
588            ExistentialPredicate::AutoTrait(def_id) => {
589                rustc_middle::ty::ExistentialPredicate::AutoTrait(*def_id)
590            }
591        }
592    }
593}
594
595#[derive(
596    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
597)]
598pub struct ExistentialTraitRef {
599    pub def_id: DefId,
600    pub args: GenericArgs,
601}
602
603pub type PolyExistentialTraitRef = Binder<ExistentialTraitRef>;
604
605impl PolyExistentialTraitRef {
606    pub fn def_id(&self) -> DefId {
607        self.as_ref().skip_binder().def_id
608    }
609}
610
611#[derive(
612    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
613)]
614pub struct ExistentialProjection {
615    pub def_id: DefId,
616    pub args: GenericArgs,
617    pub term: SubsetTyCtor,
618}
619
620#[derive(
621    PartialEq, Eq, Hash, Debug, Clone, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
622)]
623pub struct ProjectionPredicate {
624    pub projection_ty: AliasTy,
625    pub term: SubsetTyCtor,
626}
627
628impl Pretty for ProjectionPredicate {
629    fn fmt(&self, _cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
630        write!(
631            f,
632            "ProjectionPredicate << projection_ty = {:?}, term = {:?} >>",
633            self.projection_ty, self.term
634        )
635    }
636}
637
638impl ProjectionPredicate {
639    pub fn self_ty(&self) -> SubsetTyCtor {
640        self.projection_ty.self_ty().clone()
641    }
642}
643
644impl<'tcx> ToRustc<'tcx> for ProjectionPredicate {
645    type T = rustc_middle::ty::ProjectionPredicate<'tcx>;
646
647    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
648        rustc_middle::ty::ProjectionPredicate {
649            projection_term: rustc_middle::ty::AliasTerm::new_from_args(
650                tcx,
651                self.projection_ty.def_id,
652                self.projection_ty.args.to_rustc(tcx),
653            ),
654            term: self.term.as_bty_skipping_binder().to_rustc(tcx).into(),
655        }
656    }
657}
658
659pub type PolyProjectionPredicate = Binder<ProjectionPredicate>;
660
661impl PolyProjectionPredicate {
662    pub fn projection_def_id(&self) -> DefId {
663        self.skip_binder_ref().projection_ty.def_id
664    }
665
666    pub fn self_ty(&self) -> Binder<SubsetTyCtor> {
667        self.clone().map(|predicate| predicate.self_ty())
668    }
669}
670
671#[derive(
672    Clone, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
673)]
674pub struct FnTraitPredicate {
675    pub self_ty: Ty,
676    pub tupled_args: Ty,
677    pub output: Ty,
678    pub kind: ClosureKind,
679}
680
681impl Pretty for FnTraitPredicate {
682    fn fmt(&self, _cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
683        write!(
684            f,
685            "self = {:?}, args = {:?}, output = {:?}, kind = {}",
686            self.self_ty, self.tupled_args, self.output, self.kind
687        )
688    }
689}
690
691impl FnTraitPredicate {
692    pub fn fndef_sig(&self) -> FnSig {
693        let inputs = self.tupled_args.expect_tuple().iter().cloned().collect();
694        let ret = self.output.clone().shift_in_escaping(1);
695        let output = Binder::bind_with_vars(FnOutput::new(ret, vec![]), List::empty());
696        FnSig::new(Safety::Safe, rustc_abi::ExternAbi::Rust, List::empty(), inputs, output, false)
697    }
698}
699
700pub fn to_closure_sig(
701    tcx: TyCtxt,
702    closure_id: LocalDefId,
703    tys: &[Ty],
704    args: &flux_rustc_bridge::ty::GenericArgs,
705    poly_sig: &PolyFnSig,
706) -> PolyFnSig {
707    let closure_args = args.as_closure();
708    let kind_ty = closure_args.kind_ty().to_rustc(tcx);
709    let Some(kind) = kind_ty.to_opt_closure_kind() else {
710        bug!("to_closure_sig: expected closure kind, found {kind_ty:?}");
711    };
712
713    let mut vars = poly_sig.vars().clone().to_vec();
714    let fn_sig = poly_sig.clone().skip_binder();
715    let closure_ty = Ty::closure(closure_id.into(), tys, args);
716    let env_ty = match kind {
717        ClosureKind::Fn => {
718            vars.push(BoundVariableKind::Region(BoundRegionKind::ClosureEnv));
719            let br = BoundRegion {
720                var: BoundVar::from_usize(vars.len() - 1),
721                kind: BoundRegionKind::ClosureEnv,
722            };
723            Ty::mk_ref(ReBound(INNERMOST, br), closure_ty, Mutability::Not)
724        }
725        ClosureKind::FnMut => {
726            vars.push(BoundVariableKind::Region(BoundRegionKind::ClosureEnv));
727            let br = BoundRegion {
728                var: BoundVar::from_usize(vars.len() - 1),
729                kind: BoundRegionKind::ClosureEnv,
730            };
731            Ty::mk_ref(ReBound(INNERMOST, br), closure_ty, Mutability::Mut)
732        }
733        ClosureKind::FnOnce => closure_ty,
734    };
735
736    let inputs = std::iter::once(env_ty)
737        .chain(fn_sig.inputs().iter().cloned())
738        .collect::<Vec<_>>();
739    let output = fn_sig.output().clone();
740
741    let fn_sig = crate::rty::FnSig::new(
742        fn_sig.safety,
743        fn_sig.abi,
744        fn_sig.requires.clone(),
745        inputs.into(),
746        output,
747        false,
748    );
749
750    PolyFnSig::bind_with_vars(fn_sig, List::from(vars))
751}
752
753#[derive(
754    Clone, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
755)]
756pub struct CoroutineObligPredicate {
757    pub def_id: DefId,
758    pub resume_ty: Ty,
759    pub upvar_tys: List<Ty>,
760    pub output: Ty,
761}
762
763#[derive(Copy, Clone, Encodable, Decodable, Hash, PartialEq, Eq)]
764pub struct AssocReft {
765    pub def_id: FluxDefId,
766    // NOTE: Field is used to denote final associated generic refinements on Traits
767    pub final_: bool,
768    pub span: Span,
769}
770
771impl AssocReft {
772    pub fn new(def_id: FluxDefId, final_: bool, span: Span) -> Self {
773        Self { def_id, final_, span }
774    }
775
776    pub fn name(&self) -> Symbol {
777        self.def_id.name()
778    }
779
780    pub fn def_id(&self) -> FluxDefId {
781        self.def_id
782    }
783}
784
785#[derive(Clone, Encodable, Decodable)]
786pub struct AssocRefinements {
787    pub items: List<AssocReft>,
788}
789
790impl Default for AssocRefinements {
791    fn default() -> Self {
792        Self { items: List::empty() }
793    }
794}
795
796impl AssocRefinements {
797    pub fn get(&self, assoc_id: FluxDefId) -> AssocReft {
798        *self
799            .items
800            .into_iter()
801            .find(|it| it.def_id == assoc_id)
802            .unwrap_or_else(|| bug!("caller should guarantee existence of associated refinement"))
803    }
804
805    pub fn find(&self, name: Symbol) -> Option<AssocReft> {
806        Some(*self.items.into_iter().find(|it| it.name() == name)?)
807    }
808}
809
810#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
811pub enum SortCtor {
812    Set,
813    Map,
814    Adt(AdtSortDef),
815    User(FluxDefId),
816}
817
818newtype_index! {
819    /// [`ParamSort`] is used for polymorphic sorts (`Set`, `Map`, etc.) and [bit-vector size parameters].
820    /// They should occur "bound" under a [`PolyFuncSort`] or an [`AdtSortDef`]. We assume there's a
821    /// single binder and a [`ParamSort`] represents a variable as an index into the list of variables
822    /// bound by that binder, i.e., the representation doesnt't support higher-ranked sorts.
823    ///
824    /// [bit-vector size parameters]: BvSize::Param
825    #[debug_format = "?{}s"]
826    #[encodable]
827    pub struct ParamSort {}
828}
829
830newtype_index! {
831    /// A *sort* *v*variable *id*
832    #[debug_format = "?{}s"]
833    #[encodable]
834    pub struct SortVid {}
835}
836
837impl ena::unify::UnifyKey for SortVid {
838    type Value = SortVarVal;
839
840    #[inline]
841    fn index(&self) -> u32 {
842        self.as_u32()
843    }
844
845    #[inline]
846    fn from_index(u: u32) -> Self {
847        SortVid::from_u32(u)
848    }
849
850    fn tag() -> &'static str {
851        "SortVid"
852    }
853}
854
855bitflags! {
856    /// A *sort constraint* is a set of operations a sort must support.
857    ///
858    /// During sort checking, we accumulate the operations required for each sort variable. As
859    /// unification progresses, these constraints become more specific, i.e, a sort must support
860    /// more operations to satisfy them.
861    #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
862    pub struct SortCstr: u16 {
863        /// An empty constraint (any sort satisfies it)
864        const BOT     = 0b0000000000;
865        /// `*`
866        const MUL     = 0b0000000001;
867        /// `/`
868        const DIV     = 0b0000000010;
869        /// `%`
870        const MOD     = 0b0000000100;
871        /// `+`
872        const ADD     = 0b0000001000;
873        /// `-`
874        const SUB     = 0b0000010000;
875        /// `|`
876        const BIT_OR  = 0b0000100000;
877        /// `&`
878        const BIT_AND = 0b0001000000;
879        /// `>>`
880        const BIT_SHL = 0b0010000000;
881        /// `<<`
882        const BIT_SHR = 0b0100000000;
883        /// `^`
884        const BIT_XOR = 0b1000000000;
885
886        /// The set of operations supported by all _numeric_ sorts.
887        const NUMERIC = Self::ADD.bits() | Self::SUB.bits() | Self::MUL.bits() | Self::DIV.bits();
888        /// The set of operations supported by integers.
889        const INT = Self::DIV.bits()
890            | Self::MUL.bits()
891            | Self::MOD.bits()
892            | Self::ADD.bits()
893            | Self::SUB.bits();
894        /// The set of operations supported by reals.
895        const REAL = Self::ADD.bits() | Self::SUB.bits() | Self::MUL.bits() | Self::DIV.bits();
896        /// The set of operations supported by bit vectors.
897        const BITVEC =  Self::DIV.bits()
898            | Self::MUL.bits()
899            | Self::MOD.bits()
900            | Self::ADD.bits()
901            | Self::SUB.bits()
902            | Self::BIT_OR.bits()
903            | Self::BIT_AND.bits()
904            | Self::BIT_SHL.bits()
905            | Self::BIT_SHR.bits()
906            | Self::BIT_XOR.bits();
907        /// The set of operations supported by sets.
908        const SET = Self::SUB.bits() | Self::BIT_OR.bits() | Self::BIT_AND.bits();
909    }
910}
911
912impl SortCstr {
913    /// Returns a constraint that only requires the specified binary operation.
914    pub fn from_bin_op(op: fhir::BinOp) -> Self {
915        match op {
916            fhir::BinOp::Add => Self::ADD,
917            fhir::BinOp::Sub => Self::SUB,
918            fhir::BinOp::Mul => Self::MUL,
919            fhir::BinOp::Div => Self::DIV,
920            fhir::BinOp::Mod => Self::MOD,
921            fhir::BinOp::BitAnd => Self::BIT_AND,
922            fhir::BinOp::BitOr => Self::BIT_OR,
923            fhir::BinOp::BitXor => Self::BIT_XOR,
924            fhir::BinOp::BitShl => Self::BIT_SHL,
925            fhir::BinOp::BitShr => Self::BIT_SHR,
926            _ => bug!("{op:?} not supported as a constraint"),
927        }
928    }
929
930    /// Returns whether a sort satisfies this constraint
931    fn satisfy(self, sort: &Sort) -> bool {
932        match sort {
933            Sort::Int => SortCstr::INT.contains(self),
934            Sort::Real => SortCstr::REAL.contains(self),
935            Sort::BitVec(_) => SortCstr::BITVEC.contains(self),
936            Sort::App(SortCtor::Set, _) => SortCstr::SET.contains(self),
937            _ => self == SortCstr::BOT,
938        }
939    }
940}
941
942/// Unification value for sort variables used during sort checking.
943#[derive(Debug, Clone, PartialEq, Eq)]
944pub enum SortVarVal {
945    /// The variable is not yet solved but the solution must satisfy some constraint.
946    Unsolved(SortCstr),
947    /// The variable has been solved to a sort.
948    Solved(Sort),
949}
950
951impl Default for SortVarVal {
952    fn default() -> Self {
953        SortVarVal::Unsolved(SortCstr::BOT)
954    }
955}
956
957impl SortVarVal {
958    pub fn solved_or(&self, sort: &Sort) -> Sort {
959        match self {
960            SortVarVal::Unsolved(_) => sort.clone(),
961            SortVarVal::Solved(sort) => sort.clone(),
962        }
963    }
964
965    pub fn map_solved(&self, f: impl FnOnce(&Sort) -> Sort) -> SortVarVal {
966        match self {
967            SortVarVal::Unsolved(cstr) => SortVarVal::Unsolved(*cstr),
968            SortVarVal::Solved(sort) => SortVarVal::Solved(f(sort)),
969        }
970    }
971}
972
973impl ena::unify::UnifyValue for SortVarVal {
974    type Error = ();
975
976    fn unify_values(value1: &Self, value2: &Self) -> Result<Self, Self::Error> {
977        match (value1, value2) {
978            (SortVarVal::Solved(s1), SortVarVal::Solved(s2)) if s1 == s2 => {
979                Ok(SortVarVal::Solved(s1.clone()))
980            }
981            (SortVarVal::Unsolved(a), SortVarVal::Unsolved(b)) => Ok(SortVarVal::Unsolved(*a | *b)),
982            (SortVarVal::Unsolved(v), SortVarVal::Solved(sort))
983            | (SortVarVal::Solved(sort), SortVarVal::Unsolved(v))
984                if v.satisfy(sort) =>
985            {
986                Ok(SortVarVal::Solved(sort.clone()))
987            }
988            _ => Err(()),
989        }
990    }
991}
992
993newtype_index! {
994    /// A *b*it *v*ector *size* *v*variable *id*
995    #[debug_format = "?{}size"]
996    #[encodable]
997    pub struct BvSizeVid {}
998}
999
1000impl ena::unify::UnifyKey for BvSizeVid {
1001    type Value = Option<BvSize>;
1002
1003    #[inline]
1004    fn index(&self) -> u32 {
1005        self.as_u32()
1006    }
1007
1008    #[inline]
1009    fn from_index(u: u32) -> Self {
1010        BvSizeVid::from_u32(u)
1011    }
1012
1013    fn tag() -> &'static str {
1014        "BvSizeVid"
1015    }
1016}
1017
1018impl ena::unify::EqUnifyValue for BvSize {}
1019
1020#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1021pub enum Sort {
1022    Int,
1023    Bool,
1024    Real,
1025    BitVec(BvSize),
1026    Str,
1027    Char,
1028    Loc,
1029    Param(ParamTy),
1030    Tuple(List<Sort>),
1031    Alias(AliasKind, AliasTy),
1032    Func(PolyFuncSort),
1033    App(SortCtor, List<Sort>),
1034    Var(ParamSort),
1035    Infer(SortVid),
1036    Err,
1037}
1038
1039pub enum CastKind {
1040    /// Identity cast, which is erasable (e.g. int -> int, char -> int)
1041    Identity,
1042    /// From bool to int
1043    BoolToInt,
1044    /// Casts to unit index, (e.g. int -> float)
1045    IntoUnit,
1046    /// Uninterpreted casts, only allowed with explicit flag
1047    Uninterpreted,
1048}
1049
1050impl Sort {
1051    pub fn tuple(sorts: impl Into<List<Sort>>) -> Self {
1052        Sort::Tuple(sorts.into())
1053    }
1054
1055    pub fn app(ctor: SortCtor, sorts: List<Sort>) -> Self {
1056        Sort::App(ctor, sorts)
1057    }
1058
1059    pub fn unit() -> Self {
1060        Self::tuple(vec![])
1061    }
1062
1063    #[track_caller]
1064    pub fn expect_func(&self) -> &PolyFuncSort {
1065        if let Sort::Func(sort) = self { sort } else { bug!("expected `Sort::Func`") }
1066    }
1067
1068    pub fn is_loc(&self) -> bool {
1069        matches!(self, Sort::Loc)
1070    }
1071
1072    pub fn is_unit(&self) -> bool {
1073        matches!(self, Sort::Tuple(sorts) if sorts.is_empty())
1074    }
1075
1076    pub fn is_unit_adt(&self) -> Option<DefId> {
1077        if let Sort::App(SortCtor::Adt(sort_def), _) = self
1078            && let Some(variant) = sort_def.opt_struct_variant()
1079            && variant.fields() == 0
1080        {
1081            Some(sort_def.did())
1082        } else {
1083            None
1084        }
1085    }
1086
1087    /// Whether the sort is a function with return sort bool
1088    pub fn is_pred(&self) -> bool {
1089        matches!(self, Sort::Func(fsort) if fsort.skip_binders().output().is_bool())
1090    }
1091
1092    /// Returns `true` if the sort is [`Bool`].
1093    ///
1094    /// [`Bool`]: Sort::Bool
1095    #[must_use]
1096    pub fn is_bool(&self) -> bool {
1097        matches!(self, Self::Bool)
1098    }
1099
1100    pub fn cast_kind(self: &Sort, to: &Sort) -> CastKind {
1101        if self == to
1102            || (matches!(self, Sort::Char | Sort::Int) && matches!(to, Sort::Char | Sort::Int))
1103        {
1104            CastKind::Identity
1105        } else if matches!(self, Sort::Bool) && matches!(to, Sort::Int) {
1106            CastKind::BoolToInt
1107        } else if to.is_unit() {
1108            CastKind::IntoUnit
1109        } else {
1110            CastKind::Uninterpreted
1111        }
1112    }
1113
1114    pub fn walk(&self, mut f: impl FnMut(&Sort, &[FieldProj])) {
1115        fn go(sort: &Sort, f: &mut impl FnMut(&Sort, &[FieldProj]), proj: &mut Vec<FieldProj>) {
1116            match sort {
1117                Sort::Tuple(flds) => {
1118                    for (i, sort) in flds.iter().enumerate() {
1119                        proj.push(FieldProj::Tuple { arity: flds.len(), field: i as u32 });
1120                        go(sort, f, proj);
1121                        proj.pop();
1122                    }
1123                }
1124                Sort::App(SortCtor::Adt(sort_def), args) if sort_def.is_struct() => {
1125                    let field_sorts = sort_def.struct_variant().field_sorts(args);
1126                    for (i, sort) in field_sorts.iter().enumerate() {
1127                        proj.push(FieldProj::Adt { def_id: sort_def.did(), field: i as u32 });
1128                        go(sort, f, proj);
1129                        proj.pop();
1130                    }
1131                }
1132                _ => {
1133                    f(sort, proj);
1134                }
1135            }
1136        }
1137        go(self, &mut f, &mut vec![]);
1138    }
1139}
1140
1141/// The size of a [bit-vector]
1142///
1143/// [bit-vector]: Sort::BitVec
1144#[derive(Clone, Copy, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1145pub enum BvSize {
1146    /// A fixed size
1147    Fixed(u32),
1148    /// A size that has been parameterized, e.g., bound under a [`PolyFuncSort`]
1149    Param(ParamSort),
1150    /// A size that needs to be inferred. Used during sort checking to instantiate bit-vector
1151    /// sizes at call-sites.
1152    Infer(BvSizeVid),
1153}
1154
1155impl rustc_errors::IntoDiagArg for Sort {
1156    fn into_diag_arg(self, _path: &mut Option<std::path::PathBuf>) -> rustc_errors::DiagArgValue {
1157        rustc_errors::DiagArgValue::Str(Cow::Owned(format!("{self:?}")))
1158    }
1159}
1160
1161impl rustc_errors::IntoDiagArg for FuncSort {
1162    fn into_diag_arg(self, _path: &mut Option<std::path::PathBuf>) -> rustc_errors::DiagArgValue {
1163        rustc_errors::DiagArgValue::Str(Cow::Owned(format!("{self:?}")))
1164    }
1165}
1166
1167#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
1168pub struct FuncSort {
1169    pub inputs_and_output: List<Sort>,
1170}
1171
1172impl FuncSort {
1173    pub fn new(mut inputs: Vec<Sort>, output: Sort) -> Self {
1174        inputs.push(output);
1175        FuncSort { inputs_and_output: List::from_vec(inputs) }
1176    }
1177
1178    pub fn inputs(&self) -> &[Sort] {
1179        &self.inputs_and_output[0..self.inputs_and_output.len() - 1]
1180    }
1181
1182    pub fn output(&self) -> &Sort {
1183        &self.inputs_and_output[self.inputs_and_output.len() - 1]
1184    }
1185
1186    pub fn to_poly(&self) -> PolyFuncSort {
1187        PolyFuncSort::new(List::empty(), self.clone())
1188    }
1189}
1190
1191/// See [`PolyFuncSort`]
1192#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable)]
1193pub enum SortParamKind {
1194    Sort,
1195    BvSize,
1196}
1197
1198/// A polymorphic function sort parametric over [sorts] or [bit-vector sizes].
1199///
1200/// Parameterizing over bit-vector sizes is a bit of a stretch, because smtlib doesn't support full
1201/// parametric reasoning over them. As long as we used functions parameterized over a size monomorphically
1202/// we should be fine. Right now, we can guarantee this, because size parameters are not exposed in
1203/// the surface syntax and they are only used for predefined (interpreted) theory functions.
1204///
1205/// [sorts]: Sort
1206/// [bit-vector sizes]: BvSize::Param
1207#[derive(Clone, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable)]
1208pub struct PolyFuncSort {
1209    /// The list of parameters including sorts and bit vector sizes
1210    params: List<SortParamKind>,
1211    fsort: FuncSort,
1212}
1213
1214impl PolyFuncSort {
1215    pub fn new(params: List<SortParamKind>, fsort: FuncSort) -> Self {
1216        PolyFuncSort { params, fsort }
1217    }
1218
1219    pub fn skip_binders(&self) -> FuncSort {
1220        self.fsort.clone()
1221    }
1222
1223    pub fn instantiate_identity(&self) -> FuncSort {
1224        self.fsort.clone()
1225    }
1226
1227    pub fn expect_mono(&self) -> FuncSort {
1228        assert!(self.params.is_empty());
1229        self.fsort.clone()
1230    }
1231
1232    pub fn params(&self) -> impl ExactSizeIterator<Item = SortParamKind> + '_ {
1233        self.params.iter().copied()
1234    }
1235
1236    pub fn instantiate(&self, args: &[SortArg]) -> FuncSort {
1237        self.fsort.fold_with(&mut SortSubst::new(args))
1238    }
1239}
1240
1241/// An argument for a generic parameter in a [`Sort`] which can be either a generic sort or a
1242/// generic bit-vector size.
1243#[derive(
1244    Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
1245)]
1246pub enum SortArg {
1247    Sort(Sort),
1248    BvSize(BvSize),
1249}
1250
1251#[derive(Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1252pub enum ConstantInfo {
1253    /// An uninterpreted constant
1254    Uninterpreted,
1255    /// A non-integral constant whose value is specified by the user
1256    Interpreted(Expr, Sort),
1257}
1258
1259#[derive(Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1260pub struct AdtDef(Interned<AdtDefData>);
1261
1262#[derive(Debug, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1263pub struct AdtDefData {
1264    invariants: Vec<Invariant>,
1265    sort_def: AdtSortDef,
1266    opaque: bool,
1267    rustc: ty::AdtDef,
1268}
1269
1270/// Option-like enum to explicitly mark that we don't have information about an ADT because it was
1271/// annotated with `#[flux::opaque]`. Note that only structs can be marked as opaque.
1272#[derive(Clone, Debug, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
1273pub enum Opaqueness<T> {
1274    Opaque,
1275    Transparent(T),
1276}
1277
1278impl<T> Opaqueness<T> {
1279    pub fn map<S>(self, f: impl FnOnce(T) -> S) -> Opaqueness<S> {
1280        match self {
1281            Opaqueness::Opaque => Opaqueness::Opaque,
1282            Opaqueness::Transparent(value) => Opaqueness::Transparent(f(value)),
1283        }
1284    }
1285
1286    pub fn as_ref(&self) -> Opaqueness<&T> {
1287        match self {
1288            Opaqueness::Opaque => Opaqueness::Opaque,
1289            Opaqueness::Transparent(value) => Opaqueness::Transparent(value),
1290        }
1291    }
1292
1293    pub fn as_deref(&self) -> Opaqueness<&T::Target>
1294    where
1295        T: std::ops::Deref,
1296    {
1297        match self {
1298            Opaqueness::Opaque => Opaqueness::Opaque,
1299            Opaqueness::Transparent(value) => Opaqueness::Transparent(value.deref()),
1300        }
1301    }
1302
1303    pub fn ok_or_else<E>(self, err: impl FnOnce() -> E) -> Result<T, E> {
1304        match self {
1305            Opaqueness::Transparent(v) => Ok(v),
1306            Opaqueness::Opaque => Err(err()),
1307        }
1308    }
1309
1310    #[track_caller]
1311    pub fn expect(self, msg: &str) -> T {
1312        match self {
1313            Opaqueness::Transparent(val) => val,
1314            Opaqueness::Opaque => bug!("{}", msg),
1315        }
1316    }
1317
1318    pub fn ok_or_query_err(self, struct_id: DefId) -> Result<T, QueryErr> {
1319        self.ok_or_else(|| QueryErr::OpaqueStruct { struct_id })
1320    }
1321}
1322
1323impl<T, E> Opaqueness<Result<T, E>> {
1324    pub fn transpose(self) -> Result<Opaqueness<T>, E> {
1325        match self {
1326            Opaqueness::Transparent(Ok(x)) => Ok(Opaqueness::Transparent(x)),
1327            Opaqueness::Transparent(Err(e)) => Err(e),
1328            Opaqueness::Opaque => Ok(Opaqueness::Opaque),
1329        }
1330    }
1331}
1332
1333pub static INT_TYS: [IntTy; 6] =
1334    [IntTy::Isize, IntTy::I8, IntTy::I16, IntTy::I32, IntTy::I64, IntTy::I128];
1335pub static UINT_TYS: [UintTy; 6] =
1336    [UintTy::Usize, UintTy::U8, UintTy::U16, UintTy::U32, UintTy::U64, UintTy::U128];
1337
1338#[derive(
1339    Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable, TypeFoldable, TypeVisitable,
1340)]
1341pub struct Invariant {
1342    // This predicate may have sort variables, but we don't explicitly mark it like in `PolyFuncSort`.
1343    // See comment on `apply` for details.
1344    pred: Binder<Expr>,
1345}
1346
1347impl Invariant {
1348    pub fn new(pred: Binder<Expr>) -> Self {
1349        Self { pred }
1350    }
1351
1352    pub fn apply(&self, idx: &Expr) -> Expr {
1353        // The predicate may have sort variables but we don't explicitly instantiate them. This
1354        // works because within an expression, sort variables can only appear inside the sort
1355        // annotation for a lambda and invariants cannot have lambdas. It remains to instantiate
1356        // variables in the sort of the binder itself, but since we are removing it, we can avoid
1357        // the explicit instantiation. Ultimately, this works because the expression we generate in
1358        // fixpoint doesn't need sort annotations (sorts are re-inferred).
1359        self.pred.replace_bound_reft(idx)
1360    }
1361}
1362
1363pub type PolyVariants = List<Binder<VariantSig>>;
1364pub type PolyVariant = Binder<VariantSig>;
1365
1366#[derive(Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1367pub struct VariantSig {
1368    pub adt_def: AdtDef,
1369    pub args: GenericArgs,
1370    pub fields: List<Ty>,
1371    pub idx: Expr,
1372    pub requires: List<Expr>,
1373}
1374
1375pub type PolyFnSig = Binder<FnSig>;
1376
1377#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
1378pub struct FnSig {
1379    pub safety: Safety,
1380    pub abi: rustc_abi::ExternAbi,
1381    pub requires: List<Expr>,
1382    pub inputs: List<Ty>,
1383    pub output: Binder<FnOutput>,
1384    /// was this auto-lifted (or from a spec)
1385    pub lifted: bool,
1386}
1387
1388#[derive(
1389    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
1390)]
1391pub struct FnOutput {
1392    pub ret: Ty,
1393    pub ensures: List<Ensures>,
1394}
1395
1396#[derive(Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1397pub enum Ensures {
1398    Type(Path, Ty),
1399    Pred(Expr),
1400}
1401
1402#[derive(Debug, TypeVisitable, TypeFoldable)]
1403pub struct Qualifier {
1404    pub def_id: FluxLocalDefId,
1405    pub body: Binder<Expr>,
1406    pub kind: QualifierKind,
1407}
1408
1409#[derive(Debug, TypeFoldable, TypeVisitable, Copy, Clone)]
1410pub enum QualifierKind {
1411    Global,
1412    Local,
1413    Hint,
1414}
1415
1416/// A `PrimOpProp` is a single property for a primitive operation which
1417/// can be conjoined to get the definition of the [`PrimRel`] for that
1418/// primitive operation.
1419#[derive(Debug, TypeVisitable, TypeFoldable)]
1420pub struct PrimOpProp {
1421    pub def_id: FluxLocalDefId,
1422    pub op: BinOp,
1423    pub body: Binder<Expr>,
1424}
1425
1426#[derive(Debug, TypeVisitable, TypeFoldable)]
1427pub struct PrimRel {
1428    pub body: Binder<Expr>,
1429}
1430
1431pub type TyCtor = Binder<Ty>;
1432
1433impl TyCtor {
1434    pub fn to_ty(&self) -> Ty {
1435        match &self.vars()[..] {
1436            [] => {
1437                return self.skip_binder_ref().shift_out_escaping(1);
1438            }
1439            [BoundVariableKind::Refine(sort, ..)] => {
1440                if sort.is_unit() {
1441                    return self.replace_bound_reft(&Expr::unit());
1442                }
1443                if let Some(def_id) = sort.is_unit_adt() {
1444                    return self.replace_bound_reft(&Expr::unit_struct(def_id));
1445                }
1446            }
1447            _ => {}
1448        }
1449        Ty::exists(self.clone())
1450    }
1451}
1452
1453#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1454pub struct Ty(Interned<TyKind>);
1455
1456impl Ty {
1457    pub fn kind(&self) -> &TyKind {
1458        &self.0
1459    }
1460
1461    /// Dummy type used for the `Self` of a `TraitRef` created when converting a trait object, and
1462    /// which gets removed in `ExistentialTraitRef`. This type must not appear anywhere in other
1463    /// converted types and must be a valid `rustc` type (i.e., we must be able to call `to_rustc`
1464    /// on it). `TyKind::Infer(TyVid(0))` does the job, with the caveat that we must skip 0 when
1465    /// generating `TyKind::Infer` for "type holes".
1466    pub fn trait_object_dummy_self() -> Ty {
1467        Ty::infer(TyVid::from_u32(0))
1468    }
1469
1470    pub fn dynamic(preds: impl Into<List<Binder<ExistentialPredicate>>>, region: Region) -> Ty {
1471        BaseTy::Dynamic(preds.into(), region).to_ty()
1472    }
1473
1474    pub fn strg_ref(re: Region, path: Path, ty: Ty) -> Ty {
1475        TyKind::StrgRef(re, path, ty).intern()
1476    }
1477
1478    pub fn ptr(pk: impl Into<PtrKind>, path: impl Into<Path>) -> Ty {
1479        TyKind::Ptr(pk.into(), path.into()).intern()
1480    }
1481
1482    pub fn constr(p: impl Into<Expr>, ty: Ty) -> Ty {
1483        TyKind::Constr(p.into(), ty).intern()
1484    }
1485
1486    pub fn uninit() -> Ty {
1487        TyKind::Uninit.intern()
1488    }
1489
1490    pub fn indexed(bty: BaseTy, idx: impl Into<Expr>) -> Ty {
1491        TyKind::Indexed(bty, idx.into()).intern()
1492    }
1493
1494    pub fn exists(ty: Binder<Ty>) -> Ty {
1495        TyKind::Exists(ty).intern()
1496    }
1497
1498    pub fn exists_with_constr(bty: BaseTy, pred: Expr) -> Ty {
1499        let sort = bty.sort();
1500        let ty = Ty::indexed(bty, Expr::nu());
1501        Ty::exists(Binder::bind_with_sort(Ty::constr(pred, ty), sort))
1502    }
1503
1504    pub fn discr(adt_def: AdtDef, place: Place) -> Ty {
1505        TyKind::Discr(adt_def, place).intern()
1506    }
1507
1508    pub fn unit() -> Ty {
1509        Ty::tuple(vec![])
1510    }
1511
1512    pub fn bool() -> Ty {
1513        BaseTy::Bool.to_ty()
1514    }
1515
1516    pub fn int(int_ty: IntTy) -> Ty {
1517        BaseTy::Int(int_ty).to_ty()
1518    }
1519
1520    pub fn uint(uint_ty: UintTy) -> Ty {
1521        BaseTy::Uint(uint_ty).to_ty()
1522    }
1523
1524    pub fn param(param_ty: ParamTy) -> Ty {
1525        TyKind::Param(param_ty).intern()
1526    }
1527
1528    pub fn downcast(
1529        adt: AdtDef,
1530        args: GenericArgs,
1531        ty: Ty,
1532        variant: VariantIdx,
1533        fields: List<Ty>,
1534    ) -> Ty {
1535        TyKind::Downcast(adt, args, ty, variant, fields).intern()
1536    }
1537
1538    pub fn blocked(ty: Ty) -> Ty {
1539        TyKind::Blocked(ty).intern()
1540    }
1541
1542    pub fn str() -> Ty {
1543        BaseTy::Str.to_ty()
1544    }
1545
1546    pub fn char() -> Ty {
1547        BaseTy::Char.to_ty()
1548    }
1549
1550    pub fn float(float_ty: FloatTy) -> Ty {
1551        BaseTy::Float(float_ty).to_ty()
1552    }
1553
1554    pub fn mk_ref(region: Region, ty: Ty, mutbl: Mutability) -> Ty {
1555        BaseTy::Ref(region, ty, mutbl).to_ty()
1556    }
1557
1558    pub fn mk_slice(ty: Ty) -> Ty {
1559        BaseTy::Slice(ty).to_ty()
1560    }
1561
1562    pub fn mk_box(genv: GlobalEnv, deref_ty: Ty, alloc_ty: Ty) -> QueryResult<Ty> {
1563        let def_id = genv.tcx().require_lang_item(LangItem::OwnedBox, DUMMY_SP);
1564        let adt_def = genv.adt_def(def_id)?;
1565
1566        let args = List::from_arr([GenericArg::Ty(deref_ty), GenericArg::Ty(alloc_ty)]);
1567
1568        let bty = BaseTy::adt(adt_def, args);
1569        Ok(Ty::indexed(bty, Expr::unit_struct(def_id)))
1570    }
1571
1572    pub fn mk_box_with_default_alloc(genv: GlobalEnv, deref_ty: Ty) -> QueryResult<Ty> {
1573        let def_id = genv.tcx().require_lang_item(LangItem::OwnedBox, DUMMY_SP);
1574
1575        let generics = genv.generics_of(def_id)?;
1576        let alloc_ty = genv
1577            .lower_type_of(generics.own_params[1].def_id)?
1578            .skip_binder()
1579            .refine(&Refiner::default_for_item(genv, def_id)?)?;
1580
1581        Ty::mk_box(genv, deref_ty, alloc_ty)
1582    }
1583
1584    pub fn tuple(tys: impl Into<List<Ty>>) -> Ty {
1585        BaseTy::Tuple(tys.into()).to_ty()
1586    }
1587
1588    pub fn array(ty: Ty, c: Const) -> Ty {
1589        BaseTy::Array(ty, c).to_ty()
1590    }
1591
1592    pub fn closure(
1593        did: DefId,
1594        tys: impl Into<List<Ty>>,
1595        args: &flux_rustc_bridge::ty::GenericArgs,
1596    ) -> Ty {
1597        BaseTy::Closure(did, tys.into(), args.clone()).to_ty()
1598    }
1599
1600    pub fn coroutine(did: DefId, resume_ty: Ty, upvar_tys: List<Ty>) -> Ty {
1601        BaseTy::Coroutine(did, resume_ty, upvar_tys).to_ty()
1602    }
1603
1604    pub fn never() -> Ty {
1605        BaseTy::Never.to_ty()
1606    }
1607
1608    pub fn infer(vid: TyVid) -> Ty {
1609        TyKind::Infer(vid).intern()
1610    }
1611
1612    pub fn unconstr(&self) -> (Ty, Expr) {
1613        fn go(this: &Ty, preds: &mut Vec<Expr>) -> Ty {
1614            if let TyKind::Constr(pred, ty) = this.kind() {
1615                preds.push(pred.clone());
1616                go(ty, preds)
1617            } else {
1618                this.clone()
1619            }
1620        }
1621        let mut preds = vec![];
1622        (go(self, &mut preds), Expr::and_from_iter(preds))
1623    }
1624
1625    pub fn unblocked(&self) -> Ty {
1626        match self.kind() {
1627            TyKind::Blocked(ty) => ty.clone(),
1628            _ => self.clone(),
1629        }
1630    }
1631
1632    /// Whether the type is an `int` or a `uint`
1633    pub fn is_integral(&self) -> bool {
1634        self.as_bty_skipping_existentials()
1635            .map(BaseTy::is_integral)
1636            .unwrap_or_default()
1637    }
1638
1639    /// Whether the type is a `bool`
1640    pub fn is_bool(&self) -> bool {
1641        self.as_bty_skipping_existentials()
1642            .map(BaseTy::is_bool)
1643            .unwrap_or_default()
1644    }
1645
1646    /// Whether the type is a `char`
1647    pub fn is_char(&self) -> bool {
1648        self.as_bty_skipping_existentials()
1649            .map(BaseTy::is_char)
1650            .unwrap_or_default()
1651    }
1652
1653    pub fn is_uninit(&self) -> bool {
1654        matches!(self.kind(), TyKind::Uninit)
1655    }
1656
1657    pub fn is_box(&self) -> bool {
1658        self.as_bty_skipping_existentials()
1659            .map(BaseTy::is_box)
1660            .unwrap_or_default()
1661    }
1662
1663    pub fn is_struct(&self) -> bool {
1664        self.as_bty_skipping_existentials()
1665            .map(BaseTy::is_struct)
1666            .unwrap_or_default()
1667    }
1668
1669    pub fn is_array(&self) -> bool {
1670        self.as_bty_skipping_existentials()
1671            .map(BaseTy::is_array)
1672            .unwrap_or_default()
1673    }
1674
1675    pub fn is_slice(&self) -> bool {
1676        self.as_bty_skipping_existentials()
1677            .map(BaseTy::is_slice)
1678            .unwrap_or_default()
1679    }
1680
1681    pub fn as_bty_skipping_existentials(&self) -> Option<&BaseTy> {
1682        match self.kind() {
1683            TyKind::Indexed(bty, _) => Some(bty),
1684            TyKind::Exists(ty) => Some(ty.skip_binder_ref().as_bty_skipping_existentials()?),
1685            TyKind::Constr(_, ty) => ty.as_bty_skipping_existentials(),
1686            _ => None,
1687        }
1688    }
1689
1690    #[track_caller]
1691    pub fn expect_discr(&self) -> (&AdtDef, &Place) {
1692        if let TyKind::Discr(adt_def, place) = self.kind() {
1693            (adt_def, place)
1694        } else {
1695            tracked_span_bug!("expected discr")
1696        }
1697    }
1698
1699    #[track_caller]
1700    pub fn expect_adt(&self) -> (&AdtDef, &[GenericArg], &Expr) {
1701        if let TyKind::Indexed(BaseTy::Adt(adt_def, args), idx) = self.kind() {
1702            (adt_def, args, idx)
1703        } else {
1704            tracked_span_bug!("expected adt `{self:?}`")
1705        }
1706    }
1707
1708    #[track_caller]
1709    pub fn expect_tuple(&self) -> &[Ty] {
1710        if let TyKind::Indexed(BaseTy::Tuple(tys), _) = self.kind() {
1711            tys
1712        } else {
1713            tracked_span_bug!("expected tuple found `{self:?}` (kind: `{:?}`)", self.kind())
1714        }
1715    }
1716
1717    pub fn simplify_type(&self) -> Option<SimplifiedType> {
1718        self.as_bty_skipping_existentials()
1719            .and_then(BaseTy::simplify_type)
1720    }
1721}
1722
1723impl<'tcx> ToRustc<'tcx> for Ty {
1724    type T = rustc_middle::ty::Ty<'tcx>;
1725
1726    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
1727        match self.kind() {
1728            TyKind::Indexed(bty, _) => bty.to_rustc(tcx),
1729            TyKind::Exists(ty) => ty.skip_binder_ref().to_rustc(tcx),
1730            TyKind::Constr(_, ty) => ty.to_rustc(tcx),
1731            TyKind::Param(pty) => pty.to_ty(tcx),
1732            TyKind::StrgRef(re, _, ty) => {
1733                rustc_middle::ty::Ty::new_ref(
1734                    tcx,
1735                    re.to_rustc(tcx),
1736                    ty.to_rustc(tcx),
1737                    Mutability::Mut,
1738                )
1739            }
1740            TyKind::Infer(vid) => rustc_middle::ty::Ty::new_var(tcx, *vid),
1741            TyKind::Uninit
1742            | TyKind::Ptr(_, _)
1743            | TyKind::Discr(..)
1744            | TyKind::Downcast(..)
1745            | TyKind::Blocked(_) => bug!("TODO: to_rustc for `{self:?}`"),
1746        }
1747    }
1748}
1749
1750#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, Debug)]
1751pub enum TyKind {
1752    Indexed(BaseTy, Expr),
1753    Exists(Binder<Ty>),
1754    Constr(Expr, Ty),
1755    Uninit,
1756    StrgRef(Region, Path, Ty),
1757    Ptr(PtrKind, Path),
1758    /// This is a bit of a hack. We use this type internally to represent the result of
1759    /// [`Rvalue::Discriminant`] in a way that we can recover the necessary control information
1760    /// when checking a [`match`]. The hack is that we assume the dicriminant remains the same from
1761    /// the creation of this type until we use it in a [`match`].
1762    ///
1763    ///
1764    /// [`Rvalue::Discriminant`]: flux_rustc_bridge::mir::Rvalue::Discriminant
1765    /// [`match`]: flux_rustc_bridge::mir::TerminatorKind::SwitchInt
1766    Discr(AdtDef, Place),
1767    Param(ParamTy),
1768    Downcast(AdtDef, GenericArgs, Ty, VariantIdx, List<Ty>),
1769    Blocked(Ty),
1770    /// A type that needs to be inferred by matching the signature against a rust signature.
1771    /// [`TyKind::Infer`] appear as an intermediate step during `conv` and should not be present in
1772    /// the final signature.
1773    Infer(TyVid),
1774}
1775
1776#[derive(Copy, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1777pub enum PtrKind {
1778    Mut(Region),
1779    Box,
1780}
1781
1782#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1783pub enum BaseTy {
1784    Int(IntTy),
1785    Uint(UintTy),
1786    Bool,
1787    Str,
1788    Char,
1789    Slice(Ty),
1790    Adt(AdtDef, GenericArgs),
1791    Float(FloatTy),
1792    RawPtr(Ty, Mutability),
1793    RawPtrMetadata(Ty),
1794    Ref(Region, Ty, Mutability),
1795    FnPtr(PolyFnSig),
1796    FnDef(DefId, GenericArgs),
1797    Tuple(List<Ty>),
1798    Alias(AliasKind, AliasTy),
1799    Array(Ty, Const),
1800    Never,
1801    Closure(DefId, /* upvar_tys */ List<Ty>, flux_rustc_bridge::ty::GenericArgs),
1802    Coroutine(DefId, /*resume_ty: */ Ty, /* upvar_tys: */ List<Ty>),
1803    Dynamic(List<Binder<ExistentialPredicate>>, Region),
1804    Param(ParamTy),
1805    Infer(TyVid),
1806    Foreign(DefId),
1807}
1808
1809impl BaseTy {
1810    pub fn opaque(alias_ty: AliasTy) -> BaseTy {
1811        BaseTy::Alias(AliasKind::Opaque, alias_ty)
1812    }
1813
1814    pub fn projection(alias_ty: AliasTy) -> BaseTy {
1815        BaseTy::Alias(AliasKind::Projection, alias_ty)
1816    }
1817
1818    pub fn adt(adt_def: AdtDef, args: GenericArgs) -> BaseTy {
1819        BaseTy::Adt(adt_def, args)
1820    }
1821
1822    pub fn fn_def(def_id: DefId, args: impl Into<GenericArgs>) -> BaseTy {
1823        BaseTy::FnDef(def_id, args.into())
1824    }
1825
1826    pub fn from_primitive_str(s: &str) -> Option<BaseTy> {
1827        match s {
1828            "i8" => Some(BaseTy::Int(IntTy::I8)),
1829            "i16" => Some(BaseTy::Int(IntTy::I16)),
1830            "i32" => Some(BaseTy::Int(IntTy::I32)),
1831            "i64" => Some(BaseTy::Int(IntTy::I64)),
1832            "i128" => Some(BaseTy::Int(IntTy::I128)),
1833            "u8" => Some(BaseTy::Uint(UintTy::U8)),
1834            "u16" => Some(BaseTy::Uint(UintTy::U16)),
1835            "u32" => Some(BaseTy::Uint(UintTy::U32)),
1836            "u64" => Some(BaseTy::Uint(UintTy::U64)),
1837            "u128" => Some(BaseTy::Uint(UintTy::U128)),
1838            "f16" => Some(BaseTy::Float(FloatTy::F16)),
1839            "f32" => Some(BaseTy::Float(FloatTy::F32)),
1840            "f64" => Some(BaseTy::Float(FloatTy::F64)),
1841            "f128" => Some(BaseTy::Float(FloatTy::F128)),
1842            "isize" => Some(BaseTy::Int(IntTy::Isize)),
1843            "usize" => Some(BaseTy::Uint(UintTy::Usize)),
1844            "bool" => Some(BaseTy::Bool),
1845            "char" => Some(BaseTy::Char),
1846            "str" => Some(BaseTy::Str),
1847            _ => None,
1848        }
1849    }
1850
1851    /// If `self` is a primitive, return its [`Symbol`].
1852    pub fn primitive_symbol(&self) -> Option<Symbol> {
1853        match self {
1854            BaseTy::Bool => Some(sym::bool),
1855            BaseTy::Char => Some(sym::char),
1856            BaseTy::Float(f) => {
1857                match f {
1858                    FloatTy::F16 => Some(sym::f16),
1859                    FloatTy::F32 => Some(sym::f32),
1860                    FloatTy::F64 => Some(sym::f64),
1861                    FloatTy::F128 => Some(sym::f128),
1862                }
1863            }
1864            BaseTy::Int(f) => {
1865                match f {
1866                    IntTy::Isize => Some(sym::isize),
1867                    IntTy::I8 => Some(sym::i8),
1868                    IntTy::I16 => Some(sym::i16),
1869                    IntTy::I32 => Some(sym::i32),
1870                    IntTy::I64 => Some(sym::i64),
1871                    IntTy::I128 => Some(sym::i128),
1872                }
1873            }
1874            BaseTy::Uint(f) => {
1875                match f {
1876                    UintTy::Usize => Some(sym::usize),
1877                    UintTy::U8 => Some(sym::u8),
1878                    UintTy::U16 => Some(sym::u16),
1879                    UintTy::U32 => Some(sym::u32),
1880                    UintTy::U64 => Some(sym::u64),
1881                    UintTy::U128 => Some(sym::u128),
1882                }
1883            }
1884            BaseTy::Str => Some(sym::str),
1885            _ => None,
1886        }
1887    }
1888
1889    pub fn is_integral(&self) -> bool {
1890        matches!(self, BaseTy::Int(_) | BaseTy::Uint(_))
1891    }
1892
1893    pub fn is_signed(&self) -> bool {
1894        matches!(self, BaseTy::Int(_))
1895    }
1896
1897    pub fn is_unsigned(&self) -> bool {
1898        matches!(self, BaseTy::Uint(_))
1899    }
1900
1901    pub fn is_float(&self) -> bool {
1902        matches!(self, BaseTy::Float(_))
1903    }
1904
1905    pub fn is_bool(&self) -> bool {
1906        matches!(self, BaseTy::Bool)
1907    }
1908
1909    fn is_struct(&self) -> bool {
1910        matches!(self, BaseTy::Adt(adt_def, _) if adt_def.is_struct())
1911    }
1912
1913    fn is_array(&self) -> bool {
1914        matches!(self, BaseTy::Array(..))
1915    }
1916
1917    fn is_slice(&self) -> bool {
1918        matches!(self, BaseTy::Slice(..))
1919    }
1920
1921    pub fn is_box(&self) -> bool {
1922        matches!(self, BaseTy::Adt(adt_def, _) if adt_def.is_box())
1923    }
1924
1925    pub fn is_char(&self) -> bool {
1926        matches!(self, BaseTy::Char)
1927    }
1928
1929    pub fn is_str(&self) -> bool {
1930        matches!(self, BaseTy::Str)
1931    }
1932
1933    pub fn invariants(
1934        &self,
1935        tcx: TyCtxt,
1936        overflow_mode: OverflowMode,
1937    ) -> impl Iterator<Item = Invariant> {
1938        let (invariants, args) = match self {
1939            BaseTy::Adt(adt_def, args) => (adt_def.invariants().skip_binder(), &args[..]),
1940            BaseTy::Uint(uint_ty) => (uint_invariants(*uint_ty, overflow_mode), &[][..]),
1941            BaseTy::Int(int_ty) => (int_invariants(*int_ty, overflow_mode), &[][..]),
1942            BaseTy::Char => (char_invariants(), &[][..]),
1943            BaseTy::Slice(_) => (slice_invariants(overflow_mode), &[][..]),
1944            _ => (&[][..], &[][..]),
1945        };
1946        invariants
1947            .iter()
1948            .map(move |inv| EarlyBinder(inv).instantiate_ref(tcx, args, &[]))
1949    }
1950
1951    pub fn to_ty(&self) -> Ty {
1952        let sort = self.sort();
1953        if sort.is_unit() {
1954            Ty::indexed(self.clone(), Expr::unit())
1955        } else {
1956            Ty::exists(Binder::bind_with_sort(
1957                Ty::indexed(self.shift_in_escaping(1), Expr::nu()),
1958                sort,
1959            ))
1960        }
1961    }
1962
1963    pub fn to_subset_ty_ctor(&self) -> SubsetTyCtor {
1964        let sort = self.sort();
1965        Binder::bind_with_sort(SubsetTy::trivial(self.clone(), Expr::nu()), sort)
1966    }
1967
1968    #[track_caller]
1969    pub fn expect_adt(&self) -> (&AdtDef, &[GenericArg]) {
1970        if let BaseTy::Adt(adt_def, args) = self {
1971            (adt_def, args)
1972        } else {
1973            tracked_span_bug!("expected adt `{self:?}`")
1974        }
1975    }
1976
1977    /// A type is an *atom* if it is "self-delimiting", i.e., it has a clear boundary
1978    /// when printed. This is used to avoid unnecessary parenthesis when pretty printing.
1979    pub fn is_atom(&self) -> bool {
1980        // (nilehmann) I'm not sure about this list, please adjust if you get any odd behavior
1981        matches!(
1982            self,
1983            BaseTy::Int(_)
1984                | BaseTy::Uint(_)
1985                | BaseTy::Slice(_)
1986                | BaseTy::Bool
1987                | BaseTy::Char
1988                | BaseTy::Str
1989                | BaseTy::Adt(..)
1990                | BaseTy::Tuple(..)
1991                | BaseTy::Param(_)
1992                | BaseTy::Array(..)
1993                | BaseTy::Never
1994                | BaseTy::Closure(..)
1995                | BaseTy::Coroutine(..)
1996                // opaque alias are atoms the way we print them now, but they won't
1997                // be if we print them as `impl Trait`
1998                | BaseTy::Alias(..)
1999        )
2000    }
2001
2002    /// Similar to [`rustc_infer::infer::canonical::ir::fast_reject::simplify_type`].
2003    ///
2004    /// This implementation is currently incomplete, so it should only be used in contexts
2005    /// where completeness is not required. Currently, it's used to find incoherent
2006    /// implementations when resolving associated constants. In this context, incompleteness
2007    /// is acceptable since the worst case outcome is simply failing to resolve a type-relative
2008    /// constant.
2009    fn simplify_type(&self) -> Option<SimplifiedType> {
2010        match self {
2011            BaseTy::Bool => Some(SimplifiedType::Bool),
2012            BaseTy::Char => Some(SimplifiedType::Char),
2013            BaseTy::Int(int_type) => Some(SimplifiedType::Int(*int_type)),
2014            BaseTy::Uint(uint_type) => Some(SimplifiedType::Uint(*uint_type)),
2015            BaseTy::Float(float_type) => Some(SimplifiedType::Float(*float_type)),
2016            BaseTy::Adt(def, _) => Some(SimplifiedType::Adt(def.did())),
2017            BaseTy::Str => Some(SimplifiedType::Str),
2018            BaseTy::Array(..) => Some(SimplifiedType::Array),
2019            BaseTy::Slice(..) => Some(SimplifiedType::Slice),
2020            BaseTy::RawPtr(_, mutbl) => Some(SimplifiedType::Ptr(*mutbl)),
2021            BaseTy::Ref(_, _, mutbl) => Some(SimplifiedType::Ref(*mutbl)),
2022            BaseTy::FnDef(def_id, _) | BaseTy::Closure(def_id, ..) => {
2023                Some(SimplifiedType::Closure(*def_id))
2024            }
2025            BaseTy::Coroutine(def_id, ..) => Some(SimplifiedType::Coroutine(*def_id)),
2026            BaseTy::Never => Some(SimplifiedType::Never),
2027            BaseTy::Tuple(tys) => Some(SimplifiedType::Tuple(tys.len())),
2028            BaseTy::FnPtr(poly_fn_sig) => {
2029                Some(SimplifiedType::Function(poly_fn_sig.skip_binder_ref().inputs().len()))
2030            }
2031            BaseTy::Foreign(def_id) => Some(SimplifiedType::Foreign(*def_id)),
2032            BaseTy::RawPtrMetadata(_)
2033            | BaseTy::Alias(..)
2034            | BaseTy::Param(_)
2035            | BaseTy::Dynamic(..)
2036            | BaseTy::Infer(_) => None,
2037        }
2038    }
2039}
2040
2041impl<'tcx> ToRustc<'tcx> for BaseTy {
2042    type T = rustc_middle::ty::Ty<'tcx>;
2043
2044    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2045        use rustc_middle::ty;
2046        match self {
2047            BaseTy::Int(i) => ty::Ty::new_int(tcx, *i),
2048            BaseTy::Uint(i) => ty::Ty::new_uint(tcx, *i),
2049            BaseTy::Param(pty) => pty.to_ty(tcx),
2050            BaseTy::Slice(ty) => ty::Ty::new_slice(tcx, ty.to_rustc(tcx)),
2051            BaseTy::Bool => tcx.types.bool,
2052            BaseTy::Char => tcx.types.char,
2053            BaseTy::Str => tcx.types.str_,
2054            BaseTy::Adt(adt_def, args) => {
2055                let did = adt_def.did();
2056                let adt_def = tcx.adt_def(did);
2057                let args = args.to_rustc(tcx);
2058                ty::Ty::new_adt(tcx, adt_def, args)
2059            }
2060            BaseTy::FnDef(def_id, args) => {
2061                let args = args.to_rustc(tcx);
2062                ty::Ty::new_fn_def(tcx, *def_id, args)
2063            }
2064            BaseTy::Float(f) => ty::Ty::new_float(tcx, *f),
2065            BaseTy::RawPtr(ty, mutbl) => ty::Ty::new_ptr(tcx, ty.to_rustc(tcx), *mutbl),
2066            BaseTy::Ref(re, ty, mutbl) => {
2067                ty::Ty::new_ref(tcx, re.to_rustc(tcx), ty.to_rustc(tcx), *mutbl)
2068            }
2069            BaseTy::FnPtr(poly_sig) => ty::Ty::new_fn_ptr(tcx, poly_sig.to_rustc(tcx)),
2070            BaseTy::Tuple(tys) => {
2071                let ts = tys.iter().map(|ty| ty.to_rustc(tcx)).collect_vec();
2072                ty::Ty::new_tup(tcx, &ts)
2073            }
2074            BaseTy::Alias(kind, alias_ty) => {
2075                ty::Ty::new_alias(tcx, kind.to_rustc(tcx), alias_ty.to_rustc(tcx))
2076            }
2077            BaseTy::Array(ty, n) => {
2078                let ty = ty.to_rustc(tcx);
2079                let n = n.to_rustc(tcx);
2080                ty::Ty::new_array_with_const_len(tcx, ty, n)
2081            }
2082            BaseTy::Never => tcx.types.never,
2083            BaseTy::Closure(did, _, args) => ty::Ty::new_closure(tcx, *did, args.to_rustc(tcx)),
2084            BaseTy::Dynamic(exi_preds, re) => {
2085                let preds: Vec<_> = exi_preds
2086                    .iter()
2087                    .map(|pred| pred.to_rustc(tcx))
2088                    .collect_vec();
2089                let preds = tcx.mk_poly_existential_predicates(&preds);
2090                ty::Ty::new_dynamic(tcx, preds, re.to_rustc(tcx))
2091            }
2092            BaseTy::Coroutine(def_id, resume_ty, upvars) => {
2093                bug!("TODO: Generator {def_id:?} {resume_ty:?} {upvars:?}")
2094                // let args = args.iter().map(|arg| into_rustc_generic_arg(tcx, arg));
2095                // let args = tcx.mk_args_from_iter(args);
2096                // ty::Ty::new_generator(*tcx, *def_id, args, mov)
2097            }
2098            BaseTy::Infer(ty_vid) => ty::Ty::new_var(tcx, *ty_vid),
2099            BaseTy::Foreign(def_id) => ty::Ty::new_foreign(tcx, *def_id),
2100            BaseTy::RawPtrMetadata(ty) => {
2101                ty::Ty::new_ptr(
2102                    tcx,
2103                    ty.to_rustc(tcx),
2104                    RawPtrKind::FakeForPtrMetadata.to_mutbl_lossy(),
2105                )
2106            }
2107        }
2108    }
2109}
2110
2111#[derive(
2112    Clone, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
2113)]
2114pub struct AliasTy {
2115    pub def_id: DefId,
2116    pub args: GenericArgs,
2117    /// Holds the refinement-arguments for opaque-types; empty for projections
2118    pub refine_args: RefineArgs,
2119}
2120
2121impl AliasTy {
2122    pub fn new(def_id: DefId, args: GenericArgs, refine_args: RefineArgs) -> Self {
2123        AliasTy { args, refine_args, def_id }
2124    }
2125}
2126
2127/// This methods work only with associated type projections (i.e., no opaque types)
2128impl AliasTy {
2129    pub fn self_ty(&self) -> SubsetTyCtor {
2130        self.args[0].expect_base().clone()
2131    }
2132
2133    pub fn with_self_ty(&self, self_ty: SubsetTyCtor) -> Self {
2134        Self {
2135            def_id: self.def_id,
2136            args: [GenericArg::Base(self_ty)]
2137                .into_iter()
2138                .chain(self.args.iter().skip(1).cloned())
2139                .collect(),
2140            refine_args: self.refine_args.clone(),
2141        }
2142    }
2143}
2144
2145impl<'tcx> ToRustc<'tcx> for AliasTy {
2146    type T = rustc_middle::ty::AliasTy<'tcx>;
2147
2148    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2149        rustc_middle::ty::AliasTy::new(tcx, self.def_id, self.args.to_rustc(tcx))
2150    }
2151}
2152
2153pub type RefineArgs = List<Expr>;
2154
2155#[extension(pub trait RefineArgsExt)]
2156impl RefineArgs {
2157    fn identity_for_item(genv: GlobalEnv, def_id: DefId) -> QueryResult<RefineArgs> {
2158        Self::for_item(genv, def_id, |param, index| {
2159            Ok(Expr::var(Var::EarlyParam(EarlyReftParam {
2160                index: index as u32,
2161                name: param.name(),
2162            })))
2163        })
2164    }
2165
2166    fn for_item<F>(genv: GlobalEnv, def_id: DefId, mut mk: F) -> QueryResult<RefineArgs>
2167    where
2168        F: FnMut(EarlyBinder<RefineParam>, usize) -> QueryResult<Expr>,
2169    {
2170        let reft_generics = genv.refinement_generics_of(def_id)?;
2171        let count = reft_generics.count();
2172        let mut args = Vec::with_capacity(count);
2173        reft_generics.fill_item(genv, &mut args, &mut mk)?;
2174        Ok(List::from_vec(args))
2175    }
2176}
2177
2178/// A type constructor meant to be used as generic a argument of [kind base]. This is just an alias
2179/// to [`Binder<SubsetTy>`], but we expect the binder to have a single bound variable of the sort of
2180/// the underlying [base type].
2181///
2182/// [kind base]: GenericParamDefKind::Base
2183/// [base type]: SubsetTy::bty
2184pub type SubsetTyCtor = Binder<SubsetTy>;
2185
2186impl SubsetTyCtor {
2187    pub fn as_bty_skipping_binder(&self) -> &BaseTy {
2188        &self.as_ref().skip_binder().bty
2189    }
2190
2191    pub fn to_ty(&self) -> Ty {
2192        let sort = self.sort();
2193        if sort.is_unit() {
2194            self.replace_bound_reft(&Expr::unit()).to_ty()
2195        } else if let Some(def_id) = sort.is_unit_adt() {
2196            self.replace_bound_reft(&Expr::unit_struct(def_id)).to_ty()
2197        } else {
2198            Ty::exists(self.as_ref().map(SubsetTy::to_ty))
2199        }
2200    }
2201
2202    pub fn to_ty_ctor(&self) -> TyCtor {
2203        self.as_ref().map(SubsetTy::to_ty)
2204    }
2205}
2206
2207/// A subset type is a simplified version of a type that has the form `{b[e] | p}` where `b` is a
2208/// [`BaseTy`], `e` a refinement index, and `p` a predicate.
2209///
2210/// These are mainly found under a [`Binder`] with a single variable of the base type's sort. This
2211/// can be interpreted as a type constructor or an existential type. For example, under a binder with a
2212/// variable `v` of sort `int`, we can interpret `{i32[v] | v > 0}` as:
2213/// - A lambda `λv:int. {i32[v] | v > 0}` that "constructs" types when applied to ints, or
2214/// - An existential type `∃v:int. {i32[v] | v > 0}`.
2215///
2216/// This second interpretation is the reason we call this a subset type, i.e., the type `∃v. {b[v] | p}`
2217/// corresponds to the subset of values of  type `b` whose index satisfies `p`. These are the types
2218/// written as `B{v: p}` in the surface syntax and correspond to the types supported in other
2219/// refinement type systems like Liquid Haskell (with the difference that we are explicit
2220/// about separating refinements from program values via an index).
2221///
2222/// The main purpose for subset types is to be used as generic arguments of [kind base] when
2223/// interpreted as type constructors. They have two key properties that makes them suitable
2224/// for this:
2225///
2226/// 1. **Syntactic Restriction**: Subset types are syntactically restricted, making it easier to
2227///    relate them structurally (e.g., for subtyping). For instance, given two types `S<λv. T1>` and
2228///    `S<λ. T2>`, if `T1` and `T2` are subset types, we know they match structurally (at least
2229///    shallowly). In particularly, the syntactic restriction rules out complex types like
2230///    `S<λv. (i32[v], i32[v])>` simplifying some operations.
2231///
2232/// 2. **Eager Canonicalization**: Subset types can be eagerly canonicalized via [*strengthening*]
2233///    during substitution. For example, suppose we have a function:
2234///    ```text
2235///    fn foo<T>(x: T[@a], y: { T[@b] | b == a }) { }
2236///    ```
2237///    If we instantiate `T` with `λv. { i32[v] | v > 0}`, after substitution and applying the
2238///    lambda (the indexing syntax `T[a]` corresponds to an application of the lambda), we get:
2239///    ```text
2240///    fn foo(x: {i32[@a] | a > 0}, y: { { i32[@b] | b > 0 } | b == a }) { }
2241///    ```
2242///    Via *strengthening* we can canonicalize this to
2243///    ```text
2244///    fn foo(x: {i32[@a] | a > 0}, y: { i32[@b] | b == a && b > 0 }) { }
2245///    ```
2246///    As a result, we can guarantee the syntactic restriction through substitution.
2247///
2248/// [kind base]: GenericParamDefKind::Base
2249/// [*strengthening*]: https://arxiv.org/pdf/2010.07763.pdf
2250#[derive(PartialEq, Clone, Eq, Hash, TyEncodable, TyDecodable)]
2251pub struct SubsetTy {
2252    /// The base type `b` in the subset type `{b[e] | p}`.
2253    ///
2254    /// **NOTE:** This is mostly going to be under a [`Binder`]. It is not yet clear to me whether
2255    /// this [`BaseTy`] should be able to mention variables in the binder. In general, in a type
2256    /// `∃v. {b[e] | p}`, it's fine to mention `v` inside `b`, but since [`SubsetTy`] is meant to
2257    /// facilitate syntactic manipulation we may want to restrict this.
2258    pub bty: BaseTy,
2259    /// The refinement index `e` in the subset type `{b[e] | p}`. This can be an arbitrary expression,
2260    /// which makes manipulation easier. However, since this is mostly found under a binder, we expect
2261    /// it to be [`Expr::nu()`].
2262    pub idx: Expr,
2263    /// The predicate `p` in the subset type `{b[e] | p}`.
2264    pub pred: Expr,
2265}
2266
2267impl SubsetTy {
2268    pub fn new(bty: BaseTy, idx: impl Into<Expr>, pred: impl Into<Expr>) -> Self {
2269        Self { bty, idx: idx.into(), pred: pred.into() }
2270    }
2271
2272    pub fn trivial(bty: BaseTy, idx: impl Into<Expr>) -> Self {
2273        Self::new(bty, idx, Expr::tt())
2274    }
2275
2276    pub fn strengthen(&self, pred: impl Into<Expr>) -> Self {
2277        let this = self.clone();
2278        let pred = Expr::and(this.pred, pred).simplify(&SnapshotMap::default());
2279        Self { bty: this.bty, idx: this.idx, pred }
2280    }
2281
2282    pub fn to_ty(&self) -> Ty {
2283        let bty = self.bty.clone();
2284        if self.pred.is_trivially_true() {
2285            Ty::indexed(bty, &self.idx)
2286        } else {
2287            Ty::constr(&self.pred, Ty::indexed(bty, &self.idx))
2288        }
2289    }
2290}
2291
2292impl<'tcx> ToRustc<'tcx> for SubsetTy {
2293    type T = rustc_middle::ty::Ty<'tcx>;
2294
2295    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> rustc_middle::ty::Ty<'tcx> {
2296        self.bty.to_rustc(tcx)
2297    }
2298}
2299
2300#[derive(PartialEq, Clone, Eq, Hash, TyEncodable, TyDecodable)]
2301pub enum GenericArg {
2302    Ty(Ty),
2303    Base(SubsetTyCtor),
2304    Lifetime(Region),
2305    Const(Const),
2306}
2307
2308impl GenericArg {
2309    #[track_caller]
2310    pub fn expect_type(&self) -> &Ty {
2311        if let GenericArg::Ty(ty) = self {
2312            ty
2313        } else {
2314            bug!("expected `rty::GenericArg::Ty`, found `{self:?}`")
2315        }
2316    }
2317
2318    #[track_caller]
2319    pub fn expect_base(&self) -> &SubsetTyCtor {
2320        if let GenericArg::Base(ctor) = self {
2321            ctor
2322        } else {
2323            bug!("expected `rty::GenericArg::Base`, found `{self:?}`")
2324        }
2325    }
2326
2327    pub fn from_param_def(param: &GenericParamDef) -> Self {
2328        match param.kind {
2329            GenericParamDefKind::Type { .. } => {
2330                let param_ty = ParamTy { index: param.index, name: param.name };
2331                GenericArg::Ty(Ty::param(param_ty))
2332            }
2333            GenericParamDefKind::Base { .. } => {
2334                // λv. T[v]
2335                let param_ty = ParamTy { index: param.index, name: param.name };
2336                GenericArg::Base(Binder::bind_with_sort(
2337                    SubsetTy::trivial(BaseTy::Param(param_ty), Expr::nu()),
2338                    Sort::Param(param_ty),
2339                ))
2340            }
2341            GenericParamDefKind::Lifetime => {
2342                let region = EarlyParamRegion { index: param.index, name: param.name };
2343                GenericArg::Lifetime(Region::ReEarlyParam(region))
2344            }
2345            GenericParamDefKind::Const { .. } => {
2346                let param_const = ParamConst { index: param.index, name: param.name };
2347                let kind = ConstKind::Param(param_const);
2348                GenericArg::Const(Const { kind })
2349            }
2350        }
2351    }
2352
2353    /// Creates a `GenericArgs` from the definition of generic parameters, by calling a closure to
2354    /// obtain arg. The closures get to observe the `GenericArgs` as they're being built, which can
2355    /// be used to correctly replace defaults of generic parameters.
2356    pub fn for_item<F>(genv: GlobalEnv, def_id: DefId, mut mk_kind: F) -> QueryResult<GenericArgs>
2357    where
2358        F: FnMut(&GenericParamDef, &[GenericArg]) -> GenericArg,
2359    {
2360        let defs = genv.generics_of(def_id)?;
2361        let count = defs.count();
2362        let mut args = Vec::with_capacity(count);
2363        Self::fill_item(genv, &mut args, &defs, &mut mk_kind)?;
2364        Ok(List::from_vec(args))
2365    }
2366
2367    pub fn identity_for_item(genv: GlobalEnv, def_id: DefId) -> QueryResult<GenericArgs> {
2368        Self::for_item(genv, def_id, |param, _| GenericArg::from_param_def(param))
2369    }
2370
2371    fn fill_item<F>(
2372        genv: GlobalEnv,
2373        args: &mut Vec<GenericArg>,
2374        generics: &Generics,
2375        mk_kind: &mut F,
2376    ) -> QueryResult<()>
2377    where
2378        F: FnMut(&GenericParamDef, &[GenericArg]) -> GenericArg,
2379    {
2380        if let Some(def_id) = generics.parent {
2381            let parent_generics = genv.generics_of(def_id)?;
2382            Self::fill_item(genv, args, &parent_generics, mk_kind)?;
2383        }
2384        for param in &generics.own_params {
2385            let kind = mk_kind(param, args);
2386            tracked_span_assert_eq!(param.index as usize, args.len());
2387            args.push(kind);
2388        }
2389        Ok(())
2390    }
2391}
2392
2393impl From<TyOrBase> for GenericArg {
2394    fn from(v: TyOrBase) -> Self {
2395        match v {
2396            TyOrBase::Ty(ty) => GenericArg::Ty(ty),
2397            TyOrBase::Base(ctor) => GenericArg::Base(ctor),
2398        }
2399    }
2400}
2401
2402impl<'tcx> ToRustc<'tcx> for GenericArg {
2403    type T = rustc_middle::ty::GenericArg<'tcx>;
2404
2405    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2406        use rustc_middle::ty;
2407        match self {
2408            GenericArg::Ty(ty) => ty::GenericArg::from(ty.to_rustc(tcx)),
2409            GenericArg::Base(ctor) => ty::GenericArg::from(ctor.skip_binder_ref().to_rustc(tcx)),
2410            GenericArg::Lifetime(re) => ty::GenericArg::from(re.to_rustc(tcx)),
2411            GenericArg::Const(c) => ty::GenericArg::from(c.to_rustc(tcx)),
2412        }
2413    }
2414}
2415
2416pub type GenericArgs = List<GenericArg>;
2417
2418#[extension(pub trait GenericArgsExt)]
2419impl GenericArgs {
2420    #[track_caller]
2421    fn box_args(&self) -> (&Ty, &Ty) {
2422        if let [GenericArg::Ty(deref), GenericArg::Ty(alloc)] = &self[..] {
2423            (deref, alloc)
2424        } else {
2425            bug!("invalid generic arguments for box");
2426        }
2427    }
2428
2429    // We can't implement [`ToRustc`] because of coherence so we add it here
2430    fn to_rustc<'tcx>(&self, tcx: TyCtxt<'tcx>) -> rustc_middle::ty::GenericArgsRef<'tcx> {
2431        tcx.mk_args_from_iter(self.iter().map(|arg| arg.to_rustc(tcx)))
2432    }
2433
2434    fn rebase_onto(
2435        &self,
2436        tcx: &TyCtxt,
2437        source_ancestor: DefId,
2438        target_args: &GenericArgs,
2439    ) -> List<GenericArg> {
2440        let defs = tcx.generics_of(source_ancestor);
2441        target_args
2442            .iter()
2443            .chain(self.iter().skip(defs.count()))
2444            .cloned()
2445            .collect()
2446    }
2447}
2448
2449#[derive(Debug)]
2450pub enum TyOrBase {
2451    Ty(Ty),
2452    Base(SubsetTyCtor),
2453}
2454
2455impl TyOrBase {
2456    pub fn into_ty(self) -> Ty {
2457        match self {
2458            TyOrBase::Ty(ty) => ty,
2459            TyOrBase::Base(ctor) => ctor.to_ty(),
2460        }
2461    }
2462
2463    #[track_caller]
2464    pub fn expect_base(self) -> SubsetTyCtor {
2465        match self {
2466            TyOrBase::Base(ctor) => ctor,
2467            TyOrBase::Ty(_) => tracked_span_bug!("expected `TyOrBase::Base`"),
2468        }
2469    }
2470
2471    pub fn as_base(self) -> Option<SubsetTyCtor> {
2472        match self {
2473            TyOrBase::Base(ctor) => Some(ctor),
2474            TyOrBase::Ty(_) => None,
2475        }
2476    }
2477}
2478
2479#[derive(Debug, Clone, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
2480pub enum TyOrCtor {
2481    Ty(Ty),
2482    Ctor(TyCtor),
2483}
2484
2485impl TyOrCtor {
2486    #[track_caller]
2487    pub fn expect_ctor(self) -> TyCtor {
2488        match self {
2489            TyOrCtor::Ctor(ctor) => ctor,
2490            TyOrCtor::Ty(_) => tracked_span_bug!("expected `TyOrCtor::Ctor`"),
2491        }
2492    }
2493
2494    pub fn expect_subset_ty_ctor(self) -> SubsetTyCtor {
2495        self.expect_ctor().map(|ty| {
2496            if let canonicalize::CanonicalTy::Constr(constr_ty) = ty.shallow_canonicalize()
2497                && let TyKind::Indexed(bty, idx) = constr_ty.ty().kind()
2498                && idx.is_nu()
2499            {
2500                SubsetTy::new(bty.clone(), Expr::nu(), constr_ty.pred())
2501            } else {
2502                tracked_span_bug!()
2503            }
2504        })
2505    }
2506
2507    pub fn to_ty(&self) -> Ty {
2508        match self {
2509            TyOrCtor::Ctor(ctor) => ctor.to_ty(),
2510            TyOrCtor::Ty(ty) => ty.clone(),
2511        }
2512    }
2513}
2514
2515impl From<TyOrBase> for TyOrCtor {
2516    fn from(v: TyOrBase) -> Self {
2517        match v {
2518            TyOrBase::Ty(ty) => TyOrCtor::Ty(ty),
2519            TyOrBase::Base(ctor) => TyOrCtor::Ctor(ctor.to_ty_ctor()),
2520        }
2521    }
2522}
2523
2524impl CoroutineObligPredicate {
2525    pub fn to_poly_fn_sig(&self) -> PolyFnSig {
2526        let vars = vec![];
2527
2528        let resume_ty = &self.resume_ty;
2529        let env_ty = Ty::coroutine(self.def_id, resume_ty.clone(), self.upvar_tys.clone());
2530
2531        let inputs = List::from_arr([env_ty, resume_ty.clone()]);
2532        let output =
2533            Binder::bind_with_vars(FnOutput::new(self.output.clone(), vec![]), List::empty());
2534
2535        PolyFnSig::bind_with_vars(
2536            FnSig::new(
2537                Safety::Safe,
2538                rustc_abi::ExternAbi::RustCall,
2539                List::empty(),
2540                inputs,
2541                output,
2542                false,
2543            ),
2544            List::from(vars),
2545        )
2546    }
2547}
2548
2549impl RefinementGenerics {
2550    pub fn count(&self) -> usize {
2551        self.parent_count + self.own_params.len()
2552    }
2553
2554    pub fn own_count(&self) -> usize {
2555        self.own_params.len()
2556    }
2557}
2558
2559impl EarlyBinder<RefinementGenerics> {
2560    pub fn parent(&self) -> Option<DefId> {
2561        self.skip_binder_ref().parent
2562    }
2563
2564    pub fn parent_count(&self) -> usize {
2565        self.skip_binder_ref().parent_count
2566    }
2567
2568    pub fn count(&self) -> usize {
2569        self.skip_binder_ref().count()
2570    }
2571
2572    pub fn own_count(&self) -> usize {
2573        self.skip_binder_ref().own_count()
2574    }
2575
2576    pub fn own_param_at(&self, index: usize) -> EarlyBinder<RefineParam> {
2577        self.as_ref().map(|this| this.own_params[index].clone())
2578    }
2579
2580    pub fn param_at(
2581        &self,
2582        param_index: usize,
2583        genv: GlobalEnv,
2584    ) -> QueryResult<EarlyBinder<RefineParam>> {
2585        if let Some(index) = param_index.checked_sub(self.parent_count()) {
2586            Ok(self.own_param_at(index))
2587        } else {
2588            let parent = self.parent().expect("parent_count > 0 but no parent?");
2589            genv.refinement_generics_of(parent)?
2590                .param_at(param_index, genv)
2591        }
2592    }
2593
2594    pub fn iter_own_params(&self) -> impl Iterator<Item = EarlyBinder<RefineParam>> + use<'_> {
2595        self.skip_binder_ref()
2596            .own_params
2597            .iter()
2598            .cloned()
2599            .map(EarlyBinder)
2600    }
2601
2602    pub fn fill_item<F, R>(&self, genv: GlobalEnv, vec: &mut Vec<R>, mk: &mut F) -> QueryResult
2603    where
2604        F: FnMut(EarlyBinder<RefineParam>, usize) -> QueryResult<R>,
2605    {
2606        if let Some(def_id) = self.parent() {
2607            genv.refinement_generics_of(def_id)?
2608                .fill_item(genv, vec, mk)?;
2609        }
2610        for param in self.iter_own_params() {
2611            vec.push(mk(param, vec.len())?);
2612        }
2613        Ok(())
2614    }
2615}
2616
2617impl EarlyBinder<GenericPredicates> {
2618    pub fn predicates(&self) -> EarlyBinder<List<Clause>> {
2619        EarlyBinder(self.0.predicates.clone())
2620    }
2621}
2622
2623impl EarlyBinder<FuncSort> {
2624    /// See [`subst::GenericsSubstForSort`]
2625    pub fn instantiate_func_sort<E>(
2626        self,
2627        sort_for_param: impl FnMut(ParamTy) -> Result<Sort, E>,
2628    ) -> Result<FuncSort, E> {
2629        self.0.try_fold_with(&mut subst::GenericsSubstFolder::new(
2630            subst::GenericsSubstForSort { sort_for_param },
2631            &[],
2632        ))
2633    }
2634}
2635
2636impl VariantSig {
2637    pub fn new(
2638        adt_def: AdtDef,
2639        args: GenericArgs,
2640        fields: List<Ty>,
2641        idx: Expr,
2642        requires: List<Expr>,
2643    ) -> Self {
2644        VariantSig { adt_def, args, fields, idx, requires }
2645    }
2646
2647    pub fn fields(&self) -> &[Ty] {
2648        &self.fields
2649    }
2650
2651    pub fn ret(&self) -> Ty {
2652        let bty = BaseTy::Adt(self.adt_def.clone(), self.args.clone());
2653        let idx = self.idx.clone();
2654        Ty::indexed(bty, idx)
2655    }
2656}
2657
2658impl FnSig {
2659    pub fn new(
2660        safety: Safety,
2661        abi: rustc_abi::ExternAbi,
2662        requires: List<Expr>,
2663        inputs: List<Ty>,
2664        output: Binder<FnOutput>,
2665        lifted: bool,
2666    ) -> Self {
2667        FnSig { safety, abi, requires, inputs, output, lifted }
2668    }
2669
2670    pub fn requires(&self) -> &[Expr] {
2671        &self.requires
2672    }
2673
2674    pub fn inputs(&self) -> &[Ty] {
2675        &self.inputs
2676    }
2677
2678    pub fn output(&self) -> Binder<FnOutput> {
2679        self.output.clone()
2680    }
2681}
2682
2683impl<'tcx> ToRustc<'tcx> for FnSig {
2684    type T = rustc_middle::ty::FnSig<'tcx>;
2685
2686    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2687        tcx.mk_fn_sig(
2688            self.inputs().iter().map(|ty| ty.to_rustc(tcx)),
2689            self.output().as_ref().skip_binder().to_rustc(tcx),
2690            false,
2691            self.safety,
2692            self.abi,
2693        )
2694    }
2695}
2696
2697impl FnOutput {
2698    pub fn new(ret: Ty, ensures: impl Into<List<Ensures>>) -> Self {
2699        Self { ret, ensures: ensures.into() }
2700    }
2701}
2702
2703impl<'tcx> ToRustc<'tcx> for FnOutput {
2704    type T = rustc_middle::ty::Ty<'tcx>;
2705
2706    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2707        self.ret.to_rustc(tcx)
2708    }
2709}
2710
2711impl AdtDef {
2712    pub fn new(
2713        rustc: ty::AdtDef,
2714        sort_def: AdtSortDef,
2715        invariants: Vec<Invariant>,
2716        opaque: bool,
2717    ) -> Self {
2718        AdtDef(Interned::new(AdtDefData { invariants, sort_def, opaque, rustc }))
2719    }
2720
2721    pub fn did(&self) -> DefId {
2722        self.0.rustc.did()
2723    }
2724
2725    pub fn sort_def(&self) -> &AdtSortDef {
2726        &self.0.sort_def
2727    }
2728
2729    pub fn sort(&self, args: &[GenericArg]) -> Sort {
2730        self.sort_def().to_sort(args)
2731    }
2732
2733    pub fn is_box(&self) -> bool {
2734        self.0.rustc.is_box()
2735    }
2736
2737    pub fn is_enum(&self) -> bool {
2738        self.0.rustc.is_enum()
2739    }
2740
2741    pub fn is_struct(&self) -> bool {
2742        self.0.rustc.is_struct()
2743    }
2744
2745    pub fn is_union(&self) -> bool {
2746        self.0.rustc.is_union()
2747    }
2748
2749    pub fn variants(&self) -> &IndexSlice<VariantIdx, VariantDef> {
2750        self.0.rustc.variants()
2751    }
2752
2753    pub fn variant(&self, idx: VariantIdx) -> &VariantDef {
2754        self.0.rustc.variant(idx)
2755    }
2756
2757    pub fn invariants(&self) -> EarlyBinder<&[Invariant]> {
2758        EarlyBinder(&self.0.invariants)
2759    }
2760
2761    pub fn discriminants(&self) -> impl Iterator<Item = (VariantIdx, u128)> + '_ {
2762        self.0.rustc.discriminants()
2763    }
2764
2765    pub fn is_opaque(&self) -> bool {
2766        self.0.opaque
2767    }
2768}
2769
2770impl EarlyBinder<PolyVariant> {
2771    // The field_idx is `Some(i)` when we have the `i`-th field of a `union`, in which case,
2772    // the `inputs` are _just_ the `i`-th type (and not all the types...)
2773    pub fn to_poly_fn_sig(&self, field_idx: Option<crate::FieldIdx>) -> EarlyBinder<PolyFnSig> {
2774        self.as_ref().map(|poly_variant| {
2775            poly_variant.as_ref().map(|variant| {
2776                let ret = variant.ret().shift_in_escaping(1);
2777                let output = Binder::bind_with_vars(FnOutput::new(ret, vec![]), List::empty());
2778                let inputs = match field_idx {
2779                    None => variant.fields.clone(),
2780                    Some(i) => List::singleton(variant.fields[i.index()].clone()),
2781                };
2782                FnSig::new(
2783                    Safety::Safe,
2784                    rustc_abi::ExternAbi::Rust,
2785                    variant.requires.clone(),
2786                    inputs,
2787                    output,
2788                    false,
2789                )
2790            })
2791        })
2792    }
2793}
2794
2795impl TyKind {
2796    fn intern(self) -> Ty {
2797        Ty(Interned::new(self))
2798    }
2799}
2800
2801/// returns the same invariants as for `usize` which is the length of a slice
2802fn slice_invariants(overflow_mode: OverflowMode) -> &'static [Invariant] {
2803    static DEFAULT: LazyLock<[Invariant; 1]> = LazyLock::new(|| {
2804        [Invariant { pred: Binder::bind_with_sort(Expr::ge(Expr::nu(), Expr::zero()), Sort::Int) }]
2805    });
2806    static OVERFLOW: LazyLock<[Invariant; 2]> = LazyLock::new(|| {
2807        [
2808            Invariant {
2809                pred: Binder::bind_with_sort(Expr::ge(Expr::nu(), Expr::zero()), Sort::Int),
2810            },
2811            Invariant {
2812                pred: Binder::bind_with_sort(
2813                    Expr::le(Expr::nu(), Expr::uint_max(UintTy::Usize)),
2814                    Sort::Int,
2815                ),
2816            },
2817        ]
2818    });
2819    if matches!(overflow_mode, OverflowMode::Strict | OverflowMode::Lazy) {
2820        &*OVERFLOW
2821    } else {
2822        &*DEFAULT
2823    }
2824}
2825
2826fn uint_invariants(uint_ty: UintTy, overflow_mode: OverflowMode) -> &'static [Invariant] {
2827    static DEFAULT: LazyLock<[Invariant; 1]> = LazyLock::new(|| {
2828        [Invariant { pred: Binder::bind_with_sort(Expr::ge(Expr::nu(), Expr::zero()), Sort::Int) }]
2829    });
2830
2831    static OVERFLOW: LazyLock<UnordMap<UintTy, [Invariant; 2]>> = LazyLock::new(|| {
2832        UINT_TYS
2833            .into_iter()
2834            .map(|uint_ty| {
2835                let invariants = [
2836                    Invariant {
2837                        pred: Binder::bind_with_sort(Expr::ge(Expr::nu(), Expr::zero()), Sort::Int),
2838                    },
2839                    Invariant {
2840                        pred: Binder::bind_with_sort(
2841                            Expr::le(Expr::nu(), Expr::uint_max(uint_ty)),
2842                            Sort::Int,
2843                        ),
2844                    },
2845                ];
2846                (uint_ty, invariants)
2847            })
2848            .collect()
2849    });
2850    if matches!(overflow_mode, OverflowMode::Strict | OverflowMode::Lazy) {
2851        &OVERFLOW[&uint_ty]
2852    } else {
2853        &*DEFAULT
2854    }
2855}
2856
2857fn char_invariants() -> &'static [Invariant] {
2858    static INVARIANTS: LazyLock<[Invariant; 2]> = LazyLock::new(|| {
2859        [
2860            Invariant {
2861                pred: Binder::bind_with_sort(
2862                    Expr::le(
2863                        Expr::cast(Sort::Char, Sort::Int, Expr::nu()),
2864                        Expr::constant((char::MAX as u32).into()),
2865                    ),
2866                    Sort::Int,
2867                ),
2868            },
2869            Invariant {
2870                pred: Binder::bind_with_sort(
2871                    Expr::le(Expr::zero(), Expr::cast(Sort::Char, Sort::Int, Expr::nu())),
2872                    Sort::Int,
2873                ),
2874            },
2875        ]
2876    });
2877    &*INVARIANTS
2878}
2879
2880fn int_invariants(int_ty: IntTy, overflow_mode: OverflowMode) -> &'static [Invariant] {
2881    static DEFAULT: [Invariant; 0] = [];
2882
2883    static OVERFLOW: LazyLock<UnordMap<IntTy, [Invariant; 2]>> = LazyLock::new(|| {
2884        INT_TYS
2885            .into_iter()
2886            .map(|int_ty| {
2887                let invariants = [
2888                    Invariant {
2889                        pred: Binder::bind_with_sort(
2890                            Expr::ge(Expr::nu(), Expr::int_min(int_ty)),
2891                            Sort::Int,
2892                        ),
2893                    },
2894                    Invariant {
2895                        pred: Binder::bind_with_sort(
2896                            Expr::le(Expr::nu(), Expr::int_max(int_ty)),
2897                            Sort::Int,
2898                        ),
2899                    },
2900                ];
2901                (int_ty, invariants)
2902            })
2903            .collect()
2904    });
2905    if matches!(overflow_mode, OverflowMode::Strict | OverflowMode::Lazy) {
2906        &OVERFLOW[&int_ty]
2907    } else {
2908        &DEFAULT
2909    }
2910}
2911
2912impl_internable!(AdtDefData, AdtSortDefData, TyKind);
2913impl_slice_internable!(
2914    Ty,
2915    GenericArg,
2916    Ensures,
2917    InferMode,
2918    Sort,
2919    SortArg,
2920    GenericParamDef,
2921    TraitRef,
2922    Binder<ExistentialPredicate>,
2923    Clause,
2924    PolyVariant,
2925    Invariant,
2926    RefineParam,
2927    FluxDefId,
2928    SortParamKind,
2929    AssocReft
2930);
2931
2932#[macro_export]
2933macro_rules! _Int {
2934    ($int_ty:pat, $idxs:pat) => {
2935        TyKind::Indexed(BaseTy::Int($int_ty), $idxs)
2936    };
2937}
2938pub use crate::_Int as Int;
2939
2940#[macro_export]
2941macro_rules! _Uint {
2942    ($uint_ty:pat, $idxs:pat) => {
2943        TyKind::Indexed(BaseTy::Uint($uint_ty), $idxs)
2944    };
2945}
2946pub use crate::_Uint as Uint;
2947
2948#[macro_export]
2949macro_rules! _Bool {
2950    ($idxs:pat) => {
2951        TyKind::Indexed(BaseTy::Bool, $idxs)
2952    };
2953}
2954pub use crate::_Bool as Bool;
2955
2956#[macro_export]
2957macro_rules! _Char {
2958    ($idxs:pat) => {
2959        TyKind::Indexed(BaseTy::Char, $idxs)
2960    };
2961}
2962pub use crate::_Char as Char;
2963
2964#[macro_export]
2965macro_rules! _Ref {
2966    ($($pats:pat),+ $(,)?) => {
2967        $crate::rty::TyKind::Indexed($crate::rty::BaseTy::Ref($($pats),+), _)
2968    };
2969}
2970pub use crate::_Ref as Ref;
2971
2972pub struct WfckResults {
2973    pub owner: FluxOwnerId,
2974    param_sorts: UnordMap<fhir::ParamId, Sort>,
2975    bin_op_sorts: ItemLocalMap<Sort>,
2976    fn_app_sorts: ItemLocalMap<List<SortArg>>,
2977    coercions: ItemLocalMap<Vec<Coercion>>,
2978    field_projs: ItemLocalMap<FieldProj>,
2979    node_sorts: ItemLocalMap<Sort>,
2980    record_ctors: ItemLocalMap<DefId>,
2981}
2982
2983#[derive(Clone, Copy, Debug)]
2984pub enum Coercion {
2985    Inject(DefId),
2986    Project(DefId),
2987}
2988
2989pub type ItemLocalMap<T> = UnordMap<fhir::ItemLocalId, T>;
2990
2991#[derive(Debug)]
2992pub struct LocalTableInContext<'a, T> {
2993    owner: FluxOwnerId,
2994    data: &'a ItemLocalMap<T>,
2995}
2996
2997pub struct LocalTableInContextMut<'a, T> {
2998    owner: FluxOwnerId,
2999    data: &'a mut ItemLocalMap<T>,
3000}
3001
3002impl WfckResults {
3003    pub fn new(owner: impl Into<FluxOwnerId>) -> Self {
3004        Self {
3005            owner: owner.into(),
3006            param_sorts: UnordMap::default(),
3007            bin_op_sorts: ItemLocalMap::default(),
3008            coercions: ItemLocalMap::default(),
3009            field_projs: ItemLocalMap::default(),
3010            node_sorts: ItemLocalMap::default(),
3011            record_ctors: ItemLocalMap::default(),
3012            fn_app_sorts: ItemLocalMap::default(),
3013        }
3014    }
3015
3016    pub fn param_sorts_mut(&mut self) -> &mut UnordMap<fhir::ParamId, Sort> {
3017        &mut self.param_sorts
3018    }
3019
3020    pub fn param_sorts(&self) -> &UnordMap<fhir::ParamId, Sort> {
3021        &self.param_sorts
3022    }
3023
3024    pub fn bin_op_sorts_mut(&mut self) -> LocalTableInContextMut<'_, Sort> {
3025        LocalTableInContextMut { owner: self.owner, data: &mut self.bin_op_sorts }
3026    }
3027
3028    pub fn fn_app_sorts_mut(&mut self) -> LocalTableInContextMut<'_, List<SortArg>> {
3029        LocalTableInContextMut { owner: self.owner, data: &mut self.fn_app_sorts }
3030    }
3031
3032    pub fn fn_app_sorts(&self) -> LocalTableInContext<'_, List<SortArg>> {
3033        LocalTableInContext { owner: self.owner, data: &self.fn_app_sorts }
3034    }
3035
3036    pub fn bin_op_sorts(&self) -> LocalTableInContext<'_, Sort> {
3037        LocalTableInContext { owner: self.owner, data: &self.bin_op_sorts }
3038    }
3039
3040    pub fn coercions_mut(&mut self) -> LocalTableInContextMut<'_, Vec<Coercion>> {
3041        LocalTableInContextMut { owner: self.owner, data: &mut self.coercions }
3042    }
3043
3044    pub fn coercions(&self) -> LocalTableInContext<'_, Vec<Coercion>> {
3045        LocalTableInContext { owner: self.owner, data: &self.coercions }
3046    }
3047
3048    pub fn field_projs_mut(&mut self) -> LocalTableInContextMut<'_, FieldProj> {
3049        LocalTableInContextMut { owner: self.owner, data: &mut self.field_projs }
3050    }
3051
3052    pub fn field_projs(&self) -> LocalTableInContext<'_, FieldProj> {
3053        LocalTableInContext { owner: self.owner, data: &self.field_projs }
3054    }
3055
3056    pub fn node_sorts_mut(&mut self) -> LocalTableInContextMut<'_, Sort> {
3057        LocalTableInContextMut { owner: self.owner, data: &mut self.node_sorts }
3058    }
3059
3060    pub fn node_sorts(&self) -> LocalTableInContext<'_, Sort> {
3061        LocalTableInContext { owner: self.owner, data: &self.node_sorts }
3062    }
3063
3064    pub fn record_ctors_mut(&mut self) -> LocalTableInContextMut<'_, DefId> {
3065        LocalTableInContextMut { owner: self.owner, data: &mut self.record_ctors }
3066    }
3067
3068    pub fn record_ctors(&self) -> LocalTableInContext<'_, DefId> {
3069        LocalTableInContext { owner: self.owner, data: &self.record_ctors }
3070    }
3071}
3072
3073impl<T> LocalTableInContextMut<'_, T> {
3074    pub fn insert(&mut self, fhir_id: FhirId, value: T) {
3075        tracked_span_assert_eq!(self.owner, fhir_id.owner);
3076        self.data.insert(fhir_id.local_id, value);
3077    }
3078}
3079
3080impl<'a, T> LocalTableInContext<'a, T> {
3081    pub fn get(&self, fhir_id: FhirId) -> Option<&'a T> {
3082        tracked_span_assert_eq!(self.owner, fhir_id.owner);
3083        self.data.get(&fhir_id.local_id)
3084    }
3085}
3086
3087fn can_auto_strong(fn_sig: &PolyFnSig) -> bool {
3088    struct RegionDetector {
3089        has_region: bool,
3090    }
3091
3092    impl fold::TypeFolder for RegionDetector {
3093        fn fold_region(&mut self, re: &Region) -> Region {
3094            self.has_region = true;
3095            *re
3096        }
3097    }
3098    let mut detector = RegionDetector { has_region: false };
3099    fn_sig
3100        .skip_binder_ref()
3101        .output()
3102        .skip_binder_ref()
3103        .ret
3104        .fold_with(&mut detector);
3105
3106    !detector.has_region
3107}
3108/// The [`auto_strong`] function transforms function signatures by automatically converting
3109/// mutable reference parameters into strong references with associated ensures clauses. This
3110/// transformation is applied only when the function signature does not already contain region
3111/// variables in its return type.
3112///
3113/// Specifically, given a source function of type
3114///
3115///    fn (x: &mut InnerTy) -> bool
3116///
3117/// By default the above gives us an `rty::FnSig`
3118///
3119///    forall<>. fn (x: &mut InnerTy) -> bool
3120///
3121/// Which this function then transforms to
3122///
3123///     forall<l0: Loc>. fn (x: &strg<l0:InnerTy>) -> bool ensures l0:InnerTy
3124pub fn auto_strong(
3125    genv: GlobalEnv,
3126    def_id: impl IntoQueryParam<DefId>,
3127    fn_sig: PolyFnSig,
3128) -> PolyFnSig {
3129    // TODO(auto-strong): we only *really* need the first check `can_auto_strong` here.
3130    // The other two skip `auto-strong` as doing it breaks various downstream things
3131    // that should be fixed.
3132    if !can_auto_strong(&fn_sig)
3133        || matches!(genv.def_kind(def_id), rustc_hir::def::DefKind::Closure)
3134        || !fn_sig.skip_binder_ref().lifted
3135    {
3136        return fn_sig;
3137    }
3138    let kind = BoundReftKind::Anon;
3139    let mut vars = fn_sig.vars().to_vec();
3140    let fn_sig = fn_sig.skip_binder();
3141    // new list of (bound_var, inner_ty)
3142    let mut strg_bvars = vec![];
3143    // new list of input types
3144    let mut strg_inputs = vec![];
3145    // 1. Traverse inputs collecting strong locations
3146    for ty in &fn_sig.inputs {
3147        let strg_ty = if let TyKind::Indexed(BaseTy::Ref(re, inner_ty, Mutability::Mut), _) =
3148            ty.kind()
3149            && !inner_ty.is_slice()
3150        // TODO(auto-strong): including `slice` breaks `tock` for some reason we should replicate in our own tests...
3151        {
3152            // if input is &mut InnerTy create a new bound var `loc` for the strong location
3153            let var = {
3154                let idx = vars.len() + strg_bvars.len();
3155                BoundVar::from_usize(idx)
3156            };
3157            strg_bvars.push((var, inner_ty.clone()));
3158            let loc = Loc::Var(Var::Bound(INNERMOST, BoundReft { var, kind }));
3159            // and transform to &strg<loc:InnerTy>
3160            Ty::strg_ref(*re, Path::new(loc, List::empty()), inner_ty.clone())
3161        } else {
3162            // else leave input type unchanged
3163            ty.clone()
3164        };
3165        strg_inputs.push(strg_ty);
3166    }
3167    // 2. Add bound vars for strong locations
3168    for _ in 0..strg_bvars.len() {
3169        vars.push(BoundVariableKind::Refine(Sort::Loc, InferMode::EVar, kind));
3170    }
3171    // 3. Add ensures for strong locations
3172    let output = fn_sig.output.map(|out| {
3173        let mut ens = out.ensures.to_vec();
3174        for (var, inner_ty) in strg_bvars {
3175            let loc = Loc::Var(Var::Bound(INNERMOST.shifted_in(1), BoundReft { var, kind }));
3176            let path = Path::new(loc, List::empty());
3177            ens.push(Ensures::Type(path, inner_ty.shift_in_escaping(1)));
3178        }
3179        FnOutput { ensures: List::from_vec(ens), ..out }
3180    });
3181
3182    // 4. Reconstruct fn sig with new inputs and output and vars
3183    let fn_sig = FnSig { inputs: List::from_vec(strg_inputs), output, ..fn_sig };
3184    Binder::bind_with_vars(fn_sig, vars.into())
3185}