Skip to content

Commit 42ee42e

Browse files
[UR] Add handles to opencl adapter (#17572)
Co-authored-by: omarahmed1111 <[email protected]>
1 parent 48a08bf commit 42ee42e

25 files changed

+2152
-1342
lines changed

sycl/source/detail/buffer_impl.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,17 +86,17 @@ buffer_impl::getNativeVector(backend BackendName) const {
8686

8787
auto Adapter = Platform->getAdapter();
8888

89-
if (Platform->getBackend() == backend::opencl) {
90-
__SYCL_OCL_CALL(clRetainMemObject, ur::cast<cl_mem>(NativeMem));
91-
}
92-
9389
ur_native_handle_t Handle = 0;
9490
// When doing buffer interop we don't know what device the memory should be
9591
// resident on, so pass nullptr for Device param. Buffer interop may not be
9692
// supported by all backends.
9793
Adapter->call<UrApiKind::urMemGetNativeHandle>(NativeMem, /*Dev*/ nullptr,
9894
&Handle);
9995
Handles.push_back(Handle);
96+
97+
if (Platform->getBackend() == backend::opencl) {
98+
__SYCL_OCL_CALL(clRetainMemObject, ur::cast<cl_mem>(Handle));
99+
}
100100
}
101101

102102
addInteropObject(Handles);

sycl/source/kernel.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ kernel::kernel(cl_kernel ClKernel, const context &SyclContext) {
2222
ur_kernel_handle_t hKernel = nullptr;
2323
ur_native_handle_t nativeHandle =
2424
reinterpret_cast<ur_native_handle_t>(ClKernel);
25-
Adapter->call<detail::UrApiKind::urKernelCreateWithNativeHandle>(
26-
nativeHandle, detail::getSyclObjImpl(SyclContext)->getHandleRef(),
27-
nullptr, nullptr, &hKernel);
25+
Adapter
26+
->call<errc::invalid, detail::UrApiKind::urKernelCreateWithNativeHandle>(
27+
nativeHandle, detail::getSyclObjImpl(SyclContext)->getHandleRef(),
28+
nullptr, nullptr, &hKernel);
2829
impl = std::make_shared<detail::kernel_impl>(
2930
hKernel, detail::getSyclObjImpl(SyclContext), nullptr, nullptr);
3031
// This is a special interop constructor for OpenCL, so the kernel must be

unified-runtime/source/adapters/opencl/adapter.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
88
//
99
//===----------------------------------------------------------------------===//
10+
#include "logger/ur_logger.hpp"
11+
#include "platform.hpp"
1012

1113
#include "CL/cl.h"
1214
#include "logger/ur_logger.hpp"
@@ -18,6 +20,9 @@ struct ur_adapter_handle_t_ {
1820
std::mutex Mutex;
1921
logger::Logger &log = logger::get_logger("opencl");
2022

23+
std::vector<std::unique_ptr<ur_platform_handle_t_>> URPlatforms;
24+
uint32_t NumPlatforms = 0;
25+
2126
// Function pointers to core OpenCL entry points which may not exist in older
2227
// versions of the OpenCL-ICD-Loader are tracked here and initialized by
2328
// dynamically loading the symbol by name.

unified-runtime/source/adapters/opencl/command_buffer.cpp

Lines changed: 49 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,18 @@
1010

1111
#include "command_buffer.hpp"
1212
#include "common.hpp"
13+
#include "context.hpp"
14+
#include "event.hpp"
15+
#include "kernel.hpp"
16+
#include "memory.hpp"
17+
#include "queue.hpp"
1318

1419
/// The ur_exp_command_buffer_handle_t_ destructor calls CL release
1520
/// command-buffer to free the underlying object.
1621
ur_exp_command_buffer_handle_t_::~ur_exp_command_buffer_handle_t_() {
1722
urQueueRelease(hInternalQueue);
1823

19-
cl_context CLContext = cl_adapter::cast<cl_context>(hContext);
24+
cl_context CLContext = hContext->CLContext;
2025
cl_ext::clReleaseCommandBufferKHR_fn clReleaseCommandBufferKHR = nullptr;
2126
cl_int Res =
2227
cl_ext::getExtFuncFromContext<decltype(clReleaseCommandBufferKHR)>(
@@ -43,7 +48,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp(
4348
UR_RETURN_ON_FAILURE(
4449
urQueueCreate(hContext, hDevice, &QueueProperties, &Queue));
4550

46-
cl_context CLContext = cl_adapter::cast<cl_context>(hContext);
51+
cl_context CLContext = hContext->CLContext;
4752
cl_ext::clCreateCommandBufferKHR_fn clCreateCommandBufferKHR = nullptr;
4853
UR_RETURN_ON_FAILURE(
4954
cl_ext::getExtFuncFromContext<decltype(clCreateCommandBufferKHR)>(
@@ -53,7 +58,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp(
5358
const bool IsUpdatable = pCommandBufferDesc->isUpdatable;
5459

5560
ur_device_command_buffer_update_capability_flags_t UpdateCapabilities;
56-
cl_device_id CLDevice = cl_adapter::cast<cl_device_id>(hDevice);
61+
cl_device_id CLDevice = hDevice->CLDevice;
5762
CL_RETURN_ON_FAILURE(
5863
getDeviceCommandBufferUpdateCapabilities(CLDevice, UpdateCapabilities));
5964
bool DeviceSupportsUpdate = UpdateCapabilities > 0;
@@ -67,16 +72,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp(
6772
IsUpdatable ? CL_COMMAND_BUFFER_MUTABLE_KHR : 0u, 0};
6873

6974
cl_int Res = CL_SUCCESS;
70-
auto CLCommandBuffer = clCreateCommandBufferKHR(
71-
1, cl_adapter::cast<cl_command_queue *>(&Queue), Properties, &Res);
75+
const cl_command_queue CLQueue = Queue->CLQueue;
76+
auto CLCommandBuffer =
77+
clCreateCommandBufferKHR(1, &CLQueue, Properties, &Res);
7278
CL_RETURN_ON_FAILURE_AND_SET_NULL(Res, phCommandBuffer);
7379

7480
try {
7581
auto URCommandBuffer = std::make_unique<ur_exp_command_buffer_handle_t_>(
7682
Queue, hContext, hDevice, CLCommandBuffer, IsUpdatable, IsInOrder);
7783
*phCommandBuffer = URCommandBuffer.release();
78-
} catch (...) {
84+
} catch (std::bad_alloc &) {
7985
return UR_RESULT_ERROR_OUT_OF_RESOURCES;
86+
} catch (...) {
87+
return UR_RESULT_ERROR_UNKNOWN;
8088
}
8189

8290
CL_RETURN_ON_FAILURE(Res);
@@ -101,7 +109,7 @@ urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t hCommandBuffer) {
101109
UR_APIEXPORT ur_result_t UR_APICALL
102110
urCommandBufferFinalizeExp(ur_exp_command_buffer_handle_t hCommandBuffer) {
103111
UR_ASSERT(!hCommandBuffer->IsFinalized, UR_RESULT_ERROR_INVALID_OPERATION);
104-
cl_context CLContext = cl_adapter::cast<cl_context>(hCommandBuffer->hContext);
112+
cl_context CLContext = hCommandBuffer->hContext->CLContext;
105113
cl_ext::clFinalizeCommandBufferKHR_fn clFinalizeCommandBufferKHR = nullptr;
106114
UR_RETURN_ON_FAILURE(
107115
cl_ext::getExtFuncFromContext<decltype(clFinalizeCommandBufferKHR)>(
@@ -133,7 +141,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
133141
UR_ASSERT(!(phCommandHandle && !hCommandBuffer->IsUpdatable),
134142
UR_RESULT_ERROR_INVALID_OPERATION);
135143

136-
cl_context CLContext = cl_adapter::cast<cl_context>(hCommandBuffer->hContext);
144+
cl_context CLContext = hCommandBuffer->hContext->CLContext;
137145
cl_ext::clCommandNDRangeKernelKHR_fn clCommandNDRangeKernelKHR = nullptr;
138146
UR_RETURN_ON_FAILURE(
139147
cl_ext::getExtFuncFromContext<decltype(clCommandNDRangeKernelKHR)>(
@@ -161,10 +169,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
161169
IsInOrder ? nullptr : pSyncPointWaitList;
162170
uint32_t WaitListSize = IsInOrder ? 0 : numSyncPointsInWaitList;
163171
CL_RETURN_ON_FAILURE(clCommandNDRangeKernelKHR(
164-
hCommandBuffer->CLCommandBuffer, nullptr, Properties,
165-
cl_adapter::cast<cl_kernel>(hKernel), workDim, pGlobalWorkOffset,
166-
pGlobalWorkSize, pLocalWorkSize, WaitListSize, SyncPointWaitList,
167-
RetSyncPoint, OutCommandHandle));
172+
hCommandBuffer->CLCommandBuffer, nullptr, Properties, hKernel->CLKernel,
173+
workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize, WaitListSize,
174+
SyncPointWaitList, RetSyncPoint, OutCommandHandle));
168175

169176
try {
170177
auto Handle = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
@@ -224,7 +231,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp(
224231
(void)phEventWaitList;
225232
(void)phEvent;
226233
(void)phCommand;
227-
cl_context CLContext = cl_adapter::cast<cl_context>(hCommandBuffer->hContext);
234+
cl_context CLContext = hCommandBuffer->hContext->CLContext;
228235
cl_ext::clCommandCopyBufferKHR_fn clCommandCopyBufferKHR = nullptr;
229236
UR_RETURN_ON_FAILURE(
230237
cl_ext::getExtFuncFromContext<decltype(clCommandCopyBufferKHR)>(
@@ -237,10 +244,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp(
237244
IsInOrder ? nullptr : pSyncPointWaitList;
238245
uint32_t WaitListSize = IsInOrder ? 0 : numSyncPointsInWaitList;
239246
CL_RETURN_ON_FAILURE(clCommandCopyBufferKHR(
240-
hCommandBuffer->CLCommandBuffer, nullptr, nullptr,
241-
cl_adapter::cast<cl_mem>(hSrcMem), cl_adapter::cast<cl_mem>(hDstMem),
242-
srcOffset, dstOffset, size, WaitListSize, SyncPointWaitList, RetSyncPoint,
243-
nullptr));
247+
hCommandBuffer->CLCommandBuffer, nullptr, nullptr, hSrcMem->CLMemory,
248+
hDstMem->CLMemory, srcOffset, dstOffset, size, WaitListSize,
249+
SyncPointWaitList, RetSyncPoint, nullptr));
244250

245251
return UR_RESULT_SUCCESS;
246252
}
@@ -267,7 +273,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp(
267273
size_t OpenCLDstRect[3]{dstOrigin.x, dstOrigin.y, dstOrigin.z};
268274
size_t OpenCLRegion[3]{region.width, region.height, region.depth};
269275

270-
cl_context CLContext = cl_adapter::cast<cl_context>(hCommandBuffer->hContext);
276+
cl_context CLContext = hCommandBuffer->hContext->CLContext;
271277
cl_ext::clCommandCopyBufferRectKHR_fn clCommandCopyBufferRectKHR = nullptr;
272278
UR_RETURN_ON_FAILURE(
273279
cl_ext::getExtFuncFromContext<decltype(clCommandCopyBufferRectKHR)>(
@@ -280,11 +286,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp(
280286
IsInOrder ? nullptr : pSyncPointWaitList;
281287
uint32_t WaitListSize = IsInOrder ? 0 : numSyncPointsInWaitList;
282288
CL_RETURN_ON_FAILURE(clCommandCopyBufferRectKHR(
283-
hCommandBuffer->CLCommandBuffer, nullptr, nullptr,
284-
cl_adapter::cast<cl_mem>(hSrcMem), cl_adapter::cast<cl_mem>(hDstMem),
285-
OpenCLOriginRect, OpenCLDstRect, OpenCLRegion, srcRowPitch, srcSlicePitch,
286-
dstRowPitch, dstSlicePitch, WaitListSize, SyncPointWaitList, RetSyncPoint,
287-
nullptr));
289+
hCommandBuffer->CLCommandBuffer, nullptr, nullptr, hSrcMem->CLMemory,
290+
hDstMem->CLMemory, OpenCLOriginRect, OpenCLDstRect, OpenCLRegion,
291+
srcRowPitch, srcSlicePitch, dstRowPitch, dstSlicePitch, WaitListSize,
292+
SyncPointWaitList, RetSyncPoint, nullptr));
288293

289294
return UR_RESULT_SUCCESS;
290295
}
@@ -376,7 +381,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp(
376381
[[maybe_unused]] ur_event_handle_t *phEvent,
377382
[[maybe_unused]] ur_exp_command_buffer_command_handle_t *phCommand) {
378383

379-
cl_context CLContext = cl_adapter::cast<cl_context>(hCommandBuffer->hContext);
384+
cl_context CLContext = hCommandBuffer->hContext->CLContext;
380385
cl_ext::clCommandFillBufferKHR_fn clCommandFillBufferKHR = nullptr;
381386
UR_RETURN_ON_FAILURE(
382387
cl_ext::getExtFuncFromContext<decltype(clCommandFillBufferKHR)>(
@@ -389,9 +394,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp(
389394
IsInOrder ? nullptr : pSyncPointWaitList;
390395
uint32_t WaitListSize = IsInOrder ? 0 : numSyncPointsInWaitList;
391396
CL_RETURN_ON_FAILURE(clCommandFillBufferKHR(
392-
hCommandBuffer->CLCommandBuffer, nullptr, nullptr,
393-
cl_adapter::cast<cl_mem>(hBuffer), pPattern, patternSize, offset, size,
394-
WaitListSize, SyncPointWaitList, RetSyncPoint, nullptr));
397+
hCommandBuffer->CLCommandBuffer, nullptr, nullptr, hBuffer->CLMemory,
398+
pPattern, patternSize, offset, size, WaitListSize, SyncPointWaitList,
399+
RetSyncPoint, nullptr));
395400

396401
return UR_RESULT_SUCCESS;
397402
}
@@ -447,21 +452,25 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCommandBufferExp(
447452
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
448453
ur_event_handle_t *phEvent) {
449454

450-
cl_context CLContext = cl_adapter::cast<cl_context>(hCommandBuffer->hContext);
455+
cl_context CLContext = hCommandBuffer->hContext->CLContext;
451456
cl_ext::clEnqueueCommandBufferKHR_fn clEnqueueCommandBufferKHR = nullptr;
452457
UR_RETURN_ON_FAILURE(
453458
cl_ext::getExtFuncFromContext<decltype(clEnqueueCommandBufferKHR)>(
454459
CLContext, cl_ext::ExtFuncPtrCache->clEnqueueCommandBufferKHRCache,
455460
cl_ext::EnqueueCommandBufferName, &clEnqueueCommandBufferKHR));
456461

457462
const uint32_t NumberOfQueues = 1;
458-
463+
cl_event Event;
464+
std::vector<cl_event> CLWaitEvents(numEventsInWaitList);
465+
for (uint32_t i = 0; i < numEventsInWaitList; i++) {
466+
CLWaitEvents[i] = phEventWaitList[i]->CLEvent;
467+
}
468+
cl_command_queue CLQueue = hQueue->CLQueue;
459469
CL_RETURN_ON_FAILURE(clEnqueueCommandBufferKHR(
460-
NumberOfQueues, cl_adapter::cast<cl_command_queue *>(&hQueue),
461-
hCommandBuffer->CLCommandBuffer, numEventsInWaitList,
462-
cl_adapter::cast<const cl_event *>(phEventWaitList),
463-
cl_adapter::cast<cl_event *>(phEvent)));
470+
NumberOfQueues, &CLQueue, hCommandBuffer->CLCommandBuffer,
471+
numEventsInWaitList, CLWaitEvents.data(), ifUrEvent(phEvent, Event)));
464472

473+
UR_RETURN_ON_FAILURE(createUREvent(Event, hQueue->Context, hQueue, phEvent));
465474
return UR_RESULT_SUCCESS;
466475
}
467476

@@ -501,11 +510,11 @@ void updateKernelArgs(std::vector<cl_mutable_dispatch_arg_khr> &CLArgs,
501510
for (uint32_t i = 0; i < NumMemobjArgs; i++) {
502511
const ur_exp_command_buffer_update_memobj_arg_desc_t &URMemObjArg =
503512
ArgMemobjList[i];
513+
cl_mem arg_value = URMemObjArg.hNewMemObjArg->CLMemory;
504514
cl_mutable_dispatch_arg_khr CLArg{
505515
URMemObjArg.argIndex, // arg_index
506516
sizeof(cl_mem), // arg_size
507-
cl_adapter::cast<const cl_mem *>(
508-
&URMemObjArg.hNewMemObjArg) // arg_value
517+
&arg_value // arg_value
509518
};
510519

511520
CLArgs.push_back(CLArg);
@@ -549,7 +558,7 @@ ur_result_t validateCommandDesc(
549558
// Verify that the device supports updating the aspects of the kernel that
550559
// the user is requesting.
551560
ur_device_handle_t URDevice = CommandBuffer->hDevice;
552-
cl_device_id CLDevice = cl_adapter::cast<cl_device_id>(URDevice);
561+
cl_device_id CLDevice = URDevice->CLDevice;
553562

554563
ur_device_command_buffer_update_capability_flags_t UpdateCapabilities = 0;
555564
CL_RETURN_ON_FAILURE(
@@ -601,7 +610,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
601610
validateCommandDesc(hCommandBuffer, pUpdateKernelLaunch[i]));
602611
}
603612

604-
cl_context CLContext = cl_adapter::cast<cl_context>(hCommandBuffer->hContext);
613+
cl_context CLContext = hCommandBuffer->hContext->CLContext;
614+
605615
cl_ext::clUpdateMutableCommandsKHR_fn clUpdateMutableCommandsKHR = nullptr;
606616
UR_RETURN_ON_FAILURE(
607617
cl_ext::getExtFuncFromContext<decltype(clUpdateMutableCommandsKHR)>(
@@ -657,8 +667,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
657667
updateNDRange(CLLocalWorkSize, CommandWorkDim, LocalWorkSizePtr);
658668
}
659669

660-
cl_mutable_command_khr CLCommand =
661-
cl_adapter::cast<cl_mutable_command_khr>(Command->CLMutableCommand);
670+
cl_mutable_command_khr CLCommand = Command->CLMutableCommand;
662671
Config = cl_mutable_dispatch_config_khr{
663672
CLCommand,
664673
static_cast<cl_uint>(CLArgs.size()), // num_args
@@ -736,7 +745,7 @@ ur_result_t UR_APICALL urCommandBufferAppendNativeCommandExp(
736745
uint32_t numSyncPointsInWaitList,
737746
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
738747
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
739-
cl_context CLContext = cl_adapter::cast<cl_context>(hCommandBuffer->hContext);
748+
cl_context CLContext = hCommandBuffer->hContext->CLContext;
740749
cl_ext::clCommandBarrierWithWaitListKHR_fn clCommandBarrierWithWaitListKHR =
741750
nullptr;
742751
UR_RETURN_ON_FAILURE(

unified-runtime/source/adapters/opencl/common.hpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -156,20 +156,6 @@ extern thread_local char ErrorMessage[MaxMessageSize];
156156
// Utility function for setting a message and warning
157157
[[maybe_unused]] void setErrorMessage(const char *Message,
158158
ur_result_t ErrorCode);
159-
160-
template <class To, class From> To cast(From Value) {
161-
162-
if constexpr (std::is_pointer_v<From>) {
163-
static_assert(std::is_pointer_v<From> == std::is_pointer_v<To>,
164-
"Cast failed pointer check");
165-
return reinterpret_cast<To>(Value);
166-
} else {
167-
static_assert(sizeof(From) == sizeof(To), "Cast failed size check");
168-
static_assert(std::is_signed_v<From> == std::is_signed_v<To>,
169-
"Cast failed sign check");
170-
return static_cast<To>(Value);
171-
}
172-
}
173159
} // namespace cl_adapter
174160

175161
namespace cl_ext {

0 commit comments

Comments
 (0)