@@ -75,8 +75,9 @@ enum class SchedGroupMask {
75
75
DS = 1u << 7 ,
76
76
DS_READ = 1u << 8 ,
77
77
DS_WRITE = 1u << 9 ,
78
+ TRANS = 1u << 10 ,
78
79
ALL = ALU | VALU | SALU | MFMA | VMEM | VMEM_READ | VMEM_WRITE | DS |
79
- DS_READ | DS_WRITE,
80
+ DS_READ | DS_WRITE | TRANS ,
80
81
LLVM_MARK_AS_BITMASK_ENUM (/* LargestFlag = */ ALL)
81
82
};
82
83
@@ -1435,11 +1436,12 @@ bool SchedGroup::canAddMI(const MachineInstr &MI) const {
1435
1436
Result = false ;
1436
1437
1437
1438
else if (((SGMask & SchedGroupMask::ALU) != SchedGroupMask::NONE) &&
1438
- (TII->isVALU (MI) || TII->isMFMAorWMMA (MI) || TII->isSALU (MI)))
1439
+ (TII->isVALU (MI) || TII->isMFMAorWMMA (MI) || TII->isSALU (MI) ||
1440
+ TII->isTRANS (MI)))
1439
1441
Result = true ;
1440
1442
1441
1443
else if (((SGMask & SchedGroupMask::VALU) != SchedGroupMask::NONE) &&
1442
- TII->isVALU (MI) && !TII->isMFMAorWMMA (MI))
1444
+ TII->isVALU (MI) && !TII->isMFMAorWMMA (MI) && !TII-> isTRANS (MI) )
1443
1445
Result = true ;
1444
1446
1445
1447
else if (((SGMask & SchedGroupMask::SALU) != SchedGroupMask::NONE) &&
@@ -1476,6 +1478,10 @@ bool SchedGroup::canAddMI(const MachineInstr &MI) const {
1476
1478
MI.mayStore () && TII->isDS (MI))
1477
1479
Result = true ;
1478
1480
1481
+ else if (((SGMask & SchedGroupMask::TRANS) != SchedGroupMask::NONE) &&
1482
+ TII->isTRANS (MI))
1483
+ Result = true ;
1484
+
1479
1485
LLVM_DEBUG (
1480
1486
dbgs () << " For SchedGroup with mask " << format_hex ((int )SGMask, 10 , true )
1481
1487
<< (Result ? " could classify " : " unable to classify " ) << MI);
@@ -1635,10 +1641,13 @@ void IGroupLPDAGMutation::addSchedBarrierEdges(SUnit &SchedBarrier) {
1635
1641
// Remove all existing edges from the SCHED_BARRIER that were added due to the
1636
1642
// instruction having side effects.
1637
1643
resetEdges (SchedBarrier, DAG);
1644
+ LLVM_DEBUG (dbgs () << " Building SchedGroup for SchedBarrier with Mask: "
1645
+ << MI.getOperand (0 ).getImm () << " \n " );
1638
1646
auto InvertedMask =
1639
1647
invertSchedBarrierMask ((SchedGroupMask)MI.getOperand (0 ).getImm ());
1640
1648
SchedGroup SG (InvertedMask, std::nullopt, DAG, TII);
1641
1649
SG.initSchedGroup ();
1650
+
1642
1651
// Preserve original instruction ordering relative to the SCHED_BARRIER.
1643
1652
SG.link (
1644
1653
SchedBarrier,
@@ -1652,14 +1661,15 @@ IGroupLPDAGMutation::invertSchedBarrierMask(SchedGroupMask Mask) const {
1652
1661
// allowed past the SCHED_BARRIER.
1653
1662
SchedGroupMask InvertedMask = ~Mask;
1654
1663
1655
- // ALU implies VALU, SALU, MFMA.
1664
+ // ALU implies VALU, SALU, MFMA, TRANS .
1656
1665
if ((InvertedMask & SchedGroupMask::ALU) == SchedGroupMask::NONE)
1657
- InvertedMask &=
1658
- ~SchedGroupMask::VALU & ~SchedGroupMask::SALU & ~SchedGroupMask::MFMA ;
1659
- // VALU, SALU, MFMA implies ALU.
1666
+ InvertedMask &= ~SchedGroupMask::VALU & ~SchedGroupMask::SALU &
1667
+ ~SchedGroupMask::MFMA & ~SchedGroupMask::TRANS ;
1668
+ // VALU, SALU, MFMA, TRANS implies ALU.
1660
1669
else if ((InvertedMask & SchedGroupMask::VALU) == SchedGroupMask::NONE ||
1661
1670
(InvertedMask & SchedGroupMask::SALU) == SchedGroupMask::NONE ||
1662
- (InvertedMask & SchedGroupMask::MFMA) == SchedGroupMask::NONE)
1671
+ (InvertedMask & SchedGroupMask::MFMA) == SchedGroupMask::NONE ||
1672
+ (InvertedMask & SchedGroupMask::TRANS) == SchedGroupMask::NONE)
1663
1673
InvertedMask &= ~SchedGroupMask::ALU;
1664
1674
1665
1675
// VMEM implies VMEM_READ, VMEM_WRITE.
@@ -1678,6 +1688,9 @@ IGroupLPDAGMutation::invertSchedBarrierMask(SchedGroupMask Mask) const {
1678
1688
(InvertedMask & SchedGroupMask::DS_WRITE) == SchedGroupMask::NONE)
1679
1689
InvertedMask &= ~SchedGroupMask::DS;
1680
1690
1691
+ LLVM_DEBUG (dbgs () << " After Inverting, SchedGroup Mask: " << (int )InvertedMask
1692
+ << " \n " );
1693
+
1681
1694
return InvertedMask;
1682
1695
}
1683
1696
0 commit comments