Skip to content

Commit 608603a

Browse files
authored
Merge pull request #2152 from Bensuo/fabio/binary_update_fix
Fix command-buffer binary update implementation
2 parents 0c0c783 + a01ca73 commit 608603a

File tree

3 files changed

+51
-23
lines changed

3 files changed

+51
-23
lines changed

source/adapters/cuda/command_buffer.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,18 +1237,19 @@ validateCommandDesc(kernel_command_handle *Command,
12371237
}
12381238

12391239
/**
1240-
* Updates the arguments of CommandDesc->hNewKernel
1241-
* @param[in] Device The device associated with the kernel being updated.
1240+
* Updates the arguments of a kernel command.
1241+
* @param[in] Command The command associated with the kernel node being updated.
12421242
* @param[in] UpdateCommandDesc The update command description that contains the
1243-
* new kernel and its arguments.
1243+
* new arguments.
12441244
* @return UR_RESULT_SUCCESS or an error code on failure
12451245
*/
12461246
ur_result_t
1247-
updateKernelArguments(ur_device_handle_t Device,
1247+
updateKernelArguments(kernel_command_handle *Command,
12481248
const ur_exp_command_buffer_update_kernel_launch_desc_t
12491249
*UpdateCommandDesc) {
12501250

1251-
ur_kernel_handle_t NewKernel = UpdateCommandDesc->hNewKernel;
1251+
ur_kernel_handle_t Kernel = Command->Kernel;
1252+
ur_device_handle_t Device = Command->CommandBuffer->Device;
12521253

12531254
// Update pointer arguments to the kernel
12541255
uint32_t NumPointerArgs = UpdateCommandDesc->numNewPointerArgs;
@@ -1261,7 +1262,7 @@ updateKernelArguments(ur_device_handle_t Device,
12611262

12621263
ur_result_t Result = UR_RESULT_SUCCESS;
12631264
try {
1264-
NewKernel->setKernelArg(ArgIndex, sizeof(ArgValue), ArgValue);
1265+
Kernel->setKernelArg(ArgIndex, sizeof(ArgValue), ArgValue);
12651266
} catch (ur_result_t Err) {
12661267
Result = Err;
12671268
return Result;
@@ -1280,10 +1281,10 @@ updateKernelArguments(ur_device_handle_t Device,
12801281
ur_result_t Result = UR_RESULT_SUCCESS;
12811282
try {
12821283
if (ArgValue == nullptr) {
1283-
NewKernel->setKernelArg(ArgIndex, 0, nullptr);
1284+
Kernel->setKernelArg(ArgIndex, 0, nullptr);
12841285
} else {
12851286
CUdeviceptr CuPtr = std::get<BufferMem>(ArgValue->Mem).getPtr(Device);
1286-
NewKernel->setKernelArg(ArgIndex, sizeof(CUdeviceptr), (void *)&CuPtr);
1287+
Kernel->setKernelArg(ArgIndex, sizeof(CUdeviceptr), (void *)&CuPtr);
12871288
}
12881289
} catch (ur_result_t Err) {
12891290
Result = Err;
@@ -1303,7 +1304,7 @@ updateKernelArguments(ur_device_handle_t Device,
13031304

13041305
ur_result_t Result = UR_RESULT_SUCCESS;
13051306
try {
1306-
NewKernel->setKernelArg(ArgIndex, ArgSize, ArgValue);
1307+
Kernel->setKernelArg(ArgIndex, ArgSize, ArgValue);
13071308
} catch (ur_result_t Err) {
13081309
Result = Err;
13091310
return Result;
@@ -1364,9 +1365,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
13641365
auto KernelCommandHandle = static_cast<kernel_command_handle *>(hCommand);
13651366

13661367
UR_CHECK_ERROR(validateCommandDesc(KernelCommandHandle, pUpdateKernelLaunch));
1367-
UR_CHECK_ERROR(
1368-
updateKernelArguments(CommandBuffer->Device, pUpdateKernelLaunch));
13691368
UR_CHECK_ERROR(updateCommand(KernelCommandHandle, pUpdateKernelLaunch));
1369+
UR_CHECK_ERROR(
1370+
updateKernelArguments(KernelCommandHandle, pUpdateKernelLaunch));
13701371

13711372
// If no work-size is provided make sure we pass nullptr to setKernelParams so
13721373
// it can guess the local work size.

source/adapters/hip/command_buffer.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -951,18 +951,19 @@ validateCommandDesc(ur_exp_command_buffer_command_handle_t Command,
951951
}
952952

953953
/**
954-
* Updates the arguments of CommandDesc->hNewKernel
955-
* @param[in] Device The device associated with the kernel being updated.
956-
* @param[in] UpdateCommandDesc The update command description that contains
957-
* the new kernel and its arguments.
954+
* Updates the arguments of a kernel command.
955+
* @param[in] Command The command associated with the kernel node being updated.
956+
* @param[in] UpdateCommandDesc The update command description that contains the
957+
* new arguments.
958958
* @return UR_RESULT_SUCCESS or an error code on failure
959959
*/
960960
ur_result_t
961-
updateKernelArguments(ur_device_handle_t Device,
961+
updateKernelArguments(ur_exp_command_buffer_command_handle_t Command,
962962
const ur_exp_command_buffer_update_kernel_launch_desc_t
963963
*UpdateCommandDesc) {
964964

965-
ur_kernel_handle_t NewKernel = UpdateCommandDesc->hNewKernel;
965+
ur_kernel_handle_t Kernel = Command->Kernel;
966+
ur_device_handle_t Device = Command->CommandBuffer->Device;
966967

967968
// Update pointer arguments to the kernel
968969
uint32_t NumPointerArgs = UpdateCommandDesc->numNewPointerArgs;
@@ -974,7 +975,7 @@ updateKernelArguments(ur_device_handle_t Device,
974975
const void *ArgValue = PointerArgDesc.pNewPointerArg;
975976

976977
try {
977-
NewKernel->setKernelArg(ArgIndex, sizeof(ArgValue), ArgValue);
978+
Kernel->setKernelArg(ArgIndex, sizeof(ArgValue), ArgValue);
978979
} catch (ur_result_t Err) {
979980
return Err;
980981
}
@@ -991,10 +992,10 @@ updateKernelArguments(ur_device_handle_t Device,
991992

992993
try {
993994
if (ArgValue == nullptr) {
994-
NewKernel->setKernelArg(ArgIndex, 0, nullptr);
995+
Kernel->setKernelArg(ArgIndex, 0, nullptr);
995996
} else {
996997
void *HIPPtr = std::get<BufferMem>(ArgValue->Mem).getVoid(Device);
997-
NewKernel->setKernelArg(ArgIndex, sizeof(void *), (void *)&HIPPtr);
998+
Kernel->setKernelArg(ArgIndex, sizeof(void *), (void *)&HIPPtr);
998999
}
9991000
} catch (ur_result_t Err) {
10001001
return Err;
@@ -1012,7 +1013,7 @@ updateKernelArguments(ur_device_handle_t Device,
10121013
const void *ArgValue = ValueArgDesc.pNewValueArg;
10131014

10141015
try {
1015-
NewKernel->setKernelArg(ArgIndex, ArgSize, ArgValue);
1016+
Kernel->setKernelArg(ArgIndex, ArgSize, ArgValue);
10161017
} catch (ur_result_t Err) {
10171018
return Err;
10181019
}
@@ -1067,9 +1068,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
10671068
ur_exp_command_buffer_handle_t CommandBuffer = hCommand->CommandBuffer;
10681069

10691070
UR_CHECK_ERROR(validateCommandDesc(hCommand, pUpdateKernelLaunch));
1070-
UR_CHECK_ERROR(
1071-
updateKernelArguments(CommandBuffer->Device, pUpdateKernelLaunch));
10721071
UR_CHECK_ERROR(updateCommand(hCommand, pUpdateKernelLaunch));
1072+
UR_CHECK_ERROR(updateKernelArguments(hCommand, pUpdateKernelLaunch));
10731073

10741074
// If no worksize is provided make sure we pass nullptr to setKernelParams
10751075
// so it can guess the local work size.

test/conformance/exp_command_buffer/update/kernel_handle_update.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,3 +473,30 @@ TEST_P(urCommandBufferValidUpdateParametersTest, UpdateOnlyLocalWorkSize) {
473473

474474
ASSERT_NO_FATAL_FAILURE(SaxpyKernel->validate());
475475
}
476+
477+
// Tests that passing nullptr to hNewKernel works.
478+
TEST_P(urCommandBufferValidUpdateParametersTest, SuccessNullptrHandle) {
479+
480+
std::vector<ur_kernel_handle_t> KernelAlternatives = {
481+
FillUSM2DKernel->Kernel};
482+
483+
uur::raii::CommandBufferCommand CommandHandle;
484+
ASSERT_SUCCESS(urCommandBufferAppendKernelLaunchExp(
485+
updatable_cmd_buf_handle, SaxpyKernel->Kernel, SaxpyKernel->NDimensions,
486+
&(SaxpyKernel->GlobalOffset), &(SaxpyKernel->GlobalSize),
487+
&(SaxpyKernel->LocalSize), KernelAlternatives.size(),
488+
KernelAlternatives.data(), 0, nullptr, 0, nullptr, nullptr, nullptr,
489+
CommandHandle.ptr()));
490+
ASSERT_NE(CommandHandle, nullptr);
491+
492+
ASSERT_SUCCESS(urCommandBufferFinalizeExp(updatable_cmd_buf_handle));
493+
494+
SaxpyKernel->UpdateDesc.hNewKernel = nullptr;
495+
ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(
496+
CommandHandle, &SaxpyKernel->UpdateDesc));
497+
ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0,
498+
nullptr, nullptr));
499+
ASSERT_SUCCESS(urQueueFinish(queue));
500+
501+
ASSERT_NO_FATAL_FAILURE(SaxpyKernel->validate());
502+
}

0 commit comments

Comments
 (0)