Skip to content

[SYCL] Cache devices so they actually compare == when they should. #2092

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Aug 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions sycl/source/detail/context_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <CL/sycl/stl.hpp>
#include <detail/context_impl.hpp>
#include <detail/context_info.hpp>
#include <detail/platform_impl.hpp>

__SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {
Expand Down Expand Up @@ -80,12 +81,15 @@ context_impl::context_impl(RT::PiContext PiContext, async_handler AsyncHandler,
sizeof(RT::PiDevice) * DevicesNum,
&DeviceIds[0], nullptr);

for (auto Dev : DeviceIds) {
MDevices.emplace_back(createSyclObjFromImpl<device>(
std::make_shared<device_impl>(Dev, Plugin)));
if (!DeviceIds.empty()) {
std::shared_ptr<detail::platform_impl> Platform =
platform_impl::getPlatformFromPiDevice(DeviceIds[0], Plugin);
for (RT::PiDevice Dev : DeviceIds) {
MDevices.emplace_back(createSyclObjFromImpl<device>(
Platform->getOrMakeDeviceImpl(Dev, Platform)));
}
MPlatform = Platform;
}
// TODO What if m_Devices if empty? m_Devices[0].get_platform()
MPlatform = detail::getSyclObjImpl(MDevices[0].get_platform());
// TODO catch an exception and put it to list of asynchronous exceptions
// getPlugin() will be the same as the Plugin passed. This should be taken
// care of when creating device object.
Expand Down
19 changes: 11 additions & 8 deletions sycl/source/detail/device_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <CL/sycl/device.hpp>
#include <detail/device_impl.hpp>
#include <detail/platform_impl.hpp>

#include <algorithm>

Expand All @@ -16,8 +17,7 @@ namespace sycl {
namespace detail {

device_impl::device_impl()
: MIsHostDevice(true),
MPlatform(std::make_shared<platform_impl>(platform_impl())) {}
: MIsHostDevice(true), MPlatform(platform_impl::getHostPlatformImpl()) {}

device_impl::device_impl(pi_native_handle InteropDeviceHandle,
const plugin &Plugin)
Expand Down Expand Up @@ -67,11 +67,7 @@ device_impl::device_impl(pi_native_handle InteropDeviceHandle,

// set MPlatform
if (!Platform) {
RT::PiPlatform plt = nullptr; // TODO catch an exception and put it to list
// of asynchronous exceptions
Plugin.call<PiApiKind::piDeviceGetInfo>(MDevice, PI_DEVICE_INFO_PLATFORM,
sizeof(plt), &plt, nullptr);
Platform = std::make_shared<platform_impl>(plt, Plugin);
Platform = platform_impl::getPlatformFromPiDevice(MDevice, Plugin);
}
MPlatform = Platform;
}
Expand Down Expand Up @@ -146,7 +142,7 @@ device_impl::create_sub_devices(const cl_device_partition_property *Properties,
std::for_each(SubDevices.begin(), SubDevices.end(),
[&res, this](const RT::PiDevice &a_pi_device) {
device sycl_device = detail::createSyclObjFromImpl<device>(
std::make_shared<device_impl>(a_pi_device, MPlatform));
MPlatform->getOrMakeDeviceImpl(a_pi_device, MPlatform));
res.push_back(sycl_device);
});
return res;
Expand Down Expand Up @@ -251,6 +247,13 @@ bool device_impl::has(aspect Aspect) const {
}
}

std::shared_ptr<device_impl> device_impl::getHostDeviceImpl() {
static std::shared_ptr<device_impl> HostImpl =
std::make_shared<device_impl>();

return HostImpl;
}

} // namespace detail
} // namespace sycl
} // __SYCL_INLINE_NAMESPACE(cl)
5 changes: 5 additions & 0 deletions sycl/source/detail/device_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,11 @@ class device_impl {
/// \return true if the SYCL device has the given feature.
bool has(aspect Aspect) const;

/// Gets the single instance of the Host Device
///
/// \return the host device_impl singleton
static std::shared_ptr<device_impl> getHostDeviceImpl();

private:
explicit device_impl(pi_native_handle InteropDevice, RT::PiDevice Device,
PlatformImplPtr Platform, const plugin &Plugin);
Expand Down
8 changes: 6 additions & 2 deletions sycl/source/detail/device_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <CL/sycl/info/info_desc.hpp>
#include <CL/sycl/platform.hpp>
#include <detail/device_impl.hpp>
#include <detail/platform_impl.hpp>
#include <detail/platform_util.hpp>
#include <detail/plugin.hpp>

Expand Down Expand Up @@ -117,7 +118,7 @@ template <info::device param> struct get_device_info<platform, param> {
// Use the Plugin from the device_impl class after plugin details
// are added to the class.
return createSyclObjFromImpl<platform>(
std::make_shared<platform_impl>(result, Plugin));
platform_impl::getOrMakePlatformImpl(result, Plugin));
}
};

Expand Down Expand Up @@ -406,8 +407,11 @@ template <> struct get_device_info<device, info::device::parent_device> {
"No parent for device because it is not a subdevice",
PI_INVALID_DEVICE);

// Get the platform of this device
std::shared_ptr<detail::platform_impl> Platform =
platform_impl::getPlatformFromPiDevice(dev, Plugin);
return createSyclObjFromImpl<device>(
std::make_shared<device_impl>(result, Plugin));
Platform->getOrMakeDeviceImpl(result, Platform));
}
};

Expand Down
76 changes: 67 additions & 9 deletions sycl/source/detail/platform_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,46 @@ __SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {
namespace detail {

using PlatformImplPtr = std::shared_ptr<platform_impl>;

PlatformImplPtr platform_impl::getHostPlatformImpl() {
static PlatformImplPtr HostImpl = std::make_shared<platform_impl>();

return HostImpl;
}

PlatformImplPtr platform_impl::getOrMakePlatformImpl(RT::PiPlatform PiPlatform,
const plugin &Plugin) {
static std::vector<PlatformImplPtr> PlatformCache;
static std::mutex PlatformMapMutex;

PlatformImplPtr Result;
{
const std::lock_guard<std::mutex> Guard(PlatformMapMutex);

// If we've already seen this platform, return the impl
for (const auto &PlatImpl : PlatformCache) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (const auto &PlatImpl : PlatformCache) {
for (const PlatformImplPtr &PlatImpl : PlatformCache) {

if (PlatImpl->getHandleRef() == PiPlatform)
return PlatImpl;
}

// Otherwise make the impl
Result = std::make_shared<platform_impl>(PiPlatform, Plugin);
PlatformCache.emplace_back(Result);
}

return Result;
}

PlatformImplPtr platform_impl::getPlatformFromPiDevice(RT::PiDevice PiDevice,
const plugin &Plugin) {
RT::PiPlatform Plt = nullptr; // TODO catch an exception and put it to list
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this TODO doesn't make any sense in this context. Or I'm missing something?

// of asynchronous exceptions
Plugin.call<PiApiKind::piDeviceGetInfo>(PiDevice, PI_DEVICE_INFO_PLATFORM,
sizeof(Plt), &Plt, nullptr);
return getOrMakePlatformImpl(Plt, Plugin);
}

static bool IsBannedPlatform(platform Platform) {
// The NVIDIA OpenCL platform is currently not compatible with DPC++
// since it is only 1.2 but gets selected by default in many systems
Expand Down Expand Up @@ -65,7 +105,7 @@ vector_class<platform> platform_impl::get_platforms() {

for (const auto &PiPlatform : PiPlatforms) {
platform Platform = detail::createSyclObjFromImpl<platform>(
std::make_shared<platform_impl>(PiPlatform, Plugins[i]));
getOrMakePlatformImpl(PiPlatform, Plugins[i]));
// Skip platforms which do not contain requested device types
if (!Platform.get_devices(ForcedType).empty() &&
!IsBannedPlatform(Platform))
Expand All @@ -83,7 +123,6 @@ vector_class<platform> platform_impl::get_platforms() {
struct DevDescT {
const char *devName = nullptr;
int devNameSize = 0;

const char *devDriverVer = nullptr;
int devDriverVerSize = 0;

Expand Down Expand Up @@ -228,12 +267,30 @@ static void filterAllowList(vector_class<RT::PiDevice> &PiDevices,
PiDevices.resize(InsertIDx);
}

std::shared_ptr<device_impl> platform_impl::getOrMakeDeviceImpl(
RT::PiDevice PiDevice, const std::shared_ptr<platform_impl> &PlatformImpl) {
const std::lock_guard<std::mutex> Guard(MDeviceMapMutex);

// If we've already seen this device, return the impl
for (const std::shared_ptr<device_impl> &Device : MDeviceCache) {
if (Device->getHandleRef() == PiDevice)
return Device;
}

// Otherwise make the impl
std::shared_ptr<device_impl> Result =
std::make_shared<device_impl>(PiDevice, PlatformImpl);
MDeviceCache.emplace_back(Result);

return Result;
}

vector_class<device>
platform_impl::get_devices(info::device_type DeviceType) const {
vector_class<device> Res;
if (is_host() && (DeviceType == info::device_type::host ||
DeviceType == info::device_type::all)) {
Res.resize(1); // default device constructor creates host device
Res.push_back(device());
}

// If any DeviceType other than host was requested for host platform,
Expand All @@ -260,12 +317,13 @@ platform_impl::get_devices(info::device_type DeviceType) const {
if (SYCLConfig<SYCL_DEVICE_ALLOWLIST>::get())
filterAllowList(PiDevices, MPlatform, this->getPlugin());

std::transform(PiDevices.begin(), PiDevices.end(), std::back_inserter(Res),
[this](const RT::PiDevice &PiDevice) -> device {
return detail::createSyclObjFromImpl<device>(
std::make_shared<device_impl>(
PiDevice, std::make_shared<platform_impl>(*this)));
});
PlatformImplPtr PlatformImpl = getOrMakePlatformImpl(MPlatform, *MPlugin);
std::transform(
PiDevices.begin(), PiDevices.end(), std::back_inserter(Res),
[this, PlatformImpl](const RT::PiDevice &PiDevice) -> device {
return detail::createSyclObjFromImpl<device>(
PlatformImpl->getOrMakeDeviceImpl(PiDevice, PlatformImpl));
});

return Res;
}
Expand Down
46 changes: 46 additions & 0 deletions sycl/source/detail/platform_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,57 @@ class platform_impl {
/// given feature.
bool has(aspect Aspect) const;

/// Queries the device_impl cache to either return a shared_ptr
/// for the device_impl corresponding to the PiDevice or add
/// a new entry to the cache
///
/// \param PiDevice is the PiDevice whose impl is requested
///
/// \param PlatormImpl is the Platform for that Device
///
/// \return a shared_ptr<device_impl> corresponding to the device
std::shared_ptr<device_impl>
getOrMakeDeviceImpl(RT::PiDevice PiDevice,
const std::shared_ptr<platform_impl> &PlatformImpl);

/// Static functions that help maintain platform uniquess and
/// equality of comparison

/// Returns the host platform impl
///
/// \return the host platform impl
static std::shared_ptr<platform_impl> getHostPlatformImpl();

/// Queries the cache to see if the specified PiPlatform has been seen
/// before. If so, return the cached platform_impl, otherwise create a new
/// one and cache it.
///
/// \param PiPlatform is the PI Platform handle representing the platform
/// \param Plugin is the PI plugin providing the backend for the platform
/// \return the platform_impl representing the PI platform
static std::shared_ptr<platform_impl>
getOrMakePlatformImpl(RT::PiPlatform PiPlatform, const plugin &Plugin);

/// Queries the cache for the specified platform based on an input device.
/// If found, returns the the cached platform_impl, otherwise creates a new
/// one and caches it.
///
/// \param PiDevice is the PI device handle for the device whose platform is
/// desired
/// \param Plugin is the PI plugin providing the backend for the device and
/// platform
/// \return the platform_impl that contains the input device
static std::shared_ptr<platform_impl>
getPlatformFromPiDevice(RT::PiDevice PiDevice, const plugin &Plugin);

private:
bool MHostPlatform = false;
RT::PiPlatform MPlatform = 0;
std::shared_ptr<plugin> MPlugin;
std::vector<std::shared_ptr<device_impl>> MDeviceCache;
std::mutex MDeviceMapMutex;
};

} // namespace detail
} // namespace sycl
} // __SYCL_INLINE_NAMESPACE(cl)
2 changes: 1 addition & 1 deletion sycl/source/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ void force_type(info::device_type &t, const info::device_type &ft) {
}
} // namespace detail

device::device() : impl(std::make_shared<detail::device_impl>()) {}
device::device() : impl(detail::device_impl::getHostDeviceImpl()) {}

device::device(cl_device_id deviceId)
: impl(std::make_shared<detail::device_impl>(
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
__SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {

platform::platform() : impl(std::make_shared<detail::platform_impl>()) {}
platform::platform() : impl(detail::platform_impl::getHostPlatformImpl()) {}

platform::platform(cl_platform_id PlatformId)
: impl(std::make_shared<detail::platform_impl>(
Expand Down
54 changes: 54 additions & 0 deletions sycl/test/basic_tests/device_equality.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t1.out
// RUN: env SYCL_DEVICE_TYPE=HOST %t1.out
// RUN: %CPU_RUN_PLACEHOLDER %t1.out
// RUN: %GPU_RUN_PLACEHOLDER %t1.out

//==------- device_equality.cpp - SYCL device equality test ----------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <CL/sycl.hpp>
#include <cassert>
#include <iostream>
#include <utility>

using namespace cl::sycl;

int main() {
std::cout << "Creating q1" << std::endl;
queue q1;
std::cout << "Creating q2" << std::endl;
queue q2;

// Default selector picks the same device every time.
// That device should compare equal to itself.
// Its platform should too.

auto dev1 = q1.get_device();
auto plat1 = dev1.get_platform();

auto dev2 = q2.get_device();
auto plat2 = dev2.get_platform();

assert((dev1 == dev2) && "Device 1 == Device 2");
assert((plat1 == plat2) && "Platform 1 == Platform 2");

device h1;
device h2;

assert(h1.is_host() && "Device h1 is host");
assert(h2.is_host() && "Device h2 is host");
assert(h1 == h2 && "Host devices equal each other");

platform hp1 = h1.get_platform();
platform hp2 = h2.get_platform();
assert(hp1.is_host() && "Platform hp1 is host");
assert(hp2.is_host() && "Platform hp2 is host");
assert(hp1 == hp2 && "Host platforms equal each other");

return 0;
}
8 changes: 4 additions & 4 deletions sycl/unittests/pi/PiMock.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ TEST(PiMockTest, ConstructFromQueue) {
detail::getSyclObjImpl(Mock.getPlatform())->getPlugin().getPiPlugin();
EXPECT_EQ(&MockedQueuePiPlugin, &PiMockPlugin)
<< "The mocked object and the PiMock instance must share the same plugin";
ASSERT_FALSE(&NormalPiPlugin == &MockedQueuePiPlugin)
<< "Normal and mock platforms must not share the same plugin";
EXPECT_EQ(&NormalPiPlugin, &MockedQueuePiPlugin)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it breaks design of the plugin mock.
@AGindinson Could you please comment?

<< "Normal and mock platforms must share the same plugin";
}

TEST(PiMockTest, ConstructFromPlatform) {
Expand All @@ -60,8 +60,8 @@ TEST(PiMockTest, ConstructFromPlatform) {
detail::getSyclObjImpl(Mock.getPlatform())->getPlugin().getPiPlugin();
EXPECT_EQ(&MockedPlatformPiPlugin, &PiMockPlugin)
<< "The mocked object and the PiMock instance must share the same plugin";
ASSERT_FALSE(&NormalPiPlugin == &MockedPlatformPiPlugin)
<< "Normal and mock platforms must not share the same plugin";
EXPECT_EQ(&NormalPiPlugin, &MockedPlatformPiPlugin)
<< "Normal and mock platforms must share the same plugin";
}

TEST(PiMockTest, RedefineAPI) {
Expand Down