File tree Expand file tree Collapse file tree 3 files changed +431
-114
lines changed
exp_command_buffer/update Expand file tree Collapse file tree 3 files changed +431
-114
lines changed Original file line number Diff line number Diff line change @@ -1396,14 +1396,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
1396
1396
1397
1397
CUDA_KERNEL_NODE_PARAMS &Params = KernelCommandHandle->Params ;
1398
1398
1399
+ const auto LocalSize = KernelCommandHandle->Kernel ->getLocalSize ();
1400
+ if (LocalSize != 0 ) {
1401
+ // Clean the local size, otherwise calling updateKernelArguments() in
1402
+ // future updates with local arguments will incorrectly increase the
1403
+ // size further.
1404
+ KernelCommandHandle->Kernel ->clearLocalSize ();
1405
+ }
1406
+
1399
1407
Params.func = CuFunc;
1400
- Params.gridDimX = BlocksPerGrid[0 ];
1401
- Params.gridDimY = BlocksPerGrid[1 ];
1402
- Params.gridDimZ = BlocksPerGrid[2 ];
1403
- Params.blockDimX = ThreadsPerBlock[0 ];
1404
- Params.blockDimY = ThreadsPerBlock[1 ];
1405
- Params.blockDimZ = ThreadsPerBlock[2 ];
1406
- Params.sharedMemBytes = KernelCommandHandle-> Kernel -> getLocalSize () ;
1408
+ Params.gridDimX = static_cast < unsigned int >( BlocksPerGrid[0 ]) ;
1409
+ Params.gridDimY = static_cast < unsigned int >( BlocksPerGrid[1 ]) ;
1410
+ Params.gridDimZ = static_cast < unsigned int >( BlocksPerGrid[2 ]) ;
1411
+ Params.blockDimX = static_cast < unsigned int >( ThreadsPerBlock[0 ]) ;
1412
+ Params.blockDimY = static_cast < unsigned int >( ThreadsPerBlock[1 ]) ;
1413
+ Params.blockDimZ = static_cast < unsigned int >( ThreadsPerBlock[2 ]) ;
1414
+ Params.sharedMemBytes = LocalSize ;
1407
1415
Params.kernelParams =
1408
1416
const_cast <void **>(KernelCommandHandle->Kernel ->getArgIndices ().data ());
1409
1417
Original file line number Diff line number Diff line change @@ -15,15 +15,27 @@ int main() {
15
15
uint32_t A = 42 ;
16
16
17
17
sycl_queue.submit ([&](sycl::handler &cgh) {
18
- sycl::local_accessor<uint32_t , 1 > local_mem (local_size, cgh);
18
+ sycl::local_accessor<uint32_t , 1 > local_mem_A (local_size, cgh);
19
+ sycl::local_accessor<uint32_t , 1 > local_mem_B (1 , cgh);
20
+
19
21
cgh.parallel_for <class saxpy_usm_local_mem >(
20
22
sycl::nd_range<1 >{{array_size}, {local_size}},
21
23
[=](sycl::nd_item<1 > itemId) {
22
24
auto i = itemId.get_global_linear_id ();
23
25
auto local_id = itemId.get_local_linear_id ();
24
- local_mem[local_id] = i;
25
- Z[i] = A * X[i] + Y[i] + local_mem[local_id] +
26
+
27
+ local_mem_A[local_id] = i;
28
+ if (i == 0 ) {
29
+ local_mem_B[0 ] = 0xA ;
30
+ }
31
+
32
+ Z[i] = A * X[i] + Y[i] + local_mem_A[local_id] +
26
33
itemId.get_local_range (0 );
34
+
35
+ if (i == 0 ) {
36
+ Z[i] += local_mem_B[0 ];
37
+ }
38
+
27
39
});
28
40
});
29
41
return 0 ;
You can’t perform that action at this time.
0 commit comments