Skip to content

[SYCL][PI/CL] Check device version/extensions rather than platform version/extensions #6795

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 117 additions & 47 deletions sycl/plugins/opencl/pi_opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,71 @@ pi_result piPluginGetLastError(char **message) {
return ErrorMessageCode;
}

static cl_int getPlatformVersion(cl_platform_id plat,
OCLV::OpenCLVersion &version) {
cl_int ret_err = CL_INVALID_VALUE;

size_t platVerSize = 0;
ret_err =
clGetPlatformInfo(plat, CL_PLATFORM_VERSION, 0, nullptr, &platVerSize);

std::string platVer(platVerSize, '\0');
ret_err = clGetPlatformInfo(plat, CL_PLATFORM_VERSION, platVerSize,
platVer.data(), nullptr);

if (ret_err != CL_SUCCESS)
return ret_err;

version = OCLV::OpenCLVersion(platVer);
if (!version.isValid())
return CL_INVALID_PLATFORM;

return ret_err;
}

static cl_int getDeviceVersion(cl_device_id dev, OCLV::OpenCLVersion &version) {
cl_int ret_err = CL_INVALID_VALUE;

size_t devVerSize = 0;
ret_err = clGetDeviceInfo(dev, CL_DEVICE_VERSION, 0, nullptr, &devVerSize);

std::string devVer(devVerSize, '\0');
ret_err = clGetDeviceInfo(dev, CL_DEVICE_VERSION, devVerSize, devVer.data(),
nullptr);

if (ret_err != CL_SUCCESS)
return ret_err;

version = OCLV::OpenCLVersion(devVer);
if (!version.isValid())
return CL_INVALID_DEVICE;

return ret_err;
}

static cl_int checkDeviceExtensions(cl_device_id dev,
const std::vector<std::string> &exts,
bool &supported) {
cl_int ret_err = CL_INVALID_VALUE;

size_t extSize = 0;
ret_err = clGetDeviceInfo(dev, CL_DEVICE_EXTENSIONS, 0, nullptr, &extSize);

std::string extStr(extSize, '\0');
ret_err = clGetDeviceInfo(dev, CL_DEVICE_EXTENSIONS, extSize, extStr.data(),
nullptr);

if (ret_err != CL_SUCCESS)
return ret_err;

supported = true;
for (const std::string &ext : exts)
if (!(supported = (extStr.find(ext) != std::string::npos)))
break;

return ret_err;
}

// USM helper function to get an extension function pointer
template <const char *FuncName, typename T>
static pi_result getExtFuncFromContext(pi_context context, T *fptr) {
Expand Down Expand Up @@ -215,17 +280,18 @@ pi_result piDeviceGetInfo(pi_device device, pi_device_info paramName,
case PI_DEVICE_INFO_ATOMIC_MEMORY_SCOPE_CAPABILITIES:
return PI_ERROR_INVALID_VALUE;
case PI_DEVICE_INFO_ATOMIC_64: {
size_t extSize;
cl_bool result = clGetDeviceInfo(
cast<cl_device_id>(device), CL_DEVICE_EXTENSIONS, 0, nullptr, &extSize);
std::string extStr(extSize, '\0');
result = clGetDeviceInfo(cast<cl_device_id>(device), CL_DEVICE_EXTENSIONS,
extSize, &extStr.front(), nullptr);
if (extStr.find("cl_khr_int64_base_atomics") == std::string::npos ||
extStr.find("cl_khr_int64_extended_atomics") == std::string::npos)
result = false;
else
result = true;
cl_int ret_err = CL_SUCCESS;
cl_bool result = CL_FALSE;
bool supported = false;

ret_err = checkDeviceExtensions(
cast<cl_device_id>(device),
{"cl_khr_int64_base_atomics", "cl_khr_int64_extended_atomics"},
supported);
if (ret_err != CL_SUCCESS)
return static_cast<pi_result>(ret_err);

result = supported;
std::memcpy(paramValue, &result, sizeof(cl_bool));
return PI_SUCCESS;
}
Expand Down Expand Up @@ -402,18 +468,6 @@ pi_result piQueueCreate(pi_context context, pi_device device,

CHECK_ERR_SET_NULL_RET(ret_err, queue, ret_err);

size_t platVerSize;
ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_VERSION, 0, nullptr,
&platVerSize);

CHECK_ERR_SET_NULL_RET(ret_err, queue, ret_err);

std::string platVer(platVerSize, '\0');
ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_VERSION, platVerSize,
&platVer.front(), nullptr);

CHECK_ERR_SET_NULL_RET(ret_err, queue, ret_err);

// Check that unexpected bits are not set.
assert(!(properties &
~(PI_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE |
Expand All @@ -425,9 +479,12 @@ pi_result piQueueCreate(pi_context context, pi_device device,
CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE | CL_QUEUE_PROFILING_ENABLE |
CL_QUEUE_ON_DEVICE | CL_QUEUE_ON_DEVICE_DEFAULT;

if (platVer.find("OpenCL 1.0") != std::string::npos ||
platVer.find("OpenCL 1.1") != std::string::npos ||
platVer.find("OpenCL 1.2") != std::string::npos) {
OCLV::OpenCLVersion version;
ret_err = getPlatformVersion(curPlatform, version);

CHECK_ERR_SET_NULL_RET(ret_err, queue, ret_err);

if (version >= OCLV::V2_0) {
*queue = cast<pi_queue>(clCreateCommandQueue(
cast<cl_context>(context), cast<cl_device_id>(device),
cast<cl_command_queue_properties>(properties) & SupportByOpenCL,
Expand Down Expand Up @@ -482,38 +539,51 @@ pi_result piProgramCreate(pi_context context, const void *il, size_t length,

CHECK_ERR_SET_NULL_RET(ret_err, res_program, CL_INVALID_CONTEXT);

size_t devVerSize;
ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_VERSION, 0, nullptr,
&devVerSize);
std::string devVer(devVerSize, '\0');
ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_VERSION, devVerSize,
&devVer.front(), nullptr);
OCLV::OpenCLVersion platVer;
ret_err = getPlatformVersion(curPlatform, platVer);

CHECK_ERR_SET_NULL_RET(ret_err, res_program, CL_INVALID_CONTEXT);

pi_result err = PI_SUCCESS;
if (devVer.find("OpenCL 1.0") == std::string::npos &&
devVer.find("OpenCL 1.1") == std::string::npos &&
devVer.find("OpenCL 1.2") == std::string::npos &&
devVer.find("OpenCL 2.0") == std::string::npos) {
if (platVer >= OCLV::V2_1) {

/* Make sure all devices support CL 2.1 or newer as well. */
for (cl_device_id dev : devicesInCtx) {
OCLV::OpenCLVersion devVer;

ret_err = getDeviceVersion(dev, devVer);
CHECK_ERR_SET_NULL_RET(ret_err, res_program, CL_INVALID_CONTEXT);

/* If the device does not support CL 2.1 or greater, we need to make sure
* it supports the cl_khr_il_program extension.
*/
if (devVer < OCLV::V2_1) {
bool supported = false;

ret_err = checkDeviceExtensions(dev, {"cl_khr_il_program"}, supported);
CHECK_ERR_SET_NULL_RET(ret_err, res_program, CL_INVALID_CONTEXT);

if (!supported)
return cast<pi_result>(CL_INVALID_OPERATION);
}
}
if (res_program != nullptr)
*res_program = cast<pi_program>(clCreateProgramWithIL(
cast<cl_context>(context), il, length, cast<cl_int *>(&err)));
return err;
}

size_t extSize;
ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_EXTENSIONS, 0, nullptr,
&extSize);
std::string extStr(extSize, '\0');
ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_EXTENSIONS, extSize,
&extStr.front(), nullptr);
/* If none of the devices conform with CL 2.1 or newer make sure they all
* support the cl_khr_il_program extension.
*/
for (cl_device_id dev : devicesInCtx) {
bool supported = false;

if (ret_err != CL_SUCCESS ||
extStr.find("cl_khr_il_program") == std::string::npos) {
if (res_program != nullptr)
*res_program = nullptr;
return cast<pi_result>(CL_INVALID_CONTEXT);
ret_err = checkDeviceExtensions(dev, {"cl_khr_il_program"}, supported);
CHECK_ERR_SET_NULL_RET(ret_err, res_program, CL_INVALID_CONTEXT);

if (!supported)
return cast<pi_result>(CL_INVALID_OPERATION);
}

using apiFuncT =
Expand Down
91 changes: 91 additions & 0 deletions sycl/plugins/opencl/pi_opencl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,102 @@
#ifndef PI_OPENCL_HPP
#define PI_OPENCL_HPP

#include <climits>
#include <regex>
#include <string>

// This version should be incremented for any change made to this file or its
// corresponding .cpp file.
#define _PI_OPENCL_PLUGIN_VERSION 1

#define _PI_OPENCL_PLUGIN_VERSION_STRING \
_PI_PLUGIN_VERSION_STRING(_PI_OPENCL_PLUGIN_VERSION)

namespace OCLV {
class OpenCLVersion {
protected:
unsigned int major;
unsigned int minor;

public:
OpenCLVersion() : major(0), minor(0) {}

OpenCLVersion(unsigned int major, unsigned int minor)
: major(major), minor(minor) {
if (!isValid())
major = minor = 0;
}

OpenCLVersion(const char *version) : OpenCLVersion(std::string(version)) {}

OpenCLVersion(const std::string &version) : major(0), minor(0) {
/* The OpenCL specification defines the full version string as
* 'OpenCL<space><major_version.minor_version><space><platform-specific
* information>' for platforms and as
* 'OpenCL<space><major_version.minor_version><space><vendor-specific
* information>' for devices.
*/
std::regex rx("OpenCL ([0-9]+)\\.([0-9]+)");
std::smatch match;

if (std::regex_search(version, match, rx) && (match.size() == 3)) {
major = strtoul(match[1].str().c_str(), nullptr, 10);
minor = strtoul(match[2].str().c_str(), nullptr, 10);

if (!isValid())
major = minor = 0;
}
}

bool operator==(const OpenCLVersion &v) const {
return major == v.major && minor == v.minor;
}

bool operator!=(const OpenCLVersion &v) const { return !(*this == v); }

bool operator<(const OpenCLVersion &v) const {
if (major == v.major)
return minor < v.minor;

return major < v.major;
}

bool operator>(const OpenCLVersion &v) const { return v < *this; }

bool operator<=(const OpenCLVersion &v) const {
return (*this < v) || (*this == v);
}

bool operator>=(const OpenCLVersion &v) const {
return (*this > v) || (*this == v);
}

bool isValid() const {
switch (major) {
case 0:
return false;
case 1:
case 2:
return minor <= 2;
case UINT_MAX:
return false;
default:
return minor != UINT_MAX;
}
}

int getMajor() const { return major; }
int getMinor() const { return minor; }
};

inline const OpenCLVersion V1_0(1, 0);
inline const OpenCLVersion V1_1(1, 1);
inline const OpenCLVersion V1_2(1, 2);
inline const OpenCLVersion V2_0(2, 0);
inline const OpenCLVersion V2_1(2, 1);
inline const OpenCLVersion V2_2(2, 2);
inline const OpenCLVersion V3_0(3, 0);

} // namespace OCLV

#endif // PI_OPENCL_HPP