Skip to content

Commit 88e459f

Browse files
authored
[SYCL][CUDA][HIP] Fix PI version reporting (#5509)
Report the actual PI version rather than `0.0`.
1 parent 5645c88 commit 88e459f

File tree

8 files changed

+57
-69
lines changed

8 files changed

+57
-69
lines changed

sycl/plugins/cuda/pi_cuda.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1708,8 +1708,22 @@ pi_result cuda_piDeviceGetInfo(pi_device device, pi_device_info param_name,
17081708
device->get_reference_count());
17091709
}
17101710
case PI_DEVICE_INFO_VERSION: {
1711+
std::stringstream s;
1712+
int major;
1713+
sycl::detail::pi::assertion(
1714+
cuDeviceGetAttribute(&major,
1715+
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
1716+
device->get()) == CUDA_SUCCESS);
1717+
s << major;
1718+
1719+
int minor;
1720+
sycl::detail::pi::assertion(
1721+
cuDeviceGetAttribute(&minor,
1722+
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
1723+
device->get()) == CUDA_SUCCESS);
1724+
s << "." << minor;
17111725
return getInfo(param_value_size, param_value, param_value_size_ret,
1712-
"PI 0.0");
1726+
s.str().c_str());
17131727
}
17141728
case PI_DEVICE_INFO_OPENCL_C_VERSION: {
17151729
return getInfo(param_value_size, param_value, param_value_size_ret, "");

sycl/plugins/hip/pi_hip.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1663,8 +1663,21 @@ pi_result hip_piDeviceGetInfo(pi_device device, pi_device_info param_name,
16631663
device->get_reference_count());
16641664
}
16651665
case PI_DEVICE_INFO_VERSION: {
1666+
std::stringstream s;
1667+
1668+
hipDeviceProp_t props;
1669+
sycl::detail::pi::assertion(hipGetDeviceProperties(&props, device->get()) ==
1670+
hipSuccess);
1671+
#if defined(__HIP_PLATFORM_NVIDIA__)
1672+
s << props.major << "." << props.minor;
1673+
#elif defined(__HIP_PLATFORM_AMD__)
1674+
s << props.gcnArchName;
1675+
#else
1676+
#error("Must define exactly one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__");
1677+
#endif
1678+
16661679
return getInfo(param_value_size, param_value, param_value_size_ret,
1667-
"PI 0.0");
1680+
s.str().c_str());
16681681
}
16691682
case PI_DEVICE_INFO_OPENCL_C_VERSION: {
16701683
return getInfo(param_value_size, param_value, param_value_size_ret, "");

sycl/source/detail/device_info.hpp

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -205,27 +205,11 @@ struct get_device_info_impl<std::vector<info::fp_config>, Param> {
205205
}
206206
};
207207

208-
// Specialization for OpenCL version, splits the string returned by OpenCL
208+
// Specialization for device version
209209
template <> struct get_device_info_impl<std::string, info::device::version> {
210210
static std::string get(RT::PiDevice dev, const plugin &Plugin) {
211-
std::string result = get_device_info_string(
212-
dev, PiInfoCode<info::device::version>::value, Plugin);
213-
214-
// Extract OpenCL version from the returned string.
215-
// For example, for the string "OpenCL 2.1 (Build 0)"
216-
// return '2.1'.
217-
auto dotPos = result.find('.');
218-
if (dotPos == std::string::npos)
219-
return result;
220-
221-
auto leftPos = result.rfind(' ', dotPos);
222-
if (leftPos == std::string::npos)
223-
leftPos = 0;
224-
else
225-
leftPos++;
226-
227-
auto rightPos = result.find(' ', dotPos);
228-
return result.substr(leftPos, rightPos - leftPos);
211+
return get_device_info_string(dev, PiInfoCode<info::device::version>::value,
212+
Plugin);
229213
}
230214
};
231215

sycl/source/detail/error_handling/error_handling.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ void handleInvalidWorkGroupSize(const device_impl &DeviceImpl, pi_kernel Kernel,
5555
bool IsLevelZero = false; // Backend is any OneAPI Level 0 version
5656
auto Backend = Platform.get_backend();
5757
if (Backend == sycl::backend::opencl) {
58-
std::string VersionString = DeviceImpl.get_info<info::device::version>();
58+
std::string VersionString =
59+
DeviceImpl.get_info<info::device::version>().substr(7, 3);
5960
IsOpenCL = true;
6061
IsOpenCLV1x = (VersionString.find("1.") == 0);
6162
IsOpenCLVGE20 =

sycl/test-e2e/Basic/info_ocl_version.cpp

Lines changed: 0 additions & 36 deletions
This file was deleted.

sycl/test-e2e/GroupAlgorithm/SYCL2020/support.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@ bool isSupportedDevice(device D) {
1111

1212
if (PlatformName.find("OpenCL") != std::string::npos) {
1313
std::string Version = D.get_info<info::device::version>();
14-
size_t Offset = Version.find("OpenCL");
15-
if (Offset == std::string::npos)
16-
return false;
17-
Version = Version.substr(Offset + 7, 3);
18-
if (Version >= std::string("2.0"))
14+
15+
// Group collectives are mandatory in OpenCL 2.0 but optional in 3.0.
16+
Version = Version.substr(7, 3);
17+
if (Version >= "2.0" && Version < "3.0")
1918
return true;
2019
}
2120

sycl/test-e2e/GroupAlgorithm/support.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@ bool isSupportedDevice(device D) {
1515

1616
if (PlatformName.find("OpenCL") != std::string::npos) {
1717
std::string Version = D.get_info<sycl::info::device::version>();
18-
size_t Offset = Version.find("OpenCL");
19-
if (Offset == std::string::npos)
20-
return false;
21-
Version = Version.substr(Offset + 7, 3);
22-
if (Version >= std::string("2.0"))
18+
19+
// Group collectives are mandatory in OpenCL 2.0 but optional in 3.0.
20+
Version = Version.substr(7, 3);
21+
if (Version >= "2.0" && Version < "3.0")
2322
return true;
2423
}
2524

sycl/test-e2e/SubGroup/helper.hpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,5 +169,19 @@ bool core_sg_supported(const device &Device) {
169169
auto Vec = Device.get_info<info::device::extensions>();
170170
if (std::find(Vec.begin(), Vec.end(), "cl_khr_subgroups") != std::end(Vec))
171171
return true;
172-
return Device.get_info<info::device::version>() >= "2.1";
172+
173+
if (std::find(Vec.begin(), Vec.end(), "cl_intel_subgroups") != std::end(Vec))
174+
return true;
175+
176+
if (Device.get_backend() == sycl::backend::opencl) {
177+
// Extract the numerical version from the version string, OpenCL version
178+
// string have the format "OpenCL <major>.<minor> <vendor specific data>".
179+
std::string ver = Device.get_info<info::device::version>().substr(7, 3);
180+
181+
// cl_khr_subgroups was core in OpenCL 2.1 and 2.2, but went back to
182+
// optional in 3.0
183+
return ver >= "2.1" && ver < "3.0";
184+
}
185+
186+
return false;
173187
}

0 commit comments

Comments
 (0)