12
12
// ===----------------------------------------------------------------------===//
13
13
14
14
#include " llvm/CodeGen/ExpandReductions.h"
15
+ #include " llvm/Analysis/DomTreeUpdater.h"
15
16
#include " llvm/Analysis/TargetTransformInfo.h"
16
17
#include " llvm/CodeGen/Passes.h"
18
+ #include " llvm/IR/BasicBlock.h"
19
+ #include " llvm/IR/Constants.h"
20
+ #include " llvm/IR/DerivedTypes.h"
21
+ #include " llvm/IR/Dominators.h"
17
22
#include " llvm/IR/IRBuilder.h"
18
23
#include " llvm/IR/InstIterator.h"
24
+ #include " llvm/IR/Instruction.h"
19
25
#include " llvm/IR/IntrinsicInst.h"
20
26
#include " llvm/IR/Intrinsics.h"
21
27
#include " llvm/InitializePasses.h"
22
28
#include " llvm/Pass.h"
29
+ #include " llvm/Support/ErrorHandling.h"
30
+ #include " llvm/Support/MathExtras.h"
31
+ #include " llvm/Transforms/Utils/BasicBlockUtils.h"
23
32
#include " llvm/Transforms/Utils/LoopUtils.h"
24
33
25
34
using namespace llvm ;
26
35
27
36
namespace {
28
37
29
- bool expandReductions (Function &F, const TargetTransformInfo *TTI) {
30
- bool Changed = false ;
38
+ void updateDomTreeForScalableExpansion (DominatorTree *DT, BasicBlock *Preheader,
39
+ BasicBlock *Loop, BasicBlock *Exit) {
40
+ DT->addNewBlock (Loop, Preheader);
41
+ DT->changeImmediateDominator (Exit, Loop);
42
+ assert (DT->verify (DominatorTree::VerificationLevel::Fast));
43
+ }
44
+
45
+ // / Expand a reduction on a scalable vector into a loop
46
+ // / that iterates over one element after the other.
47
+ Value *expandScalableReduction (IRBuilderBase &Builder, IntrinsicInst *II,
48
+ Value *Acc, Value *Vec,
49
+ Instruction::BinaryOps BinOp,
50
+ DominatorTree *DT) {
51
+ ScalableVectorType *VecTy = cast<ScalableVectorType>(Vec->getType ());
52
+
53
+ // Split the original BB in two and create a new BB between them,
54
+ // which will be a loop.
55
+ BasicBlock *BeforeBB = II->getParent ();
56
+ BasicBlock *AfterBB = SplitBlock (BeforeBB, II, DT);
57
+ BasicBlock *LoopBB = BasicBlock::Create (Builder.getContext (), " rdx.loop" ,
58
+ BeforeBB->getParent (), AfterBB);
59
+ BeforeBB->getTerminator ()->setSuccessor (0 , LoopBB);
60
+
61
+ // Calculate the number of elements in the vector:
62
+ Builder.SetInsertPoint (BeforeBB->getTerminator ());
63
+ Value *NumElts =
64
+ Builder.CreateVScale (Builder.getInt64 (VecTy->getMinNumElements ()));
65
+
66
+ // Create two PHIs, one for the index of the current lane and one for
67
+ // the actuall reduction.
68
+ Builder.SetInsertPoint (LoopBB);
69
+ PHINode *IV = Builder.CreatePHI (Builder.getInt64Ty (), 2 , " index" );
70
+ IV->addIncoming (Builder.getInt64 (0 ), BeforeBB);
71
+ PHINode *RdxPhi = Builder.CreatePHI (VecTy->getScalarType (), 2 , " rdx.phi" );
72
+ RdxPhi->addIncoming (Acc, BeforeBB);
73
+
74
+ Value *IVInc =
75
+ Builder.CreateAdd (IV, Builder.getInt64 (1 ), " index.next" , true , true );
76
+ IV->addIncoming (IVInc, LoopBB);
77
+
78
+ // Extract the value at the current lane from the vector and perform
79
+ // the scalar reduction binop:
80
+ Value *Lane = Builder.CreateExtractElement (Vec, IV, " elm" );
81
+ Value *Rdx = Builder.CreateBinOp (BinOp, RdxPhi, Lane, " rdx" );
82
+ RdxPhi->addIncoming (Rdx, LoopBB);
83
+
84
+ // Exit when all lanes have been treated (assuming there will be at least
85
+ // one element in the vector):
86
+ Value *Done = Builder.CreateCmp (CmpInst::ICMP_EQ, IVInc, NumElts, " exitcond" );
87
+ Builder.CreateCondBr (Done, AfterBB, LoopBB);
88
+
89
+ if (DT)
90
+ updateDomTreeForScalableExpansion (DT, BeforeBB, LoopBB, AfterBB);
91
+
92
+ return Rdx;
93
+ }
94
+
95
+ // / Expand a reduction on a scalable vector in a parallel-tree like
96
+ // / manner, meaning halving the number of elements to treat in every
97
+ // / iteration.
98
+ Value *expandScalableTreeReduction (
99
+ IRBuilderBase &Builder, IntrinsicInst *II, std::optional<Value *> Acc,
100
+ Value *Vec, Instruction::BinaryOps BinOp,
101
+ function_ref<bool (Constant *)> IsNeutralElement, DominatorTree *DT,
102
+ std::optional<unsigned> FixedVScale) {
103
+ ScalableVectorType *VecTy = cast<ScalableVectorType>(Vec->getType ());
104
+ ScalableVectorType *VecTyX2 = ScalableVectorType::get (
105
+ VecTy->getScalarType (), VecTy->getMinNumElements () * 2 );
106
+
107
+ // If the VScale is fixed, do not generate a loop, and instead to
108
+ // something similar to llvm::getShuffleReduction(). That function
109
+ // cannot be used directly because it uses shuffle masks, which
110
+ // are not avaiable for scalable vectors (even if vscale is fixed).
111
+ // The approach is effectively the same.
112
+ if (FixedVScale.has_value ()) {
113
+ unsigned VF = VecTy->getMinNumElements () * FixedVScale.value ();
114
+ assert (isPowerOf2_64 (VF));
115
+ for (unsigned I = VF; I != 1 ; I >>= 1 ) {
116
+ Value *Extended = Builder.CreateInsertVector (
117
+ VecTyX2, PoisonValue::get (VecTyX2), Vec, Builder.getInt64 (0 ));
118
+ Value *Pair = Builder.CreateIntrinsic (Intrinsic::vector_deinterleave2,
119
+ {VecTyX2}, {Extended});
120
+
121
+ Value *Vec1 = Builder.CreateExtractValue (Pair, {0 });
122
+ Value *Vec2 = Builder.CreateExtractValue (Pair, {1 });
123
+ Vec = Builder.CreateBinOp (BinOp, Vec1, Vec2, " rdx" );
124
+ }
125
+ Value *FinalVal = Builder.CreateExtractElement (Vec, uint64_t (0 ));
126
+ if (Acc)
127
+ if (auto *C = dyn_cast<Constant>(*Acc); !C || !IsNeutralElement (C))
128
+ FinalVal = Builder.CreateBinOp (BinOp, *Acc, FinalVal, " rdx.final" );
129
+ return FinalVal;
130
+ }
131
+
132
+ // Split the original BB in two and create a new BB between them,
133
+ // which will be a loop.
134
+ BasicBlock *BeforeBB = II->getParent ();
135
+ BasicBlock *AfterBB = SplitBlock (BeforeBB, II, DT);
136
+ BasicBlock *LoopBB = BasicBlock::Create (Builder.getContext (), " rdx.loop" ,
137
+ BeforeBB->getParent (), AfterBB);
138
+ BeforeBB->getTerminator ()->setSuccessor (0 , LoopBB);
139
+
140
+ // This tree reduction only needs to do log2(N) iterations.
141
+ // Note: Calculating log2(N) using count-trailing-zeros (cttz) only works if
142
+ // `vscale` is a power-of-two. This is the case for every architecture known
143
+ // right now, but could a check be added with a fallback to some other algo.?
144
+ assert (isPowerOf2_64 (VecTy->getMinNumElements ()));
145
+ Builder.SetInsertPoint (BeforeBB->getTerminator ());
146
+ Value *NumElts =
147
+ Builder.CreateVScale (Builder.getInt64 (VecTy->getMinNumElements ()));
148
+ Value *NumIters = Builder.CreateIntrinsic (NumElts->getType (), Intrinsic::cttz,
149
+ {NumElts, Builder.getTrue ()});
150
+
151
+ // Create two PHIs, one for the IV and one for the actuall reduction.
152
+ Builder.SetInsertPoint (LoopBB);
153
+ PHINode *IV = Builder.CreatePHI (Builder.getInt64Ty (), 2 , " iter" );
154
+ IV->addIncoming (Builder.getInt64 (0 ), BeforeBB);
155
+ PHINode *VecPhi = Builder.CreatePHI (VecTy, 2 , " rdx.phi" );
156
+ VecPhi->addIncoming (Vec, BeforeBB);
157
+
158
+ // Note that instead of calculating log2(N) beforehand and having the IV
159
+ // increment by one every iteration, we could also have a IV more similar to:
160
+ // for (size_t active_lanes = N; active_lanes > 1; active_lanes /= 2) ...
161
+ // The IV is only used for the loop's exit condition, so how it is
162
+ // calculated does not matter to the tree reduction.
163
+ Value *IVInc =
164
+ Builder.CreateAdd (IV, Builder.getInt64 (1 ), " iter.next" , true , true );
165
+ IV->addIncoming (IVInc, LoopBB);
166
+
167
+ // The deinterleave intrinsic takes a vector of, for example, type
168
+ // <vscale x 8 x float> and produces a pair of vectors with half the size,
169
+ // so 2 x <vscale x 4 x float>. An insert vector operation is used to
170
+ // create a double-sized vector where the upper half is poison, because
171
+ // we never care about that upper half anyways!
172
+ Value *Extended = Builder.CreateInsertVector (
173
+ VecTyX2, PoisonValue::get (VecTyX2), VecPhi, Builder.getInt64 (0 ));
174
+ Value *Pair = Builder.CreateIntrinsic (Intrinsic::vector_deinterleave2,
175
+ {VecTyX2}, {Extended});
176
+
177
+ // Take the two vectors and multiply them together. Note that in the first
178
+ // iteration, the results of 1/2 of the lanes is used, in the second one
179
+ // 1/4, in the thrid one 1/8, etc.. It could be nice to create a mask
180
+ // for this? However, on SVE at least, the instr. latency does not depend
181
+ // on the number of active lanes (as far as I know), so this might just
182
+ // be useless.
183
+ Value *Vec1 = Builder.CreateExtractValue (Pair, {0 });
184
+ Value *Vec2 = Builder.CreateExtractValue (Pair, {1 });
185
+ Value *Rdx = Builder.CreateBinOp (BinOp, Vec1, Vec2, " rdx" );
186
+ VecPhi->addIncoming (Rdx, LoopBB);
187
+
188
+ // Reduction-loop exit condition:
189
+ Value *Done =
190
+ Builder.CreateCmp (CmpInst::ICMP_EQ, IVInc, NumIters, " exitcond" );
191
+ Builder.CreateCondBr (Done, AfterBB, LoopBB);
192
+ Builder.SetInsertPoint (AfterBB, AfterBB->getFirstInsertionPt ());
193
+ Value *FinalVal = Builder.CreateExtractElement (Rdx, uint64_t (0 ));
194
+
195
+ // If the Acc value is not the neutral element of the reduction operation,
196
+ // then we need to do the binop one last time with the end result of the
197
+ // tree reduction. Sidenote: LLVM's loop-vectorizer will actually generate
198
+ // code where Acc is zero for addition and one for multiplication most of
199
+ // the time.
200
+ if (Acc)
201
+ if (auto *C = dyn_cast<Constant>(*Acc); !C || !IsNeutralElement (C))
202
+ FinalVal = Builder.CreateBinOp (BinOp, *Acc, FinalVal, " rdx.final" );
203
+
204
+ if (DT)
205
+ updateDomTreeForScalableExpansion (DT, BeforeBB, LoopBB, AfterBB);
206
+
207
+ return FinalVal;
208
+ }
209
+
210
+ std::pair<bool , bool > expandReductions (Function &F,
211
+ const TargetTransformInfo *TTI,
212
+ DominatorTree *DT) {
213
+ bool Changed = false , CFGChanged = false ;
31
214
SmallVector<IntrinsicInst *, 4 > Worklist;
32
215
for (auto &I : instructions (F)) {
33
216
if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
@@ -54,6 +237,9 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
54
237
}
55
238
}
56
239
240
+ std::optional<unsigned > FixedVScale =
241
+ F.getAttributes ().getFnAttrs ().getFixedVScale ();
242
+
57
243
for (auto *II : Worklist) {
58
244
FastMathFlags FMF =
59
245
isa<FPMathOperator>(II) ? II->getFastMathFlags () : FastMathFlags{};
@@ -74,7 +260,31 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
74
260
// and it can't be handled by generating a shuffle sequence.
75
261
Value *Acc = II->getArgOperand (0 );
76
262
Value *Vec = II->getArgOperand (1 );
77
- unsigned RdxOpcode = getArithmeticReductionInstruction (ID);
263
+ auto RdxOpcode =
264
+ Instruction::BinaryOps (getArithmeticReductionInstruction (ID));
265
+
266
+ bool ScalableTy = Vec->getType ()->isScalableTy ();
267
+ if (ScalableTy && (!FixedVScale || FMF.allowReassoc ())) {
268
+ CFGChanged |= !FixedVScale;
269
+ if (FMF.allowReassoc ())
270
+ Rdx = expandScalableTreeReduction (
271
+ Builder, II, Acc, Vec, RdxOpcode,
272
+ [&](Constant *C) {
273
+ switch (ID) {
274
+ case Intrinsic::vector_reduce_fadd:
275
+ return C->isZeroValue ();
276
+ case Intrinsic::vector_reduce_fmul:
277
+ return C->isOneValue ();
278
+ default :
279
+ llvm_unreachable (" Binop not handled" );
280
+ }
281
+ },
282
+ DT, FixedVScale);
283
+ else
284
+ Rdx = expandScalableReduction (Builder, II, Acc, Vec, RdxOpcode, DT);
285
+ break ;
286
+ }
287
+
78
288
if (!FMF.allowReassoc ())
79
289
Rdx = getOrderedReduction (Builder, Acc, Vec, RdxOpcode, RK);
80
290
else {
@@ -125,10 +335,22 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
125
335
case Intrinsic::vector_reduce_umax:
126
336
case Intrinsic::vector_reduce_umin: {
127
337
Value *Vec = II->getArgOperand (0 );
338
+ unsigned RdxOpcode = getArithmeticReductionInstruction (ID);
339
+ if (Vec->getType ()->isScalableTy ()) {
340
+ CFGChanged |= !FixedVScale;
341
+ Rdx = expandScalableTreeReduction (
342
+ Builder, II, std::nullopt, Vec, Instruction::BinaryOps (RdxOpcode),
343
+ [](Constant *C) -> bool {
344
+ llvm_unreachable (
345
+ " No accumulator, so this should never be called!" );
346
+ },
347
+ DT, FixedVScale);
348
+ break ;
349
+ }
350
+
128
351
if (!isPowerOf2_32 (
129
352
cast<FixedVectorType>(Vec->getType ())->getNumElements ()))
130
353
continue ;
131
- unsigned RdxOpcode = getArithmeticReductionInstruction (ID);
132
354
Rdx = getShuffleReduction (Builder, Vec, RdxOpcode, RS, RK);
133
355
break ;
134
356
}
@@ -150,7 +372,7 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
150
372
II->eraseFromParent ();
151
373
Changed = true ;
152
374
}
153
- return Changed;
375
+ return {CFGChanged, Changed} ;
154
376
}
155
377
156
378
class ExpandReductions : public FunctionPass {
@@ -161,13 +383,15 @@ class ExpandReductions : public FunctionPass {
161
383
}
162
384
163
385
bool runOnFunction (Function &F) override {
164
- const auto *TTI =&getAnalysis<TargetTransformInfoWrapperPass>().getTTI (F);
165
- return expandReductions (F, TTI);
386
+ const auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI (F);
387
+ auto *DTA = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
388
+ return expandReductions (F, TTI, DTA ? &DTA->getDomTree () : nullptr ).second ;
166
389
}
167
390
168
391
void getAnalysisUsage (AnalysisUsage &AU) const override {
169
392
AU.addRequired <TargetTransformInfoWrapperPass>();
170
- AU.setPreservesCFG ();
393
+ AU.addUsedIfAvailable <DominatorTreeWrapperPass>();
394
+ AU.addPreserved <DominatorTreeWrapperPass>();
171
395
}
172
396
};
173
397
}
@@ -186,9 +410,14 @@ FunctionPass *llvm::createExpandReductionsPass() {
186
410
PreservedAnalyses ExpandReductionsPass::run (Function &F,
187
411
FunctionAnalysisManager &AM) {
188
412
const auto &TTI = AM.getResult <TargetIRAnalysis>(F);
189
- if (!expandReductions (F, &TTI))
413
+ auto *DT = AM.getCachedResult <DominatorTreeAnalysis>(F);
414
+ auto [CFGChanged, Changed] = expandReductions (F, &TTI, DT);
415
+ if (!Changed)
190
416
return PreservedAnalyses::all ();
191
417
PreservedAnalyses PA;
192
- PA.preserveSet <CFGAnalyses>();
418
+ if (!CFGChanged)
419
+ PA.preserveSet <CFGAnalyses>();
420
+ else
421
+ PA.preserve <DominatorTreeAnalysis>();
193
422
return PA;
194
423
}
0 commit comments