@@ -1139,8 +1139,7 @@ pi_result piPlatformsGet(pi_uint32 NumEntries, pi_platform *Platforms,
1139
1139
// We must only initialize the driver once, even if piPlatformsGet() is called
1140
1140
// multiple times. Declaring the return value as "static" ensures it's only
1141
1141
// called once.
1142
- static ze_result_t ZeResult =
1143
- ZE_CALL_NOCHECK (zeInit, (ZE_INIT_FLAG_GPU_ONLY));
1142
+ static ze_result_t ZeResult = ZE_CALL_NOCHECK (zeInit, (0 ));
1144
1143
1145
1144
// Absorb the ZE_RESULT_ERROR_UNINITIALIZED and just return 0 Platforms.
1146
1145
if (ZeResult == ZE_RESULT_ERROR_UNINITIALIZED) {
@@ -1319,25 +1318,47 @@ pi_result piDevicesGet(pi_platform Platform, pi_device_type DeviceType,
1319
1318
1320
1319
PI_ASSERT (Platform, PI_INVALID_PLATFORM);
1321
1320
1322
- // Get number of devices supporting Level Zero
1323
- uint32_t ZeDeviceCount = 0 ;
1324
1321
std::lock_guard<std::mutex> Lock (Platform->PiDevicesCacheMutex );
1325
-
1326
1322
pi_result Res = populateDeviceCacheIfNeeded (Platform);
1327
1323
if (Res != PI_SUCCESS) {
1328
1324
return Res;
1329
1325
}
1330
1326
1331
- ZeDeviceCount = Platform->PiDevicesCache .size ();
1332
- const bool AskingForGPU = (DeviceType & PI_DEVICE_TYPE_GPU);
1333
- const bool AskingForDefault = (DeviceType == PI_DEVICE_TYPE_DEFAULT);
1327
+ // Filter available devices based on input DeviceType
1328
+ std::vector<pi_device> MatchedDevices;
1329
+ for (auto &D : Platform->PiDevicesCache ) {
1330
+ // Only ever return root-devices from piDevicesGet, but the
1331
+ // devices cache also keeps sub-devices.
1332
+ if (D->IsSubDevice )
1333
+ continue ;
1334
1334
1335
- if (ZeDeviceCount == 0 || !(AskingForGPU || AskingForDefault)) {
1336
- if (NumDevices)
1337
- *NumDevices = 0 ;
1338
- return PI_SUCCESS;
1335
+ bool Matched = false ;
1336
+ switch (DeviceType) {
1337
+ case PI_DEVICE_TYPE_ALL:
1338
+ Matched = true ;
1339
+ break ;
1340
+ case PI_DEVICE_TYPE_GPU:
1341
+ case PI_DEVICE_TYPE_DEFAULT:
1342
+ Matched = (D->ZeDeviceProperties .type == ZE_DEVICE_TYPE_GPU);
1343
+ break ;
1344
+ case PI_DEVICE_TYPE_CPU:
1345
+ Matched = (D->ZeDeviceProperties .type == ZE_DEVICE_TYPE_CPU);
1346
+ break ;
1347
+ case PI_DEVICE_TYPE_ACC:
1348
+ Matched = (D->ZeDeviceProperties .type == ZE_DEVICE_TYPE_MCA ||
1349
+ D->ZeDeviceProperties .type == ZE_DEVICE_TYPE_FPGA);
1350
+ break ;
1351
+ default :
1352
+ Matched = false ;
1353
+ zePrint (" Unknown device type" );
1354
+ break ;
1355
+ }
1356
+ if (Matched)
1357
+ MatchedDevices.push_back (D.get ());
1339
1358
}
1340
1359
1360
+ uint32_t ZeDeviceCount = MatchedDevices.size ();
1361
+
1341
1362
if (NumDevices)
1342
1363
*NumDevices = ZeDeviceCount;
1343
1364
@@ -1348,15 +1369,9 @@ pi_result piDevicesGet(pi_platform Platform, pi_device_type DeviceType,
1348
1369
}
1349
1370
1350
1371
// Return the devices from the cache.
1351
- uint32_t I = 0 ;
1352
- for (const std::unique_ptr<_pi_device> &CachedDevice :
1353
- Platform->PiDevicesCache ) {
1354
- if (I < NumEntries) {
1355
- *Devices++ = CachedDevice.get ();
1356
- I++;
1357
- } else {
1358
- break ;
1359
- }
1372
+ if (Devices) {
1373
+ PI_ASSERT (NumEntries <= ZeDeviceCount, PI_INVALID_DEVICE);
1374
+ std::copy_n (MatchedDevices.begin (), NumEntries, Devices);
1360
1375
}
1361
1376
1362
1377
return PI_SUCCESS;
@@ -1476,11 +1491,18 @@ pi_result piDeviceGetInfo(pi_device Device, pi_device_info ParamName,
1476
1491
1477
1492
switch (ParamName) {
1478
1493
case PI_DEVICE_INFO_TYPE: {
1479
- if (Device->ZeDeviceProperties .type != ZE_DEVICE_TYPE_GPU) {
1494
+ switch (Device->ZeDeviceProperties .type ) {
1495
+ case ZE_DEVICE_TYPE_GPU:
1496
+ return ReturnValue (PI_DEVICE_TYPE_GPU);
1497
+ case ZE_DEVICE_TYPE_CPU:
1498
+ return ReturnValue (PI_DEVICE_TYPE_CPU);
1499
+ case ZE_DEVICE_TYPE_MCA:
1500
+ case ZE_DEVICE_TYPE_FPGA:
1501
+ return ReturnValue (PI_DEVICE_TYPE_ACC);
1502
+ default :
1480
1503
zePrint (" This device type is not supported\n " );
1481
1504
return PI_INVALID_VALUE;
1482
1505
}
1483
- return ReturnValue (PI_DEVICE_TYPE_GPU);
1484
1506
}
1485
1507
case PI_DEVICE_INFO_PARENT_DEVICE:
1486
1508
// TODO: all Level Zero devices are parent ?
0 commit comments