Skip to content

[AMDGPU][GFX1250] Insert S_WAIT_XCNT for SMEM and VMEM load-stores #145566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 121 additions & 16 deletions llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ enum InstCounterType {
SAMPLE_CNT = NUM_NORMAL_INST_CNTS, // gfx12+ only.
BVH_CNT, // gfx12+ only.
KM_CNT, // gfx12+ only.
X_CNT, // gfx1250.
NUM_EXTENDED_INST_CNTS,
NUM_INST_CNTS = NUM_EXTENDED_INST_CNTS
};
Expand Down Expand Up @@ -102,6 +103,7 @@ struct HardwareLimits {
unsigned SamplecntMax; // gfx12+ only.
unsigned BvhcntMax; // gfx12+ only.
unsigned KmcntMax; // gfx12+ only.
unsigned XcntMax; // gfx1250.
};

#define AMDGPU_DECLARE_WAIT_EVENTS(DECL) \
Expand All @@ -111,10 +113,12 @@ struct HardwareLimits {
DECL(VMEM_BVH_READ_ACCESS) /* vmem BVH read (gfx12+ only) */ \
DECL(VMEM_WRITE_ACCESS) /* vmem write that is not scratch */ \
DECL(SCRATCH_WRITE_ACCESS) /* vmem write that may be scratch */ \
DECL(VMEM_GROUP) /* vmem group */ \
DECL(LDS_ACCESS) /* lds read & write */ \
DECL(GDS_ACCESS) /* gds read & write */ \
DECL(SQ_MESSAGE) /* send message */ \
DECL(SMEM_ACCESS) /* scalar-memory read & write */ \
DECL(SMEM_GROUP) /* scalar-memory group */ \
DECL(EXP_GPR_LOCK) /* export holding on its data src */ \
DECL(GDS_GPR_LOCK) /* GDS holding on its data and addr src */ \
DECL(EXP_POS_ACCESS) /* write to export position */ \
Expand Down Expand Up @@ -178,7 +182,7 @@ enum VmemType {
static const unsigned instrsForExtendedCounterTypes[NUM_EXTENDED_INST_CNTS] = {
AMDGPU::S_WAIT_LOADCNT, AMDGPU::S_WAIT_DSCNT, AMDGPU::S_WAIT_EXPCNT,
AMDGPU::S_WAIT_STORECNT, AMDGPU::S_WAIT_SAMPLECNT, AMDGPU::S_WAIT_BVHCNT,
AMDGPU::S_WAIT_KMCNT};
AMDGPU::S_WAIT_KMCNT, AMDGPU::S_WAIT_XCNT};

static bool updateVMCntOnly(const MachineInstr &Inst) {
return (SIInstrInfo::isVMEM(Inst) && !SIInstrInfo::isFLAT(Inst)) ||
Expand Down Expand Up @@ -223,6 +227,8 @@ unsigned &getCounterRef(AMDGPU::Waitcnt &Wait, InstCounterType T) {
return Wait.BvhCnt;
case KM_CNT:
return Wait.KmCnt;
case X_CNT:
return Wait.XCnt;
default:
llvm_unreachable("bad InstCounterType");
}
Expand Down Expand Up @@ -283,12 +289,27 @@ class WaitcntBrackets {
return Limits.BvhcntMax;
case KM_CNT:
return Limits.KmcntMax;
case X_CNT:
return Limits.XcntMax;
default:
break;
}
return 0;
}

bool isSmemCounter(InstCounterType T) const {
return T == SmemAccessCounter || T == X_CNT;
}

unsigned getSgprScoresIdx(InstCounterType T) const {
if (T == SmemAccessCounter)
return 0;
if (T == X_CNT)
return 1;

llvm_unreachable("Invalid SMEM counter");
}

unsigned getScoreLB(InstCounterType T) const {
assert(T < NUM_INST_CNTS);
return ScoreLBs[T];
Expand All @@ -307,8 +328,8 @@ class WaitcntBrackets {
if (GprNo < NUM_ALL_VGPRS) {
return VgprScores[T][GprNo];
}
assert(T == SmemAccessCounter);
return SgprScores[GprNo - NUM_ALL_VGPRS];
assert(isSmemCounter(T));
return SgprScores[getSgprScoresIdx(T)][GprNo - NUM_ALL_VGPRS];
}

bool merge(const WaitcntBrackets &Other);
Expand All @@ -331,6 +352,7 @@ class WaitcntBrackets {

void applyWaitcnt(const AMDGPU::Waitcnt &Wait);
void applyWaitcnt(InstCounterType T, unsigned Count);
void applyXcnt(const AMDGPU::Waitcnt &Wait);
void updateByEvent(const SIInstrInfo *TII, const SIRegisterInfo *TRI,
const MachineRegisterInfo *MRI, WaitEventType E,
MachineInstr &MI);
Expand Down Expand Up @@ -462,9 +484,11 @@ class WaitcntBrackets {
int VgprUB = -1;
int SgprUB = -1;
unsigned VgprScores[NUM_INST_CNTS][NUM_ALL_VGPRS] = {{0}};
// Wait cnt scores for every sgpr, only DS_CNT (corresponding to LGKMcnt
// pre-gfx12) or KM_CNT (gfx12+ only) are relevant.
unsigned SgprScores[SQ_MAX_PGM_SGPRS] = {0};
// Wait cnt scores for every sgpr, the DS_CNT (corresponding to LGKMcnt
// pre-gfx12) or KM_CNT (gfx12+ only), and X_CNT (gfx1250) are relevant.
// Row 0 represents the score for either DS_CNT or KM_CNT and row 1 keeps the
// X_CNT score.
unsigned SgprScores[2][SQ_MAX_PGM_SGPRS] = {{0}};
// Bitmask of the VmemTypes of VMEM instructions that might have a pending
// write to each vgpr.
unsigned char VgprVmemTypes[NUM_ALL_VGPRS] = {0};
Expand Down Expand Up @@ -572,6 +596,7 @@ class WaitcntGeneratorPreGFX12 : public WaitcntGenerator {
eventMask({VMEM_WRITE_ACCESS, SCRATCH_WRITE_ACCESS}),
0,
0,
0,
0};

return WaitEventMaskForInstPreGFX12;
Expand Down Expand Up @@ -607,7 +632,8 @@ class WaitcntGeneratorGFX12Plus : public WaitcntGenerator {
eventMask({VMEM_WRITE_ACCESS, SCRATCH_WRITE_ACCESS}),
eventMask({VMEM_SAMPLER_READ_ACCESS}),
eventMask({VMEM_BVH_READ_ACCESS}),
eventMask({SMEM_ACCESS, SQ_MESSAGE})};
eventMask({SMEM_ACCESS, SQ_MESSAGE}),
eventMask({VMEM_GROUP, SMEM_GROUP})};

return WaitEventMaskForInstGFX12Plus;
}
Expand Down Expand Up @@ -743,9 +769,12 @@ class SIInsertWaitcnts {
return VmemReadMapping[getVmemType(Inst)];
}

bool hasXcnt() const { return ST->hasWaitXCnt(); }

bool mayAccessVMEMThroughFlat(const MachineInstr &MI) const;
bool mayAccessLDSThroughFlat(const MachineInstr &MI) const;
bool mayAccessScratchThroughFlat(const MachineInstr &MI) const;
bool isVmemAccess(const MachineInstr &MI) const;
bool generateWaitcntInstBefore(MachineInstr &MI,
WaitcntBrackets &ScoreBrackets,
MachineInstr *OldWaitcntInstr,
Expand Down Expand Up @@ -837,9 +866,9 @@ void WaitcntBrackets::setScoreByInterval(RegInterval Interval,
VgprUB = std::max(VgprUB, RegNo);
VgprScores[CntTy][RegNo] = Score;
} else {
assert(CntTy == SmemAccessCounter);
assert(isSmemCounter(CntTy));
SgprUB = std::max(SgprUB, RegNo - NUM_ALL_VGPRS);
SgprScores[RegNo - NUM_ALL_VGPRS] = Score;
SgprScores[getSgprScoresIdx(CntTy)][RegNo - NUM_ALL_VGPRS] = Score;
}
}
}
Expand Down Expand Up @@ -976,6 +1005,13 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
setScoreByOperand(&Inst, TRI, MRI, Op, EXP_CNT, CurrScore);
}
}
} else if (T == X_CNT) {
for (const MachineOperand &Op : Inst.all_uses()) {
RegInterval Interval = getRegInterval(&Inst, MRI, TRI, Op);
for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
setRegScore(RegNo, T, CurrScore);
}
}
} else /* LGKM_CNT || EXP_CNT || VS_CNT || NUM_INST_CNTS */ {
// Match the score to the destination registers.
//
Expand Down Expand Up @@ -1080,6 +1116,9 @@ void WaitcntBrackets::print(raw_ostream &OS) const {
case KM_CNT:
OS << " KM_CNT(" << SR << "): ";
break;
case X_CNT:
OS << " X_CNT(" << SR << "): ";
break;
default:
OS << " UNKNOWN(" << SR << "): ";
break;
Expand All @@ -1100,8 +1139,8 @@ void WaitcntBrackets::print(raw_ostream &OS) const {
OS << RelScore << ":ds ";
}
}
// Also need to print sgpr scores for lgkm_cnt.
if (T == SmemAccessCounter) {
// Also need to print sgpr scores for lgkm_cnt or xcnt.
if (isSmemCounter(T)) {
for (int J = 0; J <= SgprUB; J++) {
unsigned RegScore = getRegScore(J + NUM_ALL_VGPRS, T);
if (RegScore <= LB)
Expand Down Expand Up @@ -1140,6 +1179,7 @@ void WaitcntBrackets::simplifyWaitcnt(AMDGPU::Waitcnt &Wait) const {
simplifyWaitcnt(SAMPLE_CNT, Wait.SampleCnt);
simplifyWaitcnt(BVH_CNT, Wait.BvhCnt);
simplifyWaitcnt(KM_CNT, Wait.KmCnt);
simplifyWaitcnt(X_CNT, Wait.XCnt);
}

void WaitcntBrackets::simplifyWaitcnt(InstCounterType T,
Expand Down Expand Up @@ -1191,6 +1231,7 @@ void WaitcntBrackets::applyWaitcnt(const AMDGPU::Waitcnt &Wait) {
applyWaitcnt(SAMPLE_CNT, Wait.SampleCnt);
applyWaitcnt(BVH_CNT, Wait.BvhCnt);
applyWaitcnt(KM_CNT, Wait.KmCnt);
applyXcnt(Wait);
}

void WaitcntBrackets::applyWaitcnt(InstCounterType T, unsigned Count) {
Expand All @@ -1207,11 +1248,29 @@ void WaitcntBrackets::applyWaitcnt(InstCounterType T, unsigned Count) {
}
}

void WaitcntBrackets::applyXcnt(const AMDGPU::Waitcnt &Wait) {
// Wait on XCNT is redundant if we are already waiting for a load to complete.
// SMEM can return out of order, so only omit XCNT wait if we are waiting till
// zero.
if (Wait.KmCnt == 0 && hasPendingEvent(SMEM_GROUP))
return applyWaitcnt(X_CNT, 0);
Comment on lines +1255 to +1256
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the logic here. What if there is also an pending VMEM? Then it is not safe to assume that waiting for SMEM to complete will also imply a wait for XCNT. See #145681 for examples that I think are miscompiled because of this.

Maybe you intended that there would never be pending SMEM_GROUP and VMEM_GROUP at the same time?


// If we have pending store we cannot optimize XCnt because we do not wait for
// stores. VMEM loads retun in order, so if we only have loads XCnt is
// decremented to the same number as LOADCnt.
if (Wait.LoadCnt != ~0u && hasPendingEvent(VMEM_GROUP) &&
!hasPendingEvent(STORE_CNT))
return applyWaitcnt(X_CNT, std::min(Wait.XCnt, Wait.LoadCnt));

applyWaitcnt(X_CNT, Wait.XCnt);
}

// Where there are multiple types of event in the bracket of a counter,
// the decrement may go out of order.
bool WaitcntBrackets::counterOutOfOrder(InstCounterType T) const {
// Scalar memory read always can go out of order.
if (T == SmemAccessCounter && hasPendingEvent(SMEM_ACCESS))
if ((T == SmemAccessCounter && hasPendingEvent(SMEM_ACCESS)) ||
(T == X_CNT && hasPendingEvent(SMEM_GROUP)))
return true;
return hasMixedPendingEvents(T);
}
Expand Down Expand Up @@ -1263,6 +1322,8 @@ static std::optional<InstCounterType> counterTypeForInstr(unsigned Opcode) {
return DS_CNT;
case AMDGPU::S_WAIT_KMCNT:
return KM_CNT;
case AMDGPU::S_WAIT_XCNT:
return X_CNT;
default:
return {};
}
Expand Down Expand Up @@ -1427,7 +1488,8 @@ WaitcntGeneratorPreGFX12::getAllZeroWaitcnt(bool IncludeVSCnt) const {

AMDGPU::Waitcnt
WaitcntGeneratorGFX12Plus::getAllZeroWaitcnt(bool IncludeVSCnt) const {
return AMDGPU::Waitcnt(0, 0, 0, IncludeVSCnt ? 0 : ~0u, 0, 0, 0);
return AMDGPU::Waitcnt(0, 0, 0, IncludeVSCnt ? 0 : ~0u, 0, 0, 0,
~0u /* XCNT */);
}

/// Combine consecutive S_WAIT_*CNT instructions that precede \p It and
Expand Down Expand Up @@ -1909,13 +1971,17 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
ScoreBrackets.determineWait(BVH_CNT, Interval, Wait);
ScoreBrackets.clearVgprVmemTypes(Interval);
}

if (Op.isDef() || ScoreBrackets.hasPendingEvent(EXP_LDS_ACCESS)) {
ScoreBrackets.determineWait(EXP_CNT, Interval, Wait);
}
ScoreBrackets.determineWait(DS_CNT, Interval, Wait);
} else {
ScoreBrackets.determineWait(SmemAccessCounter, Interval, Wait);
}

if (hasXcnt() && Op.isDef())
ScoreBrackets.determineWait(X_CNT, Interval, Wait);
}
}
}
Expand Down Expand Up @@ -1958,6 +2024,8 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
Wait.BvhCnt = 0;
if (ForceEmitWaitcnt[KM_CNT])
Wait.KmCnt = 0;
if (ForceEmitWaitcnt[X_CNT])
Wait.XCnt = 0;

if (FlushVmCnt) {
if (ScoreBrackets.hasPendingEvent(LOAD_CNT))
Expand Down Expand Up @@ -2007,6 +2075,21 @@ bool SIInsertWaitcnts::generateWaitcnt(AMDGPU::Waitcnt Wait,
<< "Update Instr: " << *It);
}

// XCnt may be already consumed by a load wait.
if (Wait.KmCnt == 0 && Wait.XCnt != ~0u &&
!ScoreBrackets.hasPendingEvent(SMEM_GROUP))
Wait.XCnt = ~0u;

if (Wait.LoadCnt == 0 && Wait.XCnt != ~0u &&
!ScoreBrackets.hasPendingEvent(VMEM_GROUP))
Wait.XCnt = ~0u;

// Since the translation for VMEM addresses occur in-order, we can skip the
// XCnt if the current instruction is of VMEM type and has a memory dependency
// with another VMEM instruction in flight.
if (Wait.XCnt != ~0u && isVmemAccess(*It))
Wait.XCnt = ~0u;

if (WCG->createNewWaitcnt(Block, It, Wait))
Modified = true;

Expand Down Expand Up @@ -2096,6 +2179,11 @@ bool SIInsertWaitcnts::mayAccessScratchThroughFlat(
});
}

bool SIInsertWaitcnts::isVmemAccess(const MachineInstr &MI) const {
return (TII->isFLAT(MI) && mayAccessVMEMThroughFlat(MI)) ||
(TII->isVMEM(MI) && !AMDGPU::getMUBUFIsBufferInv(MI.getOpcode()));
Comment on lines +2183 to +2184
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isVMEM now includes isFLAT, so I think this simplifies to:

Suggested change
return (TII->isFLAT(MI) && mayAccessVMEMThroughFlat(MI)) ||
(TII->isVMEM(MI) && !AMDGPU::getMUBUFIsBufferInv(MI.getOpcode()));
return TII->isVMEM(MI) && !AMDGPU::getMUBUFIsBufferInv(MI.getOpcode());

which is maybe not what you want because it has lost the mayAccessVMEMThroughFlat check. See #137148. Cc @ro-i.

}

static bool isGFX12CacheInvOrWBInst(MachineInstr &Inst) {
auto Opc = Inst.getOpcode();
return Opc == AMDGPU::GLOBAL_INV || Opc == AMDGPU::GLOBAL_WB ||
Expand Down Expand Up @@ -2167,6 +2255,8 @@ void SIInsertWaitcnts::updateEventWaitcntAfter(MachineInstr &Inst,
// bracket and the destination operand scores.
// TODO: Use the (TSFlags & SIInstrFlags::DS_CNT) property everywhere.

bool IsVMEMAccess = false;
bool IsSMEMAccess = false;
if (TII->isDS(Inst) && TII->usesLGKM_CNT(Inst)) {
if (TII->isAlwaysGDS(Inst.getOpcode()) ||
TII->hasModifiersSet(Inst, AMDGPU::OpName::gds)) {
Expand All @@ -2189,6 +2279,7 @@ void SIInsertWaitcnts::updateEventWaitcntAfter(MachineInstr &Inst,

if (mayAccessVMEMThroughFlat(Inst)) {
++FlatASCount;
IsVMEMAccess = true;
ScoreBrackets->updateByEvent(TII, TRI, MRI, getVmemWaitEventType(Inst),
Inst);
}
Expand All @@ -2208,6 +2299,7 @@ void SIInsertWaitcnts::updateEventWaitcntAfter(MachineInstr &Inst,
ScoreBrackets->setPendingFlat();
} else if (SIInstrInfo::isVMEM(Inst) &&
!llvm::AMDGPU::getMUBUFIsBufferInv(Inst.getOpcode())) {
IsVMEMAccess = true;
ScoreBrackets->updateByEvent(TII, TRI, MRI, getVmemWaitEventType(Inst),
Inst);

Expand All @@ -2216,6 +2308,7 @@ void SIInsertWaitcnts::updateEventWaitcntAfter(MachineInstr &Inst,
ScoreBrackets->updateByEvent(TII, TRI, MRI, VMW_GPR_LOCK, Inst);
}
} else if (TII->isSMRD(Inst)) {
IsSMEMAccess = true;
ScoreBrackets->updateByEvent(TII, TRI, MRI, SMEM_ACCESS, Inst);
} else if (Inst.isCall()) {
if (callWaitsOnFunctionReturn(Inst)) {
Expand Down Expand Up @@ -2258,6 +2351,15 @@ void SIInsertWaitcnts::updateEventWaitcntAfter(MachineInstr &Inst,
break;
}
}

if (!hasXcnt())
return;

if (IsVMEMAccess)
ScoreBrackets->updateByEvent(TII, TRI, MRI, VMEM_GROUP, Inst);

if (IsSMEMAccess)
ScoreBrackets->updateByEvent(TII, TRI, MRI, SMEM_GROUP, Inst);
}

bool WaitcntBrackets::mergeScore(const MergeInfo &M, unsigned &Score,
Expand Down Expand Up @@ -2311,9 +2413,11 @@ bool WaitcntBrackets::merge(const WaitcntBrackets &Other) {
for (int J = 0; J <= VgprUB; J++)
StrictDom |= mergeScore(M, VgprScores[T][J], Other.VgprScores[T][J]);

if (T == SmemAccessCounter) {
if (isSmemCounter(T)) {
unsigned Idx = getSgprScoresIdx(T);
for (int J = 0; J <= SgprUB; J++)
StrictDom |= mergeScore(M, SgprScores[J], Other.SgprScores[J]);
StrictDom |=
mergeScore(M, SgprScores[Idx][J], Other.SgprScores[Idx][J]);
}
}

Expand Down Expand Up @@ -2651,6 +2755,7 @@ bool SIInsertWaitcnts::run(MachineFunction &MF) {
Limits.SamplecntMax = AMDGPU::getSamplecntBitMask(IV);
Limits.BvhcntMax = AMDGPU::getBvhcntBitMask(IV);
Limits.KmcntMax = AMDGPU::getKmcntBitMask(IV);
Limits.XcntMax = AMDGPU::getXcntBitMask(IV);

[[maybe_unused]] unsigned NumVGPRsMax =
ST->getAddressableNumVGPRs(MFI->getDynamicVGPRBlockSize());
Expand Down Expand Up @@ -2679,7 +2784,7 @@ bool SIInsertWaitcnts::run(MachineFunction &MF) {
BuildMI(EntryBB, I, DebugLoc(), TII->get(AMDGPU::S_WAIT_LOADCNT_DSCNT))
.addImm(0);
for (auto CT : inst_counter_types(NUM_EXTENDED_INST_CNTS)) {
if (CT == LOAD_CNT || CT == DS_CNT || CT == STORE_CNT)
if (CT == LOAD_CNT || CT == DS_CNT || CT == STORE_CNT || CT == X_CNT)
continue;

if (!ST->hasImageInsts() &&
Expand Down
Loading