@@ -854,7 +854,8 @@ class IGLPStrategy {
854
854
// Add SchedGroups to \p Pipeline to implement this Strategy.
855
855
virtual void applyIGLPStrategy (
856
856
DenseMap<int , SUnitsToCandidateSGsMap> &SyncedInstrs,
857
- DenseMap<int , SmallVector<SchedGroup, 4 >> &SyncedSchedGroups) = 0;
857
+ DenseMap<int , SmallVector<SchedGroup, 4 >> &SyncedSchedGroups,
858
+ bool IsPostRA) = 0;
858
859
859
860
// Returns true if this strategy should be applied to a ScheduleDAG.
860
861
virtual bool shouldApplyStrategy (ScheduleDAGInstrs *DAG) = 0;
@@ -872,7 +873,8 @@ class MFMASmallGemmOpt final : public IGLPStrategy {
872
873
public:
873
874
void applyIGLPStrategy (
874
875
DenseMap<int , SUnitsToCandidateSGsMap> &SyncedInstrs,
875
- DenseMap<int , SmallVector<SchedGroup, 4 >> &SyncedSchedGroups) override ;
876
+ DenseMap<int , SmallVector<SchedGroup, 4 >> &SyncedSchedGroups,
877
+ bool IsPostRA) override ;
876
878
877
879
bool shouldApplyStrategy (ScheduleDAGInstrs *DAG) override { return true ; }
878
880
@@ -884,7 +886,8 @@ class MFMASmallGemmOpt final : public IGLPStrategy {
884
886
885
887
void MFMASmallGemmOpt::applyIGLPStrategy (
886
888
DenseMap<int , SUnitsToCandidateSGsMap> &SyncedInstrs,
887
- DenseMap<int , SmallVector<SchedGroup, 4 >> &SyncedSchedGroups) {
889
+ DenseMap<int , SmallVector<SchedGroup, 4 >> &SyncedSchedGroups,
890
+ bool IsPostRA) {
888
891
// Count the number of MFMA instructions.
889
892
unsigned MFMACount = 0 ;
890
893
for (const MachineInstr &I : *DAG)
@@ -1080,9 +1083,12 @@ class MFMASmallGemmSingleWaveOpt final : public IGLPStrategy {
1080
1083
Cache->push_back (Pred.getSUnit ());
1081
1084
}
1082
1085
}
1086
+
1087
+ // If the other group has no PERM preds, then this group won't share any
1088
+ if (!Cache->size ())
1089
+ return false ;
1083
1090
}
1084
1091
1085
- assert (Cache->size ());
1086
1092
auto DAG = SyncPipe[0 ].DAG ;
1087
1093
// Does the previous DS_WRITE share a V_PERM predecessor with this
1088
1094
// VMEM_READ
@@ -1099,7 +1105,8 @@ class MFMASmallGemmSingleWaveOpt final : public IGLPStrategy {
1099
1105
public:
1100
1106
void applyIGLPStrategy (
1101
1107
DenseMap<int , SUnitsToCandidateSGsMap> &SyncedInstrs,
1102
- DenseMap<int , SmallVector<SchedGroup, 4 >> &SyncedSchedGroups) override ;
1108
+ DenseMap<int , SmallVector<SchedGroup, 4 >> &SyncedSchedGroups,
1109
+ bool IsPostRA) override ;
1103
1110
1104
1111
bool shouldApplyStrategy (ScheduleDAGInstrs *DAG) override { return true ; }
1105
1112
@@ -1109,14 +1116,20 @@ class MFMASmallGemmSingleWaveOpt final : public IGLPStrategy {
1109
1116
}
1110
1117
};
1111
1118
1119
+ static unsigned DSWCount = 0 ;
1120
+ static unsigned DSWWithPermCount = 0 ;
1121
+ static unsigned DSWWithSharedVMEMCount = 0 ;
1122
+
1112
1123
void MFMASmallGemmSingleWaveOpt::applyIGLPStrategy (
1113
1124
DenseMap<int , SUnitsToCandidateSGsMap> &SyncedInstrs,
1114
- DenseMap<int , SmallVector<SchedGroup, 4 >> &SyncedSchedGroups) {
1125
+ DenseMap<int , SmallVector<SchedGroup, 4 >> &SyncedSchedGroups,
1126
+ bool IsPostRA) {
1115
1127
unsigned MFMACount = 0 ;
1116
- unsigned DSWCount = 0 ;
1117
- unsigned DSWWithPermCount = 0 ;
1118
- unsigned DSWWithSharedVMEMCount = 0 ;
1119
1128
unsigned DSRCount = 0 ;
1129
+
1130
+ assert ((IsPostRA ||
1131
+ DSWCount == DSWWithPermCount == DSWWithSharedVMEMCount == 0 ) &&
1132
+ " DSWCounters should be zero in pre-RA scheduling!" );
1120
1133
SmallVector<SUnit *, 6 > DSWithPerms;
1121
1134
for (auto &SU : DAG->SUnits ) {
1122
1135
auto I = SU.getInstr ();
@@ -1125,7 +1138,7 @@ void MFMASmallGemmSingleWaveOpt::applyIGLPStrategy(
1125
1138
else if (TII->isDS (*I)) {
1126
1139
if (I->mayLoad ())
1127
1140
++DSRCount;
1128
- else if (I->mayStore ()) {
1141
+ else if (I->mayStore () && !IsPostRA ) {
1129
1142
++DSWCount;
1130
1143
for (auto Pred : SU.Preds ) {
1131
1144
if (Pred.getSUnit ()->getInstr ()->getOpcode () ==
@@ -1137,57 +1150,59 @@ void MFMASmallGemmSingleWaveOpt::applyIGLPStrategy(
1137
1150
}
1138
1151
}
1139
1152
}
1140
- DSWWithPermCount = DSWithPerms.size ();
1141
- auto I = DSWithPerms.begin ();
1142
- auto E = DSWithPerms.end ();
1143
-
1144
- // Get the count of DS_WRITES with V_PERM predecessors which
1145
- // have loop carried dependencies (WAR) on the same VMEM_READs.
1146
- // We consider partial overlap as a miss -- in other words,
1147
- // for a given DS_W, we only consider another DS_W as matching
1148
- // if there is a corresponding (in terms of the VMEM_R it uses) V_PERM pred
1149
- // for every V_PERM pred of this DS_W.
1150
- DenseMap<MachineInstr *, SUnit *> VMEMLookup;
1151
- SmallVector<SUnit *, 6 > Counted;
1152
- for (; I != E; I++) {
1153
- SUnit *Cand = nullptr ;
1154
- bool MissedAny = false ;
1155
- for (auto &Pred : (*I)->Preds ) {
1156
- if (Pred.getSUnit ()->getInstr ()->getOpcode () != AMDGPU::V_PERM_B32_e64)
1157
- continue ;
1158
1153
1159
- if (Cand &&
1160
- std::find (Counted.begin (), Counted.end (), Cand) != Counted.end ())
1161
- break ;
1162
-
1163
- for (auto &Succ : Pred.getSUnit ()->Succs ) {
1164
- auto MI = Succ.getSUnit ()->getInstr ();
1165
- if (!TII->isVMEM (*MI) || !MI->mayLoad ())
1154
+ if (!IsPostRA) {
1155
+ DSWWithPermCount = DSWithPerms.size ();
1156
+ auto I = DSWithPerms.begin ();
1157
+ auto E = DSWithPerms.end ();
1158
+
1159
+ // Get the count of DS_WRITES with V_PERM predecessors which
1160
+ // have loop carried dependencies (WAR) on the same VMEM_READs.
1161
+ // We consider partial overlap as a miss -- in other words,
1162
+ // for a given DS_W, we only consider another DS_W as matching
1163
+ // if there is a corresponding (in terms of the VMEM_R it uses) V_PERM pred
1164
+ // for every V_PERM pred of this DS_W.
1165
+ DenseMap<MachineInstr *, SUnit *> VMEMLookup;
1166
+ SmallVector<SUnit *, 6 > Counted;
1167
+ for (; I != E; I++) {
1168
+ SUnit *Cand = nullptr ;
1169
+ bool MissedAny = false ;
1170
+ for (auto &Pred : (*I)->Preds ) {
1171
+ if (Pred.getSUnit ()->getInstr ()->getOpcode () != AMDGPU::V_PERM_B32_e64)
1166
1172
continue ;
1167
1173
1168
- if (MissedAny || !VMEMLookup.size ()) {
1169
- MissedAny = true ;
1170
- VMEMLookup[MI] = *I;
1171
- continue ;
1172
- }
1174
+ if (Cand && llvm::is_contained (Counted, Cand))
1175
+ break ;
1173
1176
1174
- if (!VMEMLookup.contains (MI)) {
1175
- MissedAny = true ;
1176
- VMEMLookup[MI] = *I;
1177
- continue ;
1178
- }
1177
+ for (auto &Succ : Pred.getSUnit ()->Succs ) {
1178
+ auto MI = Succ.getSUnit ()->getInstr ();
1179
+ if (!TII->isVMEM (*MI) || !MI->mayLoad ())
1180
+ continue ;
1179
1181
1180
- Cand = VMEMLookup[MI];
1181
- if (std::find (Counted.begin (), Counted.end (), Cand) != Counted.end ()) {
1182
- MissedAny = true ;
1183
- break ;
1182
+ if (MissedAny || !VMEMLookup.size ()) {
1183
+ MissedAny = true ;
1184
+ VMEMLookup[MI] = *I;
1185
+ continue ;
1186
+ }
1187
+
1188
+ if (!VMEMLookup.contains (MI)) {
1189
+ MissedAny = true ;
1190
+ VMEMLookup[MI] = *I;
1191
+ continue ;
1192
+ }
1193
+
1194
+ Cand = VMEMLookup[MI];
1195
+ if (llvm::is_contained (Counted, Cand)) {
1196
+ MissedAny = true ;
1197
+ break ;
1198
+ }
1184
1199
}
1185
1200
}
1186
- }
1187
- if (!MissedAny && Cand) {
1188
- DSWWithSharedVMEMCount += 2 ;
1189
- Counted.push_back (Cand );
1190
- Counted. push_back (*I);
1201
+ if (!MissedAny && Cand) {
1202
+ DSWWithSharedVMEMCount += 2 ;
1203
+ Counted. push_back (Cand) ;
1204
+ Counted.push_back (*I );
1205
+ }
1191
1206
}
1192
1207
}
1193
1208
@@ -1403,7 +1418,11 @@ class IGroupLPDAGMutation : public ScheduleDAGMutation {
1403
1418
// first created SchedGroup first.
1404
1419
bool IsBottomUp = 1 ;
1405
1420
1421
+ // Whether the mutation is being applied to post RA scheduling
1422
+ bool IsPostRA = false ;
1423
+
1406
1424
IGroupLPDAGMutation () = default ;
1425
+ IGroupLPDAGMutation (bool IsPostRA) : IsPostRA(IsPostRA) {}
1407
1426
};
1408
1427
1409
1428
unsigned SchedGroup::NumSchedGroups = 0 ;
@@ -1691,16 +1710,16 @@ void IGroupLPDAGMutation::initIGLPOpt(SUnit &SU) {
1691
1710
auto S = createIGLPStrategy (StrategyID, DAG, TII);
1692
1711
if (S->shouldApplyStrategy (DAG)) {
1693
1712
IsBottomUp = S->IsBottomUp ;
1694
- S->applyIGLPStrategy (SyncedInstrs, SyncedSchedGroups);
1713
+ S->applyIGLPStrategy (SyncedInstrs, SyncedSchedGroups, IsPostRA );
1695
1714
}
1696
1715
}
1697
1716
1698
1717
} // namespace
1699
1718
1700
1719
namespace llvm {
1701
1720
1702
- std::unique_ptr<ScheduleDAGMutation> createIGroupLPDAGMutation () {
1703
- return std::make_unique<IGroupLPDAGMutation>();
1721
+ std::unique_ptr<ScheduleDAGMutation> createIGroupLPDAGMutation (bool IsPostRA ) {
1722
+ return std::make_unique<IGroupLPDAGMutation>(IsPostRA );
1704
1723
}
1705
1724
1706
1725
} // end namespace llvm
0 commit comments