Skip to content

Commit d392b51

Browse files
authored
[SYCL] Cache devices and platforms (#2092)
Cache devices and platforms so they actually compare == when they should. Signed-off-by: James Brodman <[email protected]>
1 parent 6e9bf3b commit d392b51

File tree

10 files changed

+204
-30
lines changed

10 files changed

+204
-30
lines changed

sycl/source/detail/context_impl.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <CL/sycl/stl.hpp>
1818
#include <detail/context_impl.hpp>
1919
#include <detail/context_info.hpp>
20+
#include <detail/platform_impl.hpp>
2021

2122
__SYCL_INLINE_NAMESPACE(cl) {
2223
namespace sycl {
@@ -80,12 +81,15 @@ context_impl::context_impl(RT::PiContext PiContext, async_handler AsyncHandler,
8081
sizeof(RT::PiDevice) * DevicesNum,
8182
&DeviceIds[0], nullptr);
8283

83-
for (auto Dev : DeviceIds) {
84-
MDevices.emplace_back(createSyclObjFromImpl<device>(
85-
std::make_shared<device_impl>(Dev, Plugin)));
84+
if (!DeviceIds.empty()) {
85+
std::shared_ptr<detail::platform_impl> Platform =
86+
platform_impl::getPlatformFromPiDevice(DeviceIds[0], Plugin);
87+
for (RT::PiDevice Dev : DeviceIds) {
88+
MDevices.emplace_back(createSyclObjFromImpl<device>(
89+
Platform->getOrMakeDeviceImpl(Dev, Platform)));
90+
}
91+
MPlatform = Platform;
8692
}
87-
// TODO What if m_Devices if empty? m_Devices[0].get_platform()
88-
MPlatform = detail::getSyclObjImpl(MDevices[0].get_platform());
8993
// TODO catch an exception and put it to list of asynchronous exceptions
9094
// getPlugin() will be the same as the Plugin passed. This should be taken
9195
// care of when creating device object.

sycl/source/detail/device_impl.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <CL/sycl/device.hpp>
1010
#include <detail/device_impl.hpp>
11+
#include <detail/platform_impl.hpp>
1112

1213
#include <algorithm>
1314

@@ -16,8 +17,7 @@ namespace sycl {
1617
namespace detail {
1718

1819
device_impl::device_impl()
19-
: MIsHostDevice(true),
20-
MPlatform(std::make_shared<platform_impl>(platform_impl())) {}
20+
: MIsHostDevice(true), MPlatform(platform_impl::getHostPlatformImpl()) {}
2121

2222
device_impl::device_impl(pi_native_handle InteropDeviceHandle,
2323
const plugin &Plugin)
@@ -67,11 +67,7 @@ device_impl::device_impl(pi_native_handle InteropDeviceHandle,
6767

6868
// set MPlatform
6969
if (!Platform) {
70-
RT::PiPlatform plt = nullptr; // TODO catch an exception and put it to list
71-
// of asynchronous exceptions
72-
Plugin.call<PiApiKind::piDeviceGetInfo>(MDevice, PI_DEVICE_INFO_PLATFORM,
73-
sizeof(plt), &plt, nullptr);
74-
Platform = std::make_shared<platform_impl>(plt, Plugin);
70+
Platform = platform_impl::getPlatformFromPiDevice(MDevice, Plugin);
7571
}
7672
MPlatform = Platform;
7773
}
@@ -146,7 +142,7 @@ device_impl::create_sub_devices(const cl_device_partition_property *Properties,
146142
std::for_each(SubDevices.begin(), SubDevices.end(),
147143
[&res, this](const RT::PiDevice &a_pi_device) {
148144
device sycl_device = detail::createSyclObjFromImpl<device>(
149-
std::make_shared<device_impl>(a_pi_device, MPlatform));
145+
MPlatform->getOrMakeDeviceImpl(a_pi_device, MPlatform));
150146
res.push_back(sycl_device);
151147
});
152148
return res;
@@ -251,6 +247,13 @@ bool device_impl::has(aspect Aspect) const {
251247
}
252248
}
253249

250+
std::shared_ptr<device_impl> device_impl::getHostDeviceImpl() {
251+
static std::shared_ptr<device_impl> HostImpl =
252+
std::make_shared<device_impl>();
253+
254+
return HostImpl;
255+
}
256+
254257
} // namespace detail
255258
} // namespace sycl
256259
} // __SYCL_INLINE_NAMESPACE(cl)

sycl/source/detail/device_impl.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,11 @@ class device_impl {
212212
/// \return true if the SYCL device has the given feature.
213213
bool has(aspect Aspect) const;
214214

215+
/// Gets the single instance of the Host Device
216+
///
217+
/// \return the host device_impl singleton
218+
static std::shared_ptr<device_impl> getHostDeviceImpl();
219+
215220
private:
216221
explicit device_impl(pi_native_handle InteropDevice, RT::PiDevice Device,
217222
PlatformImplPtr Platform, const plugin &Plugin);

sycl/source/detail/device_info.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <CL/sycl/info/info_desc.hpp>
1616
#include <CL/sycl/platform.hpp>
1717
#include <detail/device_impl.hpp>
18+
#include <detail/platform_impl.hpp>
1819
#include <detail/platform_util.hpp>
1920
#include <detail/plugin.hpp>
2021

@@ -117,7 +118,7 @@ template <info::device param> struct get_device_info<platform, param> {
117118
// Use the Plugin from the device_impl class after plugin details
118119
// are added to the class.
119120
return createSyclObjFromImpl<platform>(
120-
std::make_shared<platform_impl>(result, Plugin));
121+
platform_impl::getOrMakePlatformImpl(result, Plugin));
121122
}
122123
};
123124

@@ -406,8 +407,11 @@ template <> struct get_device_info<device, info::device::parent_device> {
406407
"No parent for device because it is not a subdevice",
407408
PI_INVALID_DEVICE);
408409

410+
// Get the platform of this device
411+
std::shared_ptr<detail::platform_impl> Platform =
412+
platform_impl::getPlatformFromPiDevice(dev, Plugin);
409413
return createSyclObjFromImpl<device>(
410-
std::make_shared<device_impl>(result, Plugin));
414+
Platform->getOrMakeDeviceImpl(result, Platform));
411415
}
412416
};
413417

sycl/source/detail/platform_impl.cpp

Lines changed: 67 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,46 @@ __SYCL_INLINE_NAMESPACE(cl) {
2121
namespace sycl {
2222
namespace detail {
2323

24+
using PlatformImplPtr = std::shared_ptr<platform_impl>;
25+
26+
PlatformImplPtr platform_impl::getHostPlatformImpl() {
27+
static PlatformImplPtr HostImpl = std::make_shared<platform_impl>();
28+
29+
return HostImpl;
30+
}
31+
32+
PlatformImplPtr platform_impl::getOrMakePlatformImpl(RT::PiPlatform PiPlatform,
33+
const plugin &Plugin) {
34+
static std::vector<PlatformImplPtr> PlatformCache;
35+
static std::mutex PlatformMapMutex;
36+
37+
PlatformImplPtr Result;
38+
{
39+
const std::lock_guard<std::mutex> Guard(PlatformMapMutex);
40+
41+
// If we've already seen this platform, return the impl
42+
for (const auto &PlatImpl : PlatformCache) {
43+
if (PlatImpl->getHandleRef() == PiPlatform)
44+
return PlatImpl;
45+
}
46+
47+
// Otherwise make the impl
48+
Result = std::make_shared<platform_impl>(PiPlatform, Plugin);
49+
PlatformCache.emplace_back(Result);
50+
}
51+
52+
return Result;
53+
}
54+
55+
PlatformImplPtr platform_impl::getPlatformFromPiDevice(RT::PiDevice PiDevice,
56+
const plugin &Plugin) {
57+
RT::PiPlatform Plt = nullptr; // TODO catch an exception and put it to list
58+
// of asynchronous exceptions
59+
Plugin.call<PiApiKind::piDeviceGetInfo>(PiDevice, PI_DEVICE_INFO_PLATFORM,
60+
sizeof(Plt), &Plt, nullptr);
61+
return getOrMakePlatformImpl(Plt, Plugin);
62+
}
63+
2464
static bool IsBannedPlatform(platform Platform) {
2565
// The NVIDIA OpenCL platform is currently not compatible with DPC++
2666
// since it is only 1.2 but gets selected by default in many systems
@@ -65,7 +105,7 @@ vector_class<platform> platform_impl::get_platforms() {
65105

66106
for (const auto &PiPlatform : PiPlatforms) {
67107
platform Platform = detail::createSyclObjFromImpl<platform>(
68-
std::make_shared<platform_impl>(PiPlatform, Plugins[i]));
108+
getOrMakePlatformImpl(PiPlatform, Plugins[i]));
69109
// Skip platforms which do not contain requested device types
70110
if (!Platform.get_devices(ForcedType).empty() &&
71111
!IsBannedPlatform(Platform))
@@ -83,7 +123,6 @@ vector_class<platform> platform_impl::get_platforms() {
83123
struct DevDescT {
84124
const char *devName = nullptr;
85125
int devNameSize = 0;
86-
87126
const char *devDriverVer = nullptr;
88127
int devDriverVerSize = 0;
89128

@@ -228,12 +267,30 @@ static void filterAllowList(vector_class<RT::PiDevice> &PiDevices,
228267
PiDevices.resize(InsertIDx);
229268
}
230269

270+
std::shared_ptr<device_impl> platform_impl::getOrMakeDeviceImpl(
271+
RT::PiDevice PiDevice, const std::shared_ptr<platform_impl> &PlatformImpl) {
272+
const std::lock_guard<std::mutex> Guard(MDeviceMapMutex);
273+
274+
// If we've already seen this device, return the impl
275+
for (const std::shared_ptr<device_impl> &Device : MDeviceCache) {
276+
if (Device->getHandleRef() == PiDevice)
277+
return Device;
278+
}
279+
280+
// Otherwise make the impl
281+
std::shared_ptr<device_impl> Result =
282+
std::make_shared<device_impl>(PiDevice, PlatformImpl);
283+
MDeviceCache.emplace_back(Result);
284+
285+
return Result;
286+
}
287+
231288
vector_class<device>
232289
platform_impl::get_devices(info::device_type DeviceType) const {
233290
vector_class<device> Res;
234291
if (is_host() && (DeviceType == info::device_type::host ||
235292
DeviceType == info::device_type::all)) {
236-
Res.resize(1); // default device constructor creates host device
293+
Res.push_back(device());
237294
}
238295

239296
// If any DeviceType other than host was requested for host platform,
@@ -260,12 +317,13 @@ platform_impl::get_devices(info::device_type DeviceType) const {
260317
if (SYCLConfig<SYCL_DEVICE_ALLOWLIST>::get())
261318
filterAllowList(PiDevices, MPlatform, this->getPlugin());
262319

263-
std::transform(PiDevices.begin(), PiDevices.end(), std::back_inserter(Res),
264-
[this](const RT::PiDevice &PiDevice) -> device {
265-
return detail::createSyclObjFromImpl<device>(
266-
std::make_shared<device_impl>(
267-
PiDevice, std::make_shared<platform_impl>(*this)));
268-
});
320+
PlatformImplPtr PlatformImpl = getOrMakePlatformImpl(MPlatform, *MPlugin);
321+
std::transform(
322+
PiDevices.begin(), PiDevices.end(), std::back_inserter(Res),
323+
[this, PlatformImpl](const RT::PiDevice &PiDevice) -> device {
324+
return detail::createSyclObjFromImpl<device>(
325+
PlatformImpl->getOrMakeDeviceImpl(PiDevice, PlatformImpl));
326+
});
269327

270328
return Res;
271329
}

sycl/source/detail/platform_impl.hpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,57 @@ class platform_impl {
137137
/// given feature.
138138
bool has(aspect Aspect) const;
139139

140+
/// Queries the device_impl cache to either return a shared_ptr
141+
/// for the device_impl corresponding to the PiDevice or add
142+
/// a new entry to the cache
143+
///
144+
/// \param PiDevice is the PiDevice whose impl is requested
145+
///
146+
/// \param PlatormImpl is the Platform for that Device
147+
///
148+
/// \return a shared_ptr<device_impl> corresponding to the device
149+
std::shared_ptr<device_impl>
150+
getOrMakeDeviceImpl(RT::PiDevice PiDevice,
151+
const std::shared_ptr<platform_impl> &PlatformImpl);
152+
153+
/// Static functions that help maintain platform uniquess and
154+
/// equality of comparison
155+
156+
/// Returns the host platform impl
157+
///
158+
/// \return the host platform impl
159+
static std::shared_ptr<platform_impl> getHostPlatformImpl();
160+
161+
/// Queries the cache to see if the specified PiPlatform has been seen
162+
/// before. If so, return the cached platform_impl, otherwise create a new
163+
/// one and cache it.
164+
///
165+
/// \param PiPlatform is the PI Platform handle representing the platform
166+
/// \param Plugin is the PI plugin providing the backend for the platform
167+
/// \return the platform_impl representing the PI platform
168+
static std::shared_ptr<platform_impl>
169+
getOrMakePlatformImpl(RT::PiPlatform PiPlatform, const plugin &Plugin);
170+
171+
/// Queries the cache for the specified platform based on an input device.
172+
/// If found, returns the the cached platform_impl, otherwise creates a new
173+
/// one and caches it.
174+
///
175+
/// \param PiDevice is the PI device handle for the device whose platform is
176+
/// desired
177+
/// \param Plugin is the PI plugin providing the backend for the device and
178+
/// platform
179+
/// \return the platform_impl that contains the input device
180+
static std::shared_ptr<platform_impl>
181+
getPlatformFromPiDevice(RT::PiDevice PiDevice, const plugin &Plugin);
182+
140183
private:
141184
bool MHostPlatform = false;
142185
RT::PiPlatform MPlatform = 0;
143186
std::shared_ptr<plugin> MPlugin;
187+
std::vector<std::shared_ptr<device_impl>> MDeviceCache;
188+
std::mutex MDeviceMapMutex;
144189
};
190+
145191
} // namespace detail
146192
} // namespace sycl
147193
} // __SYCL_INLINE_NAMESPACE(cl)

sycl/source/device.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ void force_type(info::device_type &t, const info::device_type &ft) {
2727
}
2828
} // namespace detail
2929

30-
device::device() : impl(std::make_shared<detail::device_impl>()) {}
30+
device::device() : impl(detail::device_impl::getHostDeviceImpl()) {}
3131

3232
device::device(cl_device_id deviceId)
3333
: impl(std::make_shared<detail::device_impl>(

sycl/source/platform.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
__SYCL_INLINE_NAMESPACE(cl) {
1717
namespace sycl {
1818

19-
platform::platform() : impl(std::make_shared<detail::platform_impl>()) {}
19+
platform::platform() : impl(detail::platform_impl::getHostPlatformImpl()) {}
2020

2121
platform::platform(cl_platform_id PlatformId)
2222
: impl(std::make_shared<detail::platform_impl>(
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t1.out
2+
// RUN: env SYCL_DEVICE_TYPE=HOST %t1.out
3+
// RUN: %CPU_RUN_PLACEHOLDER %t1.out
4+
// RUN: %GPU_RUN_PLACEHOLDER %t1.out
5+
6+
//==------- device_equality.cpp - SYCL device equality test ----------------==//
7+
//
8+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9+
// See https://llvm.org/LICENSE.txt for license information.
10+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include <CL/sycl.hpp>
15+
#include <cassert>
16+
#include <iostream>
17+
#include <utility>
18+
19+
using namespace cl::sycl;
20+
21+
int main() {
22+
std::cout << "Creating q1" << std::endl;
23+
queue q1;
24+
std::cout << "Creating q2" << std::endl;
25+
queue q2;
26+
27+
// Default selector picks the same device every time.
28+
// That device should compare equal to itself.
29+
// Its platform should too.
30+
31+
auto dev1 = q1.get_device();
32+
auto plat1 = dev1.get_platform();
33+
34+
auto dev2 = q2.get_device();
35+
auto plat2 = dev2.get_platform();
36+
37+
assert((dev1 == dev2) && "Device 1 == Device 2");
38+
assert((plat1 == plat2) && "Platform 1 == Platform 2");
39+
40+
device h1;
41+
device h2;
42+
43+
assert(h1.is_host() && "Device h1 is host");
44+
assert(h2.is_host() && "Device h2 is host");
45+
assert(h1 == h2 && "Host devices equal each other");
46+
47+
platform hp1 = h1.get_platform();
48+
platform hp2 = h2.get_platform();
49+
assert(hp1.is_host() && "Platform hp1 is host");
50+
assert(hp2.is_host() && "Platform hp2 is host");
51+
assert(hp1 == hp2 && "Host platforms equal each other");
52+
53+
return 0;
54+
}

sycl/unittests/pi/PiMock.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ TEST(PiMockTest, ConstructFromQueue) {
3939
detail::getSyclObjImpl(Mock.getPlatform())->getPlugin().getPiPlugin();
4040
EXPECT_EQ(&MockedQueuePiPlugin, &PiMockPlugin)
4141
<< "The mocked object and the PiMock instance must share the same plugin";
42-
ASSERT_FALSE(&NormalPiPlugin == &MockedQueuePiPlugin)
43-
<< "Normal and mock platforms must not share the same plugin";
42+
EXPECT_EQ(&NormalPiPlugin, &MockedQueuePiPlugin)
43+
<< "Normal and mock platforms must share the same plugin";
4444
}
4545

4646
TEST(PiMockTest, ConstructFromPlatform) {
@@ -60,8 +60,8 @@ TEST(PiMockTest, ConstructFromPlatform) {
6060
detail::getSyclObjImpl(Mock.getPlatform())->getPlugin().getPiPlugin();
6161
EXPECT_EQ(&MockedPlatformPiPlugin, &PiMockPlugin)
6262
<< "The mocked object and the PiMock instance must share the same plugin";
63-
ASSERT_FALSE(&NormalPiPlugin == &MockedPlatformPiPlugin)
64-
<< "Normal and mock platforms must not share the same plugin";
63+
EXPECT_EQ(&NormalPiPlugin, &MockedPlatformPiPlugin)
64+
<< "Normal and mock platforms must share the same plugin";
6565
}
6666

6767
TEST(PiMockTest, RedefineAPI) {

0 commit comments

Comments
 (0)