@@ -75,8 +75,9 @@ enum class SchedGroupMask {
75
75
DS = 1u << 7 ,
76
76
DS_READ = 1u << 8 ,
77
77
DS_WRITE = 1u << 9 ,
78
+ TRANS = 1u << 10 ,
78
79
ALL = ALU | VALU | SALU | MFMA | VMEM | VMEM_READ | VMEM_WRITE | DS |
79
- DS_READ | DS_WRITE,
80
+ DS_READ | DS_WRITE | TRANS ,
80
81
LLVM_MARK_AS_BITMASK_ENUM (/* LargestFlag = */ ALL)
81
82
};
82
83
@@ -1437,11 +1438,12 @@ bool SchedGroup::canAddMI(const MachineInstr &MI) const {
1437
1438
Result = false ;
1438
1439
1439
1440
else if (((SGMask & SchedGroupMask::ALU) != SchedGroupMask::NONE) &&
1440
- (TII->isVALU (MI) || TII->isMFMAorWMMA (MI) || TII->isSALU (MI)))
1441
+ (TII->isVALU (MI) || TII->isMFMAorWMMA (MI) || TII->isSALU (MI) ||
1442
+ TII->isTRANS (MI)))
1441
1443
Result = true ;
1442
1444
1443
1445
else if (((SGMask & SchedGroupMask::VALU) != SchedGroupMask::NONE) &&
1444
- TII->isVALU (MI) && !TII->isMFMAorWMMA (MI))
1446
+ TII->isVALU (MI) && !TII->isMFMAorWMMA (MI) && !TII-> isTRANS (MI) )
1445
1447
Result = true ;
1446
1448
1447
1449
else if (((SGMask & SchedGroupMask::SALU) != SchedGroupMask::NONE) &&
@@ -1478,6 +1480,10 @@ bool SchedGroup::canAddMI(const MachineInstr &MI) const {
1478
1480
MI.mayStore () && TII->isDS (MI))
1479
1481
Result = true ;
1480
1482
1483
+ else if (((SGMask & SchedGroupMask::TRANS) != SchedGroupMask::NONE) &&
1484
+ TII->isTRANS (MI))
1485
+ Result = true ;
1486
+
1481
1487
LLVM_DEBUG (
1482
1488
dbgs () << " For SchedGroup with mask " << format_hex ((int )SGMask, 10 , true )
1483
1489
<< (Result ? " could classify " : " unable to classify " ) << MI);
@@ -1637,10 +1643,13 @@ void IGroupLPDAGMutation::addSchedBarrierEdges(SUnit &SchedBarrier) {
1637
1643
// Remove all existing edges from the SCHED_BARRIER that were added due to the
1638
1644
// instruction having side effects.
1639
1645
resetEdges (SchedBarrier, DAG);
1646
+ LLVM_DEBUG (dbgs () << " Building SchedGroup for SchedBarrier with Mask: "
1647
+ << MI.getOperand (0 ).getImm () << " \n " );
1640
1648
auto InvertedMask =
1641
1649
invertSchedBarrierMask ((SchedGroupMask)MI.getOperand (0 ).getImm ());
1642
1650
SchedGroup SG (InvertedMask, std::nullopt, DAG, TII);
1643
1651
SG.initSchedGroup ();
1652
+
1644
1653
// Preserve original instruction ordering relative to the SCHED_BARRIER.
1645
1654
SG.link (
1646
1655
SchedBarrier,
@@ -1654,14 +1663,15 @@ IGroupLPDAGMutation::invertSchedBarrierMask(SchedGroupMask Mask) const {
1654
1663
// allowed past the SCHED_BARRIER.
1655
1664
SchedGroupMask InvertedMask = ~Mask;
1656
1665
1657
- // ALU implies VALU, SALU, MFMA.
1666
+ // ALU implies VALU, SALU, MFMA, TRANS .
1658
1667
if ((InvertedMask & SchedGroupMask::ALU) == SchedGroupMask::NONE)
1659
- InvertedMask &=
1660
- ~SchedGroupMask::VALU & ~SchedGroupMask::SALU & ~SchedGroupMask::MFMA ;
1661
- // VALU, SALU, MFMA implies ALU.
1668
+ InvertedMask &= ~SchedGroupMask::VALU & ~SchedGroupMask::SALU &
1669
+ ~SchedGroupMask::MFMA & ~SchedGroupMask::TRANS ;
1670
+ // VALU, SALU, MFMA, TRANS implies ALU.
1662
1671
else if ((InvertedMask & SchedGroupMask::VALU) == SchedGroupMask::NONE ||
1663
1672
(InvertedMask & SchedGroupMask::SALU) == SchedGroupMask::NONE ||
1664
- (InvertedMask & SchedGroupMask::MFMA) == SchedGroupMask::NONE)
1673
+ (InvertedMask & SchedGroupMask::MFMA) == SchedGroupMask::NONE ||
1674
+ (InvertedMask & SchedGroupMask::TRANS) == SchedGroupMask::NONE)
1665
1675
InvertedMask &= ~SchedGroupMask::ALU;
1666
1676
1667
1677
// VMEM implies VMEM_READ, VMEM_WRITE.
@@ -1680,6 +1690,9 @@ IGroupLPDAGMutation::invertSchedBarrierMask(SchedGroupMask Mask) const {
1680
1690
(InvertedMask & SchedGroupMask::DS_WRITE) == SchedGroupMask::NONE)
1681
1691
InvertedMask &= ~SchedGroupMask::DS;
1682
1692
1693
+ LLVM_DEBUG (dbgs () << " After Inverting, SchedGroup Mask: " << (int )InvertedMask
1694
+ << " \n " );
1695
+
1683
1696
return InvertedMask;
1684
1697
}
1685
1698
0 commit comments