Skip to content

Commit ee05ac8

Browse files
pkwasnie-intelpszymich
authored andcommitted
New GenISA intrinsic: WaveInterleave
Adds new GenISA intrinsic WaveInterleave that does subgroup reduction on each n-th work item. For example, for SIMD8 and interleave step = 2, the result is reduction of work items 0,2,4,6 and separate reduction of work items 1,3,5,7. Change includes pattern match for interleave reduction implemented with subgroup shuffles. (cherry picked from commit 86502fa)
1 parent e729dad commit ee05ac8

File tree

13 files changed

+460
-38
lines changed

13 files changed

+460
-38
lines changed

IGC/Compiler/CISACodeGen/CheckInstrTypes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ void CheckInstrTypes::visitCallInst(CallInst& C)
334334
case GenISAIntrinsic::GenISA_WaveInverseBallot:
335335
case GenISAIntrinsic::GenISA_WavePrefix:
336336
case GenISAIntrinsic::GenISA_WaveClustered:
337+
case GenISAIntrinsic::GenISA_WaveInterleave:
337338
case GenISAIntrinsic::GenISA_QuadPrefix:
338339
case GenISAIntrinsic::GenISA_simdShuffleDown:
339340
case GenISAIntrinsic::GenISA_simdShuffleXor:

IGC/Compiler/CISACodeGen/EmitVISAPass.cpp

Lines changed: 113 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8905,6 +8905,9 @@ void EmitPass::EmitGenIntrinsicMessage(llvm::GenIntrinsicInst* inst)
89058905
case GenISAIntrinsic::GenISA_WaveAll:
89068906
emitWaveAll(inst);
89078907
break;
8908+
case GenISAIntrinsic::GenISA_WaveInterleave:
8909+
emitWaveInterleave(inst);
8910+
break;
89088911
case GenISAIntrinsic::GenISA_WaveClustered:
89098912
emitWaveClustered(inst);
89108913
break;
@@ -13167,8 +13170,45 @@ CVariable* EmitPass::ScanReducePrepareSrc(VISA_Type type, uint64_t identityValue
1316713170
}
1316813171

1316913172
// Reduction all reduce helper: dst_lane{k} = src_lane{simd + k} OP src_lane{k}, k = 0..(simd-1)
13170-
CVariable* EmitPass::ReductionReduceHelper(e_opcode op, VISA_Type type, SIMDMode simd, CVariable* src)
13173+
CVariable* EmitPass::ReductionReduceHelper(e_opcode op, VISA_Type type, SIMDMode simd, CVariable* src, CVariable* srcSecondHalf)
1317113174
{
13175+
const bool isInt64Mul = ScanReduceIsInt64Mul(op, type);
13176+
const bool int64EmulationNeeded = ScanReduceIsInt64EmulationNeeded(op, type);
13177+
13178+
if (simd == SIMDMode::SIMD16 && m_currShader->m_numberInstance > 1)
13179+
{
13180+
IGC_ASSERT(srcSecondHalf);
13181+
13182+
CVariable* temp = m_currShader->GetNewVariable(
13183+
numLanes(simd),
13184+
type,
13185+
EALIGN_GRF,
13186+
false,
13187+
CName("reduceDstSecondHalf"));
13188+
13189+
if (!int64EmulationNeeded)
13190+
{
13191+
m_encoder->SetNoMask();
13192+
m_encoder->SetSimdSize(simd);
13193+
m_encoder->GenericAlu(op, temp, src, srcSecondHalf);
13194+
m_encoder->Push();
13195+
}
13196+
else
13197+
{
13198+
if (isInt64Mul)
13199+
{
13200+
CVariable* tmpMulSrc[2] = { src, srcSecondHalf };
13201+
Mul64(temp, tmpMulSrc, simd, true /* noMask */);
13202+
}
13203+
else
13204+
{
13205+
IGC_ASSERT_MESSAGE(0, "Unsupported");
13206+
}
13207+
}
13208+
13209+
return temp;
13210+
}
13211+
1317213212
const bool is64bitType = ScanReduceIs64BitType(type);
1317313213
const auto alignment = is64bitType ? IGC::EALIGN_QWORD : IGC::EALIGN_DWORD;
1317413214
CVariable* temp = m_currShader->GetNewVariable(
@@ -13178,9 +13218,6 @@ CVariable* EmitPass::ReductionReduceHelper(e_opcode op, VISA_Type type, SIMDMode
1317813218
false,
1317913219
CName("reduceDst_SIMD", std::to_string(numLanes(simd)).c_str()));
1318013220

13181-
const bool isInt64Mul = ScanReduceIsInt64Mul(op, type);
13182-
const bool int64EmulationNeeded = ScanReduceIsInt64EmulationNeeded(op, type);
13183-
1318413221
if (!int64EmulationNeeded)
1318513222
{
1318613223
m_encoder->SetNoMask();
@@ -13546,34 +13583,7 @@ void EmitPass::emitReductionAll(
1354613583
CVariable* srcH2 = ScanReducePrepareSrc(type, identityValue, negate, true /* secondHalf */,
1354713584
src, nullptr /* dst */);
1354813585

13549-
temp = m_currShader->GetNewVariable(
13550-
numLanes(simd),
13551-
type,
13552-
EALIGN_GRF,
13553-
false,
13554-
CName("reduceDstSecondHalf"));
13555-
13556-
const bool isInt64Mul = ScanReduceIsInt64Mul(op, type);
13557-
const bool int64EmulationNeeded = ScanReduceIsInt64EmulationNeeded(op, type);
13558-
if (!int64EmulationNeeded)
13559-
{
13560-
m_encoder->SetNoMask();
13561-
m_encoder->SetSimdSize(simd);
13562-
m_encoder->GenericAlu(op, temp, srcH1, srcH2);
13563-
m_encoder->Push();
13564-
}
13565-
else
13566-
{
13567-
if (isInt64Mul)
13568-
{
13569-
CVariable* tmpMulSrc[2] = { srcH1, srcH2 };
13570-
Mul64(temp, tmpMulSrc, simd, true /* noMask */);
13571-
}
13572-
else
13573-
{
13574-
IGC_ASSERT_MESSAGE(0, "Unsupported");
13575-
}
13576-
}
13586+
temp = ReductionReduceHelper(op, type, SIMDMode::SIMD16, temp, srcH2);
1357713587
}
1357813588
}
1357913589
if (m_currShader->m_dispatchSize >= SIMDMode::SIMD16)
@@ -13723,6 +13733,54 @@ void EmitPass::emitReductionClustered(const e_opcode op, const uint64_t identity
1372313733
}
1372413734
}
1372513735

13736+
void EmitPass::emitReductionInterleave(const e_opcode op, const uint64_t identityValue, const VISA_Type type,
13737+
const bool negate, const unsigned int step, CVariable* const src, CVariable* const dst)
13738+
{
13739+
if (step == 1)
13740+
{
13741+
// TODO: consider if it is possible to detect and handle this case in frontends
13742+
// and emit GenISA_WaveAll there, to enable optimizations specific to the ReduceAll intrinsic.
13743+
return emitReductionAll(op, identityValue, type, negate, src, dst);
13744+
}
13745+
13746+
const uint16_t firstStep = numLanes(m_currShader->m_dispatchSize) / 2;
13747+
13748+
IGC_ASSERT_MESSAGE(!dst->IsUniform(), "Unsupported: dst must be non-uniform");
13749+
IGC_ASSERT_MESSAGE(step % 2 == 0 && step <= firstStep, "Invalid reduction interleave step");
13750+
13751+
CVariable* srcH1 = ScanReducePrepareSrc(type, identityValue, negate, false /* secondHalf */,
13752+
src, nullptr /* dst */);
13753+
CVariable* temp = srcH1;
13754+
13755+
// Implementation is similar to emitReductionAll(), but we stop reduction before reaching SIMD1.
13756+
for (unsigned int currentStep = firstStep; currentStep >= step; currentStep >>= 1)
13757+
{
13758+
if (currentStep == 16 && m_currShader->m_numberInstance > 1)
13759+
{
13760+
CVariable* srcH2 = ScanReducePrepareSrc(type, identityValue, negate, true /* secondHalf */,
13761+
src, nullptr /* dst */);
13762+
13763+
temp = ReductionReduceHelper(op, type, SIMDMode::SIMD16, temp, srcH2);
13764+
}
13765+
else
13766+
{
13767+
temp = ReductionReduceHelper(op, type, lanesToSIMDMode(currentStep), temp);
13768+
}
13769+
}
13770+
13771+
// Broadcast result
13772+
m_encoder->SetSimdSize(m_currShader->m_SIMDSize);
13773+
m_encoder->SetSrcRegion(0, 0, step, 1);
13774+
m_encoder->Copy(dst, temp);
13775+
if (m_currShader->m_numberInstance > 1)
13776+
{
13777+
m_encoder->SetSecondHalf(true);
13778+
m_encoder->Copy(dst, temp);
13779+
m_encoder->SetSecondHalf(false);
13780+
}
13781+
m_encoder->Push();
13782+
}
13783+
1372613784
// do prefix op across all activate channels
1372713785
void EmitPass::emitPreOrPostFixOp(
1372813786
e_opcode op, uint64_t identityValue, VISA_Type type, bool negateSrc,
@@ -21141,6 +21199,29 @@ void EmitPass::emitWaveClustered(llvm::GenIntrinsicInst* inst)
2114121199
}
2114221200
}
2114321201

21202+
void EmitPass::emitWaveInterleave(llvm::GenIntrinsicInst* inst)
21203+
{
21204+
bool disableHelperLanes = int_cast<int>(cast<ConstantInt>(inst->getArgOperand(3))->getSExtValue()) == 2;
21205+
if (disableHelperLanes)
21206+
{
21207+
ForceDMask();
21208+
}
21209+
CVariable* src = GetSymbol(inst->getOperand(0));
21210+
const WaveOps op = static_cast<WaveOps>(cast<llvm::ConstantInt>(inst->getOperand(1))->getZExtValue());
21211+
const unsigned int step = int_cast<uint32_t>(cast<llvm::ConstantInt>(inst->getOperand(2))->getZExtValue());
21212+
VISA_Type type;
21213+
e_opcode opCode;
21214+
uint64_t identity = 0;
21215+
GetReductionOp(op, inst->getOperand(0)->getType(), identity, opCode, type);
21216+
CVariable* dst = m_destination;
21217+
m_encoder->SetSubSpanDestination(false);
21218+
emitReductionInterleave(opCode, identity, type, false, step, src, dst);
21219+
if (disableHelperLanes)
21220+
{
21221+
ResetVMask();
21222+
}
21223+
}
21224+
2114421225
void EmitPass::emitDP4A(GenIntrinsicInst* GII, const SSource* Sources, const DstModifier& modifier, bool isAccSigned) {
2114521226
GenISAIntrinsic::ID GIID = GII->getIntrinsicID();
2114621227
CVariable* dst = m_destination;

IGC/Compiler/CISACodeGen/EmitVISAPass.hpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ class EmitPass : public llvm::FunctionPass
301301
bool ScanReduceIsInt64EmulationNeeded(e_opcode op, VISA_Type type);
302302
CVariable* ScanReducePrepareSrc(VISA_Type type, uint64_t identityValue, bool negate, bool secondHalf,
303303
CVariable* src, CVariable* dst, CVariable* flag = nullptr);
304-
CVariable* ReductionReduceHelper(e_opcode op, VISA_Type type, SIMDMode simd, CVariable* src);
304+
CVariable* ReductionReduceHelper(e_opcode op, VISA_Type type, SIMDMode simd, CVariable* src, CVariable* srcSecondHalf = nullptr);
305305
void ReductionExpandHelper(e_opcode op, VISA_Type type, CVariable* src, CVariable* dst);
306306
void ReductionClusteredSrcHelper(CVariable* (&pSrc)[2], CVariable* src, uint16_t numLanes,
307307
VISA_Type type, uint numInst, bool secondHalf);
@@ -325,6 +325,14 @@ class EmitPass : public llvm::FunctionPass
325325
const unsigned int clusterSize,
326326
CVariable* const src,
327327
CVariable* const dst);
328+
void emitReductionInterleave(
329+
const e_opcode op,
330+
const uint64_t identityValue,
331+
const VISA_Type type,
332+
const bool negate,
333+
const unsigned int step,
334+
CVariable* const src,
335+
CVariable* const dst);
328336
void emitPreOrPostFixOp(
329337
e_opcode op,
330338
uint64_t identityValue,
@@ -432,6 +440,7 @@ class EmitPass : public llvm::FunctionPass
432440
void emitQuadPrefix(llvm::QuadPrefixIntrinsic* I);
433441
void emitWaveAll(llvm::GenIntrinsicInst* inst);
434442
void emitWaveClustered(llvm::GenIntrinsicInst* inst);
443+
void emitWaveInterleave(llvm::GenIntrinsicInst* inst);
435444

436445
// Those three "vector" version shall be combined with
437446
// non-vector version.

IGC/Compiler/CISACodeGen/HalfPromotion.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ void IGC::HalfPromotion::handleGenIntrinsic(llvm::GenIntrinsicInst& I)
113113
GenISAIntrinsic::ID id = I.getIntrinsicID();
114114
if (id == GenISAIntrinsic::GenISA_WaveAll ||
115115
id == GenISAIntrinsic::GenISA_WavePrefix ||
116-
id == GenISAIntrinsic::GenISA_WaveClustered)
116+
id == GenISAIntrinsic::GenISA_WaveClustered ||
117+
id == GenISAIntrinsic::GenISA_WaveInterleave)
117118
{
118119
Module* M = I.getParent()->getParent()->getParent();
119120
llvm::IGCIRBuilder<> builder(&I);

IGC/Compiler/CISACodeGen/PatternMatchPass.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1370,6 +1370,7 @@ namespace IGC
13701370
case GenISAIntrinsic::GenISA_WaveInverseBallot:
13711371
case GenISAIntrinsic::GenISA_WaveAll:
13721372
case GenISAIntrinsic::GenISA_WaveClustered:
1373+
case GenISAIntrinsic::GenISA_WaveInterleave:
13731374
case GenISAIntrinsic::GenISA_WavePrefix:
13741375
match = MatchWaveInstruction(*GII);
13751376
break;
@@ -5183,6 +5184,7 @@ namespace IGC
51835184
case GenISAIntrinsic::GenISA_WaveInverseBallot:
51845185
helperLaneIndex = 1;
51855186
break;
5187+
case GenISAIntrinsic::GenISA_WaveInterleave:
51865188
case GenISAIntrinsic::GenISA_WaveClustered:
51875189
helperLaneIndex = 3;
51885190
break;

IGC/Compiler/CISACodeGen/PromoteInt8Type.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,6 +1134,7 @@ void PromoteInt8Type::promoteIntrinsic()
11341134
else if (
11351135
GII->isGenIntrinsic(GenISAIntrinsic::GenISA_WaveAll) ||
11361136
GII->isGenIntrinsic(GenISAIntrinsic::GenISA_WaveClustered) ||
1137+
GII->isGenIntrinsic(GenISAIntrinsic::GenISA_WaveInterleave) ||
11371138
GII->isGenIntrinsic(GenISAIntrinsic::GenISA_WavePrefix) ||
11381139
GII->isGenIntrinsic(GenISAIntrinsic::GenISA_QuadPrefix))
11391140
{
@@ -1158,6 +1159,7 @@ void PromoteInt8Type::promoteIntrinsic()
11581159
GenISAIntrinsic::ID gid = GII->getIntrinsicID();
11591160
if (gid == GenISAIntrinsic::GenISA_WaveAll ||
11601161
gid == GenISAIntrinsic::GenISA_WaveClustered ||
1162+
gid == GenISAIntrinsic::GenISA_WaveInterleave ||
11611163
gid == GenISAIntrinsic::GenISA_WavePrefix ||
11621164
gid == GenISAIntrinsic::GenISA_QuadPrefix ||
11631165
gid == GenISAIntrinsic::GenISA_WaveShuffleIndex ||
@@ -1199,9 +1201,11 @@ void PromoteInt8Type::promoteIntrinsic()
11991201
break;
12001202
}
12011203
case GenISAIntrinsic::GenISA_WaveClustered:
1204+
case GenISAIntrinsic::GenISA_WaveInterleave:
12021205
{
12031206
// prototype:
12041207
// Ty <clustered> (Ty, char, int, int)
1208+
// Ty <interleave> (Ty, char, int, int)
12051209
iArgs.push_back(GII->getArgOperand(1));
12061210
iArgs.push_back(GII->getArgOperand(2));
12071211
iArgs.push_back(GII->getArgOperand(3));

IGC/Compiler/CISACodeGen/WIAnalysis.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1415,6 +1415,7 @@ WIAnalysis::WIDependancy WIAnalysisRunner::calculate_dep(const CallInst* inst)
14151415
intrinsic_name == llvm_waveBallot ||
14161416
intrinsic_name == llvm_waveAll ||
14171417
intrinsic_name == llvm_waveClustered ||
1418+
intrinsic_name == llvm_waveInterleave ||
14181419
intrinsic_name == llvm_ld_ptr ||
14191420
intrinsic_name == llvm_ldlptr ||
14201421
(IGC_IS_FLAG_DISABLED(DisableUniformTypedAccess) && intrinsic_name == llvm_typed_read) ||
@@ -1718,6 +1719,11 @@ WIAnalysis::WIDependancy WIAnalysisRunner::calculate_dep(const CallInst* inst)
17181719
}
17191720
}
17201721

1722+
if (intrinsic_name == llvm_waveInterleave)
1723+
{
1724+
return WIAnalysis::RANDOM;
1725+
}
1726+
17211727
if (intrinsic_name == llvm_URBRead ||
17221728
intrinsic_name == llvm_URBReadOutput)
17231729
{

IGC/Compiler/CISACodeGen/helper.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1782,6 +1782,7 @@ namespace IGC
17821782
{
17831783
return (opcode == llvm_waveAll ||
17841784
opcode == llvm_waveClustered ||
1785+
opcode == llvm_waveInterleave ||
17851786
opcode == llvm_wavePrefix ||
17861787
opcode == llvm_waveShuffleIndex ||
17871788
opcode == llvm_waveBroadcast ||

IGC/Compiler/CISACodeGen/opCode.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ DECLARE_OPCODE(GenISA_pair_to_ptr, GenISAIntrinsic, llvm_pair_to_ptr, false, fal
283283
DECLARE_OPCODE(GenISA_WaveBallot, GenISAIntrinsic, llvm_waveBallot, false, false, false, false, false, false, false)
284284
DECLARE_OPCODE(GenISA_WaveAll, GenISAIntrinsic, llvm_waveAll, false, false, false, false, false, false, false)
285285
DECLARE_OPCODE(GenISA_WaveClustered, GenISAIntrinsic, llvm_waveClustered, false, false, false, false, false, false, false)
286+
DECLARE_OPCODE(GenISA_WaveInterleave, GenISAIntrinsic, llvm_waveInterleave, false, false, false, false, false, false, false)
286287
DECLARE_OPCODE(GenISA_WavePrefix, GenISAIntrinsic, llvm_wavePrefix, false, false, false, false, false, false, false)
287288
DECLARE_OPCODE(GenISA_QuadPrefix, GenISAIntrinsic, llvm_quadPrefix, false, false, false, false, false, false, false)
288289
DECLARE_OPCODE(GenISA_WaveShuffleIndex, GenISAIntrinsic, llvm_waveShuffleIndex, false, false, false, false, false, false, false)

IGC/Compiler/CodeGenPublicEnums.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ namespace IGC
191191
{
192192
GroupOperationScan,
193193
GroupOperationReduce,
194-
GroupOperationClusteredReduce
194+
GroupOperationClusteredReduce,
195+
GroupOperationInterleaveReduce
195196
};
196197

197198
enum SGVUsage

0 commit comments

Comments
 (0)