Skip to content

Commit 2c5d339

Browse files
committed
[LV][EVL] Introduce the EVLIVSimplify Pass for EVL-vectorized loops
TBA...
1 parent 54d01d8 commit 2c5d339

File tree

8 files changed

+617
-0
lines changed

8 files changed

+617
-0
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//===------ EVLIndVarSimplify.h - Optimize vectorized loops w/ EVL IV------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This pass optimizes a vectorized loop with canonical IV to using EVL-based
10+
// IV if it was tail-folded by predicated EVL.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef LLVM_TRANSFORMS_VECTORIZE_EVLINDVARSIMPLIFY_H
15+
#define LLVM_TRANSFORMS_VECTORIZE_EVLINDVARSIMPLIFY_H
16+
17+
#include "llvm/Analysis/LoopAnalysisManager.h"
18+
#include "llvm/IR/PassManager.h"
19+
20+
namespace llvm {
21+
class Loop;
22+
class LPMUpdater;
23+
24+
/// Turn vectorized loops with canonical induction variables into loops that
25+
/// only use a single EVL-based induction variable.
26+
struct EVLIndVarSimplifyPass : public PassInfoMixin<EVLIndVarSimplifyPass> {
27+
PreservedAnalyses run(Loop &L, LoopAnalysisManager &LAM,
28+
LoopStandardAnalysisResults &AR, LPMUpdater &U);
29+
};
30+
} // namespace llvm
31+
#endif

llvm/lib/Passes/PassBuilder.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@
370370
#include "llvm/Transforms/Utils/SymbolRewriter.h"
371371
#include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h"
372372
#include "llvm/Transforms/Utils/UnifyLoopExits.h"
373+
#include "llvm/Transforms/Vectorize/EVLIndVarSimplify.h"
373374
#include "llvm/Transforms/Vectorize/LoadStoreVectorizer.h"
374375
#include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h"
375376
#include "llvm/Transforms/Vectorize/LoopVectorize.h"

llvm/lib/Passes/PassBuilderPipelines.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@
142142
#include "llvm/Transforms/Utils/NameAnonGlobals.h"
143143
#include "llvm/Transforms/Utils/RelLookupTableConverter.h"
144144
#include "llvm/Transforms/Utils/SimplifyCFGOptions.h"
145+
#include "llvm/Transforms/Vectorize/EVLIndVarSimplify.h"
145146
#include "llvm/Transforms/Vectorize/LoopVectorize.h"
146147
#include "llvm/Transforms/Vectorize/SLPVectorizer.h"
147148
#include "llvm/Transforms/Vectorize/VectorCombine.h"

llvm/lib/Passes/PassRegistry.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,7 @@ LOOP_ANALYSIS("should-run-extra-simple-loop-unswitch",
672672
#endif
673673
LOOP_PASS("canon-freeze", CanonicalizeFreezeInLoopsPass())
674674
LOOP_PASS("dot-ddg", DDGDotPrinterPass())
675+
LOOP_PASS("evl-iv-simplify", EVLIndVarSimplifyPass())
675676
LOOP_PASS("guard-widening", GuardWideningPass())
676677
LOOP_PASS("extra-simple-loop-unswitch-passes",
677678
ExtraLoopPassManager<ShouldRunExtraSimpleLoopUnswitch>())

llvm/lib/Target/RISCV/RISCVTargetMachine.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "llvm/Target/TargetOptions.h"
3838
#include "llvm/Transforms/IPO.h"
3939
#include "llvm/Transforms/Scalar.h"
40+
#include "llvm/Transforms/Vectorize/EVLIndVarSimplify.h"
4041
#include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h"
4142
#include <optional>
4243
using namespace llvm;
@@ -644,6 +645,12 @@ void RISCVTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
644645
OptimizationLevel Level) {
645646
LPM.addPass(LoopIdiomVectorizePass(LoopIdiomVectorizeStyle::Predicated));
646647
});
648+
649+
PB.registerVectorizerEndEPCallback(
650+
[](FunctionPassManager &FPM, OptimizationLevel Level) {
651+
if (Level.isOptimizingForSpeed())
652+
FPM.addPass(createFunctionToLoopPassAdaptor(EVLIndVarSimplifyPass()));
653+
});
647654
}
648655

649656
yaml::MachineFunctionInfo *

llvm/lib/Transforms/Vectorize/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_llvm_component_library(LLVMVectorize
2+
EVLIndVarSimplify.cpp
23
LoadStoreVectorizer.cpp
34
LoopIdiomVectorize.cpp
45
LoopVectorizationLegality.cpp
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
//===---- EVLIndVarSimplify.cpp - Optimize vectorized loops w/ EVL IV------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This pass optimizes a vectorized loop with canonical IV to using EVL-based
10+
// IV if it was tail-folded by predicated EVL.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "llvm/Transforms/Vectorize/EVLIndVarSimplify.h"
15+
#include "llvm/ADT/Statistic.h"
16+
#include "llvm/Analysis/IVDescriptors.h"
17+
#include "llvm/Analysis/LoopInfo.h"
18+
#include "llvm/Analysis/LoopPass.h"
19+
#include "llvm/Analysis/ScalarEvolution.h"
20+
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
21+
#include "llvm/Analysis/ValueTracking.h"
22+
#include "llvm/IR/IRBuilder.h"
23+
#include "llvm/IR/PatternMatch.h"
24+
#include "llvm/Support/CommandLine.h"
25+
#include "llvm/Support/Debug.h"
26+
#include "llvm/Support/MathExtras.h"
27+
#include "llvm/Support/raw_ostream.h"
28+
#include "llvm/Transforms/Scalar/LoopPassManager.h"
29+
#include "llvm/Transforms/Utils/Local.h"
30+
#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
31+
32+
#define DEBUG_TYPE "evl-iv-simplify"
33+
34+
using namespace llvm;
35+
36+
STATISTIC(NumEliminatedCanonicalIV, "Number of canonical IVs we eliminated");
37+
38+
static cl::opt<bool> EnableEVLIndVarSimplify(
39+
"enable-evl-indvar-simplify",
40+
cl::desc("Enable EVL-based induction variable simplify Pass"), cl::Hidden,
41+
cl::init(true));
42+
43+
namespace {
44+
struct EVLIndVarSimplifyImpl {
45+
ScalarEvolution &SE;
46+
47+
explicit EVLIndVarSimplifyImpl(LoopStandardAnalysisResults &LAR)
48+
: SE(LAR.SE) {}
49+
50+
explicit EVLIndVarSimplifyImpl(ScalarEvolution &SE) : SE(SE) {}
51+
52+
// Returns true if modify the loop.
53+
bool run(Loop &L);
54+
};
55+
} // anonymous namespace
56+
57+
// Returns the constant part of vectorization factor from the induction
58+
// variable's step value SCEV expression.
59+
static uint32_t getVFFromIndVar(const SCEV *Step, const Function &F) {
60+
if (!Step)
61+
return 0U;
62+
63+
// Looking for loops with IV step value in the form of `(<constant VF> x
64+
// vscale)`.
65+
if (auto *Mul = dyn_cast<SCEVMulExpr>(Step)) {
66+
if (Mul->getNumOperands() == 2) {
67+
const SCEV *LHS = Mul->getOperand(0);
68+
const SCEV *RHS = Mul->getOperand(1);
69+
if (auto *Const = dyn_cast<SCEVConstant>(LHS)) {
70+
uint64_t V = Const->getAPInt().getLimitedValue();
71+
if (isa<SCEVVScale>(RHS) && llvm::isUInt<32>(V))
72+
return V;
73+
}
74+
}
75+
}
76+
77+
// If not, see if the vscale_range of the parent function is a fixed value,
78+
// which makes the step value to be replaced by a constant.
79+
if (F.hasFnAttribute(Attribute::VScaleRange))
80+
if (auto *ConstStep = dyn_cast<SCEVConstant>(Step)) {
81+
APInt V = ConstStep->getAPInt().abs();
82+
ConstantRange CR = llvm::getVScaleRange(&F, 64);
83+
if (const APInt *Fixed = CR.getSingleElement()) {
84+
V = V.zextOrTrunc(Fixed->getBitWidth());
85+
uint64_t VF = V.udiv(*Fixed).getLimitedValue();
86+
if (VF && llvm::isUInt<32>(VF) &&
87+
// Make sure step is divisible by vscale.
88+
V.urem(*Fixed).isZero())
89+
return VF;
90+
}
91+
}
92+
93+
return 0U;
94+
}
95+
96+
bool EVLIndVarSimplifyImpl::run(Loop &L) {
97+
if (!EnableEVLIndVarSimplify)
98+
return false;
99+
100+
if (!getBooleanLoopAttribute(&L, "llvm.loop.isvectorized") ||
101+
!getBooleanLoopAttribute(&L, "llvm.loop.isvectorized.withevl"))
102+
return false;
103+
104+
BasicBlock *LatchBlock = L.getLoopLatch();
105+
ICmpInst *OrigLatchCmp = L.getLatchCmpInst();
106+
if (!LatchBlock || !OrigLatchCmp)
107+
return false;
108+
109+
InductionDescriptor IVD;
110+
PHINode *IndVar = L.getInductionVariable(SE);
111+
if (!IndVar || !L.getInductionDescriptor(SE, IVD)) {
112+
LLVM_DEBUG(dbgs() << "Cannot retrieve IV from loop " << L.getName()
113+
<< "\n");
114+
return false;
115+
}
116+
117+
BasicBlock *InitBlock, *BackEdgeBlock;
118+
if (!L.getIncomingAndBackEdge(InitBlock, BackEdgeBlock)) {
119+
LLVM_DEBUG(dbgs() << "Expect unique incoming and backedge in "
120+
<< L.getName() << "\n");
121+
return false;
122+
}
123+
124+
// Retrieve the loop bounds.
125+
std::optional<Loop::LoopBounds> Bounds = L.getBounds(SE);
126+
if (!Bounds) {
127+
LLVM_DEBUG(dbgs() << "Could not obtain the bounds for loop " << L.getName()
128+
<< "\n");
129+
return false;
130+
}
131+
Value *CanonicalIVInit = &Bounds->getInitialIVValue();
132+
Value *CanonicalIVFinal = &Bounds->getFinalIVValue();
133+
134+
const SCEV *StepV = IVD.getStep();
135+
uint32_t VF = getVFFromIndVar(StepV, *L.getHeader()->getParent());
136+
if (!VF) {
137+
LLVM_DEBUG(dbgs() << "Could not infer VF from IndVar step '" << *StepV
138+
<< "'\n");
139+
return false;
140+
}
141+
LLVM_DEBUG(dbgs() << "Using VF=" << VF << " for loop " << L.getName()
142+
<< "\n");
143+
144+
// Try to find the EVL-based induction variable.
145+
using namespace PatternMatch;
146+
BasicBlock *BB = IndVar->getParent();
147+
148+
Value *EVLIndVar = nullptr;
149+
Value *RemTC = nullptr;
150+
Value *TC = nullptr;
151+
auto IntrinsicMatch = m_Intrinsic<Intrinsic::experimental_get_vector_length>(
152+
m_Value(RemTC), m_SpecificInt(VF),
153+
/*Scalable=*/m_SpecificInt(1));
154+
for (auto &PN : BB->phis()) {
155+
if (&PN == IndVar)
156+
continue;
157+
158+
// Check 1: it has to contain both incoming (init) & backedge blocks
159+
// from IndVar.
160+
if (PN.getBasicBlockIndex(InitBlock) < 0 ||
161+
PN.getBasicBlockIndex(BackEdgeBlock) < 0)
162+
continue;
163+
// Check 2: EVL index is always increasing, thus its inital value has to be
164+
// equal to either the initial IV value (when the canonical IV is also
165+
// increasing) or the last IV value (when canonical IV is decreasing).
166+
Value *Init = PN.getIncomingValueForBlock(InitBlock);
167+
using Direction = Loop::LoopBounds::Direction;
168+
switch (Bounds->getDirection()) {
169+
case Direction::Increasing:
170+
if (Init != CanonicalIVInit)
171+
continue;
172+
break;
173+
case Direction::Decreasing:
174+
if (Init != CanonicalIVFinal)
175+
continue;
176+
break;
177+
case Direction::Unknown:
178+
// To be more permissive and see if either the initial or final IV value
179+
// matches PN's init value.
180+
if (Init != CanonicalIVInit && Init != CanonicalIVFinal)
181+
continue;
182+
break;
183+
}
184+
Value *RecValue = PN.getIncomingValueForBlock(BackEdgeBlock);
185+
assert(RecValue);
186+
187+
LLVM_DEBUG(dbgs() << "Found candidate PN of EVL-based IndVar: " << PN
188+
<< "\n");
189+
190+
// Check 3: Pattern match to find the EVL-based index and total trip count
191+
// (TC).
192+
if (match(RecValue,
193+
m_c_Add(m_ZExtOrSelf(IntrinsicMatch), m_Specific(&PN))) &&
194+
match(RemTC, m_Sub(m_Value(TC), m_Specific(&PN)))) {
195+
EVLIndVar = RecValue;
196+
break;
197+
}
198+
}
199+
200+
if (!EVLIndVar || !TC)
201+
return false;
202+
203+
LLVM_DEBUG(dbgs() << "Using " << *EVLIndVar << " for EVL-based IndVar\n");
204+
205+
// Create an EVL-based comparison and replace the branch to use it as
206+
// predicate.
207+
208+
// Loop::getLatchCmpInst check at the beginning of this function has ensured
209+
// that latch block ends in a conditional branch.
210+
auto *LatchBranch = cast<BranchInst>(LatchBlock->getTerminator());
211+
assert(LatchBranch->isConditional());
212+
ICmpInst::Predicate Pred;
213+
if (LatchBranch->getSuccessor(0) == L.getHeader())
214+
Pred = ICmpInst::ICMP_NE;
215+
else
216+
Pred = ICmpInst::ICMP_EQ;
217+
218+
IRBuilder<> Builder(OrigLatchCmp);
219+
auto *NewLatchCmp = Builder.CreateICmp(Pred, EVLIndVar, TC);
220+
OrigLatchCmp->replaceAllUsesWith(NewLatchCmp);
221+
222+
// llvm::RecursivelyDeleteDeadPHINode only deletes cycles whose values are
223+
// not used outside the cycles. However, in this case the now-RAUW-ed
224+
// OrigLatchCmp will be considered a use outside the cycle while in reality
225+
// it's practically dead. Thus we need to remove it before calling
226+
// RecursivelyDeleteDeadPHINode.
227+
(void)RecursivelyDeleteTriviallyDeadInstructions(OrigLatchCmp);
228+
if (llvm::RecursivelyDeleteDeadPHINode(IndVar))
229+
LLVM_DEBUG(dbgs() << "Removed original IndVar\n");
230+
231+
++NumEliminatedCanonicalIV;
232+
233+
return true;
234+
}
235+
236+
PreservedAnalyses EVLIndVarSimplifyPass::run(Loop &L, LoopAnalysisManager &LAM,
237+
LoopStandardAnalysisResults &AR,
238+
LPMUpdater &U) {
239+
if (EVLIndVarSimplifyImpl(AR).run(L))
240+
return PreservedAnalyses::allInSet<CFGAnalyses>();
241+
return PreservedAnalyses::all();
242+
}

0 commit comments

Comments
 (0)