Skip to content

[SYCL][NFC] Pass adapter by reference instead of pointer #19105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: sycl
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 36 additions & 35 deletions sycl/source/backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace sycl {
inline namespace _V1 {
namespace detail {

static const AdapterPtr &getAdapter(backend Backend) {
static const Adapter &getAdapter(backend Backend) {
switch (Backend) {
case backend::opencl:
return ur::getAdapter<backend::opencl>();
Expand Down Expand Up @@ -71,36 +71,36 @@ backend convertUrBackend(ur_backend_t UrBackend) {
}

platform make_platform(ur_native_handle_t NativeHandle, backend Backend) {
const auto &Adapter = getAdapter(Backend);
const auto &adapter = getAdapter(Backend);

// Create UR platform first.
ur_platform_handle_t UrPlatform = nullptr;
Adapter->call<UrApiKind::urPlatformCreateWithNativeHandle>(
NativeHandle, Adapter->getUrAdapter(), nullptr, &UrPlatform);
adapter.call<UrApiKind::urPlatformCreateWithNativeHandle>(
NativeHandle, adapter.getUrAdapter(), nullptr, &UrPlatform);

return detail::createSyclObjFromImpl<platform>(
platform_impl::getOrMakePlatformImpl(UrPlatform, Adapter));
platform_impl::getOrMakePlatformImpl(UrPlatform, adapter));
}

__SYCL_EXPORT device make_device(ur_native_handle_t NativeHandle,
backend Backend) {
const auto &Adapter = getAdapter(Backend);
const auto &adapter = getAdapter(Backend);

ur_device_handle_t UrDevice = nullptr;
Adapter->call<UrApiKind::urDeviceCreateWithNativeHandle>(
NativeHandle, Adapter->getUrAdapter(), nullptr, &UrDevice);
adapter.call<UrApiKind::urDeviceCreateWithNativeHandle>(
NativeHandle, adapter.getUrAdapter(), nullptr, &UrDevice);

// Construct the SYCL device from UR device.
return detail::createSyclObjFromImpl<device>(
platform_impl::getPlatformFromUrDevice(UrDevice, Adapter)
platform_impl::getPlatformFromUrDevice(UrDevice, adapter)
.getOrMakeDeviceImpl(UrDevice));
}

__SYCL_EXPORT context make_context(ur_native_handle_t NativeHandle,
const async_handler &Handler,
backend Backend, bool KeepOwnership,
const std::vector<device> &DeviceList) {
const auto &Adapter = getAdapter(Backend);
const auto &adapter = getAdapter(Backend);

ur_context_handle_t UrContext = nullptr;
ur_context_native_properties_t Properties{};
Expand All @@ -110,12 +110,12 @@ __SYCL_EXPORT context make_context(ur_native_handle_t NativeHandle,
for (const auto &Dev : DeviceList) {
DeviceHandles.push_back(detail::getSyclObjImpl(Dev)->getHandleRef());
}
Adapter->call<UrApiKind::urContextCreateWithNativeHandle>(
NativeHandle, Adapter->getUrAdapter(), DeviceHandles.size(),
adapter.call<UrApiKind::urContextCreateWithNativeHandle>(
NativeHandle, adapter.getUrAdapter(), DeviceHandles.size(),
DeviceHandles.data(), &Properties, &UrContext);
// Construct the SYCL context from UR context.
return detail::createSyclObjFromImpl<context>(context_impl::create(
UrContext, Handler, Adapter, DeviceList, !KeepOwnership));
UrContext, Handler, adapter, DeviceList, !KeepOwnership));
}

__SYCL_EXPORT queue make_queue(ur_native_handle_t NativeHandle,
Expand All @@ -125,7 +125,8 @@ __SYCL_EXPORT queue make_queue(ur_native_handle_t NativeHandle,
const async_handler &Handler, backend Backend) {
ur_device_handle_t UrDevice =
Device ? getSyclObjImpl(*Device)->getHandleRef() : nullptr;
const auto &Adapter = getAdapter(Backend);

const auto &adapter = getAdapter(Backend);
context_impl &ContextImpl = *getSyclObjImpl(Context);

if (PropList.has_property<ext::intel::property::queue::compute_index>()) {
Expand Down Expand Up @@ -155,7 +156,7 @@ __SYCL_EXPORT queue make_queue(ur_native_handle_t NativeHandle,
// Create UR queue first.
ur_queue_handle_t UrQueue = nullptr;

Adapter->call<UrApiKind::urQueueCreateWithNativeHandle>(
adapter.call<UrApiKind::urQueueCreateWithNativeHandle>(
NativeHandle, ContextImpl.getHandleRef(), UrDevice, &NativeProperties,
&UrQueue);
// Construct the SYCL queue from UR queue.
Expand All @@ -171,15 +172,15 @@ __SYCL_EXPORT event make_event(ur_native_handle_t NativeHandle,
__SYCL_EXPORT event make_event(ur_native_handle_t NativeHandle,
const context &Context, bool KeepOwnership,
backend Backend) {
const auto &Adapter = getAdapter(Backend);
const auto &adapter = getAdapter(Backend);
const auto &ContextImpl = getSyclObjImpl(Context);

ur_event_handle_t UrEvent = nullptr;
ur_event_native_properties_t Properties{};
Properties.stype = UR_STRUCTURE_TYPE_EVENT_NATIVE_PROPERTIES;
Properties.isNativeHandleOwned = !KeepOwnership;

Adapter->call<UrApiKind::urEventCreateWithNativeHandle>(
adapter.call<UrApiKind::urEventCreateWithNativeHandle>(
NativeHandle, ContextImpl->getHandleRef(), &Properties, &UrEvent);
event Event = detail::createSyclObjFromImpl<event>(
event_impl::create_from_handle(UrEvent, Context));
Expand All @@ -193,15 +194,15 @@ std::shared_ptr<detail::kernel_bundle_impl>
make_kernel_bundle(ur_native_handle_t NativeHandle,
const context &TargetContext, bool KeepOwnership,
bundle_state State, backend Backend) {
const auto &Adapter = getAdapter(Backend);
const auto &adapter = getAdapter(Backend);
const auto &ContextImpl = getSyclObjImpl(TargetContext);

ur_program_handle_t UrProgram = nullptr;
ur_program_native_properties_t Properties{};
Properties.stype = UR_STRUCTURE_TYPE_PROGRAM_NATIVE_PROPERTIES;
Properties.isNativeHandleOwned = !KeepOwnership;

Adapter->call<UrApiKind::urProgramCreateWithNativeHandle>(
adapter.call<UrApiKind::urProgramCreateWithNativeHandle>(
NativeHandle, ContextImpl->getHandleRef(), &Properties, &UrProgram);
if (UrProgram == nullptr)
throw sycl::exception(
Expand All @@ -214,39 +215,39 @@ make_kernel_bundle(ur_native_handle_t NativeHandle,
std::vector<ur_device_handle_t> ProgramDevices;
uint32_t NumDevices = 0;

Adapter->call<UrApiKind::urProgramGetInfo>(
adapter.call<UrApiKind::urProgramGetInfo>(
UrProgram, UR_PROGRAM_INFO_NUM_DEVICES, sizeof(NumDevices), &NumDevices,
nullptr);
ProgramDevices.resize(NumDevices);
Adapter->call<UrApiKind::urProgramGetInfo>(
adapter.call<UrApiKind::urProgramGetInfo>(
UrProgram, UR_PROGRAM_INFO_DEVICES,
sizeof(ur_device_handle_t) * NumDevices, ProgramDevices.data(), nullptr);

for (auto &Dev : ProgramDevices) {
ur_program_binary_type_t BinaryType;
Adapter->call<UrApiKind::urProgramGetBuildInfo>(
adapter.call<UrApiKind::urProgramGetBuildInfo>(
UrProgram, Dev, UR_PROGRAM_BUILD_INFO_BINARY_TYPE,
sizeof(ur_program_binary_type_t), &BinaryType, nullptr);
switch (BinaryType) {
case (UR_PROGRAM_BINARY_TYPE_NONE):
if (State == bundle_state::object) {
auto Res = Adapter->call_nocheck<UrApiKind::urProgramCompileExp>(
auto Res = adapter.call_nocheck<UrApiKind::urProgramCompileExp>(
UrProgram, 1, &Dev, nullptr);
if (Res == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) {
Res = Adapter->call_nocheck<UrApiKind::urProgramCompile>(
Res = adapter.call_nocheck<UrApiKind::urProgramCompile>(
ContextImpl->getHandleRef(), UrProgram, nullptr);
}
Adapter->checkUrResult<errc::build>(Res);
adapter.checkUrResult<errc::build>(Res);
}

else if (State == bundle_state::executable) {
auto Res = Adapter->call_nocheck<UrApiKind::urProgramBuildExp>(
auto Res = adapter.call_nocheck<UrApiKind::urProgramBuildExp>(
UrProgram, 1, &Dev, nullptr);
if (Res == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) {
Res = Adapter->call_nocheck<UrApiKind::urProgramBuild>(
Res = adapter.call_nocheck<UrApiKind::urProgramBuild>(
ContextImpl->getHandleRef(), UrProgram, nullptr);
}
Adapter->checkUrResult<errc::build>(Res);
adapter.checkUrResult<errc::build>(Res);
}

break;
Expand All @@ -259,15 +260,15 @@ make_kernel_bundle(ur_native_handle_t NativeHandle,
detail::codeToString(UR_RESULT_ERROR_INVALID_VALUE));
if (State == bundle_state::executable) {
ur_program_handle_t UrLinkedProgram = nullptr;
auto Res = Adapter->call_nocheck<UrApiKind::urProgramLinkExp>(
auto Res = adapter.call_nocheck<UrApiKind::urProgramLinkExp>(
ContextImpl->getHandleRef(), 1, &Dev, 1, &UrProgram, nullptr,
&UrLinkedProgram);
if (Res == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) {
Res = Adapter->call_nocheck<UrApiKind::urProgramLink>(
Res = adapter.call_nocheck<UrApiKind::urProgramLink>(
ContextImpl->getHandleRef(), 1, &UrProgram, nullptr,
&UrLinkedProgram);
}
Adapter->checkUrResult<errc::build>(Res);
adapter.checkUrResult<errc::build>(Res);
if (UrLinkedProgram != nullptr) {
UrProgram = UrLinkedProgram;
}
Expand All @@ -289,9 +290,9 @@ make_kernel_bundle(ur_native_handle_t NativeHandle,
Devices.reserve(ProgramDevices.size());
std::transform(
ProgramDevices.begin(), ProgramDevices.end(), std::back_inserter(Devices),
[&Adapter](const auto &Dev) {
[&adapter](const auto &Dev) {
return createSyclObjFromImpl<device>(
detail::platform_impl::getPlatformFromUrDevice(Dev, Adapter)
detail::platform_impl::getPlatformFromUrDevice(Dev, adapter)
.getOrMakeDeviceImpl(Dev));
});

Expand Down Expand Up @@ -321,7 +322,7 @@ kernel make_kernel(const context &TargetContext,
const kernel_bundle<bundle_state::executable> &KernelBundle,
ur_native_handle_t NativeHandle, bool KeepOwnership,
backend Backend) {
const auto &Adapter = getAdapter(Backend);
const auto &adapter = getAdapter(Backend);
const auto &ContextImpl = getSyclObjImpl(TargetContext);
const auto &KernelBundleImpl = getSyclObjImpl(KernelBundle);

Expand Down Expand Up @@ -351,7 +352,7 @@ kernel make_kernel(const context &TargetContext,
ur_kernel_native_properties_t Properties{};
Properties.stype = UR_STRUCTURE_TYPE_KERNEL_NATIVE_PROPERTIES;
Properties.isNativeHandleOwned = !KeepOwnership;
Adapter->call<UrApiKind::urKernelCreateWithNativeHandle>(
adapter.call<UrApiKind::urKernelCreateWithNativeHandle>(
NativeHandle, ContextImpl->getHandleRef(), UrProgram, &Properties,
&UrKernel);

Expand Down
6 changes: 3 additions & 3 deletions sycl/source/backend/level_zero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ using namespace sycl::detail;

__SYCL_EXPORT device make_device(const platform &Platform,
ur_native_handle_t NativeHandle) {
const auto &Adapter = ur::getAdapter<backend::ext_oneapi_level_zero>();
const auto &adapter = ur::getAdapter<backend::ext_oneapi_level_zero>();
// Create UR device first.
ur_device_handle_t UrDevice;
Adapter->call<UrApiKind::urDeviceCreateWithNativeHandle>(
NativeHandle, Adapter->getUrAdapter(), nullptr, &UrDevice);
adapter.call<UrApiKind::urDeviceCreateWithNativeHandle>(
NativeHandle, adapter.getUrAdapter(), nullptr, &UrDevice);

return detail::createSyclObjFromImpl<device>(
getSyclObjImpl(Platform)->getOrMakeDeviceImpl(UrDevice));
Expand Down
8 changes: 4 additions & 4 deletions sycl/source/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,15 @@ context::context(const std::vector<device> &DeviceList,
impl = detail::context_impl::create(DeviceList, AsyncHandler, PropList);
}
context::context(cl_context ClContext, async_handler AsyncHandler) {
const auto &Adapter = sycl::detail::ur::getAdapter<backend::opencl>();
const auto &adapter = sycl::detail::ur::getAdapter<backend::opencl>();

ur_context_handle_t hContext = nullptr;
ur_native_handle_t nativeHandle =
reinterpret_cast<ur_native_handle_t>(ClContext);
Adapter->call<detail::UrApiKind::urContextCreateWithNativeHandle>(
nativeHandle, Adapter->getUrAdapter(), 0, nullptr, nullptr, &hContext);
adapter.call<detail::UrApiKind::urContextCreateWithNativeHandle>(
nativeHandle, adapter.getUrAdapter(), 0, nullptr, nullptr, &hContext);

impl = detail::context_impl::create(hContext, AsyncHandler, Adapter);
impl = detail::context_impl::create(hContext, AsyncHandler, adapter);
}

template <typename Param>
Expand Down
6 changes: 3 additions & 3 deletions sycl/source/detail/adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,14 @@ class Adapter {
return UrPlatforms;
}

ur_adapter_handle_t getUrAdapter() { return MAdapter; }
ur_adapter_handle_t getUrAdapter() const { return MAdapter; }

/// Calls the UR Api, traces the call, and returns the result.
///
/// Usage:
/// \code{cpp}
/// ur_result_t Err = Adapter->call<UrApiKind::urEntryPoint>(Args);
/// Adapter->checkUrResult(Err); // Checks Result and throws a runtime_error
/// ur_result_t Err = adapter.call<UrApiKind::urEntryPoint>(Args);
/// adapter.checkUrResult(Err); // Checks Result and throws a runtime_error
/// // exception.
/// \endcode
///
Expand Down
7 changes: 3 additions & 4 deletions sycl/source/detail/allowlist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,7 @@ bool deviceIsAllowed(const DeviceDescT &DeviceDesc,
}

void applyAllowList(std::vector<ur_device_handle_t> &UrDevices,
ur_platform_handle_t UrPlatform,
const AdapterPtr &Adapter) {
ur_platform_handle_t UrPlatform, const Adapter &adapter) {

AllowListParsedT AllowListParsed =
parseAllowList(SYCLConfig<SYCL_DEVICE_ALLOWLIST>::get());
Expand All @@ -375,7 +374,7 @@ void applyAllowList(std::vector<ur_device_handle_t> &UrDevices,
// Get platform's backend and put it to DeviceDesc
DeviceDescT DeviceDesc;
platform_impl &PlatformImpl =
platform_impl::getOrMakePlatformImpl(UrPlatform, Adapter);
platform_impl::getOrMakePlatformImpl(UrPlatform, adapter);
backend Backend = PlatformImpl.getBackend();

for (const auto &SyclBe : getSyclBeMap()) {
Expand All @@ -396,7 +395,7 @@ void applyAllowList(std::vector<ur_device_handle_t> &UrDevices,
device_impl &DeviceImpl = PlatformImpl.getOrMakeDeviceImpl(Device);
// get DeviceType value and put it to DeviceDesc
ur_device_type_t UrDevType = UR_DEVICE_TYPE_ALL;
Adapter->call<UrApiKind::urDeviceGetInfo>(
adapter.call<UrApiKind::urDeviceGetInfo>(
Device, UR_DEVICE_INFO_TYPE, sizeof(UrDevType), &UrDevType, nullptr);
// TODO need mechanism to do these casts, there's a bunch of this sort of
// thing
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/allowlist.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ bool deviceIsAllowed(const DeviceDescT &DeviceDesc,
const AllowListParsedT &AllowListParsed);

void applyAllowList(std::vector<ur_device_handle_t> &UrDevices,
ur_platform_handle_t UrPlatform, const AdapterPtr &Adapter);
ur_platform_handle_t UrPlatform, const Adapter &AAdapter);

} // namespace detail
} // namespace _V1
Expand Down
12 changes: 6 additions & 6 deletions sycl/source/detail/async_alloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ void *async_malloc(sycl::handler &h, sycl::usm::alloc kind, size_t size) {
sycl::make_error_code(sycl::errc::feature_not_supported),
"Only device backed asynchronous allocations are supported!");

auto &Adapter = h.getContextImpl().getAdapter();
auto &adapter = h.getContextImpl().getAdapter();

// Get CG event dependencies for this allocation.
const auto &DepEvents = h.impl->CGData.MEvents;
Expand All @@ -84,8 +84,8 @@ void *async_malloc(sycl::handler &h, sycl::usm::alloc kind, size_t size) {
alloc = Graph->getMemPool().malloc(size, kind, DepNodes);
} else {
ur_queue_handle_t Q = h.impl->get_queue().getHandleRef();
Adapter->call<sycl::errc::runtime,
sycl::detail::UrApiKind::urEnqueueUSMDeviceAllocExp>(
adapter.call<sycl::errc::runtime,
sycl::detail::UrApiKind::urEnqueueUSMDeviceAllocExp>(
Q, (ur_usm_pool_handle_t)0, size, nullptr, UREvents.size(),
UREvents.data(), &alloc, &Event);
}
Expand Down Expand Up @@ -118,7 +118,7 @@ __SYCL_EXPORT void *async_malloc(const sycl::queue &q, sycl::usm::alloc kind,
__SYCL_EXPORT void *async_malloc_from_pool(sycl::handler &h, size_t size,
const memory_pool &pool) {

auto &Adapter = h.getContextImpl().getAdapter();
auto &adapter = h.getContextImpl().getAdapter();
auto &memPoolImpl = sycl::detail::getSyclObjImpl(pool);

// Get CG event dependencies for this allocation.
Expand All @@ -138,8 +138,8 @@ __SYCL_EXPORT void *async_malloc_from_pool(sycl::handler &h, size_t size,
sycl::detail::getSyclObjImpl(pool));
} else {
ur_queue_handle_t Q = h.impl->get_queue().getHandleRef();
Adapter->call<sycl::errc::runtime,
sycl::detail::UrApiKind::urEnqueueUSMDeviceAllocExp>(
adapter.call<sycl::errc::runtime,
sycl::detail::UrApiKind::urEnqueueUSMDeviceAllocExp>(
Q, memPoolImpl.get()->get_handle(), size, nullptr, UREvents.size(),
UREvents.data(), &alloc, &Event);
}
Expand Down
Loading
Loading