@@ -206,20 +206,6 @@ G4_INST::G4_INST(const IR_Builder &irb, G4_Predicate *prd, G4_opcode o,
206
206
initOperands ();
207
207
}
208
208
209
- G4_INST::G4_INST (const IR_Builder &irb, G4_Predicate *prd, G4_opcode o,
210
- G4_CondMod *m, G4_Sat s, G4_ExecSize size, G4_DstRegRegion *d,
211
- G4_Operand *s0, G4_Operand *s1, G4_Operand *s2, G4_Operand *s3,
212
- G4_Operand *s4, G4_InstOpts opt)
213
- : op(o), dst(d), predicate(prd), mod(m), option(opt),
214
- useInstList(irb.getAllocator()), defInstList(irb.getAllocator()),
215
- sat(s ? true : false ), dead(false ), evenlySplitInst(false ),
216
- doPostRA(false ), canBeAcc(false ), doNotDelete(false ), execSize(size),
217
- builder(irb) {
218
- vISA_ASSERT (isDpas (), " Currently only dpas variants support 5 srcs" );
219
- srcs = {s0, s1, s2, s3, s4};
220
- initOperands ();
221
- }
222
-
223
209
G4_INST::G4_INST (const IR_Builder &irb, G4_Predicate *prd, G4_opcode o,
224
210
G4_CondMod *m, G4_Sat s, G4_ExecSize size, G4_DstRegRegion *d,
225
211
G4_Operand *s0, G4_Operand *s1, G4_Operand *s2, G4_Operand *s3,
@@ -2696,19 +2682,26 @@ bool G4_INST::goodTwoGRFDst(bool &evenSplitDst) const {
2696
2682
// propagation
2697
2683
bool G4_INST::isWARdep (G4_INST *inst) {
2698
2684
G4_Operand *msg0 = NULL ;
2685
+ G4_Operand *src0_0 = inst->getSrc (0 );
2686
+ G4_Operand *src0_1 = inst->getSrc (1 );
2687
+ G4_Operand *src0_2 = inst->getSrc (2 );
2688
+ G4_Operand *src0_3 = inst->getSrc (3 );
2699
2689
G4_Operand *implicitSrc0 = inst->getImplAccSrc ();
2700
2690
G4_Predicate *pred0 = inst->getPredicate ();
2701
2691
2702
2692
G4_Operand *dst1 = dst;
2703
2693
2704
2694
if (dst1 && !hasNULLDst ()) {
2705
2695
2706
- if (std::any_of (inst->src_begin (), inst->src_end (), [&](G4_Operand *src) {
2707
- return src->compareOperand (dst1, getBuilder ()) != Rel_disjoint;
2708
- }))
2709
- return true ;
2710
-
2711
- if ((msg0 && (msg0->compareOperand (dst1, getBuilder ()) != Rel_disjoint)) ||
2696
+ if ((src0_0 &&
2697
+ src0_0->compareOperand (dst1, getBuilder ()) != Rel_disjoint) ||
2698
+ (src0_1 &&
2699
+ src0_1->compareOperand (dst1, getBuilder ()) != Rel_disjoint) ||
2700
+ (src0_2 &&
2701
+ src0_2->compareOperand (dst1, getBuilder ()) != Rel_disjoint) ||
2702
+ (src0_3 &&
2703
+ src0_3->compareOperand (dst1, getBuilder ()) != Rel_disjoint) ||
2704
+ (msg0 && (msg0->compareOperand (dst1, getBuilder ()) != Rel_disjoint)) ||
2712
2705
(pred0 &&
2713
2706
(pred0->compareOperand (dst1, getBuilder ()) != Rel_disjoint)) ||
2714
2707
(implicitSrc0 &&
@@ -2718,27 +2711,29 @@ bool G4_INST::isWARdep(G4_INST *inst) {
2718
2711
}
2719
2712
2720
2713
if (mod) {
2721
- if (pred0 && pred0->compareOperand (mod, getBuilder ()) != Rel_disjoint)
2722
- return true ;
2723
-
2724
- if ( std::any_of (inst-> src_begin (), inst-> src_end (), [&](G4_Operand *src) {
2725
- return src-> isFlag () &&
2726
- src-> compareOperand (mod, getBuilder ()) != Rel_disjoint;
2727
- }))
2714
+ if (( pred0 && pred0->compareOperand (mod, getBuilder ()) != Rel_disjoint) ||
2715
+ (src0_0 && src0_0-> isFlag () &&
2716
+ src0_0-> compareOperand (mod, getBuilder ()) != Rel_disjoint) ||
2717
+ (src0_1 && src0_1-> isFlag () &&
2718
+ src0_1-> compareOperand (mod, getBuilder ()) != Rel_disjoint) ||
2719
+ (src0_2 && src0_2-> isFlag () &&
2720
+ src0_2-> compareOperand (mod, getBuilder ()) != Rel_disjoint)) {
2728
2721
return true ;
2722
+ }
2729
2723
}
2730
2724
2731
2725
auto implAccDst = getImplAccDst ();
2732
2726
if (implAccDst) {
2733
- if (implicitSrc0 && implicitSrc0->compareOperand (
2734
- implAccDst, getBuilder ()) != Rel_disjoint)
2735
- return true ;
2736
-
2737
- if ( std::any_of (inst-> src_begin (), inst-> src_end (), [&](G4_Operand *src) {
2738
- return src-> isAccReg () &&
2739
- src-> compareOperand (implAccDst, getBuilder ()) != Rel_disjoint;
2740
- }))
2727
+ if (( implicitSrc0 && implicitSrc0->compareOperand (
2728
+ implAccDst, getBuilder ()) != Rel_disjoint) ||
2729
+ (src0_0 && src0_0-> isAccReg () &&
2730
+ src0_0-> compareOperand (implAccDst, getBuilder ()) != Rel_disjoint) ||
2731
+ (src0_1 && src0_1-> isAccReg () &&
2732
+ src0_1-> compareOperand (implAccDst, getBuilder ()) != Rel_disjoint) ||
2733
+ (src0_2 && src0_2-> isAccReg () &&
2734
+ src0_2-> compareOperand (implAccDst, getBuilder ()) != Rel_disjoint)) {
2741
2735
return true ;
2736
+ }
2742
2737
}
2743
2738
return false ;
2744
2739
}
@@ -2788,46 +2783,52 @@ bool G4_INST::isRAWdep(G4_INST *inst) {
2788
2783
G4_CondMod *cMod0 = inst->getCondMod ();
2789
2784
G4_Operand *implicitDst0 = inst->getImplAccDst ();
2790
2785
G4_Predicate *pred1 = getPredicate ();
2786
+ G4_Operand *src1_0 = getSrc (0 );
2787
+ G4_Operand *src1_1 = getSrc (1 );
2788
+ G4_Operand *src1_2 = getSrc (2 );
2789
+ G4_Operand *src1_3 = getSrc (3 );
2791
2790
G4_Operand *implicitSrc1 = getImplAccSrc ();
2792
2791
2792
+ bool NULLSrc1 = (opcode () == G4_math && src1_1->isNullReg ());
2793
2793
if (dst0 && !inst->hasNULLDst ()) {
2794
- if (std::any_of (src_begin (), src_end (), [&](G4_Operand *src) {
2795
- // TODO: check if we can remove the null src1 check for math as
2796
- // compareOperand should handle NullReg already.
2797
- if (opcode () == G4_math && src == getSrc (1 ) && src->isNullReg ())
2798
- return false ;
2799
- return src->compareOperand (dst0, getBuilder ()) != Rel_disjoint;
2800
- }))
2801
- return true ;
2802
-
2803
- if ((pred1 && pred1->compareOperand (dst0, getBuilder ()) != Rel_disjoint) ||
2794
+ if ((src1_0 &&
2795
+ src1_0->compareOperand (dst0, getBuilder ()) != Rel_disjoint) ||
2796
+ (src1_1 && !NULLSrc1 &&
2797
+ src1_1->compareOperand (dst0, getBuilder ()) != Rel_disjoint) ||
2798
+ (src1_2 &&
2799
+ src1_2->compareOperand (dst0, getBuilder ()) != Rel_disjoint) ||
2800
+ (src1_3 &&
2801
+ src1_3->compareOperand (dst0, getBuilder ()) != Rel_disjoint) ||
2802
+ (pred1 && pred1->compareOperand (dst0, getBuilder ()) != Rel_disjoint) ||
2804
2803
(implicitSrc1 &&
2805
2804
implicitSrc1->compareOperand (dst0, getBuilder ()) != Rel_disjoint)) {
2806
2805
return true ;
2807
2806
}
2808
2807
}
2809
2808
2810
2809
if (cMod0 && cMod0->getBase ()) {
2811
- if (pred1 && pred1->compareOperand (cMod0, getBuilder ()) != Rel_disjoint)
2812
- return true ;
2813
-
2814
- if ( std::any_of ( src_begin (), src_end (), [&](G4_Operand *src) {
2815
- return src-> isFlag () &&
2816
- src-> compareOperand (cMod0, getBuilder ()) != Rel_disjoint;
2817
- }))
2810
+ if (( pred1 && pred1->compareOperand (cMod0, getBuilder ()) != Rel_disjoint) ||
2811
+ (src1_0 && src1_0-> isFlag () &&
2812
+ src1_0-> compareOperand (cMod0, getBuilder ()) != Rel_disjoint) ||
2813
+ (src1_2 && src1_2-> isFlag () &&
2814
+ src1_2-> compareOperand (cMod0, getBuilder ()) != Rel_disjoint) ||
2815
+ (src1_1 && src1_1-> isFlag () &&
2816
+ src1_1-> compareOperand (cMod0, getBuilder ()) != Rel_disjoint)) {
2818
2817
return true ;
2818
+ }
2819
2819
}
2820
2820
2821
2821
if (implicitDst0) {
2822
- if (implicitSrc1 && implicitSrc1->compareOperand (
2823
- implicitDst0, getBuilder ()) != Rel_disjoint)
2824
- return true ;
2825
-
2826
- if ( std::any_of ( src_begin (), src_end (), [&](G4_Operand *src) {
2827
- return src-> isAccReg () &&
2828
- src-> compareOperand (implicitDst0, getBuilder ()) != Rel_disjoint;
2829
- }))
2822
+ if (( implicitSrc1 && implicitSrc1->compareOperand (
2823
+ implicitDst0, getBuilder ()) != Rel_disjoint) ||
2824
+ (src1_0 && src1_0-> isAccReg () &&
2825
+ src1_0-> compareOperand (implicitDst0, getBuilder ()) != Rel_disjoint) ||
2826
+ (src1_2 && src1_2-> isAccReg () &&
2827
+ src1_2-> compareOperand (implicitDst0, getBuilder ()) != Rel_disjoint) ||
2828
+ (src1_1 && src1_1-> isAccReg () &&
2829
+ src1_1-> compareOperand (implicitDst0, getBuilder ()) != Rel_disjoint)) {
2830
2830
return true ;
2831
+ }
2831
2832
}
2832
2833
return false ;
2833
2834
}
@@ -7836,10 +7837,10 @@ G4_INST *G4_InstDpas::cloneInst(const IR_Builder *b) {
7836
7837
auto src1 = nonConstBuilder->duplicateOperand (getSrc (1 ));
7837
7838
auto src2 = nonConstBuilder->duplicateOperand (getSrc (2 ));
7838
7839
auto src3 = nonConstBuilder->duplicateOperand (getSrc (3 ));
7839
- auto src4 = nonConstBuilder->duplicateOperand (getSrc (4 ));
7840
7840
return nonConstBuilder->createInternalDpasInst (
7841
- op, getExecSize (), dst, src0, src1, src2, option, getSrc2Precision (),
7842
- getSrc1Precision (), getSystolicDepth (), getRepeatCount (), src3, src4);
7841
+ op, getExecSize (), dst, src0, src1, src2, src3, option,
7842
+ getSrc2Precision (), getSrc1Precision (), getSystolicDepth (),
7843
+ getRepeatCount ());
7843
7844
}
7844
7845
7845
7846
bool G4_InstDpas::isInt () const {
@@ -7893,9 +7894,9 @@ uint8_t G4_InstDpas::getOpsPerChan() const {
7893
7894
void G4_InstDpas::computeRightBound (G4_Operand *opnd) {
7894
7895
associateOpndWithInst (opnd, this );
7895
7896
if (opnd && !opnd->isImm () && !opnd->isNullReg ()) {
7896
- uint8_t D = getSystolicDepth ();
7897
- uint8_t C = getRepeatCount ();
7898
- G4_ExecSize ES = getExecSize ();
7897
+ G4_InstDpas *dpasInst = asDpasInst ();
7898
+ uint8_t D = dpasInst-> getSystolicDepth ();
7899
+ uint8_t C = dpasInst-> getRepeatCount ();
7899
7900
7900
7901
auto computeDpasOperandBound = [this ](G4_Operand *opnd, unsigned leftBound,
7901
7902
unsigned rightBound) {
@@ -7907,24 +7908,24 @@ void G4_InstDpas::computeRightBound(G4_Operand *opnd) {
7907
7908
if (opnd == dst || (opnd == srcs[0 ] && !opnd->isNullReg ())) {
7908
7909
// dst and src0 are always packed, and RB is exec_size * type_size *
7909
7910
// rep_count
7910
- auto opndSize = ES * opnd->getTypeSize () * C;
7911
+ auto opndSize = dpasInst-> getExecSize () * opnd->getTypeSize () * C;
7911
7912
computeDpasOperandBound (opnd, opnd->left_bound ,
7912
7913
opnd->left_bound + opndSize - 1 );
7913
7914
} else if (opnd == srcs[1 ]) {
7914
- uint32_t bytesPerLane = getSrc1SizePerLaneInByte ();
7915
+ uint32_t bytesPerLane = dpasInst-> getSrc1SizePerLaneInByte ();
7915
7916
uint8_t src1_D = D;
7916
7917
7917
7918
// Each lanes needs (src1_D * bytesPerLane) bytes, and it's multiple of
7918
7919
// DW!
7919
7920
uint32_t bytesPerLaneForAllDepth = bytesPerLane * src1_D;
7920
7921
bytesPerLaneForAllDepth = ((bytesPerLaneForAllDepth + 3 ) / 4 ) * 4 ;
7921
7922
7922
- uint32_t bytes = bytesPerLaneForAllDepth * ES ;
7923
+ uint32_t bytes = bytesPerLaneForAllDepth * dpasInst-> getExecSize () ;
7923
7924
computeDpasOperandBound (opnd, opnd->left_bound ,
7924
7925
opnd->left_bound + bytes - 1 );
7925
7926
} else if (opnd == srcs[2 ]) {
7926
7927
// src2 is uniform.
7927
- uint32_t bytesPerLane = getSrc2SizePerLaneInByte ();
7928
+ uint32_t bytesPerLane = dpasInst-> getSrc2SizePerLaneInByte ();
7928
7929
uint32_t bytes = bytesPerLane * D * C;
7929
7930
if (op == G4_dpasw) {
7930
7931
bytes = bytesPerLane * D * ((C + 1 ) / 2 );
@@ -7935,10 +7936,10 @@ void G4_InstDpas::computeRightBound(G4_Operand *opnd) {
7935
7936
7936
7937
else if (opnd && opnd == srcs[3 ]) {
7937
7938
uint32_t bytes;
7938
- if (isInt ())
7939
+ if (dpasInst-> isInt ())
7939
7940
{
7940
7941
bytes = 2 * getBuilder ().getGRFSize ();
7941
- } else if (isFP16 () || isBF16 ()) {
7942
+ } else if (dpasInst-> isFP16 () || dpasInst-> isBF16 ()) {
7942
7943
bytes = getBuilder ().getGRFSize ();
7943
7944
} else { // isTF32()
7944
7945
bytes = getBuilder ().getGRFSize () / 2 ;
0 commit comments