use std::{fmt, iter, sync::OnceLock};
use flux_arc_interner::{impl_internable, impl_slice_internable, Interned, List};
use flux_common::bug;
use flux_macros::{TypeFoldable, TypeVisitable};
use flux_rustc_bridge::{
const_eval::{scalar_to_bits, scalar_to_int, scalar_to_uint},
ty::{Const, ConstKind, ValTree},
ToRustc,
};
use itertools::Itertools;
use rustc_hir::def_id::DefId;
use rustc_index::newtype_index;
use rustc_macros::{Decodable, Encodable, TyDecodable, TyEncodable};
use rustc_middle::{
mir::Local,
ty::{ParamConst, ScalarInt, TyCtxt},
};
use rustc_span::{Span, Symbol};
use rustc_target::abi::FieldIdx;
use rustc_type_ir::{BoundVar, DebruijnIndex, INNERMOST};
use super::{
evars::EVar, BaseTy, Binder, BoundReftKind, BoundVariableKinds, FuncSort, GenericArgs,
GenericArgsExt as _, IntTy, Sort, UintTy,
};
use crate::{
big_int::BigInt,
fhir::SpecFuncKind,
global_env::GlobalEnv,
queries::QueryResult,
rty::{
fold::{TypeFoldable, TypeFolder, TypeSuperFoldable},
BoundVariableKind, SortCtor,
},
};
#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
pub struct Lambda {
body: Binder<Expr>,
output: Sort,
}
impl Lambda {
pub fn bind_with_vars(body: Expr, inputs: BoundVariableKinds, output: Sort) -> Self {
debug_assert!(inputs.iter().all(BoundVariableKind::is_refine));
Self { body: Binder::bind_with_vars(body, inputs), output }
}
pub fn bind_with_fsort(body: Expr, fsort: FuncSort) -> Self {
Self { body: Binder::bind_with_sorts(body, fsort.inputs()), output: fsort.output().clone() }
}
pub fn apply(&self, args: &[Expr]) -> Expr {
self.body.replace_bound_refts(args)
}
pub fn vars(&self) -> &BoundVariableKinds {
self.body.vars()
}
pub fn output(&self) -> Sort {
self.output.clone()
}
pub fn fsort(&self) -> FuncSort {
let inputs_and_output = self
.vars()
.iter()
.map(|kind| kind.expect_sort().clone())
.chain(iter::once(self.output()))
.collect();
FuncSort { inputs_and_output }
}
}
#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
pub struct AliasReft {
pub trait_id: DefId,
pub name: Symbol,
pub args: GenericArgs,
}
impl AliasReft {
pub fn to_rustc_trait_ref<'tcx>(&self, tcx: TyCtxt<'tcx>) -> rustc_middle::ty::TraitRef<'tcx> {
let trait_def_id = self.trait_id;
let args = self
.args
.to_rustc(tcx)
.truncate_to(tcx, tcx.generics_of(trait_def_id));
rustc_middle::ty::TraitRef::new(tcx, trait_def_id, args)
}
}
#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
pub struct Expr {
kind: Interned<ExprKind>,
espan: Option<ESpan>,
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, TyEncodable, TyDecodable, Debug)]
pub struct ESpan {
pub span: Span,
pub base: Option<Span>,
}
impl ESpan {
pub fn new(span: Span) -> Self {
Self { span, base: None }
}
pub fn with_base(&self, espan: ESpan) -> Self {
Self { span: self.span, base: Some(espan.span) }
}
}
#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeFoldable, TypeVisitable)]
pub enum BinOp {
Iff,
Imp,
Or,
And,
Eq,
Ne,
Gt(Sort),
Ge(Sort),
Lt(Sort),
Le(Sort),
Add,
Sub,
Mul,
Div,
Mod,
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, Encodable, Decodable)]
pub enum UnOp {
Not,
Neg,
}
#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
pub enum ExprKind {
Var(Var),
Local(Local),
Constant(Constant),
ConstDefId(DefId),
BinaryOp(BinOp, Expr, Expr),
GlobalFunc(Symbol, SpecFuncKind),
UnaryOp(UnOp, Expr),
FieldProj(Expr, FieldProj),
Aggregate(AggregateKind, List<Expr>),
PathProj(Expr, FieldIdx),
IfThenElse(Expr, Expr, Expr),
KVar(KVar),
Alias(AliasReft, List<Expr>),
App(Expr, List<Expr>),
Abs(Lambda),
Hole(HoleKind),
ForAll(Binder<Expr>),
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, TyEncodable, TyDecodable, Debug)]
pub enum AggregateKind {
Tuple(usize),
Adt(DefId),
}
impl AggregateKind {
pub fn to_proj(self, field: u32) -> FieldProj {
match self {
AggregateKind::Tuple(arity) => FieldProj::Tuple { arity, field },
AggregateKind::Adt(def_id) => FieldProj::Adt { def_id, field },
}
}
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, TyEncodable, TyDecodable, Debug)]
pub enum FieldProj {
Tuple { arity: usize, field: u32 },
Adt { def_id: DefId, field: u32 },
}
impl FieldProj {
pub fn arity(&self, genv: GlobalEnv) -> QueryResult<usize> {
match self {
FieldProj::Tuple { arity, .. } => Ok(*arity),
FieldProj::Adt { def_id, .. } => Ok(genv.adt_sort_def_of(*def_id)?.fields()),
}
}
pub fn field_idx(&self) -> u32 {
match self {
FieldProj::Tuple { field, .. } | FieldProj::Adt { field, .. } => *field,
}
}
}
#[derive(
Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeFoldable, TypeVisitable,
)]
pub enum HoleKind {
Pred,
Expr(Sort),
}
#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
pub struct KVar {
pub kvid: KVid,
pub self_args: usize,
pub args: List<Expr>,
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Encodable, Decodable)]
pub struct EarlyReftParam {
pub index: u32,
pub name: Symbol,
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Encodable, Decodable, Debug)]
pub struct BoundReft {
pub var: BoundVar,
pub kind: BoundReftKind,
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, TyEncodable, TyDecodable)]
pub enum Var {
Free(Name),
Bound(DebruijnIndex, BoundReft),
EarlyParam(EarlyReftParam),
EVar(EVar),
ConstGeneric(ParamConst),
}
#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, TyEncodable, TyDecodable)]
pub struct Path {
pub loc: Loc,
projection: List<FieldIdx>,
}
#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, TyEncodable, TyDecodable)]
pub enum Loc {
Local(Local),
Var(Var),
}
newtype_index! {
#[debug_format = "$k{}"]
#[encodable]
pub struct KVid {}
}
newtype_index! {
#[debug_format = "a{}"]
#[orderable]
#[encodable]
pub struct Name {}
}
impl ExprKind {
fn intern(self) -> Expr {
Expr { kind: Interned::new(self), espan: None }
}
}
impl Expr {
pub fn at_opt(self, espan: Option<ESpan>) -> Expr {
Expr { kind: self.kind, espan }
}
pub fn at(self, espan: ESpan) -> Expr {
self.at_opt(Some(espan))
}
pub fn at_base(self, base: ESpan) -> Expr {
if let Some(espan) = self.espan {
self.at(espan.with_base(base))
} else {
self
}
}
pub fn span(&self) -> Option<ESpan> {
self.espan
}
pub fn tt() -> Expr {
static TRUE: OnceLock<Expr> = OnceLock::new();
TRUE.get_or_init(|| ExprKind::Constant(Constant::Bool(true)).intern())
.clone()
}
pub fn ff() -> Expr {
static FALSE: OnceLock<Expr> = OnceLock::new();
FALSE
.get_or_init(|| ExprKind::Constant(Constant::Bool(false)).intern())
.clone()
}
pub fn and_from_iter(exprs: impl IntoIterator<Item = Expr>) -> Expr {
exprs
.into_iter()
.reduce(|acc, e| Expr::binary_op(BinOp::And, acc, e))
.unwrap_or_else(Expr::tt)
}
pub fn or_from_iter(exprs: impl IntoIterator<Item = Expr>) -> Expr {
exprs
.into_iter()
.reduce(|acc, e| Expr::binary_op(BinOp::Or, acc, e))
.unwrap_or_else(Expr::ff)
}
pub fn and(e1: impl Into<Expr>, e2: impl Into<Expr>) -> Expr {
Expr::and_from_iter([e1.into(), e2.into()])
}
pub fn or(e1: impl Into<Expr>, e2: impl Into<Expr>) -> Expr {
Expr::or_from_iter([e1.into(), e2.into()])
}
pub fn zero() -> Expr {
static ZERO: OnceLock<Expr> = OnceLock::new();
ZERO.get_or_init(|| ExprKind::Constant(Constant::ZERO).intern())
.clone()
}
pub fn int_max(int_ty: IntTy) -> Expr {
let bit_width: u64 = int_ty
.bit_width()
.unwrap_or(flux_config::pointer_width().bits());
Expr::constant(Constant::int_max(bit_width.try_into().unwrap()))
}
pub fn int_min(int_ty: IntTy) -> Expr {
let bit_width: u64 = int_ty
.bit_width()
.unwrap_or(flux_config::pointer_width().bits());
Expr::constant(Constant::int_min(bit_width.try_into().unwrap()))
}
pub fn uint_max(uint_ty: UintTy) -> Expr {
let bit_width: u64 = uint_ty
.bit_width()
.unwrap_or(flux_config::pointer_width().bits());
Expr::constant(Constant::uint_max(bit_width.try_into().unwrap()))
}
pub fn nu() -> Expr {
Expr::bvar(INNERMOST, BoundVar::ZERO, BoundReftKind::Annon)
}
pub fn is_nu(&self) -> bool {
if let ExprKind::Var(Var::Bound(INNERMOST, var)) = self.kind()
&& var.var == BoundVar::ZERO
{
true
} else {
false
}
}
#[track_caller]
pub fn expect_adt(&self) -> (DefId, List<Expr>) {
if let ExprKind::Aggregate(AggregateKind::Adt(def_id), flds) = self.kind() {
(*def_id, flds.clone())
} else {
bug!("expected record, found {self:?}")
}
}
pub fn unit() -> Expr {
Expr::tuple(List::empty())
}
pub fn var(var: Var) -> Expr {
ExprKind::Var(var).intern()
}
pub fn fvar(name: Name) -> Expr {
Var::Free(name).to_expr()
}
pub fn evar(evar: EVar) -> Expr {
Var::EVar(evar).to_expr()
}
pub fn bvar(debruijn: DebruijnIndex, var: BoundVar, kind: BoundReftKind) -> Expr {
Var::Bound(debruijn, BoundReft { var, kind }).to_expr()
}
pub fn early_param(index: u32, name: Symbol) -> Expr {
Var::EarlyParam(EarlyReftParam { index, name }).to_expr()
}
pub fn local(local: Local) -> Expr {
ExprKind::Local(local).intern()
}
pub fn constant(c: Constant) -> Expr {
ExprKind::Constant(c).intern()
}
pub fn const_def_id(c: DefId) -> Expr {
ExprKind::ConstDefId(c).intern()
}
pub fn const_generic(param: ParamConst) -> Expr {
ExprKind::Var(Var::ConstGeneric(param)).intern()
}
pub fn aggregate(kind: AggregateKind, flds: List<Expr>) -> Expr {
ExprKind::Aggregate(kind, flds).intern()
}
pub fn tuple(flds: List<Expr>) -> Expr {
Expr::aggregate(AggregateKind::Tuple(flds.len()), flds)
}
pub fn adt(def_id: DefId, flds: List<Expr>) -> Expr {
ExprKind::Aggregate(AggregateKind::Adt(def_id), flds).intern()
}
pub fn from_bits(bty: &BaseTy, bits: u128) -> Expr {
match bty {
BaseTy::Int(_) => {
let bits = bits as i128;
ExprKind::Constant(Constant::from(bits)).intern()
}
BaseTy::Uint(_) => ExprKind::Constant(Constant::from(bits)).intern(),
BaseTy::Bool => ExprKind::Constant(Constant::Bool(bits != 0)).intern(),
BaseTy::Char => {
let c = char::from_u32(bits.try_into().unwrap()).unwrap();
ExprKind::Constant(Constant::Char(c)).intern()
}
_ => bug!(),
}
}
pub fn ite(p: impl Into<Expr>, e1: impl Into<Expr>, e2: impl Into<Expr>) -> Expr {
ExprKind::IfThenElse(p.into(), e1.into(), e2.into()).intern()
}
pub fn abs(lam: Lambda) -> Expr {
ExprKind::Abs(lam).intern()
}
pub fn hole(kind: HoleKind) -> Expr {
ExprKind::Hole(kind).intern()
}
pub fn kvar(kvar: KVar) -> Expr {
ExprKind::KVar(kvar).intern()
}
pub fn alias(alias: AliasReft, args: List<Expr>) -> Expr {
ExprKind::Alias(alias, args).intern()
}
pub fn forall(expr: Binder<Expr>) -> Expr {
ExprKind::ForAll(expr).intern()
}
pub fn binary_op(op: BinOp, e1: impl Into<Expr>, e2: impl Into<Expr>) -> Expr {
ExprKind::BinaryOp(op, e1.into(), e2.into()).intern()
}
pub fn unit_adt(def_id: DefId) -> Expr {
Expr::adt(def_id, List::empty())
}
pub fn app(func: impl Into<Expr>, args: List<Expr>) -> Expr {
ExprKind::App(func.into(), args).intern()
}
pub fn global_func(func: Symbol, kind: SpecFuncKind) -> Expr {
ExprKind::GlobalFunc(func, kind).intern()
}
pub fn eq(e1: impl Into<Expr>, e2: impl Into<Expr>) -> Expr {
ExprKind::BinaryOp(BinOp::Eq, e1.into(), e2.into()).intern()
}
pub fn unary_op(op: UnOp, e: impl Into<Expr>) -> Expr {
ExprKind::UnaryOp(op, e.into()).intern()
}
pub fn ne(e1: impl Into<Expr>, e2: impl Into<Expr>) -> Expr {
ExprKind::BinaryOp(BinOp::Ne, e1.into(), e2.into()).intern()
}
pub fn ge(e1: impl Into<Expr>, e2: impl Into<Expr>) -> Expr {
ExprKind::BinaryOp(BinOp::Ge(Sort::Int), e1.into(), e2.into()).intern()
}
pub fn gt(e1: impl Into<Expr>, e2: impl Into<Expr>) -> Expr {
ExprKind::BinaryOp(BinOp::Gt(Sort::Int), e1.into(), e2.into()).intern()
}
pub fn lt(e1: impl Into<Expr>, e2: impl Into<Expr>) -> Expr {
ExprKind::BinaryOp(BinOp::Lt(Sort::Int), e1.into(), e2.into()).intern()
}
pub fn le(e1: impl Into<Expr>, e2: impl Into<Expr>) -> Expr {
ExprKind::BinaryOp(BinOp::Le(Sort::Int), e1.into(), e2.into()).intern()
}
pub fn implies(e1: impl Into<Expr>, e2: impl Into<Expr>) -> Expr {
ExprKind::BinaryOp(BinOp::Imp, e1.into(), e2.into()).intern()
}
pub fn field_proj(e: impl Into<Expr>, proj: FieldProj) -> Expr {
ExprKind::FieldProj(e.into(), proj).intern()
}
pub fn field_projs(e: impl Into<Expr>, projs: &[FieldProj]) -> Expr {
projs.iter().copied().fold(e.into(), Expr::field_proj)
}
pub fn path_proj(base: Expr, field: FieldIdx) -> Expr {
ExprKind::PathProj(base, field).intern()
}
pub fn not(&self) -> Expr {
ExprKind::UnaryOp(UnOp::Not, self.clone()).intern()
}
pub fn neg(&self) -> Expr {
ExprKind::UnaryOp(UnOp::Neg, self.clone()).intern()
}
pub fn kind(&self) -> &ExprKind {
&self.kind
}
pub fn is_atom(&self) -> bool {
!matches!(self.kind(), ExprKind::Abs(..) | ExprKind::BinaryOp(..) | ExprKind::ForAll(..))
}
pub fn is_trivially_true(&self) -> bool {
self.is_true()
|| matches!(self.kind(), ExprKind::BinaryOp(BinOp::Eq | BinOp::Iff | BinOp::Imp, e1, e2) if e1 == e2)
}
pub fn is_trivially_false(&self) -> bool {
self.is_false()
}
fn is_true(&self) -> bool {
matches!(self.kind(), ExprKind::Constant(Constant::Bool(true)))
}
fn is_false(&self) -> bool {
matches!(self.kind(), ExprKind::Constant(Constant::Bool(false)))
}
pub fn from_const(tcx: TyCtxt, c: &Const) -> Expr {
match &c.kind {
ConstKind::Param(param_const) => Expr::const_generic(*param_const),
ConstKind::Value(ty, ValTree::Leaf(scalar)) => {
Expr::constant(Constant::from_scalar_int(tcx, *scalar, ty).unwrap())
}
ConstKind::Value(_ty, ValTree::Branch(_)) => {
bug!("todo: ValTree::Branch {c:?}")
}
ConstKind::Unevaluated(_) => bug!("unexpected `ConstKind::Unevaluated`"),
ConstKind::Infer(_) => bug!("unexpected `ConstKind::Infer`"),
}
}
pub fn is_binary_op(&self) -> bool {
matches!(self.kind(), ExprKind::BinaryOp(..))
}
fn const_op(op: &BinOp, c1: &Constant, c2: &Constant) -> Option<Constant> {
match op {
BinOp::Iff => c1.iff(c2),
BinOp::Imp => c1.imp(c2),
BinOp::Or => c1.or(c2),
BinOp::And => c1.and(c2),
BinOp::Gt(Sort::Int) => c1.gt(c2),
BinOp::Ge(Sort::Int) => c1.ge(c2),
BinOp::Lt(Sort::Int) => c2.gt(c1),
BinOp::Le(Sort::Int) => c2.ge(c1),
BinOp::Eq => Some(c1.eq(c2)),
BinOp::Ne => Some(c1.ne(c2)),
_ => None,
}
}
pub fn simplify(&self) -> Expr {
struct Simplify;
impl TypeFolder for Simplify {
fn fold_expr(&mut self, expr: &Expr) -> Expr {
let span = expr.span();
match expr.kind() {
ExprKind::BinaryOp(op, e1, e2) => {
let e1 = e1.fold_with(self);
let e2 = e2.fold_with(self);
match (op, e1.kind(), e2.kind()) {
(BinOp::And, ExprKind::Constant(Constant::Bool(false)), _) => {
Expr::constant(Constant::Bool(false)).at_opt(e1.span())
}
(BinOp::And, _, ExprKind::Constant(Constant::Bool(false))) => {
Expr::constant(Constant::Bool(false)).at_opt(e2.span())
}
(BinOp::And, ExprKind::Constant(Constant::Bool(true)), _) => e2,
(BinOp::And, _, ExprKind::Constant(Constant::Bool(true))) => e1,
(op, ExprKind::Constant(c1), ExprKind::Constant(c2)) => {
if let Some(c) = Expr::const_op(op, c1, c2) {
Expr::constant(c).at_opt(span.or(e2.span()))
} else {
Expr::binary_op(op.clone(), e1, e2).at_opt(span)
}
}
_ => Expr::binary_op(op.clone(), e1, e2).at_opt(span),
}
}
ExprKind::UnaryOp(UnOp::Not, e) => {
let e = e.fold_with(self);
match e.kind() {
ExprKind::Constant(Constant::Bool(b)) => {
Expr::constant(Constant::Bool(!b))
}
ExprKind::UnaryOp(UnOp::Not, e) => e.clone(),
ExprKind::BinaryOp(BinOp::Eq, e1, e2) => {
Expr::binary_op(BinOp::Ne, e1, e2).at_opt(span)
}
_ => Expr::unary_op(UnOp::Not, e).at_opt(span),
}
}
ExprKind::IfThenElse(p, e1, e2) => {
let p = p.fold_with(self);
if p.is_trivially_true() {
e1.fold_with(self).at_opt(span)
} else if p.is_trivially_false() {
e2.fold_with(self).at_opt(span)
} else {
Expr::ite(p, e1.fold_with(self), e2.fold_with(self)).at_opt(span)
}
}
_ => expr.super_fold_with(self),
}
}
}
self.fold_with(&mut Simplify)
}
pub fn to_loc(&self) -> Option<Loc> {
match self.kind() {
ExprKind::Local(local) => Some(Loc::Local(*local)),
ExprKind::Var(var) => Some(Loc::Var(*var)),
_ => None,
}
}
pub fn to_path(&self) -> Option<Path> {
let mut expr = self;
let mut proj = vec![];
while let ExprKind::PathProj(e, field) = expr.kind() {
proj.push(*field);
expr = e;
}
proj.reverse();
Some(Path::new(expr.to_loc()?, proj))
}
pub fn is_abs(&self) -> bool {
matches!(self.kind(), ExprKind::Abs(..))
}
pub fn is_unit(&self) -> bool {
matches!(self.kind(), ExprKind::Aggregate(_, flds) if flds.is_empty())
}
pub fn eta_expand_abs(&self, inputs: &BoundVariableKinds, output: Sort) -> Lambda {
let args = (0..inputs.len())
.map(|idx| Expr::bvar(INNERMOST, BoundVar::from_usize(idx), BoundReftKind::Annon))
.collect();
let body = Expr::app(self, args);
Lambda::bind_with_vars(body, inputs.clone(), output)
}
pub fn fold_sort(sort: &Sort, mut f: impl FnMut(&Sort) -> Expr) -> Expr {
fn go(sort: &Sort, f: &mut impl FnMut(&Sort) -> Expr) -> Expr {
match sort {
Sort::Tuple(sorts) => Expr::tuple(sorts.iter().map(|sort| go(sort, f)).collect()),
Sort::App(SortCtor::Adt(adt_sort_def), args) => {
let flds = adt_sort_def.field_sorts(args);
Expr::adt(adt_sort_def.did(), flds.iter().map(|sort| go(sort, f)).collect())
}
_ => f(sort),
}
}
go(sort, &mut f)
}
pub fn proj_and_reduce(&self, proj: FieldProj) -> Expr {
match self.kind() {
ExprKind::Aggregate(_, flds) => flds[proj.field_idx() as usize].clone(),
_ => Expr::field_proj(self.clone(), proj),
}
}
pub fn flatten_conjs(&self) -> Vec<&Expr> {
fn go<'a>(e: &'a Expr, vec: &mut Vec<&'a Expr>) {
if let ExprKind::BinaryOp(BinOp::And, e1, e2) = e.kind() {
go(e1, vec);
go(e2, vec);
} else {
vec.push(e);
}
}
let mut vec = vec![];
go(self, &mut vec);
vec
}
}
impl KVar {
pub fn new(kvid: KVid, self_args: usize, args: Vec<Expr>) -> Self {
KVar { kvid, self_args, args: List::from_vec(args) }
}
fn self_args(&self) -> &[Expr] {
&self.args[..self.self_args]
}
fn scope(&self) -> &[Expr] {
&self.args[self.self_args..]
}
}
impl Var {
pub fn to_expr(&self) -> Expr {
Expr::var(*self)
}
}
impl Path {
pub fn new(loc: Loc, projection: impl Into<List<FieldIdx>>) -> Path {
Path { loc, projection: projection.into() }
}
pub fn projection(&self) -> &[FieldIdx] {
&self.projection[..]
}
pub fn to_expr(&self) -> Expr {
self.projection
.iter()
.fold(self.loc.to_expr(), |e, f| Expr::path_proj(e, *f))
}
pub fn to_loc(&self) -> Option<Loc> {
if self.projection.is_empty() {
Some(self.loc)
} else {
None
}
}
}
impl Loc {
pub fn to_expr(&self) -> Expr {
match self {
Loc::Local(local) => Expr::local(*local),
Loc::Var(var) => Expr::var(*var),
}
}
}
macro_rules! impl_ops {
($($op:ident: $method:ident),*) => {$(
impl<Rhs> std::ops::$op<Rhs> for Expr
where
Rhs: Into<Expr>,
{
type Output = Expr;
fn $method(self, rhs: Rhs) -> Self::Output {
Expr::binary_op(BinOp::$op, self, rhs)
}
}
impl<Rhs> std::ops::$op<Rhs> for &Expr
where
Rhs: Into<Expr>,
{
type Output = Expr;
fn $method(self, rhs: Rhs) -> Self::Output {
Expr::binary_op(BinOp::$op, self, rhs)
}
}
)*};
}
impl_ops!(Add: add, Sub: sub, Mul: mul, Div: div);
impl From<i32> for Expr {
fn from(value: i32) -> Self {
Expr::constant(Constant::from(value))
}
}
impl From<&Expr> for Expr {
fn from(e: &Expr) -> Self {
e.clone()
}
}
impl From<Path> for Expr {
fn from(path: Path) -> Self {
path.to_expr()
}
}
impl From<Name> for Expr {
fn from(name: Name) -> Self {
Expr::fvar(name)
}
}
impl From<Var> for Expr {
fn from(var: Var) -> Self {
Expr::var(var)
}
}
impl From<Loc> for Path {
fn from(loc: Loc) -> Self {
Path::new(loc, vec![])
}
}
impl From<Name> for Loc {
fn from(name: Name) -> Self {
Loc::Var(Var::Free(name))
}
}
impl From<Local> for Loc {
fn from(local: Local) -> Self {
Loc::Local(local)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Encodable, Decodable)]
pub struct Real(pub i128);
impl liquid_fixpoint::FixpointFmt for Real {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.0 < 0 {
write!(f, "(- {}.0)", self.0.unsigned_abs())
} else {
write!(f, "{}.0", self.0)
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Encodable, Decodable)]
pub enum Constant {
Int(BigInt),
Real(Real),
Bool(bool),
Str(Symbol),
Char(char),
}
impl Constant {
pub const ZERO: Constant = Constant::Int(BigInt::ZERO);
pub const ONE: Constant = Constant::Int(BigInt::ONE);
pub const TRUE: Constant = Constant::Bool(true);
fn to_bool(self) -> Option<bool> {
match self {
Constant::Bool(b) => Some(b),
_ => None,
}
}
fn to_int(self) -> Option<BigInt> {
match self {
Constant::Int(n) => Some(n),
_ => None,
}
}
pub fn iff(&self, other: &Constant) -> Option<Constant> {
let b1 = self.to_bool()?;
let b2 = other.to_bool()?;
Some(Constant::Bool(b1 == b2))
}
pub fn imp(&self, other: &Constant) -> Option<Constant> {
let b1 = self.to_bool()?;
let b2 = other.to_bool()?;
Some(Constant::Bool(!b1 || b2))
}
pub fn or(&self, other: &Constant) -> Option<Constant> {
let b1 = self.to_bool()?;
let b2 = other.to_bool()?;
Some(Constant::Bool(b1 || b2))
}
pub fn and(&self, other: &Constant) -> Option<Constant> {
let b1 = self.to_bool()?;
let b2 = other.to_bool()?;
Some(Constant::Bool(b1 && b2))
}
pub fn eq(&self, other: &Constant) -> Constant {
Constant::Bool(*self == *other)
}
pub fn ne(&self, other: &Constant) -> Constant {
Constant::Bool(*self != *other)
}
pub fn gt(&self, other: &Constant) -> Option<Constant> {
let n1 = self.to_int()?;
let n2 = other.to_int()?;
Some(Constant::Bool(n1 > n2))
}
pub fn ge(&self, other: &Constant) -> Option<Constant> {
let n1 = self.to_int()?;
let n2 = other.to_int()?;
Some(Constant::Bool(n1 >= n2))
}
pub fn from_scalar_int<'tcx, T>(tcx: TyCtxt<'tcx>, scalar: ScalarInt, t: &T) -> Option<Self>
where
T: ToRustc<'tcx, T = rustc_middle::ty::Ty<'tcx>>,
{
use rustc_middle::ty::TyKind;
let ty = t.to_rustc(tcx);
match ty.kind() {
TyKind::Int(int_ty) => Some(Constant::from(scalar_to_int(tcx, scalar, *int_ty))),
TyKind::Uint(uint_ty) => Some(Constant::from(scalar_to_uint(tcx, scalar, *uint_ty))),
TyKind::Bool => {
let b = scalar_to_bits(tcx, scalar, ty)?;
Some(Constant::Bool(b != 0))
}
TyKind::Char => {
let b = scalar_to_bits(tcx, scalar, ty)?;
Some(Constant::Char(char::from_u32(b as u32)?))
}
_ => bug!(),
}
}
pub fn int_min(bit_width: u32) -> Constant {
Constant::Int(BigInt::int_min(bit_width))
}
pub fn int_max(bit_width: u32) -> Constant {
Constant::Int(BigInt::int_max(bit_width))
}
pub fn uint_max(bit_width: u32) -> Constant {
Constant::Int(BigInt::uint_max(bit_width))
}
}
impl From<i32> for Constant {
fn from(c: i32) -> Self {
Constant::Int(c.into())
}
}
impl From<usize> for Constant {
fn from(u: usize) -> Self {
Constant::Int(u.into())
}
}
impl From<u128> for Constant {
fn from(c: u128) -> Self {
Constant::Int(c.into())
}
}
impl From<i128> for Constant {
fn from(c: i128) -> Self {
Constant::Int(c.into())
}
}
impl From<bool> for Constant {
fn from(b: bool) -> Self {
Constant::Bool(b)
}
}
impl From<Symbol> for Constant {
fn from(s: Symbol) -> Self {
Constant::Str(s)
}
}
impl From<char> for Constant {
fn from(c: char) -> Self {
Constant::Char(c)
}
}
impl_internable!(ExprKind);
impl_slice_internable!(Expr, KVar);
mod pretty {
use flux_rustc_bridge::def_id_to_string;
use super::*;
use crate::pretty::*;
#[derive(PartialEq, Eq, PartialOrd, Ord)]
enum Precedence {
Iff,
Imp,
Or,
And,
Cmp,
AddSub,
MulDiv,
}
impl BinOp {
fn precedence(&self) -> Precedence {
match self {
BinOp::Iff => Precedence::Iff,
BinOp::Imp => Precedence::Imp,
BinOp::Or => Precedence::Or,
BinOp::And => Precedence::And,
BinOp::Eq
| BinOp::Ne
| BinOp::Gt(_)
| BinOp::Lt(_)
| BinOp::Ge(_)
| BinOp::Le(_) => Precedence::Cmp,
BinOp::Add | BinOp::Sub => Precedence::AddSub,
BinOp::Mul | BinOp::Div | BinOp::Mod => Precedence::MulDiv,
}
}
}
impl Precedence {
pub fn is_associative(&self) -> bool {
!matches!(self, Precedence::Imp | Precedence::Cmp)
}
}
impl Pretty for Expr {
fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
define_scoped!(cx, f);
fn should_parenthesize(op: &BinOp, child: &Expr) -> bool {
if let ExprKind::BinaryOp(child_op, ..) = child.kind() {
child_op.precedence() < op.precedence()
|| (child_op.precedence() == op.precedence()
&& !op.precedence().is_associative())
} else {
false
}
}
let e = if cx.simplify_exprs { self.simplify() } else { self.clone() };
match e.kind() {
ExprKind::Var(var) => w!("{:?}", var),
ExprKind::Local(local) => w!("{:?}", ^local),
ExprKind::ConstDefId(did) => w!("{}", ^def_id_to_string(*did)),
ExprKind::Constant(c) => w!("{:?}", c),
ExprKind::BinaryOp(op, e1, e2) => {
if should_parenthesize(op, e1) {
w!("({:?})", e1)?;
} else {
w!("{:?}", e1)?;
}
if matches!(op, BinOp::Div) {
w!("{:?}", op)?;
} else {
w!(" {:?} ", op)?;
}
if should_parenthesize(op, e2) {
w!("({:?})", e2)?;
} else {
w!("{:?}", e2)?;
}
Ok(())
}
ExprKind::UnaryOp(op, e) => {
if e.is_atom() {
w!("{:?}{:?}", op, e)
} else {
w!("{:?}({:?})", op, e)
}
}
ExprKind::FieldProj(e, proj) => {
if e.is_atom() {
w!("{:?}.{:?}", e, ^proj.field_idx())
} else {
w!("({:?}).{:?}", e, ^proj.field_idx())
}
}
ExprKind::Aggregate(AggregateKind::Tuple(_), flds) => {
if let [e] = &flds[..] {
w!("({:?},)", e)
} else {
w!("({:?})", join!(", ", flds))
}
}
ExprKind::Aggregate(AggregateKind::Adt(def_id), flds) => {
w!("{:?} {{ {:?} }}", def_id, join!(", ", flds))
}
ExprKind::PathProj(e, field) => {
if e.is_atom() {
w!("{:?}.{:?}", e, field)
} else {
w!("({:?}).{:?}", e, field)
}
}
ExprKind::App(func, args) => {
w!("({:?})({})",
func,
^args
.iter()
.format_with(", ", |arg, f| f(&format_args_cx!("{:?}", arg)))
)
}
ExprKind::IfThenElse(p, e1, e2) => {
w!("if {:?} {{ {:?} }} else {{ {:?} }}", p, e1, e2)
}
ExprKind::Hole(_) => {
w!("*")
}
ExprKind::KVar(kvar) => {
w!("{:?}", kvar)
}
ExprKind::Alias(alias, args) => {
w!("{:?}({:?})", alias, join!(", ", args))
}
ExprKind::Abs(lam) => {
w!("{:?}", lam)
}
ExprKind::GlobalFunc(func, _) => w!("{}", ^func),
ExprKind::ForAll(expr) => {
let vars = expr.vars();
cx.with_bound_vars(vars, || {
if !vars.is_empty() {
cx.fmt_bound_vars(false, "∀", vars, ". ", f)?;
}
w!("{:?}", expr.as_ref().skip_binder())
})
}
}
}
}
impl Pretty for Constant {
fn fmt(&self, _cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
define_scoped!(cx, f);
match self {
Constant::Int(i) => w!("{i}"),
Constant::Real(r) => w!("{}.0", ^r.0),
Constant::Bool(b) => w!("{b}"),
Constant::Str(sym) => w!("\"{sym}\""),
Constant::Char(c) => write!(f, "\'{c}\'"),
}
}
}
impl Pretty for AliasReft {
fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
define_scoped!(cx, f);
w!("<({:?}) as {:?}", &self.args[0], self.trait_id)?;
let args = &self.args[1..];
if !args.is_empty() {
w!("<{:?}>", join!(", ", args))?;
}
w!(">::{}", ^self.name)
}
}
impl Pretty for Lambda {
fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
define_scoped!(cx, f);
let vars = self.body.vars();
cx.with_bound_vars(vars, || {
cx.fmt_bound_vars(false, "λ", vars, ". ", f)?;
w!("{:?}", self.body.as_ref().skip_binder())
})
}
}
impl Pretty for Var {
fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
define_scoped!(cx, f);
match self {
Var::Bound(debruijn, var) => cx.fmt_bound_reft(*debruijn, *var, f),
Var::EarlyParam(var) => w!("{}", ^var.name),
Var::Free(name) => w!("{:?}", ^name),
Var::EVar(evar) => w!("{:?}", evar),
Var::ConstGeneric(param) => w!("{}", ^param.name),
}
}
}
impl Pretty for KVar {
fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
define_scoped!(cx, f);
w!("{:?}", ^self.kvid)?;
match cx.kvar_args {
KVarArgs::All => {
w!("({:?})[{:?}]", join!(", ", self.self_args()), join!(", ", self.scope()))?;
}
KVarArgs::SelfOnly => w!("({:?})", join!(", ", self.self_args()))?,
KVarArgs::Hide => {}
}
Ok(())
}
}
impl Pretty for Path {
fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
define_scoped!(cx, f);
w!("{:?}", &self.loc)?;
for field in &self.projection {
w!(".{}", ^u32::from(*field))?;
}
Ok(())
}
}
impl Pretty for Loc {
fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
define_scoped!(cx, f);
match self {
Loc::Local(local) => w!("{:?}", ^local),
Loc::Var(var) => w!("{:?}", var),
}
}
}
impl Pretty for BinOp {
fn fmt(&self, _cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
define_scoped!(cx, f);
match self {
BinOp::Iff => w!("⇔"),
BinOp::Imp => w!("⇒"),
BinOp::Or => w!("∨"),
BinOp::And => w!("∧"),
BinOp::Eq => w!("="),
BinOp::Ne => w!("≠"),
BinOp::Gt(_) => w!(">"),
BinOp::Ge(_) => w!("≥"),
BinOp::Lt(_) => w!("<"),
BinOp::Le(_) => w!("≤"),
BinOp::Add => w!("+"),
BinOp::Sub => w!("-"),
BinOp::Mul => w!("*"),
BinOp::Div => w!("/"),
BinOp::Mod => w!("mod"),
}
}
}
impl Pretty for UnOp {
fn fmt(&self, _cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
define_scoped!(cx, f);
match self {
UnOp::Not => w!("¬"),
UnOp::Neg => w!("-"),
}
}
}
impl_debug_with_default_cx!(Expr, Loc, Path, Var, KVar, Lambda, AliasReft);
}