flux_middle/rty/
mod.rs

1//! Defines how flux represents refinement types internally. Definitions in this module are used
2//! during refinement type checking. A couple of important differences between definitions in this
3//! module and in [`crate::fhir`] are:
4//!
5//! * Types in this module use debruijn indices to represent local binders.
6//! * Data structures are interned so they can be cheaply cloned.
7mod binder;
8pub mod canonicalize;
9mod expr;
10pub mod fold;
11pub mod normalize;
12mod pretty;
13pub mod refining;
14pub mod region_matching;
15pub mod subst;
16use std::{borrow::Cow, cmp::Ordering, fmt, hash::Hash, sync::LazyLock};
17
18pub use SortInfer::*;
19pub use binder::{Binder, BoundReftKind, BoundVariableKind, BoundVariableKinds, EarlyBinder};
20pub use expr::{
21    AggregateKind, AliasReft, BinOp, BoundReft, Constant, Ctor, ESpan, EVid, EarlyReftParam, Expr,
22    ExprKind, FieldProj, HoleKind, InternalFuncKind, KVar, KVid, Lambda, Loc, Name, Path, Real,
23    SpecFuncKind, UnOp, Var,
24};
25pub use flux_arc_interner::List;
26use flux_arc_interner::{Interned, impl_internable, impl_slice_internable};
27use flux_common::{bug, tracked_span_assert_eq, tracked_span_bug};
28use flux_macros::{TypeFoldable, TypeVisitable};
29pub use flux_rustc_bridge::ty::{
30    AliasKind, BoundRegion, BoundRegionKind, BoundVar, Const, ConstKind, ConstVid, DebruijnIndex,
31    EarlyParamRegion, LateParamRegion, LateParamRegionKind,
32    Region::{self, *},
33    RegionVid,
34};
35use flux_rustc_bridge::{
36    ToRustc,
37    mir::{Place, RawPtrKind},
38    ty::{self, GenericArgsExt as _, VariantDef},
39};
40use itertools::Itertools;
41pub use normalize::{NormalizeInfo, NormalizedDefns, local_deps};
42use refining::{Refine as _, Refiner};
43use rustc_abi;
44pub use rustc_abi::{FIRST_VARIANT, VariantIdx};
45use rustc_data_structures::{fx::FxIndexMap, snapshot_map::SnapshotMap, unord::UnordMap};
46use rustc_hir::{LangItem, Safety, def_id::DefId};
47use rustc_index::{IndexSlice, IndexVec, newtype_index};
48use rustc_macros::{Decodable, Encodable, TyDecodable, TyEncodable, extension};
49use rustc_middle::ty::TyCtxt;
50pub use rustc_middle::{
51    mir::Mutability,
52    ty::{AdtFlags, ClosureKind, FloatTy, IntTy, ParamConst, ParamTy, ScalarInt, UintTy},
53};
54use rustc_span::{DUMMY_SP, Symbol, sym, symbol::kw};
55use rustc_type_ir::Upcast as _;
56pub use rustc_type_ir::{INNERMOST, TyVid};
57
58use self::fold::TypeFoldable;
59pub use crate::fhir::InferMode;
60use crate::{
61    LocalDefId,
62    def_id::{FluxDefId, FluxLocalDefId},
63    fhir::{self, FhirId, FluxOwnerId},
64    global_env::GlobalEnv,
65    pretty::{Pretty, PrettyCx},
66    queries::{QueryErr, QueryResult},
67    rty::subst::SortSubst,
68};
69
70/// The definition of the data sort automatically generated for a struct or enum.
71#[derive(Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
72pub struct AdtSortDef(Interned<AdtSortDefData>);
73
74#[derive(Debug, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
75pub struct AdtSortVariant {
76    /// The list of field names as declared in the `#[flux::refined_by(...)]` annotation
77    field_names: Vec<Symbol>,
78    /// The sort of each of the fields. Note that these can contain [sort variables]. Methods used
79    /// to access these sorts guarantee they are properly instantiated.
80    ///
81    /// [sort variables]: Sort::Var
82    sorts: List<Sort>,
83}
84
85impl AdtSortVariant {
86    pub fn new(fields: Vec<(Symbol, Sort)>) -> Self {
87        let (field_names, sorts) = fields.into_iter().unzip();
88        AdtSortVariant { field_names, sorts: List::from_vec(sorts) }
89    }
90
91    pub fn fields(&self) -> usize {
92        self.sorts.len()
93    }
94
95    pub fn field_names(&self) -> &Vec<Symbol> {
96        &self.field_names
97    }
98
99    pub fn sort_by_field_name(&self, args: &[Sort]) -> FxIndexMap<Symbol, Sort> {
100        std::iter::zip(&self.field_names, &self.sorts.fold_with(&mut SortSubst::new(args)))
101            .map(|(name, sort)| (*name, sort.clone()))
102            .collect()
103    }
104
105    pub fn field_by_name(
106        &self,
107        def_id: DefId,
108        args: &[Sort],
109        name: Symbol,
110    ) -> Option<(FieldProj, Sort)> {
111        let idx = self.field_names.iter().position(|it| name == *it)?;
112        let proj = FieldProj::Adt { def_id, field: idx as u32 };
113        let sort = self.sorts[idx].fold_with(&mut SortSubst::new(args));
114        Some((proj, sort))
115    }
116
117    pub fn field_sorts(&self, args: &[Sort]) -> List<Sort> {
118        self.sorts.fold_with(&mut SortSubst::new(args))
119    }
120
121    pub fn projections(&self, def_id: DefId) -> impl Iterator<Item = FieldProj> {
122        (0..self.fields()).map(move |i| FieldProj::Adt { def_id, field: i as u32 })
123    }
124}
125
126#[derive(Debug, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
127struct AdtSortDefData {
128    /// [`DefId`] of the struct or enum this data sort is associated to.
129    def_id: DefId,
130    /// The list of the type parameters used in the `#[flux::refined_by(..)]` annotation.
131    ///
132    /// See [`fhir::RefinedBy::sort_params`] for more details. This is a version of that but using
133    /// [`ParamTy`] instead of [`DefId`].
134    ///
135    /// The length of this list corresponds to the number of sort variables bound by this definition.
136    params: Vec<ParamTy>,
137    /// A vec of variants of the ADT;
138    /// - a `struct` sort -- used for types with a `refined_by` has a single variant;
139    /// - a `reflected` sort -- used for `reflected` enums have multiple variants
140    variants: IndexVec<VariantIdx, AdtSortVariant>,
141    is_reflected: bool,
142    is_struct: bool,
143}
144
145impl AdtSortDef {
146    pub fn new(
147        def_id: DefId,
148        params: Vec<ParamTy>,
149        variants: IndexVec<VariantIdx, AdtSortVariant>,
150        is_reflected: bool,
151        is_struct: bool,
152    ) -> Self {
153        Self(Interned::new(AdtSortDefData { def_id, params, variants, is_reflected, is_struct }))
154    }
155
156    pub fn did(&self) -> DefId {
157        self.0.def_id
158    }
159
160    pub fn variant(&self, idx: VariantIdx) -> &AdtSortVariant {
161        &self.0.variants[idx]
162    }
163
164    pub fn variants(&self) -> &IndexSlice<VariantIdx, AdtSortVariant> {
165        &self.0.variants
166    }
167
168    pub fn opt_struct_variant(&self) -> Option<&AdtSortVariant> {
169        if self.is_struct() { Some(self.struct_variant()) } else { None }
170    }
171
172    #[track_caller]
173    pub fn struct_variant(&self) -> &AdtSortVariant {
174        tracked_span_assert_eq!(self.0.is_struct, true);
175        &self.0.variants[FIRST_VARIANT]
176    }
177
178    pub fn is_reflected(&self) -> bool {
179        self.0.is_reflected
180    }
181
182    pub fn is_struct(&self) -> bool {
183        self.0.is_struct
184    }
185
186    pub fn to_sort(&self, args: &[GenericArg]) -> Sort {
187        let sorts = self
188            .filter_generic_args(args)
189            .map(|arg| arg.expect_base().sort())
190            .collect();
191
192        Sort::App(SortCtor::Adt(self.clone()), sorts)
193    }
194
195    /// Given a list of generic args, returns an iterator of the generic arguments that should be
196    /// mapped to sorts for instantiation.
197    pub fn filter_generic_args<'a, A>(&'a self, args: &'a [A]) -> impl Iterator<Item = &'a A> + 'a {
198        self.0.params.iter().map(|p| &args[p.index as usize])
199    }
200
201    pub fn identity_args(&self) -> List<Sort> {
202        (0..self.0.params.len())
203            .map(|i| Sort::Var(ParamSort::from(i)))
204            .collect()
205    }
206
207    /// Gives the number of sort variables bound by this definition.
208    pub fn param_count(&self) -> usize {
209        self.0.params.len()
210    }
211}
212
213#[derive(Debug, Clone, Default, Encodable, Decodable)]
214pub struct Generics {
215    pub parent: Option<DefId>,
216    pub parent_count: usize,
217    pub own_params: List<GenericParamDef>,
218    pub has_self: bool,
219}
220
221impl Generics {
222    pub fn count(&self) -> usize {
223        self.parent_count + self.own_params.len()
224    }
225
226    pub fn own_default_count(&self) -> usize {
227        self.own_params
228            .iter()
229            .filter(|param| {
230                match param.kind {
231                    GenericParamDefKind::Type { has_default }
232                    | GenericParamDefKind::Const { has_default }
233                    | GenericParamDefKind::Base { has_default } => has_default,
234                    GenericParamDefKind::Lifetime => false,
235                }
236            })
237            .count()
238    }
239
240    pub fn param_at(&self, param_index: usize, genv: GlobalEnv) -> QueryResult<GenericParamDef> {
241        if let Some(index) = param_index.checked_sub(self.parent_count) {
242            Ok(self.own_params[index].clone())
243        } else {
244            let parent = self.parent.expect("parent_count > 0 but no parent?");
245            genv.generics_of(parent)?.param_at(param_index, genv)
246        }
247    }
248
249    pub fn const_params(&self, genv: GlobalEnv) -> QueryResult<Vec<(ParamConst, Sort)>> {
250        // FIXME(nilehmann) this shouldn't use the methods in `flux_middle::sort_of` to get the sort
251        let mut res = vec![];
252        for i in 0..self.count() {
253            let param = self.param_at(i, genv)?;
254            if let GenericParamDefKind::Const { .. } = param.kind
255                && let Some(sort) = genv.sort_of_generic_param(param.def_id)?
256            {
257                let param_const = ParamConst { name: param.name, index: param.index };
258                res.push((param_const, sort));
259            }
260        }
261        Ok(res)
262    }
263}
264
265#[derive(Debug, Clone, TyEncodable, TyDecodable)]
266pub struct RefinementGenerics {
267    pub parent: Option<DefId>,
268    pub parent_count: usize,
269    pub own_params: List<RefineParam>,
270}
271
272#[derive(
273    PartialEq, Eq, Debug, Clone, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
274)]
275pub struct RefineParam {
276    pub sort: Sort,
277    pub name: Symbol,
278    pub mode: InferMode,
279}
280
281#[derive(Debug, Clone, PartialEq, Eq, Hash, Encodable, Decodable)]
282pub struct GenericParamDef {
283    pub kind: GenericParamDefKind,
284    pub def_id: DefId,
285    pub index: u32,
286    pub name: Symbol,
287}
288
289#[derive(Copy, Debug, Clone, PartialEq, Eq, Hash, Encodable, Decodable)]
290pub enum GenericParamDefKind {
291    Type { has_default: bool },
292    Base { has_default: bool },
293    Lifetime,
294    Const { has_default: bool },
295}
296
297pub const SELF_PARAM_TY: ParamTy = ParamTy { index: 0, name: kw::SelfUpper };
298
299#[derive(Debug, Clone, TyEncodable, TyDecodable)]
300pub struct GenericPredicates {
301    pub parent: Option<DefId>,
302    pub predicates: List<Clause>,
303}
304
305#[derive(
306    Debug, PartialEq, Eq, Hash, Clone, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
307)]
308pub struct Clause {
309    kind: Binder<ClauseKind>,
310}
311
312impl Clause {
313    pub fn new(vars: impl Into<List<BoundVariableKind>>, kind: ClauseKind) -> Self {
314        Clause { kind: Binder::bind_with_vars(kind, vars.into()) }
315    }
316
317    pub fn kind(&self) -> Binder<ClauseKind> {
318        self.kind.clone()
319    }
320
321    fn as_trait_clause(&self) -> Option<Binder<TraitPredicate>> {
322        let clause = self.kind();
323        if let ClauseKind::Trait(trait_clause) = clause.skip_binder_ref() {
324            Some(clause.rebind(trait_clause.clone()))
325        } else {
326            None
327        }
328    }
329
330    pub fn as_projection_clause(&self) -> Option<Binder<ProjectionPredicate>> {
331        let clause = self.kind();
332        if let ClauseKind::Projection(proj_clause) = clause.skip_binder_ref() {
333            Some(clause.rebind(proj_clause.clone()))
334        } else {
335            None
336        }
337    }
338
339    // FIXME(nilehmann) we should deal with the binder in all the places this is used instead of
340    // blindly skipping it here
341    pub fn kind_skipping_binder(&self) -> ClauseKind {
342        self.kind.clone().skip_binder()
343    }
344
345    /// Group `Fn` trait clauses with their corresponding `FnOnce::Output` projection
346    /// predicate. This assumes there's exactly one corresponding projection predicate and will
347    /// crash otherwise.
348    pub fn split_off_fn_trait_clauses(
349        genv: GlobalEnv,
350        clauses: &Clauses,
351    ) -> (Vec<Clause>, Vec<Binder<FnTraitPredicate>>) {
352        let mut fn_trait_clauses = vec![];
353        let mut fn_trait_output_clauses = vec![];
354        let mut rest = vec![];
355        for clause in clauses {
356            if let Some(trait_clause) = clause.as_trait_clause()
357                && let Some(kind) = genv.tcx().fn_trait_kind_from_def_id(trait_clause.def_id())
358            {
359                fn_trait_clauses.push((kind, trait_clause));
360            } else if let Some(proj_clause) = clause.as_projection_clause()
361                && genv.is_fn_output(proj_clause.projection_def_id())
362            {
363                fn_trait_output_clauses.push(proj_clause);
364            } else {
365                rest.push(clause.clone());
366            }
367        }
368        let fn_trait_clauses = fn_trait_clauses
369            .into_iter()
370            .map(|(kind, fn_trait_clause)| {
371                let mut candidates = vec![];
372                for fn_trait_output_clause in &fn_trait_output_clauses {
373                    if fn_trait_output_clause.self_ty() == fn_trait_clause.self_ty() {
374                        candidates.push(fn_trait_output_clause.clone());
375                    }
376                }
377                tracked_span_assert_eq!(candidates.len(), 1);
378                let proj_pred = candidates.pop().unwrap().skip_binder();
379                fn_trait_clause.map(|fn_trait_clause| {
380                    FnTraitPredicate {
381                        kind,
382                        self_ty: fn_trait_clause.self_ty().to_ty(),
383                        tupled_args: fn_trait_clause.trait_ref.args[1].expect_base().to_ty(),
384                        output: proj_pred.term.to_ty(),
385                    }
386                })
387            })
388            .collect_vec();
389        (rest, fn_trait_clauses)
390    }
391}
392
393impl<'tcx> ToRustc<'tcx> for Clause {
394    type T = rustc_middle::ty::Clause<'tcx>;
395
396    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
397        self.kind.to_rustc(tcx).upcast(tcx)
398    }
399}
400
401impl From<Binder<ClauseKind>> for Clause {
402    fn from(kind: Binder<ClauseKind>) -> Self {
403        Clause { kind }
404    }
405}
406
407pub type Clauses = List<Clause>;
408
409#[derive(
410    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
411)]
412pub enum ClauseKind {
413    Trait(TraitPredicate),
414    Projection(ProjectionPredicate),
415    RegionOutlives(RegionOutlivesPredicate),
416    TypeOutlives(TypeOutlivesPredicate),
417    ConstArgHasType(Const, Ty),
418}
419
420impl<'tcx> ToRustc<'tcx> for ClauseKind {
421    type T = rustc_middle::ty::ClauseKind<'tcx>;
422
423    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
424        match self {
425            ClauseKind::Trait(trait_predicate) => {
426                rustc_middle::ty::ClauseKind::Trait(trait_predicate.to_rustc(tcx))
427            }
428            ClauseKind::Projection(projection_predicate) => {
429                rustc_middle::ty::ClauseKind::Projection(projection_predicate.to_rustc(tcx))
430            }
431            ClauseKind::RegionOutlives(outlives_predicate) => {
432                rustc_middle::ty::ClauseKind::RegionOutlives(outlives_predicate.to_rustc(tcx))
433            }
434            ClauseKind::TypeOutlives(outlives_predicate) => {
435                rustc_middle::ty::ClauseKind::TypeOutlives(outlives_predicate.to_rustc(tcx))
436            }
437            ClauseKind::ConstArgHasType(constant, ty) => {
438                rustc_middle::ty::ClauseKind::ConstArgHasType(
439                    constant.to_rustc(tcx),
440                    ty.to_rustc(tcx),
441                )
442            }
443        }
444    }
445}
446
447#[derive(Eq, PartialEq, Hash, Clone, Debug, TyEncodable, TyDecodable)]
448pub struct OutlivesPredicate<T>(pub T, pub Region);
449
450pub type TypeOutlivesPredicate = OutlivesPredicate<Ty>;
451pub type RegionOutlivesPredicate = OutlivesPredicate<Region>;
452
453impl<'tcx, V: ToRustc<'tcx>> ToRustc<'tcx> for OutlivesPredicate<V> {
454    type T = rustc_middle::ty::OutlivesPredicate<'tcx, V::T>;
455
456    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
457        rustc_middle::ty::OutlivesPredicate(self.0.to_rustc(tcx), self.1.to_rustc(tcx))
458    }
459}
460
461#[derive(
462    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
463)]
464pub struct TraitPredicate {
465    pub trait_ref: TraitRef,
466}
467
468impl TraitPredicate {
469    fn self_ty(&self) -> SubsetTyCtor {
470        self.trait_ref.self_ty()
471    }
472}
473
474impl<'tcx> ToRustc<'tcx> for TraitPredicate {
475    type T = rustc_middle::ty::TraitPredicate<'tcx>;
476
477    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
478        rustc_middle::ty::TraitPredicate {
479            polarity: rustc_middle::ty::PredicatePolarity::Positive,
480            trait_ref: self.trait_ref.to_rustc(tcx),
481        }
482    }
483}
484
485pub type PolyTraitPredicate = Binder<TraitPredicate>;
486
487impl PolyTraitPredicate {
488    fn def_id(&self) -> DefId {
489        self.skip_binder_ref().trait_ref.def_id
490    }
491
492    fn self_ty(&self) -> Binder<SubsetTyCtor> {
493        self.clone().map(|predicate| predicate.self_ty())
494    }
495}
496
497#[derive(
498    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
499)]
500pub struct TraitRef {
501    pub def_id: DefId,
502    pub args: GenericArgs,
503}
504
505impl TraitRef {
506    pub fn self_ty(&self) -> SubsetTyCtor {
507        self.args[0].expect_base().clone()
508    }
509}
510
511impl<'tcx> ToRustc<'tcx> for TraitRef {
512    type T = rustc_middle::ty::TraitRef<'tcx>;
513
514    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
515        rustc_middle::ty::TraitRef::new(tcx, self.def_id, self.args.to_rustc(tcx))
516    }
517}
518
519pub type PolyTraitRef = Binder<TraitRef>;
520
521impl PolyTraitRef {
522    pub fn def_id(&self) -> DefId {
523        self.as_ref().skip_binder().def_id
524    }
525}
526
527#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
528pub enum ExistentialPredicate {
529    Trait(ExistentialTraitRef),
530    Projection(ExistentialProjection),
531    AutoTrait(DefId),
532}
533
534pub type PolyExistentialPredicate = Binder<ExistentialPredicate>;
535
536impl ExistentialPredicate {
537    /// See [`rustc_middle::ty::ExistentialPredicateStableCmpExt`]
538    pub fn stable_cmp(&self, tcx: TyCtxt, other: &Self) -> Ordering {
539        match (self, other) {
540            (ExistentialPredicate::Trait(_), ExistentialPredicate::Trait(_)) => Ordering::Equal,
541            (ExistentialPredicate::Projection(a), ExistentialPredicate::Projection(b)) => {
542                tcx.def_path_hash(a.def_id)
543                    .cmp(&tcx.def_path_hash(b.def_id))
544            }
545            (ExistentialPredicate::AutoTrait(a), ExistentialPredicate::AutoTrait(b)) => {
546                tcx.def_path_hash(*a).cmp(&tcx.def_path_hash(*b))
547            }
548            (ExistentialPredicate::Trait(_), _) => Ordering::Less,
549            (ExistentialPredicate::Projection(_), ExistentialPredicate::Trait(_)) => {
550                Ordering::Greater
551            }
552            (ExistentialPredicate::Projection(_), _) => Ordering::Less,
553            (ExistentialPredicate::AutoTrait(_), _) => Ordering::Greater,
554        }
555    }
556}
557
558impl<'tcx> ToRustc<'tcx> for ExistentialPredicate {
559    type T = rustc_middle::ty::ExistentialPredicate<'tcx>;
560
561    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
562        match self {
563            ExistentialPredicate::Trait(trait_ref) => {
564                let trait_ref = rustc_middle::ty::ExistentialTraitRef::new_from_args(
565                    tcx,
566                    trait_ref.def_id,
567                    trait_ref.args.to_rustc(tcx),
568                );
569                rustc_middle::ty::ExistentialPredicate::Trait(trait_ref)
570            }
571            ExistentialPredicate::Projection(projection) => {
572                rustc_middle::ty::ExistentialPredicate::Projection(
573                    rustc_middle::ty::ExistentialProjection::new_from_args(
574                        tcx,
575                        projection.def_id,
576                        projection.args.to_rustc(tcx),
577                        projection.term.skip_binder_ref().to_rustc(tcx).into(),
578                    ),
579                )
580            }
581            ExistentialPredicate::AutoTrait(def_id) => {
582                rustc_middle::ty::ExistentialPredicate::AutoTrait(*def_id)
583            }
584        }
585    }
586}
587
588#[derive(
589    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
590)]
591pub struct ExistentialTraitRef {
592    pub def_id: DefId,
593    pub args: GenericArgs,
594}
595
596pub type PolyExistentialTraitRef = Binder<ExistentialTraitRef>;
597
598impl PolyExistentialTraitRef {
599    pub fn def_id(&self) -> DefId {
600        self.as_ref().skip_binder().def_id
601    }
602}
603
604#[derive(
605    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
606)]
607pub struct ExistentialProjection {
608    pub def_id: DefId,
609    pub args: GenericArgs,
610    pub term: SubsetTyCtor,
611}
612
613#[derive(
614    PartialEq, Eq, Hash, Debug, Clone, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
615)]
616pub struct ProjectionPredicate {
617    pub projection_ty: AliasTy,
618    pub term: SubsetTyCtor,
619}
620
621impl Pretty for ProjectionPredicate {
622    fn fmt(&self, _cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
623        write!(
624            f,
625            "ProjectionPredicate << projection_ty = {:?}, term = {:?} >>",
626            self.projection_ty, self.term
627        )
628    }
629}
630
631impl ProjectionPredicate {
632    pub fn self_ty(&self) -> SubsetTyCtor {
633        self.projection_ty.self_ty().clone()
634    }
635}
636
637impl<'tcx> ToRustc<'tcx> for ProjectionPredicate {
638    type T = rustc_middle::ty::ProjectionPredicate<'tcx>;
639
640    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
641        rustc_middle::ty::ProjectionPredicate {
642            projection_term: rustc_middle::ty::AliasTerm::new_from_args(
643                tcx,
644                self.projection_ty.def_id,
645                self.projection_ty.args.to_rustc(tcx),
646            ),
647            term: self.term.as_bty_skipping_binder().to_rustc(tcx).into(),
648        }
649    }
650}
651
652pub type PolyProjectionPredicate = Binder<ProjectionPredicate>;
653
654impl PolyProjectionPredicate {
655    pub fn projection_def_id(&self) -> DefId {
656        self.skip_binder_ref().projection_ty.def_id
657    }
658
659    pub fn self_ty(&self) -> Binder<SubsetTyCtor> {
660        self.clone().map(|predicate| predicate.self_ty())
661    }
662}
663
664#[derive(
665    Clone, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
666)]
667pub struct FnTraitPredicate {
668    pub self_ty: Ty,
669    pub tupled_args: Ty,
670    pub output: Ty,
671    pub kind: ClosureKind,
672}
673
674impl Pretty for FnTraitPredicate {
675    fn fmt(&self, _cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
676        write!(
677            f,
678            "self = {:?}, args = {:?}, output = {:?}, kind = {}",
679            self.self_ty, self.tupled_args, self.output, self.kind
680        )
681    }
682}
683
684impl FnTraitPredicate {
685    pub fn fndef_sig(&self) -> FnSig {
686        let inputs = self.tupled_args.expect_tuple().iter().cloned().collect();
687        let ret = self.output.clone().shift_in_escaping(1);
688        let output = Binder::bind_with_vars(FnOutput::new(ret, vec![]), List::empty());
689        FnSig::new(Safety::Safe, rustc_abi::ExternAbi::Rust, List::empty(), inputs, output)
690    }
691}
692
693pub fn to_closure_sig(
694    tcx: TyCtxt,
695    closure_id: LocalDefId,
696    tys: &[Ty],
697    args: &flux_rustc_bridge::ty::GenericArgs,
698    poly_sig: &PolyFnSig,
699) -> PolyFnSig {
700    let closure_args = args.as_closure();
701    let kind_ty = closure_args.kind_ty().to_rustc(tcx);
702    let Some(kind) = kind_ty.to_opt_closure_kind() else {
703        bug!("to_closure_sig: expected closure kind, found {kind_ty:?}");
704    };
705
706    let mut vars = poly_sig.vars().clone().to_vec();
707    let fn_sig = poly_sig.clone().skip_binder();
708    let closure_ty = Ty::closure(closure_id.into(), tys, args);
709    let env_ty = match kind {
710        ClosureKind::Fn => {
711            vars.push(BoundVariableKind::Region(BoundRegionKind::ClosureEnv));
712            let br = BoundRegion {
713                var: BoundVar::from_usize(vars.len() - 1),
714                kind: BoundRegionKind::ClosureEnv,
715            };
716            Ty::mk_ref(ReBound(INNERMOST, br), closure_ty, Mutability::Not)
717        }
718        ClosureKind::FnMut => {
719            vars.push(BoundVariableKind::Region(BoundRegionKind::ClosureEnv));
720            let br = BoundRegion {
721                var: BoundVar::from_usize(vars.len() - 1),
722                kind: BoundRegionKind::ClosureEnv,
723            };
724            Ty::mk_ref(ReBound(INNERMOST, br), closure_ty, Mutability::Mut)
725        }
726        ClosureKind::FnOnce => closure_ty,
727    };
728
729    let inputs = std::iter::once(env_ty)
730        .chain(fn_sig.inputs().iter().cloned())
731        .collect::<Vec<_>>();
732    let output = fn_sig.output().clone();
733
734    let fn_sig = crate::rty::FnSig::new(
735        fn_sig.safety,
736        fn_sig.abi,
737        fn_sig.requires.clone(), // crate::rty::List::empty(),
738        inputs.into(),
739        output,
740    );
741
742    PolyFnSig::bind_with_vars(fn_sig, List::from(vars))
743}
744
745#[derive(
746    Clone, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
747)]
748pub struct CoroutineObligPredicate {
749    pub def_id: DefId,
750    pub resume_ty: Ty,
751    pub upvar_tys: List<Ty>,
752    pub output: Ty,
753}
754
755#[derive(Copy, Clone, Encodable, Decodable, Hash, PartialEq, Eq)]
756pub struct AssocReft {
757    pub def_id: FluxDefId,
758    // NOTE: Field is used to denote final associated generic refinements on Traits
759    pub final_: bool,
760}
761
762impl AssocReft {
763    pub fn new(def_id: FluxDefId, final_: bool) -> Self {
764        Self { def_id, final_ }
765    }
766
767    pub fn name(&self) -> Symbol {
768        self.def_id.name()
769    }
770
771    pub fn def_id(&self) -> FluxDefId {
772        self.def_id
773    }
774}
775
776#[derive(Clone, Encodable, Decodable)]
777pub struct AssocRefinements {
778    pub items: List<AssocReft>,
779}
780
781impl Default for AssocRefinements {
782    fn default() -> Self {
783        Self { items: List::empty() }
784    }
785}
786
787impl AssocRefinements {
788    pub fn get(&self, assoc_id: FluxDefId) -> AssocReft {
789        *self
790            .items
791            .into_iter()
792            .find(|it| it.def_id == assoc_id)
793            .unwrap_or_else(|| bug!("caller should guarantee existence of associated refinement"))
794    }
795
796    pub fn find(&self, name: Symbol) -> Option<FluxDefId> {
797        Some(
798            self.items
799                .into_iter()
800                .find(|it| it.name() == name)?
801                .def_id(),
802        )
803    }
804}
805
806#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
807pub enum SortCtor {
808    Set,
809    Map,
810    Adt(AdtSortDef),
811    User { name: Symbol },
812}
813
814newtype_index! {
815    /// [`ParamSort`] is used for polymorphic sorts (`Set`, `Map`, etc.) and [bit-vector size parameters].
816    /// They should occur "bound" under a [`PolyFuncSort`] or an [`AdtSortDef`]. We assume there's a
817    /// single binder and a [`ParamSort`] represents a variable as an index into the list of variables
818    /// bound by that binder, i.e., the representation doesnt't support higher-ranked sorts.
819    ///
820    /// [bit-vector size parameters]: BvSize::Param
821    #[debug_format = "?{}s"]
822    #[encodable]
823    pub struct ParamSort {}
824}
825
826newtype_index! {
827    /// A *sort* *v*variable *id*
828    #[debug_format = "?{}s"]
829    #[encodable]
830    pub struct SortVid {}
831}
832
833impl ena::unify::UnifyKey for SortVid {
834    type Value = Option<Sort>;
835
836    #[inline]
837    fn index(&self) -> u32 {
838        self.as_u32()
839    }
840
841    #[inline]
842    fn from_index(u: u32) -> Self {
843        SortVid::from_u32(u)
844    }
845
846    fn tag() -> &'static str {
847        "SortVid"
848    }
849}
850
851impl ena::unify::EqUnifyValue for Sort {}
852
853newtype_index! {
854    /// A *num*eric *v*variable *id*
855    #[debug_format = "?{}n"]
856    #[encodable]
857    pub struct NumVid {}
858}
859
860#[derive(Debug, Clone, Copy, PartialEq, Eq)]
861pub enum NumVarValue {
862    Real,
863    Int,
864    BitVec(BvSize),
865}
866
867impl NumVarValue {
868    pub fn to_sort(self) -> Sort {
869        match self {
870            NumVarValue::Real => Sort::Real,
871            NumVarValue::Int => Sort::Int,
872            NumVarValue::BitVec(size) => Sort::BitVec(size),
873        }
874    }
875}
876
877impl ena::unify::UnifyKey for NumVid {
878    type Value = Option<NumVarValue>;
879
880    #[inline]
881    fn index(&self) -> u32 {
882        self.as_u32()
883    }
884
885    #[inline]
886    fn from_index(u: u32) -> Self {
887        NumVid::from_u32(u)
888    }
889
890    fn tag() -> &'static str {
891        "NumVid"
892    }
893}
894
895impl ena::unify::EqUnifyValue for NumVarValue {}
896
897/// A placeholder for a sort that needs to be inferred
898#[derive(PartialEq, Eq, Clone, Copy, Hash, Encodable, Decodable)]
899pub enum SortInfer {
900    /// A sort variable.
901    SortVar(SortVid),
902    /// A numeric sort variable.
903    NumVar(NumVid),
904}
905
906newtype_index! {
907    /// A *b*it *v*ector *size* *v*variable *id*
908    #[debug_format = "?{}size"]
909    #[encodable]
910    pub struct BvSizeVid {}
911}
912
913impl ena::unify::UnifyKey for BvSizeVid {
914    type Value = Option<BvSize>;
915
916    #[inline]
917    fn index(&self) -> u32 {
918        self.as_u32()
919    }
920
921    #[inline]
922    fn from_index(u: u32) -> Self {
923        BvSizeVid::from_u32(u)
924    }
925
926    fn tag() -> &'static str {
927        "BvSizeVid"
928    }
929}
930
931impl ena::unify::EqUnifyValue for BvSize {}
932
933#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
934pub enum Sort {
935    Int,
936    Bool,
937    Real,
938    BitVec(BvSize),
939    Str,
940    Char,
941    Loc,
942    Param(ParamTy),
943    Tuple(List<Sort>),
944    Alias(AliasKind, AliasTy),
945    Func(PolyFuncSort),
946    App(SortCtor, List<Sort>),
947    Var(ParamSort),
948    Infer(SortInfer),
949    Err,
950}
951
952impl Sort {
953    pub fn tuple(sorts: impl Into<List<Sort>>) -> Self {
954        Sort::Tuple(sorts.into())
955    }
956
957    pub fn app(ctor: SortCtor, sorts: List<Sort>) -> Self {
958        Sort::App(ctor, sorts)
959    }
960
961    pub fn unit() -> Self {
962        Self::tuple(vec![])
963    }
964
965    #[track_caller]
966    pub fn expect_func(&self) -> &PolyFuncSort {
967        if let Sort::Func(sort) = self { sort } else { bug!("expected `Sort::Func`") }
968    }
969
970    pub fn is_loc(&self) -> bool {
971        matches!(self, Sort::Loc)
972    }
973
974    pub fn is_unit(&self) -> bool {
975        matches!(self, Sort::Tuple(sorts) if sorts.is_empty())
976    }
977
978    pub fn is_unit_adt(&self) -> Option<DefId> {
979        if let Sort::App(SortCtor::Adt(sort_def), _) = self
980            && let Some(variant) = sort_def.opt_struct_variant()
981            && variant.fields() == 0
982        {
983            Some(sort_def.did())
984        } else {
985            None
986        }
987    }
988
989    /// Whether the sort is a function with return sort bool
990    pub fn is_pred(&self) -> bool {
991        matches!(self, Sort::Func(fsort) if fsort.skip_binders().output().is_bool())
992    }
993
994    /// Returns `true` if the sort is [`Bool`].
995    ///
996    /// [`Bool`]: Sort::Bool
997    #[must_use]
998    pub fn is_bool(&self) -> bool {
999        matches!(self, Self::Bool)
1000    }
1001
1002    pub fn is_numeric(&self) -> bool {
1003        matches!(self, Self::Int | Self::Real)
1004    }
1005
1006    pub fn walk(&self, mut f: impl FnMut(&Sort, &[FieldProj])) {
1007        fn go(sort: &Sort, f: &mut impl FnMut(&Sort, &[FieldProj]), proj: &mut Vec<FieldProj>) {
1008            match sort {
1009                Sort::Tuple(flds) => {
1010                    for (i, sort) in flds.iter().enumerate() {
1011                        proj.push(FieldProj::Tuple { arity: flds.len(), field: i as u32 });
1012                        go(sort, f, proj);
1013                        proj.pop();
1014                    }
1015                }
1016                Sort::App(SortCtor::Adt(sort_def), args) if sort_def.is_struct() => {
1017                    let field_sorts = sort_def.struct_variant().field_sorts(args);
1018                    for (i, sort) in field_sorts.iter().enumerate() {
1019                        proj.push(FieldProj::Adt { def_id: sort_def.did(), field: i as u32 });
1020                        go(sort, f, proj);
1021                        proj.pop();
1022                    }
1023                }
1024                _ => {
1025                    f(sort, proj);
1026                }
1027            }
1028        }
1029        go(self, &mut f, &mut vec![]);
1030    }
1031}
1032
1033/// The size of a [bit-vector]
1034///
1035/// [bit-vector]: Sort::BitVec
1036#[derive(Clone, Copy, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1037pub enum BvSize {
1038    /// A fixed size
1039    Fixed(u32),
1040    /// A size that has been parameterized, e.g., bound under a [`PolyFuncSort`]
1041    Param(ParamSort),
1042    /// A size that needs to be inferred. Used during sort checking to instantiate bit-vector
1043    /// sizes at call-sites.
1044    Infer(BvSizeVid),
1045}
1046
1047impl rustc_errors::IntoDiagArg for Sort {
1048    fn into_diag_arg(self, _path: &mut Option<std::path::PathBuf>) -> rustc_errors::DiagArgValue {
1049        rustc_errors::DiagArgValue::Str(Cow::Owned(format!("{self:?}")))
1050    }
1051}
1052
1053impl rustc_errors::IntoDiagArg for FuncSort {
1054    fn into_diag_arg(self, _path: &mut Option<std::path::PathBuf>) -> rustc_errors::DiagArgValue {
1055        rustc_errors::DiagArgValue::Str(Cow::Owned(format!("{self:?}")))
1056    }
1057}
1058
1059#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
1060pub struct FuncSort {
1061    pub inputs_and_output: List<Sort>,
1062}
1063
1064impl FuncSort {
1065    pub fn new(mut inputs: Vec<Sort>, output: Sort) -> Self {
1066        inputs.push(output);
1067        FuncSort { inputs_and_output: List::from_vec(inputs) }
1068    }
1069
1070    pub fn inputs(&self) -> &[Sort] {
1071        &self.inputs_and_output[0..self.inputs_and_output.len() - 1]
1072    }
1073
1074    pub fn output(&self) -> &Sort {
1075        &self.inputs_and_output[self.inputs_and_output.len() - 1]
1076    }
1077
1078    pub fn to_poly(&self) -> PolyFuncSort {
1079        PolyFuncSort::new(List::empty(), self.clone())
1080    }
1081}
1082
1083/// See [`PolyFuncSort`]
1084#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable)]
1085pub enum SortParamKind {
1086    Sort,
1087    BvSize,
1088}
1089
1090/// A polymorphic function sort parametric over [sorts] or [bit-vector sizes].
1091///
1092/// Parameterizing over bit-vector sizes is a bit of a stretch, because smtlib doesn't support full
1093/// parametric reasoning over them. As long as we used functions parameterized over a size monomorphically
1094/// we should be fine. Right now, we can guarantee this, because size parameters are not exposed in
1095/// the surface syntax and they are only used for predefined (interpreted) theory functions.
1096///
1097/// [sorts]: Sort
1098/// [bit-vector sizes]: BvSize::Param
1099#[derive(Clone, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable)]
1100pub struct PolyFuncSort {
1101    /// The list of parameters including sorts and bit vector sizes
1102    params: List<SortParamKind>,
1103    fsort: FuncSort,
1104}
1105
1106impl PolyFuncSort {
1107    pub fn new(params: List<SortParamKind>, fsort: FuncSort) -> Self {
1108        PolyFuncSort { params, fsort }
1109    }
1110
1111    pub fn skip_binders(&self) -> FuncSort {
1112        self.fsort.clone()
1113    }
1114
1115    pub fn instantiate_identity(&self) -> FuncSort {
1116        self.fsort.clone()
1117    }
1118
1119    pub fn expect_mono(&self) -> FuncSort {
1120        assert!(self.params.is_empty());
1121        self.fsort.clone()
1122    }
1123
1124    pub fn params(&self) -> impl ExactSizeIterator<Item = SortParamKind> + '_ {
1125        self.params.iter().copied()
1126    }
1127
1128    pub fn instantiate(&self, args: &[SortArg]) -> FuncSort {
1129        self.fsort.fold_with(&mut SortSubst::new(args))
1130    }
1131}
1132
1133/// An argument for a generic parameter in a [`Sort`] which can be either a generic sort or a
1134/// generic bit-vector size.
1135#[derive(
1136    Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
1137)]
1138pub enum SortArg {
1139    Sort(Sort),
1140    BvSize(BvSize),
1141}
1142
1143#[derive(Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1144pub enum ConstantInfo {
1145    /// An uninterpreted constant
1146    Uninterpreted,
1147    /// A non-integral constant whose value is specified by the user
1148    Interpreted(Expr, Sort),
1149}
1150
1151#[derive(Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1152pub struct AdtDef(Interned<AdtDefData>);
1153
1154#[derive(Debug, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1155pub struct AdtDefData {
1156    invariants: Vec<Invariant>,
1157    sort_def: AdtSortDef,
1158    opaque: bool,
1159    rustc: ty::AdtDef,
1160}
1161
1162/// Option-like enum to explicitly mark that we don't have information about an ADT because it was
1163/// annotated with `#[flux::opaque]`. Note that only structs can be marked as opaque.
1164#[derive(Clone, Debug, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
1165pub enum Opaqueness<T> {
1166    Opaque,
1167    Transparent(T),
1168}
1169
1170impl<T> Opaqueness<T> {
1171    pub fn map<S>(self, f: impl FnOnce(T) -> S) -> Opaqueness<S> {
1172        match self {
1173            Opaqueness::Opaque => Opaqueness::Opaque,
1174            Opaqueness::Transparent(value) => Opaqueness::Transparent(f(value)),
1175        }
1176    }
1177
1178    pub fn as_ref(&self) -> Opaqueness<&T> {
1179        match self {
1180            Opaqueness::Opaque => Opaqueness::Opaque,
1181            Opaqueness::Transparent(value) => Opaqueness::Transparent(value),
1182        }
1183    }
1184
1185    pub fn as_deref(&self) -> Opaqueness<&T::Target>
1186    where
1187        T: std::ops::Deref,
1188    {
1189        match self {
1190            Opaqueness::Opaque => Opaqueness::Opaque,
1191            Opaqueness::Transparent(value) => Opaqueness::Transparent(value.deref()),
1192        }
1193    }
1194
1195    pub fn ok_or_else<E>(self, err: impl FnOnce() -> E) -> Result<T, E> {
1196        match self {
1197            Opaqueness::Transparent(v) => Ok(v),
1198            Opaqueness::Opaque => Err(err()),
1199        }
1200    }
1201
1202    #[track_caller]
1203    pub fn expect(self, msg: &str) -> T {
1204        match self {
1205            Opaqueness::Transparent(val) => val,
1206            Opaqueness::Opaque => bug!("{}", msg),
1207        }
1208    }
1209
1210    pub fn ok_or_query_err(self, struct_id: DefId) -> Result<T, QueryErr> {
1211        self.ok_or_else(|| QueryErr::OpaqueStruct { struct_id })
1212    }
1213}
1214
1215impl<T, E> Opaqueness<Result<T, E>> {
1216    pub fn transpose(self) -> Result<Opaqueness<T>, E> {
1217        match self {
1218            Opaqueness::Transparent(Ok(x)) => Ok(Opaqueness::Transparent(x)),
1219            Opaqueness::Transparent(Err(e)) => Err(e),
1220            Opaqueness::Opaque => Ok(Opaqueness::Opaque),
1221        }
1222    }
1223}
1224
1225pub static INT_TYS: [IntTy; 6] =
1226    [IntTy::Isize, IntTy::I8, IntTy::I16, IntTy::I32, IntTy::I64, IntTy::I128];
1227pub static UINT_TYS: [UintTy; 6] =
1228    [UintTy::Usize, UintTy::U8, UintTy::U16, UintTy::U32, UintTy::U64, UintTy::U128];
1229
1230#[derive(
1231    Debug, Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable, TypeFoldable, TypeVisitable,
1232)]
1233pub struct Invariant {
1234    // This predicate may have sort variables, but we don't explicitly mark it like in `PolyFuncSort`.
1235    // See comment on `apply` for details.
1236    pred: Binder<Expr>,
1237}
1238
1239impl Invariant {
1240    pub fn new(pred: Binder<Expr>) -> Self {
1241        Self { pred }
1242    }
1243
1244    pub fn apply(&self, idx: &Expr) -> Expr {
1245        // The predicate may have sort variables but we don't explicitly instantiate them. This
1246        // works because within an expression, sort variables can only appear inside the sort
1247        // annotation for a lambda and invariants cannot have lambdas. It remains to instantiate
1248        // variables in the sort of the binder itself, but since we are removing it, we can avoid
1249        // the explicit instantiation. Ultimately, this works because the expression we generate in
1250        // fixpoint doesn't need sort annotations (sorts are re-inferred).
1251        self.pred.replace_bound_reft(idx)
1252    }
1253}
1254
1255pub type PolyVariants = List<Binder<VariantSig>>;
1256pub type PolyVariant = Binder<VariantSig>;
1257
1258#[derive(Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1259pub struct VariantSig {
1260    pub adt_def: AdtDef,
1261    pub args: GenericArgs,
1262    pub fields: List<Ty>,
1263    pub idx: Expr,
1264    pub requires: List<Expr>,
1265}
1266
1267pub type PolyFnSig = Binder<FnSig>;
1268
1269#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
1270pub struct FnSig {
1271    pub safety: Safety,
1272    pub abi: rustc_abi::ExternAbi,
1273    pub requires: List<Expr>,
1274    pub inputs: List<Ty>,
1275    pub output: Binder<FnOutput>,
1276}
1277
1278#[derive(
1279    Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
1280)]
1281pub struct FnOutput {
1282    pub ret: Ty,
1283    pub ensures: List<Ensures>,
1284}
1285
1286#[derive(Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
1287pub enum Ensures {
1288    Type(Path, Ty),
1289    Pred(Expr),
1290}
1291
1292#[derive(Debug, TypeVisitable, TypeFoldable)]
1293pub struct Qualifier {
1294    pub def_id: FluxLocalDefId,
1295    pub body: Binder<Expr>,
1296    pub global: bool,
1297}
1298
1299/// A [`PrimProp`] is a single property for a primitive operation which
1300/// can be conjoined to get the definition of the [`PrimRel`] for that
1301/// primitive operation.
1302#[derive(Debug, TypeVisitable, TypeFoldable)]
1303pub struct PrimProp {
1304    pub def_id: FluxLocalDefId,
1305    pub op: BinOp,
1306    pub body: Binder<Expr>,
1307}
1308
1309#[derive(Debug, TypeVisitable, TypeFoldable)]
1310pub struct PrimRel {
1311    pub body: Binder<Expr>,
1312}
1313
1314pub type TyCtor = Binder<Ty>;
1315
1316impl TyCtor {
1317    pub fn to_ty(&self) -> Ty {
1318        match &self.vars()[..] {
1319            [] => {
1320                return self.skip_binder_ref().shift_out_escaping(1);
1321            }
1322            [BoundVariableKind::Refine(sort, ..)] => {
1323                if sort.is_unit() {
1324                    return self.replace_bound_reft(&Expr::unit());
1325                }
1326                if let Some(def_id) = sort.is_unit_adt() {
1327                    return self.replace_bound_reft(&Expr::unit_struct(def_id));
1328                }
1329            }
1330            _ => {}
1331        }
1332        Ty::exists(self.clone())
1333    }
1334}
1335
1336#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1337pub struct Ty(Interned<TyKind>);
1338
1339impl Ty {
1340    pub fn kind(&self) -> &TyKind {
1341        &self.0
1342    }
1343
1344    /// Dummy type used for the `Self` of a `TraitRef` created when converting a trait object, and
1345    /// which gets removed in `ExistentialTraitRef`. This type must not appear anywhere in other
1346    /// converted types and must be a valid `rustc` type (i.e., we must be able to call `to_rustc`
1347    /// on it). `TyKind::Infer(TyVid(0))` does the job, with the caveat that we must skip 0 when
1348    /// generating `TyKind::Infer` for "type holes".
1349    pub fn trait_object_dummy_self() -> Ty {
1350        Ty::infer(TyVid::from_u32(0))
1351    }
1352
1353    pub fn dynamic(preds: impl Into<List<Binder<ExistentialPredicate>>>, region: Region) -> Ty {
1354        BaseTy::Dynamic(preds.into(), region).to_ty()
1355    }
1356
1357    pub fn strg_ref(re: Region, path: Path, ty: Ty) -> Ty {
1358        TyKind::StrgRef(re, path, ty).intern()
1359    }
1360
1361    pub fn ptr(pk: impl Into<PtrKind>, path: impl Into<Path>) -> Ty {
1362        TyKind::Ptr(pk.into(), path.into()).intern()
1363    }
1364
1365    pub fn constr(p: impl Into<Expr>, ty: Ty) -> Ty {
1366        TyKind::Constr(p.into(), ty).intern()
1367    }
1368
1369    pub fn uninit() -> Ty {
1370        TyKind::Uninit.intern()
1371    }
1372
1373    pub fn indexed(bty: BaseTy, idx: impl Into<Expr>) -> Ty {
1374        TyKind::Indexed(bty, idx.into()).intern()
1375    }
1376
1377    pub fn exists(ty: Binder<Ty>) -> Ty {
1378        TyKind::Exists(ty).intern()
1379    }
1380
1381    pub fn exists_with_constr(bty: BaseTy, pred: Expr) -> Ty {
1382        let sort = bty.sort();
1383        let ty = Ty::indexed(bty, Expr::nu());
1384        Ty::exists(Binder::bind_with_sort(Ty::constr(pred, ty), sort))
1385    }
1386
1387    pub fn discr(adt_def: AdtDef, place: Place) -> Ty {
1388        TyKind::Discr(adt_def, place).intern()
1389    }
1390
1391    pub fn unit() -> Ty {
1392        Ty::tuple(vec![])
1393    }
1394
1395    pub fn bool() -> Ty {
1396        BaseTy::Bool.to_ty()
1397    }
1398
1399    pub fn int(int_ty: IntTy) -> Ty {
1400        BaseTy::Int(int_ty).to_ty()
1401    }
1402
1403    pub fn uint(uint_ty: UintTy) -> Ty {
1404        BaseTy::Uint(uint_ty).to_ty()
1405    }
1406
1407    pub fn param(param_ty: ParamTy) -> Ty {
1408        TyKind::Param(param_ty).intern()
1409    }
1410
1411    pub fn downcast(
1412        adt: AdtDef,
1413        args: GenericArgs,
1414        ty: Ty,
1415        variant: VariantIdx,
1416        fields: List<Ty>,
1417    ) -> Ty {
1418        TyKind::Downcast(adt, args, ty, variant, fields).intern()
1419    }
1420
1421    pub fn blocked(ty: Ty) -> Ty {
1422        TyKind::Blocked(ty).intern()
1423    }
1424
1425    pub fn str() -> Ty {
1426        BaseTy::Str.to_ty()
1427    }
1428
1429    pub fn char() -> Ty {
1430        BaseTy::Char.to_ty()
1431    }
1432
1433    pub fn float(float_ty: FloatTy) -> Ty {
1434        BaseTy::Float(float_ty).to_ty()
1435    }
1436
1437    pub fn mk_ref(region: Region, ty: Ty, mutbl: Mutability) -> Ty {
1438        BaseTy::Ref(region, ty, mutbl).to_ty()
1439    }
1440
1441    pub fn mk_slice(ty: Ty) -> Ty {
1442        BaseTy::Slice(ty).to_ty()
1443    }
1444
1445    pub fn mk_box(genv: GlobalEnv, deref_ty: Ty, alloc_ty: Ty) -> QueryResult<Ty> {
1446        let def_id = genv.tcx().require_lang_item(LangItem::OwnedBox, DUMMY_SP);
1447        let adt_def = genv.adt_def(def_id)?;
1448
1449        let args = List::from_arr([GenericArg::Ty(deref_ty), GenericArg::Ty(alloc_ty)]);
1450
1451        let bty = BaseTy::adt(adt_def, args);
1452        Ok(Ty::indexed(bty, Expr::unit_struct(def_id)))
1453    }
1454
1455    pub fn mk_box_with_default_alloc(genv: GlobalEnv, deref_ty: Ty) -> QueryResult<Ty> {
1456        let def_id = genv.tcx().require_lang_item(LangItem::OwnedBox, DUMMY_SP);
1457
1458        let generics = genv.generics_of(def_id)?;
1459        let alloc_ty = genv
1460            .lower_type_of(generics.own_params[1].def_id)?
1461            .skip_binder()
1462            .refine(&Refiner::default_for_item(genv, def_id)?)?;
1463
1464        Ty::mk_box(genv, deref_ty, alloc_ty)
1465    }
1466
1467    pub fn tuple(tys: impl Into<List<Ty>>) -> Ty {
1468        BaseTy::Tuple(tys.into()).to_ty()
1469    }
1470
1471    pub fn array(ty: Ty, c: Const) -> Ty {
1472        BaseTy::Array(ty, c).to_ty()
1473    }
1474
1475    pub fn closure(
1476        did: DefId,
1477        tys: impl Into<List<Ty>>,
1478        args: &flux_rustc_bridge::ty::GenericArgs,
1479    ) -> Ty {
1480        BaseTy::Closure(did, tys.into(), args.clone()).to_ty()
1481    }
1482
1483    pub fn coroutine(did: DefId, resume_ty: Ty, upvar_tys: List<Ty>) -> Ty {
1484        BaseTy::Coroutine(did, resume_ty, upvar_tys).to_ty()
1485    }
1486
1487    pub fn never() -> Ty {
1488        BaseTy::Never.to_ty()
1489    }
1490
1491    pub fn infer(vid: TyVid) -> Ty {
1492        TyKind::Infer(vid).intern()
1493    }
1494
1495    pub fn unconstr(&self) -> (Ty, Expr) {
1496        fn go(this: &Ty, preds: &mut Vec<Expr>) -> Ty {
1497            if let TyKind::Constr(pred, ty) = this.kind() {
1498                preds.push(pred.clone());
1499                go(ty, preds)
1500            } else {
1501                this.clone()
1502            }
1503        }
1504        let mut preds = vec![];
1505        (go(self, &mut preds), Expr::and_from_iter(preds))
1506    }
1507
1508    pub fn unblocked(&self) -> Ty {
1509        match self.kind() {
1510            TyKind::Blocked(ty) => ty.clone(),
1511            _ => self.clone(),
1512        }
1513    }
1514
1515    /// Whether the type is an `int` or a `uint`
1516    pub fn is_integral(&self) -> bool {
1517        self.as_bty_skipping_existentials()
1518            .map(BaseTy::is_integral)
1519            .unwrap_or_default()
1520    }
1521
1522    /// Whether the type is a `bool`
1523    pub fn is_bool(&self) -> bool {
1524        self.as_bty_skipping_existentials()
1525            .map(BaseTy::is_bool)
1526            .unwrap_or_default()
1527    }
1528
1529    /// Whether the type is a `char`
1530    pub fn is_char(&self) -> bool {
1531        self.as_bty_skipping_existentials()
1532            .map(BaseTy::is_char)
1533            .unwrap_or_default()
1534    }
1535
1536    pub fn is_uninit(&self) -> bool {
1537        matches!(self.kind(), TyKind::Uninit)
1538    }
1539
1540    pub fn is_box(&self) -> bool {
1541        self.as_bty_skipping_existentials()
1542            .map(BaseTy::is_box)
1543            .unwrap_or_default()
1544    }
1545
1546    pub fn is_struct(&self) -> bool {
1547        self.as_bty_skipping_existentials()
1548            .map(BaseTy::is_struct)
1549            .unwrap_or_default()
1550    }
1551
1552    pub fn is_array(&self) -> bool {
1553        self.as_bty_skipping_existentials()
1554            .map(BaseTy::is_array)
1555            .unwrap_or_default()
1556    }
1557
1558    pub fn is_slice(&self) -> bool {
1559        self.as_bty_skipping_existentials()
1560            .map(BaseTy::is_slice)
1561            .unwrap_or_default()
1562    }
1563
1564    pub fn as_bty_skipping_existentials(&self) -> Option<&BaseTy> {
1565        match self.kind() {
1566            TyKind::Indexed(bty, _) => Some(bty),
1567            TyKind::Exists(ty) => Some(ty.skip_binder_ref().as_bty_skipping_existentials()?),
1568            TyKind::Constr(_, ty) => ty.as_bty_skipping_existentials(),
1569            _ => None,
1570        }
1571    }
1572
1573    #[track_caller]
1574    pub fn expect_discr(&self) -> (&AdtDef, &Place) {
1575        if let TyKind::Discr(adt_def, place) = self.kind() {
1576            (adt_def, place)
1577        } else {
1578            tracked_span_bug!("expected discr")
1579        }
1580    }
1581
1582    #[track_caller]
1583    pub fn expect_adt(&self) -> (&AdtDef, &[GenericArg], &Expr) {
1584        if let TyKind::Indexed(BaseTy::Adt(adt_def, args), idx) = self.kind() {
1585            (adt_def, args, idx)
1586        } else {
1587            tracked_span_bug!("expected adt `{self:?}`")
1588        }
1589    }
1590
1591    #[track_caller]
1592    pub fn expect_tuple(&self) -> &[Ty] {
1593        if let TyKind::Indexed(BaseTy::Tuple(tys), _) = self.kind() {
1594            tys
1595        } else {
1596            tracked_span_bug!("expected tuple found `{self:?}` (kind: `{:?}`)", self.kind())
1597        }
1598    }
1599}
1600
1601impl<'tcx> ToRustc<'tcx> for Ty {
1602    type T = rustc_middle::ty::Ty<'tcx>;
1603
1604    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
1605        match self.kind() {
1606            TyKind::Indexed(bty, _) => bty.to_rustc(tcx),
1607            TyKind::Exists(ty) => ty.skip_binder_ref().to_rustc(tcx),
1608            TyKind::Constr(_, ty) => ty.to_rustc(tcx),
1609            TyKind::Param(pty) => pty.to_ty(tcx),
1610            TyKind::StrgRef(re, _, ty) => {
1611                rustc_middle::ty::Ty::new_ref(
1612                    tcx,
1613                    re.to_rustc(tcx),
1614                    ty.to_rustc(tcx),
1615                    Mutability::Mut,
1616                )
1617            }
1618            TyKind::Infer(vid) => rustc_middle::ty::Ty::new_var(tcx, *vid),
1619            TyKind::Uninit
1620            | TyKind::Ptr(_, _)
1621            | TyKind::Discr(_, _)
1622            | TyKind::Downcast(_, _, _, _, _)
1623            | TyKind::Blocked(_) => bug!("TODO: to_rustc for `{self:?}`"),
1624        }
1625    }
1626}
1627
1628#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, Debug)]
1629pub enum TyKind {
1630    Indexed(BaseTy, Expr),
1631    Exists(Binder<Ty>),
1632    Constr(Expr, Ty),
1633    Uninit,
1634    StrgRef(Region, Path, Ty),
1635    Ptr(PtrKind, Path),
1636    /// This is a bit of a hack. We use this type internally to represent the result of
1637    /// [`Rvalue::Discriminant`] in a way that we can recover the necessary control information
1638    /// when checking a [`match`]. The hack is that we assume the dicriminant remains the same from
1639    /// the creation of this type until we use it in a [`match`].
1640    ///
1641    ///
1642    /// [`Rvalue::Discriminant`]: flux_rustc_bridge::mir::Rvalue::Discriminant
1643    /// [`match`]: flux_rustc_bridge::mir::TerminatorKind::SwitchInt
1644    Discr(AdtDef, Place),
1645    Param(ParamTy),
1646    Downcast(AdtDef, GenericArgs, Ty, VariantIdx, List<Ty>),
1647    Blocked(Ty),
1648    /// A type that needs to be inferred by matching the signature against a rust signature.
1649    /// [`TyKind::Infer`] appear as an intermediate step during `conv` and should not be present in
1650    /// the final signature.
1651    Infer(TyVid),
1652}
1653
1654#[derive(Copy, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1655pub enum PtrKind {
1656    Mut(Region),
1657    Box,
1658}
1659
1660#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
1661pub enum BaseTy {
1662    Int(IntTy),
1663    Uint(UintTy),
1664    Bool,
1665    Str,
1666    Char,
1667    Slice(Ty),
1668    Adt(AdtDef, GenericArgs),
1669    Float(FloatTy),
1670    RawPtr(Ty, Mutability),
1671    RawPtrMetadata(Ty),
1672    Ref(Region, Ty, Mutability),
1673    FnPtr(PolyFnSig),
1674    FnDef(DefId, GenericArgs),
1675    Tuple(List<Ty>),
1676    Alias(AliasKind, AliasTy),
1677    Array(Ty, Const),
1678    Never,
1679    Closure(DefId, /* upvar_tys */ List<Ty>, flux_rustc_bridge::ty::GenericArgs),
1680    Coroutine(DefId, /*resume_ty: */ Ty, /* upvar_tys: */ List<Ty>),
1681    Dynamic(List<Binder<ExistentialPredicate>>, Region),
1682    Param(ParamTy),
1683    Infer(TyVid),
1684    Foreign(DefId),
1685}
1686
1687impl BaseTy {
1688    pub fn opaque(alias_ty: AliasTy) -> BaseTy {
1689        BaseTy::Alias(AliasKind::Opaque, alias_ty)
1690    }
1691
1692    pub fn projection(alias_ty: AliasTy) -> BaseTy {
1693        BaseTy::Alias(AliasKind::Projection, alias_ty)
1694    }
1695
1696    pub fn adt(adt_def: AdtDef, args: GenericArgs) -> BaseTy {
1697        BaseTy::Adt(adt_def, args)
1698    }
1699
1700    pub fn fn_def(def_id: DefId, args: impl Into<GenericArgs>) -> BaseTy {
1701        BaseTy::FnDef(def_id, args.into())
1702    }
1703
1704    pub fn from_primitive_str(s: &str) -> Option<BaseTy> {
1705        match s {
1706            "i8" => Some(BaseTy::Int(IntTy::I8)),
1707            "i16" => Some(BaseTy::Int(IntTy::I16)),
1708            "i32" => Some(BaseTy::Int(IntTy::I32)),
1709            "i64" => Some(BaseTy::Int(IntTy::I64)),
1710            "i128" => Some(BaseTy::Int(IntTy::I128)),
1711            "u8" => Some(BaseTy::Uint(UintTy::U8)),
1712            "u16" => Some(BaseTy::Uint(UintTy::U16)),
1713            "u32" => Some(BaseTy::Uint(UintTy::U32)),
1714            "u64" => Some(BaseTy::Uint(UintTy::U64)),
1715            "u128" => Some(BaseTy::Uint(UintTy::U128)),
1716            "f16" => Some(BaseTy::Float(FloatTy::F16)),
1717            "f32" => Some(BaseTy::Float(FloatTy::F32)),
1718            "f64" => Some(BaseTy::Float(FloatTy::F64)),
1719            "f128" => Some(BaseTy::Float(FloatTy::F128)),
1720            "isize" => Some(BaseTy::Int(IntTy::Isize)),
1721            "usize" => Some(BaseTy::Uint(UintTy::Usize)),
1722            "bool" => Some(BaseTy::Bool),
1723            "char" => Some(BaseTy::Char),
1724            "str" => Some(BaseTy::Str),
1725            _ => None,
1726        }
1727    }
1728
1729    /// If `self` is a primitive, return its [`Symbol`].
1730    pub fn primitive_symbol(&self) -> Option<Symbol> {
1731        match self {
1732            BaseTy::Bool => Some(sym::bool),
1733            BaseTy::Char => Some(sym::char),
1734            BaseTy::Float(f) => {
1735                match f {
1736                    FloatTy::F16 => Some(sym::f16),
1737                    FloatTy::F32 => Some(sym::f32),
1738                    FloatTy::F64 => Some(sym::f64),
1739                    FloatTy::F128 => Some(sym::f128),
1740                }
1741            }
1742            BaseTy::Int(f) => {
1743                match f {
1744                    IntTy::Isize => Some(sym::isize),
1745                    IntTy::I8 => Some(sym::i8),
1746                    IntTy::I16 => Some(sym::i16),
1747                    IntTy::I32 => Some(sym::i32),
1748                    IntTy::I64 => Some(sym::i64),
1749                    IntTy::I128 => Some(sym::i128),
1750                }
1751            }
1752            BaseTy::Uint(f) => {
1753                match f {
1754                    UintTy::Usize => Some(sym::usize),
1755                    UintTy::U8 => Some(sym::u8),
1756                    UintTy::U16 => Some(sym::u16),
1757                    UintTy::U32 => Some(sym::u32),
1758                    UintTy::U64 => Some(sym::u64),
1759                    UintTy::U128 => Some(sym::u128),
1760                }
1761            }
1762            BaseTy::Str => Some(sym::str),
1763            _ => None,
1764        }
1765    }
1766
1767    pub fn is_integral(&self) -> bool {
1768        matches!(self, BaseTy::Int(_) | BaseTy::Uint(_))
1769    }
1770
1771    pub fn is_numeric(&self) -> bool {
1772        matches!(self, BaseTy::Int(_) | BaseTy::Uint(_) | BaseTy::Float(_))
1773    }
1774
1775    pub fn is_signed(&self) -> bool {
1776        matches!(self, BaseTy::Int(_))
1777    }
1778
1779    pub fn is_unsigned(&self) -> bool {
1780        matches!(self, BaseTy::Uint(_))
1781    }
1782
1783    pub fn is_float(&self) -> bool {
1784        matches!(self, BaseTy::Float(_))
1785    }
1786
1787    pub fn is_bool(&self) -> bool {
1788        matches!(self, BaseTy::Bool)
1789    }
1790
1791    fn is_struct(&self) -> bool {
1792        matches!(self, BaseTy::Adt(adt_def, _) if adt_def.is_struct())
1793    }
1794
1795    fn is_array(&self) -> bool {
1796        matches!(self, BaseTy::Array(..))
1797    }
1798
1799    fn is_slice(&self) -> bool {
1800        matches!(self, BaseTy::Slice(..))
1801    }
1802
1803    pub fn is_box(&self) -> bool {
1804        matches!(self, BaseTy::Adt(adt_def, _) if adt_def.is_box())
1805    }
1806
1807    pub fn is_char(&self) -> bool {
1808        matches!(self, BaseTy::Char)
1809    }
1810
1811    pub fn is_str(&self) -> bool {
1812        matches!(self, BaseTy::Str)
1813    }
1814
1815    pub fn unpack_box(&self) -> Option<(&Ty, &Ty)> {
1816        if let BaseTy::Adt(adt_def, args) = self
1817            && adt_def.is_box()
1818        {
1819            Some(args.box_args())
1820        } else {
1821            None
1822        }
1823    }
1824
1825    pub fn invariants(
1826        &self,
1827        tcx: TyCtxt,
1828        overflow_checking: bool,
1829    ) -> impl Iterator<Item = Invariant> {
1830        let (invariants, args) = match self {
1831            BaseTy::Adt(adt_def, args) => (adt_def.invariants().skip_binder(), &args[..]),
1832            BaseTy::Uint(uint_ty) => (uint_invariants(*uint_ty, overflow_checking), &[][..]),
1833            BaseTy::Int(int_ty) => (int_invariants(*int_ty, overflow_checking), &[][..]),
1834            BaseTy::Slice(_) => (slice_invariants(overflow_checking), &[][..]),
1835            _ => (&[][..], &[][..]),
1836        };
1837        invariants
1838            .iter()
1839            .map(move |inv| EarlyBinder(inv).instantiate_ref(tcx, args, &[]))
1840    }
1841
1842    pub fn to_ty(&self) -> Ty {
1843        let sort = self.sort();
1844        if sort.is_unit() {
1845            Ty::indexed(self.clone(), Expr::unit())
1846        } else {
1847            Ty::exists(Binder::bind_with_sort(Ty::indexed(self.clone(), Expr::nu()), sort))
1848        }
1849    }
1850
1851    pub fn to_subset_ty_ctor(&self) -> SubsetTyCtor {
1852        let sort = self.sort();
1853        Binder::bind_with_sort(SubsetTy::trivial(self.clone(), Expr::nu()), sort)
1854    }
1855
1856    #[track_caller]
1857    pub fn expect_adt(&self) -> (&AdtDef, &[GenericArg]) {
1858        if let BaseTy::Adt(adt_def, args) = self {
1859            (adt_def, args)
1860        } else {
1861            tracked_span_bug!("expected adt `{self:?}`")
1862        }
1863    }
1864
1865    /// A type is an *atom* if it is "self-delimiting", i.e., it has a clear boundary
1866    /// when printed. This is used to avoid unnecessary parenthesis when pretty printing.
1867    pub fn is_atom(&self) -> bool {
1868        // (nilehmann) I'm not sure about this list, please adjust if you get any odd behavior
1869        matches!(
1870            self,
1871            BaseTy::Int(_)
1872                | BaseTy::Uint(_)
1873                | BaseTy::Slice(_)
1874                | BaseTy::Bool
1875                | BaseTy::Char
1876                | BaseTy::Str
1877                | BaseTy::Adt(..)
1878                | BaseTy::Tuple(..)
1879                | BaseTy::Param(_)
1880                | BaseTy::Array(..)
1881                | BaseTy::Never
1882                | BaseTy::Closure(..)
1883                | BaseTy::Coroutine(..)
1884                // opaque alias are atoms the way we print them now, but they won't
1885                // be if we print them as `impl Trait`
1886                | BaseTy::Alias(..)
1887        )
1888    }
1889}
1890
1891impl<'tcx> ToRustc<'tcx> for BaseTy {
1892    type T = rustc_middle::ty::Ty<'tcx>;
1893
1894    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
1895        use rustc_middle::ty;
1896        match self {
1897            BaseTy::Int(i) => ty::Ty::new_int(tcx, *i),
1898            BaseTy::Uint(i) => ty::Ty::new_uint(tcx, *i),
1899            BaseTy::Param(pty) => pty.to_ty(tcx),
1900            BaseTy::Slice(ty) => ty::Ty::new_slice(tcx, ty.to_rustc(tcx)),
1901            BaseTy::Bool => tcx.types.bool,
1902            BaseTy::Char => tcx.types.char,
1903            BaseTy::Str => tcx.types.str_,
1904            BaseTy::Adt(adt_def, args) => {
1905                let did = adt_def.did();
1906                let adt_def = tcx.adt_def(did);
1907                let args = args.to_rustc(tcx);
1908                ty::Ty::new_adt(tcx, adt_def, args)
1909            }
1910            BaseTy::FnDef(def_id, args) => {
1911                let args = args.to_rustc(tcx);
1912                ty::Ty::new_fn_def(tcx, *def_id, args)
1913            }
1914            BaseTy::Float(f) => ty::Ty::new_float(tcx, *f),
1915            BaseTy::RawPtr(ty, mutbl) => ty::Ty::new_ptr(tcx, ty.to_rustc(tcx), *mutbl),
1916            BaseTy::Ref(re, ty, mutbl) => {
1917                ty::Ty::new_ref(tcx, re.to_rustc(tcx), ty.to_rustc(tcx), *mutbl)
1918            }
1919            BaseTy::FnPtr(poly_sig) => ty::Ty::new_fn_ptr(tcx, poly_sig.to_rustc(tcx)),
1920            BaseTy::Tuple(tys) => {
1921                let ts = tys.iter().map(|ty| ty.to_rustc(tcx)).collect_vec();
1922                ty::Ty::new_tup(tcx, &ts)
1923            }
1924            BaseTy::Alias(kind, alias_ty) => {
1925                ty::Ty::new_alias(tcx, kind.to_rustc(tcx), alias_ty.to_rustc(tcx))
1926            }
1927            BaseTy::Array(ty, n) => {
1928                let ty = ty.to_rustc(tcx);
1929                let n = n.to_rustc(tcx);
1930                ty::Ty::new_array_with_const_len(tcx, ty, n)
1931            }
1932            BaseTy::Never => tcx.types.never,
1933            BaseTy::Closure(did, _, args) => ty::Ty::new_closure(tcx, *did, args.to_rustc(tcx)),
1934            BaseTy::Dynamic(exi_preds, re) => {
1935                let preds: Vec<_> = exi_preds
1936                    .iter()
1937                    .map(|pred| pred.to_rustc(tcx))
1938                    .collect_vec();
1939                let preds = tcx.mk_poly_existential_predicates(&preds);
1940                ty::Ty::new_dynamic(tcx, preds, re.to_rustc(tcx), rustc_middle::ty::DynKind::Dyn)
1941            }
1942            BaseTy::Coroutine(def_id, resume_ty, upvars) => {
1943                bug!("TODO: Generator {def_id:?} {resume_ty:?} {upvars:?}")
1944                // let args = args.iter().map(|arg| into_rustc_generic_arg(tcx, arg));
1945                // let args = tcx.mk_args_from_iter(args);
1946                // ty::Ty::new_generator(*tcx, *def_id, args, mov)
1947            }
1948            BaseTy::Infer(ty_vid) => ty::Ty::new_var(tcx, *ty_vid),
1949            BaseTy::Foreign(def_id) => ty::Ty::new_foreign(tcx, *def_id),
1950            BaseTy::RawPtrMetadata(ty) => {
1951                ty::Ty::new_ptr(
1952                    tcx,
1953                    ty.to_rustc(tcx),
1954                    RawPtrKind::FakeForPtrMetadata.to_mutbl_lossy(),
1955                )
1956            }
1957        }
1958    }
1959}
1960
1961#[derive(
1962    Clone, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable,
1963)]
1964pub struct AliasTy {
1965    pub def_id: DefId,
1966    pub args: GenericArgs,
1967    /// Holds the refinement-arguments for opaque-types; empty for projections
1968    pub refine_args: RefineArgs,
1969}
1970
1971impl AliasTy {
1972    pub fn new(def_id: DefId, args: GenericArgs, refine_args: RefineArgs) -> Self {
1973        AliasTy { args, refine_args, def_id }
1974    }
1975}
1976
1977/// This methods work only with associated type projections (i.e., no opaque types)
1978impl AliasTy {
1979    pub fn self_ty(&self) -> SubsetTyCtor {
1980        self.args[0].expect_base().clone()
1981    }
1982
1983    pub fn with_self_ty(&self, self_ty: SubsetTyCtor) -> Self {
1984        Self {
1985            def_id: self.def_id,
1986            args: [GenericArg::Base(self_ty)]
1987                .into_iter()
1988                .chain(self.args.iter().skip(1).cloned())
1989                .collect(),
1990            refine_args: self.refine_args.clone(),
1991        }
1992    }
1993}
1994
1995impl<'tcx> ToRustc<'tcx> for AliasTy {
1996    type T = rustc_middle::ty::AliasTy<'tcx>;
1997
1998    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
1999        rustc_middle::ty::AliasTy::new(tcx, self.def_id, self.args.to_rustc(tcx))
2000    }
2001}
2002
2003pub type RefineArgs = List<Expr>;
2004
2005#[extension(pub trait RefineArgsExt)]
2006impl RefineArgs {
2007    fn identity_for_item(genv: GlobalEnv, def_id: DefId) -> QueryResult<RefineArgs> {
2008        Self::for_item(genv, def_id, |param, index| {
2009            Ok(Expr::var(Var::EarlyParam(EarlyReftParam {
2010                index: index as u32,
2011                name: param.name(),
2012            })))
2013        })
2014    }
2015
2016    fn for_item<F>(genv: GlobalEnv, def_id: DefId, mut mk: F) -> QueryResult<RefineArgs>
2017    where
2018        F: FnMut(EarlyBinder<RefineParam>, usize) -> QueryResult<Expr>,
2019    {
2020        let reft_generics = genv.refinement_generics_of(def_id)?;
2021        let count = reft_generics.count();
2022        let mut args = Vec::with_capacity(count);
2023        reft_generics.fill_item(genv, &mut args, &mut mk)?;
2024        Ok(List::from_vec(args))
2025    }
2026}
2027
2028/// A type constructor meant to be used as generic a argument of [kind base]. This is just an alias
2029/// to [`Binder<SubsetTy>`], but we expect the binder to have a single bound variable of the sort of
2030/// the underlying [base type].
2031///
2032/// [kind base]: GenericParamDefKind::Base
2033/// [base type]: SubsetTy::bty
2034pub type SubsetTyCtor = Binder<SubsetTy>;
2035
2036impl SubsetTyCtor {
2037    pub fn as_bty_skipping_binder(&self) -> &BaseTy {
2038        &self.as_ref().skip_binder().bty
2039    }
2040
2041    pub fn to_ty(&self) -> Ty {
2042        let sort = self.sort();
2043        if sort.is_unit() {
2044            self.replace_bound_reft(&Expr::unit()).to_ty()
2045        } else if let Some(def_id) = sort.is_unit_adt() {
2046            self.replace_bound_reft(&Expr::unit_struct(def_id)).to_ty()
2047        } else {
2048            Ty::exists(self.as_ref().map(SubsetTy::to_ty))
2049        }
2050    }
2051
2052    pub fn to_ty_ctor(&self) -> TyCtor {
2053        self.as_ref().map(SubsetTy::to_ty)
2054    }
2055}
2056
2057/// A subset type is a simplified version of a type that has the form `{b[e] | p}` where `b` is a
2058/// [`BaseTy`], `e` a refinement index, and `p` a predicate.
2059///
2060/// These are mainly found under a [`Binder`] with a single variable of the base type's sort. This
2061/// can be interpreted as a type constructor or an existential type. For example, under a binder with a
2062/// variable `v` of sort `int`, we can interpret `{i32[v] | v > 0}` as:
2063/// - A lambda `λv:int. {i32[v] | v > 0}` that "constructs" types when applied to ints, or
2064/// - An existential type `∃v:int. {i32[v] | v > 0}`.
2065///
2066/// This second interpretation is the reason we call this a subset type, i.e., the type `∃v. {b[v] | p}`
2067/// corresponds to the subset of values of  type `b` whose index satisfies `p`. These are the types
2068/// written as `B{v: p}` in the surface syntax and correspond to the types supported in other
2069/// refinement type systems like Liquid Haskell (with the difference that we are explicit
2070/// about separating refinements from program values via an index).
2071///
2072/// The main purpose for subset types is to be used as generic arguments of [kind base] when
2073/// interpreted as type constructors. They have two key properties that makes them suitable
2074/// for this:
2075///
2076/// 1. **Syntactic Restriction**: Subset types are syntactically restricted, making it easier to
2077///    relate them structurally (e.g., for subtyping). For instance, given two types `S<λv. T1>` and
2078///    `S<λ. T2>`, if `T1` and `T2` are subset types, we know they match structurally (at least
2079///    shallowly). In particularly, the syntactic restriction rules out complex types like
2080///    `S<λv. (i32[v], i32[v])>` simplifying some operations.
2081///
2082/// 2. **Eager Canonicalization**: Subset types can be eagerly canonicalized via [*strengthening*]
2083///    during substitution. For example, suppose we have a function:
2084///    ```text
2085///    fn foo<T>(x: T[@a], y: { T[@b] | b == a }) { }
2086///    ```
2087///    If we instantiate `T` with `λv. { i32[v] | v > 0}`, after substitution and applying the
2088///    lambda (the indexing syntax `T[a]` corresponds to an application of the lambda), we get:
2089///    ```text
2090///    fn foo(x: {i32[@a] | a > 0}, y: { { i32[@b] | b > 0 } | b == a }) { }
2091///    ```
2092///    Via *strengthening* we can canonicalize this to
2093///    ```text
2094///    fn foo(x: {i32[@a] | a > 0}, y: { i32[@b] | b == a && b > 0 }) { }
2095///    ```
2096///    As a result, we can guarantee the syntactic restriction through substitution.
2097///
2098/// [kind base]: GenericParamDefKind::Base
2099/// [*strengthening*]: https://arxiv.org/pdf/2010.07763.pdf
2100#[derive(PartialEq, Clone, Eq, Hash, TyEncodable, TyDecodable)]
2101pub struct SubsetTy {
2102    /// The base type `b` in the subset type `{b[e] | p}`.
2103    ///
2104    /// **NOTE:** This is mostly going to be under a [`Binder`]. It is not yet clear to me whether
2105    /// this [`BaseTy`] should be able to mention variables in the binder. In general, in a type
2106    /// `∃v. {b[e] | p}`, it's fine to mention `v` inside `b`, but since [`SubsetTy`] is meant to
2107    /// facilitate syntactic manipulation we may want to restrict this.
2108    pub bty: BaseTy,
2109    /// The refinement index `e` in the subset type `{b[e] | p}`. This can be an arbitrary expression,
2110    /// which makes manipulation easier. However, since this is mostly found under a binder, we expect
2111    /// it to be [`Expr::nu()`].
2112    pub idx: Expr,
2113    /// The predicate `p` in the subset type `{b[e] | p}`.
2114    pub pred: Expr,
2115}
2116
2117impl SubsetTy {
2118    pub fn new(bty: BaseTy, idx: impl Into<Expr>, pred: impl Into<Expr>) -> Self {
2119        Self { bty, idx: idx.into(), pred: pred.into() }
2120    }
2121
2122    pub fn trivial(bty: BaseTy, idx: impl Into<Expr>) -> Self {
2123        Self::new(bty, idx, Expr::tt())
2124    }
2125
2126    pub fn strengthen(&self, pred: impl Into<Expr>) -> Self {
2127        let this = self.clone();
2128        let pred = Expr::and(this.pred, pred).simplify(&SnapshotMap::default());
2129        Self { bty: this.bty, idx: this.idx, pred }
2130    }
2131
2132    pub fn to_ty(&self) -> Ty {
2133        let bty = self.bty.clone();
2134        if self.pred.is_trivially_true() {
2135            Ty::indexed(bty, &self.idx)
2136        } else {
2137            Ty::constr(&self.pred, Ty::indexed(bty, &self.idx))
2138        }
2139    }
2140}
2141
2142impl<'tcx> ToRustc<'tcx> for SubsetTy {
2143    type T = rustc_middle::ty::Ty<'tcx>;
2144
2145    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> rustc_middle::ty::Ty<'tcx> {
2146        self.bty.to_rustc(tcx)
2147    }
2148}
2149
2150#[derive(PartialEq, Clone, Eq, Hash, TyEncodable, TyDecodable)]
2151pub enum GenericArg {
2152    Ty(Ty),
2153    Base(SubsetTyCtor),
2154    Lifetime(Region),
2155    Const(Const),
2156}
2157
2158impl GenericArg {
2159    #[track_caller]
2160    pub fn expect_type(&self) -> &Ty {
2161        if let GenericArg::Ty(ty) = self {
2162            ty
2163        } else {
2164            bug!("expected `rty::GenericArg::Ty`, found `{self:?}`")
2165        }
2166    }
2167
2168    #[track_caller]
2169    pub fn expect_base(&self) -> &SubsetTyCtor {
2170        if let GenericArg::Base(ctor) = self {
2171            ctor
2172        } else {
2173            bug!("expected `rty::GenericArg::Base`, found `{self:?}`")
2174        }
2175    }
2176
2177    pub fn from_param_def(param: &GenericParamDef) -> Self {
2178        match param.kind {
2179            GenericParamDefKind::Type { .. } => {
2180                let param_ty = ParamTy { index: param.index, name: param.name };
2181                GenericArg::Ty(Ty::param(param_ty))
2182            }
2183            GenericParamDefKind::Base { .. } => {
2184                // λv. T[v]
2185                let param_ty = ParamTy { index: param.index, name: param.name };
2186                GenericArg::Base(Binder::bind_with_sort(
2187                    SubsetTy::trivial(BaseTy::Param(param_ty), Expr::nu()),
2188                    Sort::Param(param_ty),
2189                ))
2190            }
2191            GenericParamDefKind::Lifetime => {
2192                let region = EarlyParamRegion { index: param.index, name: param.name };
2193                GenericArg::Lifetime(Region::ReEarlyParam(region))
2194            }
2195            GenericParamDefKind::Const { .. } => {
2196                let param_const = ParamConst { index: param.index, name: param.name };
2197                let kind = ConstKind::Param(param_const);
2198                GenericArg::Const(Const { kind })
2199            }
2200        }
2201    }
2202
2203    /// Creates a `GenericArgs` from the definition of generic parameters, by calling a closure to
2204    /// obtain arg. The closures get to observe the `GenericArgs` as they're being built, which can
2205    /// be used to correctly replace defaults of generic parameters.
2206    pub fn for_item<F>(genv: GlobalEnv, def_id: DefId, mut mk_kind: F) -> QueryResult<GenericArgs>
2207    where
2208        F: FnMut(&GenericParamDef, &[GenericArg]) -> GenericArg,
2209    {
2210        let defs = genv.generics_of(def_id)?;
2211        let count = defs.count();
2212        let mut args = Vec::with_capacity(count);
2213        Self::fill_item(genv, &mut args, &defs, &mut mk_kind)?;
2214        Ok(List::from_vec(args))
2215    }
2216
2217    pub fn identity_for_item(genv: GlobalEnv, def_id: DefId) -> QueryResult<GenericArgs> {
2218        Self::for_item(genv, def_id, |param, _| GenericArg::from_param_def(param))
2219    }
2220
2221    fn fill_item<F>(
2222        genv: GlobalEnv,
2223        args: &mut Vec<GenericArg>,
2224        generics: &Generics,
2225        mk_kind: &mut F,
2226    ) -> QueryResult<()>
2227    where
2228        F: FnMut(&GenericParamDef, &[GenericArg]) -> GenericArg,
2229    {
2230        if let Some(def_id) = generics.parent {
2231            let parent_generics = genv.generics_of(def_id)?;
2232            Self::fill_item(genv, args, &parent_generics, mk_kind)?;
2233        }
2234        for param in &generics.own_params {
2235            let kind = mk_kind(param, args);
2236            tracked_span_assert_eq!(param.index as usize, args.len());
2237            args.push(kind);
2238        }
2239        Ok(())
2240    }
2241}
2242
2243impl From<TyOrBase> for GenericArg {
2244    fn from(v: TyOrBase) -> Self {
2245        match v {
2246            TyOrBase::Ty(ty) => GenericArg::Ty(ty),
2247            TyOrBase::Base(ctor) => GenericArg::Base(ctor),
2248        }
2249    }
2250}
2251
2252impl<'tcx> ToRustc<'tcx> for GenericArg {
2253    type T = rustc_middle::ty::GenericArg<'tcx>;
2254
2255    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2256        use rustc_middle::ty;
2257        match self {
2258            GenericArg::Ty(ty) => ty::GenericArg::from(ty.to_rustc(tcx)),
2259            GenericArg::Base(ctor) => ty::GenericArg::from(ctor.skip_binder_ref().to_rustc(tcx)),
2260            GenericArg::Lifetime(re) => ty::GenericArg::from(re.to_rustc(tcx)),
2261            GenericArg::Const(c) => ty::GenericArg::from(c.to_rustc(tcx)),
2262        }
2263    }
2264}
2265
2266pub type GenericArgs = List<GenericArg>;
2267
2268#[extension(pub trait GenericArgsExt)]
2269impl GenericArgs {
2270    #[track_caller]
2271    fn box_args(&self) -> (&Ty, &Ty) {
2272        if let [GenericArg::Ty(deref), GenericArg::Ty(alloc)] = &self[..] {
2273            (deref, alloc)
2274        } else {
2275            bug!("invalid generic arguments for box");
2276        }
2277    }
2278
2279    // We can't implement [`ToRustc`] because of coherence so we add it here
2280    fn to_rustc<'tcx>(&self, tcx: TyCtxt<'tcx>) -> rustc_middle::ty::GenericArgsRef<'tcx> {
2281        tcx.mk_args_from_iter(self.iter().map(|arg| arg.to_rustc(tcx)))
2282    }
2283
2284    fn rebase_onto(
2285        &self,
2286        tcx: &TyCtxt,
2287        source_ancestor: DefId,
2288        target_args: &GenericArgs,
2289    ) -> List<GenericArg> {
2290        let defs = tcx.generics_of(source_ancestor);
2291        target_args
2292            .iter()
2293            .chain(self.iter().skip(defs.count()))
2294            .cloned()
2295            .collect()
2296    }
2297}
2298
2299#[derive(Debug)]
2300pub enum TyOrBase {
2301    Ty(Ty),
2302    Base(SubsetTyCtor),
2303}
2304
2305impl TyOrBase {
2306    pub fn into_ty(self) -> Ty {
2307        match self {
2308            TyOrBase::Ty(ty) => ty,
2309            TyOrBase::Base(ctor) => ctor.to_ty(),
2310        }
2311    }
2312
2313    #[track_caller]
2314    pub fn expect_base(self) -> SubsetTyCtor {
2315        match self {
2316            TyOrBase::Base(ctor) => ctor,
2317            TyOrBase::Ty(_) => tracked_span_bug!("expected `TyOrBase::Base`"),
2318        }
2319    }
2320
2321    pub fn as_base(self) -> Option<SubsetTyCtor> {
2322        match self {
2323            TyOrBase::Base(ctor) => Some(ctor),
2324            TyOrBase::Ty(_) => None,
2325        }
2326    }
2327}
2328
2329#[derive(Debug, Clone, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
2330pub enum TyOrCtor {
2331    Ty(Ty),
2332    Ctor(TyCtor),
2333}
2334
2335impl TyOrCtor {
2336    #[track_caller]
2337    pub fn expect_ctor(self) -> TyCtor {
2338        match self {
2339            TyOrCtor::Ctor(ctor) => ctor,
2340            TyOrCtor::Ty(_) => tracked_span_bug!("expected `TyOrCtor::Ctor`"),
2341        }
2342    }
2343
2344    pub fn expect_subset_ty_ctor(self) -> SubsetTyCtor {
2345        self.expect_ctor().map(|ty| {
2346            if let canonicalize::CanonicalTy::Constr(constr_ty) = ty.shallow_canonicalize()
2347                && let TyKind::Indexed(bty, idx) = constr_ty.ty().kind()
2348                && idx.is_nu()
2349            {
2350                SubsetTy::new(bty.clone(), Expr::nu(), constr_ty.pred())
2351            } else {
2352                tracked_span_bug!()
2353            }
2354        })
2355    }
2356
2357    pub fn to_ty(&self) -> Ty {
2358        match self {
2359            TyOrCtor::Ctor(ctor) => ctor.to_ty(),
2360            TyOrCtor::Ty(ty) => ty.clone(),
2361        }
2362    }
2363}
2364
2365impl From<TyOrBase> for TyOrCtor {
2366    fn from(v: TyOrBase) -> Self {
2367        match v {
2368            TyOrBase::Ty(ty) => TyOrCtor::Ty(ty),
2369            TyOrBase::Base(ctor) => TyOrCtor::Ctor(ctor.to_ty_ctor()),
2370        }
2371    }
2372}
2373
2374impl CoroutineObligPredicate {
2375    pub fn to_poly_fn_sig(&self) -> PolyFnSig {
2376        let vars = vec![];
2377
2378        let resume_ty = &self.resume_ty;
2379        let env_ty = Ty::coroutine(self.def_id, resume_ty.clone(), self.upvar_tys.clone());
2380
2381        let inputs = List::from_arr([env_ty, resume_ty.clone()]);
2382        let output =
2383            Binder::bind_with_vars(FnOutput::new(self.output.clone(), vec![]), List::empty());
2384
2385        PolyFnSig::bind_with_vars(
2386            FnSig::new(Safety::Safe, rustc_abi::ExternAbi::RustCall, List::empty(), inputs, output),
2387            List::from(vars),
2388        )
2389    }
2390}
2391
2392impl RefinementGenerics {
2393    pub fn count(&self) -> usize {
2394        self.parent_count + self.own_params.len()
2395    }
2396
2397    pub fn own_count(&self) -> usize {
2398        self.own_params.len()
2399    }
2400
2401    // /// Iterate and collect all parameters in this item including parents
2402    // pub fn collect_all_params<T, S>(
2403    //     &self,
2404    //     genv: GlobalEnv,
2405    //     mut f: impl FnMut(RefineParam) -> T,
2406    // ) -> QueryResult<S>
2407    // where
2408    //     S: FromIterator<T>,
2409    // {
2410    //     (0..self.count())
2411    //         .map(|i| Ok(f(self.param_at(i, genv)?)))
2412    //         .try_collect()
2413    // }
2414}
2415
2416impl EarlyBinder<RefinementGenerics> {
2417    pub fn parent(&self) -> Option<DefId> {
2418        self.skip_binder_ref().parent
2419    }
2420
2421    pub fn parent_count(&self) -> usize {
2422        self.skip_binder_ref().parent_count
2423    }
2424
2425    pub fn count(&self) -> usize {
2426        self.skip_binder_ref().count()
2427    }
2428
2429    pub fn own_count(&self) -> usize {
2430        self.skip_binder_ref().own_count()
2431    }
2432
2433    pub fn own_param_at(&self, index: usize) -> EarlyBinder<RefineParam> {
2434        self.as_ref().map(|this| this.own_params[index].clone())
2435    }
2436
2437    pub fn param_at(
2438        &self,
2439        param_index: usize,
2440        genv: GlobalEnv,
2441    ) -> QueryResult<EarlyBinder<RefineParam>> {
2442        if let Some(index) = param_index.checked_sub(self.parent_count()) {
2443            Ok(self.own_param_at(index))
2444        } else {
2445            let parent = self.parent().expect("parent_count > 0 but no parent?");
2446            genv.refinement_generics_of(parent)?
2447                .param_at(param_index, genv)
2448        }
2449    }
2450
2451    pub fn iter_own_params(&self) -> impl Iterator<Item = EarlyBinder<RefineParam>> + use<'_> {
2452        self.skip_binder_ref()
2453            .own_params
2454            .iter()
2455            .cloned()
2456            .map(EarlyBinder)
2457    }
2458
2459    pub fn fill_item<F, R>(&self, genv: GlobalEnv, vec: &mut Vec<R>, mk: &mut F) -> QueryResult
2460    where
2461        F: FnMut(EarlyBinder<RefineParam>, usize) -> QueryResult<R>,
2462    {
2463        if let Some(def_id) = self.parent() {
2464            genv.refinement_generics_of(def_id)?
2465                .fill_item(genv, vec, mk)?;
2466        }
2467        for param in self.iter_own_params() {
2468            vec.push(mk(param, vec.len())?);
2469        }
2470        Ok(())
2471    }
2472}
2473
2474impl EarlyBinder<GenericPredicates> {
2475    pub fn predicates(&self) -> EarlyBinder<List<Clause>> {
2476        EarlyBinder(self.0.predicates.clone())
2477    }
2478}
2479
2480impl EarlyBinder<FuncSort> {
2481    /// See [`subst::GenericsSubstForSort`]
2482    pub fn instantiate_func_sort<E>(
2483        self,
2484        sort_for_param: impl FnMut(ParamTy) -> Result<Sort, E>,
2485    ) -> Result<FuncSort, E> {
2486        self.0.try_fold_with(&mut subst::GenericsSubstFolder::new(
2487            subst::GenericsSubstForSort { sort_for_param },
2488            &[],
2489        ))
2490    }
2491}
2492
2493impl VariantSig {
2494    pub fn new(
2495        adt_def: AdtDef,
2496        args: GenericArgs,
2497        fields: List<Ty>,
2498        idx: Expr,
2499        requires: List<Expr>,
2500    ) -> Self {
2501        VariantSig { adt_def, args, fields, idx, requires }
2502    }
2503
2504    pub fn fields(&self) -> &[Ty] {
2505        &self.fields
2506    }
2507
2508    pub fn ret(&self) -> Ty {
2509        let bty = BaseTy::Adt(self.adt_def.clone(), self.args.clone());
2510        let idx = self.idx.clone();
2511        Ty::indexed(bty, idx)
2512    }
2513}
2514
2515impl FnSig {
2516    pub fn new(
2517        safety: Safety,
2518        abi: rustc_abi::ExternAbi,
2519        requires: List<Expr>,
2520        inputs: List<Ty>,
2521        output: Binder<FnOutput>,
2522    ) -> Self {
2523        FnSig { safety, abi, requires, inputs, output }
2524    }
2525
2526    pub fn requires(&self) -> &[Expr] {
2527        &self.requires
2528    }
2529
2530    pub fn inputs(&self) -> &[Ty] {
2531        &self.inputs
2532    }
2533
2534    pub fn output(&self) -> Binder<FnOutput> {
2535        self.output.clone()
2536    }
2537}
2538
2539impl<'tcx> ToRustc<'tcx> for FnSig {
2540    type T = rustc_middle::ty::FnSig<'tcx>;
2541
2542    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2543        tcx.mk_fn_sig(
2544            self.inputs().iter().map(|ty| ty.to_rustc(tcx)),
2545            self.output().as_ref().skip_binder().to_rustc(tcx),
2546            false,
2547            self.safety,
2548            self.abi,
2549        )
2550    }
2551}
2552
2553impl FnOutput {
2554    pub fn new(ret: Ty, ensures: impl Into<List<Ensures>>) -> Self {
2555        Self { ret, ensures: ensures.into() }
2556    }
2557}
2558
2559impl<'tcx> ToRustc<'tcx> for FnOutput {
2560    type T = rustc_middle::ty::Ty<'tcx>;
2561
2562    fn to_rustc(&self, tcx: TyCtxt<'tcx>) -> Self::T {
2563        self.ret.to_rustc(tcx)
2564    }
2565}
2566
2567impl AdtDef {
2568    pub fn new(
2569        rustc: ty::AdtDef,
2570        sort_def: AdtSortDef,
2571        invariants: Vec<Invariant>,
2572        opaque: bool,
2573    ) -> Self {
2574        AdtDef(Interned::new(AdtDefData { invariants, sort_def, opaque, rustc }))
2575    }
2576
2577    pub fn did(&self) -> DefId {
2578        self.0.rustc.did()
2579    }
2580
2581    pub fn sort_def(&self) -> &AdtSortDef {
2582        &self.0.sort_def
2583    }
2584
2585    pub fn sort(&self, args: &[GenericArg]) -> Sort {
2586        self.sort_def().to_sort(args)
2587    }
2588
2589    pub fn is_box(&self) -> bool {
2590        self.0.rustc.is_box()
2591    }
2592
2593    pub fn is_enum(&self) -> bool {
2594        self.0.rustc.is_enum()
2595    }
2596
2597    pub fn is_struct(&self) -> bool {
2598        self.0.rustc.is_struct()
2599    }
2600
2601    pub fn is_union(&self) -> bool {
2602        self.0.rustc.is_union()
2603    }
2604
2605    pub fn variants(&self) -> &IndexSlice<VariantIdx, VariantDef> {
2606        self.0.rustc.variants()
2607    }
2608
2609    pub fn variant(&self, idx: VariantIdx) -> &VariantDef {
2610        self.0.rustc.variant(idx)
2611    }
2612
2613    pub fn invariants(&self) -> EarlyBinder<&[Invariant]> {
2614        EarlyBinder(&self.0.invariants)
2615    }
2616
2617    pub fn discriminants(&self) -> impl Iterator<Item = (VariantIdx, u128)> + '_ {
2618        self.0.rustc.discriminants()
2619    }
2620
2621    pub fn is_opaque(&self) -> bool {
2622        self.0.opaque
2623    }
2624}
2625
2626impl EarlyBinder<PolyVariant> {
2627    // The field_idx is `Some(i)` when we have the `i`-th field of a `union`, in which case,
2628    // the `inputs` are _just_ the `i`-th type (and not all the types...)
2629    pub fn to_poly_fn_sig(&self, field_idx: Option<crate::FieldIdx>) -> EarlyBinder<PolyFnSig> {
2630        self.as_ref().map(|poly_variant| {
2631            poly_variant.as_ref().map(|variant| {
2632                let ret = variant.ret().shift_in_escaping(1);
2633                let output = Binder::bind_with_vars(FnOutput::new(ret, vec![]), List::empty());
2634                let inputs = match field_idx {
2635                    None => variant.fields.clone(),
2636                    Some(i) => List::singleton(variant.fields[i.index()].clone()),
2637                };
2638                FnSig::new(
2639                    Safety::Safe,
2640                    rustc_abi::ExternAbi::Rust,
2641                    variant.requires.clone(),
2642                    inputs,
2643                    output,
2644                )
2645            })
2646        })
2647    }
2648}
2649
2650impl TyKind {
2651    fn intern(self) -> Ty {
2652        Ty(Interned::new(self))
2653    }
2654}
2655
2656/// returns the same invariants as for `usize` which is the length of a slice
2657fn slice_invariants(overflow_checking: bool) -> &'static [Invariant] {
2658    static DEFAULT: LazyLock<[Invariant; 1]> = LazyLock::new(|| {
2659        [Invariant { pred: Binder::bind_with_sort(Expr::ge(Expr::nu(), Expr::zero()), Sort::Int) }]
2660    });
2661    static OVERFLOW: LazyLock<[Invariant; 2]> = LazyLock::new(|| {
2662        [
2663            Invariant {
2664                pred: Binder::bind_with_sort(Expr::ge(Expr::nu(), Expr::zero()), Sort::Int),
2665            },
2666            Invariant {
2667                pred: Binder::bind_with_sort(
2668                    Expr::le(Expr::nu(), Expr::uint_max(UintTy::Usize)),
2669                    Sort::Int,
2670                ),
2671            },
2672        ]
2673    });
2674    if overflow_checking { &*OVERFLOW } else { &*DEFAULT }
2675}
2676
2677fn uint_invariants(uint_ty: UintTy, overflow_checking: bool) -> &'static [Invariant] {
2678    static DEFAULT: LazyLock<[Invariant; 1]> = LazyLock::new(|| {
2679        [Invariant { pred: Binder::bind_with_sort(Expr::ge(Expr::nu(), Expr::zero()), Sort::Int) }]
2680    });
2681
2682    static OVERFLOW: LazyLock<UnordMap<UintTy, [Invariant; 2]>> = LazyLock::new(|| {
2683        UINT_TYS
2684            .into_iter()
2685            .map(|uint_ty| {
2686                let invariants = [
2687                    Invariant {
2688                        pred: Binder::bind_with_sort(Expr::ge(Expr::nu(), Expr::zero()), Sort::Int),
2689                    },
2690                    Invariant {
2691                        pred: Binder::bind_with_sort(
2692                            Expr::le(Expr::nu(), Expr::uint_max(uint_ty)),
2693                            Sort::Int,
2694                        ),
2695                    },
2696                ];
2697                (uint_ty, invariants)
2698            })
2699            .collect()
2700    });
2701    if overflow_checking { &OVERFLOW[&uint_ty] } else { &*DEFAULT }
2702}
2703
2704fn int_invariants(int_ty: IntTy, overflow_checking: bool) -> &'static [Invariant] {
2705    static DEFAULT: [Invariant; 0] = [];
2706
2707    static OVERFLOW: LazyLock<UnordMap<IntTy, [Invariant; 2]>> = LazyLock::new(|| {
2708        INT_TYS
2709            .into_iter()
2710            .map(|int_ty| {
2711                let invariants = [
2712                    Invariant {
2713                        pred: Binder::bind_with_sort(
2714                            Expr::ge(Expr::nu(), Expr::int_min(int_ty)),
2715                            Sort::Int,
2716                        ),
2717                    },
2718                    Invariant {
2719                        pred: Binder::bind_with_sort(
2720                            Expr::le(Expr::nu(), Expr::int_max(int_ty)),
2721                            Sort::Int,
2722                        ),
2723                    },
2724                ];
2725                (int_ty, invariants)
2726            })
2727            .collect()
2728    });
2729    if overflow_checking { &OVERFLOW[&int_ty] } else { &DEFAULT }
2730}
2731
2732impl_internable!(AdtDefData, AdtSortDefData, TyKind);
2733impl_slice_internable!(
2734    Ty,
2735    GenericArg,
2736    Ensures,
2737    InferMode,
2738    Sort,
2739    GenericParamDef,
2740    TraitRef,
2741    Binder<ExistentialPredicate>,
2742    Clause,
2743    PolyVariant,
2744    Invariant,
2745    RefineParam,
2746    FluxDefId,
2747    SortParamKind,
2748    AssocReft
2749);
2750
2751#[macro_export]
2752macro_rules! _Int {
2753    ($int_ty:pat, $idxs:pat) => {
2754        TyKind::Indexed(BaseTy::Int($int_ty), $idxs)
2755    };
2756}
2757pub use crate::_Int as Int;
2758
2759#[macro_export]
2760macro_rules! _Uint {
2761    ($uint_ty:pat, $idxs:pat) => {
2762        TyKind::Indexed(BaseTy::Uint($uint_ty), $idxs)
2763    };
2764}
2765pub use crate::_Uint as Uint;
2766
2767#[macro_export]
2768macro_rules! _Bool {
2769    ($idxs:pat) => {
2770        TyKind::Indexed(BaseTy::Bool, $idxs)
2771    };
2772}
2773pub use crate::_Bool as Bool;
2774
2775#[macro_export]
2776macro_rules! _Char {
2777    ($idxs:pat) => {
2778        TyKind::Indexed(BaseTy::Char, $idxs)
2779    };
2780}
2781pub use crate::_Char as Char;
2782
2783#[macro_export]
2784macro_rules! _Ref {
2785    ($($pats:pat),+ $(,)?) => {
2786        $crate::rty::TyKind::Indexed($crate::rty::BaseTy::Ref($($pats),+), _)
2787    };
2788}
2789pub use crate::_Ref as Ref;
2790
2791pub struct WfckResults {
2792    pub owner: FluxOwnerId,
2793    bin_op_sorts: ItemLocalMap<Sort>,
2794    coercions: ItemLocalMap<Vec<Coercion>>,
2795    field_projs: ItemLocalMap<FieldProj>,
2796    node_sorts: ItemLocalMap<Sort>,
2797    record_ctors: ItemLocalMap<DefId>,
2798}
2799
2800#[derive(Clone, Copy, Debug)]
2801pub enum Coercion {
2802    Inject(DefId),
2803    Project(DefId),
2804}
2805
2806pub type ItemLocalMap<T> = UnordMap<fhir::ItemLocalId, T>;
2807
2808#[derive(Debug)]
2809pub struct LocalTableInContext<'a, T> {
2810    owner: FluxOwnerId,
2811    data: &'a ItemLocalMap<T>,
2812}
2813
2814pub struct LocalTableInContextMut<'a, T> {
2815    owner: FluxOwnerId,
2816    data: &'a mut ItemLocalMap<T>,
2817}
2818
2819impl WfckResults {
2820    pub fn new(owner: impl Into<FluxOwnerId>) -> Self {
2821        Self {
2822            owner: owner.into(),
2823            bin_op_sorts: ItemLocalMap::default(),
2824            coercions: ItemLocalMap::default(),
2825            field_projs: ItemLocalMap::default(),
2826            node_sorts: ItemLocalMap::default(),
2827            record_ctors: ItemLocalMap::default(),
2828        }
2829    }
2830
2831    pub fn bin_op_sorts_mut(&mut self) -> LocalTableInContextMut<'_, Sort> {
2832        LocalTableInContextMut { owner: self.owner, data: &mut self.bin_op_sorts }
2833    }
2834
2835    pub fn bin_op_sorts(&self) -> LocalTableInContext<'_, Sort> {
2836        LocalTableInContext { owner: self.owner, data: &self.bin_op_sorts }
2837    }
2838
2839    pub fn coercions_mut(&mut self) -> LocalTableInContextMut<'_, Vec<Coercion>> {
2840        LocalTableInContextMut { owner: self.owner, data: &mut self.coercions }
2841    }
2842
2843    pub fn coercions(&self) -> LocalTableInContext<'_, Vec<Coercion>> {
2844        LocalTableInContext { owner: self.owner, data: &self.coercions }
2845    }
2846
2847    pub fn field_projs_mut(&mut self) -> LocalTableInContextMut<'_, FieldProj> {
2848        LocalTableInContextMut { owner: self.owner, data: &mut self.field_projs }
2849    }
2850
2851    pub fn field_projs(&self) -> LocalTableInContext<'_, FieldProj> {
2852        LocalTableInContext { owner: self.owner, data: &self.field_projs }
2853    }
2854
2855    pub fn node_sorts_mut(&mut self) -> LocalTableInContextMut<'_, Sort> {
2856        LocalTableInContextMut { owner: self.owner, data: &mut self.node_sorts }
2857    }
2858
2859    pub fn node_sorts(&self) -> LocalTableInContext<'_, Sort> {
2860        LocalTableInContext { owner: self.owner, data: &self.node_sorts }
2861    }
2862
2863    pub fn record_ctors_mut(&mut self) -> LocalTableInContextMut<'_, DefId> {
2864        LocalTableInContextMut { owner: self.owner, data: &mut self.record_ctors }
2865    }
2866
2867    pub fn record_ctors(&self) -> LocalTableInContext<'_, DefId> {
2868        LocalTableInContext { owner: self.owner, data: &self.record_ctors }
2869    }
2870}
2871
2872impl<T> LocalTableInContextMut<'_, T> {
2873    pub fn insert(&mut self, fhir_id: FhirId, value: T) {
2874        tracked_span_assert_eq!(self.owner, fhir_id.owner);
2875        self.data.insert(fhir_id.local_id, value);
2876    }
2877}
2878
2879impl<'a, T> LocalTableInContext<'a, T> {
2880    pub fn get(&self, fhir_id: FhirId) -> Option<&'a T> {
2881        tracked_span_assert_eq!(self.owner, fhir_id.owner);
2882        self.data.get(&fhir_id.local_id)
2883    }
2884}