Skip to content

[SandboxIR] Add more Instruction member functions #98588

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ namespace llvm {

namespace sandboxir {

class Function;
class BasicBlock;
class Context;
class Function;
class Instruction;
class User;
class Value;
Expand Down Expand Up @@ -508,6 +509,14 @@ class Instruction : public sandboxir::User {

Opcode Opc;

/// A SandboxIR Instruction may map to multiple LLVM IR Instruction. This
/// returns its topmost LLVM IR instruction.
llvm::Instruction *getTopmostLLVMInstruction() const;

/// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
/// order.
virtual SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const = 0;

public:
static const char *getOpcodeName(Opcode Opc);
#ifndef NDEBUG
Expand All @@ -518,6 +527,40 @@ class Instruction : public sandboxir::User {
#endif
/// This is used by BasicBlock::iterator.
virtual unsigned getNumOfIRInstrs() const = 0;
/// \Returns a BasicBlock::iterator for this Instruction.
BBIterator getIterator() const;
/// \Returns the next sandboxir::Instruction in the block, or nullptr if at
/// the end of the block.
Instruction *getNextNode() const;
/// \Returns the previous sandboxir::Instruction in the block, or nullptr if
/// at the beginning of the block.
Instruction *getPrevNode() const;
/// \Returns this Instruction's opcode. Note that SandboxIR has its own opcode
/// state to allow for new SandboxIR-specific instructions.
Opcode getOpcode() const { return Opc; }
/// Detach this from its parent BasicBlock without deleting it.
void removeFromParent();
/// Detach this Value from its parent and delete it.
void eraseFromParent();
/// Insert this detached instruction before \p BeforeI.
void insertBefore(Instruction *BeforeI);
/// Insert this detached instruction after \p AfterI.
void insertAfter(Instruction *AfterI);
/// Insert this detached instruction into \p BB at \p WhereIt.
void insertInto(BasicBlock *BB, const BBIterator &WhereIt);
/// Move this instruction to \p WhereIt.
void moveBefore(BasicBlock &BB, const BBIterator &WhereIt);
/// Move this instruction before \p Before.
void moveBefore(Instruction *Before) {
moveBefore(*Before->getParent(), Before->getIterator());
}
/// Move this instruction after \p After.
void moveAfter(Instruction *After) {
moveBefore(*After->getParent(), std::next(After->getIterator()));
}
/// \Returns the BasicBlock containing this Instruction, or null if it is
/// detached.
BasicBlock *getParent() const;
/// For isa/dyn_cast.
static bool classof(const sandboxir::Value *From);

Expand All @@ -543,6 +586,9 @@ class OpaqueInst : public sandboxir::Instruction {
Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
return getOperandUseDefault(OpIdx, Verify);
}
SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
return {cast<llvm::Instruction>(Val)};
}

public:
static bool classof(const sandboxir::Value *From) {
Expand Down Expand Up @@ -570,7 +616,8 @@ class BasicBlock : public Value {
/// Builds a graph that contains all values in \p BB in their original form
/// i.e., no vectorization is taking place here.
void buildBasicBlockFromLLVMIR(llvm::BasicBlock *LLVMBB);
friend class Context; // For `buildBasicBlockFromIR`
friend class Context; // For `buildBasicBlockFromIR`
friend class Instruction; // For LLVM Val.

BasicBlock(llvm::BasicBlock *BB, Context &SBCtx)
: Value(ClassID::Block, BB, SBCtx) {
Expand Down Expand Up @@ -623,6 +670,12 @@ class Context {
DenseMap<llvm::Value *, std::unique_ptr<sandboxir::Value>>
LLVMValueToValueMap;

/// Remove \p V from the maps and returns the unique_ptr.
std::unique_ptr<Value> detachLLVMValue(llvm::Value *V);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm debating whether or not this should have [[nodiscard]]. Unsure how pervasive you'd want [[nodiscard]] to be throughout these APIs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel strongly either way, but these detach() functions are meant to be used either for detaching or for just deleting the value. I think it's currently only used for deleting them, so the value is discarded. There is no real harm in discarding the returned value as far as I can tell.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we wanted to keep around Values in case we rollback

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These APIs are for internal use only. The user won't be able to detach values. The only way a user can delete a value is through the eraseFromParent() API, and this will be tracked and can be rolled back.

/// Remove \p SBV from all SandboxIR maps and stop owning it. This effectively
/// detaches \p V from the underlying IR.
std::unique_ptr<Value> detach(Value *V);
friend void Instruction::eraseFromParent(); // For detach().
/// Take ownership of VPtr and store it in `LLVMValueToValueMap`.
Value *registerValue(std::unique_ptr<Value> &&VPtr);

Expand Down
127 changes: 127 additions & 0 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,115 @@ const char *Instruction::getOpcodeName(Opcode Opc) {
llvm_unreachable("Unknown Opcode");
}

llvm::Instruction *Instruction::getTopmostLLVMInstruction() const {
Instruction *Prev = getPrevNode();
if (Prev == nullptr) {
// If at top of the BB, return the first BB instruction.
return &*cast<llvm::BasicBlock>(getParent()->Val)->begin();
}
// Else get the Previous sandbox IR instruction's bottom IR instruction and
// return its successor.
llvm::Instruction *PrevBotI = cast<llvm::Instruction>(Prev->Val);
return PrevBotI->getNextNode();
}

BBIterator Instruction::getIterator() const {
auto *I = cast<llvm::Instruction>(Val);
return BasicBlock::iterator(I->getParent(), I->getIterator(), &Ctx);
}

Instruction *Instruction::getNextNode() const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this necessary? this is part of ilist_node_with_parent, so the name doesn't really make sense here. if we do want to support this operation, I'd at least rename it to something like getNextInstruction()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well getNextNode() is part of the Instruction API, so the user expects it to be there. So I am not sure we should remove it, or use a different name for a function that does the same thing. Perhaps we could have both getNextInstruction() and getNextNode()? Wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think that Node is an implementation detail and we can just have getNextInstruction(), but don't feel that strongly. This is fine for now.

assert(getParent() != nullptr && "Detached!");
assert(getIterator() != getParent()->end() && "Already at end!");
auto *LLVMI = cast<llvm::Instruction>(Val);
assert(LLVMI->getParent() != nullptr && "LLVM IR instr is detached!");
auto *NextLLVMI = LLVMI->getNextNode();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about multi-instruction instructions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So here is how this works: NextLLVMI is an LLVM instruction that belongs to the next SandboxIR instruction, whether it's a single-Instruction or multi-Instruction. If it's a multi-Instruction, then NextLLVMI will be the topmost LLVM instruction. In either case Ctxt.getValue(NextLLVMI) returns the correct SandboxIR Instruction that maps to NextLLVMI. Does this make sense?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant if this is a multi-instruction, then NextLLVMI will point to the second LLVM instruction in this. Or are you going to override getNextNode() in the multi-instruction subclass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah the confusing part is that Val always points to the bottom LLVM Instruction of a sandboxir::Instruction, not the topmost one.
So if this is a multi-instruction, LLVMI is the bottom-most LLVM Instruction in this, so LLVMI->getNextNode() points to an LLVM instruction that does not belong to this, but instead to the next sandboxir::Instruction.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I see. can you update the comments around Val (separate PR) to say that?

auto *NextI = cast_or_null<Instruction>(Ctx.getValue(NextLLVMI));
if (NextI == nullptr)
return nullptr;
return NextI;
}

Instruction *Instruction::getPrevNode() const {
assert(getParent() != nullptr && "Detached!");
auto It = getIterator();
if (It != getParent()->begin())
return std::prev(getIterator()).get();
return nullptr;
}

void Instruction::removeFromParent() {
// Detach all the LLVM IR instructions from their parent BB.
for (llvm::Instruction *I : getLLVMInstrs())
I->removeFromParent();
}

void Instruction::eraseFromParent() {
assert(users().empty() && "Still connected to users, can't erase!");
// We don't have Tracking yet, so just erase the LLVM IR instructions.
// Erase in reverse to avoid erasing nstructions with attached uses.
for (llvm::Instruction *I : reverse(getLLVMInstrs()))
I->eraseFromParent();
}

void Instruction::moveBefore(BasicBlock &BB, const BBIterator &WhereIt) {
if (std::next(getIterator()) == WhereIt)
// Destination is same as origin, nothing to do.
return;
auto *LLVMBB = cast<llvm::BasicBlock>(BB.Val);
llvm::BasicBlock::iterator It;
if (WhereIt == BB.end()) {
It = LLVMBB->end();
} else {
Instruction *WhereI = &*WhereIt;
It = WhereI->getTopmostLLVMInstruction()->getIterator();
}
// TODO: Move this to the verifier of sandboxir::Instruction.
assert(is_sorted(getLLVMInstrs(),
[](auto *I1, auto *I2) { return I1->comesBefore(I2); }) &&
"Expected program order!");
// Do the actual move in LLVM IR.
for (auto *I : getLLVMInstrs())
I->moveBefore(*LLVMBB, It);
}

void Instruction::insertBefore(Instruction *BeforeI) {
llvm::Instruction *BeforeTopI = BeforeI->getTopmostLLVMInstruction();
// TODO: Move this to the verifier of sandboxir::Instruction.
assert(is_sorted(getLLVMInstrs(),
[](auto *I1, auto *I2) { return I1->comesBefore(I2); }) &&
"Expected program order!");
for (llvm::Instruction *I : getLLVMInstrs())
I->insertBefore(BeforeTopI);
}

void Instruction::insertAfter(Instruction *AfterI) {
insertInto(AfterI->getParent(), std::next(AfterI->getIterator()));
}

void Instruction::insertInto(BasicBlock *BB, const BBIterator &WhereIt) {
llvm::BasicBlock *LLVMBB = cast<llvm::BasicBlock>(BB->Val);
llvm::Instruction *LLVMBeforeI;
llvm::BasicBlock::iterator LLVMBeforeIt;
if (WhereIt != BB->end()) {
Instruction *BeforeI = &*WhereIt;
LLVMBeforeI = BeforeI->getTopmostLLVMInstruction();
LLVMBeforeIt = LLVMBeforeI->getIterator();
} else {
LLVMBeforeI = nullptr;
LLVMBeforeIt = LLVMBB->end();
}
for (llvm::Instruction *I : getLLVMInstrs())
I->insertInto(LLVMBB, LLVMBeforeIt);
}

BasicBlock *Instruction::getParent() const {
auto *BB = cast<llvm::Instruction>(Val)->getParent();
if (BB == nullptr)
return nullptr;
return cast<BasicBlock>(Ctx.getValue(BB));
}

bool Instruction::classof(const sandboxir::Value *From) {
switch (From->getSubclassID()) {
#define DEF_INSTR(ID, OPC, CLASS) \
Expand Down Expand Up @@ -344,6 +453,24 @@ BasicBlock::iterator::getInstr(llvm::BasicBlock::iterator It) const {
return cast_or_null<Instruction>(Ctx->getValue(&*It));
}

std::unique_ptr<Value> Context::detachLLVMValue(llvm::Value *V) {
std::unique_ptr<Value> Erased;
auto It = LLVMValueToValueMap.find(V);
if (It != LLVMValueToValueMap.end()) {
auto *Val = It->second.release();
Erased = std::unique_ptr<Value>(Val);
LLVMValueToValueMap.erase(It);
}
return Erased;
}

std::unique_ptr<Value> Context::detach(Value *V) {
assert(V->getSubclassID() != Value::ClassID::Constant &&
"Can't detach a constant!");
assert(V->getSubclassID() != Value::ClassID::User && "Can't detach a user!");
return detachLLVMValue(V->Val);
}

Value *Context::registerValue(std::unique_ptr<Value> &&VPtr) {
assert(VPtr->getSubclassID() != Value::ClassID::User &&
"Can't register a user!");
Expand Down
89 changes: 89 additions & 0 deletions llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,3 +471,92 @@ define void @foo(i32 %v1) {
}
#endif // NDEBUG
}

TEST_F(SandboxIRTest, Instruction) {
parseIR(C, R"IR(
define void @foo(i8 %v1) {
%add0 = add i8 %v1, %v1
%sub1 = sub i8 %add0, %v1
ret void
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
sandboxir::Context Ctx(C);
sandboxir::Function *F = Ctx.createFunction(LLVMF);
auto *Arg = F->getArg(0);
auto *BB = &*F->begin();
auto It = BB->begin();
auto *I0 = &*It++;
auto *I1 = &*It++;
auto *Ret = &*It++;

// Check getPrevNode().
EXPECT_EQ(Ret->getPrevNode(), I1);
EXPECT_EQ(I1->getPrevNode(), I0);
EXPECT_EQ(I0->getPrevNode(), nullptr);

// Check getNextNode().
EXPECT_EQ(I0->getNextNode(), I1);
EXPECT_EQ(I1->getNextNode(), Ret);
EXPECT_EQ(Ret->getNextNode(), nullptr);

// Check getIterator().
EXPECT_EQ(I0->getIterator(), std::next(BB->begin(), 0));
EXPECT_EQ(I1->getIterator(), std::next(BB->begin(), 1));
EXPECT_EQ(Ret->getIterator(), std::next(BB->begin(), 2));

// Check getOpcode().
EXPECT_EQ(I0->getOpcode(), sandboxir::Instruction::Opcode::Opaque);
EXPECT_EQ(I1->getOpcode(), sandboxir::Instruction::Opcode::Opaque);
EXPECT_EQ(Ret->getOpcode(), sandboxir::Instruction::Opcode::Opaque);

// Check moveBefore(I).
I1->moveBefore(I0);
EXPECT_EQ(I0->getPrevNode(), I1);
EXPECT_EQ(I1->getNextNode(), I0);

// Check moveAfter(I).
I1->moveAfter(I0);
EXPECT_EQ(I0->getNextNode(), I1);
EXPECT_EQ(I1->getPrevNode(), I0);

// Check moveBefore(BB, It).
I1->moveBefore(*BB, BB->begin());
EXPECT_EQ(I1->getPrevNode(), nullptr);
EXPECT_EQ(I1->getNextNode(), I0);
I1->moveBefore(*BB, BB->end());
EXPECT_EQ(I1->getNextNode(), nullptr);
EXPECT_EQ(Ret->getNextNode(), I1);
I1->moveBefore(*BB, std::next(BB->begin()));
EXPECT_EQ(I0->getNextNode(), I1);
EXPECT_EQ(I1->getNextNode(), Ret);

// Check removeFromParent().
I0->removeFromParent();
#ifndef NDEBUG
EXPECT_DEATH(I0->getPrevNode(), ".*Detached.*");
EXPECT_DEATH(I0->getNextNode(), ".*Detached.*");
#endif // NDEBUG
EXPECT_EQ(I0->getParent(), nullptr);
EXPECT_EQ(I1->getPrevNode(), nullptr);
EXPECT_EQ(I0->getOperand(0), Arg);

// Check insertBefore().
I0->insertBefore(I1);
EXPECT_EQ(I1->getPrevNode(), I0);

// Check insertInto().
I0->removeFromParent();
I0->insertInto(BB, BB->end());
EXPECT_EQ(Ret->getNextNode(), I0);
I0->moveBefore(I1);
EXPECT_EQ(I0->getNextNode(), I1);

// Check eraseFromParent().
#ifndef NDEBUG
EXPECT_DEATH(I0->eraseFromParent(), "Still connected to users.*");
#endif
I1->eraseFromParent();
EXPECT_EQ(I0->getNumUses(), 0u);
EXPECT_EQ(I0->getNextNode(), Ret);
}
Loading