-
Notifications
You must be signed in to change notification settings - Fork 787
[SYCL] Cache devices so they actually compare == when they should. #2092
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9c5ef98
0aa8627
bddfd12
e3ed628
3a76864
a8fd94d
a15c0fa
4bbad36
b4de638
f9cf973
cd8f3df
d4941b0
d962e6b
67ff7a6
c1f736f
a42dbf4
500d6ba
db9f1b4
b970c2a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,46 @@ __SYCL_INLINE_NAMESPACE(cl) { | |
namespace sycl { | ||
namespace detail { | ||
|
||
using PlatformImplPtr = std::shared_ptr<platform_impl>; | ||
|
||
PlatformImplPtr platform_impl::getHostPlatformImpl() { | ||
static PlatformImplPtr HostImpl = std::make_shared<platform_impl>(); | ||
|
||
return HostImpl; | ||
} | ||
|
||
PlatformImplPtr platform_impl::getOrMakePlatformImpl(RT::PiPlatform PiPlatform, | ||
const plugin &Plugin) { | ||
static std::vector<PlatformImplPtr> PlatformCache; | ||
static std::mutex PlatformMapMutex; | ||
|
||
PlatformImplPtr Result; | ||
{ | ||
const std::lock_guard<std::mutex> Guard(PlatformMapMutex); | ||
|
||
// If we've already seen this platform, return the impl | ||
for (const auto &PlatImpl : PlatformCache) { | ||
if (PlatImpl->getHandleRef() == PiPlatform) | ||
return PlatImpl; | ||
} | ||
|
||
// Otherwise make the impl | ||
Result = std::make_shared<platform_impl>(PiPlatform, Plugin); | ||
PlatformCache.emplace_back(Result); | ||
} | ||
|
||
return Result; | ||
} | ||
|
||
PlatformImplPtr platform_impl::getPlatformFromPiDevice(RT::PiDevice PiDevice, | ||
const plugin &Plugin) { | ||
RT::PiPlatform Plt = nullptr; // TODO catch an exception and put it to list | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems this TODO doesn't make any sense in this context. Or I'm missing something? |
||
// of asynchronous exceptions | ||
Plugin.call<PiApiKind::piDeviceGetInfo>(PiDevice, PI_DEVICE_INFO_PLATFORM, | ||
sizeof(Plt), &Plt, nullptr); | ||
return getOrMakePlatformImpl(Plt, Plugin); | ||
} | ||
|
||
static bool IsBannedPlatform(platform Platform) { | ||
// The NVIDIA OpenCL platform is currently not compatible with DPC++ | ||
// since it is only 1.2 but gets selected by default in many systems | ||
|
@@ -65,7 +105,7 @@ vector_class<platform> platform_impl::get_platforms() { | |
|
||
for (const auto &PiPlatform : PiPlatforms) { | ||
platform Platform = detail::createSyclObjFromImpl<platform>( | ||
std::make_shared<platform_impl>(PiPlatform, Plugins[i])); | ||
getOrMakePlatformImpl(PiPlatform, Plugins[i])); | ||
// Skip platforms which do not contain requested device types | ||
if (!Platform.get_devices(ForcedType).empty() && | ||
!IsBannedPlatform(Platform)) | ||
|
@@ -83,7 +123,6 @@ vector_class<platform> platform_impl::get_platforms() { | |
struct DevDescT { | ||
const char *devName = nullptr; | ||
int devNameSize = 0; | ||
|
||
const char *devDriverVer = nullptr; | ||
int devDriverVerSize = 0; | ||
|
||
|
@@ -228,12 +267,30 @@ static void filterAllowList(vector_class<RT::PiDevice> &PiDevices, | |
PiDevices.resize(InsertIDx); | ||
} | ||
|
||
std::shared_ptr<device_impl> platform_impl::getOrMakeDeviceImpl( | ||
RT::PiDevice PiDevice, const std::shared_ptr<platform_impl> &PlatformImpl) { | ||
const std::lock_guard<std::mutex> Guard(MDeviceMapMutex); | ||
|
||
// If we've already seen this device, return the impl | ||
for (const std::shared_ptr<device_impl> &Device : MDeviceCache) { | ||
if (Device->getHandleRef() == PiDevice) | ||
return Device; | ||
} | ||
|
||
// Otherwise make the impl | ||
std::shared_ptr<device_impl> Result = | ||
std::make_shared<device_impl>(PiDevice, PlatformImpl); | ||
MDeviceCache.emplace_back(Result); | ||
|
||
return Result; | ||
} | ||
|
||
vector_class<device> | ||
platform_impl::get_devices(info::device_type DeviceType) const { | ||
vector_class<device> Res; | ||
if (is_host() && (DeviceType == info::device_type::host || | ||
DeviceType == info::device_type::all)) { | ||
Res.resize(1); // default device constructor creates host device | ||
Res.push_back(device()); | ||
} | ||
|
||
// If any DeviceType other than host was requested for host platform, | ||
|
@@ -260,12 +317,13 @@ platform_impl::get_devices(info::device_type DeviceType) const { | |
if (SYCLConfig<SYCL_DEVICE_ALLOWLIST>::get()) | ||
filterAllowList(PiDevices, MPlatform, this->getPlugin()); | ||
|
||
std::transform(PiDevices.begin(), PiDevices.end(), std::back_inserter(Res), | ||
[this](const RT::PiDevice &PiDevice) -> device { | ||
return detail::createSyclObjFromImpl<device>( | ||
std::make_shared<device_impl>( | ||
PiDevice, std::make_shared<platform_impl>(*this))); | ||
}); | ||
PlatformImplPtr PlatformImpl = getOrMakePlatformImpl(MPlatform, *MPlugin); | ||
std::transform( | ||
PiDevices.begin(), PiDevices.end(), std::back_inserter(Res), | ||
[this, PlatformImpl](const RT::PiDevice &PiDevice) -> device { | ||
return detail::createSyclObjFromImpl<device>( | ||
PlatformImpl->getOrMakeDeviceImpl(PiDevice, PlatformImpl)); | ||
}); | ||
|
||
return Res; | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t1.out | ||
// RUN: env SYCL_DEVICE_TYPE=HOST %t1.out | ||
// RUN: %CPU_RUN_PLACEHOLDER %t1.out | ||
// RUN: %GPU_RUN_PLACEHOLDER %t1.out | ||
|
||
//==------- device_equality.cpp - SYCL device equality test ----------------==// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include <CL/sycl.hpp> | ||
#include <cassert> | ||
#include <iostream> | ||
#include <utility> | ||
|
||
using namespace cl::sycl; | ||
|
||
int main() { | ||
std::cout << "Creating q1" << std::endl; | ||
queue q1; | ||
std::cout << "Creating q2" << std::endl; | ||
queue q2; | ||
|
||
// Default selector picks the same device every time. | ||
// That device should compare equal to itself. | ||
// Its platform should too. | ||
|
||
auto dev1 = q1.get_device(); | ||
auto plat1 = dev1.get_platform(); | ||
|
||
auto dev2 = q2.get_device(); | ||
auto plat2 = dev2.get_platform(); | ||
|
||
assert((dev1 == dev2) && "Device 1 == Device 2"); | ||
assert((plat1 == plat2) && "Platform 1 == Platform 2"); | ||
|
||
device h1; | ||
device h2; | ||
|
||
assert(h1.is_host() && "Device h1 is host"); | ||
assert(h2.is_host() && "Device h2 is host"); | ||
assert(h1 == h2 && "Host devices equal each other"); | ||
|
||
platform hp1 = h1.get_platform(); | ||
platform hp2 = h2.get_platform(); | ||
assert(hp1.is_host() && "Platform hp1 is host"); | ||
assert(hp2.is_host() && "Platform hp2 is host"); | ||
assert(hp1 == hp2 && "Host platforms equal each other"); | ||
|
||
return 0; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,8 +39,8 @@ TEST(PiMockTest, ConstructFromQueue) { | |
detail::getSyclObjImpl(Mock.getPlatform())->getPlugin().getPiPlugin(); | ||
EXPECT_EQ(&MockedQueuePiPlugin, &PiMockPlugin) | ||
<< "The mocked object and the PiMock instance must share the same plugin"; | ||
ASSERT_FALSE(&NormalPiPlugin == &MockedQueuePiPlugin) | ||
<< "Normal and mock platforms must not share the same plugin"; | ||
EXPECT_EQ(&NormalPiPlugin, &MockedQueuePiPlugin) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess it breaks design of the plugin mock. |
||
<< "Normal and mock platforms must share the same plugin"; | ||
} | ||
|
||
TEST(PiMockTest, ConstructFromPlatform) { | ||
|
@@ -60,8 +60,8 @@ TEST(PiMockTest, ConstructFromPlatform) { | |
detail::getSyclObjImpl(Mock.getPlatform())->getPlugin().getPiPlugin(); | ||
EXPECT_EQ(&MockedPlatformPiPlugin, &PiMockPlugin) | ||
<< "The mocked object and the PiMock instance must share the same plugin"; | ||
ASSERT_FALSE(&NormalPiPlugin == &MockedPlatformPiPlugin) | ||
<< "Normal and mock platforms must not share the same plugin"; | ||
EXPECT_EQ(&NormalPiPlugin, &MockedPlatformPiPlugin) | ||
<< "Normal and mock platforms must share the same plugin"; | ||
} | ||
|
||
TEST(PiMockTest, RedefineAPI) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.