Skip to content

Commit 928ed3e

Browse files
authored
[UR][L0] Fix issue with command-buffer local mem update (#17069)
- Fix group count not being recalculated when a user only passes a new local work size and no new global size - Remove CTS test skips for local update on L0
1 parent a739d34 commit 928ed3e

File tree

3 files changed

+35
-21
lines changed

3 files changed

+35
-21
lines changed

unified-runtime/source/adapters/level_zero/command_buffer.cpp

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -907,17 +907,18 @@ ur_result_t setKernelPendingArguments(
907907
* @param[in] CommandBuffer The CommandBuffer associated with the new command.
908908
* @param[in] Kernel The Kernel associated with the new command.
909909
* @param[in] WorkDim Dimensions of the kernel associated with the new command.
910+
* @param[in] GlobalWorkSize Global work size of the kernel associated with the
911+
* new command.
910912
* @param[in] LocalWorkSize LocalWorkSize of the kernel associated with the new
911913
* command.
912914
* @param[out] Command The handle to the new command.
913915
* @return UR_RESULT_SUCCESS or an error code on failure
914916
*/
915-
ur_result_t
916-
createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,
917-
ur_kernel_handle_t Kernel, uint32_t WorkDim,
918-
const size_t *LocalWorkSize, uint32_t NumKernelAlternatives,
919-
ur_kernel_handle_t *KernelAlternatives,
920-
ur_exp_command_buffer_command_handle_t *Command) {
917+
ur_result_t createCommandHandle(
918+
ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel,
919+
uint32_t WorkDim, const size_t *GlobalWorkSize, const size_t *LocalWorkSize,
920+
uint32_t NumKernelAlternatives, ur_kernel_handle_t *KernelAlternatives,
921+
ur_exp_command_buffer_command_handle_t *Command) {
921922

922923
assert(CommandBuffer->IsUpdatable);
923924

@@ -989,6 +990,8 @@ createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,
989990
CommandBuffer, Kernel, CommandId, WorkDim, LocalWorkSize != nullptr,
990991
NumKernelAlternatives, KernelAlternatives);
991992

993+
NewCommand->setGlobalWorkSize(GlobalWorkSize);
994+
992995
*Command = NewCommand.get();
993996

994997
CommandBuffer->CommandHandles.push_back(std::move(NewCommand));
@@ -1063,9 +1066,9 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
10631066
}
10641067

10651068
if (Command) {
1066-
UR_CALL(createCommandHandle(CommandBuffer, Kernel, WorkDim, LocalWorkSize,
1067-
NumKernelAlternatives, KernelAlternatives,
1068-
Command));
1069+
UR_CALL(createCommandHandle(CommandBuffer, Kernel, WorkDim, GlobalWorkSize,
1070+
LocalWorkSize, NumKernelAlternatives,
1071+
KernelAlternatives, Command));
10691072
}
10701073
std::vector<ze_event_handle_t> ZeEventList;
10711074
ze_event_handle_t ZeLaunchEvent = nullptr;
@@ -1919,10 +1922,16 @@ ur_result_t updateKernelCommand(
19191922
Descs.push_back(std::move(MutableGroupSizeDesc));
19201923
}
19211924

1922-
// Check if a new global size is provided and if we need to update the group
1923-
// count.
1925+
// Check if a new global or local size is provided and if so we need to update
1926+
// the group count.
19241927
ze_group_count_t ZeThreadGroupDimensions{1, 1, 1};
1925-
if (NewGlobalWorkSize && Dim > 0) {
1928+
if ((NewGlobalWorkSize || NewLocalWorkSize) && Dim > 0) {
1929+
// If a new global work size is provided update that in the command,
1930+
// otherwise the previous work group size will be used
1931+
if (NewGlobalWorkSize) {
1932+
Command->WorkDim = Dim;
1933+
Command->setGlobalWorkSize(NewGlobalWorkSize);
1934+
}
19261935
// If a new global work size is provided but a new local work size is not
19271936
// then we still need to update local work size based on the size suggested
19281937
// by the driver for the kernel.
@@ -1932,9 +1941,9 @@ ur_result_t updateKernelCommand(
19321941
UR_CALL(getZeKernel(ZeDevice, Command->Kernel, &ZeKernel));
19331942

19341943
uint32_t WG[3];
1935-
UR_CALL(calculateKernelWorkDimensions(ZeKernel, CommandBuffer->Device,
1936-
ZeThreadGroupDimensions, WG, Dim,
1937-
NewGlobalWorkSize, NewLocalWorkSize));
1944+
UR_CALL(calculateKernelWorkDimensions(
1945+
ZeKernel, CommandBuffer->Device, ZeThreadGroupDimensions, WG, Dim,
1946+
Command->GlobalWorkSize, NewLocalWorkSize));
19381947

19391948
auto MutableGroupCountDesc =
19401949
std::make_unique<ZeStruct<ze_mutable_group_count_exp_desc_t>>();

unified-runtime/source/adapters/level_zero/command_buffer.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,19 @@ struct kernel_command_handle : public ur_exp_command_buffer_command_handle_t_ {
172172

173173
~kernel_command_handle();
174174

175+
void setGlobalWorkSize(const size_t *GlobalWorkSizePtr) {
176+
const size_t CopySize = sizeof(size_t) * WorkDim;
177+
std::memcpy(GlobalWorkSize, GlobalWorkSizePtr, CopySize);
178+
if (WorkDim < 3) {
179+
const size_t ZeroSize = sizeof(size_t) * (3 - WorkDim);
180+
std::memset(GlobalWorkSize + WorkDim, 0, ZeroSize);
181+
}
182+
}
183+
175184
// Work-dimension the command was originally created with.
176185
uint32_t WorkDim;
186+
// Global work size of the kernel
187+
size_t GlobalWorkSize[3];
177188
// Set to true if the user set the local work size on command creation.
178189
bool UserDefinedLocalSize;
179190
// Currently active kernel handle

unified-runtime/test/conformance/exp_command_buffer/update/local_memory_update.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,6 @@ TEST_P(LocalMemoryUpdateTest, UpdateParametersEmptyLocalSize) {
378378
// Test updating A,X,Y parameters to new values and local memory parameters
379379
// to new smaller values.
380380
TEST_P(LocalMemoryUpdateTest, UpdateParametersSmallerLocalSize) {
381-
UUR_KNOWN_FAILURE_ON(uur::LevelZero{});
382381

383382
// Run command-buffer prior to update an verify output
384383
ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0,
@@ -1081,11 +1080,6 @@ struct LocalMemoryUpdateTestBaseOutOfOrder : LocalMemoryUpdateTestBase {
10811080
UUR_RETURN_ON_FATAL_FAILURE(
10821081
urUpdatableCommandBufferExpExecutionTest::SetUp());
10831082

1084-
if (backend == UR_PLATFORM_BACKEND_LEVEL_ZERO) {
1085-
GTEST_SKIP()
1086-
<< "Local memory argument update not supported on Level Zero.";
1087-
}
1088-
10891083
// HIP has extra args for local memory so we define an offset for arg
10901084
// indices here for updating
10911085
hip_arg_offset = backend == UR_PLATFORM_BACKEND_HIP ? 3 : 0;

0 commit comments

Comments
 (0)