Skip to content

Commit 14f4a7e

Browse files
committed
Address review comments
1 parent 6bc5c94 commit 14f4a7e

File tree

1 file changed

+32
-20
lines changed

1 file changed

+32
-20
lines changed

llvm/lib/CodeGen/EVLIndVarSimplify.cpp

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "llvm/IR/PatternMatch.h"
2323
#include "llvm/InitializePasses.h"
2424
#include "llvm/Pass.h"
25+
#include "llvm/Support/CommandLine.h"
2526
#include "llvm/Support/Debug.h"
2627
#include "llvm/Support/MathExtras.h"
2728
#include "llvm/Support/raw_ostream.h"
@@ -33,6 +34,11 @@ using namespace llvm;
3334

3435
STATISTIC(NumEliminatedCanonicalIV, "Number of canonical IVs we eliminated");
3536

37+
static cl::opt<bool> EnableEVLIndVarSimplify(
38+
"enable-evl-indvar-simplify",
39+
cl::desc("Enable EVL-based induction variable simplify Pass"), cl::Hidden,
40+
cl::init(true));
41+
3642
namespace {
3743
struct EVLIndVarSimplifyImpl {
3844
ScalarEvolution &SE;
@@ -62,10 +68,9 @@ struct EVLIndVarSimplify : public LoopPass {
6268
};
6369
} // anonymous namespace
6470

65-
static std::optional<uint32_t> getVFFromIndVar(const SCEV *Step,
66-
const Function &F) {
71+
static uint32_t getVFFromIndVar(const SCEV *Step, const Function &F) {
6772
if (!Step)
68-
return std::nullopt;
73+
return 0U;
6974

7075
// Looking for loops with IV step value in the form of `(<constant VF> x
7176
// vscale)`.
@@ -95,14 +100,18 @@ static std::optional<uint32_t> getVFFromIndVar(const SCEV *Step,
95100
}
96101
}
97102

98-
return std::nullopt;
103+
return 0U;
99104
}
100105

101106
// Remove the original induction variable if it's not used anywhere.
102-
static void cleanupOriginalIndVar(PHINode *OrigIndVar, BasicBlock *InitBlock,
103-
BasicBlock *BackEdgeBlock) {
104-
Value *InitValue = OrigIndVar->getIncomingValueForBlock(InitBlock);
105-
Value *RecValue = OrigIndVar->getIncomingValueForBlock(BackEdgeBlock);
107+
static void tryCleanupOriginalIndVar(PHINode *OrigIndVar,
108+
const InductionDescriptor &IVD) {
109+
if (OrigIndVar->getNumIncomingValues() != 2)
110+
return;
111+
Value *InitValue = OrigIndVar->getIncomingValue(0);
112+
Value *RecValue = OrigIndVar->getIncomingValue(1);
113+
if (InitValue != IVD.getStartValue())
114+
std::swap(InitValue, RecValue);
106115

107116
// If the only user of OrigIndVar is the one produces RecValue, then we can
108117
// safely remove it.
@@ -117,6 +126,9 @@ static void cleanupOriginalIndVar(PHINode *OrigIndVar, BasicBlock *InitBlock,
117126
}
118127

119128
bool EVLIndVarSimplifyImpl::run(Loop &L) {
129+
if (!EnableEVLIndVarSimplify)
130+
return false;
131+
120132
InductionDescriptor IVD;
121133
PHINode *IndVar = L.getInductionVariable(SE);
122134
if (!IndVar || !L.getInductionDescriptor(SE, IVD)) {
@@ -143,23 +155,23 @@ bool EVLIndVarSimplifyImpl::run(Loop &L) {
143155
Value *CanonicalIVFinal = &Bounds->getFinalIVValue();
144156

145157
const SCEV *StepV = IVD.getStep();
146-
auto VF = getVFFromIndVar(StepV, *L.getHeader()->getParent());
158+
uint32_t VF = getVFFromIndVar(StepV, *L.getHeader()->getParent());
147159
if (!VF) {
148160
LLVM_DEBUG(dbgs() << "Could not infer VF from IndVar step '" << *StepV
149161
<< "'\n");
150162
return false;
151163
}
152-
LLVM_DEBUG(dbgs() << "Using VF=" << *VF << " for loop " << L.getName()
164+
LLVM_DEBUG(dbgs() << "Using VF=" << VF << " for loop " << L.getName()
153165
<< "\n");
154166

155167
// Try to find the EVL-based induction variable.
156168
using namespace PatternMatch;
157169
BasicBlock *BB = IndVar->getParent();
158170

159-
Value *EVLIndex = nullptr;
160-
Value *RemVL = nullptr, *AVL = nullptr;
171+
Value *EVLIndVar = nullptr;
172+
Value *RemTC = nullptr, *TC = nullptr;
161173
auto IntrinsicMatch = m_Intrinsic<Intrinsic::experimental_get_vector_length>(
162-
m_Value(RemVL), m_SpecificInt(*VF),
174+
m_Value(RemTC), m_SpecificInt(VF),
163175
/*Scalable=*/m_SpecificInt(1));
164176
for (auto &PN : BB->phis()) {
165177
if (&PN == IndVar)
@@ -198,19 +210,19 @@ bool EVLIndVarSimplifyImpl::run(Loop &L) {
198210
<< "\n");
199211

200212
// Check 3: Pattern match to find the EVL-based index and total trip count
201-
// (AVL).
213+
// (TC).
202214
if (match(RecValue,
203215
m_c_Add(m_ZExtOrSelf(IntrinsicMatch), m_Specific(&PN))) &&
204-
match(RemVL, m_Sub(m_Value(AVL), m_Specific(&PN)))) {
205-
EVLIndex = RecValue;
216+
match(RemTC, m_Sub(m_Value(TC), m_Specific(&PN)))) {
217+
EVLIndVar = RecValue;
206218
break;
207219
}
208220
}
209221

210-
if (!EVLIndex || !AVL)
222+
if (!EVLIndVar || !TC)
211223
return false;
212224

213-
LLVM_DEBUG(dbgs() << "Using " << *EVLIndex << " for EVL-based IndVar\n");
225+
LLVM_DEBUG(dbgs() << "Using " << *EVLIndVar << " for EVL-based IndVar\n");
214226

215227
// Create an EVL-based comparison and replace the branch to use it as
216228
// predicate.
@@ -220,10 +232,10 @@ bool EVLIndVarSimplifyImpl::run(Loop &L) {
220232
return false;
221233

222234
IRBuilder<> Builder(OrigLatchCmp);
223-
auto *NewPred = Builder.CreateICmp(Pred, EVLIndex, AVL);
235+
auto *NewPred = Builder.CreateICmp(Pred, EVLIndVar, TC);
224236
OrigLatchCmp->replaceAllUsesWith(NewPred);
225237

226-
cleanupOriginalIndVar(IndVar, InitBlock, BackEdgeBlock);
238+
tryCleanupOriginalIndVar(IndVar, IVD);
227239

228240
++NumEliminatedCanonicalIV;
229241

0 commit comments

Comments
 (0)