5
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
-
9
8
// / \defgroup sycl_pi_ocl OpenCL Plugin
10
9
// / \ingroup sycl_pi
11
10
20
19
21
20
#include < cassert>
22
21
#include < cstring>
22
+ #include < iostream>
23
+ #include < limits>
23
24
#include < map>
24
25
#include < string>
25
26
#include < vector>
@@ -58,6 +59,8 @@ CONSTFIX char clEnqueueMemcpyName[] = "clEnqueueMemcpyINTEL";
58
59
CONSTFIX char clEnqueueMigrateMemName[] = " clEnqueueMigrateMemINTEL" ;
59
60
CONSTFIX char clEnqueueMemAdviseName[] = " clEnqueueMemAdviseINTEL" ;
60
61
CONSTFIX char clGetMemAllocInfoName[] = " clGetMemAllocInfoINTEL" ;
62
+ CONSTFIX char clSetProgramSpecializationConstantName[] =
63
+ " clSetProgramSpecializationConstant" ;
61
64
62
65
#undef CONSTFIX
63
66
@@ -215,7 +218,7 @@ pi_result OCL(piDevicesGet)(pi_platform platform, pi_device_type device_type,
215
218
pi_result OCL (piextDeviceSelectBinary)(pi_device device,
216
219
pi_device_binary *images,
217
220
pi_uint32 num_images,
218
- pi_device_binary *selected_image ) {
221
+ pi_uint32 *selected_image_ind ) {
219
222
220
223
// TODO: this is a bare-bones implementation for choosing a device image
221
224
// that would be compatible with the targeted device. An AOT-compiled
@@ -234,11 +237,12 @@ pi_result OCL(piextDeviceSelectBinary)(pi_device device,
234
237
const char *image_target = nullptr ;
235
238
// Get the type of the device
236
239
cl_device_type device_type;
240
+ constexpr pi_uint32 invalid_ind = std::numeric_limits<pi_uint32>::max ();
237
241
cl_int ret_err =
238
242
clGetDeviceInfo (cast<cl_device_id>(device), CL_DEVICE_TYPE,
239
243
sizeof (cl_device_type), &device_type, nullptr );
240
244
if (ret_err != CL_SUCCESS) {
241
- *selected_image = nullptr ;
245
+ *selected_image_ind = invalid_ind ;
242
246
return cast<pi_result>(ret_err);
243
247
}
244
248
@@ -266,18 +270,18 @@ pi_result OCL(piextDeviceSelectBinary)(pi_device device,
266
270
}
267
271
268
272
// Find the appropriate device image, fallback to spirv if not found
269
- pi_device_binary fallback = nullptr ;
270
- for (size_t i = 0 ; i < num_images; ++i) {
273
+ pi_uint32 fallback = invalid_ind ;
274
+ for (pi_uint32 i = 0 ; i < num_images; ++i) {
271
275
if (strcmp (images[i]->DeviceTargetSpec , image_target) == 0 ) {
272
- *selected_image = images[i] ;
276
+ *selected_image_ind = i ;
273
277
return PI_SUCCESS;
274
278
}
275
279
if (strcmp (images[i]->DeviceTargetSpec , PI_DEVICE_BINARY_TARGET_SPIRV64) ==
276
280
0 )
277
- fallback = images[i] ;
281
+ fallback = i ;
278
282
}
279
283
// Points to a spirv image, if such indeed was found
280
- if ((*selected_image = fallback))
284
+ if ((*selected_image_ind = fallback) != invalid_ind )
281
285
return PI_SUCCESS;
282
286
// No image can be loaded for the given device
283
287
return PI_INVALID_BINARY;
@@ -1013,6 +1017,32 @@ pi_result OCL(piKernelSetExecInfo)(pi_kernel kernel,
1013
1017
}
1014
1018
}
1015
1019
1020
+ typedef CL_API_ENTRY cl_int (CL_API_CALL *clSetProgramSpecializationConstant_fn)(
1021
+ cl_program program, cl_uint spec_id, size_t spec_size,
1022
+ const void *spec_value);
1023
+
1024
+ static pi_result OCL (piextProgramSetSpecializationConstantImpl)(
1025
+ pi_program prog, unsigned int spec_id, size_t spec_size,
1026
+ const void *spec_value) {
1027
+ cl_program ClProg = cast<cl_program>(prog);
1028
+ cl_context Ctx = nullptr ;
1029
+ size_t RetSize = 0 ;
1030
+ cl_int Res =
1031
+ clGetProgramInfo (ClProg, CL_PROGRAM_CONTEXT, sizeof (Ctx), &Ctx, &RetSize);
1032
+
1033
+ if (Res != CL_SUCCESS)
1034
+ return cast<pi_result>(Res);
1035
+
1036
+ clSetProgramSpecializationConstant_fn F = nullptr ;
1037
+ Res = getExtFuncFromContext<clSetProgramSpecializationConstantName,
1038
+ decltype (F)>(cast<pi_context>(Ctx), &F);
1039
+
1040
+ if (!F || Res != CL_SUCCESS)
1041
+ return PI_INVALID_OPERATION;
1042
+ Res = F (ClProg, spec_id, spec_size, spec_value);
1043
+ return cast<pi_result>(Res);
1044
+ }
1045
+
1016
1046
pi_result piPluginInit (pi_plugin *PluginInit) {
1017
1047
int CompareVersions = strcmp (PluginInit->PiVersion , SupportedVersion);
1018
1048
if (CompareVersions < 0 ) {
@@ -1070,6 +1100,9 @@ pi_result piPluginInit(pi_plugin *PluginInit) {
1070
1100
_PI_CL (piProgramGetBuildInfo, clGetProgramBuildInfo)
1071
1101
_PI_CL (piProgramRetain, clRetainProgram)
1072
1102
_PI_CL (piProgramRelease, clReleaseProgram)
1103
+ _PI_CL (piextProgramSetSpecializationConstant,
1104
+ OCL (piextProgramSetSpecializationConstantImpl))
1105
+
1073
1106
// Kernel
1074
1107
_PI_CL (piKernelCreate, OCL (piKernelCreate))
1075
1108
_PI_CL (piKernelSetArg, clSetKernelArg)
0 commit comments