Skip to content

Commit adc5f60

Browse files
iamloukJulienVillette
authored andcommitted
[CodeGen] Expand-Support for Scalable Reductions (!100)
1 parent c60961c commit adc5f60

File tree

8 files changed

+444
-20
lines changed

8 files changed

+444
-20
lines changed

llvm/include/llvm/IR/Attributes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ class AttributeSet {
433433
const;
434434
unsigned getVScaleRangeMin() const;
435435
std::optional<unsigned> getVScaleRangeMax() const;
436+
std::optional<unsigned> getFixedVScale() const;
436437
UWTableKind getUWTableKind() const;
437438
AllocFnKind getAllocKind() const;
438439
MemoryEffects getMemoryEffects() const;

llvm/lib/CodeGen/ExpandReductions.cpp

Lines changed: 239 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,205 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "llvm/CodeGen/ExpandReductions.h"
15+
#include "llvm/Analysis/DomTreeUpdater.h"
1516
#include "llvm/Analysis/TargetTransformInfo.h"
1617
#include "llvm/CodeGen/Passes.h"
18+
#include "llvm/IR/BasicBlock.h"
19+
#include "llvm/IR/Constants.h"
20+
#include "llvm/IR/DerivedTypes.h"
21+
#include "llvm/IR/Dominators.h"
1722
#include "llvm/IR/IRBuilder.h"
1823
#include "llvm/IR/InstIterator.h"
24+
#include "llvm/IR/Instruction.h"
1925
#include "llvm/IR/IntrinsicInst.h"
2026
#include "llvm/IR/Intrinsics.h"
2127
#include "llvm/InitializePasses.h"
2228
#include "llvm/Pass.h"
29+
#include "llvm/Support/ErrorHandling.h"
30+
#include "llvm/Support/MathExtras.h"
31+
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
2332
#include "llvm/Transforms/Utils/LoopUtils.h"
2433

2534
using namespace llvm;
2635

2736
namespace {
2837

29-
bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
30-
bool Changed = false;
38+
void updateDomTreeForScalableExpansion(DominatorTree *DT, BasicBlock *Preheader,
39+
BasicBlock *Loop, BasicBlock *Exit) {
40+
DT->addNewBlock(Loop, Preheader);
41+
DT->changeImmediateDominator(Exit, Loop);
42+
assert(DT->verify(DominatorTree::VerificationLevel::Fast));
43+
}
44+
45+
/// Expand a reduction on a scalable vector into a loop
46+
/// that iterates over one element after the other.
47+
Value *expandScalableReduction(IRBuilderBase &Builder, IntrinsicInst *II,
48+
Value *Acc, Value *Vec,
49+
Instruction::BinaryOps BinOp,
50+
DominatorTree *DT) {
51+
ScalableVectorType *VecTy = cast<ScalableVectorType>(Vec->getType());
52+
53+
// Split the original BB in two and create a new BB between them,
54+
// which will be a loop.
55+
BasicBlock *BeforeBB = II->getParent();
56+
BasicBlock *AfterBB = SplitBlock(BeforeBB, II, DT);
57+
BasicBlock *LoopBB = BasicBlock::Create(Builder.getContext(), "rdx.loop",
58+
BeforeBB->getParent(), AfterBB);
59+
BeforeBB->getTerminator()->setSuccessor(0, LoopBB);
60+
61+
// Calculate the number of elements in the vector:
62+
Builder.SetInsertPoint(BeforeBB->getTerminator());
63+
Value *NumElts =
64+
Builder.CreateVScale(Builder.getInt64(VecTy->getMinNumElements()));
65+
66+
// Create two PHIs, one for the index of the current lane and one for
67+
// the actuall reduction.
68+
Builder.SetInsertPoint(LoopBB);
69+
PHINode *IV = Builder.CreatePHI(Builder.getInt64Ty(), 2, "index");
70+
IV->addIncoming(Builder.getInt64(0), BeforeBB);
71+
PHINode *RdxPhi = Builder.CreatePHI(VecTy->getScalarType(), 2, "rdx.phi");
72+
RdxPhi->addIncoming(Acc, BeforeBB);
73+
74+
Value *IVInc =
75+
Builder.CreateAdd(IV, Builder.getInt64(1), "index.next", true, true);
76+
IV->addIncoming(IVInc, LoopBB);
77+
78+
// Extract the value at the current lane from the vector and perform
79+
// the scalar reduction binop:
80+
Value *Lane = Builder.CreateExtractElement(Vec, IV, "elm");
81+
Value *Rdx = Builder.CreateBinOp(BinOp, RdxPhi, Lane, "rdx");
82+
RdxPhi->addIncoming(Rdx, LoopBB);
83+
84+
// Exit when all lanes have been treated (assuming there will be at least
85+
// one element in the vector):
86+
Value *Done = Builder.CreateCmp(CmpInst::ICMP_EQ, IVInc, NumElts, "exitcond");
87+
Builder.CreateCondBr(Done, AfterBB, LoopBB);
88+
89+
if (DT)
90+
updateDomTreeForScalableExpansion(DT, BeforeBB, LoopBB, AfterBB);
91+
92+
return Rdx;
93+
}
94+
95+
/// Expand a reduction on a scalable vector in a parallel-tree like
96+
/// manner, meaning halving the number of elements to treat in every
97+
/// iteration.
98+
Value *expandScalableTreeReduction(
99+
IRBuilderBase &Builder, IntrinsicInst *II, std::optional<Value *> Acc,
100+
Value *Vec, Instruction::BinaryOps BinOp,
101+
function_ref<bool(Constant *)> IsNeutralElement, DominatorTree *DT,
102+
std::optional<unsigned> FixedVScale) {
103+
ScalableVectorType *VecTy = cast<ScalableVectorType>(Vec->getType());
104+
ScalableVectorType *VecTyX2 = ScalableVectorType::get(
105+
VecTy->getScalarType(), VecTy->getMinNumElements() * 2);
106+
107+
// If the VScale is fixed, do not generate a loop, and instead to
108+
// something similar to llvm::getShuffleReduction(). That function
109+
// cannot be used directly because it uses shuffle masks, which
110+
// are not avaiable for scalable vectors (even if vscale is fixed).
111+
// The approach is effectively the same.
112+
if (FixedVScale.has_value()) {
113+
unsigned VF = VecTy->getMinNumElements() * FixedVScale.value();
114+
assert(isPowerOf2_64(VF));
115+
for (unsigned I = VF; I != 1; I >>= 1) {
116+
Value *Extended = Builder.CreateInsertVector(
117+
VecTyX2, PoisonValue::get(VecTyX2), Vec, Builder.getInt64(0));
118+
Value *Pair = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2,
119+
{VecTyX2}, {Extended});
120+
121+
Value *Vec1 = Builder.CreateExtractValue(Pair, {0});
122+
Value *Vec2 = Builder.CreateExtractValue(Pair, {1});
123+
Vec = Builder.CreateBinOp(BinOp, Vec1, Vec2, "rdx");
124+
}
125+
Value *FinalVal = Builder.CreateExtractElement(Vec, uint64_t(0));
126+
if (Acc)
127+
if (auto *C = dyn_cast<Constant>(*Acc); !C || !IsNeutralElement(C))
128+
FinalVal = Builder.CreateBinOp(BinOp, *Acc, FinalVal, "rdx.final");
129+
return FinalVal;
130+
}
131+
132+
// Split the original BB in two and create a new BB between them,
133+
// which will be a loop.
134+
BasicBlock *BeforeBB = II->getParent();
135+
BasicBlock *AfterBB = SplitBlock(BeforeBB, II, DT);
136+
BasicBlock *LoopBB = BasicBlock::Create(Builder.getContext(), "rdx.loop",
137+
BeforeBB->getParent(), AfterBB);
138+
BeforeBB->getTerminator()->setSuccessor(0, LoopBB);
139+
140+
// This tree reduction only needs to do log2(N) iterations.
141+
// Note: Calculating log2(N) using count-trailing-zeros (cttz) only works if
142+
// `vscale` is a power-of-two. This is the case for every architecture known
143+
// right now, but could a check be added with a fallback to some other algo.?
144+
assert(isPowerOf2_64(VecTy->getMinNumElements()));
145+
Builder.SetInsertPoint(BeforeBB->getTerminator());
146+
Value *NumElts =
147+
Builder.CreateVScale(Builder.getInt64(VecTy->getMinNumElements()));
148+
Value *NumIters = Builder.CreateIntrinsic(NumElts->getType(), Intrinsic::cttz,
149+
{NumElts, Builder.getTrue()});
150+
151+
// Create two PHIs, one for the IV and one for the actuall reduction.
152+
Builder.SetInsertPoint(LoopBB);
153+
PHINode *IV = Builder.CreatePHI(Builder.getInt64Ty(), 2, "iter");
154+
IV->addIncoming(Builder.getInt64(0), BeforeBB);
155+
PHINode *VecPhi = Builder.CreatePHI(VecTy, 2, "rdx.phi");
156+
VecPhi->addIncoming(Vec, BeforeBB);
157+
158+
// Note that instead of calculating log2(N) beforehand and having the IV
159+
// increment by one every iteration, we could also have a IV more similar to:
160+
// for (size_t active_lanes = N; active_lanes > 1; active_lanes /= 2) ...
161+
// The IV is only used for the loop's exit condition, so how it is
162+
// calculated does not matter to the tree reduction.
163+
Value *IVInc =
164+
Builder.CreateAdd(IV, Builder.getInt64(1), "iter.next", true, true);
165+
IV->addIncoming(IVInc, LoopBB);
166+
167+
// The deinterleave intrinsic takes a vector of, for example, type
168+
// <vscale x 8 x float> and produces a pair of vectors with half the size,
169+
// so 2 x <vscale x 4 x float>. An insert vector operation is used to
170+
// create a double-sized vector where the upper half is poison, because
171+
// we never care about that upper half anyways!
172+
Value *Extended = Builder.CreateInsertVector(
173+
VecTyX2, PoisonValue::get(VecTyX2), VecPhi, Builder.getInt64(0));
174+
Value *Pair = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2,
175+
{VecTyX2}, {Extended});
176+
177+
// Take the two vectors and multiply them together. Note that in the first
178+
// iteration, the results of 1/2 of the lanes is used, in the second one
179+
// 1/4, in the thrid one 1/8, etc.. It could be nice to create a mask
180+
// for this? However, on SVE at least, the instr. latency does not depend
181+
// on the number of active lanes (as far as I know), so this might just
182+
// be useless.
183+
Value *Vec1 = Builder.CreateExtractValue(Pair, {0});
184+
Value *Vec2 = Builder.CreateExtractValue(Pair, {1});
185+
Value *Rdx = Builder.CreateBinOp(BinOp, Vec1, Vec2, "rdx");
186+
VecPhi->addIncoming(Rdx, LoopBB);
187+
188+
// Reduction-loop exit condition:
189+
Value *Done =
190+
Builder.CreateCmp(CmpInst::ICMP_EQ, IVInc, NumIters, "exitcond");
191+
Builder.CreateCondBr(Done, AfterBB, LoopBB);
192+
Builder.SetInsertPoint(AfterBB, AfterBB->getFirstInsertionPt());
193+
Value *FinalVal = Builder.CreateExtractElement(Rdx, uint64_t(0));
194+
195+
// If the Acc value is not the neutral element of the reduction operation,
196+
// then we need to do the binop one last time with the end result of the
197+
// tree reduction. Sidenote: LLVM's loop-vectorizer will actually generate
198+
// code where Acc is zero for addition and one for multiplication most of
199+
// the time.
200+
if (Acc)
201+
if (auto *C = dyn_cast<Constant>(*Acc); !C || !IsNeutralElement(C))
202+
FinalVal = Builder.CreateBinOp(BinOp, *Acc, FinalVal, "rdx.final");
203+
204+
if (DT)
205+
updateDomTreeForScalableExpansion(DT, BeforeBB, LoopBB, AfterBB);
206+
207+
return FinalVal;
208+
}
209+
210+
std::pair<bool, bool> expandReductions(Function &F,
211+
const TargetTransformInfo *TTI,
212+
DominatorTree *DT) {
213+
bool Changed = false, CFGChanged = false;
31214
SmallVector<IntrinsicInst *, 4> Worklist;
32215
for (auto &I : instructions(F)) {
33216
if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
@@ -54,6 +237,9 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
54237
}
55238
}
56239

240+
std::optional<unsigned> FixedVScale =
241+
F.getAttributes().getFnAttrs().getFixedVScale();
242+
57243
for (auto *II : Worklist) {
58244
FastMathFlags FMF =
59245
isa<FPMathOperator>(II) ? II->getFastMathFlags() : FastMathFlags{};
@@ -74,7 +260,31 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
74260
// and it can't be handled by generating a shuffle sequence.
75261
Value *Acc = II->getArgOperand(0);
76262
Value *Vec = II->getArgOperand(1);
77-
unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
263+
auto RdxOpcode =
264+
Instruction::BinaryOps(getArithmeticReductionInstruction(ID));
265+
266+
bool ScalableTy = Vec->getType()->isScalableTy();
267+
if (ScalableTy && (!FixedVScale || FMF.allowReassoc())) {
268+
CFGChanged |= !FixedVScale;
269+
if (FMF.allowReassoc())
270+
Rdx = expandScalableTreeReduction(
271+
Builder, II, Acc, Vec, RdxOpcode,
272+
[&](Constant *C) {
273+
switch (ID) {
274+
case Intrinsic::vector_reduce_fadd:
275+
return C->isZeroValue();
276+
case Intrinsic::vector_reduce_fmul:
277+
return C->isOneValue();
278+
default:
279+
llvm_unreachable("Binop not handled");
280+
}
281+
},
282+
DT, FixedVScale);
283+
else
284+
Rdx = expandScalableReduction(Builder, II, Acc, Vec, RdxOpcode, DT);
285+
break;
286+
}
287+
78288
if (!FMF.allowReassoc())
79289
Rdx = getOrderedReduction(Builder, Acc, Vec, RdxOpcode, RK);
80290
else {
@@ -125,10 +335,22 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
125335
case Intrinsic::vector_reduce_umax:
126336
case Intrinsic::vector_reduce_umin: {
127337
Value *Vec = II->getArgOperand(0);
338+
unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
339+
if (Vec->getType()->isScalableTy()) {
340+
CFGChanged |= !FixedVScale;
341+
Rdx = expandScalableTreeReduction(
342+
Builder, II, std::nullopt, Vec, Instruction::BinaryOps(RdxOpcode),
343+
[](Constant *C) -> bool {
344+
llvm_unreachable(
345+
"No accumulator, so this should never be called!");
346+
},
347+
DT, FixedVScale);
348+
break;
349+
}
350+
128351
if (!isPowerOf2_32(
129352
cast<FixedVectorType>(Vec->getType())->getNumElements()))
130353
continue;
131-
unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
132354
Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK);
133355
break;
134356
}
@@ -150,7 +372,7 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
150372
II->eraseFromParent();
151373
Changed = true;
152374
}
153-
return Changed;
375+
return {CFGChanged, Changed};
154376
}
155377

156378
class ExpandReductions : public FunctionPass {
@@ -161,13 +383,15 @@ class ExpandReductions : public FunctionPass {
161383
}
162384

163385
bool runOnFunction(Function &F) override {
164-
const auto *TTI =&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
165-
return expandReductions(F, TTI);
386+
const auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
387+
auto *DTA = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
388+
return expandReductions(F, TTI, DTA ? &DTA->getDomTree() : nullptr).second;
166389
}
167390

168391
void getAnalysisUsage(AnalysisUsage &AU) const override {
169392
AU.addRequired<TargetTransformInfoWrapperPass>();
170-
AU.setPreservesCFG();
393+
AU.addUsedIfAvailable<DominatorTreeWrapperPass>();
394+
AU.addPreserved<DominatorTreeWrapperPass>();
171395
}
172396
};
173397
}
@@ -186,9 +410,14 @@ FunctionPass *llvm::createExpandReductionsPass() {
186410
PreservedAnalyses ExpandReductionsPass::run(Function &F,
187411
FunctionAnalysisManager &AM) {
188412
const auto &TTI = AM.getResult<TargetIRAnalysis>(F);
189-
if (!expandReductions(F, &TTI))
413+
auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
414+
auto [CFGChanged, Changed] = expandReductions(F, &TTI, DT);
415+
if (!Changed)
190416
return PreservedAnalyses::all();
191417
PreservedAnalyses PA;
192-
PA.preserveSet<CFGAnalyses>();
418+
if (!CFGChanged)
419+
PA.preserveSet<CFGAnalyses>();
420+
else
421+
PA.preserve<DominatorTreeAnalysis>();
193422
return PA;
194423
}

llvm/lib/IR/Attributes.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,6 +1158,14 @@ std::optional<unsigned> AttributeSet::getVScaleRangeMax() const {
11581158
return SetNode ? SetNode->getVScaleRangeMax() : std::nullopt;
11591159
}
11601160

1161+
std::optional<unsigned> AttributeSet::getFixedVScale() const {
1162+
unsigned Min = getVScaleRangeMin();
1163+
std::optional<unsigned> Max = getVScaleRangeMax();
1164+
if (Min != 0 && Max.has_value() && Max.value() == Min)
1165+
return Min;
1166+
return std::nullopt;
1167+
}
1168+
11611169
UWTableKind AttributeSet::getUWTableKind() const {
11621170
return SetNode ? SetNode->getUWTableKind() : UWTableKind::None;
11631171
}

0 commit comments

Comments
 (0)