@@ -206,6 +206,20 @@ G4_INST::G4_INST(const IR_Builder &irb, G4_Predicate *prd, G4_opcode o,
206
206
initOperands ();
207
207
}
208
208
209
+ G4_INST::G4_INST (const IR_Builder &irb, G4_Predicate *prd, G4_opcode o,
210
+ G4_CondMod *m, G4_Sat s, G4_ExecSize size, G4_DstRegRegion *d,
211
+ G4_Operand *s0, G4_Operand *s1, G4_Operand *s2, G4_Operand *s3,
212
+ G4_Operand *s4, G4_InstOpts opt)
213
+ : op(o), dst(d), predicate(prd), mod(m), option(opt),
214
+ useInstList(irb.getAllocator()), defInstList(irb.getAllocator()),
215
+ sat(s ? true : false ), dead(false ), evenlySplitInst(false ),
216
+ doPostRA(false ), canBeAcc(false ), doNotDelete(false ), execSize(size),
217
+ builder(irb) {
218
+ vISA_ASSERT (isDpas (), " Currently only dpas variants support 5 srcs" );
219
+ srcs = {s0, s1, s2, s3, s4};
220
+ initOperands ();
221
+ }
222
+
209
223
G4_INST::G4_INST (const IR_Builder &irb, G4_Predicate *prd, G4_opcode o,
210
224
G4_CondMod *m, G4_Sat s, G4_ExecSize size, G4_DstRegRegion *d,
211
225
G4_Operand *s0, G4_Operand *s1, G4_Operand *s2, G4_Operand *s3,
@@ -2682,26 +2696,19 @@ bool G4_INST::goodTwoGRFDst(bool &evenSplitDst) const {
2682
2696
// propagation
2683
2697
bool G4_INST::isWARdep (G4_INST *inst) {
2684
2698
G4_Operand *msg0 = NULL ;
2685
- G4_Operand *src0_0 = inst->getSrc (0 );
2686
- G4_Operand *src0_1 = inst->getSrc (1 );
2687
- G4_Operand *src0_2 = inst->getSrc (2 );
2688
- G4_Operand *src0_3 = inst->getSrc (3 );
2689
2699
G4_Operand *implicitSrc0 = inst->getImplAccSrc ();
2690
2700
G4_Predicate *pred0 = inst->getPredicate ();
2691
2701
2692
2702
G4_Operand *dst1 = dst;
2693
2703
2694
2704
if (dst1 && !hasNULLDst ()) {
2695
2705
2696
- if ((src0_0 &&
2697
- src0_0->compareOperand (dst1, getBuilder ()) != Rel_disjoint) ||
2698
- (src0_1 &&
2699
- src0_1->compareOperand (dst1, getBuilder ()) != Rel_disjoint) ||
2700
- (src0_2 &&
2701
- src0_2->compareOperand (dst1, getBuilder ()) != Rel_disjoint) ||
2702
- (src0_3 &&
2703
- src0_3->compareOperand (dst1, getBuilder ()) != Rel_disjoint) ||
2704
- (msg0 && (msg0->compareOperand (dst1, getBuilder ()) != Rel_disjoint)) ||
2706
+ if (std::any_of (inst->src_begin (), inst->src_end (), [&](G4_Operand *src) {
2707
+ return src->compareOperand (dst1, getBuilder ()) != Rel_disjoint;
2708
+ }))
2709
+ return true ;
2710
+
2711
+ if ((msg0 && (msg0->compareOperand (dst1, getBuilder ()) != Rel_disjoint)) ||
2705
2712
(pred0 &&
2706
2713
(pred0->compareOperand (dst1, getBuilder ()) != Rel_disjoint)) ||
2707
2714
(implicitSrc0 &&
@@ -2711,29 +2718,27 @@ bool G4_INST::isWARdep(G4_INST *inst) {
2711
2718
}
2712
2719
2713
2720
if (mod) {
2714
- if (( pred0 && pred0->compareOperand (mod, getBuilder ()) != Rel_disjoint) ||
2715
- (src0_0 && src0_0-> isFlag () &&
2716
- src0_0-> compareOperand (mod, getBuilder ()) != Rel_disjoint) ||
2717
- (src0_1 && src0_1-> isFlag () &&
2718
- src0_1-> compareOperand (mod, getBuilder ()) != Rel_disjoint) ||
2719
- (src0_2 && src0_2-> isFlag () &&
2720
- src0_2-> compareOperand (mod, getBuilder ()) != Rel_disjoint)) {
2721
+ if (pred0 && pred0->compareOperand (mod, getBuilder ()) != Rel_disjoint)
2722
+ return true ;
2723
+
2724
+ if ( std::any_of (inst-> src_begin (), inst-> src_end (), [&](G4_Operand *src) {
2725
+ return src-> isFlag () &&
2726
+ src-> compareOperand (mod, getBuilder ()) != Rel_disjoint;
2727
+ }))
2721
2728
return true ;
2722
- }
2723
2729
}
2724
2730
2725
2731
auto implAccDst = getImplAccDst ();
2726
2732
if (implAccDst) {
2727
- if (( implicitSrc0 && implicitSrc0->compareOperand (
2728
- implAccDst, getBuilder ()) != Rel_disjoint) ||
2729
- (src0_0 && src0_0-> isAccReg () &&
2730
- src0_0-> compareOperand (implAccDst, getBuilder ()) != Rel_disjoint) ||
2731
- (src0_1 && src0_1-> isAccReg () &&
2732
- src0_1-> compareOperand (implAccDst, getBuilder ()) != Rel_disjoint) ||
2733
- (src0_2 && src0_2-> isAccReg () &&
2734
- src0_2-> compareOperand (implAccDst, getBuilder ()) != Rel_disjoint)) {
2733
+ if (implicitSrc0 && implicitSrc0->compareOperand (
2734
+ implAccDst, getBuilder ()) != Rel_disjoint)
2735
+ return true ;
2736
+
2737
+ if ( std::any_of (inst-> src_begin (), inst-> src_end (), [&](G4_Operand *src) {
2738
+ return src-> isAccReg () &&
2739
+ src-> compareOperand (implAccDst, getBuilder ()) != Rel_disjoint;
2740
+ }))
2735
2741
return true ;
2736
- }
2737
2742
}
2738
2743
return false ;
2739
2744
}
@@ -2783,52 +2788,46 @@ bool G4_INST::isRAWdep(G4_INST *inst) {
2783
2788
G4_CondMod *cMod0 = inst->getCondMod ();
2784
2789
G4_Operand *implicitDst0 = inst->getImplAccDst ();
2785
2790
G4_Predicate *pred1 = getPredicate ();
2786
- G4_Operand *src1_0 = getSrc (0 );
2787
- G4_Operand *src1_1 = getSrc (1 );
2788
- G4_Operand *src1_2 = getSrc (2 );
2789
- G4_Operand *src1_3 = getSrc (3 );
2790
2791
G4_Operand *implicitSrc1 = getImplAccSrc ();
2791
2792
2792
- bool NULLSrc1 = (opcode () == G4_math && src1_1->isNullReg ());
2793
2793
if (dst0 && !inst->hasNULLDst ()) {
2794
- if ((src1_0 &&
2795
- src1_0->compareOperand (dst0, getBuilder ()) != Rel_disjoint) ||
2796
- (src1_1 && !NULLSrc1 &&
2797
- src1_1->compareOperand (dst0, getBuilder ()) != Rel_disjoint) ||
2798
- (src1_2 &&
2799
- src1_2->compareOperand (dst0, getBuilder ()) != Rel_disjoint) ||
2800
- (src1_3 &&
2801
- src1_3->compareOperand (dst0, getBuilder ()) != Rel_disjoint) ||
2802
- (pred1 && pred1->compareOperand (dst0, getBuilder ()) != Rel_disjoint) ||
2794
+ if (std::any_of (src_begin (), src_end (), [&](G4_Operand *src) {
2795
+ // TODO: check if we can remove the null src1 check for math as
2796
+ // compareOperand should handle NullReg already.
2797
+ if (opcode () == G4_math && src == getSrc (1 ) && src->isNullReg ())
2798
+ return false ;
2799
+ return src->compareOperand (dst0, getBuilder ()) != Rel_disjoint;
2800
+ }))
2801
+ return true ;
2802
+
2803
+ if ((pred1 && pred1->compareOperand (dst0, getBuilder ()) != Rel_disjoint) ||
2803
2804
(implicitSrc1 &&
2804
2805
implicitSrc1->compareOperand (dst0, getBuilder ()) != Rel_disjoint)) {
2805
2806
return true ;
2806
2807
}
2807
2808
}
2808
2809
2809
2810
if (cMod0 && cMod0->getBase ()) {
2810
- if (( pred1 && pred1->compareOperand (cMod0, getBuilder ()) != Rel_disjoint) ||
2811
- (src1_0 && src1_0-> isFlag () &&
2812
- src1_0-> compareOperand (cMod0, getBuilder ()) != Rel_disjoint) ||
2813
- (src1_2 && src1_2-> isFlag () &&
2814
- src1_2-> compareOperand (cMod0, getBuilder ()) != Rel_disjoint) ||
2815
- (src1_1 && src1_1-> isFlag () &&
2816
- src1_1-> compareOperand (cMod0, getBuilder ()) != Rel_disjoint)) {
2811
+ if (pred1 && pred1->compareOperand (cMod0, getBuilder ()) != Rel_disjoint)
2812
+ return true ;
2813
+
2814
+ if ( std::any_of ( src_begin (), src_end (), [&](G4_Operand *src) {
2815
+ return src-> isFlag () &&
2816
+ src-> compareOperand (cMod0, getBuilder ()) != Rel_disjoint;
2817
+ }))
2817
2818
return true ;
2818
- }
2819
2819
}
2820
2820
2821
2821
if (implicitDst0) {
2822
- if (( implicitSrc1 && implicitSrc1->compareOperand (
2823
- implicitDst0, getBuilder ()) != Rel_disjoint) ||
2824
- (src1_0 && src1_0-> isAccReg () &&
2825
- src1_0-> compareOperand (implicitDst0, getBuilder ()) != Rel_disjoint) ||
2826
- (src1_2 && src1_2-> isAccReg () &&
2827
- src1_2-> compareOperand (implicitDst0, getBuilder ()) != Rel_disjoint) ||
2828
- (src1_1 && src1_1-> isAccReg () &&
2829
- src1_1-> compareOperand (implicitDst0, getBuilder ()) != Rel_disjoint)) {
2822
+ if (implicitSrc1 && implicitSrc1->compareOperand (
2823
+ implicitDst0, getBuilder ()) != Rel_disjoint)
2824
+ return true ;
2825
+
2826
+ if ( std::any_of ( src_begin (), src_end (), [&](G4_Operand *src) {
2827
+ return src-> isAccReg () &&
2828
+ src-> compareOperand (implicitDst0, getBuilder ()) != Rel_disjoint;
2829
+ }))
2830
2830
return true ;
2831
- }
2832
2831
}
2833
2832
return false ;
2834
2833
}
@@ -7837,10 +7836,10 @@ G4_INST *G4_InstDpas::cloneInst(const IR_Builder *b) {
7837
7836
auto src1 = nonConstBuilder->duplicateOperand (getSrc (1 ));
7838
7837
auto src2 = nonConstBuilder->duplicateOperand (getSrc (2 ));
7839
7838
auto src3 = nonConstBuilder->duplicateOperand (getSrc (3 ));
7839
+ auto src4 = nonConstBuilder->duplicateOperand (getSrc (4 ));
7840
7840
return nonConstBuilder->createInternalDpasInst (
7841
- op, getExecSize (), dst, src0, src1, src2, src3, option,
7842
- getSrc2Precision (), getSrc1Precision (), getSystolicDepth (),
7843
- getRepeatCount ());
7841
+ op, getExecSize (), dst, src0, src1, src2, option, getSrc2Precision (),
7842
+ getSrc1Precision (), getSystolicDepth (), getRepeatCount (), src3, src4);
7844
7843
}
7845
7844
7846
7845
bool G4_InstDpas::isInt () const {
@@ -7894,9 +7893,9 @@ uint8_t G4_InstDpas::getOpsPerChan() const {
7894
7893
void G4_InstDpas::computeRightBound (G4_Operand *opnd) {
7895
7894
associateOpndWithInst (opnd, this );
7896
7895
if (opnd && !opnd->isImm () && !opnd->isNullReg ()) {
7897
- G4_InstDpas *dpasInst = asDpasInst ();
7898
- uint8_t D = dpasInst-> getSystolicDepth ();
7899
- uint8_t C = dpasInst-> getRepeatCount ();
7896
+ uint8_t D = getSystolicDepth ();
7897
+ uint8_t C = getRepeatCount ();
7898
+ G4_ExecSize ES = getExecSize ();
7900
7899
7901
7900
auto computeDpasOperandBound = [this ](G4_Operand *opnd, unsigned leftBound,
7902
7901
unsigned rightBound) {
@@ -7908,24 +7907,24 @@ void G4_InstDpas::computeRightBound(G4_Operand *opnd) {
7908
7907
if (opnd == dst || (opnd == srcs[0 ] && !opnd->isNullReg ())) {
7909
7908
// dst and src0 are always packed, and RB is exec_size * type_size *
7910
7909
// rep_count
7911
- auto opndSize = dpasInst-> getExecSize () * opnd->getTypeSize () * C;
7910
+ auto opndSize = ES * opnd->getTypeSize () * C;
7912
7911
computeDpasOperandBound (opnd, opnd->left_bound ,
7913
7912
opnd->left_bound + opndSize - 1 );
7914
7913
} else if (opnd == srcs[1 ]) {
7915
- uint32_t bytesPerLane = dpasInst-> getSrc1SizePerLaneInByte ();
7914
+ uint32_t bytesPerLane = getSrc1SizePerLaneInByte ();
7916
7915
uint8_t src1_D = D;
7917
7916
7918
7917
// Each lanes needs (src1_D * bytesPerLane) bytes, and it's multiple of
7919
7918
// DW!
7920
7919
uint32_t bytesPerLaneForAllDepth = bytesPerLane * src1_D;
7921
7920
bytesPerLaneForAllDepth = ((bytesPerLaneForAllDepth + 3 ) / 4 ) * 4 ;
7922
7921
7923
- uint32_t bytes = bytesPerLaneForAllDepth * dpasInst-> getExecSize () ;
7922
+ uint32_t bytes = bytesPerLaneForAllDepth * ES ;
7924
7923
computeDpasOperandBound (opnd, opnd->left_bound ,
7925
7924
opnd->left_bound + bytes - 1 );
7926
7925
} else if (opnd == srcs[2 ]) {
7927
7926
// src2 is uniform.
7928
- uint32_t bytesPerLane = dpasInst-> getSrc2SizePerLaneInByte ();
7927
+ uint32_t bytesPerLane = getSrc2SizePerLaneInByte ();
7929
7928
uint32_t bytes = bytesPerLane * D * C;
7930
7929
if (op == G4_dpasw) {
7931
7930
bytes = bytesPerLane * D * ((C + 1 ) / 2 );
@@ -7936,10 +7935,10 @@ void G4_InstDpas::computeRightBound(G4_Operand *opnd) {
7936
7935
7937
7936
else if (opnd && opnd == srcs[3 ]) {
7938
7937
uint32_t bytes;
7939
- if (dpasInst-> isInt ())
7938
+ if (isInt ())
7940
7939
{
7941
7940
bytes = 2 * getBuilder ().getGRFSize ();
7942
- } else if (dpasInst-> isFP16 () || dpasInst-> isBF16 ()) {
7941
+ } else if (isFP16 () || isBF16 ()) {
7943
7942
bytes = getBuilder ().getGRFSize ();
7944
7943
} else { // isTF32()
7945
7944
bytes = getBuilder ().getGRFSize () / 2 ;
0 commit comments