Skip to content

Commit 26d5d98

Browse files
[SYCL] Fix get_pointer_device for cases with descendent devices (#6719)
Looking through context members alone when searching for a specific device isn't enough anymore since now descendent devices of context members can be used within that context as well. Change the logic to look for the device in the cache instead.
1 parent 1adfa06 commit 26d5d98

File tree

3 files changed

+45
-15
lines changed

3 files changed

+45
-15
lines changed

sycl/source/detail/platform_impl.cpp

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -213,21 +213,23 @@ static void filterDeviceFilter(std::vector<RT::PiDevice> &PiDevices,
213213
Plugin.setLastDeviceId(Platform, DeviceNum);
214214
}
215215

216-
std::shared_ptr<device_impl> platform_impl::getOrMakeDeviceImpl(
216+
std::shared_ptr<device_impl> platform_impl::getDeviceImpl(
217217
RT::PiDevice PiDevice, const std::shared_ptr<platform_impl> &PlatformImpl) {
218218
const std::lock_guard<std::mutex> Guard(MDeviceMapMutex);
219+
return getDeviceImplHelper(PiDevice, PlatformImpl);
220+
}
219221

222+
std::shared_ptr<device_impl> platform_impl::getOrMakeDeviceImpl(
223+
RT::PiDevice PiDevice, const std::shared_ptr<platform_impl> &PlatformImpl) {
224+
const std::lock_guard<std::mutex> Guard(MDeviceMapMutex);
220225
// If we've already seen this device, return the impl
221-
for (const std::weak_ptr<device_impl> &DeviceWP : MDeviceCache) {
222-
if (std::shared_ptr<device_impl> Device = DeviceWP.lock()) {
223-
if (Device->getHandleRef() == PiDevice)
224-
return Device;
225-
}
226-
}
226+
std::shared_ptr<device_impl> Result =
227+
getDeviceImplHelper(PiDevice, PlatformImpl);
228+
if (Result)
229+
return Result;
227230

228231
// Otherwise make the impl
229-
std::shared_ptr<device_impl> Result =
230-
std::make_shared<device_impl>(PiDevice, PlatformImpl);
232+
Result = std::make_shared<device_impl>(PiDevice, PlatformImpl);
231233
MDeviceCache.emplace_back(Result);
232234

233235
return Result;
@@ -334,6 +336,17 @@ bool platform_impl::has(aspect Aspect) const {
334336
return true;
335337
}
336338

339+
std::shared_ptr<device_impl> platform_impl::getDeviceImplHelper(
340+
RT::PiDevice PiDevice, const std::shared_ptr<platform_impl> &PlatformImpl) {
341+
for (const std::weak_ptr<device_impl> &DeviceWP : MDeviceCache) {
342+
if (std::shared_ptr<device_impl> Device = DeviceWP.lock()) {
343+
if (Device->getHandleRef() == PiDevice)
344+
return Device;
345+
}
346+
}
347+
return nullptr;
348+
}
349+
337350
#define __SYCL_PARAM_TRAITS_SPEC(DescType, Desc, ReturnT, PiCode) \
338351
template ReturnT platform_impl::get_info<info::platform::Desc>() const;
339352

sycl/source/detail/platform_impl.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,18 @@ class platform_impl {
137137
/// given feature.
138138
bool has(aspect Aspect) const;
139139

140+
/// Queries the device_impl cache to return a shared_ptr for the
141+
/// device_impl corresponding to the PiDevice.
142+
///
143+
/// \param PiDevice is the PiDevice whose impl is requested
144+
///
145+
/// \param PlatormImpl is the Platform for that Device
146+
///
147+
/// \return a shared_ptr<device_impl> corresponding to the device
148+
std::shared_ptr<device_impl>
149+
getDeviceImpl(RT::PiDevice PiDevice,
150+
const std::shared_ptr<platform_impl> &PlatformImpl);
151+
140152
/// Queries the device_impl cache to either return a shared_ptr
141153
/// for the device_impl corresponding to the PiDevice or add
142154
/// a new entry to the cache
@@ -181,6 +193,10 @@ class platform_impl {
181193
getPlatformFromPiDevice(RT::PiDevice PiDevice, const plugin &Plugin);
182194

183195
private:
196+
std::shared_ptr<device_impl>
197+
getDeviceImplHelper(RT::PiDevice PiDevice,
198+
const std::shared_ptr<platform_impl> &PlatformImpl);
199+
184200
bool MHostPlatform = false;
185201
RT::PiPlatform MPlatform = 0;
186202
std::shared_ptr<plugin> MPlugin;

sycl/source/detail/usm/usm_impl.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -594,12 +594,13 @@ device get_pointer_device(const void *Ptr, const context &Ctxt) {
594594
Plugin.call<detail::PiApiKind::piextUSMGetMemAllocInfo>(
595595
PICtx, Ptr, PI_MEM_ALLOC_DEVICE, sizeof(pi_device), &DeviceId, nullptr);
596596

597-
for (const device &Dev : CtxImpl->getDevices()) {
598-
// Try to find the real sycl device used in the context
599-
if (detail::getSyclObjImpl(Dev)->getHandleRef() == DeviceId)
600-
return Dev;
601-
}
602-
597+
// The device is not necessarily a member of the context, it could be a
598+
// member's descendant instead. Fetch the corresponding device from the cache.
599+
std::shared_ptr<detail::platform_impl> PltImpl = CtxImpl->getPlatformImpl();
600+
std::shared_ptr<detail::device_impl> DevImpl =
601+
PltImpl->getDeviceImpl(DeviceId, PltImpl);
602+
if (DevImpl)
603+
return detail::createSyclObjFromImpl<device>(DevImpl);
603604
throw runtime_error("Cannot find device associated with USM allocation!",
604605
PI_ERROR_INVALID_OPERATION);
605606
}

0 commit comments

Comments
 (0)