Skip to content

Commit 04b50f7

Browse files
rampitecsearlmc1
authored andcommitted
[AMDGPU] Use alias info to relax waitcounts for LDS DMA (llvm#74537)
LDA DMA loads increase VMCNT and a load from the LDS stored must wait on this counter to only read memory after it is written. Wait count insertion pass does not track memory dependencies, it tracks register dependencies. To model the LDS dependency a pseudo register is used in the scoreboard, acting like if LDS DMA writes it and LDS load reads it. This patch adds 8 more pseudo registers to use for independent LDS locations if we can prove they are disjoint using alias analysis. Fixes: SWDEV-433427 Change-Id: I21e5931c3db0676a8778489b07f490976c6ef45e
1 parent 024ad3e commit 04b50f7

File tree

2 files changed

+171
-21
lines changed

2 files changed

+171
-21
lines changed

llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "llvm/ADT/MapVector.h"
3232
#include "llvm/ADT/PostOrderIterator.h"
3333
#include "llvm/ADT/Sequence.h"
34+
#include "llvm/Analysis/AliasAnalysis.h"
3435
#include "llvm/CodeGen/MachineLoopInfo.h"
3536
#include "llvm/CodeGen/MachinePostDominators.h"
3637
#include "llvm/InitializePasses.h"
@@ -121,8 +122,13 @@ enum RegisterMapping {
121122
SQ_MAX_PGM_VGPRS = 512, // Maximum programmable VGPRs across all targets.
122123
AGPR_OFFSET = 256, // Maximum programmable ArchVGPRs across all targets.
123124
SQ_MAX_PGM_SGPRS = 256, // Maximum programmable SGPRs across all targets.
124-
NUM_EXTRA_VGPRS = 1, // A reserved slot for DS.
125-
EXTRA_VGPR_LDS = 0, // An artificial register to track LDS writes.
125+
NUM_EXTRA_VGPRS = 9, // Reserved slots for DS.
126+
// Artificial register slots to track LDS writes into specific LDS locations
127+
// if a location is known. When slots are exhausted or location is
128+
// unknown use the first slot. The first slot is also always updated in
129+
// addition to known location's slot to properly generate waits if dependent
130+
// instruction's location is unknown.
131+
EXTRA_VGPR_LDS = 0,
126132
NUM_ALL_VGPRS = SQ_MAX_PGM_VGPRS + NUM_EXTRA_VGPRS, // Where SGPR starts.
127133
};
128134

@@ -292,6 +298,10 @@ class WaitcntBrackets {
292298
VgprVmemTypes[GprNo] = 0;
293299
}
294300

301+
ArrayRef<const MachineInstr *> getLDSDMAStores() const {
302+
return LDSDMAStores;
303+
}
304+
295305
void print(raw_ostream &);
296306
void dump() { print(dbgs()); }
297307

@@ -354,6 +364,9 @@ class WaitcntBrackets {
354364
// Bitmask of the VmemTypes of VMEM instructions that might have a pending
355365
// write to each vgpr.
356366
unsigned char VgprVmemTypes[NUM_ALL_VGPRS] = {0};
367+
// Store representative LDS DMA operations. The only useful info here is
368+
// alias info. One store is kept per unique AAInfo.
369+
SmallVector<const MachineInstr *, NUM_EXTRA_VGPRS - 1> LDSDMAStores;
357370
};
358371

359372
class SIInsertWaitcnts : public MachineFunctionPass {
@@ -369,6 +382,7 @@ class SIInsertWaitcnts : public MachineFunctionPass {
369382
DenseMap<MachineBasicBlock *, bool> PreheadersToFlush;
370383
MachineLoopInfo *MLI;
371384
MachinePostDominatorTree *PDT;
385+
AliasAnalysis *AA = nullptr;
372386

373387
struct BlockInfo {
374388
std::unique_ptr<WaitcntBrackets> Incoming;
@@ -411,6 +425,8 @@ class SIInsertWaitcnts : public MachineFunctionPass {
411425
AU.setPreservesCFG();
412426
AU.addRequired<MachineLoopInfo>();
413427
AU.addRequired<MachinePostDominatorTree>();
428+
AU.addUsedIfAvailable<AAResultsWrapperPass>();
429+
AU.addPreserved<AAResultsWrapperPass>();
414430
MachineFunctionPass::getAnalysisUsage(AU);
415431
}
416432

@@ -701,7 +717,40 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
701717
(TII->isDS(Inst) || TII->mayWriteLDSThroughDMA(Inst))) {
702718
// MUBUF and FLAT LDS DMA operations need a wait on vmcnt before LDS
703719
// written can be accessed. A load from LDS to VMEM does not need a wait.
704-
setRegScore(SQ_MAX_PGM_VGPRS + EXTRA_VGPR_LDS, T, CurrScore);
720+
unsigned Slot = 0;
721+
for (const auto *MemOp : Inst.memoperands()) {
722+
if (!MemOp->isStore() ||
723+
MemOp->getAddrSpace() != AMDGPUAS::LOCAL_ADDRESS)
724+
continue;
725+
// Comparing just AA info does not guarantee memoperands are equal
726+
// in general, but this is so for LDS DMA in practice.
727+
auto AAI = MemOp->getAAInfo();
728+
// Alias scope information gives a way to definitely identify an
729+
// original memory object and practically produced in the module LDS
730+
// lowering pass. If there is no scope available we will not be able
731+
// to disambiguate LDS aliasing as after the module lowering all LDS
732+
// is squashed into a single big object. Do not attempt to use one of
733+
// the limited LDSDMAStores for something we will not be able to use
734+
// anyway.
735+
if (!AAI || !AAI.Scope)
736+
break;
737+
for (unsigned I = 0, E = LDSDMAStores.size(); I != E && !Slot; ++I) {
738+
for (const auto *MemOp : LDSDMAStores[I]->memoperands()) {
739+
if (MemOp->isStore() && AAI == MemOp->getAAInfo()) {
740+
Slot = I + 1;
741+
break;
742+
}
743+
}
744+
}
745+
if (Slot || LDSDMAStores.size() == NUM_EXTRA_VGPRS - 1)
746+
break;
747+
LDSDMAStores.push_back(&Inst);
748+
Slot = LDSDMAStores.size();
749+
break;
750+
}
751+
setRegScore(SQ_MAX_PGM_VGPRS + EXTRA_VGPR_LDS + Slot, T, CurrScore);
752+
if (Slot)
753+
setRegScore(SQ_MAX_PGM_VGPRS + EXTRA_VGPR_LDS, T, CurrScore);
705754
}
706755
}
707756
}
@@ -1186,9 +1235,27 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
11861235
// No need to wait before load from VMEM to LDS.
11871236
if (TII->mayWriteLDSThroughDMA(MI))
11881237
continue;
1189-
unsigned RegNo = SQ_MAX_PGM_VGPRS + EXTRA_VGPR_LDS;
1238+
11901239
// VM_CNT is only relevant to vgpr or LDS.
1191-
ScoreBrackets.determineWait(VM_CNT, RegNo, Wait);
1240+
unsigned RegNo = SQ_MAX_PGM_VGPRS + EXTRA_VGPR_LDS;
1241+
bool FoundAliasingStore = false;
1242+
// Only objects with alias scope info were added to LDSDMAScopes array.
1243+
// In the absense of the scope info we will not be able to disambiguate
1244+
// aliasing here. There is no need to try searching for a corresponding
1245+
// store slot. This is conservatively correct because in that case we
1246+
// will produce a wait using the first (general) LDS DMA wait slot which
1247+
// will wait on all of them anyway.
1248+
if (Ptr && Memop->getAAInfo() && Memop->getAAInfo().Scope) {
1249+
const auto &LDSDMAStores = ScoreBrackets.getLDSDMAStores();
1250+
for (unsigned I = 0, E = LDSDMAStores.size(); I != E; ++I) {
1251+
if (MI.mayAlias(AA, *LDSDMAStores[I], true)) {
1252+
FoundAliasingStore = true;
1253+
ScoreBrackets.determineWait(VM_CNT, RegNo + I + 1, Wait);
1254+
}
1255+
}
1256+
}
1257+
if (!FoundAliasingStore)
1258+
ScoreBrackets.determineWait(VM_CNT, RegNo, Wait);
11921259
if (Memop->isStore()) {
11931260
ScoreBrackets.determineWait(EXP_CNT, RegNo, Wait);
11941261
}
@@ -1825,6 +1892,8 @@ bool SIInsertWaitcnts::runOnMachineFunction(MachineFunction &MF) {
18251892
const SIMachineFunctionInfo *MFI = MF.getInfo<SIMachineFunctionInfo>();
18261893
MLI = &getAnalysis<MachineLoopInfo>();
18271894
PDT = &getAnalysis<MachinePostDominatorTree>();
1895+
if (auto AAR = getAnalysisIfAvailable<AAResultsWrapperPass>())
1896+
AA = &AAR->getAAResults();
18281897

18291898
ForceEmitZeroWaitcnts = ForceEmitZeroFlag;
18301899
for (auto T : inst_counter_types())

llvm/test/CodeGen/AMDGPU/lds-dma-waits.ll

Lines changed: 97 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,23 @@
33

44
@lds.0 = internal addrspace(3) global [64 x float] poison, align 16
55
@lds.1 = internal addrspace(3) global [64 x float] poison, align 16
6+
@lds.2 = internal addrspace(3) global [64 x float] poison, align 16
7+
@lds.3 = internal addrspace(3) global [64 x float] poison, align 16
8+
@lds.4 = internal addrspace(3) global [64 x float] poison, align 16
9+
@lds.5 = internal addrspace(3) global [64 x float] poison, align 16
10+
@lds.6 = internal addrspace(3) global [64 x float] poison, align 16
11+
@lds.7 = internal addrspace(3) global [64 x float] poison, align 16
12+
@lds.8 = internal addrspace(3) global [64 x float] poison, align 16
13+
@lds.9 = internal addrspace(3) global [64 x float] poison, align 16
614

715
declare void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) nocapture, i32 %size, i32 %voffset, i32 %soffset, i32 %offset, i32 %aux)
816
declare void @llvm.amdgcn.global.load.lds(ptr addrspace(1) nocapture %gptr, ptr addrspace(3) nocapture %lptr, i32 %size, i32 %offset, i32 %aux)
917

10-
; FIXME: vmcnt(0) is too strong, it shall use vmcnt(2) before the first
11-
; ds_read_b32 and vmcnt(0) before the second.
12-
1318
; GCN-LABEL: {{^}}buffer_load_lds_dword_2_arrays:
1419
; GCN-COUNT-4: buffer_load_dword
15-
; GCN: s_waitcnt vmcnt(0)
20+
; GCN: s_waitcnt vmcnt(2)
1621
; GCN: ds_read_b32
17-
18-
; FIXME:
19-
; GCN-NOT: s_waitcnt
20-
22+
; GCN: s_waitcnt vmcnt(0)
2123
; GCN: ds_read_b32
2224
define amdgpu_kernel void @buffer_load_lds_dword_2_arrays(<4 x i32> %rsrc, i32 %i1, i32 %i2, ptr addrspace(1) %out) {
2325
main_body:
@@ -43,15 +45,9 @@ main_body:
4345
; GCN-COUNT-4: global_load_dword
4446
; GFX9: s_waitcnt vmcnt(0)
4547
; GFX9-COUNT-2: ds_read_b32
46-
47-
; FIXME: can be vmcnt(2)
48-
49-
; GFX10: s_waitcnt vmcnt(0)
48+
; GFX10: s_waitcnt vmcnt(2)
5049
; GFX10: ds_read_b32
51-
52-
; FIXME:
53-
; GFX10-NOT: s_waitcnt
54-
50+
; GFX10: s_waitcnt vmcnt(0)
5551
; GFX10: ds_read_b32
5652
define amdgpu_kernel void @global_load_lds_dword_2_arrays(ptr addrspace(1) nocapture %gptr, i32 %i1, i32 %i2, ptr addrspace(1) %out) {
5753
main_body:
@@ -70,4 +66,89 @@ main_body:
7066
ret void
7167
}
7268

69+
; There are 8 pseudo registers defined to track LDS DMA dependencies.
70+
; When exhausted we default to vmcnt(0).
71+
72+
; GCN-LABEL: {{^}}buffer_load_lds_dword_10_arrays:
73+
; GCN-COUNT-10: buffer_load_dword
74+
; GCN: s_waitcnt vmcnt(8)
75+
; GCN: ds_read_b32
76+
; GCN: s_waitcnt vmcnt(7)
77+
; GCN: ds_read_b32
78+
; GCN: s_waitcnt vmcnt(6)
79+
; GCN: ds_read_b32
80+
; GCN: s_waitcnt vmcnt(5)
81+
; GCN: ds_read_b32
82+
; GCN: s_waitcnt vmcnt(4)
83+
; GCN: ds_read_b32
84+
; GCN: s_waitcnt vmcnt(3)
85+
; GCN: ds_read_b32
86+
; GCN: s_waitcnt vmcnt(2)
87+
; GCN-NOT: s_waitcnt vmcnt
88+
; GCN: ds_read_b32
89+
; GCN: s_waitcnt vmcnt(0)
90+
; GCN: ds_read_b32
91+
define amdgpu_kernel void @buffer_load_lds_dword_10_arrays(<4 x i32> %rsrc, i32 %i1, i32 %i2, i32 %i3, i32 %i4, i32 %i5, i32 %i6, i32 %i7, i32 %i8, i32 %i9, ptr addrspace(1) %out) {
92+
main_body:
93+
call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.0, i32 4, i32 0, i32 0, i32 0, i32 0)
94+
call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.1, i32 4, i32 0, i32 0, i32 0, i32 0)
95+
call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.2, i32 4, i32 0, i32 0, i32 0, i32 0)
96+
call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.3, i32 4, i32 0, i32 0, i32 0, i32 0)
97+
call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.4, i32 4, i32 0, i32 0, i32 0, i32 0)
98+
call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.5, i32 4, i32 0, i32 0, i32 0, i32 0)
99+
call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.6, i32 4, i32 0, i32 0, i32 0, i32 0)
100+
call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.7, i32 4, i32 0, i32 0, i32 0, i32 0)
101+
call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.8, i32 4, i32 0, i32 0, i32 0, i32 0)
102+
call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.9, i32 4, i32 0, i32 0, i32 0, i32 0)
103+
%gep.0 = getelementptr float, ptr addrspace(3) @lds.0, i32 %i1
104+
%gep.1 = getelementptr float, ptr addrspace(3) @lds.1, i32 %i2
105+
%gep.2 = getelementptr float, ptr addrspace(3) @lds.2, i32 %i2
106+
%gep.3 = getelementptr float, ptr addrspace(3) @lds.3, i32 %i2
107+
%gep.4 = getelementptr float, ptr addrspace(3) @lds.4, i32 %i2
108+
%gep.5 = getelementptr float, ptr addrspace(3) @lds.5, i32 %i2
109+
%gep.6 = getelementptr float, ptr addrspace(3) @lds.6, i32 %i2
110+
%gep.7 = getelementptr float, ptr addrspace(3) @lds.7, i32 %i2
111+
%gep.8 = getelementptr float, ptr addrspace(3) @lds.8, i32 %i2
112+
%gep.9 = getelementptr float, ptr addrspace(3) @lds.9, i32 %i2
113+
%val.0 = load float, ptr addrspace(3) %gep.0, align 4
114+
call void @llvm.amdgcn.wave.barrier()
115+
%val.1 = load float, ptr addrspace(3) %gep.1, align 4
116+
call void @llvm.amdgcn.wave.barrier()
117+
%val.2 = load float, ptr addrspace(3) %gep.2, align 4
118+
call void @llvm.amdgcn.wave.barrier()
119+
%val.3 = load float, ptr addrspace(3) %gep.3, align 4
120+
call void @llvm.amdgcn.wave.barrier()
121+
%val.4 = load float, ptr addrspace(3) %gep.4, align 4
122+
call void @llvm.amdgcn.wave.barrier()
123+
%val.5 = load float, ptr addrspace(3) %gep.5, align 4
124+
call void @llvm.amdgcn.wave.barrier()
125+
%val.6 = load float, ptr addrspace(3) %gep.6, align 4
126+
call void @llvm.amdgcn.wave.barrier()
127+
%val.7 = load float, ptr addrspace(3) %gep.7, align 4
128+
call void @llvm.amdgcn.wave.barrier()
129+
%val.8 = load float, ptr addrspace(3) %gep.8, align 4
130+
call void @llvm.amdgcn.wave.barrier()
131+
%val.9 = load float, ptr addrspace(3) %gep.9, align 4
132+
%out.gep.1 = getelementptr float, ptr addrspace(1) %out, i32 1
133+
%out.gep.2 = getelementptr float, ptr addrspace(1) %out, i32 2
134+
%out.gep.3 = getelementptr float, ptr addrspace(1) %out, i32 3
135+
%out.gep.4 = getelementptr float, ptr addrspace(1) %out, i32 4
136+
%out.gep.5 = getelementptr float, ptr addrspace(1) %out, i32 5
137+
%out.gep.6 = getelementptr float, ptr addrspace(1) %out, i32 6
138+
%out.gep.7 = getelementptr float, ptr addrspace(1) %out, i32 7
139+
%out.gep.8 = getelementptr float, ptr addrspace(1) %out, i32 8
140+
%out.gep.9 = getelementptr float, ptr addrspace(1) %out, i32 9
141+
store float %val.0, ptr addrspace(1) %out
142+
store float %val.1, ptr addrspace(1) %out.gep.1
143+
store float %val.2, ptr addrspace(1) %out.gep.2
144+
store float %val.3, ptr addrspace(1) %out.gep.3
145+
store float %val.4, ptr addrspace(1) %out.gep.4
146+
store float %val.5, ptr addrspace(1) %out.gep.5
147+
store float %val.6, ptr addrspace(1) %out.gep.6
148+
store float %val.7, ptr addrspace(1) %out.gep.7
149+
store float %val.8, ptr addrspace(1) %out.gep.8
150+
store float %val.9, ptr addrspace(1) %out.gep.9
151+
ret void
152+
}
153+
73154
declare void @llvm.amdgcn.wave.barrier()

0 commit comments

Comments
 (0)