Skip to content

Commit d4cb697

Browse files
pkwasnie-inteligcbot
authored andcommitted
SubGroupReductionPattern pass - suppport for i64 (2nd try)
Shuffle for i64 type can be implemented as shuffle of vector type <2 x i32>; that is two separate shuffles of i32. This commit adds support for such patttern in SubGroupReductionPattern pass.
1 parent ff92f23 commit d4cb697

File tree

2 files changed

+342
-16
lines changed

2 files changed

+342
-16
lines changed

IGC/Compiler/Optimizer/OpenCLPasses/SubGroupReductionPattern/SubGroupReductionPattern.cpp

Lines changed: 191 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@ SPDX-License-Identifier: MIT
1212
#include "common/LLVMWarningsPush.hpp"
1313
#include <llvm/IR/InstVisitor.h>
1414
#include <llvm/Transforms/Utils/Local.h>
15+
#include "llvmWrapper/IR/DerivedTypes.h"
1516
#include <llvmWrapper/IR/PatternMatch.h>
1617
#include "common/LLVMWarningsPop.hpp"
1718

1819
#include "common/igc_regkeys.hpp"
20+
#include "Compiler/CISACodeGen/OpenCLKernelCodeGen.hpp"
1921
#include "Compiler/MetaDataUtilsWrapper.h"
2022
#include "GenISAIntrinsics/GenIntrinsicInst.h"
2123

@@ -42,26 +44,50 @@ class ShufflePattern
4244

4345
struct PatternStep
4446
{
45-
PatternStep(GenIntrinsicInst *ShuffleOp, Instruction *Op, uint64_t Lane)
46-
: ShuffleOp(ShuffleOp), Op(Op), Lane(Lane) { }
47+
PatternStep(Value *InputValue, GenIntrinsicInst *ShuffleOp, Instruction *Op, uint64_t Lane)
48+
: InputValue(InputValue), ShuffleOp(ShuffleOp), Op(Op), Lane(Lane) { }
4749

50+
Value *InputValue;
51+
52+
// Shuffle InputValue with other SIMD lane.
4853
GenIntrinsicInst *ShuffleOp;
49-
Instruction *Op;
5054
uint64_t Lane;
55+
56+
// Op on InputValue and ShuffleOp result.
57+
Instruction *Op;
5158
};
5259

53-
ShufflePattern(GenIntrinsicInst *ShuffleOp, Instruction *Op, WaveOps OpType, uint64_t Lane)
60+
ShufflePattern(Value *InputValue, GenIntrinsicInst *ShuffleOp, Instruction *Op, WaveOps OpType, uint64_t Lane)
5461
: OpType(OpType)
5562
{
56-
Steps.emplace_back(ShuffleOp, Op, Lane);
63+
Steps.emplace_back(InputValue, ShuffleOp, Op, Lane);
5764
}
5865

59-
bool append(GenIntrinsicInst *ShuffleOp, Instruction *Op, WaveOps OpType, uint64_t Lane);
66+
bool append(Value *InputValue, GenIntrinsicInst *ShuffleOp, Instruction *Op, WaveOps OpType, uint64_t Lane);
6067

6168
WaveOps OpType;
6269
SmallVector<PatternStep, 8> Steps;
6370
};
6471

72+
// A half of i64 shuffled as <2 x i32>.
73+
// See SubGroupReductionPattern::matchVectorShufflePattern for details.
74+
struct VectorShufflePattern
75+
{
76+
VectorShufflePattern(Instruction *Op, uint64_t Lane, uint64_t VectorIndex)
77+
: Op(Op), Lane(Lane), VectorIndex(VectorIndex) {}
78+
79+
bool match(Instruction* Op, uint64_t Lane, uint64_t VectorIndex)
80+
{
81+
return this->Op == Op &&
82+
this->Lane == Lane &&
83+
((this->VectorIndex == 0 && VectorIndex == 1) || (this->VectorIndex == 1 && VectorIndex == 0));
84+
}
85+
86+
Instruction *Op;
87+
uint64_t Lane;
88+
uint64_t VectorIndex;
89+
};
90+
6591
// Pass for matching common manual subgroup reduction pattern and replacing them
6692
// with corresponding GenISA.Wave* call.
6793
class SubGroupReductionPattern : public llvm::FunctionPass, public llvm::InstVisitor<SubGroupReductionPattern>
@@ -88,16 +114,25 @@ class SubGroupReductionPattern : public llvm::FunctionPass, public llvm::InstVis
88114
void visitWaveShuffleIndex(GenIntrinsicInst &ShuffleOp);
89115

90116
void matchShufflePattern(GenIntrinsicInst &ShuffleOp, uint64_t Lane);
117+
void matchVectorShufflePattern(GenIntrinsicInst &ShuffleOp, uint64_t Lane);
118+
void addShufflePattern(Value *InputValue, GenIntrinsicInst &ShuffleOp, Instruction *Op, uint64_t Lane);
91119

92120
bool reduce(ShufflePattern &Pattern);
93121
GenISAIntrinsic::ID getReductionType(uint64_t XorMask);
94122

95123
static WaveOps getWaveOp(Instruction* Op);
96124

125+
CodeGenContext *CGC = nullptr;
126+
97127
int SubGroupSize = 0;
98128
bool Modified = false;
99129

100130
SmallVector<ShufflePattern, 8> Matches;
131+
132+
// For i64 shuffle done as two i32 shuffles (vector <2 x i32>), each
133+
// of i32 shuffle is matched as separate pattern. This map temporary
134+
// holds first of the matched pair.
135+
DenseMap<Instruction*, VectorShufflePattern> VectorShufflePatterns;
101136
};
102137

103138
SubGroupReductionPattern::SubGroupReductionPattern() : FunctionPass(ID)
@@ -107,6 +142,7 @@ SubGroupReductionPattern::SubGroupReductionPattern() : FunctionPass(ID)
107142

108143
void SubGroupReductionPattern::getAnalysisUsage(llvm::AnalysisUsage &AU) const
109144
{
145+
AU.addRequired<CodeGenContextWrapper>();
110146
AU.addRequired<MetaDataUtilsWrapper>();
111147
AU.setPreservesCFG();
112148
}
@@ -116,6 +152,7 @@ bool SubGroupReductionPattern::runOnFunction(llvm::Function &F)
116152
if (F.hasOptNone())
117153
return false;
118154

155+
CGC = getAnalysis<CodeGenContextWrapper>().getCodeGenContext();
119156
auto MDU = getAnalysis<MetaDataUtilsWrapper>().getMetaDataUtils();
120157
auto FII = MDU->findFunctionsInfoItem(&F);
121158
if (FII == MDU->end_FunctionsInfo())
@@ -206,29 +243,165 @@ void SubGroupReductionPattern::matchShufflePattern(GenIntrinsicInst &ShuffleOp,
206243
return;
207244

208245
Instruction *Op = ShuffleOp.user_back();
209-
if (!Op)
246+
247+
if (isa<InsertElementInst>(Op))
248+
{
249+
return matchVectorShufflePattern(ShuffleOp, Lane);
250+
}
251+
252+
Value *InputValue = ShuffleOp.getOperand(0);
253+
if (((Op->getOperand(0) == InputValue && Op->getOperand(1) == &ShuffleOp) || (Op->getOperand(0) == &ShuffleOp && Op->getOperand(1) == InputValue)) == false)
210254
return;
211255

256+
addShufflePattern(InputValue, ShuffleOp, Op, Lane);
257+
}
258+
259+
void SubGroupReductionPattern::addShufflePattern(Value *InputValue, GenIntrinsicInst &ShuffleOp, Instruction *Op, uint64_t Lane)
260+
{
212261
WaveOps OpType = getWaveOp(Op);
213262
if (OpType == WaveOps::UNDEF)
214263
return;
215264

216-
Value* Other = ShuffleOp.getOperand(0);
217-
if (((Op->getOperand(0) == Other && Op->getOperand(1) == &ShuffleOp) || (Op->getOperand(0) == &ShuffleOp && Op->getOperand(1) == Other)) == false)
265+
// Check if type is supported.
266+
if (Op->getType()->isIntegerTy(64) && CGC->platform.need64BitEmulation())
218267
return;
219268

220269
// Continues previous pattern?
221270
for (auto &Match : Matches)
222271
{
223-
if (Match.append(&ShuffleOp, Op, OpType, Lane))
272+
if (Match.append(InputValue, &ShuffleOp, Op, OpType, Lane))
224273
return;
225274
}
226275

227276
// New pattern.
228-
Matches.emplace_back(&ShuffleOp, Op, OpType, Lane);
277+
Matches.emplace_back(InputValue, &ShuffleOp, Op, OpType, Lane);
278+
}
279+
280+
// Shuffle of i64 type can be split into two shuffles of i32 type. This method handles
281+
// such case; it matches the following pattern:
282+
//
283+
// %3 = bitcast i64 %value to <2 x i32>
284+
// %value1 = extractelement <2 x i32> %3, i64 0
285+
// %value2 = extractelement <2 x i32> %3, i64 1
286+
// %simdShuffleXor1 = call i32 @llvm.genx.GenISA.simdShuffleXor.i32(i32 %value1, i32 8)
287+
// %simdShuffleXor2 = call i32 @llvm.genx.GenISA.simdShuffleXor.i32(i32 %value2, i32 8)
288+
// %shuffledVec1 = insertelement <2 x i32> undef, i32 %simdShuffleXor1, i64 0
289+
// %shuffledVec2 = insertelement <2 x i32> %shuffledVec1, i32 %simdShuffleXor2, i64 1
290+
// %shuffled = bitcast <2 x i32> %shuffledVec2 to i64
291+
// %result = <op> i64 %value, %shuffled
292+
void SubGroupReductionPattern::matchVectorShufflePattern(GenIntrinsicInst &ShuffleOp, uint64_t Lane)
293+
{
294+
auto CheckVectorType = [](Type *Ty)
295+
{
296+
if (auto *VTy = dyn_cast<IGCLLVM::FixedVectorType>(Ty))
297+
return VTy->getNumElements() == 2;
298+
return false;
299+
};
300+
301+
// Match instructions that happen before shuffle, that is:
302+
//
303+
// %3 = bitcast i64 %value to <2 x i32>
304+
// %value1 = extractelement <2 x i32> %3, i64 0
305+
// %value2 = extractelement <2 x i32> %3, i64 1
306+
// %simdShuffleXor1 = call i32 @llvm.genx.GenISA.simdShuffleXor.i32(i32 %value1, i32 8)
307+
// %simdShuffleXor2 = call i32 @llvm.genx.GenISA.simdShuffleXor.i32(i32 %value2, i32 8)
308+
//
309+
// Collect:
310+
// 1. Input value.
311+
// 2. BitCast instruction (to validate type).
312+
// 2. ExtractElement index.
313+
314+
Value *InputValue = nullptr;
315+
Instruction *BitCast = nullptr;
316+
uint64_t VectorIndex = 0;
317+
318+
if (!match(ShuffleOp.getOperand(0), m_OneUse(
319+
m_ExtractElt(
320+
m_CombineAnd(
321+
m_Instruction(BitCast),
322+
m_BitCast(m_Value(InputValue)
323+
)),
324+
m_ConstantInt(VectorIndex)))))
325+
return;
326+
327+
if (VectorIndex != 0 && VectorIndex != 1)
328+
return;
329+
330+
if (InputValue->getType()->isVectorTy() || !CheckVectorType(BitCast->getType()))
331+
return;
332+
333+
// Match instructions that happen after shuffle, that is:
334+
//
335+
// %simdShuffleXor1 = call i32 @llvm.genx.GenISA.simdShuffleXor.i32(i32 %value1, i32 8)
336+
// %simdShuffleXor2 = call i32 @llvm.genx.GenISA.simdShuffleXor.i32(i32 %value2, i32 8)
337+
// %shuffledVec1 = insertelement <2 x i32> undef, i32 %simdShuffleXor1, i64 0
338+
// %shuffledVec2 = insertelement <2 x i32> %shuffledVec1, i32 %simdShuffleXor2, i64 1
339+
// %shuffled = bitcast <2 x i32> %shuffledVec2 to i64
340+
// %result = <op> i64 %value, %shuffled
341+
//
342+
// Collect:
343+
// 1. InsertElement instruction (to validate type).
344+
345+
Instruction *InsertElement = nullptr;
346+
347+
// First pattern - shuffle is used in second insertelement.
348+
auto Pattern1 = m_OneUse(
349+
m_BitCast(
350+
m_OneUse(
351+
m_CombineAnd(
352+
m_Instruction(InsertElement),
353+
m_InsertElt(
354+
m_Value(),
355+
m_OneUse(m_Specific(&ShuffleOp)),
356+
m_SpecificInt(VectorIndex)
357+
)))));
358+
359+
// Second pattern - shuffle is used in first insertelement.
360+
auto Pattern2 = m_OneUse(
361+
m_BitCast(
362+
m_OneUse(
363+
m_InsertElt(
364+
m_OneUse(
365+
m_CombineAnd(
366+
m_Instruction(InsertElement),
367+
m_InsertElt(
368+
m_Undef(),
369+
m_OneUse(m_Specific(&ShuffleOp)),
370+
m_SpecificInt(VectorIndex)))),
371+
m_Value(),
372+
m_SpecificInt(VectorIndex ? 0 : 1)
373+
))));
374+
375+
for (auto *User : InputValue->users())
376+
{
377+
Instruction *Op = dyn_cast<Instruction>(User);
378+
379+
if (match(Op, m_c_BinOp(m_Specific(InputValue), Pattern1)) || match(Op, m_c_BinOp(m_Specific(InputValue), Pattern2)))
380+
{
381+
if (!CheckVectorType(InsertElement->getType()))
382+
return;
383+
384+
// Now that pattern is matched, check if this is the first or second shuffle of the pair.
385+
if (VectorShufflePatterns.count(Op))
386+
{
387+
if (VectorShufflePatterns.find(Op)->second.match(Op, Lane, VectorIndex))
388+
{
389+
// Collected two parts of i64 shuffle, create new pattern.
390+
addShufflePattern(InputValue, ShuffleOp, Op, Lane);
391+
}
392+
}
393+
else
394+
{
395+
// First part of i64 shuffle, store for later matching.
396+
VectorShufflePatterns.try_emplace(Op, Op, Lane, VectorIndex);
397+
}
398+
399+
return;
400+
}
401+
}
229402
}
230403

231-
bool ShufflePattern::append(GenIntrinsicInst *ShuffleOp, Instruction *Op, WaveOps OpType, uint64_t Lane)
404+
bool ShufflePattern::append(Value *InputValue, GenIntrinsicInst *ShuffleOp, Instruction *Op, WaveOps OpType, uint64_t Lane)
232405
{
233406
if (this->OpType != OpType)
234407
return false;
@@ -238,13 +411,13 @@ bool ShufflePattern::append(GenIntrinsicInst *ShuffleOp, Instruction *Op, WaveOp
238411
if (PreviousValue->getNumUses() != 2)
239412
return false;
240413

241-
if (ShuffleOp->getOperand(0) != PreviousValue)
414+
if (InputValue != PreviousValue)
242415
return false;
243416

244-
if (Op->getOperand(0) != PreviousValue && Op->getOperand(1) != PreviousValue)
417+
if (Op->getOperand(0) != InputValue && Op->getOperand(1) != InputValue)
245418
return false;
246419

247-
Steps.emplace_back(ShuffleOp, Op, Lane);
420+
Steps.emplace_back(InputValue, ShuffleOp, Op, Lane);
248421
return true;
249422
}
250423

@@ -263,14 +436,15 @@ bool SubGroupReductionPattern::reduce(ShufflePattern &Pattern)
263436
if (ReductionType == GenISAIntrinsic::no_intrinsic)
264437
return false;
265438

439+
auto *InputValue = Pattern.Steps.front().InputValue;
266440
auto *FirstShuffle = Pattern.Steps.front().ShuffleOp;
267441
auto *LastOp = Pattern.Steps.back().Op;
268442

269443
IRBuilder<> IRB(FirstShuffle);
270444
IRB.SetCurrentDebugLocation(LastOp->getDebugLoc());
271445

272446
SmallVector<Value*, 4> Args;
273-
Args.push_back(FirstShuffle->getOperand(0));
447+
Args.push_back(InputValue);
274448
Args.push_back(IRB.getInt8((uint8_t) Pattern.OpType));
275449
if (ReductionType == GenISAIntrinsic::GenISA_WaveClustered)
276450
Args.push_back(IRB.getInt32(XorMask + 1));
@@ -431,6 +605,7 @@ WaveOps SubGroupReductionPattern::getWaveOp(Instruction *Op)
431605
#define PASS_CFG_ONLY false
432606
#define PASS_ANALYSIS false
433607
IGC_INITIALIZE_PASS_BEGIN(SubGroupReductionPattern, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
608+
IGC_INITIALIZE_PASS_DEPENDENCY(CodeGenContextWrapper)
434609
IGC_INITIALIZE_PASS_DEPENDENCY(MetaDataUtilsWrapper)
435610
IGC_INITIALIZE_PASS_END(SubGroupReductionPattern, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
436611

0 commit comments

Comments
 (0)