Skip to content

Commit 9f89247

Browse files
author
Danilo Krummrich
authored
[SYCL][PI/CL] Check device version/extensions rather than platform version/extensions (#6795)
For OpenCL backends currently piProgramCreate() queries the platform version (CL_PLATFORM_VERSION) and platform extensions (CL_PLATFORM_EXTENSIONS) to check whether we're capable of running on top of a particular OpenCL backend. However, there might be platforms where the supported device version is lower than the platform version or where not all devices do support the same extensions and hence some extensions supported by a particular device are not reported in the platform extensions. In particular for CL_PLATFORM_EXTENSIONS the OpenCL specification says: "[...] Each extension that is supported by all devices associated with this platform must be reported here." In 3.4.1 Mixed Version Support the specification also says: "[...] The version returned corresponds to the highest version of the OpenCL specification for which the device is conformant, but is not higher than the platform version." Hence, check for the device version and extensions rather than the platform version and extensions in piProgramCreate(). Signed-off-by: Danilo Krummrich <[email protected]>
1 parent 1f8d90f commit 9f89247

File tree

2 files changed

+208
-47
lines changed

2 files changed

+208
-47
lines changed

sycl/plugins/opencl/pi_opencl.cpp

Lines changed: 117 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,71 @@ pi_result piPluginGetLastError(char **message) {
8989
return ErrorMessageCode;
9090
}
9191

92+
static cl_int getPlatformVersion(cl_platform_id plat,
93+
OCLV::OpenCLVersion &version) {
94+
cl_int ret_err = CL_INVALID_VALUE;
95+
96+
size_t platVerSize = 0;
97+
ret_err =
98+
clGetPlatformInfo(plat, CL_PLATFORM_VERSION, 0, nullptr, &platVerSize);
99+
100+
std::string platVer(platVerSize, '\0');
101+
ret_err = clGetPlatformInfo(plat, CL_PLATFORM_VERSION, platVerSize,
102+
platVer.data(), nullptr);
103+
104+
if (ret_err != CL_SUCCESS)
105+
return ret_err;
106+
107+
version = OCLV::OpenCLVersion(platVer);
108+
if (!version.isValid())
109+
return CL_INVALID_PLATFORM;
110+
111+
return ret_err;
112+
}
113+
114+
static cl_int getDeviceVersion(cl_device_id dev, OCLV::OpenCLVersion &version) {
115+
cl_int ret_err = CL_INVALID_VALUE;
116+
117+
size_t devVerSize = 0;
118+
ret_err = clGetDeviceInfo(dev, CL_DEVICE_VERSION, 0, nullptr, &devVerSize);
119+
120+
std::string devVer(devVerSize, '\0');
121+
ret_err = clGetDeviceInfo(dev, CL_DEVICE_VERSION, devVerSize, devVer.data(),
122+
nullptr);
123+
124+
if (ret_err != CL_SUCCESS)
125+
return ret_err;
126+
127+
version = OCLV::OpenCLVersion(devVer);
128+
if (!version.isValid())
129+
return CL_INVALID_DEVICE;
130+
131+
return ret_err;
132+
}
133+
134+
static cl_int checkDeviceExtensions(cl_device_id dev,
135+
const std::vector<std::string> &exts,
136+
bool &supported) {
137+
cl_int ret_err = CL_INVALID_VALUE;
138+
139+
size_t extSize = 0;
140+
ret_err = clGetDeviceInfo(dev, CL_DEVICE_EXTENSIONS, 0, nullptr, &extSize);
141+
142+
std::string extStr(extSize, '\0');
143+
ret_err = clGetDeviceInfo(dev, CL_DEVICE_EXTENSIONS, extSize, extStr.data(),
144+
nullptr);
145+
146+
if (ret_err != CL_SUCCESS)
147+
return ret_err;
148+
149+
supported = true;
150+
for (const std::string &ext : exts)
151+
if (!(supported = (extStr.find(ext) != std::string::npos)))
152+
break;
153+
154+
return ret_err;
155+
}
156+
92157
// USM helper function to get an extension function pointer
93158
template <const char *FuncName, typename T>
94159
static pi_result getExtFuncFromContext(pi_context context, T *fptr) {
@@ -215,17 +280,18 @@ pi_result piDeviceGetInfo(pi_device device, pi_device_info paramName,
215280
case PI_DEVICE_INFO_ATOMIC_MEMORY_SCOPE_CAPABILITIES:
216281
return PI_ERROR_INVALID_VALUE;
217282
case PI_DEVICE_INFO_ATOMIC_64: {
218-
size_t extSize;
219-
cl_bool result = clGetDeviceInfo(
220-
cast<cl_device_id>(device), CL_DEVICE_EXTENSIONS, 0, nullptr, &extSize);
221-
std::string extStr(extSize, '\0');
222-
result = clGetDeviceInfo(cast<cl_device_id>(device), CL_DEVICE_EXTENSIONS,
223-
extSize, &extStr.front(), nullptr);
224-
if (extStr.find("cl_khr_int64_base_atomics") == std::string::npos ||
225-
extStr.find("cl_khr_int64_extended_atomics") == std::string::npos)
226-
result = false;
227-
else
228-
result = true;
283+
cl_int ret_err = CL_SUCCESS;
284+
cl_bool result = CL_FALSE;
285+
bool supported = false;
286+
287+
ret_err = checkDeviceExtensions(
288+
cast<cl_device_id>(device),
289+
{"cl_khr_int64_base_atomics", "cl_khr_int64_extended_atomics"},
290+
supported);
291+
if (ret_err != CL_SUCCESS)
292+
return static_cast<pi_result>(ret_err);
293+
294+
result = supported;
229295
std::memcpy(paramValue, &result, sizeof(cl_bool));
230296
return PI_SUCCESS;
231297
}
@@ -402,18 +468,6 @@ pi_result piQueueCreate(pi_context context, pi_device device,
402468

403469
CHECK_ERR_SET_NULL_RET(ret_err, queue, ret_err);
404470

405-
size_t platVerSize;
406-
ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_VERSION, 0, nullptr,
407-
&platVerSize);
408-
409-
CHECK_ERR_SET_NULL_RET(ret_err, queue, ret_err);
410-
411-
std::string platVer(platVerSize, '\0');
412-
ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_VERSION, platVerSize,
413-
&platVer.front(), nullptr);
414-
415-
CHECK_ERR_SET_NULL_RET(ret_err, queue, ret_err);
416-
417471
// Check that unexpected bits are not set.
418472
assert(!(properties &
419473
~(PI_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE |
@@ -425,9 +479,12 @@ pi_result piQueueCreate(pi_context context, pi_device device,
425479
CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE | CL_QUEUE_PROFILING_ENABLE |
426480
CL_QUEUE_ON_DEVICE | CL_QUEUE_ON_DEVICE_DEFAULT;
427481

428-
if (platVer.find("OpenCL 1.0") != std::string::npos ||
429-
platVer.find("OpenCL 1.1") != std::string::npos ||
430-
platVer.find("OpenCL 1.2") != std::string::npos) {
482+
OCLV::OpenCLVersion version;
483+
ret_err = getPlatformVersion(curPlatform, version);
484+
485+
CHECK_ERR_SET_NULL_RET(ret_err, queue, ret_err);
486+
487+
if (version >= OCLV::V2_0) {
431488
*queue = cast<pi_queue>(clCreateCommandQueue(
432489
cast<cl_context>(context), cast<cl_device_id>(device),
433490
cast<cl_command_queue_properties>(properties) & SupportByOpenCL,
@@ -482,38 +539,51 @@ pi_result piProgramCreate(pi_context context, const void *il, size_t length,
482539

483540
CHECK_ERR_SET_NULL_RET(ret_err, res_program, CL_INVALID_CONTEXT);
484541

485-
size_t devVerSize;
486-
ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_VERSION, 0, nullptr,
487-
&devVerSize);
488-
std::string devVer(devVerSize, '\0');
489-
ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_VERSION, devVerSize,
490-
&devVer.front(), nullptr);
542+
OCLV::OpenCLVersion platVer;
543+
ret_err = getPlatformVersion(curPlatform, platVer);
491544

492545
CHECK_ERR_SET_NULL_RET(ret_err, res_program, CL_INVALID_CONTEXT);
493546

494547
pi_result err = PI_SUCCESS;
495-
if (devVer.find("OpenCL 1.0") == std::string::npos &&
496-
devVer.find("OpenCL 1.1") == std::string::npos &&
497-
devVer.find("OpenCL 1.2") == std::string::npos &&
498-
devVer.find("OpenCL 2.0") == std::string::npos) {
548+
if (platVer >= OCLV::V2_1) {
549+
550+
/* Make sure all devices support CL 2.1 or newer as well. */
551+
for (cl_device_id dev : devicesInCtx) {
552+
OCLV::OpenCLVersion devVer;
553+
554+
ret_err = getDeviceVersion(dev, devVer);
555+
CHECK_ERR_SET_NULL_RET(ret_err, res_program, CL_INVALID_CONTEXT);
556+
557+
/* If the device does not support CL 2.1 or greater, we need to make sure
558+
* it supports the cl_khr_il_program extension.
559+
*/
560+
if (devVer < OCLV::V2_1) {
561+
bool supported = false;
562+
563+
ret_err = checkDeviceExtensions(dev, {"cl_khr_il_program"}, supported);
564+
CHECK_ERR_SET_NULL_RET(ret_err, res_program, CL_INVALID_CONTEXT);
565+
566+
if (!supported)
567+
return cast<pi_result>(CL_INVALID_OPERATION);
568+
}
569+
}
499570
if (res_program != nullptr)
500571
*res_program = cast<pi_program>(clCreateProgramWithIL(
501572
cast<cl_context>(context), il, length, cast<cl_int *>(&err)));
502573
return err;
503574
}
504575

505-
size_t extSize;
506-
ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_EXTENSIONS, 0, nullptr,
507-
&extSize);
508-
std::string extStr(extSize, '\0');
509-
ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_EXTENSIONS, extSize,
510-
&extStr.front(), nullptr);
576+
/* If none of the devices conform with CL 2.1 or newer make sure they all
577+
* support the cl_khr_il_program extension.
578+
*/
579+
for (cl_device_id dev : devicesInCtx) {
580+
bool supported = false;
511581

512-
if (ret_err != CL_SUCCESS ||
513-
extStr.find("cl_khr_il_program") == std::string::npos) {
514-
if (res_program != nullptr)
515-
*res_program = nullptr;
516-
return cast<pi_result>(CL_INVALID_CONTEXT);
582+
ret_err = checkDeviceExtensions(dev, {"cl_khr_il_program"}, supported);
583+
CHECK_ERR_SET_NULL_RET(ret_err, res_program, CL_INVALID_CONTEXT);
584+
585+
if (!supported)
586+
return cast<pi_result>(CL_INVALID_OPERATION);
517587
}
518588

519589
using apiFuncT =

sycl/plugins/opencl/pi_opencl.hpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,102 @@
1717
#ifndef PI_OPENCL_HPP
1818
#define PI_OPENCL_HPP
1919

20+
#include <climits>
21+
#include <regex>
22+
#include <string>
23+
2024
// This version should be incremented for any change made to this file or its
2125
// corresponding .cpp file.
2226
#define _PI_OPENCL_PLUGIN_VERSION 1
2327

2428
#define _PI_OPENCL_PLUGIN_VERSION_STRING \
2529
_PI_PLUGIN_VERSION_STRING(_PI_OPENCL_PLUGIN_VERSION)
2630

31+
namespace OCLV {
32+
class OpenCLVersion {
33+
protected:
34+
unsigned int major;
35+
unsigned int minor;
36+
37+
public:
38+
OpenCLVersion() : major(0), minor(0) {}
39+
40+
OpenCLVersion(unsigned int major, unsigned int minor)
41+
: major(major), minor(minor) {
42+
if (!isValid())
43+
major = minor = 0;
44+
}
45+
46+
OpenCLVersion(const char *version) : OpenCLVersion(std::string(version)) {}
47+
48+
OpenCLVersion(const std::string &version) : major(0), minor(0) {
49+
/* The OpenCL specification defines the full version string as
50+
* 'OpenCL<space><major_version.minor_version><space><platform-specific
51+
* information>' for platforms and as
52+
* 'OpenCL<space><major_version.minor_version><space><vendor-specific
53+
* information>' for devices.
54+
*/
55+
std::regex rx("OpenCL ([0-9]+)\\.([0-9]+)");
56+
std::smatch match;
57+
58+
if (std::regex_search(version, match, rx) && (match.size() == 3)) {
59+
major = strtoul(match[1].str().c_str(), nullptr, 10);
60+
minor = strtoul(match[2].str().c_str(), nullptr, 10);
61+
62+
if (!isValid())
63+
major = minor = 0;
64+
}
65+
}
66+
67+
bool operator==(const OpenCLVersion &v) const {
68+
return major == v.major && minor == v.minor;
69+
}
70+
71+
bool operator!=(const OpenCLVersion &v) const { return !(*this == v); }
72+
73+
bool operator<(const OpenCLVersion &v) const {
74+
if (major == v.major)
75+
return minor < v.minor;
76+
77+
return major < v.major;
78+
}
79+
80+
bool operator>(const OpenCLVersion &v) const { return v < *this; }
81+
82+
bool operator<=(const OpenCLVersion &v) const {
83+
return (*this < v) || (*this == v);
84+
}
85+
86+
bool operator>=(const OpenCLVersion &v) const {
87+
return (*this > v) || (*this == v);
88+
}
89+
90+
bool isValid() const {
91+
switch (major) {
92+
case 0:
93+
return false;
94+
case 1:
95+
case 2:
96+
return minor <= 2;
97+
case UINT_MAX:
98+
return false;
99+
default:
100+
return minor != UINT_MAX;
101+
}
102+
}
103+
104+
int getMajor() const { return major; }
105+
int getMinor() const { return minor; }
106+
};
107+
108+
inline const OpenCLVersion V1_0(1, 0);
109+
inline const OpenCLVersion V1_1(1, 1);
110+
inline const OpenCLVersion V1_2(1, 2);
111+
inline const OpenCLVersion V2_0(2, 0);
112+
inline const OpenCLVersion V2_1(2, 1);
113+
inline const OpenCLVersion V2_2(2, 2);
114+
inline const OpenCLVersion V3_0(3, 0);
115+
116+
} // namespace OCLV
117+
27118
#endif // PI_OPENCL_HPP

0 commit comments

Comments
 (0)