Skip to content

Commit cc8c6b0

Browse files
authored
[OpenMP] [amdgpu] Added a synchronous version of data exchange. (#87032)
Similar to H2D and D2H, use synchronous mode for large data transfers beyond a certain size for D2D as well. As with H2D and D2H, this size is controlled by an env-var.
1 parent 3a106e5 commit cc8c6b0

File tree

2 files changed

+93
-0
lines changed

2 files changed

+93
-0
lines changed

openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2402,6 +2402,27 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
24022402
AsyncInfoWrapperTy &AsyncInfoWrapper) override {
24032403
AMDGPUDeviceTy &DstDevice = static_cast<AMDGPUDeviceTy &>(DstGenericDevice);
24042404

2405+
// For large transfers use synchronous behavior.
2406+
if (Size >= OMPX_MaxAsyncCopyBytes) {
2407+
if (AsyncInfoWrapper.hasQueue())
2408+
if (auto Err = synchronize(AsyncInfoWrapper))
2409+
return Err;
2410+
2411+
AMDGPUSignalTy Signal;
2412+
if (auto Err = Signal.init())
2413+
return Err;
2414+
2415+
if (auto Err = utils::asyncMemCopy(
2416+
useMultipleSdmaEngines(), DstPtr, DstDevice.getAgent(), SrcPtr,
2417+
getAgent(), (uint64_t)Size, 0, nullptr, Signal.get()))
2418+
return Err;
2419+
2420+
if (auto Err = Signal.wait(getStreamBusyWaitMicroseconds()))
2421+
return Err;
2422+
2423+
return Signal.deinit();
2424+
}
2425+
24052426
AMDGPUStreamTy *Stream = nullptr;
24062427
if (auto Err = getStream(AsyncInfoWrapper, Stream))
24072428
return Err;
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// RUN: %libomptarget-compile-generic && \
2+
// RUN: env LIBOMPTARGET_AMDGPU_MAX_ASYNC_COPY_BYTES=0 %libomptarget-run-generic | \
3+
// RUN: %fcheck-generic -allow-empty
4+
// REQUIRES: amdgcn-amd-amdhsa
5+
6+
#include <assert.h>
7+
#include <omp.h>
8+
#include <stdio.h>
9+
#include <stdlib.h>
10+
11+
const int magic_num = 7;
12+
13+
int main(int argc, char *argv[]) {
14+
const int N = 128;
15+
const int num_devices = omp_get_num_devices();
16+
17+
// No target device, just return
18+
if (num_devices == 0) {
19+
printf("PASS\n");
20+
return 0;
21+
}
22+
23+
const int src_device = 0;
24+
int dst_device = num_devices - 1;
25+
26+
int length = N * sizeof(int);
27+
int *src_ptr = omp_target_alloc(length, src_device);
28+
int *dst_ptr = omp_target_alloc(length, dst_device);
29+
30+
if (!src_ptr || !dst_ptr) {
31+
printf("FAIL\n");
32+
return 1;
33+
}
34+
35+
#pragma omp target teams distribute parallel for device(src_device) \
36+
is_device_ptr(src_ptr)
37+
for (int i = 0; i < N; ++i) {
38+
src_ptr[i] = magic_num;
39+
}
40+
41+
if (omp_target_memcpy(dst_ptr, src_ptr, length, 0, 0, dst_device,
42+
src_device)) {
43+
printf("FAIL\n");
44+
return 1;
45+
}
46+
47+
int *buffer = malloc(length);
48+
if (!buffer) {
49+
printf("FAIL\n");
50+
return 1;
51+
}
52+
53+
#pragma omp target teams distribute parallel for device(dst_device) \
54+
map(from : buffer[0 : N]) is_device_ptr(dst_ptr)
55+
for (int i = 0; i < N; ++i) {
56+
buffer[i] = dst_ptr[i] + magic_num;
57+
}
58+
59+
for (int i = 0; i < N; ++i)
60+
assert(buffer[i] == 2 * magic_num);
61+
62+
printf("PASS\n");
63+
64+
// Free host and device memory
65+
free(buffer);
66+
omp_target_free(src_ptr, src_device);
67+
omp_target_free(dst_ptr, dst_device);
68+
69+
return 0;
70+
}
71+
72+
// CHECK: PASS

0 commit comments

Comments
 (0)