Skip to content

Commit 8d77a38

Browse files
[SYCL][CUDA] Fixes device leak (#1175)
This also makes `cuda_piDevicesGet` return all available CUDA devices. Signed-off-by: Steffen Larsen <[email protected]>
1 parent e94cbd3 commit 8d77a38

File tree

2 files changed

+27
-13
lines changed

2 files changed

+27
-13
lines changed

sycl/plugins/cuda/pi_cuda.cpp

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -542,9 +542,28 @@ pi_result cuda_piPlatformsGet(pi_uint32 num_entries, pi_platform *platforms,
542542

543543
static std::once_flag initFlag;
544544
static _pi_platform platformId;
545-
std::call_once(initFlag,
546-
[](pi_result &err) { err = PI_CHECK_ERROR(cuInit(0)); },
547-
err);
545+
std::call_once(
546+
initFlag,
547+
[](pi_result &err) {
548+
err = PI_CHECK_ERROR(cuInit(0));
549+
550+
int numDevices = 0;
551+
err = PI_CHECK_ERROR(cuDeviceGetCount(&numDevices));
552+
platformId.devices_.reserve(numDevices);
553+
try {
554+
for (int i = 0; i < numDevices; ++i) {
555+
CUdevice device;
556+
err = PI_CHECK_ERROR(cuDeviceGet(&device, i));
557+
platformId.devices_.emplace_back(
558+
new _pi_device{device, &platformId});
559+
}
560+
} catch (...) {
561+
// Clear and rethrow to allow retry
562+
platformId.devices_.clear();
563+
throw;
564+
}
565+
},
566+
err);
548567

549568
*platforms = &platformId;
550569
}
@@ -594,22 +613,16 @@ pi_result cuda_piDevicesGet(pi_platform platform, pi_device_type device_type,
594613

595614
pi_result err = PI_SUCCESS;
596615
const bool askingForGPU = (device_type & PI_DEVICE_TYPE_GPU);
597-
size_t numDevices = askingForGPU ? 1 : 0;
616+
size_t numDevices = askingForGPU ? platform->devices_.size() : 0;
598617

599618
try {
600619
if (num_devices) {
601620
*num_devices = numDevices;
602621
}
603622

604-
if (askingForGPU) {
605-
if (devices) {
606-
CUdevice device;
607-
err = PI_CHECK_ERROR(cuDeviceGet(&device, 0));
608-
*devices = new _pi_device{device, platform};
609-
}
610-
} else {
611-
if (devices) {
612-
*devices = nullptr;
623+
if (askingForGPU && devices) {
624+
for (size_t i = 0; i < std::min(size_t(num_entries), numDevices); ++i) {
625+
devices[i] = platform->devices_[i].get();
613626
}
614627
}
615628

sycl/plugins/cuda/pi_cuda.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ pi_result cuda_piKernelRelease(pi_kernel);
4646
}
4747

4848
struct _pi_platform {
49+
std::vector<std::unique_ptr<_pi_device>> devices_;
4950
};
5051

5152
struct _pi_device {

0 commit comments

Comments
 (0)