Skip to content

[UR][L0] Fix issue with command-buffer local mem update #17069

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions unified-runtime/source/adapters/level_zero/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -910,17 +910,18 @@ ur_result_t setKernelPendingArguments(
* @param[in] CommandBuffer The CommandBuffer associated with the new command.
* @param[in] Kernel The Kernel associated with the new command.
* @param[in] WorkDim Dimensions of the kernel associated with the new command.
* @param[in] GlobalWorkSize Global work size of the kernel associated with the
* new command.
* @param[in] LocalWorkSize LocalWorkSize of the kernel associated with the new
* command.
* @param[out] Command The handle to the new command.
* @return UR_RESULT_SUCCESS or an error code on failure
*/
ur_result_t
createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,
ur_kernel_handle_t Kernel, uint32_t WorkDim,
const size_t *LocalWorkSize, uint32_t NumKernelAlternatives,
ur_kernel_handle_t *KernelAlternatives,
ur_exp_command_buffer_command_handle_t *Command) {
ur_result_t createCommandHandle(
ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel,
uint32_t WorkDim, const size_t *GlobalWorkSize, const size_t *LocalWorkSize,
uint32_t NumKernelAlternatives, ur_kernel_handle_t *KernelAlternatives,
ur_exp_command_buffer_command_handle_t *Command) {

assert(CommandBuffer->IsUpdatable);

Expand Down Expand Up @@ -992,6 +993,8 @@ createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,
CommandBuffer, Kernel, CommandId, WorkDim, LocalWorkSize != nullptr,
NumKernelAlternatives, KernelAlternatives);

NewCommand->setGlobalWorkSize(GlobalWorkSize);

*Command = NewCommand.get();

CommandBuffer->CommandHandles.push_back(std::move(NewCommand));
Expand Down Expand Up @@ -1066,9 +1069,9 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
}

if (Command) {
UR_CALL(createCommandHandle(CommandBuffer, Kernel, WorkDim, LocalWorkSize,
NumKernelAlternatives, KernelAlternatives,
Command));
UR_CALL(createCommandHandle(CommandBuffer, Kernel, WorkDim, GlobalWorkSize,
LocalWorkSize, NumKernelAlternatives,
KernelAlternatives, Command));
}
std::vector<ze_event_handle_t> ZeEventList;
ze_event_handle_t ZeLaunchEvent = nullptr;
Expand Down Expand Up @@ -1922,10 +1925,16 @@ ur_result_t updateKernelCommand(
Descs.push_back(std::move(MutableGroupSizeDesc));
}

// Check if a new global size is provided and if we need to update the group
// count.
// Check if a new global or local size is provided and if so we need to update
// the group count.
ze_group_count_t ZeThreadGroupDimensions{1, 1, 1};
if (NewGlobalWorkSize && Dim > 0) {
if ((NewGlobalWorkSize || NewLocalWorkSize) && Dim > 0) {
// If a new global work size is provided update that in the command,
// otherwise the previous work group size will be used
if (NewGlobalWorkSize) {
Command->WorkDim = Dim;
Command->setGlobalWorkSize(NewGlobalWorkSize);
}
// If a new global work size is provided but a new local work size is not
// then we still need to update local work size based on the size suggested
// by the driver for the kernel.
Expand All @@ -1935,9 +1944,9 @@ ur_result_t updateKernelCommand(
UR_CALL(getZeKernel(ZeDevice, Command->Kernel, &ZeKernel));

uint32_t WG[3];
UR_CALL(calculateKernelWorkDimensions(ZeKernel, CommandBuffer->Device,
ZeThreadGroupDimensions, WG, Dim,
NewGlobalWorkSize, NewLocalWorkSize));
UR_CALL(calculateKernelWorkDimensions(
ZeKernel, CommandBuffer->Device, ZeThreadGroupDimensions, WG, Dim,
Command->GlobalWorkSize, NewLocalWorkSize));

auto MutableGroupCountDesc =
std::make_unique<ZeStruct<ze_mutable_group_count_exp_desc_t>>();
Expand Down
11 changes: 11 additions & 0 deletions unified-runtime/source/adapters/level_zero/command_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,19 @@ struct kernel_command_handle : public ur_exp_command_buffer_command_handle_t_ {

~kernel_command_handle();

void setGlobalWorkSize(const size_t *GlobalWorkSizePtr) {
const size_t CopySize = sizeof(size_t) * WorkDim;
std::memcpy(GlobalWorkSize, GlobalWorkSizePtr, CopySize);
if (WorkDim < 3) {
const size_t ZeroSize = sizeof(size_t) * (3 - WorkDim);
std::memset(GlobalWorkSize + WorkDim, 0, ZeroSize);
}
}

// Work-dimension the command was originally created with.
uint32_t WorkDim;
// Global work size of the kernel
size_t GlobalWorkSize[3];
// Set to true if the user set the local work size on command creation.
bool UserDefinedLocalSize;
// Currently active kernel handle
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,6 @@ TEST_P(LocalMemoryUpdateTest, UpdateParametersEmptyLocalSize) {
// Test updating A,X,Y parameters to new values and local memory parameters
// to new smaller values.
TEST_P(LocalMemoryUpdateTest, UpdateParametersSmallerLocalSize) {
UUR_KNOWN_FAILURE_ON(uur::LevelZero{});

// Run command-buffer prior to update an verify output
ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0,
Expand Down Expand Up @@ -1081,11 +1080,6 @@ struct LocalMemoryUpdateTestBaseOutOfOrder : LocalMemoryUpdateTestBase {
UUR_RETURN_ON_FATAL_FAILURE(
urUpdatableCommandBufferExpExecutionTest::SetUp());

if (backend == UR_PLATFORM_BACKEND_LEVEL_ZERO) {
GTEST_SKIP()
<< "Local memory argument update not supported on Level Zero.";
}

// HIP has extra args for local memory so we define an offset for arg
// indices here for updating
hip_arg_offset = backend == UR_PLATFORM_BACKEND_HIP ? 3 : 0;
Expand Down