Skip to content

Commit 965ae32

Browse files
committed
[NativeCPU] Handle local args.
Depending on the number of available threads, NativeCPU goes through different code paths for launching kernels. Some of these were missing the call to kernel.handleLocalArgs, resulting in local pointers being left as nullptr. Skip this code path for kernels that use local pointers.
1 parent e10146e commit 965ae32

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

source/adapters/native_cpu/enqueue.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
138138
#else
139139
bool isLocalSizeOne =
140140
ndr.LocalSize[0] == 1 && ndr.LocalSize[1] == 1 && ndr.LocalSize[2] == 1;
141-
if (isLocalSizeOne && ndr.GlobalSize[0] > numParallelThreads) {
141+
if (isLocalSizeOne && ndr.GlobalSize[0] > numParallelThreads &&
142+
!hKernel->hasLocalArgs()) {
142143
// If the local size is one, we make the assumption that we are running a
143144
// parallel_for over a sycl::range.
144-
// Todo: we could add compiler checks and
145-
// kernel properties for this (e.g. check that no barriers are called, no
146-
// local memory args).
145+
// Todo: we could add more compiler checks and
146+
// kernel properties for this (e.g. check that no barriers are called).
147147

148148
// Todo: this assumes that dim 0 is the best dimension over which we want to
149149
// parallelize

source/adapters/native_cpu/kernel.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,9 @@ struct ur_kernel_handle_t_ : RefCounted {
142142
_localMemPoolSize = reqSize;
143143
}
144144

145-
// To be called before executing a work group
145+
bool hasLocalArgs() const { return !_localArgInfo.empty(); }
146+
147+
// To be called before executing a work group if local args are present
146148
void handleLocalArgs(size_t numParallelThread, size_t threadId) {
147149
// For each local argument we have size*numthreads
148150
size_t offset = 0;

0 commit comments

Comments
 (0)