Skip to content

Commit 0b18b49

Browse files
authored
[SYCL][L0] Add SubDevices into the device cache (#3314)
This PR adds the missing sub-devices in the device cache. This fixes the missing cache loopup for sub-devices. Signed-off-by: Byoungro So <[email protected]>
1 parent e8de2b0 commit 0b18b49

File tree

2 files changed

+50
-29
lines changed

2 files changed

+50
-29
lines changed

sycl/plugins/level_zero/pi_level_zero.cpp

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1802,19 +1802,26 @@ pi_result piDevicePartition(pi_device Device,
18021802

18031803
PI_ASSERT(Device, PI_INVALID_DEVICE);
18041804

1805+
// Check if Device was already partitioned into the same or bigger size
1806+
// before. If so, we can return immediately without searching the global
1807+
// device cache. Note that L0 driver always returns the same handles in the
1808+
// same order for the given number of sub-devices.
1809+
if (OutDevices && NumDevices <= Device->SubDevices.size()) {
1810+
for (uint32_t I = 0; I < NumDevices; I++) {
1811+
OutDevices[I] = Device->SubDevices[I];
1812+
// reusing the same pi_device needs to increment the reference count
1813+
piDeviceRetain(OutDevices[I]);
1814+
}
1815+
if (OutNumDevices)
1816+
*OutNumDevices = NumDevices;
1817+
return PI_SUCCESS;
1818+
}
1819+
18051820
// Get the number of subdevices available.
18061821
// TODO: maybe add interface to create the specified # of subdevices.
18071822
uint32_t Count = 0;
18081823
ZE_CALL(zeDeviceGetSubDevices(Device->ZeDevice, &Count, nullptr));
18091824

1810-
// Check that the requested/allocated # of sub-devices is the same
1811-
// as was reported by the above call.
1812-
// TODO: we may want to support smaller/larger # devices too.
1813-
if (Count != NumDevices) {
1814-
zePrint("piDevicePartition: unsupported # of sub-devices requested\n");
1815-
return PI_INVALID_OPERATION;
1816-
}
1817-
18181825
if (OutNumDevices) {
18191826
*OutNumDevices = Count;
18201827
}
@@ -1825,17 +1832,29 @@ pi_result piDevicePartition(pi_device Device,
18251832
}
18261833

18271834
try {
1835+
pi_platform Platform = Device->Platform;
18281836
auto ZeSubdevices = new ze_device_handle_t[Count];
18291837
ZE_CALL(zeDeviceGetSubDevices(Device->ZeDevice, &Count, ZeSubdevices));
18301838

18311839
// Wrap the Level Zero sub-devices into PI sub-devices, and write them out.
18321840
for (uint32_t I = 0; I < Count; ++I) {
1833-
OutDevices[I] = new _pi_device(ZeSubdevices[I], Device->Platform,
1834-
true /* isSubDevice */);
1835-
pi_result Result = OutDevices[I]->initialize();
1836-
if (Result != PI_SUCCESS) {
1837-
delete[] ZeSubdevices;
1838-
return Result;
1841+
pi_device Dev = Platform->getDeviceFromNativeHandle(ZeSubdevices[I]);
1842+
if (Dev) {
1843+
OutDevices[I] = Dev;
1844+
// reusing the same pi_device needs to increment the reference count
1845+
piDeviceRetain(OutDevices[I]);
1846+
} else {
1847+
std::unique_ptr<_pi_device> PiSubDevice(
1848+
new _pi_device(ZeSubdevices[I], Platform));
1849+
pi_result Result = PiSubDevice->initialize();
1850+
if (Result != PI_SUCCESS) {
1851+
delete[] ZeSubdevices;
1852+
return Result;
1853+
}
1854+
OutDevices[I] = PiSubDevice.get();
1855+
Platform->PiDevicesCache.push_back(std::move(PiSubDevice));
1856+
// save pointers to sub-devices for quick retrieval in the future.
1857+
Device->SubDevices.push_back(Dev);
18391858
}
18401859
}
18411860
delete[] ZeSubdevices;
@@ -1911,29 +1930,25 @@ pi_result piextDeviceCreateWithNativeHandle(pi_native_handle NativeHandle,
19111930
PI_ASSERT(Device, PI_INVALID_DEVICE);
19121931
PI_ASSERT(NativeHandle, PI_INVALID_VALUE);
19131932
PI_ASSERT(Platform, PI_INVALID_PLATFORM);
1914-
1915-
std::lock_guard<std::mutex> Lock(Platform->PiDevicesCacheMutex);
1916-
pi_result Res = populateDeviceCacheIfNeeded(Platform);
1917-
if (Res != PI_SUCCESS) {
1918-
return Res;
1933+
{
1934+
std::lock_guard<std::mutex> Lock(Platform->PiDevicesCacheMutex);
1935+
pi_result Res = populateDeviceCacheIfNeeded(Platform);
1936+
if (Res != PI_SUCCESS) {
1937+
return Res;
1938+
}
19191939
}
1920-
19211940
auto ZeDevice = pi_cast<ze_device_handle_t>(NativeHandle);
19221941

19231942
// The SYCL spec requires that the set of devices must remain fixed for the
19241943
// duration of the application's execution. We assume that we found all of the
19251944
// Level Zero devices when we initialized the device cache, so the
19261945
// "NativeHandle" must already be in the cache. If it is not, this must not be
19271946
// a valid Level Zero device.
1928-
for (const std::unique_ptr<_pi_device> &CachedDevice :
1929-
Platform->PiDevicesCache) {
1930-
if (CachedDevice->ZeDevice == ZeDevice) {
1931-
*Device = CachedDevice.get();
1932-
return PI_SUCCESS;
1933-
}
1934-
}
1935-
1936-
return PI_INVALID_VALUE;
1947+
pi_device Dev = Platform->getDeviceFromNativeHandle(ZeDevice);
1948+
if (Dev == nullptr)
1949+
return PI_INVALID_VALUE;
1950+
*Device = Dev;
1951+
return PI_SUCCESS;
19371952
}
19381953

19391954
pi_result piContextCreate(const pi_context_properties *Properties,

sycl/plugins/level_zero/pi_level_zero.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,12 @@ struct _pi_device : _pi_object {
156156
// Level Zero device handle.
157157
ze_device_handle_t ZeDevice;
158158

159+
// Keep the subdevices that are partitioned from this pi_device for reuse
160+
// The order of sub-devices in this vector is repeated from the
161+
// ze_device_handle_t array that are returned from zeDeviceGetSubDevices()
162+
// call, which will always return sub-devices in the fixed same order.
163+
std::vector<pi_device> SubDevices;
164+
159165
// PI platform to which this device belongs.
160166
pi_platform Platform;
161167

0 commit comments

Comments
 (0)