Skip to content

Commit a884b13

Browse files
author
Steffen Larsen
committed
[SYCL][PI][OpenCL] Generalizing interop handler getters
This commit makes get_mem and get_queue of interop_handler return types based on a specified backend. The backend defaults to OpenCL to avoid breakages. Signed-off-by: Steffen Larsen <[email protected]>
1 parent 25a76c5 commit a884b13

File tree

7 files changed

+142
-31
lines changed

7 files changed

+142
-31
lines changed

sycl/include/CL/sycl.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <CL/sycl/accessor.hpp>
1212
#include <CL/sycl/atomic.hpp>
1313
#include <CL/sycl/backend.hpp>
14+
#include <CL/sycl/backend/opencl.hpp>
1415
#include <CL/sycl/buffer.hpp>
1516
#include <CL/sycl/builtins.hpp>
1617
#include <CL/sycl/context.hpp>
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
2+
//==---------------- opencl.hpp - SYCL OpenCL backend ----------------------==//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#pragma once
11+
12+
#include <CL/cl.h>
13+
#include <CL/sycl/accessor.hpp>
14+
#include <CL/sycl/backend_types.hpp>
15+
16+
__SYCL_INLINE_NAMESPACE(cl) {
17+
namespace sycl {
18+
19+
template <> struct interop<backend::opencl, queue> {
20+
using type = cl_command_queue;
21+
};
22+
23+
template <typename DataT, int Dimensions, access::mode AccessMode>
24+
struct interop<backend::opencl, accessor<DataT, Dimensions, AccessMode,
25+
access::target::global_buffer,
26+
access::placeholder::false_t>> {
27+
using type = cl_mem;
28+
};
29+
30+
} // namespace sycl
31+
} // __SYCL_INLINE_NAMESPACE(cl)

sycl/include/CL/sycl/detail/cg.hpp

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#pragma once
1010

11+
#include <CL/sycl/backend_types.hpp>
1112
#include <CL/sycl/detail/accessor_impl.hpp>
1213
#include <CL/sycl/detail/common.hpp>
1314
#include <CL/sycl/detail/export.hpp>
@@ -29,6 +30,12 @@
2930
__SYCL_INLINE_NAMESPACE(cl) {
3031
namespace sycl {
3132

33+
// Forward declaration
34+
class queue;
35+
namespace detail {
36+
class queue_impl;
37+
} // namespace detail
38+
3239
// Interoperability handler
3340
//
3441
class interop_handler {
@@ -37,26 +44,51 @@ class interop_handler {
3744
access::target AccTarget, access::placeholder isPlaceholder>
3845
friend class accessor;
3946
public:
47+
using QueueImplPtr = std::shared_ptr<detail::queue_impl>;
4048
using ReqToMem = std::pair<detail::Requirement*, pi_mem>;
4149

42-
interop_handler(std::vector<ReqToMem> MemObjs, cl_command_queue PiQueue) :
43-
MQueue(PiQueue), MMemObjs(MemObjs) {}
50+
interop_handler(std::vector<ReqToMem> MemObjs, QueueImplPtr Queue)
51+
: MQueue(std::move(Queue)), MMemObjs(std::move(MemObjs)) {}
4452

45-
cl_command_queue get_queue() const noexcept { return MQueue; };
53+
template <backend BackendName = backend::opencl>
54+
auto get_queue() const -> typename interop<BackendName, queue>::type {
55+
return reinterpret_cast<typename interop<BackendName, queue>::type>(
56+
GetNativeQueue());
57+
}
4658

47-
template <typename DataT, int Dims, access::mode AccessMode,
48-
access::target AccessTarget,
59+
template <backend BackendName = backend::opencl, typename DataT, int Dims,
60+
access::mode AccessMode, access::target AccessTarget,
4961
access::placeholder IsPlaceholder = access::placeholder::false_t>
50-
cl_mem get_mem(accessor<DataT, Dims, AccessMode, AccessTarget,
51-
access::placeholder::false_t>
52-
Acc) const {
62+
auto get_mem(accessor<DataT, Dims, AccessMode, AccessTarget,
63+
access::placeholder::false_t>
64+
Acc) const ->
65+
typename interop<BackendName,
66+
accessor<DataT, Dims, AccessMode, AccessTarget,
67+
access::placeholder::false_t>>::type {
5368
detail::AccessorBaseHost *AccBase = (detail::AccessorBaseHost *)&Acc;
54-
return getMemImpl(detail::getSyclObjImpl(*AccBase).get());
69+
return getMemImpl<BackendName, DataT, Dims, AccessMode, AccessTarget,
70+
access::placeholder::false_t>(
71+
detail::getSyclObjImpl(*AccBase).get());
5572
}
73+
5674
private:
57-
cl_command_queue MQueue;
75+
QueueImplPtr MQueue;
5876
std::vector<ReqToMem> MMemObjs;
59-
__SYCL_EXPORT cl_mem getMemImpl(detail::Requirement *Req) const;
77+
78+
template <backend BackendName, typename DataT, int Dims,
79+
access::mode AccessMode, access::target AccessTarget,
80+
access::placeholder IsPlaceholder>
81+
__SYCL_EXPORT auto
82+
getMemImpl(detail::Requirement *Req) const -> typename interop<
83+
BackendName,
84+
accessor<DataT, Dims, AccessMode, AccessTarget, IsPlaceholder>>::type {
85+
return (typename interop<BackendName,
86+
accessor<DataT, Dims, AccessMode, AccessTarget,
87+
IsPlaceholder>>::type)GetNativeMem(Req);
88+
}
89+
90+
__SYCL_EXPORT pi_native_handle GetNativeMem(detail::Requirement *Req) const;
91+
__SYCL_EXPORT pi_native_handle GetNativeQueue() const;
6092
};
6193

6294
namespace detail {

sycl/plugins/opencl/pi_opencl.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,23 @@ static pi_result OCL(piextProgramSetSpecializationConstantImpl)(
10431043
return cast<pi_result>(Res);
10441044
}
10451045

1046+
/// API to get the native handle of a PI object
1047+
///
1048+
/// \param handleType is an identifier representing the type of the handle
1049+
/// \param piObject is the PI object to get the handle of
1050+
/// \param nativeHandle is the native handle of piObject
1051+
pi_result OCL(piGetNativeHandle)(pi_handle_type handleType, void *piObject,
1052+
pi_native_handle *nativeHandle) {
1053+
switch (handleType) {
1054+
case pi_handle_type::PI_NATIVE_HANDLE_MEM:
1055+
case pi_handle_type::PI_NATIVE_HANDLE_QUEUE:
1056+
*nativeHandle = reinterpret_cast<pi_native_handle>(piObject);
1057+
return PI_SUCCESS;
1058+
default:
1059+
return PI_INVALID_VALUE;
1060+
}
1061+
}
1062+
10461063
pi_result piPluginInit(pi_plugin *PluginInit) {
10471064
int CompareVersions = strcmp(PluginInit->PiVersion, SupportedVersion);
10481065
if (CompareVersions < 0) {
@@ -1154,6 +1171,8 @@ pi_result piPluginInit(pi_plugin *PluginInit) {
11541171
_PI_CL(piextUSMEnqueuePrefetch, OCL(piextUSMEnqueuePrefetch))
11551172
_PI_CL(piextUSMEnqueueMemAdvise, OCL(piextUSMEnqueueMemAdvise))
11561173
_PI_CL(piextUSMGetMemAllocInfo, OCL(piextUSMGetMemAllocInfo))
1174+
// Native
1175+
_PI_CL(piGetNativeHandle, OCL(piGetNativeHandle))
11571176

11581177
_PI_CL(piextKernelSetArgMemObj, OCL(piextKernelSetArgMemObj))
11591178

sycl/source/detail/cg.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88

99
#include "CL/sycl/detail/cg.hpp"
1010
#include <CL/sycl/detail/memory_manager.hpp>
11+
#include <CL/sycl/detail/pi.hpp>
1112
#include <detail/queue_impl.hpp>
1213
#include <detail/scheduler/commands.hpp>
1314
#include <detail/scheduler/scheduler.hpp>
1415

15-
1616
#include <memory>
1717
#include <string>
1818
#include <type_traits>
@@ -21,17 +21,24 @@
2121
namespace cl {
2222
namespace sycl {
2323

24-
cl_mem interop_handler::getMemImpl(detail::Requirement* Req) const {
25-
auto Iter = std::find_if(std::begin(MMemObjs), std::end(MMemObjs),
26-
[=](ReqToMem Elem) {
27-
return (Elem.first == Req);
28-
});
24+
pi_native_handle interop_handler::GetNativeQueue() const {
25+
return MQueue->getNative();
26+
}
27+
28+
pi_native_handle interop_handler::GetNativeMem(detail::Requirement *Req) const {
29+
auto Iter = std::find_if(std::begin(MMemObjs), std::end(MMemObjs),
30+
[=](ReqToMem Elem) { return (Elem.first == Req); });
2931

30-
if (Iter == std::end(MMemObjs)) {
31-
throw("Invalid memory object used inside interop");
32-
}
33-
return detail::pi::cast<cl_mem>(Iter->second);
32+
if (Iter == std::end(MMemObjs)) {
33+
throw("Invalid memory object used inside interop");
3434
}
3535

36+
auto Plugin = MQueue->getPlugin();
37+
pi_native_handle Handle;
38+
Plugin.call<detail::PiApiKind::piGetNativeHandle>(
39+
pi_handle_type::PI_NATIVE_HANDLE_MEM, Iter->second, &Handle);
40+
return Handle;
41+
}
42+
3643
} // sycl
3744
} // cl

sycl/source/detail/scheduler/commands.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1772,12 +1772,13 @@ cl_int ExecCGCommand::enqueueImp() {
17721772
ReqMemObjs.emplace_back(ReqToMem);
17731773
});
17741774

1775-
auto interop_queue = MQueue->get();
17761775
std::sort(std::begin(ReqMemObjs), std::end(ReqMemObjs));
1777-
interop_handler InteropHandler(std::move(ReqMemObjs), interop_queue);
1776+
interop_handler InteropHandler(std::move(ReqMemObjs), MQueue);
17781777
ExecInterop->MInteropTask->call(InteropHandler);
1779-
Plugin.call<PiApiKind::piEnqueueEventsWait>(MQueue->getHandleRef(), 0, nullptr, &Event);
1780-
Plugin.call<PiApiKind::piQueueRelease>(reinterpret_cast<pi_queue>(interop_queue));
1778+
Plugin.call<PiApiKind::piEnqueueEventsWait>(MQueue->getHandleRef(), 0,
1779+
nullptr, &Event);
1780+
Plugin.call<PiApiKind::piQueueRelease>(
1781+
reinterpret_cast<pi_queue>(MQueue->get()));
17811782
return CL_SUCCESS;
17821783
}
17831784
case CG::CGTYPE::NONE:

sycl/unittests/pi/cuda/test_interop_get_native.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
// REQUIRES: cuda
2-
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -I%opencl_include_dir -I%cuda_toolkit_include -o %t.out -lcuda -lsycl
3-
// RUN: env SYCL_DEVICE_TYPE=GPU %t.out
4-
5-
//==---------- interop_get_native.cpp - SYCL cuda get_native tests ---------==//
1+
//==------- test_interop_get_native.cpp - SYCL CUDA get_native tests -------==//
62
//
73
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
84
// See https://llvm.org/LICENSE.txt for license information.
@@ -74,4 +70,28 @@ TEST_F(DISABLED_CudaInteropGetNativeTests, getNativeQueue) {
7470

7571
CUcontext cudaContext = get_native<backend::cuda>(syclContext_);
7672
ASSERT_EQ(streamContext, cudaContext);
77-
}
73+
}
74+
75+
TEST_F(DISABLED_CudaInteropGetNativeTests, interopTaskGetMem) {
76+
buffer<int, 1> syclBuffer(range<1>{1});
77+
syclQueue_.submit([&](cl::sycl::handler &cgh) {
78+
auto syclAccessor = syclBuffer.get_access<access::mode::read>(cgh);
79+
cgh.interop_task([=](sycl::interop_handler ih) {
80+
CUdeviceptr cudaPtr = ih.get_mem<backend::cuda>(syclAccessor);
81+
CUdeviceptr cudaPtrBase;
82+
size_t cudaPtrSize = 0;
83+
cuMemGetAddressRange(&cudaPtrBase, &cudaPtrSize, cudaPtr);
84+
ASSERT_EQ(cudaPtrSize, sizeof(int));
85+
});
86+
});
87+
}
88+
89+
TEST_F(DISABLED_CudaInteropGetNativeTests, interopTaskGetBufferMem) {
90+
CUstream cudaStream = get_native<backend::cuda>(syclQueue_);
91+
syclQueue_.submit([&](cl::sycl::handler &cgh) {
92+
cgh.interop_task([=](sycl::interop_handler ih) {
93+
CUstream cudaInteropStream = ih.get_queue<backend::cuda>();
94+
ASSERT_EQ(cudaInteropStream, cudaStream);
95+
});
96+
});
97+
}

0 commit comments

Comments
 (0)