Skip to content

Commit fc97c36

Browse files
committed
UR updates
1 parent 22ad6c8 commit fc97c36

File tree

10 files changed

+56
-14
lines changed

10 files changed

+56
-14
lines changed

sycl/plugins/level_zero/pi_level_zero.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,10 +1061,11 @@ pi_result piKernelSetExecInfo(pi_kernel Kernel, pi_kernel_exec_info ParamName,
10611061
}
10621062

10631063
pi_result piextProgramSetSpecializationConstant(pi_program Prog,
1064+
pi_kernel Kernel,
10641065
pi_uint32 SpecID, size_t Size,
10651066
const void *SpecValue) {
1066-
return pi2ur::piextProgramSetSpecializationConstant(Prog, SpecID, Size,
1067-
SpecValue);
1067+
return pi2ur::piextProgramSetSpecializationConstant(Prog, Kernel, SpecID,
1068+
Size, SpecValue);
10681069
}
10691070

10701071
// Command buffer extension

sycl/plugins/native_cpu/pi_native_cpu.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1183,7 +1183,8 @@ pi_result piKernelSetExecInfo(pi_kernel, pi_kernel_exec_info, size_t,
11831183
return PI_SUCCESS;
11841184
}
11851185

1186-
pi_result piextProgramSetSpecializationConstant(pi_program, pi_uint32, size_t,
1186+
pi_result piextProgramSetSpecializationConstant(pi_program, pi_kernel,
1187+
pi_uint32, size_t,
11871188
const void *) {
11881189
DIE_NO_IMPLEMENTATION;
11891190
}

sycl/plugins/opencl/pi_opencl.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2212,9 +2212,11 @@ pi_result piKernelSetExecInfo(pi_kernel kernel, pi_kernel_exec_info param_name,
22122212
}
22132213

22142214
pi_result piextProgramSetSpecializationConstant(pi_program prog,
2215+
pi_kernel kernel,
22152216
pi_uint32 spec_id,
22162217
size_t spec_size,
22172218
const void *spec_value) {
2219+
std::ignore = kernel;
22182220
cl_program ClProg = cast<cl_program>(prog);
22192221
cl_context Ctx = nullptr;
22202222
size_t RetSize = 0;

sycl/plugins/unified_runtime/pi2ur.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1912,18 +1912,22 @@ piProgramBuild(pi_program Program, pi_uint32 NumDevices,
19121912
}
19131913

19141914
inline pi_result piextProgramSetSpecializationConstant(pi_program Program,
1915+
pi_kernel Kernel,
19151916
pi_uint32 SpecID,
19161917
size_t Size,
19171918
const void *SpecValue) {
19181919
ur_program_handle_t UrProgram =
19191920
reinterpret_cast<ur_program_handle_t>(Program);
1921+
ur_kernel_handle_t UrKernel = reinterpret_cast<ur_kernel_handle_t>(Kernel);
19201922
uint32_t Count = 1;
19211923
ur_specialization_constant_info_t SpecConstant{};
19221924
SpecConstant.id = SpecID;
19231925
SpecConstant.size = Size;
19241926
SpecConstant.pValue = SpecValue;
1925-
HANDLE_ERRORS(
1926-
urProgramSetSpecializationConstants(UrProgram, Count, &SpecConstant));
1927+
HANDLE_ERRORS(Kernel ? urKernelSetSpecializationConstants(UrKernel, Count,
1928+
&SpecConstant)
1929+
: urProgramSetSpecializationConstants(UrProgram, Count,
1930+
&SpecConstant));
19271931

19281932
return PI_SUCCESS;
19291933
}

sycl/plugins/unified_runtime/pi_unified_runtime.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,10 @@ __SYCL_EXPORT pi_result piProgramBuild(
127127
}
128128

129129
__SYCL_EXPORT pi_result piextProgramSetSpecializationConstant(
130-
pi_program Prog, pi_uint32 SpecID, size_t Size, const void *SpecValue) {
131-
return pi2ur::piextProgramSetSpecializationConstant(Prog, SpecID, Size,
132-
SpecValue);
130+
pi_program Prog, pi_kernel Kernel, pi_uint32 SpecID, size_t Size,
131+
const void *SpecValue) {
132+
return pi2ur::piextProgramSetSpecializationConstant(Prog, Kernel, SpecID,
133+
Size, SpecValue);
133134
}
134135

135136
__SYCL_EXPORT pi_result

sycl/plugins/unified_runtime/ur/adapters/cuda/device.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -993,7 +993,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
993993
strnlen(AddressBuffer, AddressBufferSize - 1) + 1);
994994
}
995995
case UR_DEVICE_INFO_KERNEL_SET_SPECIALIZATION_CONSTANTS:
996-
return ReturnValue(false);
996+
return ReturnValue(true);
997997
// TODO: Investigate if this information is available on CUDA.
998998
case UR_DEVICE_INFO_GPU_EU_COUNT:
999999
case UR_DEVICE_INFO_GPU_EU_SIMD_WIDTH:

sycl/plugins/unified_runtime/ur/adapters/cuda/kernel.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "kernel.hpp"
1010
#include "memory.hpp"
1111
#include "sampler.hpp"
12+
#include "ur_api.h"
1213

1314
UR_APIEXPORT ur_result_t UR_APICALL
1415
urKernelCreate(ur_program_handle_t hProgram, const char *pKernelName,
@@ -374,3 +375,35 @@ urKernelSetArgSampler(ur_kernel_handle_t hKernel, uint32_t argIndex,
374375
}
375376
return Result;
376377
}
378+
379+
UR_APIEXPORT ur_result_t UR_APICALL urKernelSetSpecializationConstants(
380+
ur_kernel_handle_t Kernel, uint32_t,
381+
const ur_specialization_constant_info_t *SpecConstants) {
382+
CUdeviceptr DPtr = 0;
383+
size_t Bytes = 0;
384+
// NOTE: GlobalNamePrefix is generated by CUDANativeSpecConstantsPass and
385+
// must not be changed.
386+
constexpr char GlobalNamePrefix[] = "sycl_specialization_constants_kernel_";
387+
std::string Name{GlobalNamePrefix};
388+
Name.append(Kernel->getName());
389+
const auto ResGetGlobal = cuModuleGetGlobal(
390+
&DPtr, &Bytes, Kernel->get_program()->get(), Name.c_str());
391+
UR_CHECK_ERROR(ResGetGlobal);
392+
// NOTE: The size of the symbol here - 1 - is important (the value is invalid
393+
// and would result in a failure from `cuMemcpyHtoD`), instead, it's used as
394+
// a flag to communicate to the plugin that even though the implicit kernel
395+
// argument is present (hence the call to
396+
// `cuda_piextProgramSetSpecializationConstant`), it has no uses and there is
397+
// no need to set up the symbol.
398+
// See: `CUDANativeSpecConstantsPass::setUpPlaceholderEntries` for how pi
399+
// handles it.
400+
if (Bytes == 1) {
401+
return UR_RESULT_SUCCESS;
402+
}
403+
404+
UR_ASSERT(DPtr && Bytes, UR_RESULT_ERROR_INVALID_VALUE);
405+
406+
const auto ResMemcpy =
407+
cuMemcpyHtoD(DPtr, SpecConstants->pValue, SpecConstants->size);
408+
return UR_CHECK_ERROR(ResMemcpy);
409+
}

sycl/source/detail/program_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ void program_impl::flush_spec_constants(
549549
Descriptors.consume<uint32_t, uint32_t, uint32_t>();
550550

551551
Ctx->getPlugin()->call<PiApiKind::piextProgramSetSpecializationConstant>(
552-
NativePrg, Id, Size, SC.getValuePtr() + Offset);
552+
NativePrg, nullptr, Id, Size, SC.getValuePtr() + Offset);
553553
}
554554
}
555555
}

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ enableITTAnnotationsIfNeeded(const sycl::detail::pi::PiProgram &Prog,
6060
if (SYCLConfig<INTEL_ENABLE_OFFLOAD_ANNOTATIONS>::get() != nullptr) {
6161
constexpr char SpecValue = 1;
6262
Plugin->call<PiApiKind::piextProgramSetSpecializationConstant>(
63-
Prog, ITTSpecConstId, sizeof(char), &SpecValue);
63+
Prog, nullptr, ITTSpecConstId, sizeof(char), &SpecValue);
6464
}
6565
}
6666

@@ -2095,7 +2095,7 @@ setSpecializationConstants(const std::shared_ptr<device_image_impl> &InputImpl,
20952095
for (const device_image_impl::SpecConstDescT &SpecIDDesc : SpecConstDescs) {
20962096
if (SpecIDDesc.IsSet) {
20972097
Plugin->call<PiApiKind::piextProgramSetSpecializationConstant>(
2098-
Prog, SpecIDDesc.ID, SpecIDDesc.Size,
2098+
Prog, nullptr, SpecIDDesc.ID, SpecIDDesc.Size,
20992099
SpecConsts.data() + SpecIDDesc.BlobOffset);
21002100
}
21012101
}

sycl/source/detail/scheduler/commands.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2271,10 +2271,10 @@ static pi_result SetKernelParamsAndLaunch(
22712271
SpecConstsBuffer ? &SpecConstsBuffer : nullptr;
22722272

22732273
// Call into set spec constant for pi cuda kernels.
2274-
if (Queue->getPlugin().getBackend() == backend::ext_oneapi_cuda) {
2274+
if (Queue->getDeviceImplPtr()->getBackend() == backend::ext_oneapi_cuda) {
22752275
static unsigned SpecID = 0;
22762276
auto &Blob = DeviceImageImpl->get_spec_const_blob_ref();
2277-
Plugin.call<PiApiKind::piextProgramSetSpecializationConstant>(
2277+
Plugin->call<PiApiKind::piextProgramSetSpecializationConstant>(
22782278
DeviceImageImpl->get_program_ref(), Kernel, SpecID++, Blob.size(),
22792279
Blob.data());
22802280
}

0 commit comments

Comments
 (0)