Skip to content

Commit 0408899

Browse files
[SYCL] Fix unexpected acceptance of id argument in nd_range parallel_for (#1348)
The kernel callable being invoked from an nd_range parallel_for is accepting an id argument, while it should be nd_item. After my analysis, I found we check arguments' type for kernel_parallel_for instead of parallel_for. But that check is useless, because the compiler can still find a candidate for kernel_parallel_for with nd_range and id which is a wrong combination. In my solution, parallel_for with nd_range calls kernel_parallel_for_nd_range(...) which is only available for nd_item. Signed-off-by: Bing1 Yu <[email protected]>
1 parent 6b44ebb commit 0408899

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

sycl/include/CL/sycl/handler.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,8 @@ class handler {
428428

429429
template <typename KernelName, typename KernelType, int Dims,
430430
EnableIfNDItem<KernelType, Dims> = 0>
431-
__attribute__((sycl_kernel)) void kernel_parallel_for(KernelType KernelFunc) {
431+
__attribute__((sycl_kernel)) void
432+
kernel_parallel_for_nd_range(KernelType KernelFunc) {
432433
KernelFunc(detail::Builder::getNDItem<Dims>());
433434
}
434435

@@ -618,7 +619,7 @@ class handler {
618619
using NameT =
619620
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
620621
#ifdef __SYCL_DEVICE_ONLY__
621-
kernel_parallel_for<NameT, KernelType, Dims>(KernelFunc);
622+
kernel_parallel_for_nd_range<NameT, KernelType, Dims>(KernelFunc);
622623
#else
623624
MNDRDesc.set(std::move(ExecutionRange));
624625
StoreLambda<NameT, KernelType, Dims>(std::move(KernelFunc));
@@ -855,7 +856,7 @@ class handler {
855856
using NameT =
856857
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
857858
#ifdef __SYCL_DEVICE_ONLY__
858-
kernel_parallel_for<NameT, KernelType, Dims>(KernelFunc);
859+
kernel_parallel_for_nd_range<NameT, KernelType, Dims>(KernelFunc);
859860
#else
860861
MNDRDesc.set(std::move(NDRange));
861862
MKernel = detail::getSyclObjImpl(std::move(Kernel));

sycl/test/basic_tests/parallel_for_range_host.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ int main() {
9898
// parallel_for, 30 global, 1(implicit) local -> pass.
9999
try {
100100
Q.submit([&](handler &CGH) {
101-
CGH.parallel_for<class g>(range<1>(30),
102-
[=](nd_item<1>) {});
101+
CGH.parallel_for<class g>(range<1>(30),
102+
[=](id<1>) {});
103103
});
104104
Q.wait_and_throw();
105105
} catch (nd_range_error) {

sycl/test/ordered_queue/oq_kernels.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ int main() {
5353
});
5454

5555
nd_range<1> NDR(range<1>{N}, range<1>{2});
56-
q.parallel_for<class NDFoo>(NDR, [=](id<1> ID) {
57-
auto i = ID[0];
56+
q.parallel_for<class NDFoo>(NDR, [=](nd_item<1> Item) {
57+
auto i = Item.get_global_id(0);
5858
A[i]++;
5959
});
6060

0 commit comments

Comments
 (0)