Skip to content

Commit 89460e8

Browse files
authored
[SYCL][CUDA] Add experimental cuda interop with queue (#6290)
This PR is adds part of the CUDA-backend spec interop proposed in KhronosGroup/SYCL-Docs#197. The changes work with the CUDA CTS interop checks KhronosGroup/SYCL-CTS#336. This PR just adds the queue interop. llvm-test-suite: intel/llvm-test-suite#1054
1 parent 34dcf83 commit 89460e8

File tree

5 files changed

+63
-11
lines changed

5 files changed

+63
-11
lines changed

sycl/include/sycl/ext/oneapi/experimental/backend/backend_traits_cuda.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ template <> struct InteropFeatureSupportMap<backend::ext_oneapi_cuda> {
131131
static constexpr bool MakePlatform = false;
132132
static constexpr bool MakeDevice = true;
133133
static constexpr bool MakeContext = true;
134-
static constexpr bool MakeQueue = false;
134+
static constexpr bool MakeQueue = true;
135135
static constexpr bool MakeEvent = false;
136136
static constexpr bool MakeBuffer = false;
137137
static constexpr bool MakeKernel = false;

sycl/include/sycl/ext/oneapi/experimental/backend/cuda.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,5 +71,15 @@ inline device make_device<backend::ext_oneapi_cuda>(
7171
return ext::oneapi::cuda::make_device(NativeHandle);
7272
}
7373

74+
// CUDA queue specialization
75+
template <>
76+
inline queue make_queue<backend::ext_oneapi_cuda>(
77+
const backend_input_t<backend::ext_oneapi_cuda, queue> &BackendObject,
78+
const context &TargetContext, const async_handler Handler) {
79+
return detail::make_queue(detail::pi::cast<pi_native_handle>(BackendObject),
80+
TargetContext, true, Handler,
81+
/*Backend*/ backend::ext_oneapi_cuda);
82+
}
83+
7484
} // namespace sycl
7585
} // __SYCL_INLINE_NAMESPACE(cl)

sycl/plugins/cuda/pi_cuda.cpp

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2448,6 +2448,9 @@ pi_result cuda_piQueueRelease(pi_queue command_queue) {
24482448
try {
24492449
std::unique_ptr<_pi_queue> queueImpl(command_queue);
24502450

2451+
if (!command_queue->backend_has_ownership())
2452+
return PI_SUCCESS;
2453+
24512454
ScopedContext active(command_queue->get_context());
24522455

24532456
command_queue->for_each_stream([](CUstream s) {
@@ -2511,8 +2514,7 @@ pi_result cuda_piextQueueGetNativeHandle(pi_queue queue,
25112514
}
25122515

25132516
/// Created a PI queue object from a CUDA queue handle.
2514-
/// TODO: Implement this.
2515-
/// NOTE: The created PI object takes ownership of the native handle.
2517+
/// NOTE: The created PI object does not take ownership of the native handle.
25162518
///
25172519
/// \param[in] nativeHandle The native handle to create PI queue object from.
25182520
/// \param[in] context is the PI context of the queue.
@@ -2521,13 +2523,43 @@ pi_result cuda_piextQueueGetNativeHandle(pi_queue queue,
25212523
/// the native handle, if it can.
25222524
///
25232525
/// \return TBD
2524-
pi_result cuda_piextQueueCreateWithNativeHandle(pi_native_handle, pi_context,
2525-
pi_device, bool ownNativeHandle,
2526-
pi_queue *) {
2526+
pi_result cuda_piextQueueCreateWithNativeHandle(pi_native_handle nativeHandle,
2527+
pi_context context,
2528+
pi_device device,
2529+
bool ownNativeHandle,
2530+
pi_queue *queue) {
2531+
(void)device;
25272532
(void)ownNativeHandle;
2528-
cl::sycl::detail::pi::die(
2529-
"Creation of PI queue from native handle not implemented");
2530-
return {};
2533+
assert(ownNativeHandle == false);
2534+
2535+
unsigned int flags;
2536+
CUstream cuStream = reinterpret_cast<CUstream>(nativeHandle);
2537+
2538+
auto retErr = PI_CHECK_ERROR(cuStreamGetFlags(cuStream, &flags));
2539+
2540+
pi_queue_properties properties = 0;
2541+
if (flags == CU_STREAM_DEFAULT)
2542+
properties = __SYCL_PI_CUDA_USE_DEFAULT_STREAM;
2543+
else if (flags == CU_STREAM_NON_BLOCKING)
2544+
properties = __SYCL_PI_CUDA_SYNC_WITH_DEFAULT;
2545+
else
2546+
cl::sycl::detail::pi::die("Unknown cuda stream");
2547+
2548+
std::vector<CUstream> computeCuStreams(1, cuStream);
2549+
std::vector<CUstream> transferCuStreams(0);
2550+
2551+
// Create queue and set num_compute_streams to 1, as computeCuStreams has
2552+
// valid stream
2553+
*queue = new _pi_queue{std::move(computeCuStreams),
2554+
std::move(transferCuStreams),
2555+
context,
2556+
context->get_device(),
2557+
properties,
2558+
flags,
2559+
/*backend_owns*/ false};
2560+
(*queue)->num_compute_streams_ = 1;
2561+
2562+
return retErr;
25312563
}
25322564

25332565
pi_result cuda_piEnqueueMemBufferWrite(pi_queue command_queue, pi_mem buffer,

sycl/plugins/cuda/pi_cuda.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -401,18 +401,19 @@ struct _pi_queue {
401401
unsigned int flags_;
402402
std::mutex compute_stream_mutex_;
403403
std::mutex transfer_stream_mutex_;
404+
bool has_ownership_;
404405

405406
_pi_queue(std::vector<CUstream> &&compute_streams,
406407
std::vector<CUstream> &&transfer_streams, _pi_context *context,
407408
_pi_device *device, pi_queue_properties properties,
408-
unsigned int flags)
409+
unsigned int flags, bool backend_owns = true)
409410
: compute_streams_{std::move(compute_streams)},
410411
transfer_streams_{std::move(transfer_streams)}, context_{context},
411412
device_{device}, properties_{properties}, refCount_{1}, eventCount_{0},
412413
compute_stream_idx_{0}, transfer_stream_idx_{0},
413414
num_compute_streams_{0}, num_transfer_streams_{0},
414415
last_sync_compute_streams_{0}, last_sync_transfer_streams_{0},
415-
flags_(flags) {
416+
flags_(flags), has_ownership_{backend_owns} {
416417
cuda_piContextRetain(context_);
417418
cuda_piDeviceRetain(device_);
418419
}
@@ -513,6 +514,8 @@ struct _pi_queue {
513514
pi_uint32 get_reference_count() const noexcept { return refCount_; }
514515

515516
pi_uint32 get_next_event_id() noexcept { return ++eventCount_; }
517+
518+
bool backend_has_ownership() const noexcept { return has_ownership_; }
516519
};
517520

518521
typedef void (*pfn_notify)(pi_event event, pi_int32 eventCommandStatus,

sycl/test/basic_tests/interop-cuda-experimental.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ int main() {
4040

4141
backend_traits<backend::ext_oneapi_cuda>::return_type<device> cu_device;
4242
backend_traits<backend::ext_oneapi_cuda>::return_type<context> cu_context;
43+
backend_traits<backend::ext_oneapi_cuda>::return_type<queue> cu_queue;
4344

4445
// 4.5.1.2 For each SYCL runtime class T which supports SYCL application
4546
// interoperability, a specialization of get_native must be defined, which
@@ -50,6 +51,7 @@ int main() {
5051

5152
cu_device = get_native<backend::ext_oneapi_cuda>(Device);
5253
cu_context = get_native<backend::ext_oneapi_cuda>(Context);
54+
cu_queue = get_native<backend::ext_oneapi_cuda>(Queue);
5355

5456
// Check deprecated
5557
// expected-warning@+2 {{'get_native' is deprecated: Use SYCL 2020 sycl::get_native free function}}
@@ -58,6 +60,9 @@ int main() {
5860
// expected-warning@+2 {{'get_native' is deprecated: Use SYCL 2020 sycl::get_native free function}}
5961
// expected-warning@+1 {{'get_native<sycl::backend::ext_oneapi_cuda>' is deprecated: Use SYCL 2020 sycl::get_native free function}}
6062
cu_context = Context.get_native<backend::ext_oneapi_cuda>();
63+
// expected-warning@+2 {{'get_native' is deprecated: Use SYCL 2020 sycl::get_native free function}}
64+
// expected-warning@+1 {{'get_native<sycl::backend::ext_oneapi_cuda>' is deprecated: Use SYCL 2020 sycl::get_native free function}}
65+
cu_queue = Queue.get_native<backend::ext_oneapi_cuda>();
6166

6267
// 4.5.1.1 For each SYCL runtime class T which supports SYCL application
6368
// interoperability with the SYCL backend, a specialization of input_type must
@@ -85,5 +90,7 @@ int main() {
8590
context InteropContext =
8691
make_context<backend::ext_oneapi_cuda>(InteropContextInput);
8792

93+
queue InteropQueue = make_queue<backend::ext_oneapi_cuda>(cu_queue, Context);
94+
8895
return 0;
8996
}

0 commit comments

Comments
 (0)