@@ -542,9 +542,28 @@ pi_result cuda_piPlatformsGet(pi_uint32 num_entries, pi_platform *platforms,
542
542
543
543
static std::once_flag initFlag;
544
544
static _pi_platform platformId;
545
- std::call_once (initFlag,
546
- [](pi_result &err) { err = PI_CHECK_ERROR (cuInit (0 )); },
547
- err);
545
+ std::call_once (
546
+ initFlag,
547
+ [](pi_result &err) {
548
+ err = PI_CHECK_ERROR (cuInit (0 ));
549
+
550
+ int numDevices = 0 ;
551
+ err = PI_CHECK_ERROR (cuDeviceGetCount (&numDevices));
552
+ platformId.devices_ .reserve (numDevices);
553
+ try {
554
+ for (int i = 0 ; i < numDevices; ++i) {
555
+ CUdevice device;
556
+ err = PI_CHECK_ERROR (cuDeviceGet (&device, i));
557
+ platformId.devices_ .emplace_back (
558
+ new _pi_device{device, &platformId});
559
+ }
560
+ } catch (...) {
561
+ // Clear and rethrow to allow retry
562
+ platformId.devices_ .clear ();
563
+ throw ;
564
+ }
565
+ },
566
+ err);
548
567
549
568
*platforms = &platformId;
550
569
}
@@ -594,22 +613,16 @@ pi_result cuda_piDevicesGet(pi_platform platform, pi_device_type device_type,
594
613
595
614
pi_result err = PI_SUCCESS;
596
615
const bool askingForGPU = (device_type & PI_DEVICE_TYPE_GPU);
597
- size_t numDevices = askingForGPU ? 1 : 0 ;
616
+ size_t numDevices = askingForGPU ? platform-> devices_ . size () : 0 ;
598
617
599
618
try {
600
619
if (num_devices) {
601
620
*num_devices = numDevices;
602
621
}
603
622
604
- if (askingForGPU) {
605
- if (devices) {
606
- CUdevice device;
607
- err = PI_CHECK_ERROR (cuDeviceGet (&device, 0 ));
608
- *devices = new _pi_device{device, platform};
609
- }
610
- } else {
611
- if (devices) {
612
- *devices = nullptr ;
623
+ if (askingForGPU && devices) {
624
+ for (size_t i = 0 ; i < std::min (size_t (num_entries), numDevices); ++i) {
625
+ devices[i] = platform->devices_ [i].get ();
613
626
}
614
627
}
615
628
0 commit comments