Skip to content

Commit 8427bd2

Browse files
authored
[SYCL][HIP][CUDA] Use new version of piMemGetNativeHandle and add test (#12297)
We want to change the signature of `piMemGetNativeHandle` for reasons explained here oneapi-src/unified-runtime#1199 Corresponding UR PR: oneapi-src/unified-runtime#1226 A previous PR added a new entry point #12199 but it was decided that it is better to modify the existing entry point
1 parent 4fdcb58 commit 8427bd2

File tree

14 files changed

+188
-31
lines changed

14 files changed

+188
-31
lines changed

sycl/include/sycl/detail/pi.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,11 @@
149149
// 14.40 Add HIP _pi_mem_advice alises to match the PI_MEM_ADVICE_CUDA* ones.
150150
// 14.41 Added piextCommandBufferMemBufferFill & piextCommandBufferFillUSM
151151
// 14.42 Added piextCommandBufferPrefetchUSM and piextCommandBufferAdviseUSM
152+
// 15.43 Changed the signature of piextMemGetNativeHandle to also take a
153+
// pi_device
152154

153-
#define _PI_H_VERSION_MAJOR 14
154-
#define _PI_H_VERSION_MINOR 42
155+
#define _PI_H_VERSION_MAJOR 15
156+
#define _PI_H_VERSION_MINOR 43
155157

156158
#define _PI_STRING_HELPER(a) #a
157159
#define _PI_CONCAT(a, b) _PI_STRING_HELPER(a.b)
@@ -1424,8 +1426,9 @@ __SYCL_EXPORT pi_result piMemBufferPartition(
14241426
/// Gets the native handle of a PI mem object.
14251427
///
14261428
/// \param mem is the PI mem to get the native handle of.
1429+
/// \param dev is the PI device that the native allocation will be resident on
14271430
/// \param nativeHandle is the native handle of mem.
1428-
__SYCL_EXPORT pi_result piextMemGetNativeHandle(pi_mem mem,
1431+
__SYCL_EXPORT pi_result piextMemGetNativeHandle(pi_mem mem, pi_device dev,
14291432
pi_native_handle *nativeHandle);
14301433

14311434
/// Creates PI mem object from a native handle.

sycl/plugins/cuda/pi_cuda.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,9 @@ pi_result piMemImageCreate(pi_context Context, pi_mem_flags Flags,
233233
HostPtr, RetImage);
234234
}
235235

236-
pi_result piextMemGetNativeHandle(pi_mem Mem, pi_native_handle *NativeHandle) {
237-
return pi2ur::piextMemGetNativeHandle(Mem, NativeHandle);
236+
pi_result piextMemGetNativeHandle(pi_mem Mem, pi_device Dev,
237+
pi_native_handle *NativeHandle) {
238+
return pi2ur::piextMemGetNativeHandle(Mem, Dev, NativeHandle);
238239
}
239240

240241
pi_result piextMemCreateWithNativeHandle(pi_native_handle NativeHandle,

sycl/plugins/hip/pi_hip.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,9 @@ pi_result piMemImageCreate(pi_context Context, pi_mem_flags Flags,
236236
HostPtr, RetImage);
237237
}
238238

239-
pi_result piextMemGetNativeHandle(pi_mem Mem, pi_native_handle *NativeHandle) {
240-
return pi2ur::piextMemGetNativeHandle(Mem, NativeHandle);
239+
pi_result piextMemGetNativeHandle(pi_mem Mem, pi_device Dev,
240+
pi_native_handle *NativeHandle) {
241+
return pi2ur::piextMemGetNativeHandle(Mem, Dev, NativeHandle);
241242
}
242243

243244
pi_result piextMemCreateWithNativeHandle(pi_native_handle NativeHandle,

sycl/plugins/level_zero/pi_level_zero.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,9 @@ pi_result piMemImageCreate(pi_context Context, pi_mem_flags Flags,
243243
HostPtr, RetImage);
244244
}
245245

246-
pi_result piextMemGetNativeHandle(pi_mem Mem, pi_native_handle *NativeHandle) {
247-
return pi2ur::piextMemGetNativeHandle(Mem, NativeHandle);
246+
pi_result piextMemGetNativeHandle(pi_mem Mem, pi_device Dev,
247+
pi_native_handle *NativeHandle) {
248+
return pi2ur::piextMemGetNativeHandle(Mem, Dev, NativeHandle);
248249
}
249250

250251
pi_result piextMemCreateWithNativeHandle(pi_native_handle NativeHandle,

sycl/plugins/native_cpu/pi_native_cpu.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,9 @@ pi_result piMemImageCreate(pi_context Context, pi_mem_flags Flags,
240240
HostPtr, RetImage);
241241
}
242242

243-
pi_result piextMemGetNativeHandle(pi_mem Mem, pi_native_handle *NativeHandle) {
244-
return pi2ur::piextMemGetNativeHandle(Mem, NativeHandle);
243+
pi_result piextMemGetNativeHandle(pi_mem Mem, pi_device Dev,
244+
pi_native_handle *NativeHandle) {
245+
return pi2ur::piextMemGetNativeHandle(Mem, Dev, NativeHandle);
245246
}
246247

247248
pi_result piextMemCreateWithNativeHandle(pi_native_handle NativeHandle,

sycl/plugins/opencl/pi_opencl.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,9 @@ pi_result piMemImageCreate(pi_context Context, pi_mem_flags Flags,
222222
HostPtr, RetImage);
223223
}
224224

225-
pi_result piextMemGetNativeHandle(pi_mem Mem, pi_native_handle *NativeHandle) {
226-
return pi2ur::piextMemGetNativeHandle(Mem, NativeHandle);
225+
pi_result piextMemGetNativeHandle(pi_mem Mem, pi_device Dev,
226+
pi_native_handle *NativeHandle) {
227+
return pi2ur::piextMemGetNativeHandle(Mem, Dev, NativeHandle);
227228
}
228229

229230
pi_result piextMemCreateWithNativeHandle(pi_native_handle NativeHandle,

sycl/plugins/unified_runtime/CMakeLists.txt

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,14 @@ endif()
5656
if(SYCL_PI_UR_USE_FETCH_CONTENT)
5757
include(FetchContent)
5858

59-
set(UNIFIED_RUNTIME_REPO "https://github.com/oneapi-src/unified-runtime")
60-
# commit 3225b822b5d8cbfa85d7fc1bd5a5bf96e5bb8c1a
61-
# Merge: edb281f3 5fc41099
59+
set(UNIFIED_RUNTIME_REPO "https://github.com/oneapi-src/unified-runtime.git")
60+
# commit d216eb44d5c9fe3433eecdd09b10e3e79ac25bd7
61+
# Merge: 40517d2b fc1f3066
6262
# Author: Kenneth Benzie (Benie) <[email protected]>
63-
# Date: Tue Jan 30 12:31:44 2024 +0000
64-
# Merge pull request #1168 from Seanst98/sean/unique-addr-mode-per-dim-adapters
65-
# [Bindless][CUDA] Unique addressing modes per dimension
66-
set(UNIFIED_RUNTIME_TAG 3225b822b5d8cbfa85d7fc1bd5a5bf96e5bb8c1a)
63+
# Date: Wed Jan 31 10:38:07 2024 +0000
64+
# Merge pull request #1226 from hdelan/get-native-mem-on-device2
65+
# [UR] Add extra param to urMemGetNativeHandle
66+
set(UNIFIED_RUNTIME_TAG d216eb44d5c9fe3433eecdd09b10e3e79ac25bd7)
6767

6868
if(SYCL_PI_UR_OVERRIDE_FETCH_CONTENT_REPO)
6969
set(UNIFIED_RUNTIME_REPO "${SYCL_PI_UR_OVERRIDE_FETCH_CONTENT_REPO}")

sycl/plugins/unified_runtime/pi2ur.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3081,13 +3081,14 @@ inline pi_result piMemBufferPartition(pi_mem Buffer, pi_mem_flags Flags,
30813081
return PI_SUCCESS;
30823082
}
30833083

3084-
inline pi_result piextMemGetNativeHandle(pi_mem Mem,
3084+
inline pi_result piextMemGetNativeHandle(pi_mem Mem, pi_device Dev,
30853085
pi_native_handle *NativeHandle) {
30863086
PI_ASSERT(Mem, PI_ERROR_INVALID_MEM_OBJECT);
30873087

30883088
ur_mem_handle_t UrMem = reinterpret_cast<ur_mem_handle_t>(Mem);
3089+
ur_device_handle_t UrDev = reinterpret_cast<ur_device_handle_t>(Dev);
30893090
ur_native_handle_t NativeMem{};
3090-
HANDLE_ERRORS(urMemGetNativeHandle(UrMem, &NativeMem));
3091+
HANDLE_ERRORS(urMemGetNativeHandle(UrMem, UrDev, &NativeMem));
30913092

30923093
*NativeHandle = reinterpret_cast<pi_native_handle>(NativeMem);
30933094

sycl/plugins/unified_runtime/pi_unified_runtime.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,9 @@ __SYCL_EXPORT pi_result piMemBufferPartition(
235235
BufferCreateInfo, RetMem);
236236
}
237237

238-
__SYCL_EXPORT pi_result
239-
piextMemGetNativeHandle(pi_mem Mem, pi_native_handle *NativeHandle) {
240-
return pi2ur::piextMemGetNativeHandle(Mem, NativeHandle);
238+
__SYCL_EXPORT pi_result piextMemGetNativeHandle(
239+
pi_mem Mem, pi_device Dev, pi_native_handle *NativeHandle) {
240+
return pi2ur::piextMemGetNativeHandle(Mem, Dev, NativeHandle);
241241
}
242242

243243
__SYCL_EXPORT pi_result

sycl/source/detail/buffer_impl.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,11 @@ buffer_impl::getNativeVector(backend BackendName) const {
8484
}
8585

8686
pi_native_handle Handle;
87-
Plugin->call<PiApiKind::piextMemGetNativeHandle>(NativeMem, &Handle);
87+
// When doing buffer interop we don't know what device the memory should be
88+
// resident on, so pass nullptr for Device param. Buffer interop may not be
89+
// supported by all backends.
90+
Plugin->call<PiApiKind::piextMemGetNativeHandle>(NativeMem, /*Dev*/ nullptr,
91+
&Handle);
8892
Handles.push_back(Handle);
8993
}
9094

sycl/source/detail/memory_manager.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,11 @@ void memBufferCreateHelper(const PluginPtr &Plugin, pi_context Ctx,
144144
// Always use call_nocheck here, because call may throw an exception,
145145
// and this lambda will be called from destructor, which in combination
146146
// rewards us with UB.
147-
Plugin->call_nocheck<PiApiKind::piextMemGetNativeHandle>(*RetMem, &Ptr);
147+
// When doing buffer interop we don't know what device the memory should
148+
// be resident on, so pass nullptr for Device param. Buffer interop may
149+
// not be supported by all backends.
150+
Plugin->call_nocheck<PiApiKind::piextMemGetNativeHandle>(
151+
*RetMem, /*Dev*/ nullptr, &Ptr);
148152
emitMemAllocEndTrace(MemObjID, (uintptr_t)(Ptr), Size, 0 /* guard zone */,
149153
CorrID);
150154
}};
@@ -167,7 +171,11 @@ void memReleaseHelper(const PluginPtr &Plugin, pi_mem Mem) {
167171
// Do not make unnecessary PI calls without instrumentation enabled
168172
if (xptiTraceEnabled()) {
169173
pi_native_handle PtrHandle = 0;
170-
Plugin->call<PiApiKind::piextMemGetNativeHandle>(Mem, &PtrHandle);
174+
// When doing buffer interop we don't know what device the memory should be
175+
// resident on, so pass nullptr for Device param. Buffer interop may not be
176+
// supported by all backends.
177+
Plugin->call<PiApiKind::piextMemGetNativeHandle>(Mem, /*Dev*/ nullptr,
178+
&PtrHandle);
171179
Ptr = (uintptr_t)(PtrHandle);
172180
}
173181
#endif

sycl/source/interop_handle.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ pi_native_handle interop_handle::getNativeMem(detail::Requirement *Req) const {
3434

3535
auto Plugin = MQueue->getPlugin();
3636
pi_native_handle Handle;
37-
Plugin->call<detail::PiApiKind::piextMemGetNativeHandle>(Iter->second,
38-
&Handle);
37+
Plugin->call<detail::PiApiKind::piextMemGetNativeHandle>(
38+
Iter->second, MDevice->getHandleRef(), &Handle);
3939
return Handle;
4040
}
4141

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
// FIXME: the rocm include path and link path are highly platform dependent,
2+
// we should set this with some variable instead.
3+
// RUN: %{build} -o %t.out -I/opt/rocm/include -L/opt/rocm/lib -lamdhip64
4+
// RUN: %{run} %t.out
5+
// REQUIRES: hip
6+
7+
#include <iostream>
8+
#include <sycl/sycl.hpp>
9+
10+
#define __HIP_PLATFORM_AMD__
11+
12+
#include <hip/hip_runtime.h>
13+
14+
using namespace sycl;
15+
using namespace sycl::access;
16+
17+
static constexpr size_t BUFFER_SIZE = 1024;
18+
19+
template <typename T> class Modifier;
20+
21+
template <typename T> class Init;
22+
23+
template <typename BufferT, typename ValueT>
24+
void checkBufferValues(BufferT Buffer, ValueT Value) {
25+
auto Acc = Buffer.get_host_access();
26+
for (size_t Idx = 0; Idx < Acc.get_count(); ++Idx) {
27+
if (Acc[Idx] != Value) {
28+
std::cerr << "buffer[" << Idx << "] = " << Acc[Idx]
29+
<< ", expected val = " << Value << '\n';
30+
exit(1);
31+
}
32+
}
33+
}
34+
35+
template <typename DataT>
36+
void copy(buffer<DataT, 1> &Src, buffer<DataT, 1> &Dst, queue &Q) {
37+
Q.submit([&](handler &CGH) {
38+
auto SrcA = Src.template get_access<mode::read>(CGH);
39+
auto DstA = Dst.template get_access<mode::write>(CGH);
40+
41+
auto Func = [=](interop_handle IH) {
42+
auto HipStream = IH.get_native_queue<backend::ext_oneapi_hip>();
43+
auto SrcMem = IH.get_native_mem<backend::ext_oneapi_hip>(SrcA);
44+
auto DstMem = IH.get_native_mem<backend::ext_oneapi_hip>(DstA);
45+
46+
if (hipMemcpyWithStream(DstMem, SrcMem, sizeof(DataT) * SrcA.get_count(),
47+
hipMemcpyDefault, HipStream) != hipSuccess) {
48+
throw;
49+
}
50+
51+
if (hipStreamSynchronize(HipStream) != hipSuccess) {
52+
throw;
53+
}
54+
55+
if (Q.get_backend() != IH.get_backend())
56+
throw;
57+
};
58+
CGH.host_task(Func);
59+
});
60+
}
61+
62+
template <typename DataT> void modify(buffer<DataT, 1> &B, queue &Q) {
63+
Q.submit([&](handler &CGH) {
64+
auto Acc = B.template get_access<mode::read_write>(CGH);
65+
66+
auto Kernel = [=](item<1> Id) { Acc[Id] += 1; };
67+
68+
CGH.parallel_for<Modifier<DataT>>(Acc.get_count(), Kernel);
69+
});
70+
}
71+
72+
template <typename DataT, DataT B1Init, DataT B2Init>
73+
void init(buffer<DataT, 1> &B1, buffer<DataT, 1> &B2, queue &Q) {
74+
Q.submit([&](handler &CGH) {
75+
auto Acc1 = B1.template get_access<mode::write>(CGH);
76+
auto Acc2 = B2.template get_access<mode::write>(CGH);
77+
78+
CGH.parallel_for<Init<DataT>>(BUFFER_SIZE, [=](item<1> Id) {
79+
Acc1[Id] = B1Init;
80+
Acc2[Id] = B2Init;
81+
});
82+
});
83+
}
84+
85+
// Check that a single host-interop-task with a buffer will work.
86+
void test_ht_buffer(queue &Q) {
87+
buffer<int, 1> Buffer{BUFFER_SIZE};
88+
89+
Q.submit([&](handler &CGH) {
90+
auto Acc = Buffer.get_access<mode::write>(CGH);
91+
auto Func = [=](interop_handle IH) { /*A no-op */ };
92+
CGH.host_task(Func);
93+
});
94+
}
95+
96+
// A test that uses HIP interop to copy data from buffer A to buffer B, by
97+
// getting HIP ptrs and calling the hipMemcpyWithStream. Then run a SYCL
98+
// kernel that modifies the data in place for B, e.g. increment one, then copy
99+
// back to buffer A. Run it on a loop, to ensure the dependencies and the
100+
// reference counting of the objects is not leaked.
101+
void test_ht_kernel_dependencies(queue &Q) {
102+
static constexpr int COUNT = 4;
103+
buffer<int, 1> Buffer1{BUFFER_SIZE};
104+
buffer<int, 1> Buffer2{BUFFER_SIZE};
105+
106+
// Init the buffer with a'priori invalid data.
107+
init<int, -1, -2>(Buffer1, Buffer2, Q);
108+
109+
// Repeat a couple of times.
110+
for (size_t Idx = 0; Idx < COUNT; ++Idx) {
111+
copy(Buffer1, Buffer2, Q);
112+
modify(Buffer2, Q);
113+
copy(Buffer2, Buffer1, Q);
114+
}
115+
116+
checkBufferValues(Buffer1, COUNT - 1);
117+
checkBufferValues(Buffer2, COUNT - 1);
118+
}
119+
120+
void tests(queue &Q) {
121+
test_ht_buffer(Q);
122+
test_ht_kernel_dependencies(Q);
123+
}
124+
125+
int main() {
126+
queue Q([](sycl::exception_list ExceptionList) {
127+
if (ExceptionList.size() != 1) {
128+
std::cerr << "Should be one exception in exception list" << std::endl;
129+
std::abort();
130+
}
131+
std::rethrow_exception(*ExceptionList.begin());
132+
});
133+
tests(Q);
134+
std::cout << "Test PASSED" << std::endl;
135+
return 0;
136+
}

sycl/unittests/helpers/PiMockPlugin.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ mock_piMemBufferPartition(pi_mem buffer, pi_mem_flags flags,
613613
return PI_SUCCESS;
614614
}
615615

616-
inline pi_result mock_piextMemGetNativeHandle(pi_mem mem,
616+
inline pi_result mock_piextMemGetNativeHandle(pi_mem mem, pi_device dev,
617617
pi_native_handle *nativeHandle) {
618618
*nativeHandle = reinterpret_cast<pi_native_handle>(mem);
619619
return PI_SUCCESS;

0 commit comments

Comments
 (0)