Skip to content

Commit dc65a88

Browse files
committed
[CUDA][HIP] Fix for command-buffer local argument update
After setting kernel arguments during update, we need to reset the amount of local memory used.
1 parent a563456 commit dc65a88

File tree

3 files changed

+431
-114
lines changed

3 files changed

+431
-114
lines changed

source/adapters/cuda/command_buffer.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1396,14 +1396,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
13961396

13971397
CUDA_KERNEL_NODE_PARAMS &Params = KernelCommandHandle->Params;
13981398

1399+
const auto LocalSize = KernelCommandHandle->Kernel->getLocalSize();
1400+
if (LocalSize != 0) {
1401+
// Clean the local size, otherwise calling updateKernelArguments() in
1402+
// future updates with local arguments will incorrectly increase the
1403+
// size further.
1404+
KernelCommandHandle->Kernel->clearLocalSize();
1405+
}
1406+
13991407
Params.func = CuFunc;
1400-
Params.gridDimX = BlocksPerGrid[0];
1401-
Params.gridDimY = BlocksPerGrid[1];
1402-
Params.gridDimZ = BlocksPerGrid[2];
1403-
Params.blockDimX = ThreadsPerBlock[0];
1404-
Params.blockDimY = ThreadsPerBlock[1];
1405-
Params.blockDimZ = ThreadsPerBlock[2];
1406-
Params.sharedMemBytes = KernelCommandHandle->Kernel->getLocalSize();
1408+
Params.gridDimX = static_cast<unsigned int>(BlocksPerGrid[0]);
1409+
Params.gridDimY = static_cast<unsigned int>(BlocksPerGrid[1]);
1410+
Params.gridDimZ = static_cast<unsigned int>(BlocksPerGrid[2]);
1411+
Params.blockDimX = static_cast<unsigned int>(ThreadsPerBlock[0]);
1412+
Params.blockDimY = static_cast<unsigned int>(ThreadsPerBlock[1]);
1413+
Params.blockDimZ = static_cast<unsigned int>(ThreadsPerBlock[2]);
1414+
Params.sharedMemBytes = LocalSize;
14071415
Params.kernelParams =
14081416
const_cast<void **>(KernelCommandHandle->Kernel->getArgIndices().data());
14091417

test/conformance/device_code/saxpy_usm_local_mem.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,27 @@ int main() {
1515
uint32_t A = 42;
1616

1717
sycl_queue.submit([&](sycl::handler &cgh) {
18-
sycl::local_accessor<uint32_t, 1> local_mem(local_size, cgh);
18+
sycl::local_accessor<uint32_t, 1> local_mem_A(local_size, cgh);
19+
sycl::local_accessor<uint32_t, 1> local_mem_B(1, cgh);
20+
1921
cgh.parallel_for<class saxpy_usm_local_mem>(
2022
sycl::nd_range<1>{{array_size}, {local_size}},
2123
[=](sycl::nd_item<1> itemId) {
2224
auto i = itemId.get_global_linear_id();
2325
auto local_id = itemId.get_local_linear_id();
24-
local_mem[local_id] = i;
25-
Z[i] = A * X[i] + Y[i] + local_mem[local_id] +
26+
27+
local_mem_A[local_id] = i;
28+
if (i == 0) {
29+
local_mem_B[0] = 0xA;
30+
}
31+
32+
Z[i] = A * X[i] + Y[i] + local_mem_A[local_id] +
2633
itemId.get_local_range(0);
34+
35+
if (i == 0) {
36+
Z[i] += local_mem_B[0];
37+
}
38+
2739
});
2840
});
2941
return 0;

0 commit comments

Comments
 (0)