Skip to content

Commit 253d7fa

Browse files
[SYCL] Cache sub-devices early such they can be seen by interop API (#3690)
Signed-off-by: Sergey V Maslov <[email protected]>
1 parent d08c21a commit 253d7fa

File tree

2 files changed

+74
-81
lines changed

2 files changed

+74
-81
lines changed

sycl/plugins/level_zero/pi_level_zero.cpp

Lines changed: 67 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,13 @@ static std::map<std::string, int> *ZeCallCount = nullptr;
124124

125125
// Trace an internal PI call; returns in case of an error.
126126
#define PI_CALL(Call) \
127-
if (PrintPiTrace) \
128-
fprintf(stderr, "PI ---> %s\n", #Call); \
129-
pi_result Result = (Call); \
130-
if (Result != PI_SUCCESS) \
131-
return Result;
127+
{ \
128+
if (PrintPiTrace) \
129+
fprintf(stderr, "PI ---> %s\n", #Call); \
130+
pi_result Result = (Call); \
131+
if (Result != PI_SUCCESS) \
132+
return Result; \
133+
}
132134

133135
enum DebugLevel {
134136
ZE_DEBUG_NONE = 0x0,
@@ -1074,8 +1076,6 @@ static pi_result copyModule(ze_context_handle_t ZeContext,
10741076

10751077
static bool setEnvVar(const char *var, const char *value);
10761078

1077-
static pi_result populateDeviceCacheIfNeeded(pi_platform Platform);
1078-
10791079
// Forward declarations for mock implementations of Level Zero APIs that
10801080
// do not yet work in the driver.
10811081
// TODO: Remove these mock definitions when they work in the driver.
@@ -1333,7 +1333,11 @@ pi_result piextPlatformCreateWithNativeHandle(pi_native_handle NativeHandle,
13331333
// Return NULL if no such PI device found.
13341334
pi_device _pi_platform::getDeviceFromNativeHandle(ze_device_handle_t ZeDevice) {
13351335

1336-
std::lock_guard<std::mutex> Lock(this->PiDevicesCacheMutex);
1336+
pi_result Res = populateDeviceCacheIfNeeded();
1337+
if (Res != PI_SUCCESS) {
1338+
return nullptr;
1339+
}
1340+
13371341
auto it = std::find_if(PiDevicesCache.begin(), PiDevicesCache.end(),
13381342
[&](std::unique_ptr<_pi_device> &D) {
13391343
return D.get()->ZeDevice == ZeDevice;
@@ -1350,8 +1354,7 @@ pi_result piDevicesGet(pi_platform Platform, pi_device_type DeviceType,
13501354

13511355
PI_ASSERT(Platform, PI_INVALID_PLATFORM);
13521356

1353-
std::lock_guard<std::mutex> Lock(Platform->PiDevicesCacheMutex);
1354-
pi_result Res = populateDeviceCacheIfNeeded(Platform);
1357+
pi_result Res = Platform->populateDeviceCacheIfNeeded();
13551358
if (Res != PI_SUCCESS) {
13561359
return Res;
13571360
}
@@ -1409,15 +1412,14 @@ pi_result piDevicesGet(pi_platform Platform, pi_device_type DeviceType,
14091412
return PI_SUCCESS;
14101413
}
14111414

1412-
// Check the device cache and load it if necessary. The PiDevicesCacheMutex must
1413-
// be locked before calling this function to prevent any synchronization issues.
1414-
static pi_result populateDeviceCacheIfNeeded(pi_platform Platform) {
1415+
// Check the device cache and load it if necessary.
1416+
pi_result _pi_platform::populateDeviceCacheIfNeeded() {
1417+
std::lock_guard<std::mutex> Lock(PiDevicesCacheMutex);
14151418

1416-
if (Platform->DeviceCachePopulated) {
1419+
if (DeviceCachePopulated) {
14171420
return PI_SUCCESS;
14181421
}
14191422

1420-
ze_driver_handle_t ZeDriver = Platform->ZeDriver;
14211423
uint32_t ZeDeviceCount = 0;
14221424
ZE_CALL(zeDeviceGet, (ZeDriver, &ZeDeviceCount, nullptr));
14231425

@@ -1426,21 +1428,48 @@ static pi_result populateDeviceCacheIfNeeded(pi_platform Platform) {
14261428
ZE_CALL(zeDeviceGet, (ZeDriver, &ZeDeviceCount, ZeDevices.data()));
14271429

14281430
for (uint32_t I = 0; I < ZeDeviceCount; ++I) {
1429-
std::unique_ptr<_pi_device> Device(
1430-
new _pi_device(ZeDevices[I], Platform));
1431+
std::unique_ptr<_pi_device> Device(new _pi_device(ZeDevices[I], this));
14311432
pi_result Result = Device->initialize();
14321433
if (Result != PI_SUCCESS) {
14331434
return Result;
14341435
}
1435-
// save a copy in the cache for future uses.
1436-
Platform->PiDevicesCache.push_back(std::move(Device));
1436+
1437+
// Additionally we need to cache all sub-devices too, such that they
1438+
// are readily visible to the piextDeviceCreateWithNativeHandle.
1439+
//
1440+
pi_uint32 SubDevicesCount = 0;
1441+
ZE_CALL(zeDeviceGetSubDevices,
1442+
(Device->ZeDevice, &SubDevicesCount, nullptr));
1443+
1444+
auto ZeSubdevices = new ze_device_handle_t[SubDevicesCount];
1445+
ZE_CALL(zeDeviceGetSubDevices,
1446+
(Device->ZeDevice, &SubDevicesCount, ZeSubdevices));
1447+
1448+
// Wrap the Level Zero sub-devices into PI sub-devices, and add them to
1449+
// cache.
1450+
for (uint32_t I = 0; I < SubDevicesCount; ++I) {
1451+
std::unique_ptr<_pi_device> PiSubDevice(
1452+
new _pi_device(ZeSubdevices[I], this, true));
1453+
pi_result Result = PiSubDevice->initialize();
1454+
if (Result != PI_SUCCESS) {
1455+
delete[] ZeSubdevices;
1456+
return Result;
1457+
}
1458+
// save pointers to sub-devices for quick retrieval in the future.
1459+
Device->SubDevices.push_back(PiSubDevice.get());
1460+
PiDevicesCache.push_back(std::move(PiSubDevice));
1461+
}
1462+
delete[] ZeSubdevices;
1463+
1464+
// Save the root device in the cache for future uses.
1465+
PiDevicesCache.push_back(std::move(Device));
14371466
}
14381467
} catch (const std::bad_alloc &) {
14391468
return PI_OUT_OF_HOST_MEMORY;
14401469
} catch (...) {
14411470
return PI_ERROR_UNKNOWN;
14421471
}
1443-
Platform->DeviceCachePopulated = true;
1472+
DeviceCachePopulated = true;
14441473
return PI_SUCCESS;
14451474
}
14461475

@@ -1986,66 +2015,30 @@ pi_result piDevicePartition(pi_device Device,
19862015

19872016
PI_ASSERT(Device, PI_INVALID_DEVICE);
19882017

1989-
// Check if Device was already partitioned into the same or bigger size
1990-
// before. If so, we can return immediately without searching the global
1991-
// device cache. Note that L0 driver always returns the same handles in the
1992-
// same order for the given number of sub-devices.
1993-
if (OutDevices && NumDevices <= Device->SubDevices.size()) {
1994-
for (uint32_t I = 0; I < NumDevices; I++) {
1995-
OutDevices[I] = Device->SubDevices[I];
1996-
// reusing the same pi_device needs to increment the reference count
1997-
piDeviceRetain(OutDevices[I]);
1998-
}
1999-
if (OutNumDevices)
2000-
*OutNumDevices = NumDevices;
2001-
return PI_SUCCESS;
2018+
// Devices cache is normally created in piDevicesGet but still make
2019+
// sure that cache is populated.
2020+
//
2021+
pi_result Res = Device->Platform->populateDeviceCacheIfNeeded();
2022+
if (Res != PI_SUCCESS) {
2023+
return Res;
20022024
}
20032025

2004-
// Get the number of subdevices available.
2005-
// TODO: maybe add interface to create the specified # of subdevices.
2006-
uint32_t Count = 0;
2007-
ZE_CALL(zeDeviceGetSubDevices, (Device->ZeDevice, &Count, nullptr));
2008-
20092026
if (OutNumDevices) {
2010-
*OutNumDevices = Count;
2027+
*OutNumDevices = Device->SubDevices.size();
20112028
}
20122029

2013-
if (!OutDevices) {
2014-
// If we are not given the buffer, we are done.
2015-
return PI_SUCCESS;
2016-
}
2030+
if (OutDevices) {
2031+
// TODO: Consider support for partitioning to <= total sub-devices.
2032+
// Currently supported partitioning (by affinity domain/numa) would always
2033+
// partition to all sub-devices.
2034+
//
2035+
PI_ASSERT(NumDevices == Device->SubDevices.size(), PI_INVALID_VALUE);
20172036

2018-
try {
2019-
pi_platform Platform = Device->Platform;
2020-
auto ZeSubdevices = new ze_device_handle_t[Count];
2021-
ZE_CALL(zeDeviceGetSubDevices, (Device->ZeDevice, &Count, ZeSubdevices));
2022-
2023-
// Wrap the Level Zero sub-devices into PI sub-devices, and write them out.
2024-
for (uint32_t I = 0; I < Count; ++I) {
2025-
pi_device Dev = Platform->getDeviceFromNativeHandle(ZeSubdevices[I]);
2026-
if (Dev) {
2027-
OutDevices[I] = Dev;
2028-
// reusing the same pi_device needs to increment the reference count
2029-
piDeviceRetain(OutDevices[I]);
2030-
} else {
2031-
std::unique_ptr<_pi_device> PiSubDevice(
2032-
new _pi_device(ZeSubdevices[I], Platform));
2033-
pi_result Result = PiSubDevice->initialize();
2034-
if (Result != PI_SUCCESS) {
2035-
delete[] ZeSubdevices;
2036-
return Result;
2037-
}
2038-
OutDevices[I] = PiSubDevice.get();
2039-
// save pointers to sub-devices for quick retrieval in the future.
2040-
Device->SubDevices.push_back(PiSubDevice.get());
2041-
Platform->PiDevicesCache.push_back(std::move(PiSubDevice));
2042-
}
2037+
for (uint32_t I = 0; I < NumDevices; I++) {
2038+
OutDevices[I] = Device->SubDevices[I];
2039+
// reusing the same pi_device needs to increment the reference count
2040+
PI_CALL(piDeviceRetain(OutDevices[I]));
20432041
}
2044-
delete[] ZeSubdevices;
2045-
} catch (const std::bad_alloc &) {
2046-
return PI_OUT_OF_HOST_MEMORY;
2047-
} catch (...) {
2048-
return PI_ERROR_UNKNOWN;
20492042
}
20502043
return PI_SUCCESS;
20512044
}
@@ -2114,13 +2107,7 @@ pi_result piextDeviceCreateWithNativeHandle(pi_native_handle NativeHandle,
21142107
PI_ASSERT(Device, PI_INVALID_DEVICE);
21152108
PI_ASSERT(NativeHandle, PI_INVALID_VALUE);
21162109
PI_ASSERT(Platform, PI_INVALID_PLATFORM);
2117-
{
2118-
std::lock_guard<std::mutex> Lock(Platform->PiDevicesCacheMutex);
2119-
pi_result Res = populateDeviceCacheIfNeeded(Platform);
2120-
if (Res != PI_SUCCESS) {
2121-
return Res;
2122-
}
2123-
}
2110+
21242111
auto ZeDevice = pi_cast<ze_device_handle_t>(NativeHandle);
21252112

21262113
// The SYCL spec requires that the set of devices must remain fixed for the

sycl/plugins/level_zero/pi_level_zero.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,15 @@ struct _pi_platform {
8282
// Cache pi_devices for reuse
8383
std::vector<std::unique_ptr<_pi_device>> PiDevicesCache;
8484
std::mutex PiDevicesCacheMutex;
85-
pi_device getDeviceFromNativeHandle(ze_device_handle_t);
8685
bool DeviceCachePopulated = false;
8786

87+
// Check the device cache and load it if necessary.
88+
pi_result populateDeviceCacheIfNeeded();
89+
90+
// Return the PI device from cache that represents given native device.
91+
// If not found, then nullptr is returned.
92+
pi_device getDeviceFromNativeHandle(ze_device_handle_t);
93+
8894
// Current number of L0 Command Lists created on this platform.
8995
// this number must not exceed ZeMaxCommandListCache.
9096
std::atomic<int> ZeGlobalCommandListCount{0};

0 commit comments

Comments
 (0)