Skip to content

[SandboxIR] Implement ConstantFP #106648

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 93 additions & 1 deletion llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ namespace sandboxir {

class BasicBlock;
class ConstantInt;
class ConstantFP;
class Context;
class Function;
class Instruction;
Expand Down Expand Up @@ -597,6 +598,94 @@ class ConstantInt : public Constant {
#endif
};

// TODO: This should inherit from ConstantData.
class ConstantFP final : public Constant {
ConstantFP(llvm::ConstantFP *C, Context &Ctx)
: Constant(ClassID::ConstantFP, C, Ctx) {}
friend class Context; // For constructor.

public:
/// This returns a ConstantFP, or a vector containing a splat of a ConstantFP,
/// for the specified value in the specified type. This should only be used
/// for simple constant values like 2.0/1.0 etc, that are known-valid both as
/// host double and as the target format.
static Constant *get(Type *Ty, double V);

/// If Ty is a vector type, return a Constant with a splat of the given
/// value. Otherwise return a ConstantFP for the given value.
static Constant *get(Type *Ty, const APFloat &V);

static Constant *get(Type *Ty, StringRef Str);

static ConstantFP *get(const APFloat &V, Context &Ctx);

static Constant *getNaN(Type *Ty, bool Negative = false,
uint64_t Payload = 0);
static Constant *getQNaN(Type *Ty, bool Negative = false,
APInt *Payload = nullptr);
static Constant *getSNaN(Type *Ty, bool Negative = false,
APInt *Payload = nullptr);
static Constant *getZero(Type *Ty, bool Negative = false);

static Constant *getNegativeZero(Type *Ty);
static Constant *getInfinity(Type *Ty, bool Negative = false);

/// Return true if Ty is big enough to represent V.
static bool isValueValidForType(Type *Ty, const APFloat &V);

inline const APFloat &getValueAPF() const {
return cast<llvm::ConstantFP>(Val)->getValueAPF();
}
inline const APFloat &getValue() const {
return cast<llvm::ConstantFP>(Val)->getValue();
}

/// Return true if the value is positive or negative zero.
bool isZero() const { return cast<llvm::ConstantFP>(Val)->isZero(); }

/// Return true if the sign bit is set.
bool isNegative() const { return cast<llvm::ConstantFP>(Val)->isNegative(); }

/// Return true if the value is infinity
bool isInfinity() const { return cast<llvm::ConstantFP>(Val)->isInfinity(); }

/// Return true if the value is a NaN.
bool isNaN() const { return cast<llvm::ConstantFP>(Val)->isNaN(); }

/// We don't rely on operator== working on double values, as it returns true
/// for things that are clearly not equal, like -0.0 and 0.0.
/// As such, this method can be used to do an exact bit-for-bit comparison of
/// two floating point values. The version with a double operand is retained
/// because it's so convenient to write isExactlyValue(2.0), but please use
/// it only for simple constants.
bool isExactlyValue(const APFloat &V) const {
return cast<llvm::ConstantFP>(Val)->isExactlyValue(V);
}

bool isExactlyValue(double V) const {
return cast<llvm::ConstantFP>(Val)->isExactlyValue(V);
}

/// For isa/dyn_cast.
static bool classof(const sandboxir::Value *From) {
return From->getSubclassID() == ClassID::ConstantFP;
}

// TODO: Better name: getOperandNo(const Use&). Should be private.
unsigned getUseOperandNo(const Use &Use) const final {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am trying to match this with LLVM IR. Why not just call this "getOperandNo" with the Use as parameter?

Copy link
Contributor Author

@vporpo vporpo Aug 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a SandboxIR-specific function, it is not wrapping any LLVM IR one. We should probably make this private at some point. But yeah getOperandNo(const Use &) is a much better name.

llvm_unreachable("ConstantFP has no operands!");
}
#ifndef NDEBUG
void verify() const override {
assert(isa<llvm::ConstantFP>(Val) && "Expected a ConstantFP!");
}
void dumpOS(raw_ostream &OS) const override {
dumpCommonPrefix(OS);
dumpCommonSuffix(OS);
}
#endif
};

/// Iterator for `Instruction`s in a `BasicBlock.
/// \Returns an sandboxir::Instruction & when derereferenced.
class BBIterator {
Expand Down Expand Up @@ -3156,7 +3245,10 @@ class Context {
Constant *getOrCreateConstant(llvm::Constant *LLVMC) {
return cast<Constant>(getOrCreateValueInternal(LLVMC, 0));
}
friend class ConstantInt; // For getOrCreateConstant().
// Friends for getOrCreateConstant().
#define DEF_CONST(ID, CLASS) friend class CLASS;
#include "llvm/SandboxIR/SandboxIRValues.def"

/// Create a sandboxir::BasicBlock for an existing LLVM IR \p BB. This will
/// also create all contents of the block.
BasicBlock *createBasicBlock(llvm::BasicBlock *BB);
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/SandboxIR/SandboxIRValues.def
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ DEF_USER(User, User)
DEF_VALUE(Block, BasicBlock)
DEF_CONST(Constant, Constant)
DEF_CONST(ConstantInt, ConstantInt)
DEF_CONST(ConstantFP, ConstantFP)

#ifndef DEF_INSTR
#define DEF_INSTR(ID, OPCODE, CLASS)
Expand Down
3 changes: 2 additions & 1 deletion llvm/include/llvm/SandboxIR/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class PointerType;
class VectorType;
class FunctionType;
#define DEF_INSTR(ID, OPCODE, CLASS) class CLASS;
#define DEF_CONST(ID, CLASS) class CLASS;
#include "llvm/SandboxIR/SandboxIRValues.def"

/// Just like llvm::Type these are immutable, unique, never get freed and can
Expand All @@ -42,7 +43,7 @@ class Type {
friend class ConstantInt; // For LLVMTy.
// Friend all instruction classes because `create()` functions use LLVMTy.
#define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS;
// TODO: Friend DEF_CONST()
#define DEF_CONST(ID, CLASS) friend class CLASS;
#include "llvm/SandboxIR/SandboxIRValues.def"
Context &Ctx;

Expand Down
52 changes: 52 additions & 0 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2248,6 +2248,54 @@ ConstantInt *ConstantInt::get(Type *Ty, uint64_t V, bool IsSigned) {
return cast<ConstantInt>(Ty->getContext().getOrCreateConstant(LLVMC));
}

Constant *ConstantFP::get(Type *Ty, double V) {
auto *LLVMC = llvm::ConstantFP::get(Ty->LLVMTy, V);
return Ty->getContext().getOrCreateConstant(LLVMC);
}

Constant *ConstantFP::get(Type *Ty, const APFloat &V) {
auto *LLVMC = llvm::ConstantFP::get(Ty->LLVMTy, V);
return Ty->getContext().getOrCreateConstant(LLVMC);
}

Constant *ConstantFP::get(Type *Ty, StringRef Str) {
auto *LLVMC = llvm::ConstantFP::get(Ty->LLVMTy, Str);
return Ty->getContext().getOrCreateConstant(LLVMC);
}

ConstantFP *ConstantFP::get(const APFloat &V, Context &Ctx) {
auto *LLVMC = llvm::ConstantFP::get(Ctx.LLVMCtx, V);
return cast<ConstantFP>(Ctx.getOrCreateConstant(LLVMC));
}

Constant *ConstantFP::getNaN(Type *Ty, bool Negative, uint64_t Payload) {
auto *LLVMC = llvm::ConstantFP::getNaN(Ty->LLVMTy, Negative, Payload);
return cast<Constant>(Ty->getContext().getOrCreateConstant(LLVMC));
}
Constant *ConstantFP::getQNaN(Type *Ty, bool Negative, APInt *Payload) {
auto *LLVMC = llvm::ConstantFP::getQNaN(Ty->LLVMTy, Negative, Payload);
return cast<Constant>(Ty->getContext().getOrCreateConstant(LLVMC));
}
Constant *ConstantFP::getSNaN(Type *Ty, bool Negative, APInt *Payload) {
auto *LLVMC = llvm::ConstantFP::getSNaN(Ty->LLVMTy, Negative, Payload);
return cast<Constant>(Ty->getContext().getOrCreateConstant(LLVMC));
}
Constant *ConstantFP::getZero(Type *Ty, bool Negative) {
auto *LLVMC = llvm::ConstantFP::getZero(Ty->LLVMTy, Negative);
return cast<Constant>(Ty->getContext().getOrCreateConstant(LLVMC));
}
Constant *ConstantFP::getNegativeZero(Type *Ty) {
auto *LLVMC = llvm::ConstantFP::getNegativeZero(Ty->LLVMTy);
return cast<Constant>(Ty->getContext().getOrCreateConstant(LLVMC));
}
Constant *ConstantFP::getInfinity(Type *Ty, bool Negative) {
auto *LLVMC = llvm::ConstantFP::getInfinity(Ty->LLVMTy, Negative);
return cast<Constant>(Ty->getContext().getOrCreateConstant(LLVMC));
}
bool ConstantFP::isValueValidForType(Type *Ty, const APFloat &V) {
return llvm::ConstantFP::isValueValidForType(Ty->LLVMTy, V);
}

FunctionType *Function::getFunctionType() const {
return cast<FunctionType>(
Ctx.getType(cast<llvm::Function>(Val)->getFunctionType()));
Expand Down Expand Up @@ -2339,6 +2387,10 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
It->second = std::unique_ptr<ConstantInt>(new ConstantInt(CI, *this));
return It->second.get();
}
if (auto *CF = dyn_cast<llvm::ConstantFP>(C)) {
It->second = std::unique_ptr<ConstantFP>(new ConstantFP(CF, *this));
return It->second.get();
}
if (auto *F = dyn_cast<llvm::Function>(LLVMV))
It->second = std::unique_ptr<Function>(new Function(F, *this));
else
Expand Down
155 changes: 155 additions & 0 deletions llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,161 @@ define void @foo(i32 %v0) {
EXPECT_NE(FortyThree, FortyTwo);
}

TEST_F(SandboxIRTest, ConstantFP) {
parseIR(C, R"IR(
define void @foo(float %v0, double %v1) {
%fadd0 = fadd float %v0, 42.0
%fadd1 = fadd double %v1, 43.0
ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);

auto &F = *Ctx.createFunction(&LLVMF);
auto &BB = *F.begin();
auto It = BB.begin();
auto *FAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
auto *FAdd1 = cast<sandboxir::BinaryOperator>(&*It++);
auto *FortyTwo = cast<sandboxir::ConstantFP>(FAdd0->getOperand(1));
[[maybe_unused]] auto *FortyThree =
cast<sandboxir::ConstantFP>(FAdd1->getOperand(1));

auto *FloatTy = sandboxir::Type::getFloatTy(Ctx);
auto *DoubleTy = sandboxir::Type::getDoubleTy(Ctx);
auto *LLVMFloatTy = Type::getFloatTy(C);
auto *LLVMDoubleTy = Type::getDoubleTy(C);
// Check that creating an identical constant gives us the same object.
auto *NewFortyTwo = sandboxir::ConstantFP::get(FloatTy, 42.0);
EXPECT_EQ(NewFortyTwo, FortyTwo);
// Check get(Type, double).
auto *FortyFour =
cast<sandboxir::ConstantFP>(sandboxir::ConstantFP::get(FloatTy, 44.0));
auto *LLVMFortyFour =
cast<llvm::ConstantFP>(llvm::ConstantFP::get(LLVMFloatTy, 44.0));
EXPECT_NE(FortyFour, FortyTwo);
EXPECT_EQ(FortyFour, Ctx.getValue(LLVMFortyFour));
// Check get(Type, APFloat).
auto *FortyFive = cast<sandboxir::ConstantFP>(
sandboxir::ConstantFP::get(DoubleTy, APFloat(45.0)));
auto *LLVMFortyFive = cast<llvm::ConstantFP>(
llvm::ConstantFP::get(LLVMDoubleTy, APFloat(45.0)));
EXPECT_EQ(FortyFive, Ctx.getValue(LLVMFortyFive));
// Check get(Type, StringRef).
auto *FortySix = sandboxir::ConstantFP::get(FloatTy, "46.0");
EXPECT_EQ(FortySix, Ctx.getValue(llvm::ConstantFP::get(LLVMFloatTy, "46.0")));
// Check get(APFloat).
auto *FortySeven = sandboxir::ConstantFP::get(APFloat(47.0), Ctx);
EXPECT_EQ(FortySeven, Ctx.getValue(llvm::ConstantFP::get(C, APFloat(47.0))));
// Check getNaN().
{
auto *NaN = sandboxir::ConstantFP::getNaN(FloatTy);
EXPECT_EQ(NaN, Ctx.getValue(llvm::ConstantFP::getNaN(LLVMFloatTy)));
}
{
auto *NaN = sandboxir::ConstantFP::getNaN(FloatTy, /*Negative=*/true);
EXPECT_EQ(NaN, Ctx.getValue(llvm::ConstantFP::getNaN(LLVMFloatTy,
/*Negative=*/true)));
}
{
auto *NaN = sandboxir::ConstantFP::getNaN(FloatTy, /*Negative=*/true,
/*Payload=*/1);
EXPECT_EQ(NaN, Ctx.getValue(llvm::ConstantFP::getNaN(
LLVMFloatTy, /*Negative=*/true, /*Payload=*/1)));
}
// Check getQNaN().
{
auto *QNaN = sandboxir::ConstantFP::getQNaN(FloatTy);
EXPECT_EQ(QNaN, Ctx.getValue(llvm::ConstantFP::getQNaN(LLVMFloatTy)));
}
{
auto *QNaN = sandboxir::ConstantFP::getQNaN(FloatTy, /*Negative=*/true);
EXPECT_EQ(QNaN, Ctx.getValue(llvm::ConstantFP::getQNaN(LLVMFloatTy,
/*Negative=*/true)));
}
{
APInt Payload(1, 1);
auto *QNaN =
sandboxir::ConstantFP::getQNaN(FloatTy, /*Negative=*/true, &Payload);
EXPECT_EQ(QNaN, Ctx.getValue(llvm::ConstantFP::getQNaN(
LLVMFloatTy, /*Negative=*/true, &Payload)));
}
// Check getSNaN().
{
auto *SNaN = sandboxir::ConstantFP::getSNaN(FloatTy);
EXPECT_EQ(SNaN, Ctx.getValue(llvm::ConstantFP::getSNaN(LLVMFloatTy)));
}
{
auto *SNaN = sandboxir::ConstantFP::getSNaN(FloatTy, /*Negative=*/true);
EXPECT_EQ(SNaN, Ctx.getValue(llvm::ConstantFP::getSNaN(LLVMFloatTy,
/*Negative=*/true)));
}
{
APInt Payload(1, 1);
auto *SNaN =
sandboxir::ConstantFP::getSNaN(FloatTy, /*Negative=*/true, &Payload);
EXPECT_EQ(SNaN, Ctx.getValue(llvm::ConstantFP::getSNaN(
LLVMFloatTy, /*Negative=*/true, &Payload)));
}

// Check getZero().
{
auto *Zero = sandboxir::ConstantFP::getZero(FloatTy);
EXPECT_EQ(Zero, Ctx.getValue(llvm::ConstantFP::getZero(LLVMFloatTy)));
}
{
auto *Zero = sandboxir::ConstantFP::getZero(FloatTy, /*Negative=*/true);
EXPECT_EQ(Zero, Ctx.getValue(llvm::ConstantFP::getZero(LLVMFloatTy,
/*Negative=*/true)));
}

// Check getNegativeZero().
auto *NegZero = cast<sandboxir::ConstantFP>(
sandboxir::ConstantFP::getNegativeZero(FloatTy));
EXPECT_EQ(NegZero,
Ctx.getValue(llvm::ConstantFP::getNegativeZero(LLVMFloatTy)));

// Check getInfinity().
{
auto *Inf = sandboxir::ConstantFP::getInfinity(FloatTy);
EXPECT_EQ(Inf, Ctx.getValue(llvm::ConstantFP::getInfinity(LLVMFloatTy)));
}
{
auto *Inf = sandboxir::ConstantFP::getInfinity(FloatTy, /*Negative=*/true);
EXPECT_EQ(Inf, Ctx.getValue(llvm::ConstantFP::getInfinity(
LLVMFloatTy, /*Negative=*/true)));
}

// Check isValueValidForType().
APFloat V(1.1);
EXPECT_EQ(sandboxir::ConstantFP::isValueValidForType(FloatTy, V),
llvm::ConstantFP::isValueValidForType(LLVMFloatTy, V));
// Check getValueAPF().
EXPECT_EQ(FortyFour->getValueAPF(), LLVMFortyFour->getValueAPF());
// Check getValue().
EXPECT_EQ(FortyFour->getValue(), LLVMFortyFour->getValue());
// Check isZero().
EXPECT_EQ(FortyFour->isZero(), LLVMFortyFour->isZero());
EXPECT_TRUE(sandboxir::ConstantFP::getZero(FloatTy));
EXPECT_TRUE(sandboxir::ConstantFP::getZero(FloatTy, /*Negative=*/true));
// Check isNegative().
EXPECT_TRUE(cast<sandboxir::ConstantFP>(
sandboxir::ConstantFP::getZero(FloatTy, /*Negative=*/true))
->isNegative());
// Check isInfinity().
EXPECT_TRUE(
cast<sandboxir::ConstantFP>(sandboxir::ConstantFP::getInfinity(FloatTy))
->isInfinity());
// Check isNaN().
EXPECT_TRUE(
cast<sandboxir::ConstantFP>(sandboxir::ConstantFP::getNaN(FloatTy))
->isNaN());
// Check isExactlyValue(APFloat).
EXPECT_TRUE(NegZero->isExactlyValue(NegZero->getValueAPF()));
// Check isExactlyValue(double).
EXPECT_TRUE(NegZero->isExactlyValue(-0.0));
}

TEST_F(SandboxIRTest, Use) {
parseIR(C, R"IR(
define i32 @foo(i32 %v0, i32 %v1) {
Expand Down
Loading