Skip to content

Commit 021def6

Browse files
authored
[AMDGPU] Use alias info to relax waitcounts for LDS DMA (#74537)
LDA DMA loads increase VMCNT and a load from the LDS stored must wait on this counter to only read memory after it is written. Wait count insertion pass does not track memory dependencies, it tracks register dependencies. To model the LDS dependency a pseudo register is used in the scoreboard, acting like if LDS DMA writes it and LDS load reads it. This patch adds 8 more pseudo registers to use for independent LDS locations if we can prove they are disjoint using alias analysis. Fixes: SWDEV-433427
1 parent c3cc09b commit 021def6

File tree

2 files changed

+171
-21
lines changed

2 files changed

+171
-21
lines changed

llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "llvm/ADT/MapVector.h"
3232
#include "llvm/ADT/PostOrderIterator.h"
3333
#include "llvm/ADT/Sequence.h"
34+
#include "llvm/Analysis/AliasAnalysis.h"
3435
#include "llvm/CodeGen/MachineLoopInfo.h"
3536
#include "llvm/CodeGen/MachinePostDominators.h"
3637
#include "llvm/InitializePasses.h"
@@ -121,8 +122,13 @@ enum RegisterMapping {
121122
SQ_MAX_PGM_VGPRS = 512, // Maximum programmable VGPRs across all targets.
122123
AGPR_OFFSET = 256, // Maximum programmable ArchVGPRs across all targets.
123124
SQ_MAX_PGM_SGPRS = 256, // Maximum programmable SGPRs across all targets.
124-
NUM_EXTRA_VGPRS = 1, // A reserved slot for DS.
125-
EXTRA_VGPR_LDS = 0, // An artificial register to track LDS writes.
125+
NUM_EXTRA_VGPRS = 9, // Reserved slots for DS.
126+
// Artificial register slots to track LDS writes into specific LDS locations
127+
// if a location is known. When slots are exhausted or location is
128+
// unknown use the first slot. The first slot is also always updated in
129+
// addition to known location's slot to properly generate waits if dependent
130+
// instruction's location is unknown.
131+
EXTRA_VGPR_LDS = 0,
126132
NUM_ALL_VGPRS = SQ_MAX_PGM_VGPRS + NUM_EXTRA_VGPRS, // Where SGPR starts.
127133
};
128134

@@ -297,6 +303,10 @@ class WaitcntBrackets {
297303
PendingEvents |= WaitEventMaskForInst[VS_CNT];
298304
}
299305

306+
ArrayRef<const MachineInstr *> getLDSDMAStores() const {
307+
return LDSDMAStores;
308+
}
309+
300310
void print(raw_ostream &);
301311
void dump() { print(dbgs()); }
302312

@@ -359,6 +369,9 @@ class WaitcntBrackets {
359369
// Bitmask of the VmemTypes of VMEM instructions that might have a pending
360370
// write to each vgpr.
361371
unsigned char VgprVmemTypes[NUM_ALL_VGPRS] = {0};
372+
// Store representative LDS DMA operations. The only useful info here is
373+
// alias info. One store is kept per unique AAInfo.
374+
SmallVector<const MachineInstr *, NUM_EXTRA_VGPRS - 1> LDSDMAStores;
362375
};
363376

364377
class SIInsertWaitcnts : public MachineFunctionPass {
@@ -373,6 +386,7 @@ class SIInsertWaitcnts : public MachineFunctionPass {
373386
DenseMap<MachineBasicBlock *, bool> PreheadersToFlush;
374387
MachineLoopInfo *MLI;
375388
MachinePostDominatorTree *PDT;
389+
AliasAnalysis *AA = nullptr;
376390

377391
struct BlockInfo {
378392
std::unique_ptr<WaitcntBrackets> Incoming;
@@ -415,6 +429,8 @@ class SIInsertWaitcnts : public MachineFunctionPass {
415429
AU.setPreservesCFG();
416430
AU.addRequired<MachineLoopInfo>();
417431
AU.addRequired<MachinePostDominatorTree>();
432+
AU.addUsedIfAvailable<AAResultsWrapperPass>();
433+
AU.addPreserved<AAResultsWrapperPass>();
418434
MachineFunctionPass::getAnalysisUsage(AU);
419435
}
420436

@@ -707,7 +723,40 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
707723
(TII->isDS(Inst) || TII->mayWriteLDSThroughDMA(Inst))) {
708724
// MUBUF and FLAT LDS DMA operations need a wait on vmcnt before LDS
709725
// written can be accessed. A load from LDS to VMEM does not need a wait.
710-
setRegScore(SQ_MAX_PGM_VGPRS + EXTRA_VGPR_LDS, T, CurrScore);
726+
unsigned Slot = 0;
727+
for (const auto *MemOp : Inst.memoperands()) {
728+
if (!MemOp->isStore() ||
729+
MemOp->getAddrSpace() != AMDGPUAS::LOCAL_ADDRESS)
730+
continue;
731+
// Comparing just AA info does not guarantee memoperands are equal
732+
// in general, but this is so for LDS DMA in practice.
733+
auto AAI = MemOp->getAAInfo();
734+
// Alias scope information gives a way to definitely identify an
735+
// original memory object and practically produced in the module LDS
736+
// lowering pass. If there is no scope available we will not be able
737+
// to disambiguate LDS aliasing as after the module lowering all LDS
738+
// is squashed into a single big object. Do not attempt to use one of
739+
// the limited LDSDMAStores for something we will not be able to use
740+
// anyway.
741+
if (!AAI || !AAI.Scope)
742+
break;
743+
for (unsigned I = 0, E = LDSDMAStores.size(); I != E && !Slot; ++I) {
744+
for (const auto *MemOp : LDSDMAStores[I]->memoperands()) {
745+
if (MemOp->isStore() && AAI == MemOp->getAAInfo()) {
746+
Slot = I + 1;
747+
break;
748+
}
749+
}
750+
}
751+
if (Slot || LDSDMAStores.size() == NUM_EXTRA_VGPRS - 1)
752+
break;
753+
LDSDMAStores.push_back(&Inst);
754+
Slot = LDSDMAStores.size();
755+
break;
756+
}
757+
setRegScore(SQ_MAX_PGM_VGPRS + EXTRA_VGPR_LDS + Slot, T, CurrScore);
758+
if (Slot)
759+
setRegScore(SQ_MAX_PGM_VGPRS + EXTRA_VGPR_LDS, T, CurrScore);
711760
}
712761
}
713762
}
@@ -1183,9 +1232,27 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
11831232
// No need to wait before load from VMEM to LDS.
11841233
if (TII->mayWriteLDSThroughDMA(MI))
11851234
continue;
1186-
unsigned RegNo = SQ_MAX_PGM_VGPRS + EXTRA_VGPR_LDS;
1235+
11871236
// VM_CNT is only relevant to vgpr or LDS.
1188-
ScoreBrackets.determineWait(VM_CNT, RegNo, Wait);
1237+
unsigned RegNo = SQ_MAX_PGM_VGPRS + EXTRA_VGPR_LDS;
1238+
bool FoundAliasingStore = false;
1239+
// Only objects with alias scope info were added to LDSDMAScopes array.
1240+
// In the absense of the scope info we will not be able to disambiguate
1241+
// aliasing here. There is no need to try searching for a corresponding
1242+
// store slot. This is conservatively correct because in that case we
1243+
// will produce a wait using the first (general) LDS DMA wait slot which
1244+
// will wait on all of them anyway.
1245+
if (Ptr && Memop->getAAInfo() && Memop->getAAInfo().Scope) {
1246+
const auto &LDSDMAStores = ScoreBrackets.getLDSDMAStores();
1247+
for (unsigned I = 0, E = LDSDMAStores.size(); I != E; ++I) {
1248+
if (MI.mayAlias(AA, *LDSDMAStores[I], true)) {
1249+
FoundAliasingStore = true;
1250+
ScoreBrackets.determineWait(VM_CNT, RegNo + I + 1, Wait);
1251+
}
1252+
}
1253+
}
1254+
if (!FoundAliasingStore)
1255+
ScoreBrackets.determineWait(VM_CNT, RegNo, Wait);
11891256
if (Memop->isStore()) {
11901257
ScoreBrackets.determineWait(EXP_CNT, RegNo, Wait);
11911258
}
@@ -1834,6 +1901,8 @@ bool SIInsertWaitcnts::runOnMachineFunction(MachineFunction &MF) {
18341901
const SIMachineFunctionInfo *MFI = MF.getInfo<SIMachineFunctionInfo>();
18351902
MLI = &getAnalysis<MachineLoopInfo>();
18361903
PDT = &getAnalysis<MachinePostDominatorTree>();
1904+
if (auto AAR = getAnalysisIfAvailable<AAResultsWrapperPass>())
1905+
AA = &AAR->getAAResults();
18371906

18381907
ForceEmitZeroWaitcnts = ForceEmitZeroFlag;
18391908
for (auto T : inst_counter_types())

llvm/test/CodeGen/AMDGPU/lds-dma-waits.ll

Lines changed: 97 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,23 @@
33

44
@lds.0 = internal addrspace(3) global [64 x float] poison, align 16
55
@lds.1 = internal addrspace(3) global [64 x float] poison, align 16
6+
@lds.2 = internal addrspace(3) global [64 x float] poison, align 16
7+
@lds.3 = internal addrspace(3) global [64 x float] poison, align 16
8+
@lds.4 = internal addrspace(3) global [64 x float] poison, align 16
9+
@lds.5 = internal addrspace(3) global [64 x float] poison, align 16
10+
@lds.6 = internal addrspace(3) global [64 x float] poison, align 16
11+
@lds.7 = internal addrspace(3) global [64 x float] poison, align 16
12+
@lds.8 = internal addrspace(3) global [64 x float] poison, align 16
13+
@lds.9 = internal addrspace(3) global [64 x float] poison, align 16
614

715
declare void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) nocapture, i32 %size, i32 %voffset, i32 %soffset, i32 %offset, i32 %aux)
816
declare void @llvm.amdgcn.global.load.lds(ptr addrspace(1) nocapture %gptr, ptr addrspace(3) nocapture %lptr, i32 %size, i32 %offset, i32 %aux)
917

10-
; FIXME: vmcnt(0) is too strong, it shall use vmcnt(2) before the first
11-
; ds_read_b32 and vmcnt(0) before the second.
12-
1318
; GCN-LABEL: {{^}}buffer_load_lds_dword_2_arrays:
1419
; GCN-COUNT-4: buffer_load_dword
15-
; GCN: s_waitcnt vmcnt(0)
20+
; GCN: s_waitcnt vmcnt(2)
1621
; GCN: ds_read_b32
17-
18-
; FIXME:
19-
; GCN-NOT: s_waitcnt
20-
22+
; GCN: s_waitcnt vmcnt(0)
2123
; GCN: ds_read_b32
2224
define amdgpu_kernel void @buffer_load_lds_dword_2_arrays(<4 x i32> %rsrc, i32 %i1, i32 %i2, ptr addrspace(1) %out) {
2325
main_body:
@@ -43,15 +45,9 @@ main_body:
4345
; GCN-COUNT-4: global_load_dword
4446
; GFX9: s_waitcnt vmcnt(0)
4547
; GFX9-COUNT-2: ds_read_b32
46-
47-
; FIXME: can be vmcnt(2)
48-
49-
; GFX10: s_waitcnt vmcnt(0)
48+
; GFX10: s_waitcnt vmcnt(2)
5049
; GFX10: ds_read_b32
51-
52-
; FIXME:
53-
; GFX10-NOT: s_waitcnt
54-
50+
; GFX10: s_waitcnt vmcnt(0)
5551
; GFX10: ds_read_b32
5652
define amdgpu_kernel void @global_load_lds_dword_2_arrays(ptr addrspace(1) nocapture %gptr, i32 %i1, i32 %i2, ptr addrspace(1) %out) {
5753
main_body:
@@ -70,4 +66,89 @@ main_body:
7066
ret void
7167
}
7268

69+
; There are 8 pseudo registers defined to track LDS DMA dependencies.
70+
; When exhausted we default to vmcnt(0).
71+
72+
; GCN-LABEL: {{^}}buffer_load_lds_dword_10_arrays:
73+
; GCN-COUNT-10: buffer_load_dword
74+
; GCN: s_waitcnt vmcnt(8)
75+
; GCN: ds_read_b32
76+
; GCN: s_waitcnt vmcnt(7)
77+
; GCN: ds_read_b32
78+
; GCN: s_waitcnt vmcnt(6)
79+
; GCN: ds_read_b32
80+
; GCN: s_waitcnt vmcnt(5)
81+
; GCN: ds_read_b32
82+
; GCN: s_waitcnt vmcnt(4)
83+
; GCN: ds_read_b32
84+
; GCN: s_waitcnt vmcnt(3)
85+
; GCN: ds_read_b32
86+
; GCN: s_waitcnt vmcnt(2)
87+
; GCN-NOT: s_waitcnt vmcnt
88+
; GCN: ds_read_b32
89+
; GCN: s_waitcnt vmcnt(0)
90+
; GCN: ds_read_b32
91+
define amdgpu_kernel void @buffer_load_lds_dword_10_arrays(<4 x i32> %rsrc, i32 %i1, i32 %i2, i32 %i3, i32 %i4, i32 %i5, i32 %i6, i32 %i7, i32 %i8, i32 %i9, ptr addrspace(1) %out) {
92+
main_body:
93+
call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.0, i32 4, i32 0, i32 0, i32 0, i32 0)
94+
call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.1, i32 4, i32 0, i32 0, i32 0, i32 0)
95+
call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.2, i32 4, i32 0, i32 0, i32 0, i32 0)
96+
call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.3, i32 4, i32 0, i32 0, i32 0, i32 0)
97+
call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.4, i32 4, i32 0, i32 0, i32 0, i32 0)
98+
call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.5, i32 4, i32 0, i32 0, i32 0, i32 0)
99+
call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.6, i32 4, i32 0, i32 0, i32 0, i32 0)
100+
call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.7, i32 4, i32 0, i32 0, i32 0, i32 0)
101+
call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.8, i32 4, i32 0, i32 0, i32 0, i32 0)
102+
call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.9, i32 4, i32 0, i32 0, i32 0, i32 0)
103+
%gep.0 = getelementptr float, ptr addrspace(3) @lds.0, i32 %i1
104+
%gep.1 = getelementptr float, ptr addrspace(3) @lds.1, i32 %i2
105+
%gep.2 = getelementptr float, ptr addrspace(3) @lds.2, i32 %i2
106+
%gep.3 = getelementptr float, ptr addrspace(3) @lds.3, i32 %i2
107+
%gep.4 = getelementptr float, ptr addrspace(3) @lds.4, i32 %i2
108+
%gep.5 = getelementptr float, ptr addrspace(3) @lds.5, i32 %i2
109+
%gep.6 = getelementptr float, ptr addrspace(3) @lds.6, i32 %i2
110+
%gep.7 = getelementptr float, ptr addrspace(3) @lds.7, i32 %i2
111+
%gep.8 = getelementptr float, ptr addrspace(3) @lds.8, i32 %i2
112+
%gep.9 = getelementptr float, ptr addrspace(3) @lds.9, i32 %i2
113+
%val.0 = load float, ptr addrspace(3) %gep.0, align 4
114+
call void @llvm.amdgcn.wave.barrier()
115+
%val.1 = load float, ptr addrspace(3) %gep.1, align 4
116+
call void @llvm.amdgcn.wave.barrier()
117+
%val.2 = load float, ptr addrspace(3) %gep.2, align 4
118+
call void @llvm.amdgcn.wave.barrier()
119+
%val.3 = load float, ptr addrspace(3) %gep.3, align 4
120+
call void @llvm.amdgcn.wave.barrier()
121+
%val.4 = load float, ptr addrspace(3) %gep.4, align 4
122+
call void @llvm.amdgcn.wave.barrier()
123+
%val.5 = load float, ptr addrspace(3) %gep.5, align 4
124+
call void @llvm.amdgcn.wave.barrier()
125+
%val.6 = load float, ptr addrspace(3) %gep.6, align 4
126+
call void @llvm.amdgcn.wave.barrier()
127+
%val.7 = load float, ptr addrspace(3) %gep.7, align 4
128+
call void @llvm.amdgcn.wave.barrier()
129+
%val.8 = load float, ptr addrspace(3) %gep.8, align 4
130+
call void @llvm.amdgcn.wave.barrier()
131+
%val.9 = load float, ptr addrspace(3) %gep.9, align 4
132+
%out.gep.1 = getelementptr float, ptr addrspace(1) %out, i32 1
133+
%out.gep.2 = getelementptr float, ptr addrspace(1) %out, i32 2
134+
%out.gep.3 = getelementptr float, ptr addrspace(1) %out, i32 3
135+
%out.gep.4 = getelementptr float, ptr addrspace(1) %out, i32 4
136+
%out.gep.5 = getelementptr float, ptr addrspace(1) %out, i32 5
137+
%out.gep.6 = getelementptr float, ptr addrspace(1) %out, i32 6
138+
%out.gep.7 = getelementptr float, ptr addrspace(1) %out, i32 7
139+
%out.gep.8 = getelementptr float, ptr addrspace(1) %out, i32 8
140+
%out.gep.9 = getelementptr float, ptr addrspace(1) %out, i32 9
141+
store float %val.0, ptr addrspace(1) %out
142+
store float %val.1, ptr addrspace(1) %out.gep.1
143+
store float %val.2, ptr addrspace(1) %out.gep.2
144+
store float %val.3, ptr addrspace(1) %out.gep.3
145+
store float %val.4, ptr addrspace(1) %out.gep.4
146+
store float %val.5, ptr addrspace(1) %out.gep.5
147+
store float %val.6, ptr addrspace(1) %out.gep.6
148+
store float %val.7, ptr addrspace(1) %out.gep.7
149+
store float %val.8, ptr addrspace(1) %out.gep.8
150+
store float %val.9, ptr addrspace(1) %out.gep.9
151+
ret void
152+
}
153+
73154
declare void @llvm.amdgcn.wave.barrier()

0 commit comments

Comments
 (0)