Skip to content

Commit 17299ee

Browse files
authored
[SYCL] Implement braced-init-list or a number as range for queue::parallel_for (#1931)
Modification - Make three overloads for `queue::parallel_for` to support `range` implicit conversion from number or `braced-init-list` - Add tests for `queue::parallel_for` calls with generic lambda Implement the following `queue` extension - https://github.com/intel/llvm/tree/sycl/sycl/doc/extensions/QueueShortcuts/ Signed-off-by: Ruslan Arutyunyan <[email protected]>
1 parent 58fc414 commit 17299ee

File tree

4 files changed

+214
-9
lines changed

4 files changed

+214
-9
lines changed

sycl/include/CL/sycl/queue.hpp

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -432,10 +432,9 @@ class __SYCL_EXPORT queue {
432432
/// \param NumWorkItems is a range that specifies the work space of the kernel
433433
/// \param KernelFunc is the Kernel functor or lambda
434434
/// \param CodeLoc contains the code location of user code
435-
template <typename KernelName = detail::auto_name, typename KernelType,
436-
int Dims>
435+
template <typename KernelName = detail::auto_name, typename KernelType>
437436
event parallel_for(
438-
range<Dims> NumWorkItems, KernelType KernelFunc
437+
range<1> NumWorkItems, KernelType KernelFunc
439438
#ifndef DISABLE_SYCL_INSTRUMENTATION_METADATA
440439
,
441440
const detail::code_location &CodeLoc = detail::code_location::current()
@@ -444,12 +443,47 @@ class __SYCL_EXPORT queue {
444443
#ifdef DISABLE_SYCL_INSTRUMENTATION_METADATA
445444
const detail::code_location &CodeLoc = {};
446445
#endif
447-
return submit(
448-
[&](handler &CGH) {
449-
CGH.template parallel_for<KernelName, KernelType>(NumWorkItems,
450-
KernelFunc);
451-
},
452-
CodeLoc);
446+
return parallel_for_impl<KernelName>(NumWorkItems, KernelFunc, CodeLoc);
447+
}
448+
449+
/// parallel_for version with a kernel represented as a lambda + range that
450+
/// specifies global size only.
451+
///
452+
/// \param NumWorkItems is a range that specifies the work space of the kernel
453+
/// \param KernelFunc is the Kernel functor or lambda
454+
/// \param CodeLoc contains the code location of user code
455+
template <typename KernelName = detail::auto_name, typename KernelType>
456+
event parallel_for(
457+
range<2> NumWorkItems, KernelType KernelFunc
458+
#ifndef DISABLE_SYCL_INSTRUMENTATION_METADATA
459+
,
460+
const detail::code_location &CodeLoc = detail::code_location::current()
461+
#endif
462+
) {
463+
#ifdef DISABLE_SYCL_INSTRUMENTATION_METADATA
464+
const detail::code_location &CodeLoc = {};
465+
#endif
466+
return parallel_for_impl<KernelName>(NumWorkItems, KernelFunc, CodeLoc);
467+
}
468+
469+
/// parallel_for version with a kernel represented as a lambda + range that
470+
/// specifies global size only.
471+
///
472+
/// \param NumWorkItems is a range that specifies the work space of the kernel
473+
/// \param KernelFunc is the Kernel functor or lambda
474+
/// \param CodeLoc contains the code location of user code
475+
template <typename KernelName = detail::auto_name, typename KernelType>
476+
event parallel_for(
477+
range<3> NumWorkItems, KernelType KernelFunc
478+
#ifndef DISABLE_SYCL_INSTRUMENTATION_METADATA
479+
,
480+
const detail::code_location &CodeLoc = detail::code_location::current()
481+
#endif
482+
) {
483+
#ifdef DISABLE_SYCL_INSTRUMENTATION_METADATA
484+
const detail::code_location &CodeLoc = {};
485+
#endif
486+
return parallel_for_impl<KernelName>(NumWorkItems, KernelFunc, CodeLoc);
453487
}
454488

455489
/// parallel_for version with a kernel represented as a lambda + range that
@@ -716,6 +750,25 @@ class __SYCL_EXPORT queue {
716750
/// A template-free version of submit.
717751
event submit_impl(function_class<void(handler &)> CGH, queue secondQueue,
718752
const detail::code_location &CodeLoc);
753+
754+
/// parallel_for_impl with a kernel represented as a lambda + range that
755+
/// specifies global size only.
756+
///
757+
/// \param NumWorkItems is a range that specifies the work space of the kernel
758+
/// \param KernelFunc is the Kernel functor or lambda
759+
/// \param CodeLoc contains the code location of user code
760+
template <typename KernelName = detail::auto_name, typename KernelType,
761+
int Dims>
762+
event parallel_for_impl(
763+
range<Dims> NumWorkItems, KernelType KernelFunc,
764+
const detail::code_location &CodeLoc = detail::code_location::current()) {
765+
return submit(
766+
[&](handler &CGH) {
767+
CGH.template parallel_for<KernelName, KernelType>(NumWorkItems,
768+
KernelFunc);
769+
},
770+
CodeLoc);
771+
}
719772
};
720773

721774
} // namespace sycl
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// XFAIL: cuda
2+
// piextUSM*Alloc functions for CUDA are not behaving as described in
3+
// https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/USM/USM.adoc
4+
// https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/USM/cl_intel_unified_shared_memory.asciidoc
5+
//
6+
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
7+
// RUN: env SYCL_DEVICE_TYPE=HOST %t.out
8+
// RUN: %ACC_RUN_PLACEHOLDER %t.out
9+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
10+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
11+
12+
//==- queue_parallel_for_generic.cpp - SYCL queue parallel_for generic lambda -=//
13+
//
14+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
15+
// See https://llvm.org/LICENSE.txt for license information.
16+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
17+
//
18+
//===------------------------------------------------------------------------===//
19+
20+
#include <CL/sycl.hpp>
21+
#include <iostream>
22+
#include <type_traits>
23+
24+
int main() {
25+
sycl::queue q{};
26+
auto dev = q.get_device();
27+
auto ctx = q.get_context();
28+
constexpr int N = 8;
29+
30+
if (!dev.get_info<sycl::info::device::usm_shared_allocations>()) {
31+
return 0;
32+
}
33+
34+
auto A = static_cast<int *>(sycl::malloc_shared(N * sizeof(int), dev, ctx));
35+
36+
for (int i = 0; i < N; i++) {
37+
A[i] = 1;
38+
}
39+
40+
q.parallel_for<class Bar>(N, [=](auto i) {
41+
static_assert(std::is_same<decltype(i), sycl::item<1>>::value,
42+
"lambda arg type is unexpected");
43+
A[i]++;
44+
});
45+
46+
q.parallel_for<class Foo>({N}, [=](auto i) {
47+
static_assert(std::is_same<decltype(i), sycl::item<1>>::value,
48+
"lambda arg type is unexpected");
49+
A[i]++;
50+
});
51+
52+
sycl::id<1> offset(0);
53+
q.parallel_for<class Baz>(sycl::range<1>{N}, offset, [=](auto i) {
54+
static_assert(std::is_same<decltype(i), sycl::item<1>>::value,
55+
"lambda arg type is unexpected");
56+
A[i]++;
57+
});
58+
59+
sycl::nd_range<1> NDR(sycl::range<1>{N}, sycl::range<1>{2});
60+
q.parallel_for<class NDFoo>(NDR, [=](auto nd_i) {
61+
static_assert(std::is_same<decltype(nd_i), sycl::nd_item<1>>::value,
62+
"lambda arg type is unexpected");
63+
auto i = nd_i.get_global_id(0);
64+
A[i]++;
65+
});
66+
67+
q.wait();
68+
69+
for (int i = 0; i < N; i++) {
70+
if (A[i] != 5)
71+
return 1;
72+
}
73+
sycl::free(A, ctx);
74+
}
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// RUN: %clangxx -fsycl -fsyntax-only %s -o %t.out
2+
3+
//==- queue_parallel_for_generic.cpp - SYCL queue parallel_for interface test -=//
4+
//
5+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6+
// See https://llvm.org/LICENSE.txt for license information.
7+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
//
9+
//===------------------------------------------------------------------------===//
10+
11+
#include <CL/sycl.hpp>
12+
#include <iostream>
13+
#include <type_traits>
14+
15+
template <typename KernelName, std::size_t... Is>
16+
void test_range_impl(sycl::queue q, std::index_sequence<Is...>,
17+
sycl::range<sizeof...(Is)> *) {
18+
constexpr auto dims = sizeof...(Is);
19+
20+
q.parallel_for<KernelName>(sycl::range<dims>{Is...}, [=](auto i) {
21+
static_assert(std::is_same<decltype(i), sycl::item<dims>>::value,
22+
"lambda arg type is unexpected");
23+
});
24+
}
25+
26+
template <typename KernelName, std::size_t... Is>
27+
void test_range_impl(sycl::queue q, std::index_sequence<Is...>,
28+
sycl::nd_range<sizeof...(Is)> *) {
29+
constexpr auto dims = sizeof...(Is);
30+
31+
sycl::nd_range<dims> ndr{sycl::range<dims>{Is...}, sycl::range<dims>{Is...}};
32+
q.parallel_for<KernelName>(ndr, [=](auto i) {
33+
static_assert(std::is_same<decltype(i), sycl::nd_item<dims>>::value,
34+
"lambda arg type is unexpected");
35+
});
36+
}
37+
38+
template <typename KernelName, template <int> class Range, std::size_t Dims>
39+
void test_range(sycl::queue q) {
40+
test_range_impl<KernelName>(q, std::make_index_sequence<Dims>{},
41+
static_cast<Range<Dims> *>(nullptr));
42+
}
43+
44+
void test_number_braced_init_list(sycl::queue q) {
45+
constexpr auto n = 1;
46+
q.parallel_for<class Number>(n, [=](auto i) {
47+
static_assert(std::is_same<decltype(i), sycl::item<1>>::value,
48+
"lambda arg type is unexpected");
49+
});
50+
51+
q.parallel_for<class BracedInitList1>({n}, [=](auto i) {
52+
static_assert(std::is_same<decltype(i), sycl::item<1>>::value,
53+
"lambda arg type is unexpected");
54+
});
55+
56+
q.parallel_for<class BracedInitList2>({n, n}, [=](auto i) {
57+
static_assert(std::is_same<decltype(i), sycl::item<2>>::value,
58+
"lambda arg type is unexpected");
59+
});
60+
61+
q.parallel_for<class BracedInitList3>({n, n, n}, [=](auto i) {
62+
static_assert(std::is_same<decltype(i), sycl::item<3>>::value,
63+
"lambda arg type is unexpected");
64+
});
65+
}
66+
67+
int main() {
68+
sycl::queue q{};
69+
70+
test_range<class test_range1, sycl::range, 1>(q);
71+
test_range<class test_range2, sycl::range, 2>(q);
72+
test_range<class test_range3, sycl::range, 3>(q);
73+
test_range<class test_nd_range1, sycl::nd_range, 1>(q);
74+
test_range<class test_nd_range2, sycl::nd_range, 2>(q);
75+
test_range<class test_nd_range3, sycl::nd_range, 3>(q);
76+
77+
test_number_braced_init_list(q);
78+
}

0 commit comments

Comments
 (0)