Skip to content

Commit 08b8d46

Browse files
authored
[AMDGPU][GFX1250] Insert S_WAIT_XCNT for SMEM and VMEM load-stores (#145566)
This patch tracks the register operands of both VMEM (FLAT, MUBUF, MTBUF) and SMEM load-store operations and inserts a S_WAIT_XCNT instruction with sufficient wait-count before potentially redefining them. For VMEM instructions, XNACK is returned in the same order as they were issued and hence non-zero counter values can be inserted. However, SMEM execution is out-of-order and so is their XNACK reception. Thus, only zero counter value can be inserted to capture SMEM dependencies.
1 parent deb3464 commit 08b8d46

File tree

5 files changed

+1630
-20
lines changed

5 files changed

+1630
-20
lines changed

llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp

Lines changed: 121 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ enum InstCounterType {
7373
SAMPLE_CNT = NUM_NORMAL_INST_CNTS, // gfx12+ only.
7474
BVH_CNT, // gfx12+ only.
7575
KM_CNT, // gfx12+ only.
76+
X_CNT, // gfx1250.
7677
NUM_EXTENDED_INST_CNTS,
7778
NUM_INST_CNTS = NUM_EXTENDED_INST_CNTS
7879
};
@@ -102,6 +103,7 @@ struct HardwareLimits {
102103
unsigned SamplecntMax; // gfx12+ only.
103104
unsigned BvhcntMax; // gfx12+ only.
104105
unsigned KmcntMax; // gfx12+ only.
106+
unsigned XcntMax; // gfx1250.
105107
};
106108

107109
#define AMDGPU_DECLARE_WAIT_EVENTS(DECL) \
@@ -111,10 +113,12 @@ struct HardwareLimits {
111113
DECL(VMEM_BVH_READ_ACCESS) /* vmem BVH read (gfx12+ only) */ \
112114
DECL(VMEM_WRITE_ACCESS) /* vmem write that is not scratch */ \
113115
DECL(SCRATCH_WRITE_ACCESS) /* vmem write that may be scratch */ \
116+
DECL(VMEM_GROUP) /* vmem group */ \
114117
DECL(LDS_ACCESS) /* lds read & write */ \
115118
DECL(GDS_ACCESS) /* gds read & write */ \
116119
DECL(SQ_MESSAGE) /* send message */ \
117120
DECL(SMEM_ACCESS) /* scalar-memory read & write */ \
121+
DECL(SMEM_GROUP) /* scalar-memory group */ \
118122
DECL(EXP_GPR_LOCK) /* export holding on its data src */ \
119123
DECL(GDS_GPR_LOCK) /* GDS holding on its data and addr src */ \
120124
DECL(EXP_POS_ACCESS) /* write to export position */ \
@@ -178,7 +182,7 @@ enum VmemType {
178182
static const unsigned instrsForExtendedCounterTypes[NUM_EXTENDED_INST_CNTS] = {
179183
AMDGPU::S_WAIT_LOADCNT, AMDGPU::S_WAIT_DSCNT, AMDGPU::S_WAIT_EXPCNT,
180184
AMDGPU::S_WAIT_STORECNT, AMDGPU::S_WAIT_SAMPLECNT, AMDGPU::S_WAIT_BVHCNT,
181-
AMDGPU::S_WAIT_KMCNT};
185+
AMDGPU::S_WAIT_KMCNT, AMDGPU::S_WAIT_XCNT};
182186

183187
static bool updateVMCntOnly(const MachineInstr &Inst) {
184188
return (SIInstrInfo::isVMEM(Inst) && !SIInstrInfo::isFLAT(Inst)) ||
@@ -223,6 +227,8 @@ unsigned &getCounterRef(AMDGPU::Waitcnt &Wait, InstCounterType T) {
223227
return Wait.BvhCnt;
224228
case KM_CNT:
225229
return Wait.KmCnt;
230+
case X_CNT:
231+
return Wait.XCnt;
226232
default:
227233
llvm_unreachable("bad InstCounterType");
228234
}
@@ -283,12 +289,27 @@ class WaitcntBrackets {
283289
return Limits.BvhcntMax;
284290
case KM_CNT:
285291
return Limits.KmcntMax;
292+
case X_CNT:
293+
return Limits.XcntMax;
286294
default:
287295
break;
288296
}
289297
return 0;
290298
}
291299

300+
bool isSmemCounter(InstCounterType T) const {
301+
return T == SmemAccessCounter || T == X_CNT;
302+
}
303+
304+
unsigned getSgprScoresIdx(InstCounterType T) const {
305+
if (T == SmemAccessCounter)
306+
return 0;
307+
if (T == X_CNT)
308+
return 1;
309+
310+
llvm_unreachable("Invalid SMEM counter");
311+
}
312+
292313
unsigned getScoreLB(InstCounterType T) const {
293314
assert(T < NUM_INST_CNTS);
294315
return ScoreLBs[T];
@@ -307,8 +328,8 @@ class WaitcntBrackets {
307328
if (GprNo < NUM_ALL_VGPRS) {
308329
return VgprScores[T][GprNo];
309330
}
310-
assert(T == SmemAccessCounter);
311-
return SgprScores[GprNo - NUM_ALL_VGPRS];
331+
assert(isSmemCounter(T));
332+
return SgprScores[getSgprScoresIdx(T)][GprNo - NUM_ALL_VGPRS];
312333
}
313334

314335
bool merge(const WaitcntBrackets &Other);
@@ -331,6 +352,7 @@ class WaitcntBrackets {
331352

332353
void applyWaitcnt(const AMDGPU::Waitcnt &Wait);
333354
void applyWaitcnt(InstCounterType T, unsigned Count);
355+
void applyXcnt(const AMDGPU::Waitcnt &Wait);
334356
void updateByEvent(const SIInstrInfo *TII, const SIRegisterInfo *TRI,
335357
const MachineRegisterInfo *MRI, WaitEventType E,
336358
MachineInstr &MI);
@@ -462,9 +484,11 @@ class WaitcntBrackets {
462484
int VgprUB = -1;
463485
int SgprUB = -1;
464486
unsigned VgprScores[NUM_INST_CNTS][NUM_ALL_VGPRS] = {{0}};
465-
// Wait cnt scores for every sgpr, only DS_CNT (corresponding to LGKMcnt
466-
// pre-gfx12) or KM_CNT (gfx12+ only) are relevant.
467-
unsigned SgprScores[SQ_MAX_PGM_SGPRS] = {0};
487+
// Wait cnt scores for every sgpr, the DS_CNT (corresponding to LGKMcnt
488+
// pre-gfx12) or KM_CNT (gfx12+ only), and X_CNT (gfx1250) are relevant.
489+
// Row 0 represents the score for either DS_CNT or KM_CNT and row 1 keeps the
490+
// X_CNT score.
491+
unsigned SgprScores[2][SQ_MAX_PGM_SGPRS] = {{0}};
468492
// Bitmask of the VmemTypes of VMEM instructions that might have a pending
469493
// write to each vgpr.
470494
unsigned char VgprVmemTypes[NUM_ALL_VGPRS] = {0};
@@ -572,6 +596,7 @@ class WaitcntGeneratorPreGFX12 : public WaitcntGenerator {
572596
eventMask({VMEM_WRITE_ACCESS, SCRATCH_WRITE_ACCESS}),
573597
0,
574598
0,
599+
0,
575600
0};
576601

577602
return WaitEventMaskForInstPreGFX12;
@@ -607,7 +632,8 @@ class WaitcntGeneratorGFX12Plus : public WaitcntGenerator {
607632
eventMask({VMEM_WRITE_ACCESS, SCRATCH_WRITE_ACCESS}),
608633
eventMask({VMEM_SAMPLER_READ_ACCESS}),
609634
eventMask({VMEM_BVH_READ_ACCESS}),
610-
eventMask({SMEM_ACCESS, SQ_MESSAGE})};
635+
eventMask({SMEM_ACCESS, SQ_MESSAGE}),
636+
eventMask({VMEM_GROUP, SMEM_GROUP})};
611637

612638
return WaitEventMaskForInstGFX12Plus;
613639
}
@@ -743,9 +769,12 @@ class SIInsertWaitcnts {
743769
return VmemReadMapping[getVmemType(Inst)];
744770
}
745771

772+
bool hasXcnt() const { return ST->hasWaitXCnt(); }
773+
746774
bool mayAccessVMEMThroughFlat(const MachineInstr &MI) const;
747775
bool mayAccessLDSThroughFlat(const MachineInstr &MI) const;
748776
bool mayAccessScratchThroughFlat(const MachineInstr &MI) const;
777+
bool isVmemAccess(const MachineInstr &MI) const;
749778
bool generateWaitcntInstBefore(MachineInstr &MI,
750779
WaitcntBrackets &ScoreBrackets,
751780
MachineInstr *OldWaitcntInstr,
@@ -837,9 +866,9 @@ void WaitcntBrackets::setScoreByInterval(RegInterval Interval,
837866
VgprUB = std::max(VgprUB, RegNo);
838867
VgprScores[CntTy][RegNo] = Score;
839868
} else {
840-
assert(CntTy == SmemAccessCounter);
869+
assert(isSmemCounter(CntTy));
841870
SgprUB = std::max(SgprUB, RegNo - NUM_ALL_VGPRS);
842-
SgprScores[RegNo - NUM_ALL_VGPRS] = Score;
871+
SgprScores[getSgprScoresIdx(CntTy)][RegNo - NUM_ALL_VGPRS] = Score;
843872
}
844873
}
845874
}
@@ -976,6 +1005,13 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
9761005
setScoreByOperand(&Inst, TRI, MRI, Op, EXP_CNT, CurrScore);
9771006
}
9781007
}
1008+
} else if (T == X_CNT) {
1009+
for (const MachineOperand &Op : Inst.all_uses()) {
1010+
RegInterval Interval = getRegInterval(&Inst, MRI, TRI, Op);
1011+
for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
1012+
setRegScore(RegNo, T, CurrScore);
1013+
}
1014+
}
9791015
} else /* LGKM_CNT || EXP_CNT || VS_CNT || NUM_INST_CNTS */ {
9801016
// Match the score to the destination registers.
9811017
//
@@ -1080,6 +1116,9 @@ void WaitcntBrackets::print(raw_ostream &OS) const {
10801116
case KM_CNT:
10811117
OS << " KM_CNT(" << SR << "): ";
10821118
break;
1119+
case X_CNT:
1120+
OS << " X_CNT(" << SR << "): ";
1121+
break;
10831122
default:
10841123
OS << " UNKNOWN(" << SR << "): ";
10851124
break;
@@ -1100,8 +1139,8 @@ void WaitcntBrackets::print(raw_ostream &OS) const {
11001139
OS << RelScore << ":ds ";
11011140
}
11021141
}
1103-
// Also need to print sgpr scores for lgkm_cnt.
1104-
if (T == SmemAccessCounter) {
1142+
// Also need to print sgpr scores for lgkm_cnt or xcnt.
1143+
if (isSmemCounter(T)) {
11051144
for (int J = 0; J <= SgprUB; J++) {
11061145
unsigned RegScore = getRegScore(J + NUM_ALL_VGPRS, T);
11071146
if (RegScore <= LB)
@@ -1140,6 +1179,7 @@ void WaitcntBrackets::simplifyWaitcnt(AMDGPU::Waitcnt &Wait) const {
11401179
simplifyWaitcnt(SAMPLE_CNT, Wait.SampleCnt);
11411180
simplifyWaitcnt(BVH_CNT, Wait.BvhCnt);
11421181
simplifyWaitcnt(KM_CNT, Wait.KmCnt);
1182+
simplifyWaitcnt(X_CNT, Wait.XCnt);
11431183
}
11441184

11451185
void WaitcntBrackets::simplifyWaitcnt(InstCounterType T,
@@ -1191,6 +1231,7 @@ void WaitcntBrackets::applyWaitcnt(const AMDGPU::Waitcnt &Wait) {
11911231
applyWaitcnt(SAMPLE_CNT, Wait.SampleCnt);
11921232
applyWaitcnt(BVH_CNT, Wait.BvhCnt);
11931233
applyWaitcnt(KM_CNT, Wait.KmCnt);
1234+
applyXcnt(Wait);
11941235
}
11951236

11961237
void WaitcntBrackets::applyWaitcnt(InstCounterType T, unsigned Count) {
@@ -1207,11 +1248,29 @@ void WaitcntBrackets::applyWaitcnt(InstCounterType T, unsigned Count) {
12071248
}
12081249
}
12091250

1251+
void WaitcntBrackets::applyXcnt(const AMDGPU::Waitcnt &Wait) {
1252+
// Wait on XCNT is redundant if we are already waiting for a load to complete.
1253+
// SMEM can return out of order, so only omit XCNT wait if we are waiting till
1254+
// zero.
1255+
if (Wait.KmCnt == 0 && hasPendingEvent(SMEM_GROUP))
1256+
return applyWaitcnt(X_CNT, 0);
1257+
1258+
// If we have pending store we cannot optimize XCnt because we do not wait for
1259+
// stores. VMEM loads retun in order, so if we only have loads XCnt is
1260+
// decremented to the same number as LOADCnt.
1261+
if (Wait.LoadCnt != ~0u && hasPendingEvent(VMEM_GROUP) &&
1262+
!hasPendingEvent(STORE_CNT))
1263+
return applyWaitcnt(X_CNT, std::min(Wait.XCnt, Wait.LoadCnt));
1264+
1265+
applyWaitcnt(X_CNT, Wait.XCnt);
1266+
}
1267+
12101268
// Where there are multiple types of event in the bracket of a counter,
12111269
// the decrement may go out of order.
12121270
bool WaitcntBrackets::counterOutOfOrder(InstCounterType T) const {
12131271
// Scalar memory read always can go out of order.
1214-
if (T == SmemAccessCounter && hasPendingEvent(SMEM_ACCESS))
1272+
if ((T == SmemAccessCounter && hasPendingEvent(SMEM_ACCESS)) ||
1273+
(T == X_CNT && hasPendingEvent(SMEM_GROUP)))
12151274
return true;
12161275
return hasMixedPendingEvents(T);
12171276
}
@@ -1263,6 +1322,8 @@ static std::optional<InstCounterType> counterTypeForInstr(unsigned Opcode) {
12631322
return DS_CNT;
12641323
case AMDGPU::S_WAIT_KMCNT:
12651324
return KM_CNT;
1325+
case AMDGPU::S_WAIT_XCNT:
1326+
return X_CNT;
12661327
default:
12671328
return {};
12681329
}
@@ -1427,7 +1488,8 @@ WaitcntGeneratorPreGFX12::getAllZeroWaitcnt(bool IncludeVSCnt) const {
14271488

14281489
AMDGPU::Waitcnt
14291490
WaitcntGeneratorGFX12Plus::getAllZeroWaitcnt(bool IncludeVSCnt) const {
1430-
return AMDGPU::Waitcnt(0, 0, 0, IncludeVSCnt ? 0 : ~0u, 0, 0, 0);
1491+
return AMDGPU::Waitcnt(0, 0, 0, IncludeVSCnt ? 0 : ~0u, 0, 0, 0,
1492+
~0u /* XCNT */);
14311493
}
14321494

14331495
/// Combine consecutive S_WAIT_*CNT instructions that precede \p It and
@@ -1909,13 +1971,17 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
19091971
ScoreBrackets.determineWait(BVH_CNT, Interval, Wait);
19101972
ScoreBrackets.clearVgprVmemTypes(Interval);
19111973
}
1974+
19121975
if (Op.isDef() || ScoreBrackets.hasPendingEvent(EXP_LDS_ACCESS)) {
19131976
ScoreBrackets.determineWait(EXP_CNT, Interval, Wait);
19141977
}
19151978
ScoreBrackets.determineWait(DS_CNT, Interval, Wait);
19161979
} else {
19171980
ScoreBrackets.determineWait(SmemAccessCounter, Interval, Wait);
19181981
}
1982+
1983+
if (hasXcnt() && Op.isDef())
1984+
ScoreBrackets.determineWait(X_CNT, Interval, Wait);
19191985
}
19201986
}
19211987
}
@@ -1958,6 +2024,8 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
19582024
Wait.BvhCnt = 0;
19592025
if (ForceEmitWaitcnt[KM_CNT])
19602026
Wait.KmCnt = 0;
2027+
if (ForceEmitWaitcnt[X_CNT])
2028+
Wait.XCnt = 0;
19612029

19622030
if (FlushVmCnt) {
19632031
if (ScoreBrackets.hasPendingEvent(LOAD_CNT))
@@ -2007,6 +2075,21 @@ bool SIInsertWaitcnts::generateWaitcnt(AMDGPU::Waitcnt Wait,
20072075
<< "Update Instr: " << *It);
20082076
}
20092077

2078+
// XCnt may be already consumed by a load wait.
2079+
if (Wait.KmCnt == 0 && Wait.XCnt != ~0u &&
2080+
!ScoreBrackets.hasPendingEvent(SMEM_GROUP))
2081+
Wait.XCnt = ~0u;
2082+
2083+
if (Wait.LoadCnt == 0 && Wait.XCnt != ~0u &&
2084+
!ScoreBrackets.hasPendingEvent(VMEM_GROUP))
2085+
Wait.XCnt = ~0u;
2086+
2087+
// Since the translation for VMEM addresses occur in-order, we can skip the
2088+
// XCnt if the current instruction is of VMEM type and has a memory dependency
2089+
// with another VMEM instruction in flight.
2090+
if (Wait.XCnt != ~0u && isVmemAccess(*It))
2091+
Wait.XCnt = ~0u;
2092+
20102093
if (WCG->createNewWaitcnt(Block, It, Wait))
20112094
Modified = true;
20122095

@@ -2096,6 +2179,11 @@ bool SIInsertWaitcnts::mayAccessScratchThroughFlat(
20962179
});
20972180
}
20982181

2182+
bool SIInsertWaitcnts::isVmemAccess(const MachineInstr &MI) const {
2183+
return (TII->isFLAT(MI) && mayAccessVMEMThroughFlat(MI)) ||
2184+
(TII->isVMEM(MI) && !AMDGPU::getMUBUFIsBufferInv(MI.getOpcode()));
2185+
}
2186+
20992187
static bool isGFX12CacheInvOrWBInst(MachineInstr &Inst) {
21002188
auto Opc = Inst.getOpcode();
21012189
return Opc == AMDGPU::GLOBAL_INV || Opc == AMDGPU::GLOBAL_WB ||
@@ -2167,6 +2255,8 @@ void SIInsertWaitcnts::updateEventWaitcntAfter(MachineInstr &Inst,
21672255
// bracket and the destination operand scores.
21682256
// TODO: Use the (TSFlags & SIInstrFlags::DS_CNT) property everywhere.
21692257

2258+
bool IsVMEMAccess = false;
2259+
bool IsSMEMAccess = false;
21702260
if (TII->isDS(Inst) && TII->usesLGKM_CNT(Inst)) {
21712261
if (TII->isAlwaysGDS(Inst.getOpcode()) ||
21722262
TII->hasModifiersSet(Inst, AMDGPU::OpName::gds)) {
@@ -2189,6 +2279,7 @@ void SIInsertWaitcnts::updateEventWaitcntAfter(MachineInstr &Inst,
21892279

21902280
if (mayAccessVMEMThroughFlat(Inst)) {
21912281
++FlatASCount;
2282+
IsVMEMAccess = true;
21922283
ScoreBrackets->updateByEvent(TII, TRI, MRI, getVmemWaitEventType(Inst),
21932284
Inst);
21942285
}
@@ -2208,6 +2299,7 @@ void SIInsertWaitcnts::updateEventWaitcntAfter(MachineInstr &Inst,
22082299
ScoreBrackets->setPendingFlat();
22092300
} else if (SIInstrInfo::isVMEM(Inst) &&
22102301
!llvm::AMDGPU::getMUBUFIsBufferInv(Inst.getOpcode())) {
2302+
IsVMEMAccess = true;
22112303
ScoreBrackets->updateByEvent(TII, TRI, MRI, getVmemWaitEventType(Inst),
22122304
Inst);
22132305

@@ -2216,6 +2308,7 @@ void SIInsertWaitcnts::updateEventWaitcntAfter(MachineInstr &Inst,
22162308
ScoreBrackets->updateByEvent(TII, TRI, MRI, VMW_GPR_LOCK, Inst);
22172309
}
22182310
} else if (TII->isSMRD(Inst)) {
2311+
IsSMEMAccess = true;
22192312
ScoreBrackets->updateByEvent(TII, TRI, MRI, SMEM_ACCESS, Inst);
22202313
} else if (Inst.isCall()) {
22212314
if (callWaitsOnFunctionReturn(Inst)) {
@@ -2258,6 +2351,15 @@ void SIInsertWaitcnts::updateEventWaitcntAfter(MachineInstr &Inst,
22582351
break;
22592352
}
22602353
}
2354+
2355+
if (!hasXcnt())
2356+
return;
2357+
2358+
if (IsVMEMAccess)
2359+
ScoreBrackets->updateByEvent(TII, TRI, MRI, VMEM_GROUP, Inst);
2360+
2361+
if (IsSMEMAccess)
2362+
ScoreBrackets->updateByEvent(TII, TRI, MRI, SMEM_GROUP, Inst);
22612363
}
22622364

22632365
bool WaitcntBrackets::mergeScore(const MergeInfo &M, unsigned &Score,
@@ -2311,9 +2413,11 @@ bool WaitcntBrackets::merge(const WaitcntBrackets &Other) {
23112413
for (int J = 0; J <= VgprUB; J++)
23122414
StrictDom |= mergeScore(M, VgprScores[T][J], Other.VgprScores[T][J]);
23132415

2314-
if (T == SmemAccessCounter) {
2416+
if (isSmemCounter(T)) {
2417+
unsigned Idx = getSgprScoresIdx(T);
23152418
for (int J = 0; J <= SgprUB; J++)
2316-
StrictDom |= mergeScore(M, SgprScores[J], Other.SgprScores[J]);
2419+
StrictDom |=
2420+
mergeScore(M, SgprScores[Idx][J], Other.SgprScores[Idx][J]);
23172421
}
23182422
}
23192423

@@ -2651,6 +2755,7 @@ bool SIInsertWaitcnts::run(MachineFunction &MF) {
26512755
Limits.SamplecntMax = AMDGPU::getSamplecntBitMask(IV);
26522756
Limits.BvhcntMax = AMDGPU::getBvhcntBitMask(IV);
26532757
Limits.KmcntMax = AMDGPU::getKmcntBitMask(IV);
2758+
Limits.XcntMax = AMDGPU::getXcntBitMask(IV);
26542759

26552760
[[maybe_unused]] unsigned NumVGPRsMax =
26562761
ST->getAddressableNumVGPRs(MFI->getDynamicVGPRBlockSize());
@@ -2679,7 +2784,7 @@ bool SIInsertWaitcnts::run(MachineFunction &MF) {
26792784
BuildMI(EntryBB, I, DebugLoc(), TII->get(AMDGPU::S_WAIT_LOADCNT_DSCNT))
26802785
.addImm(0);
26812786
for (auto CT : inst_counter_types(NUM_EXTENDED_INST_CNTS)) {
2682-
if (CT == LOAD_CNT || CT == DS_CNT || CT == STORE_CNT)
2787+
if (CT == LOAD_CNT || CT == DS_CNT || CT == STORE_CNT || CT == X_CNT)
26832788
continue;
26842789

26852790
if (!ST->hasImageInsts() &&

0 commit comments

Comments
 (0)