Skip to content

[GVN] MemorySSA for GVN: embed the memory state in symbolic expressions #123218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion llvm/include/llvm/Transforms/Scalar/GVN.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ class ImplicitControlFlowTracking;
class LoadInst;
class LoopInfo;
class MemDepResult;
class MemoryAccess;
class MemoryDependenceResults;
class MemoryLocation;
class MemorySSA;
class MemorySSAUpdater;
class NonLocalDepResult;
Expand Down Expand Up @@ -170,13 +172,20 @@ class GVNPass : public PassInfoMixin<GVNPass> {
// Value number to PHINode mapping. Used for phi-translate in scalarpre.
DenseMap<uint32_t, PHINode *> NumberingPhi;

// Value number to BasicBlock mapping. Used for phi-translate across
// MemoryPhis.
DenseMap<uint32_t, BasicBlock *> NumberingBB;

// Cache for phi-translate in scalarpre.
using PhiTranslateMap =
DenseMap<std::pair<uint32_t, const BasicBlock *>, uint32_t>;
PhiTranslateMap PhiTranslateTable;

AAResults *AA = nullptr;
MemoryDependenceResults *MD = nullptr;
bool IsMDEnabled = false;
MemorySSA *MSSA = nullptr;
bool IsMSSAEnabled = false;
DominatorTree *DT = nullptr;

uint32_t NextValueNumber = 1;
Expand All @@ -187,12 +196,14 @@ class GVNPass : public PassInfoMixin<GVNPass> {
Expression createExtractvalueExpr(ExtractValueInst *EI);
Expression createGEPExpr(GetElementPtrInst *GEP);
uint32_t lookupOrAddCall(CallInst *C);
uint32_t computeLoadStoreVN(Instruction *I);
uint32_t phiTranslateImpl(const BasicBlock *BB, const BasicBlock *PhiBlock,
uint32_t Num, GVNPass &GVN);
bool areCallValsEqual(uint32_t Num, uint32_t NewNum, const BasicBlock *Pred,
const BasicBlock *PhiBlock, GVNPass &GVN);
std::pair<uint32_t, bool> assignExpNewValueNum(Expression &Exp);
bool areAllValsInBB(uint32_t Num, const BasicBlock *BB, GVNPass &GVN);
void addMemoryStateToExp(Instruction *I, Expression &Exp);

public:
LLVM_ABI ValueTable();
Expand All @@ -201,6 +212,7 @@ class GVNPass : public PassInfoMixin<GVNPass> {
LLVM_ABI ~ValueTable();
LLVM_ABI ValueTable &operator=(const ValueTable &Arg);

LLVM_ABI uint32_t lookupOrAdd(MemoryAccess *MA);
LLVM_ABI uint32_t lookupOrAdd(Value *V);
LLVM_ABI uint32_t lookup(Value *V, bool Verify = true) const;
LLVM_ABI uint32_t lookupOrAddCmp(unsigned Opcode, CmpInst::Predicate Pred,
Expand All @@ -216,7 +228,14 @@ class GVNPass : public PassInfoMixin<GVNPass> {
LLVM_ABI void erase(Value *V);
void setAliasAnalysis(AAResults *A) { AA = A; }
AAResults *getAliasAnalysis() const { return AA; }
void setMemDep(MemoryDependenceResults *M) { MD = M; }
void setMemDep(MemoryDependenceResults *M, bool MDEnabled = true) {
MD = M;
IsMDEnabled = MDEnabled;
}
void setMemorySSA(MemorySSA *M, bool MSSAEnabled = false) {
MSSA = M;
IsMSSAEnabled = MSSAEnabled;
}
void setDomTree(DominatorTree *D) { DT = D; }
uint32_t getNextUnusedValueNumber() { return NextValueNumber; }
LLVM_ABI void verifyRemoved(const Value *) const;
Expand Down
89 changes: 84 additions & 5 deletions llvm/lib/Transforms/Scalar/GVN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,19 @@ void GVNPass::ValueTable::add(Value *V, uint32_t Num) {
NumberingPhi[Num] = PN;
}

/// Include the incoming memory state into the hash of the expression for the
/// given instruction. If the incoming memory state is:
/// * LiveOnEntry, add the value number of the entry block,
/// * a MemoryPhi, add the value number of the basic block corresponding to that
/// MemoryPhi,
/// * a MemoryDef, add the value number of the memory setting instruction.
void GVNPass::ValueTable::addMemoryStateToExp(Instruction *I, Expression &Exp) {
assert(MSSA && "addMemoryStateToExp should not be called without MemorySSA");
assert(MSSA->getMemoryAccess(I) && "Instruction does not access memory");
MemoryAccess *MA = MSSA->getSkipSelfWalker()->getClobberingMemoryAccess(I);
Exp.VarArgs.push_back(lookupOrAdd(MA));
}

uint32_t GVNPass::ValueTable::lookupOrAddCall(CallInst *C) {
// FIXME: Currently the calls which may access the thread id may
// be considered as not accessing the memory. But this is
Expand Down Expand Up @@ -594,15 +607,48 @@ uint32_t GVNPass::ValueTable::lookupOrAddCall(CallInst *C) {
return V;
}

if (MSSA && IsMSSAEnabled && AA->onlyReadsMemory(C)) {
Expression Exp = createExpr(C);
addMemoryStateToExp(C, Exp);
auto [V, _] = assignExpNewValueNum(Exp);
ValueNumbering[C] = V;
return V;
}

ValueNumbering[C] = NextValueNumber;
return NextValueNumber++;
}

/// Returns the value number for the specified load or store instruction.
uint32_t GVNPass::ValueTable::computeLoadStoreVN(Instruction *I) {
if (!MSSA || !IsMSSAEnabled) {
ValueNumbering[I] = NextValueNumber;
return NextValueNumber++;
}

Expression Exp;
Exp.Ty = I->getType();
Exp.Opcode = I->getOpcode();
for (Use &Op : I->operands())
Exp.VarArgs.push_back(lookupOrAdd(Op));
addMemoryStateToExp(I, Exp);

auto [V, _] = assignExpNewValueNum(Exp);
ValueNumbering[I] = V;
return V;
}

/// Returns true if a value number exists for the specified value.
bool GVNPass::ValueTable::exists(Value *V) const {
return ValueNumbering.contains(V);
}

uint32_t GVNPass::ValueTable::lookupOrAdd(MemoryAccess *MA) {
return MSSA->isLiveOnEntryDef(MA) || isa<MemoryPhi>(MA)
? lookupOrAdd(MA->getBlock())
: lookupOrAdd(cast<MemoryUseOrDef>(MA)->getMemoryInst());
}

/// lookupOrAdd - Returns the value number for the specified value, assigning
/// it a new number if it did not have one before.
uint32_t GVNPass::ValueTable::lookupOrAdd(Value *V) {
Expand All @@ -613,6 +659,8 @@ uint32_t GVNPass::ValueTable::lookupOrAdd(Value *V) {
auto *I = dyn_cast<Instruction>(V);
if (!I) {
ValueNumbering[V] = NextValueNumber;
if (isa<BasicBlock>(V))
NumberingBB[NextValueNumber] = cast<BasicBlock>(V);
return NextValueNumber++;
}

Expand Down Expand Up @@ -672,6 +720,9 @@ uint32_t GVNPass::ValueTable::lookupOrAdd(Value *V) {
ValueNumbering[V] = NextValueNumber;
NumberingPhi[NextValueNumber] = cast<PHINode>(V);
return NextValueNumber++;
case Instruction::Load:
case Instruction::Store:
return computeLoadStoreVN(I);
default:
ValueNumbering[V] = NextValueNumber;
return NextValueNumber++;
Expand Down Expand Up @@ -709,6 +760,7 @@ void GVNPass::ValueTable::clear() {
ValueNumbering.clear();
ExpressionNumbering.clear();
NumberingPhi.clear();
NumberingBB.clear();
PhiTranslateTable.clear();
NextValueNumber = 1;
Expressions.clear();
Expand All @@ -723,6 +775,8 @@ void GVNPass::ValueTable::erase(Value *V) {
// If V is PHINode, V <--> value number is an one-to-one mapping.
if (isa<PHINode>(V))
NumberingPhi.erase(Num);
else if (isa<BasicBlock>(V))
NumberingBB.erase(Num);
}

/// verifyRemoved - Verify that the value is removed from all internal data
Expand Down Expand Up @@ -2310,15 +2364,39 @@ bool GVNPass::ValueTable::areCallValsEqual(uint32_t Num, uint32_t NewNum,
uint32_t GVNPass::ValueTable::phiTranslateImpl(const BasicBlock *Pred,
const BasicBlock *PhiBlock,
uint32_t Num, GVNPass &GVN) {
// See if we can refine the value number by looking at the PN incoming value
// for the given predecessor.
if (PHINode *PN = NumberingPhi[Num]) {
for (unsigned I = 0; I != PN->getNumIncomingValues(); ++I) {
if (PN->getParent() == PhiBlock && PN->getIncomingBlock(I) == Pred)
if (uint32_t TransVal = lookup(PN->getIncomingValue(I), false))
return TransVal;
}
if (PN->getParent() == PhiBlock)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just hoisted the check out of the loop, can also drop as not strictly related to the patch.

for (unsigned I = 0; I != PN->getNumIncomingValues(); ++I)
if (PN->getIncomingBlock(I) == Pred)
if (uint32_t TransVal = lookup(PN->getIncomingValue(I), false))
return TransVal;
return Num;
}

if (BasicBlock *BB = NumberingBB[Num]) {
assert(MSSA && "NumberingBB is non-empty only when using MemorySSA");
// Value numbers of basic blocks are used to represent memory state in
// load/store instructions and read-only function calls when said state is
// set by a MemoryPhi.
if (BB != PhiBlock)
return Num;
MemoryPhi *MPhi = MSSA->getMemoryAccess(BB);
for (unsigned i = 0, N = MPhi->getNumIncomingValues(); i != N; ++i) {
if (MPhi->getIncomingBlock(i) != Pred)
continue;
MemoryAccess *MA = MPhi->getIncomingValue(i);
if (auto *PredPhi = dyn_cast<MemoryPhi>(MA))
return lookupOrAdd(PredPhi->getBlock());
if (MSSA->isLiveOnEntryDef(MA))
return lookupOrAdd(&BB->getParent()->getEntryBlock());
return lookupOrAdd(cast<MemoryUseOrDef>(MA)->getMemoryInst());
}
llvm_unreachable(
"CFG/MemorySSA mismatch: predecessor not found among incoming blocks");
}

// If there is any value related with Num is defined in a BB other than
// PhiBlock, it cannot depend on a phi in PhiBlock without going through
// a backedge. We can do an early exit in that case to save compile time.
Expand Down Expand Up @@ -2761,6 +2839,7 @@ bool GVNPass::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT,
ICF = &ImplicitCFT;
this->LI = &LI;
VN.setMemDep(MD);
VN.setMemorySSA(MSSA);
ORE = RunORE;
InvalidBlockRPONumbers = true;
MemorySSAUpdater Updater(MSSA);
Expand Down
Loading