Skip to content

Commit d37ea6e

Browse files
[SYCL] Add ext_oneapi_unified_runtime SYCL backend (#7816)
This PR adds new backend to SYCL: ext_oneapi_unified_runtime E2E test: intel/llvm-test-suite#1450 Signed-off-by: Sergey V Maslov <[email protected]>
1 parent 0f59628 commit d37ea6e

File tree

14 files changed

+331
-163
lines changed

14 files changed

+331
-163
lines changed

sycl/include/sycl/backend_types.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ enum class backend : char {
3232
ext_intel_esimd_emulator,
3333
ext_oneapi_hip = 6,
3434
hip __SYCL2020_DEPRECATED("use 'ext_oneapi_hip' instead") = ext_oneapi_hip,
35+
ext_oneapi_unified_runtime = 7,
3536
};
3637

3738
template <backend Backend> class backend_traits;
@@ -63,6 +64,9 @@ inline std::ostream &operator<<(std::ostream &Out, backend be) {
6364
case backend::ext_oneapi_hip:
6465
Out << "ext_oneapi_hip";
6566
break;
67+
case backend::ext_oneapi_unified_runtime:
68+
Out << "ext_oneapi_unified_runtime";
69+
break;
6670
case backend::all:
6771
Out << "all";
6872
}

sycl/include/sycl/detail/pi.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,21 @@ bool trace(TraceLevel level);
6767
#define __SYCL_CUDA_PLUGIN_NAME "pi_cuda.dll"
6868
#define __SYCL_ESIMD_EMULATOR_PLUGIN_NAME "pi_esimd_emulator.dll"
6969
#define __SYCL_HIP_PLUGIN_NAME "libpi_hip.dll"
70+
#define __SYCL_UNIFIED_RUNTIME_PLUGIN_NAME "pi_unified_runtime.dll"
7071
#elif defined(__SYCL_RT_OS_LINUX)
7172
#define __SYCL_OPENCL_PLUGIN_NAME "libpi_opencl.so"
7273
#define __SYCL_LEVEL_ZERO_PLUGIN_NAME "libpi_level_zero.so"
7374
#define __SYCL_CUDA_PLUGIN_NAME "libpi_cuda.so"
7475
#define __SYCL_ESIMD_EMULATOR_PLUGIN_NAME "libpi_esimd_emulator.so"
76+
#define __SYCL_UNIFIED_RUNTIME_PLUGIN_NAME "libpi_unified_runtime.so"
7577
#define __SYCL_HIP_PLUGIN_NAME "libpi_hip.so"
7678
#elif defined(__SYCL_RT_OS_DARWIN)
7779
#define __SYCL_OPENCL_PLUGIN_NAME "libpi_opencl.dylib"
7880
#define __SYCL_LEVEL_ZERO_PLUGIN_NAME "libpi_level_zero.dylib"
7981
#define __SYCL_CUDA_PLUGIN_NAME "libpi_cuda.dylib"
8082
#define __SYCL_ESIMD_EMULATOR_PLUGIN_NAME "libpi_esimd_emulator.dylib"
8183
#define __SYCL_HIP_PLUGIN_NAME "libpi_hip.dylib"
84+
#define __SYCL_UNIFIED_RUNTIME_PLUGIN_NAME "libpi_unified_runtime.dylib"
8285
#else
8386
#error "Unsupported OS"
8487
#endif

sycl/plugins/level_zero/pi_level_zero.cpp

Lines changed: 11 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -198,87 +198,6 @@ template <> ze_result_t zeHostSynchronize(ze_command_queue_handle_t Handle) {
198198
return zeHostSynchronizeImpl(zeCommandQueueSynchronize, Handle);
199199
}
200200

201-
template <typename T, typename Assign>
202-
pi_result getInfoImpl(size_t param_value_size, void *param_value,
203-
size_t *param_value_size_ret, T value, size_t value_size,
204-
Assign &&assign_func) {
205-
206-
if (param_value != nullptr) {
207-
208-
if (param_value_size < value_size) {
209-
return PI_ERROR_INVALID_VALUE;
210-
}
211-
212-
assign_func(param_value, value, value_size);
213-
}
214-
215-
if (param_value_size_ret != nullptr) {
216-
*param_value_size_ret = value_size;
217-
}
218-
219-
return PI_SUCCESS;
220-
}
221-
222-
template <typename T>
223-
pi_result getInfo(size_t param_value_size, void *param_value,
224-
size_t *param_value_size_ret, T value) {
225-
226-
auto assignment = [](void *param_value, T value, size_t value_size) {
227-
(void)value_size;
228-
*static_cast<T *>(param_value) = value;
229-
};
230-
231-
return getInfoImpl(param_value_size, param_value, param_value_size_ret, value,
232-
sizeof(T), assignment);
233-
}
234-
235-
template <typename T>
236-
pi_result getInfoArray(size_t array_length, size_t param_value_size,
237-
void *param_value, size_t *param_value_size_ret,
238-
T *value) {
239-
return getInfoImpl(param_value_size, param_value, param_value_size_ret, value,
240-
array_length * sizeof(T), memcpy);
241-
}
242-
243-
template <typename T, typename RetType>
244-
pi_result getInfoArray(size_t array_length, size_t param_value_size,
245-
void *param_value, size_t *param_value_size_ret,
246-
T *value) {
247-
if (param_value) {
248-
memset(param_value, 0, param_value_size);
249-
for (uint32_t I = 0; I < array_length; I++)
250-
((RetType *)param_value)[I] = (RetType)value[I];
251-
}
252-
if (param_value_size_ret)
253-
*param_value_size_ret = array_length * sizeof(RetType);
254-
return PI_SUCCESS;
255-
}
256-
257-
template <>
258-
pi_result getInfo<const char *>(size_t param_value_size, void *param_value,
259-
size_t *param_value_size_ret,
260-
const char *value) {
261-
return getInfoArray(strlen(value) + 1, param_value_size, param_value,
262-
param_value_size_ret, value);
263-
}
264-
265-
class ReturnHelper {
266-
public:
267-
ReturnHelper(size_t param_value_size, void *param_value,
268-
size_t *param_value_size_ret)
269-
: param_value_size(param_value_size), param_value(param_value),
270-
param_value_size_ret(param_value_size_ret) {}
271-
272-
template <class T> pi_result operator()(const T &t) {
273-
return getInfo(param_value_size, param_value, param_value_size_ret, t);
274-
}
275-
276-
private:
277-
size_t param_value_size;
278-
void *param_value;
279-
size_t *param_value_size_ret;
280-
};
281-
282201
} // anonymous namespace
283202

284203
// SYCL_PI_LEVEL_ZERO_USE_COMPUTE_ENGINE can be set to an integer (>=0) in
@@ -439,11 +358,6 @@ pi_result _pi_context::decrementUnreleasedEventsInPool(pi_event Event) {
439358
return PI_SUCCESS;
440359
}
441360

442-
// Some opencl extensions we know are supported by all Level Zero devices.
443-
constexpr char ZE_SUPPORTED_EXTENSIONS[] =
444-
"cl_khr_il_program cl_khr_subgroups cl_intel_subgroups "
445-
"cl_intel_subgroups_short cl_intel_required_subgroup_size ";
446-
447361
// Forward declarations
448362
static pi_result
449363
enqueueMemCopyHelper(pi_command_type CommandType, pi_queue Queue, void *Dst,
@@ -2307,51 +2221,17 @@ pi_result piPlatformsGet(pi_uint32 NumEntries, pi_platform *Platforms,
23072221
pi_result piPlatformGetInfo(pi_platform Platform, pi_platform_info ParamName,
23082222
size_t ParamValueSize, void *ParamValue,
23092223
size_t *ParamValueSizeRet) {
2310-
2311-
PI_ASSERT(Platform, PI_ERROR_INVALID_PLATFORM);
2312-
23132224
zePrint("==========================\n");
23142225
zePrint("SYCL over Level-Zero %s\n", Platform->ZeDriverVersion.c_str());
23152226
zePrint("==========================\n");
23162227

2317-
ReturnHelper ReturnValue(ParamValueSize, ParamValue, ParamValueSizeRet);
2318-
2319-
switch (ParamName) {
2320-
case PI_PLATFORM_INFO_NAME:
2321-
// TODO: Query Level Zero driver when relevant info is added there.
2228+
// To distinguish this L0 platform from Unified Runtime one.
2229+
if (ParamName == PI_PLATFORM_INFO_NAME) {
2230+
ReturnHelper ReturnValue(ParamValueSize, ParamValue, ParamValueSizeRet);
23222231
return ReturnValue("Intel(R) Level-Zero");
2323-
case PI_PLATFORM_INFO_VENDOR:
2324-
// TODO: Query Level Zero driver when relevant info is added there.
2325-
return ReturnValue("Intel(R) Corporation");
2326-
case PI_PLATFORM_INFO_EXTENSIONS:
2327-
// Convention adopted from OpenCL:
2328-
// "Returns a space-separated list of extension names (the extension
2329-
// names themselves do not contain any spaces) supported by the platform.
2330-
// Extensions defined here must be supported by all devices associated
2331-
// with this platform."
2332-
//
2333-
// TODO: Check the common extensions supported by all connected devices and
2334-
// return them. For now, hardcoding some extensions we know are supported by
2335-
// all Level Zero devices.
2336-
return ReturnValue(ZE_SUPPORTED_EXTENSIONS);
2337-
case PI_PLATFORM_INFO_PROFILE:
2338-
// TODO: figure out what this means and how is this used
2339-
return ReturnValue("FULL_PROFILE");
2340-
case PI_PLATFORM_INFO_VERSION:
2341-
// TODO: this should query to zeDriverGetDriverVersion
2342-
// but we don't yet have the driver handle here.
2343-
//
2344-
// From OpenCL 2.1: "This version string has the following format:
2345-
// OpenCL<space><major_version.minor_version><space><platform-specific
2346-
// information>. Follow the same notation here.
2347-
//
2348-
return ReturnValue(Platform->ZeDriverApiVersion.c_str());
2349-
default:
2350-
zePrint("piPlatformGetInfo: unrecognized ParamName\n");
2351-
return PI_ERROR_INVALID_VALUE;
23522232
}
2353-
2354-
return PI_SUCCESS;
2233+
return pi2ur::piPlatformGetInfo(Platform, ParamName, ParamValueSize,
2234+
ParamValue, ParamValueSizeRet);
23552235
}
23562236

23572237
pi_result piextPlatformGetNativeHandle(pi_platform Platform,
@@ -3068,10 +2948,9 @@ pi_result piDeviceGetInfo(pi_device Device, pi_device_info ParamName,
30682948
case PI_DEVICE_INFO_SUB_GROUP_SIZES_INTEL: {
30692949
// ze_device_compute_properties.subGroupSizes is in uint32_t whereas the
30702950
// expected return is size_t datatype. size_t can be 8 bytes of data.
3071-
return getInfoArray<uint32_t, size_t>(
3072-
Device->ZeDeviceComputeProperties->numSubGroupSizes, ParamValueSize,
3073-
ParamValue, ParamValueSizeRet,
3074-
Device->ZeDeviceComputeProperties->subGroupSizes);
2951+
return ReturnValue.template operator()<size_t>(
2952+
Device->ZeDeviceComputeProperties->subGroupSizes,
2953+
Device->ZeDeviceComputeProperties->numSubGroupSizes);
30752954
}
30762955
case PI_DEVICE_INFO_IL_VERSION: {
30772956
// Set to a space separated list of IL version strings of the form
@@ -3463,8 +3342,7 @@ pi_result piContextGetInfo(pi_context Context, pi_context_info ParamName,
34633342
ReturnHelper ReturnValue(ParamValueSize, ParamValue, ParamValueSizeRet);
34643343
switch (ParamName) {
34653344
case PI_CONTEXT_INFO_DEVICES:
3466-
return getInfoArray(Context->Devices.size(), ParamValueSize, ParamValue,
3467-
ParamValueSizeRet, &Context->Devices[0]);
3345+
return ReturnValue(&Context->Devices[0], Context->Devices.size());
34683346
case PI_CONTEXT_INFO_NUM_DEVICES:
34693347
return ReturnValue(pi_uint32(Context->Devices.size()));
34703348
case PI_CONTEXT_INFO_REFERENCE_COUNT:
@@ -5375,7 +5253,8 @@ pi_result piKernelGetGroupInfo(pi_kernel Kernel, pi_device Device,
53755253
return ReturnValue(WorkSize);
53765254
}
53775255
case PI_KERNEL_GROUP_INFO_WORK_GROUP_SIZE: {
5378-
// As of right now, L0 is missing API to query kernel and device specific max work group size.
5256+
// As of right now, L0 is missing API to query kernel and device specific
5257+
// max work group size.
53795258
return ReturnValue(
53805259
pi_uint64{Device->ZeDeviceComputeProperties->maxTotalGroupSize});
53815260
}

sycl/plugins/unified_runtime/pi2ur.hpp

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,7 @@
1515

1616
// Map of UR error codes to PI error codes
1717
static pi_result ur2piResult(zer_result_t urResult) {
18-
19-
// TODO: replace "global lifetime" objects with a non-trivial d'tor with
20-
// either pointers to such objects (which would be allocated and dealocated
21-
// during init and teardown) or objects with trivial d'tor.
22-
// E.g. for this case we could have an std::array with sorted values.
23-
//
24-
static std::unordered_map<zer_result_t, pi_result> ErrorMapping = {
18+
std::unordered_map<zer_result_t, pi_result> ErrorMapping = {
2519
{ZER_RESULT_SUCCESS, PI_SUCCESS},
2620
{ZER_RESULT_ERROR_UNKNOWN, PI_ERROR_UNKNOWN},
2721
{ZER_RESULT_ERROR_DEVICE_LOST, PI_ERROR_DEVICE_NOT_FOUND},
@@ -50,6 +44,24 @@ static pi_result ur2piResult(zer_result_t urResult) {
5044
if (auto Result = urCall) \
5145
return ur2piResult(Result);
5246

47+
// A version of return helper that returns pi_result and not zer_result_t
48+
class ReturnHelper : public UrReturnHelper {
49+
public:
50+
using UrReturnHelper::UrReturnHelper;
51+
52+
template <class T> pi_result operator()(const T &t) {
53+
return ur2piResult(UrReturnHelper::operator()(t));
54+
}
55+
// Array return value
56+
template <class T> pi_result operator()(const T *t, size_t s) {
57+
return ur2piResult(UrReturnHelper::operator()(t, s));
58+
}
59+
// Array return value where element type is differrent from T
60+
template <class RetType, class T> pi_result operator()(const T *t, size_t s) {
61+
return ur2piResult(UrReturnHelper::operator()<RetType>(t, s));
62+
}
63+
};
64+
5365
namespace pi2ur {
5466
inline pi_result piPlatformsGet(pi_uint32 num_entries, pi_platform *platforms,
5567
pi_uint32 *num_platforms) {
@@ -66,14 +78,31 @@ inline pi_result piPlatformsGet(pi_uint32 num_entries, pi_platform *platforms,
6678
}
6779

6880
inline pi_result piPlatformGetInfo(pi_platform platform,
69-
pi_platform_info param_name,
70-
size_t param_value_size, void *param_value,
71-
size_t *param_value_size_ret) {
72-
(void)platform;
73-
(void)param_name;
74-
(void)param_value_size;
75-
(void)param_value;
76-
(void)param_value_size_ret;
77-
die("Unified Runtime: piPlatformGetInfo is not implemented");
81+
pi_platform_info ParamName,
82+
size_t ParamValueSize, void *ParamValue,
83+
size_t *ParamValueSizeRet) {
84+
85+
static std::unordered_map<pi_platform_info, zer_platform_info_t> InfoMapping =
86+
{
87+
{PI_PLATFORM_INFO_EXTENSIONS, ZER_PLATFORM_INFO_NAME},
88+
{PI_PLATFORM_INFO_NAME, ZER_PLATFORM_INFO_NAME},
89+
{PI_PLATFORM_INFO_PROFILE, ZER_PLATFORM_INFO_PROFILE},
90+
{PI_PLATFORM_INFO_VENDOR, ZER_PLATFORM_INFO_VENDOR_NAME},
91+
{PI_PLATFORM_INFO_VERSION, ZER_PLATFORM_INFO_VERSION},
92+
};
93+
94+
auto InfoType = InfoMapping.find(ParamName);
95+
if (InfoType == InfoMapping.end()) {
96+
return PI_ERROR_UNKNOWN;
97+
}
98+
99+
size_t SizeInOut = ParamValueSize;
100+
auto hPlatform = reinterpret_cast<zer_platform_handle_t>(platform);
101+
HANDLE_ERRORS(
102+
zerPlatformGetInfo(hPlatform, InfoType->second, &SizeInOut, ParamValue));
103+
if (ParamValueSizeRet) {
104+
*ParamValueSizeRet = SizeInOut;
105+
}
106+
return PI_SUCCESS;
78107
}
79108
} // namespace pi2ur

sycl/plugins/unified_runtime/pi_unified_runtime.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,79 @@
66
//
77
//===------------------------------------------------------------------===//
88

9+
#include <cstring>
10+
911
#include <pi2ur.hpp>
12+
#include <pi_unified_runtime.hpp>
13+
14+
// Stub function to where all not yet supported PI API are bound
15+
static void DieUnsupported() {
16+
die("Unified Runtime: functionality is not supported");
17+
}
1018

19+
// All PI API interfaces are C interfaces
1120
extern "C" {
1221
__SYCL_EXPORT pi_result piPlatformsGet(pi_uint32 num_entries,
1322
pi_platform *platforms,
1423
pi_uint32 *num_platforms) {
1524
return pi2ur::piPlatformsGet(num_entries, platforms, num_platforms);
1625
}
26+
27+
__SYCL_EXPORT pi_result piPlatformGetInfo(pi_platform Platform,
28+
pi_platform_info ParamName,
29+
size_t ParamValueSize,
30+
void *ParamValue,
31+
size_t *ParamValueSizeRet) {
32+
return pi2ur::piPlatformGetInfo(Platform, ParamName, ParamValueSize,
33+
ParamValue, ParamValueSizeRet);
34+
}
35+
36+
__SYCL_EXPORT pi_result piDevicesGet(pi_platform Platform,
37+
pi_device_type DeviceType,
38+
pi_uint32 NumEntries, pi_device *Devices,
39+
pi_uint32 *NumDevices) {
40+
// Report no devices, stab to have a minimal SYCL test running
41+
if (NumDevices) {
42+
*NumDevices = 0;
43+
}
44+
return PI_SUCCESS;
45+
}
46+
47+
// This interface is not in Unified Runtime currently
48+
__SYCL_EXPORT pi_result piTearDown(void *) { return PI_SUCCESS; }
49+
50+
// This interface is not in Unified Runtime currently
51+
__SYCL_EXPORT pi_result piPluginInit(pi_plugin *PluginInit) {
52+
PI_ASSERT(PluginInit, PI_ERROR_INVALID_VALUE);
53+
54+
const char SupportedVersion[] = _PI_UNIFIED_RUNTIME_PLUGIN_VERSION_STRING;
55+
56+
// Check that the major version matches in PiVersion and SupportedVersion
57+
_PI_PLUGIN_VERSION_CHECK(PluginInit->PiVersion, SupportedVersion);
58+
59+
// TODO: handle versioning/targets properly.
60+
size_t PluginVersionSize = sizeof(PluginInit->PluginVersion);
61+
62+
PI_ASSERT(strlen(_PI_UNIFIED_RUNTIME_PLUGIN_VERSION_STRING) <
63+
PluginVersionSize,
64+
PI_ERROR_INVALID_VALUE);
65+
66+
strncpy(PluginInit->PluginVersion, SupportedVersion, PluginVersionSize);
67+
68+
// Bind interfaces that are already supported and "die" for unsupported ones
69+
#define _PI_API(api) \
70+
(PluginInit->PiFunctionTable).api = (decltype(&::api))(&DieUnsupported);
71+
#include <sycl/detail/pi.def>
72+
73+
#define _PI_API(api) \
74+
(PluginInit->PiFunctionTable).api = (decltype(&::api))(&api);
75+
76+
_PI_API(piPlatformsGet)
77+
_PI_API(piPlatformGetInfo)
78+
_PI_API(piDevicesGet)
79+
_PI_API(piTearDown)
80+
81+
return PI_SUCCESS;
82+
}
83+
1784
} // extern "C

0 commit comments

Comments
 (0)