Skip to content

Commit 86502fa

Browse files
pkwasnie-inteligcbot
authored andcommitted
New GenISA intrinsic: WaveInterleave
Adds new GenISA intrinsic WaveInterleave that does subgroup reduction on each n-th work item. For example, for SIMD8 and interleave step = 2, the result is reduction of work items 0,2,4,6 and separate reduction of work items 1,3,5,7. Change includes pattern match for interleave reduction implemented with subgroup shuffles.
1 parent dff1024 commit 86502fa

File tree

13 files changed

+460
-38
lines changed

13 files changed

+460
-38
lines changed

IGC/Compiler/CISACodeGen/CheckInstrTypes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ void CheckInstrTypes::visitCallInst(CallInst& C)
334334
case GenISAIntrinsic::GenISA_WaveInverseBallot:
335335
case GenISAIntrinsic::GenISA_WavePrefix:
336336
case GenISAIntrinsic::GenISA_WaveClustered:
337+
case GenISAIntrinsic::GenISA_WaveInterleave:
337338
case GenISAIntrinsic::GenISA_QuadPrefix:
338339
case GenISAIntrinsic::GenISA_simdShuffleDown:
339340
case GenISAIntrinsic::GenISA_simdShuffleXor:

IGC/Compiler/CISACodeGen/EmitVISAPass.cpp

Lines changed: 113 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8905,6 +8905,9 @@ void EmitPass::EmitGenIntrinsicMessage(llvm::GenIntrinsicInst* inst)
89058905
case GenISAIntrinsic::GenISA_WaveAll:
89068906
emitWaveAll(inst);
89078907
break;
8908+
case GenISAIntrinsic::GenISA_WaveInterleave:
8909+
emitWaveInterleave(inst);
8910+
break;
89088911
case GenISAIntrinsic::GenISA_WaveClustered:
89098912
emitWaveClustered(inst);
89108913
break;
@@ -13167,8 +13170,45 @@ CVariable* EmitPass::ScanReducePrepareSrc(VISA_Type type, uint64_t identityValue
1316713170
}
1316813171

1316913172
// Reduction all reduce helper: dst_lane{k} = src_lane{simd + k} OP src_lane{k}, k = 0..(simd-1)
13170-
CVariable* EmitPass::ReductionReduceHelper(e_opcode op, VISA_Type type, SIMDMode simd, CVariable* src)
13173+
CVariable* EmitPass::ReductionReduceHelper(e_opcode op, VISA_Type type, SIMDMode simd, CVariable* src, CVariable* srcSecondHalf)
1317113174
{
13175+
const bool isInt64Mul = ScanReduceIsInt64Mul(op, type);
13176+
const bool int64EmulationNeeded = ScanReduceIsInt64EmulationNeeded(op, type);
13177+
13178+
if (simd == SIMDMode::SIMD16 && m_currShader->m_numberInstance > 1)
13179+
{
13180+
IGC_ASSERT(srcSecondHalf);
13181+
13182+
CVariable* temp = m_currShader->GetNewVariable(
13183+
numLanes(simd),
13184+
type,
13185+
EALIGN_GRF,
13186+
false,
13187+
CName("reduceDstSecondHalf"));
13188+
13189+
if (!int64EmulationNeeded)
13190+
{
13191+
m_encoder->SetNoMask();
13192+
m_encoder->SetSimdSize(simd);
13193+
m_encoder->GenericAlu(op, temp, src, srcSecondHalf);
13194+
m_encoder->Push();
13195+
}
13196+
else
13197+
{
13198+
if (isInt64Mul)
13199+
{
13200+
CVariable* tmpMulSrc[2] = { src, srcSecondHalf };
13201+
Mul64(temp, tmpMulSrc, simd, true /* noMask */);
13202+
}
13203+
else
13204+
{
13205+
IGC_ASSERT_MESSAGE(0, "Unsupported");
13206+
}
13207+
}
13208+
13209+
return temp;
13210+
}
13211+
1317213212
const bool is64bitType = ScanReduceIs64BitType(type);
1317313213
const auto alignment = is64bitType ? IGC::EALIGN_QWORD : IGC::EALIGN_DWORD;
1317413214
CVariable* temp = m_currShader->GetNewVariable(
@@ -13178,9 +13218,6 @@ CVariable* EmitPass::ReductionReduceHelper(e_opcode op, VISA_Type type, SIMDMode
1317813218
false,
1317913219
CName("reduceDst_SIMD", std::to_string(numLanes(simd)).c_str()));
1318013220

13181-
const bool isInt64Mul = ScanReduceIsInt64Mul(op, type);
13182-
const bool int64EmulationNeeded = ScanReduceIsInt64EmulationNeeded(op, type);
13183-
1318413221
if (!int64EmulationNeeded)
1318513222
{
1318613223
m_encoder->SetNoMask();
@@ -13546,34 +13583,7 @@ void EmitPass::emitReductionAll(
1354613583
CVariable* srcH2 = ScanReducePrepareSrc(type, identityValue, negate, true /* secondHalf */,
1354713584
src, nullptr /* dst */);
1354813585

13549-
temp = m_currShader->GetNewVariable(
13550-
numLanes(simd),
13551-
type,
13552-
EALIGN_GRF,
13553-
false,
13554-
CName("reduceDstSecondHalf"));
13555-
13556-
const bool isInt64Mul = ScanReduceIsInt64Mul(op, type);
13557-
const bool int64EmulationNeeded = ScanReduceIsInt64EmulationNeeded(op, type);
13558-
if (!int64EmulationNeeded)
13559-
{
13560-
m_encoder->SetNoMask();
13561-
m_encoder->SetSimdSize(simd);
13562-
m_encoder->GenericAlu(op, temp, srcH1, srcH2);
13563-
m_encoder->Push();
13564-
}
13565-
else
13566-
{
13567-
if (isInt64Mul)
13568-
{
13569-
CVariable* tmpMulSrc[2] = { srcH1, srcH2 };
13570-
Mul64(temp, tmpMulSrc, simd, true /* noMask */);
13571-
}
13572-
else
13573-
{
13574-
IGC_ASSERT_MESSAGE(0, "Unsupported");
13575-
}
13576-
}
13586+
temp = ReductionReduceHelper(op, type, SIMDMode::SIMD16, temp, srcH2);
1357713587
}
1357813588
}
1357913589
if (m_currShader->m_dispatchSize >= SIMDMode::SIMD16)
@@ -13723,6 +13733,54 @@ void EmitPass::emitReductionClustered(const e_opcode op, const uint64_t identity
1372313733
}
1372413734
}
1372513735

13736+
void EmitPass::emitReductionInterleave(const e_opcode op, const uint64_t identityValue, const VISA_Type type,
13737+
const bool negate, const unsigned int step, CVariable* const src, CVariable* const dst)
13738+
{
13739+
if (step == 1)
13740+
{
13741+
// TODO: consider if it is possible to detect and handle this case in frontends
13742+
// and emit GenISA_WaveAll there, to enable optimizations specific to the ReduceAll intrinsic.
13743+
return emitReductionAll(op, identityValue, type, negate, src, dst);
13744+
}
13745+
13746+
const uint16_t firstStep = numLanes(m_currShader->m_dispatchSize) / 2;
13747+
13748+
IGC_ASSERT_MESSAGE(!dst->IsUniform(), "Unsupported: dst must be non-uniform");
13749+
IGC_ASSERT_MESSAGE(step % 2 == 0 && step <= firstStep, "Invalid reduction interleave step");
13750+
13751+
CVariable* srcH1 = ScanReducePrepareSrc(type, identityValue, negate, false /* secondHalf */,
13752+
src, nullptr /* dst */);
13753+
CVariable* temp = srcH1;
13754+
13755+
// Implementation is similar to emitReductionAll(), but we stop reduction before reaching SIMD1.
13756+
for (unsigned int currentStep = firstStep; currentStep >= step; currentStep >>= 1)
13757+
{
13758+
if (currentStep == 16 && m_currShader->m_numberInstance > 1)
13759+
{
13760+
CVariable* srcH2 = ScanReducePrepareSrc(type, identityValue, negate, true /* secondHalf */,
13761+
src, nullptr /* dst */);
13762+
13763+
temp = ReductionReduceHelper(op, type, SIMDMode::SIMD16, temp, srcH2);
13764+
}
13765+
else
13766+
{
13767+
temp = ReductionReduceHelper(op, type, lanesToSIMDMode(currentStep), temp);
13768+
}
13769+
}
13770+
13771+
// Broadcast result
13772+
m_encoder->SetSimdSize(m_currShader->m_SIMDSize);
13773+
m_encoder->SetSrcRegion(0, 0, step, 1);
13774+
m_encoder->Copy(dst, temp);
13775+
if (m_currShader->m_numberInstance > 1)
13776+
{
13777+
m_encoder->SetSecondHalf(true);
13778+
m_encoder->Copy(dst, temp);
13779+
m_encoder->SetSecondHalf(false);
13780+
}
13781+
m_encoder->Push();
13782+
}
13783+
1372613784
// do prefix op across all activate channels
1372713785
void EmitPass::emitPreOrPostFixOp(
1372813786
e_opcode op, uint64_t identityValue, VISA_Type type, bool negateSrc,
@@ -21141,6 +21199,29 @@ void EmitPass::emitWaveClustered(llvm::GenIntrinsicInst* inst)
2114121199
}
2114221200
}
2114321201

21202+
void EmitPass::emitWaveInterleave(llvm::GenIntrinsicInst* inst)
21203+
{
21204+
bool disableHelperLanes = int_cast<int>(cast<ConstantInt>(inst->getArgOperand(3))->getSExtValue()) == 2;
21205+
if (disableHelperLanes)
21206+
{
21207+
ForceDMask();
21208+
}
21209+
CVariable* src = GetSymbol(inst->getOperand(0));
21210+
const WaveOps op = static_cast<WaveOps>(cast<llvm::ConstantInt>(inst->getOperand(1))->getZExtValue());
21211+
const unsigned int step = int_cast<uint32_t>(cast<llvm::ConstantInt>(inst->getOperand(2))->getZExtValue());
21212+
VISA_Type type;
21213+
e_opcode opCode;
21214+
uint64_t identity = 0;
21215+
GetReductionOp(op, inst->getOperand(0)->getType(), identity, opCode, type);
21216+
CVariable* dst = m_destination;
21217+
m_encoder->SetSubSpanDestination(false);
21218+
emitReductionInterleave(opCode, identity, type, false, step, src, dst);
21219+
if (disableHelperLanes)
21220+
{
21221+
ResetVMask();
21222+
}
21223+
}
21224+
2114421225
void EmitPass::emitDP4A(GenIntrinsicInst* GII, const SSource* Sources, const DstModifier& modifier, bool isAccSigned) {
2114521226
GenISAIntrinsic::ID GIID = GII->getIntrinsicID();
2114621227
CVariable* dst = m_destination;

IGC/Compiler/CISACodeGen/EmitVISAPass.hpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ class EmitPass : public llvm::FunctionPass
301301
bool ScanReduceIsInt64EmulationNeeded(e_opcode op, VISA_Type type);
302302
CVariable* ScanReducePrepareSrc(VISA_Type type, uint64_t identityValue, bool negate, bool secondHalf,
303303
CVariable* src, CVariable* dst, CVariable* flag = nullptr);
304-
CVariable* ReductionReduceHelper(e_opcode op, VISA_Type type, SIMDMode simd, CVariable* src);
304+
CVariable* ReductionReduceHelper(e_opcode op, VISA_Type type, SIMDMode simd, CVariable* src, CVariable* srcSecondHalf = nullptr);
305305
void ReductionExpandHelper(e_opcode op, VISA_Type type, CVariable* src, CVariable* dst);
306306
void ReductionClusteredSrcHelper(CVariable* (&pSrc)[2], CVariable* src, uint16_t numLanes,
307307
VISA_Type type, uint numInst, bool secondHalf);
@@ -325,6 +325,14 @@ class EmitPass : public llvm::FunctionPass
325325
const unsigned int clusterSize,
326326
CVariable* const src,
327327
CVariable* const dst);
328+
void emitReductionInterleave(
329+
const e_opcode op,
330+
const uint64_t identityValue,
331+
const VISA_Type type,
332+
const bool negate,
333+
const unsigned int step,
334+
CVariable* const src,
335+
CVariable* const dst);
328336
void emitPreOrPostFixOp(
329337
e_opcode op,
330338
uint64_t identityValue,
@@ -432,6 +440,7 @@ class EmitPass : public llvm::FunctionPass
432440
void emitQuadPrefix(llvm::QuadPrefixIntrinsic* I);
433441
void emitWaveAll(llvm::GenIntrinsicInst* inst);
434442
void emitWaveClustered(llvm::GenIntrinsicInst* inst);
443+
void emitWaveInterleave(llvm::GenIntrinsicInst* inst);
435444

436445
// Those three "vector" version shall be combined with
437446
// non-vector version.

IGC/Compiler/CISACodeGen/HalfPromotion.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ void IGC::HalfPromotion::handleGenIntrinsic(llvm::GenIntrinsicInst& I)
113113
GenISAIntrinsic::ID id = I.getIntrinsicID();
114114
if (id == GenISAIntrinsic::GenISA_WaveAll ||
115115
id == GenISAIntrinsic::GenISA_WavePrefix ||
116-
id == GenISAIntrinsic::GenISA_WaveClustered)
116+
id == GenISAIntrinsic::GenISA_WaveClustered ||
117+
id == GenISAIntrinsic::GenISA_WaveInterleave)
117118
{
118119
Module* M = I.getParent()->getParent()->getParent();
119120
llvm::IGCIRBuilder<> builder(&I);

IGC/Compiler/CISACodeGen/PatternMatchPass.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1370,6 +1370,7 @@ namespace IGC
13701370
case GenISAIntrinsic::GenISA_WaveInverseBallot:
13711371
case GenISAIntrinsic::GenISA_WaveAll:
13721372
case GenISAIntrinsic::GenISA_WaveClustered:
1373+
case GenISAIntrinsic::GenISA_WaveInterleave:
13731374
case GenISAIntrinsic::GenISA_WavePrefix:
13741375
match = MatchWaveInstruction(*GII);
13751376
break;
@@ -5183,6 +5184,7 @@ namespace IGC
51835184
case GenISAIntrinsic::GenISA_WaveInverseBallot:
51845185
helperLaneIndex = 1;
51855186
break;
5187+
case GenISAIntrinsic::GenISA_WaveInterleave:
51865188
case GenISAIntrinsic::GenISA_WaveClustered:
51875189
helperLaneIndex = 3;
51885190
break;

IGC/Compiler/CISACodeGen/PromoteInt8Type.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,6 +1134,7 @@ void PromoteInt8Type::promoteIntrinsic()
11341134
else if (
11351135
GII->isGenIntrinsic(GenISAIntrinsic::GenISA_WaveAll) ||
11361136
GII->isGenIntrinsic(GenISAIntrinsic::GenISA_WaveClustered) ||
1137+
GII->isGenIntrinsic(GenISAIntrinsic::GenISA_WaveInterleave) ||
11371138
GII->isGenIntrinsic(GenISAIntrinsic::GenISA_WavePrefix) ||
11381139
GII->isGenIntrinsic(GenISAIntrinsic::GenISA_QuadPrefix))
11391140
{
@@ -1158,6 +1159,7 @@ void PromoteInt8Type::promoteIntrinsic()
11581159
GenISAIntrinsic::ID gid = GII->getIntrinsicID();
11591160
if (gid == GenISAIntrinsic::GenISA_WaveAll ||
11601161
gid == GenISAIntrinsic::GenISA_WaveClustered ||
1162+
gid == GenISAIntrinsic::GenISA_WaveInterleave ||
11611163
gid == GenISAIntrinsic::GenISA_WavePrefix ||
11621164
gid == GenISAIntrinsic::GenISA_QuadPrefix ||
11631165
gid == GenISAIntrinsic::GenISA_WaveShuffleIndex ||
@@ -1199,9 +1201,11 @@ void PromoteInt8Type::promoteIntrinsic()
11991201
break;
12001202
}
12011203
case GenISAIntrinsic::GenISA_WaveClustered:
1204+
case GenISAIntrinsic::GenISA_WaveInterleave:
12021205
{
12031206
// prototype:
12041207
// Ty <clustered> (Ty, char, int, int)
1208+
// Ty <interleave> (Ty, char, int, int)
12051209
iArgs.push_back(GII->getArgOperand(1));
12061210
iArgs.push_back(GII->getArgOperand(2));
12071211
iArgs.push_back(GII->getArgOperand(3));

IGC/Compiler/CISACodeGen/WIAnalysis.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1415,6 +1415,7 @@ WIAnalysis::WIDependancy WIAnalysisRunner::calculate_dep(const CallInst* inst)
14151415
intrinsic_name == llvm_waveBallot ||
14161416
intrinsic_name == llvm_waveAll ||
14171417
intrinsic_name == llvm_waveClustered ||
1418+
intrinsic_name == llvm_waveInterleave ||
14181419
intrinsic_name == llvm_ld_ptr ||
14191420
intrinsic_name == llvm_ldlptr ||
14201421
(IGC_IS_FLAG_DISABLED(DisableUniformTypedAccess) && intrinsic_name == llvm_typed_read) ||
@@ -1718,6 +1719,11 @@ WIAnalysis::WIDependancy WIAnalysisRunner::calculate_dep(const CallInst* inst)
17181719
}
17191720
}
17201721

1722+
if (intrinsic_name == llvm_waveInterleave)
1723+
{
1724+
return WIAnalysis::RANDOM;
1725+
}
1726+
17211727
if (intrinsic_name == llvm_URBRead ||
17221728
intrinsic_name == llvm_URBReadOutput)
17231729
{

IGC/Compiler/CISACodeGen/helper.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1782,6 +1782,7 @@ namespace IGC
17821782
{
17831783
return (opcode == llvm_waveAll ||
17841784
opcode == llvm_waveClustered ||
1785+
opcode == llvm_waveInterleave ||
17851786
opcode == llvm_wavePrefix ||
17861787
opcode == llvm_waveShuffleIndex ||
17871788
opcode == llvm_waveBroadcast ||

IGC/Compiler/CISACodeGen/opCode.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ DECLARE_OPCODE(GenISA_pair_to_ptr, GenISAIntrinsic, llvm_pair_to_ptr, false, fal
283283
DECLARE_OPCODE(GenISA_WaveBallot, GenISAIntrinsic, llvm_waveBallot, false, false, false, false, false, false, false)
284284
DECLARE_OPCODE(GenISA_WaveAll, GenISAIntrinsic, llvm_waveAll, false, false, false, false, false, false, false)
285285
DECLARE_OPCODE(GenISA_WaveClustered, GenISAIntrinsic, llvm_waveClustered, false, false, false, false, false, false, false)
286+
DECLARE_OPCODE(GenISA_WaveInterleave, GenISAIntrinsic, llvm_waveInterleave, false, false, false, false, false, false, false)
286287
DECLARE_OPCODE(GenISA_WavePrefix, GenISAIntrinsic, llvm_wavePrefix, false, false, false, false, false, false, false)
287288
DECLARE_OPCODE(GenISA_QuadPrefix, GenISAIntrinsic, llvm_quadPrefix, false, false, false, false, false, false, false)
288289
DECLARE_OPCODE(GenISA_WaveShuffleIndex, GenISAIntrinsic, llvm_waveShuffleIndex, false, false, false, false, false, false, false)

IGC/Compiler/CodeGenPublicEnums.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ namespace IGC
191191
{
192192
GroupOperationScan,
193193
GroupOperationReduce,
194-
GroupOperationClusteredReduce
194+
GroupOperationClusteredReduce,
195+
GroupOperationInterleaveReduce
195196
};
196197

197198
enum SGVUsage

0 commit comments

Comments
 (0)