Skip to content

Commit d9deb25

Browse files
authored
[SYCL][CUDA] Pass device from context in create queue. (#10491)
Recently in the switch to UR `urQueueCreateFromNativeHandle` changed the previous behaviour whereby a queue was created with a device taken as the default device from the context. It changed it so that the queue was created with the device argument instead. Since the sycl runtime always passes a nullptr for the device when programmers call `make_queue(nativeStream, context)`, this broke `make_queue`. This patch reverts to the previous behaviour before the switch from pi cuda to ur cuda. Note that this should also fix `make_queue` for l0 which I also guess was broken due to the asserts meaning that this line was never reached: https://github.com/intel/llvm/blob/sycl/sycl/plugins/unified_runtime/ur/adapters/level_zero/queue.cpp#L574. But I have not tested this. --------- Signed-off-by: JackAKirk <[email protected]>
1 parent 45fb7ae commit d9deb25

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

sycl/plugins/unified_runtime/pi2ur.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1620,7 +1620,6 @@ inline pi_result piextQueueCreateWithNativeHandle(
16201620
PI_ASSERT(Context, PI_ERROR_INVALID_CONTEXT);
16211621
PI_ASSERT(NativeHandle, PI_ERROR_INVALID_VALUE);
16221622
PI_ASSERT(Queue, PI_ERROR_INVALID_QUEUE);
1623-
PI_ASSERT(Device, PI_ERROR_INVALID_DEVICE);
16241623

16251624
ur_context_handle_t UrContext =
16261625
reinterpret_cast<ur_context_handle_t>(Context);

sycl/plugins/unified_runtime/ur/adapters/cuda/queue.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,10 +243,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueCreateWithNativeHandle(
243243
ur_device_handle_t hDevice, const ur_queue_native_properties_t *pProperties,
244244
ur_queue_handle_t *phQueue) {
245245
(void)pProperties;
246+
(void)hDevice;
246247

247248
unsigned int CuFlags;
248249
CUstream CuStream = reinterpret_cast<CUstream>(hNativeQueue);
249-
UR_ASSERT(hContext->getDevice() == hDevice, UR_RESULT_ERROR_INVALID_DEVICE);
250250

251251
auto Return = UR_CHECK_ERROR(cuStreamGetFlags(CuStream, &CuFlags));
252252

@@ -266,7 +266,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueCreateWithNativeHandle(
266266
*phQueue = new ur_queue_handle_t_{std::move(ComputeCuStreams),
267267
std::move(TransferCuStreams),
268268
hContext,
269-
hDevice,
269+
hContext->getDevice(),
270270
CuFlags,
271271
Flags,
272272
/*backend_owns*/ false};

0 commit comments

Comments
 (0)