Skip to content

Commit b16aadf

Browse files
committed
[OpenMP] Introduce aligned synchronization into the new device RT
We will later use the fact that a barrier is aligned to reason about thread divergence. For now we introduce the assumption and some more documentation. Reviewed By: tianshilei1992 Differential Revision: https://reviews.llvm.org/D112153
1 parent e32b1ee commit b16aadf

File tree

5 files changed

+50
-8
lines changed

5 files changed

+50
-8
lines changed

openmp/libomptarget/DeviceRTL/include/Synchronization.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,21 @@ void warp(LaneMaskTy Mask);
2727
/// Synchronize all threads in a block.
2828
void threads();
2929

30+
/// Synchronizing threads is allowed even if they all hit different instances of
31+
/// `synchronize::threads()`. However, `synchronize::threadsAligned()` is more
32+
/// restrictive in that it requires all threads to hit the same instance. The
33+
/// noinline is removed by the openmp-opt pass and helps to preserve the
34+
/// information till then.
35+
///{
36+
#pragma omp begin assumes ext_aligned_barrier
37+
38+
/// Synchronize all threads in a block, they are are reaching the same
39+
/// instruction (hence all threads in the block are "aligned").
40+
__attribute__((noinline)) void threadsAligned();
41+
42+
#pragma omp end assumes
43+
///}
44+
3045
} // namespace synchronize
3146

3247
namespace fence {

openmp/libomptarget/DeviceRTL/src/Kernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ int32_t __kmpc_target_init(IdentTy *Ident, int8_t Mode,
6969
const bool IsSPMD = Mode & OMP_TGT_EXEC_MODE_SPMD;
7070
if (IsSPMD) {
7171
inititializeRuntime(/* IsSPMD */ true);
72-
synchronize::threads();
72+
synchronize::threadsAligned();
7373
} else {
7474
inititializeRuntime(/* IsSPMD */ false);
7575
// No need to wait since only the main threads will execute user

openmp/libomptarget/DeviceRTL/src/Parallelism.cpp

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,18 +93,36 @@ void __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
9393

9494
uint32_t NumThreads = determineNumberOfThreads(num_threads);
9595
if (mapping::isSPMDMode()) {
96-
synchronize::threads();
96+
// Avoid the race between the read of the `icv::Level` above and the write
97+
// below by synchronizing all threads here.
98+
synchronize::threadsAligned();
9799
{
100+
// Note that the order here is important. `icv::Level` has to be updated
101+
// last or the other updates will cause a thread specific state to be
102+
// created.
98103
state::ValueRAII ParallelTeamSizeRAII(state::ParallelTeamSize, NumThreads,
99104
1u, TId == 0);
100105
state::ValueRAII ActiveLevelRAII(icv::ActiveLevel, 1u, 0u, TId == 0);
101106
state::ValueRAII LevelRAII(icv::Level, 1u, 0u, TId == 0);
102-
synchronize::threads();
107+
108+
// Synchronize all threads after the main thread (TId == 0) set up the
109+
// team state properly.
110+
synchronize::threadsAligned();
111+
112+
ASSERT(state::ParallelTeamSize == NumThreads);
113+
ASSERT(icv::ActiveLevel == 1u);
114+
ASSERT(icv::Level == 1u);
103115

104116
if (TId < NumThreads)
105117
invokeMicrotask(TId, 0, fn, args, nargs);
106-
synchronize::threads();
118+
119+
// Synchronize all threads at the end of a parallel region.
120+
synchronize::threadsAligned();
107121
}
122+
123+
ASSERT(state::ParallelTeamSize == 1u);
124+
ASSERT(icv::ActiveLevel == 0u);
125+
ASSERT(icv::Level == 0u);
108126
return;
109127
}
110128

@@ -130,6 +148,9 @@ void __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
130148
}
131149

132150
{
151+
// Note that the order here is important. `icv::Level` has to be updated
152+
// last or the other updates will cause a thread specific state to be
153+
// created.
133154
state::ValueRAII ParallelTeamSizeRAII(state::ParallelTeamSize, NumThreads,
134155
1u, true);
135156
state::ValueRAII ParallelRegionFnRAII(state::ParallelRegionFn, wrapper_fn,

openmp/libomptarget/DeviceRTL/src/State.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ namespace {
4141
///{
4242

4343
extern "C" {
44-
void *malloc(uint64_t Size);
45-
void free(void *Ptr);
44+
__attribute__((leaf)) void *malloc(uint64_t Size);
45+
__attribute__((leaf)) void free(void *Ptr);
4646
}
4747

4848
///}

openmp/libomptarget/DeviceRTL/src/Synchronization.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ void syncWarp(__kmpc_impl_lanemask_t) {
132132

133133
void syncThreads() { __builtin_amdgcn_s_barrier(); }
134134

135+
void syncThreadsAligned() { syncThreads(); }
136+
135137
void fenceTeam(int Ordering) { __builtin_amdgcn_fence(Ordering, "workgroup"); }
136138

137139
void fenceKernel(int Ordering) { __builtin_amdgcn_fence(Ordering, "agent"); }
@@ -179,6 +181,8 @@ void syncThreads() {
179181
asm volatile("barrier.sync %0;" : : "r"(BarrierNo) : "memory");
180182
}
181183

184+
void syncThreadsAligned() { __syncthreads(); }
185+
182186
constexpr uint32_t OMP_SPIN = 1000;
183187
constexpr uint32_t UNSET = 0;
184188
constexpr uint32_t SET = 1;
@@ -227,6 +231,8 @@ void synchronize::warp(LaneMaskTy Mask) { impl::syncWarp(Mask); }
227231

228232
void synchronize::threads() { impl::syncThreads(); }
229233

234+
void synchronize::threadsAligned() { impl::syncThreadsAligned(); }
235+
230236
void fence::team(int Ordering) { impl::fenceTeam(Ordering); }
231237

232238
void fence::kernel(int Ordering) { impl::fenceKernel(Ordering); }
@@ -238,7 +244,7 @@ uint32_t atomic::load(uint32_t *Addr, int Ordering) {
238244
}
239245

240246
void atomic::store(uint32_t *Addr, uint32_t V, int Ordering) {
241-
impl::atomicStore(Addr, V, Ordering);
247+
impl::atomicStore(Addr, V, Ordering);
242248
}
243249

244250
uint32_t atomic::inc(uint32_t *Addr, uint32_t V, int Ordering) {
@@ -275,7 +281,7 @@ void __kmpc_barrier(IdentTy *Loc, int32_t TId) {
275281

276282
__attribute__((noinline)) void __kmpc_barrier_simple_spmd(IdentTy *Loc,
277283
int32_t TId) {
278-
synchronize::threads();
284+
synchronize::threadsAligned();
279285
}
280286

281287
int32_t __kmpc_master(IdentTy *Loc, int32_t TId) {

0 commit comments

Comments
 (0)