Skip to content

[SandboxIR] Implement Instruction flags #103343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,98 @@ class Instruction : public sandboxir::User {
/// For isa/dyn_cast.
static bool classof(const sandboxir::Value *From);

/// Determine whether the no signed wrap flag is set.
bool hasNoUnsignedWrap() const {
return cast<llvm::Instruction>(Val)->hasNoUnsignedWrap();
}
/// Set or clear the nuw flag on this instruction, which must be an operator
/// which supports this flag. See LangRef.html for the meaning of this flag.
void setHasNoUnsignedWrap(bool B = true);
/// Determine whether the no signed wrap flag is set.
bool hasNoSignedWrap() const {
return cast<llvm::Instruction>(Val)->hasNoSignedWrap();
}
/// Set or clear the nsw flag on this instruction, which must be an operator
/// which supports this flag. See LangRef.html for the meaning of this flag.
void setHasNoSignedWrap(bool B = true);
/// Determine whether all fast-math-flags are set.
bool isFast() const { return cast<llvm::Instruction>(Val)->isFast(); }
/// Set or clear all fast-math-flags on this instruction, which must be an
/// operator which supports this flag. See LangRef.html for the meaning of
/// this flag.
void setFast(bool B);
/// Determine whether the allow-reassociation flag is set.
bool hasAllowReassoc() const {
return cast<llvm::Instruction>(Val)->hasAllowReassoc();
}
/// Set or clear the reassociation flag on this instruction, which must be
/// an operator which supports this flag. See LangRef.html for the meaning of
/// this flag.
void setHasAllowReassoc(bool B);
/// Determine whether the exact flag is set.
bool isExact() const { return cast<llvm::Instruction>(Val)->isExact(); }
/// Set or clear the exact flag on this instruction, which must be an operator
/// which supports this flag. See LangRef.html for the meaning of this flag.
void setIsExact(bool B = true);
/// Determine whether the no-NaNs flag is set.
bool hasNoNaNs() const { return cast<llvm::Instruction>(Val)->hasNoNaNs(); }
/// Set or clear the no-nans flag on this instruction, which must be an
/// operator which supports this flag. See LangRef.html for the meaning of
/// this flag.
void setHasNoNaNs(bool B);
/// Determine whether the no-infs flag is set.
bool hasNoInfs() const { return cast<llvm::Instruction>(Val)->hasNoInfs(); }
/// Set or clear the no-infs flag on this instruction, which must be an
/// operator which supports this flag. See LangRef.html for the meaning of
/// this flag.
void setHasNoInfs(bool B);
/// Determine whether the no-signed-zeros flag is set.
bool hasNoSignedZeros() const {
return cast<llvm::Instruction>(Val)->hasNoSignedZeros();
}
/// Set or clear the no-signed-zeros flag on this instruction, which must be
/// an operator which supports this flag. See LangRef.html for the meaning of
/// this flag.
void setHasNoSignedZeros(bool B);
/// Determine whether the allow-reciprocal flag is set.
bool hasAllowReciprocal() const {
return cast<llvm::Instruction>(Val)->hasAllowReciprocal();
}
/// Set or clear the allow-reciprocal flag on this instruction, which must be
/// an operator which supports this flag. See LangRef.html for the meaning of
/// this flag.
void setHasAllowReciprocal(bool B);
/// Determine whether the allow-contract flag is set.
bool hasAllowContract() const {
return cast<llvm::Instruction>(Val)->hasAllowContract();
}
/// Set or clear the allow-contract flag on this instruction, which must be
/// an operator which supports this flag. See LangRef.html for the meaning of
/// this flag.
void setHasAllowContract(bool B);
/// Determine whether the approximate-math-functions flag is set.
bool hasApproxFunc() const {
return cast<llvm::Instruction>(Val)->hasApproxFunc();
}
/// Set or clear the approximate-math-functions flag on this instruction,
/// which must be an operator which supports this flag. See LangRef.html for
/// the meaning of this flag.
void setHasApproxFunc(bool B);
/// Convenience function for getting all the fast-math flags, which must be an
/// operator which supports these flags. See LangRef.html for the meaning of
/// these flags.
FastMathFlags getFastMathFlags() const {
return cast<llvm::Instruction>(Val)->getFastMathFlags();
}
/// Convenience function for setting multiple fast-math flags on this
/// instruction, which must be an operator which supports these flags. See
/// LangRef.html for the meaning of these flags.
void setFastMathFlags(FastMathFlags FMF);
/// Convenience function for transferring all fast-math flag values to this
/// instruction, which must be an operator which supports these flags. See
/// LangRef.html for the meaning of these flags.
void copyFastMathFlags(FastMathFlags FMF);

#ifndef NDEBUG
void dumpOS(raw_ostream &OS) const override;
#endif
Expand Down
14 changes: 7 additions & 7 deletions llvm/include/llvm/SandboxIR/Tracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,13 @@ class RemoveFromParent : public IRChangeBase {
///
template <auto GetterFn, auto SetterFn>
class GenericSetter final : public IRChangeBase {
/// Helper for getting the class type from the getter
template <typename ClassT, typename RetT>
static ClassT getClassTypeFromGetter(RetT (ClassT::*Fn)() const);
template <typename ClassT, typename RetT>
static ClassT getClassTypeFromGetter(RetT (ClassT::*Fn)());

using InstrT = decltype(getClassTypeFromGetter(GetterFn));
/// Traits for getting the class type from GetterFn type.
template <typename> struct GetClassTypeFromGetter;
template <typename RetT, typename ClassT>
struct GetClassTypeFromGetter<RetT (ClassT::*)() const> {
using ClassType = ClassT;
};
using InstrT = typename GetClassTypeFromGetter<decltype(GetterFn)>::ClassType;
using SavedValT = std::invoke_result_t<decltype(GetterFn), InstrT>;
InstrT *I;
SavedValT OrigVal;
Expand Down
97 changes: 97 additions & 0 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,103 @@ bool Instruction::classof(const sandboxir::Value *From) {
}
}

void Instruction::setHasNoUnsignedWrap(bool B) {
Ctx.getTracker()
.emplaceIfTracking<GenericSetter<&Instruction::hasNoUnsignedWrap,
&Instruction::setHasNoUnsignedWrap>>(
this);
cast<llvm::Instruction>(Val)->setHasNoUnsignedWrap(B);
}

void Instruction::setHasNoSignedWrap(bool B) {
Ctx.getTracker()
.emplaceIfTracking<GenericSetter<&Instruction::hasNoSignedWrap,
&Instruction::setHasNoSignedWrap>>(this);
cast<llvm::Instruction>(Val)->setHasNoSignedWrap(B);
}

void Instruction::setFast(bool B) {
Ctx.getTracker()
.emplaceIfTracking<
GenericSetter<&Instruction::isFast, &Instruction::setFast>>(this);
cast<llvm::Instruction>(Val)->setFast(B);
}

void Instruction::setIsExact(bool B) {
Ctx.getTracker()
.emplaceIfTracking<
GenericSetter<&Instruction::isExact, &Instruction::setIsExact>>(this);
cast<llvm::Instruction>(Val)->setIsExact(B);
}

void Instruction::setHasAllowReassoc(bool B) {
Ctx.getTracker()
.emplaceIfTracking<GenericSetter<&Instruction::hasAllowReassoc,
&Instruction::setHasAllowReassoc>>(this);
cast<llvm::Instruction>(Val)->setHasAllowReassoc(B);
}

void Instruction::setHasNoNaNs(bool B) {
Ctx.getTracker()
.emplaceIfTracking<
GenericSetter<&Instruction::hasNoNaNs, &Instruction::setHasNoNaNs>>(
this);
cast<llvm::Instruction>(Val)->setHasNoNaNs(B);
}

void Instruction::setHasNoInfs(bool B) {
Ctx.getTracker()
.emplaceIfTracking<
GenericSetter<&Instruction::hasNoInfs, &Instruction::setHasNoInfs>>(
this);
cast<llvm::Instruction>(Val)->setHasNoInfs(B);
}

void Instruction::setHasNoSignedZeros(bool B) {
Ctx.getTracker()
.emplaceIfTracking<GenericSetter<&Instruction::hasNoSignedZeros,
&Instruction::setHasNoSignedZeros>>(
this);
cast<llvm::Instruction>(Val)->setHasNoSignedZeros(B);
}

void Instruction::setHasAllowReciprocal(bool B) {
Ctx.getTracker()
.emplaceIfTracking<GenericSetter<&Instruction::hasAllowReciprocal,
&Instruction::setHasAllowReciprocal>>(
this);
cast<llvm::Instruction>(Val)->setHasAllowReciprocal(B);
}

void Instruction::setHasAllowContract(bool B) {
Ctx.getTracker()
.emplaceIfTracking<GenericSetter<&Instruction::hasAllowContract,
&Instruction::setHasAllowContract>>(
this);
cast<llvm::Instruction>(Val)->setHasAllowContract(B);
}

void Instruction::setFastMathFlags(FastMathFlags FMF) {
Ctx.getTracker()
.emplaceIfTracking<GenericSetter<&Instruction::getFastMathFlags,
&Instruction::copyFastMathFlags>>(this);
cast<llvm::Instruction>(Val)->setFastMathFlags(FMF);
}

void Instruction::copyFastMathFlags(FastMathFlags FMF) {
Ctx.getTracker()
.emplaceIfTracking<GenericSetter<&Instruction::getFastMathFlags,
&Instruction::copyFastMathFlags>>(this);
cast<llvm::Instruction>(Val)->copyFastMathFlags(FMF);
}

void Instruction::setHasApproxFunc(bool B) {
Ctx.getTracker()
.emplaceIfTracking<GenericSetter<&Instruction::hasApproxFunc,
&Instruction::setHasApproxFunc>>(this);
cast<llvm::Instruction>(Val)->setHasApproxFunc(B);
}

#ifndef NDEBUG
void Instruction::dumpOS(raw_ostream &OS) const {
OS << "Unimplemented! Please override dump().";
Expand Down
59 changes: 59 additions & 0 deletions llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1569,6 +1569,65 @@ define void @foo(ptr %ptr, <2 x ptr> %ptrs) {
EXPECT_EQ(NewGEP2->getNextNode(), nullptr);
}

TEST_F(SandboxIRTest, Flags) {
parseIR(C, R"IR(
define void @foo(i32 %arg, float %farg) {
%add = add i32 %arg, %arg
%fadd = fadd float %farg, %farg
%udiv = udiv i32 %arg, %arg
ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");
BasicBlock *LLVMBB = &*LLVMF.begin();
auto LLVMIt = LLVMBB->begin();
auto *LLVMAdd = &*LLVMIt++;
auto *LLVMFAdd = &*LLVMIt++;
auto *LLVMUDiv = &*LLVMIt++;

sandboxir::Context Ctx(C);
auto &F = *Ctx.createFunction(&LLVMF);
auto *BB = &*F.begin();
auto It = BB->begin();
auto *Add = &*It++;
auto *FAdd = &*It++;
auto *UDiv = &*It++;

#define CHECK_FLAG(I, LLVMI, GETTER, SETTER) \
{ \
EXPECT_EQ(I->GETTER(), LLVMI->GETTER()); \
bool NewFlagVal = !I->GETTER(); \
I->SETTER(NewFlagVal); \
EXPECT_EQ(I->GETTER(), NewFlagVal); \
EXPECT_EQ(I->GETTER(), LLVMI->GETTER()); \
}

CHECK_FLAG(Add, LLVMAdd, hasNoUnsignedWrap, setHasNoUnsignedWrap);
CHECK_FLAG(Add, LLVMAdd, hasNoSignedWrap, setHasNoSignedWrap);
CHECK_FLAG(FAdd, LLVMFAdd, isFast, setFast);
CHECK_FLAG(FAdd, LLVMFAdd, hasAllowReassoc, setHasAllowReassoc);
CHECK_FLAG(UDiv, LLVMUDiv, isExact, setIsExact);
CHECK_FLAG(FAdd, LLVMFAdd, hasNoNaNs, setHasNoNaNs);
CHECK_FLAG(FAdd, LLVMFAdd, hasNoInfs, setHasNoInfs);
CHECK_FLAG(FAdd, LLVMFAdd, hasNoSignedZeros, setHasNoSignedZeros);
CHECK_FLAG(FAdd, LLVMFAdd, hasAllowReciprocal, setHasAllowReciprocal);
CHECK_FLAG(FAdd, LLVMFAdd, hasAllowContract, setHasAllowContract);
CHECK_FLAG(FAdd, LLVMFAdd, hasApproxFunc, setHasApproxFunc);

// Check getFastMathFlags(), copyFastMathFlags().
FAdd->setFastMathFlags(FastMathFlags::getFast());
EXPECT_FALSE(FAdd->getFastMathFlags() != LLVMFAdd->getFastMathFlags());
FastMathFlags OrigFMF = FAdd->getFastMathFlags();
FastMathFlags NewFMF;
NewFMF.setAllowReassoc(true);
EXPECT_TRUE(NewFMF != OrigFMF);
FAdd->setFastMathFlags(NewFMF);
EXPECT_FALSE(FAdd->getFastMathFlags() != OrigFMF);
FAdd->copyFastMathFlags(NewFMF);
EXPECT_FALSE(FAdd->getFastMathFlags() != NewFMF);
EXPECT_FALSE(FAdd->getFastMathFlags() != LLVMFAdd->getFastMathFlags());
}

TEST_F(SandboxIRTest, AtomicCmpXchgInst) {
parseIR(C, R"IR(
define void @foo(ptr %ptr, i8 %cmp, i8 %new) {
Expand Down
61 changes: 61 additions & 0 deletions llvm/unittests/SandboxIR/TrackerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -968,3 +968,64 @@ define void @foo(ptr %arg0, i8 %val) {
Ctx.revert();
EXPECT_FALSE(Store->isVolatile());
}

TEST_F(TrackerTest, Flags) {
parseIR(C, R"IR(
define void @foo(i32 %arg, float %farg) {
%add = add i32 %arg, %arg
%fadd = fadd float %farg, %farg
%udiv = udiv i32 %arg, %arg
ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);
auto &F = *Ctx.createFunction(&LLVMF);
auto *BB = &*F.begin();
auto It = BB->begin();
auto *Add = &*It++;
auto *FAdd = &*It++;
auto *UDiv = &*It++;

#define CHECK_FLAG(I, GETTER, SETTER) \
{ \
Ctx.save(); \
bool OrigFlag = I->GETTER(); \
bool NewFlag = !OrigFlag; \
I->SETTER(NewFlag); \
EXPECT_EQ(I->GETTER(), NewFlag); \
Ctx.revert(); \
EXPECT_EQ(I->GETTER(), OrigFlag); \
}

CHECK_FLAG(Add, hasNoUnsignedWrap, setHasNoUnsignedWrap);
CHECK_FLAG(Add, hasNoSignedWrap, setHasNoSignedWrap);
CHECK_FLAG(FAdd, isFast, setFast);
CHECK_FLAG(FAdd, hasAllowReassoc, setHasAllowReassoc);
CHECK_FLAG(UDiv, isExact, setIsExact);
CHECK_FLAG(FAdd, hasNoNaNs, setHasNoNaNs);
CHECK_FLAG(FAdd, hasNoInfs, setHasNoInfs);
CHECK_FLAG(FAdd, hasNoSignedZeros, setHasNoSignedZeros);
CHECK_FLAG(FAdd, hasAllowReciprocal, setHasAllowReciprocal);
CHECK_FLAG(FAdd, hasAllowContract, setHasAllowContract);
CHECK_FLAG(FAdd, hasApproxFunc, setHasApproxFunc);

// Check setFastMathFlags().
FastMathFlags OrigFMF = FAdd->getFastMathFlags();
FastMathFlags NewFMF;
NewFMF.setAllowReassoc(true);
EXPECT_TRUE(NewFMF != OrigFMF);

Ctx.save();
FAdd->setFastMathFlags(NewFMF);
EXPECT_FALSE(FAdd->getFastMathFlags() != NewFMF);
Ctx.revert();
EXPECT_FALSE(FAdd->getFastMathFlags() != OrigFMF);

// Check copyFastMathFlags().
Ctx.save();
FAdd->copyFastMathFlags(NewFMF);
EXPECT_FALSE(FAdd->getFastMathFlags() != NewFMF);
Ctx.revert();
EXPECT_FALSE(FAdd->getFastMathFlags() != OrigFMF);
}
Loading