Skip to content

[SandboxIR] Implement PHINodes #101111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ class Value {
friend class CallBrInst; // For getting `Val`.
friend class GetElementPtrInst; // For getting `Val`.
friend class CastInst; // For getting `Val`.
friend class PHINode; // For getting `Val`.

/// All values point to the context.
Context &Ctx;
Expand Down Expand Up @@ -567,6 +568,7 @@ class Instruction : public sandboxir::User {
friend class CallBrInst; // For getTopmostLLVMInstruction().
friend class GetElementPtrInst; // For getTopmostLLVMInstruction().
friend class CastInst; // For getTopmostLLVMInstruction().
friend class PHINode; // For getTopmostLLVMInstruction().

/// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
/// order.
Expand Down Expand Up @@ -1464,6 +1466,100 @@ class IntToPtrInst final : public CastInst {
#endif // NDEBUG
};

class PHINode final : public Instruction {
/// Use Context::createPHINode(). Don't call the constructor directly.
PHINode(llvm::PHINode *PHI, Context &Ctx)
: Instruction(ClassID::PHI, Opcode::PHI, PHI, Ctx) {}
friend Context; // for PHINode()
Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
return getOperandUseDefault(OpIdx, Verify);
}
SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
return {cast<llvm::Instruction>(Val)};
}
/// Helper for mapped_iterator.
struct LLVMBBToBB {
Context &Ctx;
LLVMBBToBB(Context &Ctx) : Ctx(Ctx) {}
BasicBlock *operator()(llvm::BasicBlock *LLVMBB) const;
};

public:
unsigned getUseOperandNo(const Use &Use) const final {
return getUseOperandNoDefault(Use);
}
unsigned getNumOfIRInstrs() const final { return 1u; }
static PHINode *create(Type *Ty, unsigned NumReservedValues,
Instruction *InsertBefore, Context &Ctx,
const Twine &Name = "");
/// For isa/dyn_cast.
static bool classof(const Value *From);

using const_block_iterator =
mapped_iterator<llvm::PHINode::const_block_iterator, LLVMBBToBB>;

const_block_iterator block_begin() const {
LLVMBBToBB BBGetter(Ctx);
return const_block_iterator(cast<llvm::PHINode>(Val)->block_begin(),
BBGetter);
}
const_block_iterator block_end() const {
LLVMBBToBB BBGetter(Ctx);
return const_block_iterator(cast<llvm::PHINode>(Val)->block_end(),
BBGetter);
}
iterator_range<const_block_iterator> blocks() const {
return make_range(block_begin(), block_end());
}

op_range incoming_values() { return operands(); }

const_op_range incoming_values() const { return operands(); }

unsigned getNumIncomingValues() const {
return cast<llvm::PHINode>(Val)->getNumIncomingValues();
}
Value *getIncomingValue(unsigned Idx) const;
void setIncomingValue(unsigned Idx, Value *V);
static unsigned getOperandNumForIncomingValue(unsigned Idx) {
return llvm::PHINode::getOperandNumForIncomingValue(Idx);
}
static unsigned getIncomingValueNumForOperand(unsigned Idx) {
return llvm::PHINode::getIncomingValueNumForOperand(Idx);
}
BasicBlock *getIncomingBlock(unsigned Idx) const;
BasicBlock *getIncomingBlock(const Use &U) const;

void setIncomingBlock(unsigned Idx, BasicBlock *BB);

void addIncoming(Value *V, BasicBlock *BB);

Value *removeIncomingValue(unsigned Idx);
Value *removeIncomingValue(BasicBlock *BB);

int getBasicBlockIndex(const BasicBlock *BB) const;
Value *getIncomingValueForBlock(const BasicBlock *BB) const;

Value *hasConstantValue() const;

bool hasConstantOrUndefValue() const {
return cast<llvm::PHINode>(Val)->hasConstantOrUndefValue();
}
bool isComplete() const { return cast<llvm::PHINode>(Val)->isComplete(); }
// TODO: Implement the below functions:
// void replaceIncomingBlockWith (const BasicBlock *Old, BasicBlock *New);
// void copyIncomingBlocks(iterator_range<const_block_iterator> BBRange,
// uint32_t ToIdx = 0)
// void removeIncomingValueIf(function_ref< bool(unsigned)> Predicate,
// bool DeletePHIIfEmpty=true)
#ifndef NDEBUG
void verify() const final {
assert(isa<llvm::PHINode>(Val) && "Expected PHINode!");
}
void dump(raw_ostream &OS) const override;
LLVM_DUMP_METHOD void dump() const override;
#endif
};
class PtrToIntInst final : public CastInst {
public:
static Value *create(Value *Src, Type *DestTy, BBIterator WhereIt,
Expand Down Expand Up @@ -1700,6 +1796,8 @@ class Context {
friend GetElementPtrInst; // For createGetElementPtrInst()
CastInst *createCastInst(llvm::CastInst *I);
friend CastInst; // For createCastInst()
PHINode *createPHINode(llvm::PHINode *I);
friend PHINode; // For createPHINode()

public:
Context(LLVMContext &LLVMCtx)
Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/SandboxIR/SandboxIRValues.def
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ DEF_INSTR(Cast, OPCODES(\
OP(BitCast) \
OP(AddrSpaceCast) \
), CastInst)
DEF_INSTR(PHI, OP(PHI), PHINode)

// clang-format on
#ifdef DEF_VALUE
#undef DEF_VALUE
Expand Down
58 changes: 58 additions & 0 deletions llvm/include/llvm/SandboxIR/Tracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,64 @@ class UseSet : public IRChangeBase {
#endif
};

class PHISetIncoming : public IRChangeBase {
PHINode &PHI;
unsigned Idx;
PointerUnion<Value *, BasicBlock *> OrigValueOrBB;

public:
enum class What {
Value,
Block,
};
PHISetIncoming(PHINode &PHI, unsigned Idx, What What, Tracker &Tracker);
void revert() final;
void accept() final {}
#ifndef NDEBUG
void dump(raw_ostream &OS) const final {
dumpCommon(OS);
OS << "PHISetIncoming";
}
LLVM_DUMP_METHOD void dump() const final;
#endif
};

class PHIRemoveIncoming : public IRChangeBase {
PHINode &PHI;
unsigned RemovedIdx;
Value *RemovedV;
BasicBlock *RemovedBB;

public:
PHIRemoveIncoming(PHINode &PHI, unsigned RemovedIdx, Tracker &Tracker);
void revert() final;
void accept() final {}
#ifndef NDEBUG
void dump(raw_ostream &OS) const final {
dumpCommon(OS);
OS << "PHISetIncoming";
}
LLVM_DUMP_METHOD void dump() const final;
#endif
};

class PHIAddIncoming : public IRChangeBase {
PHINode &PHI;
unsigned Idx;

public:
PHIAddIncoming(PHINode &PHI, Tracker &Tracker);
void revert() final;
void accept() final {}
#ifndef NDEBUG
void dump(raw_ostream &OS) const final {
dumpCommon(OS);
OS << "PHISetIncoming";
}
LLVM_DUMP_METHOD void dump() const final;
#endif
};

/// Tracks swapping a Use with another Use.
class UseSwap : public IRChangeBase {
Use ThisUse;
Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/SandboxIR/Use.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Context;
class Value;
class User;
class CallBase;
class PHINode;

/// Represents a Def-use/Use-def edge in SandboxIR.
/// NOTE: Unlike llvm::Use, this is not an integral part of the use-def chains.
Expand All @@ -43,6 +44,7 @@ class Use {
friend class UserUseIterator; // For accessing members
friend class CallBase; // For LLVMUse
friend class CallBrInst; // For constructor
friend class PHINode; // For LLVMUse

public:
operator Value *() const { return get(); }
Expand Down
108 changes: 108 additions & 0 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,95 @@ void GetElementPtrInst::dump() const {
}
#endif // NDEBUG

BasicBlock *PHINode::LLVMBBToBB::operator()(llvm::BasicBlock *LLVMBB) const {
return cast<BasicBlock>(Ctx.getValue(LLVMBB));
}

PHINode *PHINode::create(Type *Ty, unsigned NumReservedValues,
Instruction *InsertBefore, Context &Ctx,
const Twine &Name) {
llvm::PHINode *NewPHI = llvm::PHINode::Create(
Ty, NumReservedValues, Name, InsertBefore->getTopmostLLVMInstruction());
return Ctx.createPHINode(NewPHI);
}

bool PHINode::classof(const Value *From) {
return From->getSubclassID() == ClassID::PHI;
}

Value *PHINode::getIncomingValue(unsigned Idx) const {
return Ctx.getValue(cast<llvm::PHINode>(Val)->getIncomingValue(Idx));
}
void PHINode::setIncomingValue(unsigned Idx, Value *V) {
auto &Tracker = Ctx.getTracker();
if (Tracker.isTracking())
Tracker.track(std::make_unique<PHISetIncoming>(
*this, Idx, PHISetIncoming::What::Value, Tracker));

cast<llvm::PHINode>(Val)->setIncomingValue(Idx, V->Val);
}
BasicBlock *PHINode::getIncomingBlock(unsigned Idx) const {
return cast<BasicBlock>(
Ctx.getValue(cast<llvm::PHINode>(Val)->getIncomingBlock(Idx)));
}
BasicBlock *PHINode::getIncomingBlock(const Use &U) const {
llvm::Use *LLVMUse = U.LLVMUse;
llvm::BasicBlock *BB = cast<llvm::PHINode>(Val)->getIncomingBlock(*LLVMUse);
return cast<BasicBlock>(Ctx.getValue(BB));
}
void PHINode::setIncomingBlock(unsigned Idx, BasicBlock *BB) {
auto &Tracker = Ctx.getTracker();
if (Tracker.isTracking())
Tracker.track(std::make_unique<PHISetIncoming>(
*this, Idx, PHISetIncoming::What::Block, Tracker));
cast<llvm::PHINode>(Val)->setIncomingBlock(Idx,
cast<llvm::BasicBlock>(BB->Val));
}
void PHINode::addIncoming(Value *V, BasicBlock *BB) {
auto &Tracker = Ctx.getTracker();
if (Tracker.isTracking())
Tracker.track(std::make_unique<PHIAddIncoming>(*this, Tracker));

cast<llvm::PHINode>(Val)->addIncoming(V->Val,
cast<llvm::BasicBlock>(BB->Val));
}
Value *PHINode::removeIncomingValue(unsigned Idx) {
auto &Tracker = Ctx.getTracker();
if (Tracker.isTracking())
Tracker.track(std::make_unique<PHIRemoveIncoming>(*this, Idx, Tracker));

llvm::Value *LLVMV =
cast<llvm::PHINode>(Val)->removeIncomingValue(Idx,
/*DeletePHIIfEmpty=*/false);
return Ctx.getValue(LLVMV);
}
Value *PHINode::removeIncomingValue(BasicBlock *BB) {
auto &Tracker = Ctx.getTracker();
if (Tracker.isTracking())
Tracker.track(std::make_unique<PHIRemoveIncoming>(
*this, getBasicBlockIndex(BB), Tracker));

auto *LLVMBB = cast<llvm::BasicBlock>(BB->Val);
llvm::Value *LLVMV =
cast<llvm::PHINode>(Val)->removeIncomingValue(LLVMBB,
/*DeletePHIIfEmpty=*/false);
return Ctx.getValue(LLVMV);
}
int PHINode::getBasicBlockIndex(const BasicBlock *BB) const {
auto *LLVMBB = cast<llvm::BasicBlock>(BB->Val);
return cast<llvm::PHINode>(Val)->getBasicBlockIndex(LLVMBB);
}
Value *PHINode::getIncomingValueForBlock(const BasicBlock *BB) const {
auto *LLVMBB = cast<llvm::BasicBlock>(BB->Val);
llvm::Value *LLVMV =
cast<llvm::PHINode>(Val)->getIncomingValueForBlock(LLVMBB);
return Ctx.getValue(LLVMV);
}
Value *PHINode::hasConstantValue() const {
llvm::Value *LLVMV = cast<llvm::PHINode>(Val)->hasConstantValue();
return LLVMV != nullptr ? Ctx.getValue(LLVMV) : nullptr;
}

static llvm::Instruction::CastOps getLLVMCastOp(Instruction::Opcode Opc) {
switch (Opc) {
case Instruction::Opcode::ZExt:
Expand Down Expand Up @@ -1272,6 +1361,16 @@ Value *PtrToIntInst::create(Value *Src, Type *DestTy, BasicBlock *InsertAtEnd,
}

#ifndef NDEBUG
void PHINode::dump(raw_ostream &OS) const {
dumpCommonPrefix(OS);
dumpCommonSuffix(OS);
}

void PHINode::dump() const {
dump(dbgs());
dbgs() << "\n";
}

void PtrToIntInst::dump(raw_ostream &OS) const {
dumpCommonPrefix(OS);
dumpCommonSuffix(OS);
Expand Down Expand Up @@ -1537,6 +1636,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
It->second = std::unique_ptr<CastInst>(new CastInst(LLVMCast, *this));
return It->second.get();
}
case llvm::Instruction::PHI: {
auto *LLVMPhi = cast<llvm::PHINode>(LLVMV);
It->second = std::unique_ptr<PHINode>(new PHINode(LLVMPhi, *this));
return It->second.get();
}
default:
break;
}
Expand Down Expand Up @@ -1606,6 +1710,10 @@ CastInst *Context::createCastInst(llvm::CastInst *I) {
auto NewPtr = std::unique_ptr<CastInst>(new CastInst(I, *this));
return cast<CastInst>(registerValue(std::move(NewPtr)));
}
PHINode *Context::createPHINode(llvm::PHINode *I) {
auto NewPtr = std::unique_ptr<PHINode>(new PHINode(I, *this));
return cast<PHINode>(registerValue(std::move(NewPtr)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Perehaps return cast<PHINode>(registerValue(std::unique_ptr<PHINode>(new PHINode(I, *this)))); ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That ends up being different from the others, like on the code just above this one. Should we change them all?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, your call. Either is fine with me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean we can keep it as it is too, I don't feel strongly about changing them.

}

Value *Context::getValue(llvm::Value *V) const {
auto It = LLVMValueToValueMap.find(V);
Expand Down
Loading
Loading