Skip to content

[SandboxIR] Implement LoadInst #99597

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 60 additions & 6 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
#define LLVM_SANDBOXIR_SANDBOXIR_H

#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
#include "llvm/SandboxIR/Tracker.h"
Expand All @@ -74,6 +75,7 @@ class BasicBlock;
class Context;
class Function;
class Instruction;
class LoadInst;
class User;
class Value;

Expand Down Expand Up @@ -170,9 +172,10 @@ class Value {
/// order.
llvm::Value *Val = nullptr;

friend class Context; // For getting `Val`.
friend class User; // For getting `Val`.
friend class Use; // For getting `Val`.
friend class Context; // For getting `Val`.
friend class User; // For getting `Val`.
friend class Use; // For getting `Val`.
friend class LoadInst; // For getting `Val`.

/// All values point to the context.
Context &Ctx;
Expand Down Expand Up @@ -262,11 +265,14 @@ class Value {
llvm::function_ref<bool(const Use &)> ShouldReplace);
void replaceAllUsesWith(Value *Other);

/// \Returns the LLVM IR name of the bottom-most LLVM value.
StringRef getName() const { return Val->getName(); }

#ifndef NDEBUG
/// Should crash if there is something wrong with the instruction.
virtual void verify() const = 0;
/// Returns the name in the form 'SB<number>.' like 'SB1.'
std::string getName() const;
/// Returns the unique id in the form 'SB<number>.' like 'SB1.'
std::string getUid() const;
virtual void dumpCommonHeader(raw_ostream &OS) const;
void dumpCommonFooter(raw_ostream &OS) const;
void dumpCommonPrefix(raw_ostream &OS) const;
Expand Down Expand Up @@ -489,6 +495,7 @@ class Instruction : public sandboxir::User {
/// A SandboxIR Instruction may map to multiple LLVM IR Instruction. This
/// returns its topmost LLVM IR instruction.
llvm::Instruction *getTopmostLLVMInstruction() const;
friend class LoadInst; // For getTopmostLLVMInstruction().

/// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
/// order.
Expand Down Expand Up @@ -553,6 +560,45 @@ class Instruction : public sandboxir::User {
#endif
};

class LoadInst final : public Instruction {
/// Use LoadInst::create() instead of calling the constructor.
LoadInst(llvm::LoadInst *LI, Context &Ctx)
: Instruction(ClassID::Load, Opcode::Load, LI, Ctx) {}
friend Context; // for LoadInst()
Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
return getOperandUseDefault(OpIdx, Verify);
}
SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
return {cast<llvm::Instruction>(Val)};
}

public:
unsigned getUseOperandNo(const Use &Use) const final {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The final confused me. It comes from User . For style and safeness, I would mark it as override first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm why would you mark it as override ? LoadInst won't have subclasses so it should be safe to mark it final.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. But getUseOperandNo overrides a virtual = 0 from User. Once you mark it override the compiler will do checks for you. override final.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think override final is redundant, and I am not sure it's being used in LLVM. If you use final the compiler will check that you are actually overriding and print an error if not, and on top of that it will also check that it is final.

Copy link

@tschuett tschuett Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. I found one occurrence.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that LoadInst is final, the finals are kind of redundant. Feel free to ignore, but override would be better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't they still useful though? If you misspell a final function name you will still get an error, but you won't get one if you just rely on the class being marked final.
Well, I think I would prefer to stick with final because it conveys both that it's overriden and that it's final.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries. The class is marked final. You cannot inherit from anyway and overwriting the functions. My only confusion is that final is kind of redundant. override is less confusing and gives the same compiler checks.

return getUseOperandNoDefault(Use);
}

unsigned getNumOfIRInstrs() const final { return 1u; }
static LoadInst *create(Type *Ty, Value *Ptr, MaybeAlign Align,
Instruction *InsertBefore, Context &Ctx,
const Twine &Name = "");
static LoadInst *create(Type *Ty, Value *Ptr, MaybeAlign Align,
BasicBlock *InsertAtEnd, Context &Ctx,
const Twine &Name = "");
/// For isa/dyn_cast.
static bool classof(const Value *From);
Value *getPointerOperand() const;
Align getAlign() const { return cast<llvm::LoadInst>(Val)->getAlign(); }
bool isUnordered() const { return cast<llvm::LoadInst>(Val)->isUnordered(); }
bool isSimple() const { return cast<llvm::LoadInst>(Val)->isSimple(); }
#ifndef NDEBUG
void verify() const final {
assert(isa<llvm::LoadInst>(Val) && "Expected LoadInst!");
}
void dump(raw_ostream &OS) const override;
LLVM_DUMP_METHOD void dump() const override;
#endif
};

/// An LLLVM Instruction that has no SandboxIR equivalent class gets mapped to
/// an OpaqueInstr.
class OpaqueInst : public sandboxir::Instruction {
Expand Down Expand Up @@ -683,8 +729,16 @@ class Context {

friend class BasicBlock; // For getOrCreateValue().

IRBuilder<ConstantFolder> LLVMIRBuilder;
auto &getLLVMIRBuilder() { return LLVMIRBuilder; }

LoadInst *createLoadInst(llvm::LoadInst *LI);
friend LoadInst; // For createLoadInst()

public:
Context(LLVMContext &LLVMCtx) : LLVMCtx(LLVMCtx), IRTracker(*this) {}
Context(LLVMContext &LLVMCtx)
: LLVMCtx(LLVMCtx), IRTracker(*this),
LLVMIRBuilder(LLVMCtx, ConstantFolder()) {}

Tracker &getTracker() { return IRTracker; }
/// Convenience function for `getTracker().save()`
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/SandboxIR/SandboxIRValues.def
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ DEF_USER(Constant, Constant)
#endif
// ClassID, Opcode(s), Class
DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
DEF_INSTR(Load, OP(Load), LoadInst)

#ifdef DEF_VALUE
#undef DEF_VALUE
Expand Down
69 changes: 64 additions & 5 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,14 @@ void Value::replaceAllUsesWith(Value *Other) {
}

#ifndef NDEBUG
std::string Value::getName() const {
std::string Value::getUid() const {
std::stringstream SS;
SS << "SB" << UID << ".";
return SS.str();
}

void Value::dumpCommonHeader(raw_ostream &OS) const {
OS << getName() << " " << getSubclassIDStr(SubclassID) << " ";
OS << getUid() << " " << getSubclassIDStr(SubclassID) << " ";
}

void Value::dumpCommonFooter(raw_ostream &OS) const {
Expand All @@ -167,7 +167,7 @@ void Value::dumpCommonPrefix(raw_ostream &OS) const {
}

void Value::dumpCommonSuffix(raw_ostream &OS) const {
OS << " ; " << getName() << " (" << getSubclassIDStr(SubclassID) << ")";
OS << " ; " << getUid() << " (" << getSubclassIDStr(SubclassID) << ")";
}

void Value::printAsOperandCommon(raw_ostream &OS) const {
Expand Down Expand Up @@ -453,6 +453,49 @@ void Instruction::dump() const {
dump(dbgs());
dbgs() << "\n";
}
#endif // NDEBUG

LoadInst *LoadInst::create(Type *Ty, Value *Ptr, MaybeAlign Align,
Instruction *InsertBefore, Context &Ctx,
const Twine &Name) {
llvm::Instruction *BeforeIR = InsertBefore->getTopmostLLVMInstruction();
auto &Builder = Ctx.getLLVMIRBuilder();
Builder.SetInsertPoint(BeforeIR);
auto *NewLI = Builder.CreateAlignedLoad(Ty, Ptr->Val, Align,
/*isVolatile=*/false, Name);
auto *NewSBI = Ctx.createLoadInst(NewLI);
return NewSBI;
}

LoadInst *LoadInst::create(Type *Ty, Value *Ptr, MaybeAlign Align,
BasicBlock *InsertAtEnd, Context &Ctx,
const Twine &Name) {
auto &Builder = Ctx.getLLVMIRBuilder();
Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
auto *NewLI = Builder.CreateAlignedLoad(Ty, Ptr->Val, Align,
/*isVolatile=*/false, Name);
auto *NewSBI = Ctx.createLoadInst(NewLI);
return NewSBI;
}

bool LoadInst::classof(const Value *From) {
return From->getSubclassID() == ClassID::Load;
}

Value *LoadInst::getPointerOperand() const {
return Ctx.getValue(cast<llvm::LoadInst>(Val)->getPointerOperand());
}

#ifndef NDEBUG
void LoadInst::dump(raw_ostream &OS) const {
dumpCommonPrefix(OS);
dumpCommonSuffix(OS);
}

void LoadInst::dump() const {
dump(dbgs());
dbgs() << "\n";
}

void OpaqueInst::dump(raw_ostream &OS) const {
dumpCommonPrefix(OS);
Expand Down Expand Up @@ -538,8 +581,8 @@ Value *Context::registerValue(std::unique_ptr<Value> &&VPtr) {
assert(VPtr->getSubclassID() != Value::ClassID::User &&
"Can't register a user!");
Value *V = VPtr.get();
llvm::Value *Key = V->Val;
LLVMValueToValueMap[Key] = std::move(VPtr);
auto Pair = LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)});
assert(Pair.second && "Already exists!");
return V;
}

Expand Down Expand Up @@ -568,6 +611,17 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
return nullptr;
}
assert(isa<llvm::Instruction>(LLVMV) && "Expected Instruction");

switch (cast<llvm::Instruction>(LLVMV)->getOpcode()) {
case llvm::Instruction::Load: {
auto *LLVMLd = cast<llvm::LoadInst>(LLVMV);
It->second = std::unique_ptr<LoadInst>(new LoadInst(LLVMLd, *this));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use std::make_unique.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just like the lines above we can't use make_unique because the constructors are private.

return It->second.get();
}
default:
break;
}

It->second = std::unique_ptr<OpaqueInst>(
new OpaqueInst(cast<llvm::Instruction>(LLVMV), *this));
return It->second.get();
Expand All @@ -582,6 +636,11 @@ BasicBlock *Context::createBasicBlock(llvm::BasicBlock *LLVMBB) {
return BB;
}

LoadInst *Context::createLoadInst(llvm::LoadInst *LI) {
auto NewPtr = std::unique_ptr<LoadInst>(new LoadInst(LI, *this));
return cast<LoadInst>(registerValue(std::move(NewPtr)));
}

Value *Context::getValue(llvm::Value *V) const {
auto It = LLVMValueToValueMap.find(V);
if (It != LLVMValueToValueMap.end())
Expand Down
31 changes: 31 additions & 0 deletions llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -560,3 +560,34 @@ define void @foo(i8 %v1) {
EXPECT_EQ(I0->getNumUses(), 0u);
EXPECT_EQ(I0->getNextNode(), Ret);
}

TEST_F(SandboxIRTest, LoadInst) {
parseIR(C, R"IR(
define void @foo(ptr %arg0, ptr %arg1) {
%ld = load i8, ptr %arg0, align 64
ret void
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
sandboxir::Context Ctx(C);
sandboxir::Function *F = Ctx.createFunction(LLVMF);
auto *Arg0 = F->getArg(0);
auto *Arg1 = F->getArg(1);
auto *BB = &*F->begin();
auto It = BB->begin();
auto *Ld = cast<sandboxir::LoadInst>(&*It++);
auto *Ret = &*It++;

// Check getPointerOperand()
EXPECT_EQ(Ld->getPointerOperand(), Arg0);
// Check getAlign()
EXPECT_EQ(Ld->getAlign(), 64);
// Check create(InsertBefore)
sandboxir::LoadInst *NewLd =
sandboxir::LoadInst::create(Ld->getType(), Arg1, Align(8),
/*InsertBefore=*/Ret, Ctx, "NewLd");
EXPECT_EQ(NewLd->getType(), Ld->getType());
EXPECT_EQ(NewLd->getPointerOperand(), Arg1);
EXPECT_EQ(NewLd->getAlign(), 8);
EXPECT_EQ(NewLd->getName(), "NewLd");
}
Loading