Skip to content

Commit 64e5b37

Browse files
vporpoyuxuanchen1997
authored andcommitted
[SandboxIR] Implement LoadInst (#99597)
Summary: This patch implements a `LoadInst` instruction in SandboxIR. It mirrors `llvm::LoadInst`. Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60251256
1 parent 7dc2e88 commit 64e5b37

File tree

4 files changed

+156
-11
lines changed

4 files changed

+156
-11
lines changed

llvm/include/llvm/SandboxIR/SandboxIR.h

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
#define LLVM_SANDBOXIR_SANDBOXIR_H
6060

6161
#include "llvm/IR/Function.h"
62+
#include "llvm/IR/IRBuilder.h"
6263
#include "llvm/IR/User.h"
6364
#include "llvm/IR/Value.h"
6465
#include "llvm/SandboxIR/Tracker.h"
@@ -74,6 +75,7 @@ class BasicBlock;
7475
class Context;
7576
class Function;
7677
class Instruction;
78+
class LoadInst;
7779
class User;
7880
class Value;
7981

@@ -170,9 +172,10 @@ class Value {
170172
/// order.
171173
llvm::Value *Val = nullptr;
172174

173-
friend class Context; // For getting `Val`.
174-
friend class User; // For getting `Val`.
175-
friend class Use; // For getting `Val`.
175+
friend class Context; // For getting `Val`.
176+
friend class User; // For getting `Val`.
177+
friend class Use; // For getting `Val`.
178+
friend class LoadInst; // For getting `Val`.
176179

177180
/// All values point to the context.
178181
Context &Ctx;
@@ -262,11 +265,14 @@ class Value {
262265
llvm::function_ref<bool(const Use &)> ShouldReplace);
263266
void replaceAllUsesWith(Value *Other);
264267

268+
/// \Returns the LLVM IR name of the bottom-most LLVM value.
269+
StringRef getName() const { return Val->getName(); }
270+
265271
#ifndef NDEBUG
266272
/// Should crash if there is something wrong with the instruction.
267273
virtual void verify() const = 0;
268-
/// Returns the name in the form 'SB<number>.' like 'SB1.'
269-
std::string getName() const;
274+
/// Returns the unique id in the form 'SB<number>.' like 'SB1.'
275+
std::string getUid() const;
270276
virtual void dumpCommonHeader(raw_ostream &OS) const;
271277
void dumpCommonFooter(raw_ostream &OS) const;
272278
void dumpCommonPrefix(raw_ostream &OS) const;
@@ -489,6 +495,7 @@ class Instruction : public sandboxir::User {
489495
/// A SandboxIR Instruction may map to multiple LLVM IR Instruction. This
490496
/// returns its topmost LLVM IR instruction.
491497
llvm::Instruction *getTopmostLLVMInstruction() const;
498+
friend class LoadInst; // For getTopmostLLVMInstruction().
492499

493500
/// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
494501
/// order.
@@ -553,6 +560,45 @@ class Instruction : public sandboxir::User {
553560
#endif
554561
};
555562

563+
class LoadInst final : public Instruction {
564+
/// Use LoadInst::create() instead of calling the constructor.
565+
LoadInst(llvm::LoadInst *LI, Context &Ctx)
566+
: Instruction(ClassID::Load, Opcode::Load, LI, Ctx) {}
567+
friend Context; // for LoadInst()
568+
Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
569+
return getOperandUseDefault(OpIdx, Verify);
570+
}
571+
SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
572+
return {cast<llvm::Instruction>(Val)};
573+
}
574+
575+
public:
576+
unsigned getUseOperandNo(const Use &Use) const final {
577+
return getUseOperandNoDefault(Use);
578+
}
579+
580+
unsigned getNumOfIRInstrs() const final { return 1u; }
581+
static LoadInst *create(Type *Ty, Value *Ptr, MaybeAlign Align,
582+
Instruction *InsertBefore, Context &Ctx,
583+
const Twine &Name = "");
584+
static LoadInst *create(Type *Ty, Value *Ptr, MaybeAlign Align,
585+
BasicBlock *InsertAtEnd, Context &Ctx,
586+
const Twine &Name = "");
587+
/// For isa/dyn_cast.
588+
static bool classof(const Value *From);
589+
Value *getPointerOperand() const;
590+
Align getAlign() const { return cast<llvm::LoadInst>(Val)->getAlign(); }
591+
bool isUnordered() const { return cast<llvm::LoadInst>(Val)->isUnordered(); }
592+
bool isSimple() const { return cast<llvm::LoadInst>(Val)->isSimple(); }
593+
#ifndef NDEBUG
594+
void verify() const final {
595+
assert(isa<llvm::LoadInst>(Val) && "Expected LoadInst!");
596+
}
597+
void dump(raw_ostream &OS) const override;
598+
LLVM_DUMP_METHOD void dump() const override;
599+
#endif
600+
};
601+
556602
/// An LLLVM Instruction that has no SandboxIR equivalent class gets mapped to
557603
/// an OpaqueInstr.
558604
class OpaqueInst : public sandboxir::Instruction {
@@ -683,8 +729,16 @@ class Context {
683729

684730
friend class BasicBlock; // For getOrCreateValue().
685731

732+
IRBuilder<ConstantFolder> LLVMIRBuilder;
733+
auto &getLLVMIRBuilder() { return LLVMIRBuilder; }
734+
735+
LoadInst *createLoadInst(llvm::LoadInst *LI);
736+
friend LoadInst; // For createLoadInst()
737+
686738
public:
687-
Context(LLVMContext &LLVMCtx) : LLVMCtx(LLVMCtx), IRTracker(*this) {}
739+
Context(LLVMContext &LLVMCtx)
740+
: LLVMCtx(LLVMCtx), IRTracker(*this),
741+
LLVMIRBuilder(LLVMCtx, ConstantFolder()) {}
688742

689743
Tracker &getTracker() { return IRTracker; }
690744
/// Convenience function for `getTracker().save()`

llvm/include/llvm/SandboxIR/SandboxIRValues.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ DEF_USER(Constant, Constant)
2525
#endif
2626
// ClassID, Opcode(s), Class
2727
DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
28+
DEF_INSTR(Load, OP(Load), LoadInst)
2829

2930
#ifdef DEF_VALUE
3031
#undef DEF_VALUE

llvm/lib/SandboxIR/SandboxIR.cpp

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,14 @@ void Value::replaceAllUsesWith(Value *Other) {
140140
}
141141

142142
#ifndef NDEBUG
143-
std::string Value::getName() const {
143+
std::string Value::getUid() const {
144144
std::stringstream SS;
145145
SS << "SB" << UID << ".";
146146
return SS.str();
147147
}
148148

149149
void Value::dumpCommonHeader(raw_ostream &OS) const {
150-
OS << getName() << " " << getSubclassIDStr(SubclassID) << " ";
150+
OS << getUid() << " " << getSubclassIDStr(SubclassID) << " ";
151151
}
152152

153153
void Value::dumpCommonFooter(raw_ostream &OS) const {
@@ -167,7 +167,7 @@ void Value::dumpCommonPrefix(raw_ostream &OS) const {
167167
}
168168

169169
void Value::dumpCommonSuffix(raw_ostream &OS) const {
170-
OS << " ; " << getName() << " (" << getSubclassIDStr(SubclassID) << ")";
170+
OS << " ; " << getUid() << " (" << getSubclassIDStr(SubclassID) << ")";
171171
}
172172

173173
void Value::printAsOperandCommon(raw_ostream &OS) const {
@@ -453,6 +453,49 @@ void Instruction::dump() const {
453453
dump(dbgs());
454454
dbgs() << "\n";
455455
}
456+
#endif // NDEBUG
457+
458+
LoadInst *LoadInst::create(Type *Ty, Value *Ptr, MaybeAlign Align,
459+
Instruction *InsertBefore, Context &Ctx,
460+
const Twine &Name) {
461+
llvm::Instruction *BeforeIR = InsertBefore->getTopmostLLVMInstruction();
462+
auto &Builder = Ctx.getLLVMIRBuilder();
463+
Builder.SetInsertPoint(BeforeIR);
464+
auto *NewLI = Builder.CreateAlignedLoad(Ty, Ptr->Val, Align,
465+
/*isVolatile=*/false, Name);
466+
auto *NewSBI = Ctx.createLoadInst(NewLI);
467+
return NewSBI;
468+
}
469+
470+
LoadInst *LoadInst::create(Type *Ty, Value *Ptr, MaybeAlign Align,
471+
BasicBlock *InsertAtEnd, Context &Ctx,
472+
const Twine &Name) {
473+
auto &Builder = Ctx.getLLVMIRBuilder();
474+
Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
475+
auto *NewLI = Builder.CreateAlignedLoad(Ty, Ptr->Val, Align,
476+
/*isVolatile=*/false, Name);
477+
auto *NewSBI = Ctx.createLoadInst(NewLI);
478+
return NewSBI;
479+
}
480+
481+
bool LoadInst::classof(const Value *From) {
482+
return From->getSubclassID() == ClassID::Load;
483+
}
484+
485+
Value *LoadInst::getPointerOperand() const {
486+
return Ctx.getValue(cast<llvm::LoadInst>(Val)->getPointerOperand());
487+
}
488+
489+
#ifndef NDEBUG
490+
void LoadInst::dump(raw_ostream &OS) const {
491+
dumpCommonPrefix(OS);
492+
dumpCommonSuffix(OS);
493+
}
494+
495+
void LoadInst::dump() const {
496+
dump(dbgs());
497+
dbgs() << "\n";
498+
}
456499

457500
void OpaqueInst::dump(raw_ostream &OS) const {
458501
dumpCommonPrefix(OS);
@@ -538,8 +581,8 @@ Value *Context::registerValue(std::unique_ptr<Value> &&VPtr) {
538581
assert(VPtr->getSubclassID() != Value::ClassID::User &&
539582
"Can't register a user!");
540583
Value *V = VPtr.get();
541-
llvm::Value *Key = V->Val;
542-
LLVMValueToValueMap[Key] = std::move(VPtr);
584+
auto Pair = LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)});
585+
assert(Pair.second && "Already exists!");
543586
return V;
544587
}
545588

@@ -568,6 +611,17 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
568611
return nullptr;
569612
}
570613
assert(isa<llvm::Instruction>(LLVMV) && "Expected Instruction");
614+
615+
switch (cast<llvm::Instruction>(LLVMV)->getOpcode()) {
616+
case llvm::Instruction::Load: {
617+
auto *LLVMLd = cast<llvm::LoadInst>(LLVMV);
618+
It->second = std::unique_ptr<LoadInst>(new LoadInst(LLVMLd, *this));
619+
return It->second.get();
620+
}
621+
default:
622+
break;
623+
}
624+
571625
It->second = std::unique_ptr<OpaqueInst>(
572626
new OpaqueInst(cast<llvm::Instruction>(LLVMV), *this));
573627
return It->second.get();
@@ -582,6 +636,11 @@ BasicBlock *Context::createBasicBlock(llvm::BasicBlock *LLVMBB) {
582636
return BB;
583637
}
584638

639+
LoadInst *Context::createLoadInst(llvm::LoadInst *LI) {
640+
auto NewPtr = std::unique_ptr<LoadInst>(new LoadInst(LI, *this));
641+
return cast<LoadInst>(registerValue(std::move(NewPtr)));
642+
}
643+
585644
Value *Context::getValue(llvm::Value *V) const {
586645
auto It = LLVMValueToValueMap.find(V);
587646
if (It != LLVMValueToValueMap.end())

llvm/unittests/SandboxIR/SandboxIRTest.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,3 +560,34 @@ define void @foo(i8 %v1) {
560560
EXPECT_EQ(I0->getNumUses(), 0u);
561561
EXPECT_EQ(I0->getNextNode(), Ret);
562562
}
563+
564+
TEST_F(SandboxIRTest, LoadInst) {
565+
parseIR(C, R"IR(
566+
define void @foo(ptr %arg0, ptr %arg1) {
567+
%ld = load i8, ptr %arg0, align 64
568+
ret void
569+
}
570+
)IR");
571+
llvm::Function *LLVMF = &*M->getFunction("foo");
572+
sandboxir::Context Ctx(C);
573+
sandboxir::Function *F = Ctx.createFunction(LLVMF);
574+
auto *Arg0 = F->getArg(0);
575+
auto *Arg1 = F->getArg(1);
576+
auto *BB = &*F->begin();
577+
auto It = BB->begin();
578+
auto *Ld = cast<sandboxir::LoadInst>(&*It++);
579+
auto *Ret = &*It++;
580+
581+
// Check getPointerOperand()
582+
EXPECT_EQ(Ld->getPointerOperand(), Arg0);
583+
// Check getAlign()
584+
EXPECT_EQ(Ld->getAlign(), 64);
585+
// Check create(InsertBefore)
586+
sandboxir::LoadInst *NewLd =
587+
sandboxir::LoadInst::create(Ld->getType(), Arg1, Align(8),
588+
/*InsertBefore=*/Ret, Ctx, "NewLd");
589+
EXPECT_EQ(NewLd->getType(), Ld->getType());
590+
EXPECT_EQ(NewLd->getPointerOperand(), Arg1);
591+
EXPECT_EQ(NewLd->getAlign(), 8);
592+
EXPECT_EQ(NewLd->getName(), "NewLd");
593+
}

0 commit comments

Comments
 (0)