Skip to content

Commit a7653f3

Browse files
Petr Veselyveselypeta
authored andcommitted
[UR][CUDA][HIP] Fix Set Arg Local
1 parent 5c30815 commit a7653f3

File tree

6 files changed

+44
-19
lines changed

6 files changed

+44
-19
lines changed

sycl/plugins/unified_runtime/pi2ur.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2190,8 +2190,12 @@ inline pi_result piKernelSetArg(pi_kernel Kernel, pi_uint32 ArgIndex,
21902190

21912191
ur_kernel_handle_t UrKernel = reinterpret_cast<ur_kernel_handle_t>(Kernel);
21922192

2193-
HANDLE_ERRORS(
2194-
urKernelSetArgValue(UrKernel, ArgIndex, ArgSize, nullptr, ArgValue));
2193+
if (ArgValue) {
2194+
HANDLE_ERRORS(
2195+
urKernelSetArgValue(UrKernel, ArgIndex, ArgSize, nullptr, ArgValue));
2196+
} else {
2197+
HANDLE_ERRORS(urKernelSetArgLocal(UrKernel, ArgIndex, ArgSize, nullptr));
2198+
}
21952199
return PI_SUCCESS;
21962200
}
21972201

sycl/plugins/unified_runtime/ur/adapters/cuda/kernel.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
186186

187187
ur_result_t Result = UR_RESULT_SUCCESS;
188188
try {
189-
if (pArgValue) {
190-
hKernel->setKernelArg(argIndex, argSize, pArgValue);
191-
} else {
192-
hKernel->setKernelLocalArg(argIndex, argSize);
193-
}
189+
hKernel->setKernelArg(argIndex, argSize, pArgValue);
190+
} catch (ur_result_t Err) {
191+
Result = Err;
192+
}
193+
return Result;
194+
}
195+
196+
UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgLocal(
197+
ur_kernel_handle_t hKernel, uint32_t argIndex, size_t argSize,
198+
const ur_kernel_arg_local_properties_t *pProperties) {
199+
std::ignore = pProperties;
200+
UR_ASSERT(argSize, UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE);
201+
202+
ur_result_t Result = UR_RESULT_SUCCESS;
203+
try {
204+
hKernel->setKernelLocalArg(argIndex, argSize);
194205
} catch (ur_result_t Err) {
195206
Result = Err;
196207
}

sycl/plugins/unified_runtime/ur/adapters/cuda/ur_interface_loader.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelProcAddrTable(
115115
pDdiTable->pfnGetSubGroupInfo = urKernelGetSubGroupInfo;
116116
pDdiTable->pfnRelease = urKernelRelease;
117117
pDdiTable->pfnRetain = urKernelRetain;
118-
pDdiTable->pfnSetArgLocal = nullptr;
118+
pDdiTable->pfnSetArgLocal = urKernelSetArgLocal;
119119
pDdiTable->pfnSetArgMemObj = urKernelSetArgMemObj;
120120
pDdiTable->pfnSetArgPointer = urKernelSetArgPointer;
121121
pDdiTable->pfnSetArgSampler = urKernelSetArgSampler;

sycl/plugins/unified_runtime/ur/adapters/hip/kernel.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
168168
const ur_kernel_arg_value_properties_t *, const void *pArgValue) {
169169
ur_result_t Result = UR_RESULT_SUCCESS;
170170
try {
171-
if (pArgValue) {
172-
hKernel->setKernelArg(argIndex, argSize, pArgValue);
173-
} else {
174-
hKernel->setKernelLocalArg(argIndex, argSize);
175-
}
171+
hKernel->setKernelArg(argIndex, argSize, pArgValue);
172+
} catch (ur_result_t Err) {
173+
Result = Err;
174+
}
175+
return Result;
176+
}
177+
178+
UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgLocal(
179+
ur_kernel_handle_t hKernel, uint32_t argIndex, size_t argSize,
180+
const ur_kernel_arg_local_properties_t *pProperties) {
181+
std::ignore = pProperties;
182+
UR_ASSERT(argSize, UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE);
183+
184+
ur_result_t Result = UR_RESULT_SUCCESS;
185+
try {
186+
hKernel->setKernelLocalArg(argIndex, argSize);
176187
} catch (ur_result_t Err) {
177188
Result = Err;
178189
}

sycl/plugins/unified_runtime/ur/adapters/hip/ur_interface_loader.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelProcAddrTable(
115115
pDdiTable->pfnGetSubGroupInfo = urKernelGetSubGroupInfo;
116116
pDdiTable->pfnRelease = urKernelRelease;
117117
pDdiTable->pfnRetain = urKernelRetain;
118-
pDdiTable->pfnSetArgLocal = nullptr;
118+
pDdiTable->pfnSetArgLocal = urKernelSetArgLocal;
119119
pDdiTable->pfnSetArgMemObj = urKernelSetArgMemObj;
120120
pDdiTable->pfnSetArgPointer = urKernelSetArgPointer;
121121
pDdiTable->pfnSetArgSampler = urKernelSetArgSampler;

sycl/plugins/unified_runtime/ur/adapters/level_zero/kernel.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -419,12 +419,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgLocal(
419419
const ur_kernel_arg_local_properties_t
420420
*Properties ///< [in][optional] argument properties
421421
) {
422-
std::ignore = Kernel;
423-
std::ignore = ArgIndex;
424422
std::ignore = Properties;
425-
std::ignore = ArgSize;
426-
urPrint("[UR][L0] %s function not implemented!\n", __FUNCTION__);
427-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
423+
424+
UR_CALL(urKernelSetArgValue(Kernel, ArgIndex, ArgSize, nullptr, nullptr));
425+
426+
return UR_RESULT_SUCCESS;
428427
}
429428

430429
UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(

0 commit comments

Comments
 (0)