Skip to content

Commit ccac682

Browse files
committed
[SYCL] Add specialization constant API to the SYCL RT Plugin Interface.
New PI API added: pi_result piProgramSetSpecializationConstant(pi_program prog, pi_uint32 spec_id, size_t spec_size, const void *spec_value); Signed-off-by: Konstantin S Bobrovsky <[email protected]>
1 parent abb37d7 commit ccac682

File tree

3 files changed

+54
-9
lines changed

3 files changed

+54
-9
lines changed

sycl/include/CL/sycl/detail/pi.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ _PI_API(piProgramLink)
5757
_PI_API(piProgramGetBuildInfo)
5858
_PI_API(piProgramRetain)
5959
_PI_API(piProgramRelease)
60+
_PI_API(piextProgramSetSpecializationConstant)
6061
// Kernel
6162
_PI_API(piKernelCreate)
6263
_PI_API(piKernelSetArg)

sycl/include/CL/sycl/detail/pi.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,7 @@ pi_result piDevicePartition(pi_device device,
743743
///
744744
pi_result piextDeviceSelectBinary(pi_device device, pi_device_binary *binaries,
745745
pi_uint32 num_binaries,
746-
pi_device_binary *selected_binary);
746+
pi_uint32 *selected_binary_ind);
747747

748748
/// Retrieves a device function pointer to a user-defined function
749749
/// \arg \c function_name. \arg \c function_pointer_ret is set to 0 if query
@@ -879,6 +879,17 @@ pi_result piProgramRetain(pi_program program);
879879

880880
pi_result piProgramRelease(pi_program program);
881881

882+
/// Sets a specialization constant to a specific value.
883+
///
884+
/// \param prog the program object which will use the value
885+
/// \param spec_id integer ID of the constant
886+
/// \param spec_size size of the value
887+
/// \param spec_value bytes of the value
888+
pi_result piextProgramSetSpecializationConstant(pi_program prog,
889+
pi_uint32 spec_id,
890+
size_t spec_size,
891+
const void *spec_value);
892+
882893
//
883894
// Kernel
884895
//

sycl/plugins/opencl/pi_opencl.cpp

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8-
98
/// \defgroup sycl_pi_ocl OpenCL Plugin
109
/// \ingroup sycl_pi
1110

@@ -20,6 +19,8 @@
2019

2120
#include <cassert>
2221
#include <cstring>
22+
#include <iostream>
23+
#include <limits>
2324
#include <map>
2425
#include <string>
2526
#include <vector>
@@ -58,6 +59,8 @@ CONSTFIX char clEnqueueMemcpyName[] = "clEnqueueMemcpyINTEL";
5859
CONSTFIX char clEnqueueMigrateMemName[] = "clEnqueueMigrateMemINTEL";
5960
CONSTFIX char clEnqueueMemAdviseName[] = "clEnqueueMemAdviseINTEL";
6061
CONSTFIX char clGetMemAllocInfoName[] = "clGetMemAllocInfoINTEL";
62+
CONSTFIX char clSetProgramSpecializationConstantName[] =
63+
"clSetProgramSpecializationConstant";
6164

6265
#undef CONSTFIX
6366

@@ -215,7 +218,7 @@ pi_result OCL(piDevicesGet)(pi_platform platform, pi_device_type device_type,
215218
pi_result OCL(piextDeviceSelectBinary)(pi_device device,
216219
pi_device_binary *images,
217220
pi_uint32 num_images,
218-
pi_device_binary *selected_image) {
221+
pi_uint32 *selected_image_ind) {
219222

220223
// TODO: this is a bare-bones implementation for choosing a device image
221224
// that would be compatible with the targeted device. An AOT-compiled
@@ -234,11 +237,12 @@ pi_result OCL(piextDeviceSelectBinary)(pi_device device,
234237
const char *image_target = nullptr;
235238
// Get the type of the device
236239
cl_device_type device_type;
240+
constexpr pi_uint32 invalid_ind = std::numeric_limits<pi_uint32>::max();
237241
cl_int ret_err =
238242
clGetDeviceInfo(cast<cl_device_id>(device), CL_DEVICE_TYPE,
239243
sizeof(cl_device_type), &device_type, nullptr);
240244
if (ret_err != CL_SUCCESS) {
241-
*selected_image = nullptr;
245+
*selected_image_ind = invalid_ind;
242246
return cast<pi_result>(ret_err);
243247
}
244248

@@ -266,18 +270,18 @@ pi_result OCL(piextDeviceSelectBinary)(pi_device device,
266270
}
267271

268272
// Find the appropriate device image, fallback to spirv if not found
269-
pi_device_binary fallback = nullptr;
270-
for (size_t i = 0; i < num_images; ++i) {
273+
pi_uint32 fallback = invalid_ind;
274+
for (pi_uint32 i = 0; i < num_images; ++i) {
271275
if (strcmp(images[i]->DeviceTargetSpec, image_target) == 0) {
272-
*selected_image = images[i];
276+
*selected_image_ind = i;
273277
return PI_SUCCESS;
274278
}
275279
if (strcmp(images[i]->DeviceTargetSpec, PI_DEVICE_BINARY_TARGET_SPIRV64) ==
276280
0)
277-
fallback = images[i];
281+
fallback = i;
278282
}
279283
// Points to a spirv image, if such indeed was found
280-
if ((*selected_image = fallback))
284+
if ((*selected_image_ind = fallback) != invalid_ind)
281285
return PI_SUCCESS;
282286
// No image can be loaded for the given device
283287
return PI_INVALID_BINARY;
@@ -1013,6 +1017,32 @@ pi_result OCL(piKernelSetExecInfo)(pi_kernel kernel,
10131017
}
10141018
}
10151019

1020+
typedef CL_API_ENTRY cl_int(CL_API_CALL *clSetProgramSpecializationConstant_fn)(
1021+
cl_program program, cl_uint spec_id, size_t spec_size,
1022+
const void *spec_value);
1023+
1024+
static pi_result OCL(piextProgramSetSpecializationConstantImpl)(
1025+
pi_program prog, unsigned int spec_id, size_t spec_size,
1026+
const void *spec_value) {
1027+
cl_program ClProg = cast<cl_program>(prog);
1028+
cl_context Ctx = nullptr;
1029+
size_t RetSize = 0;
1030+
cl_int Res =
1031+
clGetProgramInfo(ClProg, CL_PROGRAM_CONTEXT, sizeof(Ctx), &Ctx, &RetSize);
1032+
1033+
if (Res != CL_SUCCESS)
1034+
return cast<pi_result>(Res);
1035+
1036+
clSetProgramSpecializationConstant_fn F = nullptr;
1037+
Res = getExtFuncFromContext<clSetProgramSpecializationConstantName,
1038+
decltype(F)>(cast<pi_context>(Ctx), &F);
1039+
1040+
if (!F || Res != CL_SUCCESS)
1041+
return PI_INVALID_OPERATION;
1042+
Res = F(ClProg, spec_id, spec_size, spec_value);
1043+
return cast<pi_result>(Res);
1044+
}
1045+
10161046
pi_result piPluginInit(pi_plugin *PluginInit) {
10171047
int CompareVersions = strcmp(PluginInit->PiVersion, SupportedVersion);
10181048
if (CompareVersions < 0) {
@@ -1070,6 +1100,9 @@ pi_result piPluginInit(pi_plugin *PluginInit) {
10701100
_PI_CL(piProgramGetBuildInfo, clGetProgramBuildInfo)
10711101
_PI_CL(piProgramRetain, clRetainProgram)
10721102
_PI_CL(piProgramRelease, clReleaseProgram)
1103+
_PI_CL(piextProgramSetSpecializationConstant,
1104+
OCL(piextProgramSetSpecializationConstantImpl))
1105+
10731106
// Kernel
10741107
_PI_CL(piKernelCreate, OCL(piKernelCreate))
10751108
_PI_CL(piKernelSetArg, clSetKernelArg)

0 commit comments

Comments
 (0)