@@ -469,6 +469,7 @@ size_t DepSetBuilder::DpasMacroBuilder::getNumberOfSuppresionGroups(uint32_t src
469
469
return 0 ;
470
470
}
471
471
472
+
472
473
size_t DepSetBuilder::DpasMacroBuilder::formSrcSuppressionBlock (
473
474
InstListIterator startIt, uint32_t srcIdx) {
474
475
// get the candidate block
@@ -553,9 +554,10 @@ DepSetBuilder::DpasMacroBuilder::SuppressBlockPtrTy
553
554
DepSetBuilder::DpasMacroBuilder::getSuppressionBlockCandidate (
554
555
InstListIterator startIt, uint32_t srcIdx,
555
556
BitSet<>& allDstBits, BitSet<>& allSrcBits,
556
- BitSet<>& allDstNoLastBits, BitSet<>& allSrcNoLastBits) const {
557
+ BitSet<>& allDstNoLastBits, BitSet<>& allSrcNoLastBits,
558
+ int forceGroupNum) const {
557
559
assert (srcIdx == 1 || srcIdx == 2 );
558
- size_t maxGroupNum = getNumberOfSuppresionGroups (srcIdx);
560
+ size_t maxGroupNum = forceGroupNum < 0 ? getNumberOfSuppresionGroups (srcIdx) : forceGroupNum ;
559
561
// return null if the given src can't be suppressed
560
562
if (!maxGroupNum)
561
563
return nullptr ;
@@ -612,13 +614,12 @@ bool DepSetBuilder::DpasMacroBuilder::srcIsSuppressCandidate(const Instruction&
612
614
if (srcIdx == 1 )
613
615
return true ;
614
616
if (srcIdx == 2 ) {
615
- // can't be DP dpas
617
+ // DP dpas must have rep count 4
616
618
if (inst.isDF ())
617
- return false ;
619
+ return GetDpasRepeatCount (inst. getDpasFc ()) == 4 ;
618
620
619
- if (GetDpasRepeatCount (inst.getDpasFc ()) != 8 )
620
- return false ;
621
- return true ;
621
+ // allow only rep count 8 for non-DP dpase
622
+ return GetDpasRepeatCount (inst.getDpasFc ()) == 8 ;
622
623
}
623
624
return false ;
624
625
}
@@ -631,10 +632,8 @@ bool DepSetBuilder::DpasMacroBuilder::hasProducerConsumerDep(
631
632
BitSet<> new_dstbits (m_dsBuilder.getGRF_LEN ());
632
633
setDstSrcBits (src_range, dst_range, new_srcbits, new_dstbits);
633
634
634
- // check if there is WAR/RAW/WAW dependency
635
- if (target_src_bits.intersects (new_dstbits) ||
636
- target_dst_bits.intersects (new_srcbits) ||
637
- target_dst_bits.intersects (new_dstbits))
635
+ // check if there is RAW dependency
636
+ if (target_dst_bits.intersects (new_srcbits))
638
637
return true ;
639
638
return false ;
640
639
}
@@ -663,13 +662,16 @@ const Instruction& DepSetBuilder::DpasMacroBuilder::formMacro(size_t& dpasCnt) {
663
662
m_inps.getDpasDstDependency (**cur, dst_range);
664
663
InstListIterator next = cur;
665
664
next++;
666
- // early exit if there is no following instructions
667
- if (next == m_instList.end ()) {
665
+ // early exit if there is no following instructions or dpas depth is not 8
666
+ if (next == m_instList.end () || GetDpasSystolicDepth ((*cur)-> getDpasFc ()) != 8 ) {
668
667
updateRegFootprintsToDepSets (src_range, src_extra_range, dst_range);
669
668
return **cur;
670
669
}
671
670
672
- dpasCnt = std::max (dpasCnt, formSrcSuppressionBlock (m_firstDpasIt, 1 ));
671
+ bool formMacroForSrc1 = false ;
672
+
673
+ if (!formMacroForSrc1)
674
+ dpasCnt = std::max (dpasCnt, formSrcSuppressionBlock (m_firstDpasIt, 1 ));
673
675
dpasCnt = std::max (dpasCnt, formSrcSuppressionBlock (m_firstDpasIt, 2 ));
674
676
675
677
if (dpasCnt == 1 ) {
0 commit comments