Skip to content

Commit 5f7b830

Browse files
author
Hugh Delaney
committed
Use new UR entry point to get native mem with device param
1 parent 7bd51c6 commit 5f7b830

File tree

12 files changed

+262
-10
lines changed

12 files changed

+262
-10
lines changed

sycl/include/sycl/detail/pi.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ _PI_API(piMemRetain)
6060
_PI_API(piMemRelease)
6161
_PI_API(piMemBufferPartition)
6262
_PI_API(piextMemGetNativeHandle)
63+
_PI_API(piextMemGetNativeHandleExp)
6364
_PI_API(piextMemCreateWithNativeHandle)
6465
_PI_API(piextMemImageCreateWithNativeHandle)
6566
// Program

sycl/include/sycl/detail/pi.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,6 +1402,13 @@ __SYCL_EXPORT pi_result piMemBufferPartition(
14021402
__SYCL_EXPORT pi_result piextMemGetNativeHandle(pi_mem mem,
14031403
pi_native_handle *nativeHandle);
14041404

1405+
/// Gets the native handle of a PI mem object on a particular device
1406+
///
1407+
/// \param mem is the PI mem to get the native handle of.
1408+
/// \param nativeHandle is the native handle of mem.
1409+
__SYCL_EXPORT pi_result piextMemGetNativeHandleExp(
1410+
pi_mem mem, pi_device dev, pi_native_handle *nativeHandle);
1411+
14051412
/// Creates PI mem object from a native handle.
14061413
/// NOTE: The created PI object takes ownership of the native handle.
14071414
///

sycl/plugins/cuda/pi_cuda.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,11 @@ pi_result piextMemGetNativeHandle(pi_mem Mem, pi_native_handle *NativeHandle) {
232232
return pi2ur::piextMemGetNativeHandle(Mem, NativeHandle);
233233
}
234234

235+
pi_result piextMemGetNativeHandleExp(pi_mem Mem, pi_device Dev,
236+
pi_native_handle *NativeHandle) {
237+
return pi2ur::piextMemGetNativeHandleExp(Mem, Dev, NativeHandle);
238+
}
239+
235240
pi_result piextMemCreateWithNativeHandle(pi_native_handle NativeHandle,
236241
pi_context Context,
237242
bool ownNativeHandle, pi_mem *Mem) {

sycl/plugins/hip/pi_hip.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,11 @@ pi_result piextMemGetNativeHandle(pi_mem Mem, pi_native_handle *NativeHandle) {
240240
return pi2ur::piextMemGetNativeHandle(Mem, NativeHandle);
241241
}
242242

243+
pi_result piextMemGetNativeHandleExp(pi_mem Mem, pi_device Dev,
244+
pi_native_handle *NativeHandle) {
245+
return pi2ur::piextMemGetNativeHandleExp(Mem, Dev, NativeHandle);
246+
}
247+
243248
pi_result piextMemCreateWithNativeHandle(pi_native_handle NativeHandle,
244249
pi_context Context,
245250
bool ownNativeHandle, pi_mem *Mem) {

sycl/plugins/level_zero/pi_level_zero.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,11 @@ pi_result piextMemGetNativeHandle(pi_mem Mem, pi_native_handle *NativeHandle) {
248248
return pi2ur::piextMemGetNativeHandle(Mem, NativeHandle);
249249
}
250250

251+
pi_result piextMemGetNativeHandleExp(pi_mem Mem, pi_device Dev,
252+
pi_native_handle *NativeHandle) {
253+
return pi2ur::piextMemGetNativeHandleExp(Mem, Dev, NativeHandle);
254+
}
255+
251256
pi_result piextMemCreateWithNativeHandle(pi_native_handle NativeHandle,
252257
pi_context Context,
253258
bool ownNativeHandle, pi_mem *Mem) {

sycl/plugins/native_cpu/pi_native_cpu.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,11 @@ pi_result piextMemGetNativeHandle(pi_mem Mem, pi_native_handle *NativeHandle) {
244244
return pi2ur::piextMemGetNativeHandle(Mem, NativeHandle);
245245
}
246246

247+
pi_result piextMemGetNativeHandleExp(pi_mem Mem, pi_device Dev,
248+
pi_native_handle *NativeHandle) {
249+
return pi2ur::piextMemGetNativeHandleExp(Mem, Dev, NativeHandle);
250+
}
251+
247252
pi_result piextMemCreateWithNativeHandle(pi_native_handle NativeHandle,
248253
pi_context Context,
249254
bool ownNativeHandle, pi_mem *Mem) {

sycl/plugins/opencl/pi_opencl.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,11 @@ pi_result piextMemGetNativeHandle(pi_mem Mem, pi_native_handle *NativeHandle) {
226226
return pi2ur::piextMemGetNativeHandle(Mem, NativeHandle);
227227
}
228228

229+
pi_result piextMemGetNativeHandleExp(pi_mem Mem, pi_device Dev,
230+
pi_native_handle *NativeHandle) {
231+
return pi2ur::piextMemGetNativeHandleExp(Mem, Dev, NativeHandle);
232+
}
233+
229234
pi_result piextMemCreateWithNativeHandle(pi_native_handle NativeHandle,
230235
pi_context Context,
231236
bool ownNativeHandle, pi_mem *Mem) {

sycl/plugins/unified_runtime/CMakeLists.txt

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,9 @@ endif()
5656
if(SYCL_PI_UR_USE_FETCH_CONTENT)
5757
include(FetchContent)
5858

59-
set(UNIFIED_RUNTIME_REPO "https://github.com/oneapi-src/unified-runtime.git")
60-
# commit 69a56ea6d1369a6bde5fce97c85fc7dbda49252f
61-
# Merge: b25bb64d b78f541d
62-
# Author: Kenneth Benzie (Benie) <[email protected]>
63-
# Date: Mon Dec 11 12:30:24 2023 +0000
64-
# Merge pull request #1123 from aarongreig/aaron/usmLocationProps
65-
# [OpenCL] Add ur_usm_alloc_location_desc struct and handle it in the CL adapter.
66-
set(UNIFIED_RUNTIME_TAG 69a56ea6d1369a6bde5fce97c85fc7dbda49252f)
59+
# DO NOT MERGE
60+
set(UNIFIED_RUNTIME_REPO "https://github.com/hdelan/unified-runtime.git")
61+
set(UNIFIED_RUNTIME_TAG get-native-mem-on-device)
6762

6863
if(SYCL_PI_UR_OVERRIDE_FETCH_CONTENT_REPO)
6964
set(UNIFIED_RUNTIME_REPO "${SYCL_PI_UR_OVERRIDE_FETCH_CONTENT_REPO}")

sycl/plugins/unified_runtime/pi2ur.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3090,6 +3090,20 @@ inline pi_result piextMemGetNativeHandle(pi_mem Mem,
30903090
return PI_SUCCESS;
30913091
}
30923092

3093+
inline pi_result piextMemGetNativeHandleExp(pi_mem Mem, pi_device Device,
3094+
pi_native_handle *NativeHandle) {
3095+
PI_ASSERT(Mem, PI_ERROR_INVALID_MEM_OBJECT);
3096+
3097+
ur_mem_handle_t UrMem = reinterpret_cast<ur_mem_handle_t>(Mem);
3098+
ur_device_handle_t UrDevice = reinterpret_cast<ur_device_handle_t>(Device);
3099+
ur_native_handle_t NativeMem{};
3100+
HANDLE_ERRORS(urMemGetNativeHandleExp(UrMem, UrDevice, &NativeMem));
3101+
3102+
*NativeHandle = reinterpret_cast<pi_native_handle>(NativeMem);
3103+
3104+
return PI_SUCCESS;
3105+
}
3106+
30933107
inline pi_result
30943108
piEnqueueMemImageCopy(pi_queue Queue, pi_mem SrcImage, pi_mem DstImage,
30953109
pi_image_offset SrcOrigin, pi_image_offset DstOrigin,

sycl/plugins/unified_runtime/pi_unified_runtime.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,11 @@ piextMemGetNativeHandle(pi_mem Mem, pi_native_handle *NativeHandle) {
240240
return pi2ur::piextMemGetNativeHandle(Mem, NativeHandle);
241241
}
242242

243+
__SYCL_EXPORT pi_result piextMemGetNativeHandleExp(
244+
pi_mem Mem, pi_device Device, pi_native_handle *NativeHandle) {
245+
return pi2ur::piextMemGetNativeHandleExp(Mem, Device, NativeHandle);
246+
}
247+
243248
__SYCL_EXPORT pi_result
244249
piEnqueueMemImageCopy(pi_queue Queue, pi_mem SrcImage, pi_mem DstImage,
245250
pi_image_offset SrcOrigin, pi_image_offset DstOrigin,
@@ -1390,6 +1395,7 @@ __SYCL_EXPORT pi_result piPluginInit(pi_plugin *PluginInit) {
13901395
_PI_API(piMemBufferPartition)
13911396
_PI_API(piEnqueueMemImageCopy)
13921397
_PI_API(piextMemGetNativeHandle)
1398+
_PI_API(piextMemGetNativeHandleExp)
13931399
_PI_API(piextMemCreateWithNativeHandle)
13941400
_PI_API(piMemRetain)
13951401
_PI_API(piextUSMGetMemAllocInfo)

sycl/source/interop_handle.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,13 @@ pi_native_handle interop_handle::getNativeMem(detail::Requirement *Req) const {
3434

3535
auto Plugin = MQueue->getPlugin();
3636
pi_native_handle Handle;
37-
Plugin->call<detail::PiApiKind::piextMemGetNativeHandle>(Iter->second,
38-
&Handle);
37+
if (get_backend() == backend::ext_oneapi_cuda ||
38+
get_backend() == backend::ext_oneapi_hip)
39+
Plugin->call<detail::PiApiKind::piextMemGetNativeHandleExp>(
40+
Iter->second, MDevice->getHandleRef(), &Handle);
41+
else
42+
Plugin->call<detail::PiApiKind::piextMemGetNativeHandle>(Iter->second,
43+
&Handle);
3944
return Handle;
4045
}
4146

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
// FIXME: the rocm include path and link path are highly platform dependent,
2+
// we should set this with some variable instead
3+
// RUN: %{build} -o %t.out -I/opt/rocm/include -L/opt/rocm/lib -lamdhip64
4+
// RUN: %{run} %t.out
5+
// REQUIRES: hip
6+
7+
#include <iostream>
8+
#include <sycl/sycl.hpp>
9+
10+
#define __HIP_PLATFORM_AMD__
11+
12+
#include <hip/hip_runtime.h>
13+
14+
using namespace sycl;
15+
using namespace sycl::access;
16+
17+
static constexpr size_t BUFFER_SIZE = 1024;
18+
19+
template <typename T> class Modifier;
20+
21+
template <typename T> class Init;
22+
23+
template <typename BufferT, typename ValueT>
24+
void checkBufferValues(BufferT Buffer, ValueT Value) {
25+
auto Acc = Buffer.get_host_access();
26+
for (size_t Idx = 0; Idx < Acc.get_count(); ++Idx) {
27+
if (Acc[Idx] != Value) {
28+
std::cerr << "buffer[" << Idx << "] = " << Acc[Idx]
29+
<< ", expected val = " << Value << std::endl;
30+
assert(0 && "Invalid data in the buffer");
31+
}
32+
}
33+
}
34+
35+
template <typename DataT>
36+
void copy(buffer<DataT, 1> &Src, buffer<DataT, 1> &Dst, queue &Q) {
37+
Q.submit([&](handler &CGH) {
38+
auto SrcA = Src.template get_access<mode::read>(CGH);
39+
auto DstA = Dst.template get_access<mode::write>(CGH);
40+
41+
auto Func = [=](interop_handle IH) {
42+
auto HipStream = IH.get_native_queue<backend::ext_oneapi_hip>();
43+
auto SrcMem = IH.get_native_mem<backend::ext_oneapi_hip>(SrcA);
44+
auto DstMem = IH.get_native_mem<backend::ext_oneapi_hip>(DstA);
45+
cl_event Event;
46+
47+
if (hipMemcpyWithStream(DstMem, SrcMem, sizeof(DataT) * SrcA.get_count(),
48+
hipMemcpyDefault, HipStream) != hipSuccess) {
49+
throw;
50+
}
51+
52+
if (hipStreamSynchronize(HipStream) != hipSuccess) {
53+
throw;
54+
}
55+
56+
if (Q.get_backend() != IH.get_backend())
57+
throw;
58+
};
59+
CGH.host_task(Func);
60+
});
61+
}
62+
63+
template <typename DataT> void modify(buffer<DataT, 1> &B, queue &Q) {
64+
Q.submit([&](handler &CGH) {
65+
auto Acc = B.template get_access<mode::read_write>(CGH);
66+
67+
auto Kernel = [=](item<1> Id) { Acc[Id] += 1; };
68+
69+
CGH.parallel_for<Modifier<DataT>>(Acc.get_count(), Kernel);
70+
});
71+
}
72+
73+
template <typename DataT, DataT B1Init, DataT B2Init>
74+
void init(buffer<DataT, 1> &B1, buffer<DataT, 1> &B2, queue &Q) {
75+
Q.submit([&](handler &CGH) {
76+
auto Acc1 = B1.template get_access<mode::write>(CGH);
77+
auto Acc2 = B2.template get_access<mode::write>(CGH);
78+
79+
CGH.parallel_for<Init<DataT>>(BUFFER_SIZE, [=](item<1> Id) {
80+
Acc1[Id] = -1;
81+
Acc2[Id] = -2;
82+
});
83+
});
84+
}
85+
86+
// A test that uses OpenCL interop to copy data from buffer A to buffer B, by
87+
// getting cl_mem objects and calling the clEnqueueBufferCopy. Then run a SYCL
88+
// kernel that modifies the data in place for B, e.g. increment one, then copy
89+
// back to buffer A. Run it on a loop, to ensure the dependencies and the
90+
// reference counting of the objects is not leaked.
91+
void test1(queue &Q) {
92+
static constexpr int COUNT = 4;
93+
buffer<int, 1> Buffer1{BUFFER_SIZE};
94+
buffer<int, 1> Buffer2{BUFFER_SIZE};
95+
96+
// init the buffer with a'priori invalid data
97+
init<int, -1, -2>(Buffer1, Buffer2, Q);
98+
99+
// Repeat a couple of times
100+
for (size_t Idx = 0; Idx < COUNT; ++Idx) {
101+
copy(Buffer1, Buffer2, Q);
102+
modify(Buffer2, Q);
103+
copy(Buffer2, Buffer1, Q);
104+
}
105+
106+
checkBufferValues(Buffer1, COUNT - 1);
107+
checkBufferValues(Buffer2, COUNT - 1);
108+
}
109+
110+
// Same as above, but performing each command group on a separate SYCL queue
111+
// (on the same or different devices). This ensures the dependency tracking
112+
// works well but also there is no accidental side effects on other queues.
113+
void test2(queue &Q) {
114+
static constexpr int COUNT = 4;
115+
buffer<int, 1> Buffer1{BUFFER_SIZE};
116+
buffer<int, 1> Buffer2{BUFFER_SIZE};
117+
118+
// init the buffer with a'priori invalid data
119+
init<int, -1, -2>(Buffer1, Buffer2, Q);
120+
121+
// Repeat a couple of times
122+
for (size_t Idx = 0; Idx < COUNT; ++Idx) {
123+
copy(Buffer1, Buffer2, Q);
124+
modify(Buffer2, Q);
125+
copy(Buffer2, Buffer1, Q);
126+
}
127+
checkBufferValues(Buffer1, COUNT - 1);
128+
checkBufferValues(Buffer2, COUNT - 1);
129+
}
130+
131+
// Same as above but with queue constructed out of context
132+
void test2_1(queue &Q) {
133+
static constexpr int COUNT = 4;
134+
buffer<int, 1> Buffer1{BUFFER_SIZE};
135+
buffer<int, 1> Buffer2{BUFFER_SIZE};
136+
137+
device Device;
138+
auto Context = context(Device);
139+
// init the buffer with a'priori invalid data
140+
init<int, -1, -2>(Buffer1, Buffer2, Q);
141+
142+
// Repeat a couple of times
143+
for (size_t Idx = 0; Idx < COUNT; ++Idx) {
144+
copy(Buffer1, Buffer2, Q);
145+
modify(Buffer2, Q);
146+
copy(Buffer2, Buffer1, Q);
147+
}
148+
checkBufferValues(Buffer1, COUNT - 1);
149+
checkBufferValues(Buffer2, COUNT - 1);
150+
}
151+
152+
// Check that a single host-interop-task with a buffer will work
153+
void test3(queue &Q) {
154+
buffer<int, 1> Buffer{BUFFER_SIZE};
155+
156+
Q.submit([&](handler &CGH) {
157+
auto Acc = Buffer.get_access<mode::write>(CGH);
158+
auto Func = [=](interop_handle IH) { /*A no-op */ };
159+
CGH.host_task(Func);
160+
});
161+
}
162+
163+
void test4(queue &Q) {
164+
buffer<int, 1> Buffer1{BUFFER_SIZE};
165+
buffer<int, 1> Buffer2{BUFFER_SIZE};
166+
167+
Q.submit([&](handler &CGH) {
168+
auto Acc = Buffer1.template get_access<mode::write>(CGH);
169+
170+
auto Kernel = [=](item<1> Id) { Acc[Id] = 123; };
171+
CGH.parallel_for<class Test5Init>(Acc.get_count(), Kernel);
172+
});
173+
174+
copy(Buffer1, Buffer2, Q);
175+
176+
checkBufferValues(Buffer2, static_cast<int>(123));
177+
}
178+
179+
void tests(queue &Q) {
180+
test1(Q);
181+
test2(Q);
182+
test2_1(Q);
183+
test3(Q);
184+
test4(Q);
185+
}
186+
187+
int main() {
188+
queue Q([](sycl::exception_list ExceptionList) {
189+
if (ExceptionList.size() != 1) {
190+
std::cerr << "Should be one exception in exception list" << std::endl;
191+
std::abort();
192+
}
193+
std::rethrow_exception(*ExceptionList.begin());
194+
});
195+
tests(Q);
196+
tests(Q);
197+
std::cout << "Test PASSED" << std::endl;
198+
return 0;
199+
}

0 commit comments

Comments
 (0)