Skip to content

Commit c03fe73

Browse files
committed
[OPENMP][NVPTX]Correctly handle L2 parallelism in SPMD mode.
Summary: The parallelLevel counter must be on per-thread basis to fully support L2+ parallelism, otherwise we may end up with undefined behavior. Introduce the parallelLevel on per-warp basis using shared memory. It allows to avoid the problems with the synchronization and allows fully support L2+ parallelism in SPMD mode with no runtime. Reviewers: gtbercea, grokos Subscribers: guansong, jdoerfert, caomhin, kkwli0, openmp-commits Tags: #openmp Differential Revision: https://reviews.llvm.org/D60918 llvm-svn: 359341
1 parent 5ddc6d1 commit c03fe73

File tree

8 files changed

+50
-26
lines changed

8 files changed

+50
-26
lines changed

openmp/libomptarget/deviceRTLs/nvptx/src/libcall.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ EXTERN int omp_get_level(void) {
165165
ASSERT0(LT_FUSSY, isSPMDMode(),
166166
"Expected SPMD mode only with uninitialized runtime.");
167167
// parallelLevel starts from 0, need to add 1 for correct level.
168-
return parallelLevel + 1;
168+
return parallelLevel[GetWarpId()] + 1;
169169
}
170170
int level = 0;
171171
omptarget_nvptx_TaskDescr *currTaskDescr =

openmp/libomptarget/deviceRTLs/nvptx/src/omp_data.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ __device__ omptarget_nvptx_SimpleMemoryManager
3131
__device__ __shared__ uint32_t usedMemIdx;
3232
__device__ __shared__ uint32_t usedSlotIdx;
3333

34-
__device__ __shared__ uint8_t parallelLevel;
34+
__device__ __shared__ uint8_t parallelLevel[MAX_THREADS_PER_TEAM / WARPSIZE];
3535

3636
// Pointer to this team's OpenMP state object
3737
__device__ __shared__

openmp/libomptarget/deviceRTLs/nvptx/src/omptarget-nvptx.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,10 @@ EXTERN void __kmpc_spmd_kernel_init(int ThreadLimit, int16_t RequiresOMPRuntime,
9595
// If OMP runtime is not required don't initialize OMP state.
9696
setExecutionParameters(Spmd, RuntimeUninitialized);
9797
if (GetThreadIdInBlock() == 0) {
98-
parallelLevel = 0;
9998
usedSlotIdx = smid() % MAX_SM;
99+
parallelLevel[0] = 0;
100+
} else if (GetLaneId() == 0) {
101+
parallelLevel[GetWarpId()] = 0;
100102
}
101103
__SYNCTHREADS();
102104
return;

openmp/libomptarget/deviceRTLs/nvptx/src/omptarget-nvptx.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,8 @@ extern __device__ omptarget_nvptx_SimpleMemoryManager
406406
omptarget_nvptx_simpleMemoryManager;
407407
extern __device__ __shared__ uint32_t usedMemIdx;
408408
extern __device__ __shared__ uint32_t usedSlotIdx;
409-
extern __device__ __shared__ uint8_t parallelLevel;
409+
extern __device__ __shared__ uint8_t
410+
parallelLevel[MAX_THREADS_PER_TEAM / WARPSIZE];
410411
extern __device__ __shared__
411412
omptarget_nvptx_ThreadPrivateContext *omptarget_nvptx_threadPrivateContext;
412413

openmp/libomptarget/deviceRTLs/nvptx/src/parallel.cu

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -339,10 +339,12 @@ EXTERN void __kmpc_serialized_parallel(kmp_Ident *loc, uint32_t global_tid) {
339339
if (checkRuntimeUninitialized(loc)) {
340340
ASSERT0(LT_FUSSY, checkSPMDMode(loc),
341341
"Expected SPMD mode with uninitialized runtime.");
342-
__SYNCTHREADS();
343-
if (GetThreadIdInBlock() == 0)
344-
++parallelLevel;
345-
__SYNCTHREADS();
342+
unsigned tnum = __ACTIVEMASK();
343+
int leader = __ffs(tnum) - 1;
344+
__SHFL_SYNC(tnum, leader, leader);
345+
if (GetLaneId() == leader)
346+
++parallelLevel[GetWarpId()];
347+
__SHFL_SYNC(tnum, leader, leader);
346348

347349
return;
348350
}
@@ -382,10 +384,12 @@ EXTERN void __kmpc_end_serialized_parallel(kmp_Ident *loc,
382384
if (checkRuntimeUninitialized(loc)) {
383385
ASSERT0(LT_FUSSY, checkSPMDMode(loc),
384386
"Expected SPMD mode with uninitialized runtime.");
385-
__SYNCTHREADS();
386-
if (GetThreadIdInBlock() == 0)
387-
--parallelLevel;
388-
__SYNCTHREADS();
387+
unsigned tnum = __ACTIVEMASK();
388+
int leader = __ffs(tnum) - 1;
389+
__SHFL_SYNC(tnum, leader, leader);
390+
if (GetLaneId() == leader)
391+
--parallelLevel[GetWarpId()];
392+
__SHFL_SYNC(tnum, leader, leader);
389393
return;
390394
}
391395

@@ -407,7 +411,7 @@ EXTERN uint16_t __kmpc_parallel_level(kmp_Ident *loc, uint32_t global_tid) {
407411
if (checkRuntimeUninitialized(loc)) {
408412
ASSERT0(LT_FUSSY, checkSPMDMode(loc),
409413
"Expected SPMD mode with uninitialized runtime.");
410-
return parallelLevel + 1;
414+
return parallelLevel[GetWarpId()] + 1;
411415
}
412416

413417
int threadId = GetLogicalThreadIdInBlock(checkSPMDMode(loc));

openmp/libomptarget/deviceRTLs/nvptx/src/support.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ INLINE int GetThreadIdInBlock();
4040
INLINE int GetBlockIdInKernel();
4141
INLINE int GetNumberOfBlocksInKernel();
4242
INLINE int GetNumberOfThreadsInBlock();
43+
INLINE unsigned GetWarpId();
44+
INLINE unsigned GetLaneId();
4345

4446
// get global ids to locate tread/team info (constant regardless of OMP)
4547
INLINE int GetLogicalThreadIdInBlock(bool isSPMDExecutionMode);

openmp/libomptarget/deviceRTLs/nvptx/src/supporti.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ INLINE int GetNumberOfBlocksInKernel() { return gridDim.x; }
102102

103103
INLINE int GetNumberOfThreadsInBlock() { return blockDim.x; }
104104

105+
INLINE unsigned GetWarpId() { return threadIdx.x / WARPSIZE; }
106+
107+
INLINE unsigned GetLaneId() { return threadIdx.x & (WARPSIZE - 1); }
108+
105109
////////////////////////////////////////////////////////////////////////////////
106110
//
107111
// Calls to the Generic Scheme Implementation Layer (assuming 1D layout)
@@ -154,7 +158,7 @@ INLINE int GetOmpThreadId(int threadId, bool isSPMDExecutionMode,
154158
ASSERT0(LT_FUSSY, isSPMDExecutionMode,
155159
"Uninitialized runtime with non-SPMD mode.");
156160
// For level 2 parallelism all parallel regions are executed sequentially.
157-
if (parallelLevel > 0)
161+
if (parallelLevel[GetWarpId()] > 0)
158162
rc = 0;
159163
else
160164
rc = GetThreadIdInBlock();
@@ -175,7 +179,7 @@ INLINE int GetNumberOfOmpThreads(int threadId, bool isSPMDExecutionMode,
175179
ASSERT0(LT_FUSSY, isSPMDExecutionMode,
176180
"Uninitialized runtime with non-SPMD mode.");
177181
// For level 2 parallelism all parallel regions are executed sequentially.
178-
if (parallelLevel > 0)
182+
if (parallelLevel[GetWarpId()] > 0)
179183
rc = 1;
180184
else
181185
rc = GetNumberOfThreadsInBlock();

openmp/libomptarget/deviceRTLs/nvptx/test/parallel/spmd_parallel_regions.cpp

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,31 @@
66
int main(void) {
77
int isHost = -1;
88
int ParallelLevel1 = -1, ParallelLevel2 = -1;
9+
int Count = 0;
910

1011
#pragma omp target parallel for map(tofrom \
11-
: isHost, ParallelLevel1, ParallelLevel2)
12+
: isHost, ParallelLevel1, ParallelLevel2), reduction(+: Count) schedule(static, 1)
1213
for (int J = 0; J < 10; ++J) {
1314
#pragma omp critical
1415
{
15-
isHost = (isHost < 0 || isHost == omp_is_initial_device())
16-
? omp_is_initial_device()
17-
: 1;
18-
ParallelLevel1 =
19-
(ParallelLevel1 < 0 || ParallelLevel1 == 1) ? omp_get_level() : 2;
16+
isHost = (isHost < 0 || isHost == 0) ? omp_is_initial_device() : isHost;
17+
ParallelLevel1 = (ParallelLevel1 < 0 || ParallelLevel1 == 1)
18+
? omp_get_level()
19+
: ParallelLevel1;
2020
}
21-
int L2;
22-
#pragma omp parallel for schedule(dynamic) lastprivate(L2)
23-
for (int I = 0; I < 10; ++I)
24-
L2 = omp_get_level();
21+
if (omp_get_thread_num() > 5) {
22+
int L2;
23+
#pragma omp parallel for schedule(dynamic) lastprivate(L2) reduction(+: Count)
24+
for (int I = 0; I < 10; ++I) {
25+
L2 = omp_get_level();
26+
Count += omp_get_level(); // (10-6)*10*2 = 80
27+
}
2528
#pragma omp critical
26-
ParallelLevel2 = (ParallelLevel2 < 0 || ParallelLevel2 == 2) ? L2 : 1;
29+
ParallelLevel2 =
30+
(ParallelLevel2 < 0 || ParallelLevel2 == 2) ? L2 : ParallelLevel2;
31+
} else {
32+
Count += omp_get_level(); // 6 * 1 = 6
33+
}
2734
}
2835

2936
if (isHost < 0) {
@@ -35,6 +42,10 @@ int main(void) {
3542
// CHECK: Parallel level in SPMD mode: L1 is 1, L2 is 2
3643
printf("Parallel level in SPMD mode: L1 is %d, L2 is %d\n", ParallelLevel1,
3744
ParallelLevel2);
45+
// Final result of Count is (10-6)(num of loops)*10(num of iterations)*2(par
46+
// level) + 6(num of iterations) * 1(par level)
47+
// CHECK: Expected count = 86
48+
printf("Expected count = %d\n", Count);
3849

3950
return isHost;
4051
}

0 commit comments

Comments
 (0)