Skip to content

Commit 3b8b434

Browse files
bowenxue-inteligcbot
authored andcommitted
Fix and enable WaveAllJointReduction by default
Fix bug with uniformity of destination register and enable WaveAllJointReduction by default
1 parent b55067a commit 3b8b434

File tree

4 files changed

+23
-26
lines changed

4 files changed

+23
-26
lines changed

IGC/Compiler/CISACodeGen/EmitVISAPass.cpp

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14188,10 +14188,19 @@ void EmitPass::emitReductionTree( e_opcode op, VISA_Type type, CVariable* src, C
1418814188
for( unsigned int i = 0; i < numIterations; i++ )
1418914189
{
1419014190
// Get alias for src0, src1, and dst based on offsets and SIMD size
14191-
auto* layerSrc0 = m_currShader->GetNewAlias( src, type, i * 2 * layerMaxSimdLanes * m_encoder->GetCISADataTypeSize( type ), layerMaxSimdLanes );
14192-
auto* layerSrc1 = m_currShader->GetNewAlias( src, type, ( i * 2 * layerMaxSimdLanes + src1Offset ) * m_encoder->GetCISADataTypeSize( type ), layerMaxSimdLanes );
14193-
auto* layerDst = m_currShader->GetNewAlias( src, type, i * layerMaxSimdLanes * m_encoder->GetCISADataTypeSize( type ), layerMaxSimdLanes );
14194-
14191+
auto* layerSrc0 = m_currShader->GetNewAlias( src, type, i * 2 * layerMaxSimdLanes * m_encoder->GetCISADataTypeSize( type ), layerMaxSimdLanes, false );
14192+
auto* layerSrc1 = m_currShader->GetNewAlias( src, type, ( i * 2 * layerMaxSimdLanes + src1Offset ) * m_encoder->GetCISADataTypeSize( type ), layerMaxSimdLanes, false );
14193+
CVariable* layerDst;
14194+
if( (srcElementCount >> 1 <= dst->GetNumberElement()) && (i + 1 == numIterations ))
14195+
{
14196+
// Final layer, use destination of WaveAll vector intrinsic inst (passed in with correct offset)
14197+
layerDst = dst;
14198+
}
14199+
else
14200+
{
14201+
// Use src as workspace to store intermediate values
14202+
layerDst = m_currShader->GetNewAlias( src, type, i * layerMaxSimdLanes * m_encoder->GetCISADataTypeSize( type ), layerMaxSimdLanes, false );
14203+
}
1419514204
if( !int64EmulationNeeded )
1419614205
{
1419714206
m_encoder->SetNoMask();
@@ -14220,13 +14229,6 @@ void EmitPass::emitReductionTree( e_opcode op, VISA_Type type, CVariable* src, C
1422014229
srcElementCount >>= 1;
1422114230
reductionElementCount >>= 1;
1422214231
}
14223-
14224-
// copy fully reduced elements from src to dst
14225-
auto* finalLayerDst = m_currShader->GetNewAlias( src, type, 0, dst->GetNumberElement() );
14226-
m_encoder->SetNoMask();
14227-
m_encoder->SetSimdSize( lanesToSIMDMode( dst->GetNumberElement() ) );
14228-
m_encoder->Copy( dst, finalLayerDst );
14229-
m_encoder->Push();
1423014232
}
1423114233

1423214234
// Recursive function that emits one or more joint reduction trees based on the joint output width
@@ -14240,8 +14242,8 @@ void EmitPass::emitReductionTrees( e_opcode op, VISA_Type type, SIMDMode simdMod
1424014242
// Do full tree reduction
1424114243
unsigned int reductionElements = src->GetNumberElement() / dst->GetNumberElement();
1424214244
unsigned int groupReductionElementCount = reductionElements * simdLanes;
14243-
CVariable* srcAlias = m_currShader->GetNewAlias( src, type, startIdx * reductionElements * m_encoder->GetCISADataTypeSize( type ), groupReductionElementCount );
14244-
CVariable* dstAlias = m_currShader->GetNewAlias( dst, type, startIdx * m_encoder->GetCISADataTypeSize( type ), simdLanes);
14245+
CVariable* srcAlias = m_currShader->GetNewAlias( src, type, startIdx * reductionElements * m_encoder->GetCISADataTypeSize( type ), groupReductionElementCount, false );
14246+
CVariable* dstAlias = m_currShader->GetNewAlias( dst, type, startIdx * m_encoder->GetCISADataTypeSize( type ), simdLanes, false);
1424514247
emitReductionTree( op, type, srcAlias, dstAlias );
1424614248
// Start new recursive tree if any elements are left
1424714249
if ( numGroups > simdLanes )
@@ -22559,13 +22561,13 @@ void EmitPass::emitWaveAll(llvm::GenIntrinsicInst* inst)
2255922561
for( uint16_t i = 0; i < dst->GetNumberElement(); i++ )
2256022562
{
2256122563
// Prepare reduceSrc
22562-
CVariable* srcAlias = m_currShader->GetNewAlias( src, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ) );
22563-
CVariable* reduceSrcAlias = m_currShader->GetNewAlias( reduceSrc, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ) );
22564+
CVariable* srcAlias = m_currShader->GetNewAlias( src, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ), false);
22565+
CVariable* reduceSrcAlias = m_currShader->GetNewAlias( reduceSrc, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ), false );
2256422566
ScanReducePrepareSrc( type, identity, false, false, srcAlias, reduceSrcAlias );
2256522567

2256622568
// Prepare reduceSrcSecondHalf
22567-
CVariable* srcSecondHalfAlias = m_currShader->GetNewAlias( src, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ) );
22568-
CVariable* reduceSrcSecondHalfAlias = m_currShader->GetNewAlias( reduceSrcSecondHalf, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ) );
22569+
CVariable* srcSecondHalfAlias = m_currShader->GetNewAlias( src, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ), false );
22570+
CVariable* reduceSrcSecondHalfAlias = m_currShader->GetNewAlias( reduceSrcSecondHalf, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ), false);
2256922571
ScanReducePrepareSrc( type, identity, false, true, srcSecondHalfAlias, reduceSrcSecondHalfAlias );
2257022572

2257122573
// Emit correct operations

IGC/Compiler/tests/EmitVISAPass/wave-all-joint-reduction-dual-simd16-group4.ll

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,7 @@ define void @CSMain(i32 %runtime_value_0, i32 %runtime_value_1, i32 %runtime_val
8484
; layer 3
8585
; CHECK: add (M1_NM, 8) reduceSrc_waveAllSrc0(0,0)<1> reduceSrc_waveAllSrc0(0,0)<4;2,1> reduceSrc_waveAllSrc0(0,2)<4;2,1>
8686
; layer 4
87-
; CHECK: add (M1_NM, 4) reduceSrc_waveAllSrc0(0,0)<1> reduceSrc_waveAllSrc0(0,0)<2;1,1> reduceSrc_waveAllSrc0(0,1)<2;1,1>
88-
; copy to dest
89-
; CHECK: mov (M1_NM, 1) waveAllJoint(0,0)<1> reduceSrc_waveAllSrc0(0,0)<1;1,0>
87+
; CHECK: add (M1_NM, 4) waveAllJoint(0,0)<1> reduceSrc_waveAllSrc0(0,0)<2;1,1> reduceSrc_waveAllSrc0(0,1)<2;1,1>
9088
%waveAllJoint = call <4 x i32> @llvm.genx.GenISA.WaveAll.v4i32.i8.i32(<4 x i32> %waveAllSrc3, i8 0, i32 0)
9189
%res_a = extractelement <4 x i32> %waveAllJoint, i32 0
9290
%res_b = extractelement <4 x i32> %waveAllJoint, i32 1

IGC/Compiler/tests/EmitVISAPass/wave-all-joint-reduction-simd32-group17.ll

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,13 @@ define void @CSMain(i32 %runtime_value_0, i32 %runtime_value_1, i32 %runtime_val
144144
; layer 4
145145
; CHECK: add (M1_NM, 32) reduceSrc_waveAllSrc0(0,0)<1> reduceSrc_waveAllSrc0(0,0)<4;2,1> reduceSrc_waveAllSrc0(0,2)<4;2,1>
146146
; layer 5
147-
; CHECK: add (M1_NM, 16) reduceSrc_waveAllSrc0(0,0)<1> reduceSrc_waveAllSrc0(0,0)<2;1,1> reduceSrc_waveAllSrc0(0,1)<2;1,1>
148-
; copy to dest
149-
; CHECK: mov (M1_NM, 1) waveAllJoint(0,0)<1> reduceSrc_waveAllSrc0(0,0)<1;1,0>
147+
; CHECK: add (M1_NM, 16) waveAllJoint(0,0)<1> reduceSrc_waveAllSrc0(0,0)<2;1,1> reduceSrc_waveAllSrc0(0,1)<2;1,1>
150148
; Joint Reduction Tree (1-wide, leftover from splitting the 17-wide vector into 16 and 1, almost identical to existing non-joint reduction tree generated from scalar WaveAll intrinsic further below)
151149
; CHECK: add (M1_NM, 16) reduceSrc_waveAllSrc0(32,0)<1> reduceSrc_waveAllSrc0(32,0)<32;16,1> reduceSrc_waveAllSrc0(33,0)<32;16,1>
152150
; CHECK: add (M1_NM, 8) reduceSrc_waveAllSrc0(32,0)<1> reduceSrc_waveAllSrc0(32,0)<16;8,1> reduceSrc_waveAllSrc0(32,8)<16;8,1>
153151
; CHECK: add (M1_NM, 4) reduceSrc_waveAllSrc0(32,0)<1> reduceSrc_waveAllSrc0(32,0)<8;4,1> reduceSrc_waveAllSrc0(32,4)<8;4,1>
154152
; CHECK: add (M1_NM, 2) reduceSrc_waveAllSrc0(32,0)<1> reduceSrc_waveAllSrc0(32,0)<4;2,1> reduceSrc_waveAllSrc0(32,2)<4;2,1>
155-
; CHECK: add (M1_NM, 1) reduceSrc_waveAllSrc0(32,0)<1> reduceSrc_waveAllSrc0(32,0)<2;1,1> reduceSrc_waveAllSrc0(32,1)<2;1,1>
156-
; CHECK: mov (M1_NM, 1) waveAllJoint(1,0)<1> reduceSrc_waveAllSrc0(32,0)<1;1,0>
153+
; CHECK: add (M1_NM, 1) waveAllJoint(1,0)<1> reduceSrc_waveAllSrc0(32,0)<2;1,1> reduceSrc_waveAllSrc0(32,1)<2;1,1>
157154
%waveAllJoint = call <17 x i32> @llvm.genx.GenISA.WaveAll.v17i32.i8.i32(<17 x i32> %waveAllSrc16, i8 0, i32 0)
158155
%res_f = call i32 @llvm.genx.GenISA.WaveAll.i32.i8.i32(i32 %f, i8 0, i32 0)
159156
%res_add_0 = extractelement <17 x i32> %waveAllJoint, i32 0

IGC/common/igc_flags.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ DECLARE_IGC_REGKEY(bool, DisableLoopSplitWidePHIs, false, "Disable splitting of
331331
DECLARE_IGC_REGKEY(bool, EnableBarrierControlFlowOptimizationPass, false, "Enable barrier control flow optimization pass", false)
332332
DECLARE_IGC_REGKEY(bool, EnableWaveShuffleIndexSinking, true, "Hoist identical instructions operating on WaveShuffleIndex instructions with the same source and a constant lane/channel", false)
333333
DECLARE_IGC_REGKEY(DWORD, WaveShuffleIndexSinkingMaxIterations, 3, "Max number of iterations to run iterative WaveShuffleIndexSinking", false)
334-
DECLARE_IGC_REGKEY(bool, EnableWaveAllJointReduction, false, "Enable Joint Reduction Optimization.", false)
334+
DECLARE_IGC_REGKEY(bool, EnableWaveAllJointReduction, true, "Enable Joint Reduction Optimization.", false)
335335

336336
DECLARE_IGC_GROUP("Shader debugging")
337337
DECLARE_IGC_REGKEY(bool, CopyA0ToDBG0, false, " Copy a0 used for extended msg descriptor to dbg0 to help debug", false)

0 commit comments

Comments
 (0)