Skip to content

Commit d571af7

Browse files
committed
[OpenMP][FIX] Ensure thread states do not crash on the GPU
The nested parallelism causes thread states which still do not properly work but at least don't crash anymore.
1 parent ca01f2a commit d571af7

File tree

4 files changed

+48
-3
lines changed

4 files changed

+48
-3
lines changed

openmp/libomptarget/DeviceRTL/include/LibC.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
extern "C" {
1818

1919
int memcmp(const void *lhs, const void *rhs, size_t count);
20+
void memset(void *dst, int C, size_t count);
2021

2122
int printf(const char *format, ...);
2223
}

openmp/libomptarget/DeviceRTL/src/LibC.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ int memcmp(const void *lhs, const void *rhs, size_t count) {
4747
return 0;
4848
}
4949

50+
void memset(void *dst, int C, size_t count) {
51+
auto *dstc = reinterpret_cast<char *>(dst);
52+
for (size_t I = 0; I < count; ++I)
53+
dstc[I] = C;
54+
}
55+
5056
/// printf() calls are rewritten by CGGPUBuiltin to __llvm_omp_vprintf
5157
int32_t __llvm_omp_vprintf(const char *Format, void *Arguments, uint32_t Size) {
5258
return impl::omp_vprintf(Format, Arguments, Size);

openmp/libomptarget/DeviceRTL/src/State.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "Debug.h"
1313
#include "Environment.h"
1414
#include "Interface.h"
15+
#include "LibC.h"
1516
#include "Mapping.h"
1617
#include "Synchronization.h"
1718
#include "Types.h"
@@ -263,13 +264,14 @@ void state::enterDataEnvironment(IdentTy *Ident) {
263264
return;
264265

265266
unsigned TId = mapping::getThreadIdInBlock();
266-
ThreadStateTy *NewThreadState =
267-
static_cast<ThreadStateTy *>(__kmpc_alloc_shared(sizeof(ThreadStateTy)));
267+
ThreadStateTy *NewThreadState = static_cast<ThreadStateTy *>(
268+
memory::allocGlobal(sizeof(ThreadStateTy), "ThreadStates alloc"));
268269
uintptr_t *ThreadStatesBitsPtr = reinterpret_cast<uintptr_t *>(&ThreadStates);
269270
if (!atomic::load(ThreadStatesBitsPtr, atomic::seq_cst)) {
270271
uint32_t Bytes = sizeof(ThreadStates[0]) * mapping::getMaxTeamThreads();
271272
void *ThreadStatesPtr =
272273
memory::allocGlobal(Bytes, "Thread state array allocation");
274+
memset(ThreadStatesPtr, '0', Bytes);
273275
if (!atomic::cas(ThreadStatesBitsPtr, uintptr_t(0),
274276
reinterpret_cast<uintptr_t>(ThreadStatesPtr),
275277
atomic::seq_cst, atomic::seq_cst))
@@ -298,7 +300,7 @@ void state::resetStateForThread(uint32_t TId) {
298300
return;
299301

300302
ThreadStateTy *PreviousThreadState = ThreadStates[TId]->PreviousThreadState;
301-
__kmpc_free_shared(ThreadStates[TId], sizeof(ThreadStateTy));
303+
memory::freeGlobal(ThreadStates[TId], "ThreadStates dealloc");
302304
ThreadStates[TId] = PreviousThreadState;
303305
}
304306

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// RUN: %libomptarget-compile-run-and-check-generic
2+
// RUN: %libomptarget-compileopt-run-and-check-generic
3+
4+
// These are supported and work, but we compute bogus results on the GPU. For
5+
// now we disable the CPU and enable it once the GPU is fixed.
6+
//
7+
// UNSUPPORTED: aarch64-unknown-linux-gnu
8+
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
9+
// UNSUPPORTED: x86_64-pc-linux-gnu
10+
// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
11+
12+
#include <omp.h>
13+
#include <stdio.h>
14+
15+
int main() {
16+
// TODO: Test all ICVs
17+
int lvl = 333, tid = 666, nt = 999;
18+
#pragma omp target teams map(tofrom : lvl, tid, nt) num_teams(2)
19+
{
20+
if (omp_get_team_num() == 0) {
21+
#pragma omp parallel num_threads(128)
22+
if (omp_get_thread_num() == 17) {
23+
#pragma omp parallel num_threads(64)
24+
if (omp_get_thread_num() == omp_get_num_threads() - 1) {
25+
lvl = omp_get_level();
26+
tid = omp_get_thread_num();
27+
nt = omp_get_num_threads();
28+
}
29+
}
30+
}
31+
}
32+
// TODO: This is wrong, but at least it doesn't crash
33+
// CHECK: lvl: 333, tid: 666, nt: 999
34+
printf("lvl: %i, tid: %i, nt: %i\n", lvl, tid, nt);
35+
return 0;
36+
}

0 commit comments

Comments
 (0)