Skip to content

Commit 1757da8

Browse files
committed
[SYCL][CUDA] Fix for default selection of CUDA devices
Signed-off-by: Ruyman Reyes <[email protected]>
1 parent 8445ee8 commit 1757da8

File tree

4 files changed

+11
-31
lines changed

4 files changed

+11
-31
lines changed

sycl/plugins/cuda/pi_cuda.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,7 @@ pi_result cuda_piPlatformGetInfo(pi_platform platform,
684684
switch (param_name) {
685685
case PI_PLATFORM_INFO_NAME:
686686
return getInfo(param_value_size, param_value, param_value_size_ret,
687-
"NVIDIA CUDA");
687+
"NVIDIA CUDA BACKEND");
688688
case PI_PLATFORM_INFO_VENDOR:
689689
return getInfo(param_value_size, param_value, param_value_size_ret,
690690
"NVIDIA Corporation");

sycl/source/detail/platform_impl.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,10 @@ class platform_impl {
7474
bool is_host() const { return MHostPlatform; };
7575

7676
bool is_cuda() const {
77-
const string_class CUDA_PLATFORM_STRING = "NVIDIA CUDA";
77+
const string_class CUDA_PLATFORM_STRING = "NVIDIA CUDA BACKEND";
7878
const string_class PlatformName =
79-
get_platform_info<string_class, info::platform::name>::get(MPlatform,
80-
getPlugin());
79+
get_platform_info<string_class, info::platform::version>::get(
80+
MPlatform, getPlugin());
8181
return PlatformName == CUDA_PLATFORM_STRING;
8282
}
8383

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -84,29 +84,7 @@ static RT::PiProgram createBinaryProgram(const ContextImplPtr Context,
8484

8585
RT::PiProgram Program;
8686

87-
bool IsCUDA = false;
88-
89-
// TODO: Implement `piProgramCreateWithBinary` to not require extra logic for
90-
// the CUDA backend.
91-
#if USE_PI_CUDA
92-
// All devices in a context are from the same platform.
93-
RT::PiDevice Device = getFirstDevice(Context);
94-
RT::PiPlatform Platform = nullptr;
95-
Plugin.call<PiApiKind::piDeviceGetInfo>(Device, PI_DEVICE_INFO_PLATFORM, sizeof(Platform),
96-
&Platform, nullptr);
97-
size_t PlatformNameSize = 0u;
98-
Plugin.call<PiApiKind::piPlatformGetInfo>(Platform, PI_PLATFORM_INFO_NAME, 0u, nullptr,
99-
&PlatformNameSize);
100-
std::vector<char> PlatformName(PlatformNameSize, '\0');
101-
Plugin.call<PiApiKind::piPlatformGetInfo>(Platform, PI_PLATFORM_INFO_NAME,
102-
PlatformName.size(), PlatformName.data(), nullptr);
103-
if (PlatformNameSize > 0u &&
104-
std::strncmp(PlatformName.data(), "NVIDIA CUDA", PlatformNameSize) == 0) {
105-
IsCUDA = true;
106-
}
107-
#endif // USE_PI_CUDA
108-
109-
if (IsCUDA) {
87+
if (Context->getPlatformImpl()->is_cuda()) {
11088
// TODO: Reemplace CreateWithSource with CreateWithBinary in CUDA backend
11189
const char *SignedData = reinterpret_cast<const char *>(Data);
11290
Plugin.call<PiApiKind::piclProgramCreateWithSource>(Context->getHandleRef(), 1 /*one binary*/, &SignedData,

sycl/source/device_selector.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,16 @@ int default_selector::operator()(const device &dev) const {
4343
const platform platform = dev.get_info<info::device::platform>();
4444
const std::string platformVersion =
4545
platform.get_info<info::platform::version>();;
46+
const bool HasCudaString =
47+
platformVersion.find("CUDA") != std::string::npos;
48+
const bool HasOpenCLString =
49+
platformVersion.find("OpenCL") != std::string::npos;
4650
// If using PI_CUDA, don't accept a non-CUDA device
47-
if (platformVersion.find("CUDA") == std::string::npos &&
48-
backend == "PI_CUDA") {
51+
if (HasCudaString && HasOpenCLString && backend == "PI_CUDA") {
4952
return -1;
5053
}
5154
// If using PI_OPENCL, don't accept a non-OpenCL device
52-
if (platformVersion.find("OpenCL") == std::string::npos &&
53-
backend == "PI_OPENCL") {
55+
if (HasCudaString && !HasOpenCLString && backend == "PI_OPENCL") {
5456
return -1;
5557
}
5658
}

0 commit comments

Comments
 (0)