@@ -1802,19 +1802,26 @@ pi_result piDevicePartition(pi_device Device,
1802
1802
1803
1803
PI_ASSERT (Device, PI_INVALID_DEVICE);
1804
1804
1805
+ // Check if Device was already partitioned into the same or bigger size
1806
+ // before. If so, we can return immediately without searching the global
1807
+ // device cache. Note that L0 driver always returns the same handles in the
1808
+ // same order for the given number of sub-devices.
1809
+ if (OutDevices && NumDevices <= Device->SubDevices .size ()) {
1810
+ for (uint32_t I = 0 ; I < NumDevices; I++) {
1811
+ OutDevices[I] = Device->SubDevices [I];
1812
+ // reusing the same pi_device needs to increment the reference count
1813
+ piDeviceRetain (OutDevices[I]);
1814
+ }
1815
+ if (OutNumDevices)
1816
+ *OutNumDevices = NumDevices;
1817
+ return PI_SUCCESS;
1818
+ }
1819
+
1805
1820
// Get the number of subdevices available.
1806
1821
// TODO: maybe add interface to create the specified # of subdevices.
1807
1822
uint32_t Count = 0 ;
1808
1823
ZE_CALL (zeDeviceGetSubDevices (Device->ZeDevice , &Count, nullptr ));
1809
1824
1810
- // Check that the requested/allocated # of sub-devices is the same
1811
- // as was reported by the above call.
1812
- // TODO: we may want to support smaller/larger # devices too.
1813
- if (Count != NumDevices) {
1814
- zePrint (" piDevicePartition: unsupported # of sub-devices requested\n " );
1815
- return PI_INVALID_OPERATION;
1816
- }
1817
-
1818
1825
if (OutNumDevices) {
1819
1826
*OutNumDevices = Count;
1820
1827
}
@@ -1825,17 +1832,29 @@ pi_result piDevicePartition(pi_device Device,
1825
1832
}
1826
1833
1827
1834
try {
1835
+ pi_platform Platform = Device->Platform ;
1828
1836
auto ZeSubdevices = new ze_device_handle_t [Count];
1829
1837
ZE_CALL (zeDeviceGetSubDevices (Device->ZeDevice , &Count, ZeSubdevices));
1830
1838
1831
1839
// Wrap the Level Zero sub-devices into PI sub-devices, and write them out.
1832
1840
for (uint32_t I = 0 ; I < Count; ++I) {
1833
- OutDevices[I] = new _pi_device (ZeSubdevices[I], Device->Platform ,
1834
- true /* isSubDevice */ );
1835
- pi_result Result = OutDevices[I]->initialize ();
1836
- if (Result != PI_SUCCESS) {
1837
- delete[] ZeSubdevices;
1838
- return Result;
1841
+ pi_device Dev = Platform->getDeviceFromNativeHandle (ZeSubdevices[I]);
1842
+ if (Dev) {
1843
+ OutDevices[I] = Dev;
1844
+ // reusing the same pi_device needs to increment the reference count
1845
+ piDeviceRetain (OutDevices[I]);
1846
+ } else {
1847
+ std::unique_ptr<_pi_device> PiSubDevice (
1848
+ new _pi_device (ZeSubdevices[I], Platform));
1849
+ pi_result Result = PiSubDevice->initialize ();
1850
+ if (Result != PI_SUCCESS) {
1851
+ delete[] ZeSubdevices;
1852
+ return Result;
1853
+ }
1854
+ OutDevices[I] = PiSubDevice.get ();
1855
+ Platform->PiDevicesCache .push_back (std::move (PiSubDevice));
1856
+ // save pointers to sub-devices for quick retrieval in the future.
1857
+ Device->SubDevices .push_back (Dev);
1839
1858
}
1840
1859
}
1841
1860
delete[] ZeSubdevices;
@@ -1911,29 +1930,25 @@ pi_result piextDeviceCreateWithNativeHandle(pi_native_handle NativeHandle,
1911
1930
PI_ASSERT (Device, PI_INVALID_DEVICE);
1912
1931
PI_ASSERT (NativeHandle, PI_INVALID_VALUE);
1913
1932
PI_ASSERT (Platform, PI_INVALID_PLATFORM);
1914
-
1915
- std::lock_guard<std::mutex> Lock (Platform->PiDevicesCacheMutex );
1916
- pi_result Res = populateDeviceCacheIfNeeded (Platform);
1917
- if (Res != PI_SUCCESS) {
1918
- return Res;
1933
+ {
1934
+ std::lock_guard<std::mutex> Lock (Platform->PiDevicesCacheMutex );
1935
+ pi_result Res = populateDeviceCacheIfNeeded (Platform);
1936
+ if (Res != PI_SUCCESS) {
1937
+ return Res;
1938
+ }
1919
1939
}
1920
-
1921
1940
auto ZeDevice = pi_cast<ze_device_handle_t >(NativeHandle);
1922
1941
1923
1942
// The SYCL spec requires that the set of devices must remain fixed for the
1924
1943
// duration of the application's execution. We assume that we found all of the
1925
1944
// Level Zero devices when we initialized the device cache, so the
1926
1945
// "NativeHandle" must already be in the cache. If it is not, this must not be
1927
1946
// a valid Level Zero device.
1928
- for (const std::unique_ptr<_pi_device> &CachedDevice :
1929
- Platform->PiDevicesCache ) {
1930
- if (CachedDevice->ZeDevice == ZeDevice) {
1931
- *Device = CachedDevice.get ();
1932
- return PI_SUCCESS;
1933
- }
1934
- }
1935
-
1936
- return PI_INVALID_VALUE;
1947
+ pi_device Dev = Platform->getDeviceFromNativeHandle (ZeDevice);
1948
+ if (Dev == nullptr )
1949
+ return PI_INVALID_VALUE;
1950
+ *Device = Dev;
1951
+ return PI_SUCCESS;
1937
1952
}
1938
1953
1939
1954
pi_result piContextCreate (const pi_context_properties *Properties,
0 commit comments