22
22
#include " llvm/IR/PatternMatch.h"
23
23
#include " llvm/InitializePasses.h"
24
24
#include " llvm/Pass.h"
25
+ #include " llvm/Support/CommandLine.h"
25
26
#include " llvm/Support/Debug.h"
26
27
#include " llvm/Support/MathExtras.h"
27
28
#include " llvm/Support/raw_ostream.h"
@@ -33,6 +34,11 @@ using namespace llvm;
33
34
34
35
STATISTIC (NumEliminatedCanonicalIV, " Number of canonical IVs we eliminated" );
35
36
37
+ static cl::opt<bool > EnableEVLIndVarSimplify (
38
+ " enable-evl-indvar-simplify" ,
39
+ cl::desc (" Enable EVL-based induction variable simplify Pass" ), cl::Hidden,
40
+ cl::init(true ));
41
+
36
42
namespace {
37
43
struct EVLIndVarSimplifyImpl {
38
44
ScalarEvolution &SE;
@@ -62,10 +68,9 @@ struct EVLIndVarSimplify : public LoopPass {
62
68
};
63
69
} // anonymous namespace
64
70
65
- static std::optional<uint32_t > getVFFromIndVar (const SCEV *Step,
66
- const Function &F) {
71
+ static uint32_t getVFFromIndVar (const SCEV *Step, const Function &F) {
67
72
if (!Step)
68
- return std::nullopt ;
73
+ return 0U ;
69
74
70
75
// Looking for loops with IV step value in the form of `(<constant VF> x
71
76
// vscale)`.
@@ -95,14 +100,18 @@ static std::optional<uint32_t> getVFFromIndVar(const SCEV *Step,
95
100
}
96
101
}
97
102
98
- return std::nullopt ;
103
+ return 0U ;
99
104
}
100
105
101
106
// Remove the original induction variable if it's not used anywhere.
102
- static void cleanupOriginalIndVar (PHINode *OrigIndVar, BasicBlock *InitBlock,
103
- BasicBlock *BackEdgeBlock) {
104
- Value *InitValue = OrigIndVar->getIncomingValueForBlock (InitBlock);
105
- Value *RecValue = OrigIndVar->getIncomingValueForBlock (BackEdgeBlock);
107
+ static void tryCleanupOriginalIndVar (PHINode *OrigIndVar,
108
+ const InductionDescriptor &IVD) {
109
+ if (OrigIndVar->getNumIncomingValues () != 2 )
110
+ return ;
111
+ Value *InitValue = OrigIndVar->getIncomingValue (0 );
112
+ Value *RecValue = OrigIndVar->getIncomingValue (1 );
113
+ if (InitValue != IVD.getStartValue ())
114
+ std::swap (InitValue, RecValue);
106
115
107
116
// If the only user of OrigIndVar is the one produces RecValue, then we can
108
117
// safely remove it.
@@ -117,6 +126,9 @@ static void cleanupOriginalIndVar(PHINode *OrigIndVar, BasicBlock *InitBlock,
117
126
}
118
127
119
128
bool EVLIndVarSimplifyImpl::run (Loop &L) {
129
+ if (!EnableEVLIndVarSimplify)
130
+ return false ;
131
+
120
132
InductionDescriptor IVD;
121
133
PHINode *IndVar = L.getInductionVariable (SE);
122
134
if (!IndVar || !L.getInductionDescriptor (SE, IVD)) {
@@ -143,23 +155,23 @@ bool EVLIndVarSimplifyImpl::run(Loop &L) {
143
155
Value *CanonicalIVFinal = &Bounds->getFinalIVValue ();
144
156
145
157
const SCEV *StepV = IVD.getStep ();
146
- auto VF = getVFFromIndVar (StepV, *L.getHeader ()->getParent ());
158
+ uint32_t VF = getVFFromIndVar (StepV, *L.getHeader ()->getParent ());
147
159
if (!VF) {
148
160
LLVM_DEBUG (dbgs () << " Could not infer VF from IndVar step '" << *StepV
149
161
<< " '\n " );
150
162
return false ;
151
163
}
152
- LLVM_DEBUG (dbgs () << " Using VF=" << * VF << " for loop " << L.getName ()
164
+ LLVM_DEBUG (dbgs () << " Using VF=" << VF << " for loop " << L.getName ()
153
165
<< " \n " );
154
166
155
167
// Try to find the EVL-based induction variable.
156
168
using namespace PatternMatch ;
157
169
BasicBlock *BB = IndVar->getParent ();
158
170
159
- Value *EVLIndex = nullptr ;
160
- Value *RemVL = nullptr , *AVL = nullptr ;
171
+ Value *EVLIndVar = nullptr ;
172
+ Value *RemTC = nullptr , *TC = nullptr ;
161
173
auto IntrinsicMatch = m_Intrinsic<Intrinsic::experimental_get_vector_length>(
162
- m_Value (RemVL ), m_SpecificInt (* VF),
174
+ m_Value (RemTC ), m_SpecificInt (VF),
163
175
/* Scalable=*/ m_SpecificInt (1 ));
164
176
for (auto &PN : BB->phis ()) {
165
177
if (&PN == IndVar)
@@ -198,19 +210,19 @@ bool EVLIndVarSimplifyImpl::run(Loop &L) {
198
210
<< " \n " );
199
211
200
212
// Check 3: Pattern match to find the EVL-based index and total trip count
201
- // (AVL ).
213
+ // (TC ).
202
214
if (match (RecValue,
203
215
m_c_Add (m_ZExtOrSelf (IntrinsicMatch), m_Specific (&PN))) &&
204
- match (RemVL , m_Sub (m_Value (AVL ), m_Specific (&PN)))) {
205
- EVLIndex = RecValue;
216
+ match (RemTC , m_Sub (m_Value (TC ), m_Specific (&PN)))) {
217
+ EVLIndVar = RecValue;
206
218
break ;
207
219
}
208
220
}
209
221
210
- if (!EVLIndex || !AVL )
222
+ if (!EVLIndVar || !TC )
211
223
return false ;
212
224
213
- LLVM_DEBUG (dbgs () << " Using " << *EVLIndex << " for EVL-based IndVar\n " );
225
+ LLVM_DEBUG (dbgs () << " Using " << *EVLIndVar << " for EVL-based IndVar\n " );
214
226
215
227
// Create an EVL-based comparison and replace the branch to use it as
216
228
// predicate.
@@ -220,10 +232,10 @@ bool EVLIndVarSimplifyImpl::run(Loop &L) {
220
232
return false ;
221
233
222
234
IRBuilder<> Builder (OrigLatchCmp);
223
- auto *NewPred = Builder.CreateICmp (Pred, EVLIndex, AVL );
235
+ auto *NewPred = Builder.CreateICmp (Pred, EVLIndVar, TC );
224
236
OrigLatchCmp->replaceAllUsesWith (NewPred);
225
237
226
- cleanupOriginalIndVar (IndVar, InitBlock, BackEdgeBlock );
238
+ tryCleanupOriginalIndVar (IndVar, IVD );
227
239
228
240
++NumEliminatedCanonicalIV;
229
241
0 commit comments