Skip to content

Commit 3024161

Browse files
[SYCL][Reduction] Fix range deductions in reduction parallel_for (#9813)
Currently `parallel_for` taking reduction variables take the dimensionality as a template argument. This means the range deduction guides cannot be used. This commit amends this by splitting them into separate function definitions per valid dimensionality. This adheres to [4.9.4.2.2. parallel_for invoke](https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#_parallel_for_invoke): > The parallel_for overload without an offset can be called with either a number or a braced-init-list with 1-3 elements. In that case the following calls are equivalent: > - parallel_for(N, some_kernel) has same effect as parallel_for(range<1>(N), some_kernel) > - parallel_for({N}, some_kernel) has same effect as parallel_for(range<1>(N), some_kernel) > - parallel_for({N1, N2}, some_kernel) has same effect as parallel_for(range<2>(N1, N2), some_kernel) > - parallel_for({N1, N2, N3}, some_kernel) has same effect as parallel_for(range<3>(N1, N2, N3), some_kernel) --------- Signed-off-by: Larsen, Steffen <[email protected]>
1 parent 6d2a46a commit 3024161

File tree

2 files changed

+189
-5
lines changed

2 files changed

+189
-5
lines changed

sycl/include/sycl/handler.hpp

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2125,21 +2125,58 @@ class __SYCL_EXPORT handler {
21252125

21262126
/// Reductions @{
21272127

2128-
template <typename KernelName = detail::auto_name, int Dims,
2129-
typename PropertiesT, typename... RestT>
2128+
template <typename KernelName = detail::auto_name, typename PropertiesT,
2129+
typename... RestT>
21302130
std::enable_if_t<
21312131
(sizeof...(RestT) > 1) &&
21322132
detail::AreAllButLastReductions<RestT...>::value &&
21332133
ext::oneapi::experimental::is_property_list<PropertiesT>::value>
2134-
parallel_for(range<Dims> Range, PropertiesT Properties, RestT &&...Rest) {
2134+
parallel_for(range<1> Range, PropertiesT Properties, RestT &&...Rest) {
21352135
detail::reduction_parallel_for<KernelName>(*this, Range, Properties,
21362136
std::forward<RestT>(Rest)...);
21372137
}
21382138

2139-
template <typename KernelName = detail::auto_name, int Dims,
2139+
template <typename KernelName = detail::auto_name, typename PropertiesT,
21402140
typename... RestT>
2141+
std::enable_if_t<
2142+
(sizeof...(RestT) > 1) &&
2143+
detail::AreAllButLastReductions<RestT...>::value &&
2144+
ext::oneapi::experimental::is_property_list<PropertiesT>::value>
2145+
parallel_for(range<2> Range, PropertiesT Properties, RestT &&...Rest) {
2146+
detail::reduction_parallel_for<KernelName>(*this, Range, Properties,
2147+
std::forward<RestT>(Rest)...);
2148+
}
2149+
2150+
template <typename KernelName = detail::auto_name, typename PropertiesT,
2151+
typename... RestT>
2152+
std::enable_if_t<
2153+
(sizeof...(RestT) > 1) &&
2154+
detail::AreAllButLastReductions<RestT...>::value &&
2155+
ext::oneapi::experimental::is_property_list<PropertiesT>::value>
2156+
parallel_for(range<3> Range, PropertiesT Properties, RestT &&...Rest) {
2157+
detail::reduction_parallel_for<KernelName>(*this, Range, Properties,
2158+
std::forward<RestT>(Rest)...);
2159+
}
2160+
2161+
template <typename KernelName = detail::auto_name, typename... RestT>
2162+
std::enable_if_t<detail::AreAllButLastReductions<RestT...>::value>
2163+
parallel_for(range<1> Range, RestT &&...Rest) {
2164+
parallel_for<KernelName>(
2165+
Range, ext::oneapi::experimental::detail::empty_properties_t{},
2166+
std::forward<RestT>(Rest)...);
2167+
}
2168+
2169+
template <typename KernelName = detail::auto_name, typename... RestT>
2170+
std::enable_if_t<detail::AreAllButLastReductions<RestT...>::value>
2171+
parallel_for(range<2> Range, RestT &&...Rest) {
2172+
parallel_for<KernelName>(
2173+
Range, ext::oneapi::experimental::detail::empty_properties_t{},
2174+
std::forward<RestT>(Rest)...);
2175+
}
2176+
2177+
template <typename KernelName = detail::auto_name, typename... RestT>
21412178
std::enable_if_t<detail::AreAllButLastReductions<RestT...>::value>
2142-
parallel_for(range<Dims> Range, RestT &&...Rest) {
2179+
parallel_for(range<3> Range, RestT &&...Rest) {
21432180
parallel_for<KernelName>(
21442181
Range, ext::oneapi::experimental::detail::empty_properties_t{},
21452182
std::forward<RestT>(Rest)...);
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
// RUN: %clangxx -fsycl -fsyntax-only %s
2+
3+
// Tests reduction parallel_for can use SYCL 2020 range deduction guides.
4+
5+
#include <sycl/sycl.hpp>
6+
7+
template <class T> struct PlusWithoutIdentity {
8+
T operator()(const T &A, const T &B) const { return A + B; }
9+
};
10+
11+
int main() {
12+
sycl::queue Q;
13+
14+
int *ScalarMem = sycl::malloc_shared<int>(1, Q);
15+
int *SpanMem = sycl::malloc_shared<int>(8, Q);
16+
auto ScalarRed1 = sycl::reduction(ScalarMem, std::plus<int>{});
17+
auto ScalarRed2 = sycl::reduction(ScalarMem, PlusWithoutIdentity<int>{});
18+
auto SpanRed1 =
19+
sycl::reduction(sycl::span<int, 8>{SpanMem, 8}, std::plus<int>{});
20+
auto SpanRed2 = sycl::reduction(sycl::span<int, 8>{SpanMem, 8},
21+
PlusWithoutIdentity<int>{});
22+
23+
// Shortcut and range<1> deduction from integer.
24+
Q.parallel_for(1024, ScalarRed1, [=](sycl::item<1>, auto &) {});
25+
Q.parallel_for(1024, SpanRed1, [=](sycl::item<1>, auto &) {});
26+
Q.parallel_for(1024, ScalarRed1, ScalarRed2,
27+
[=](sycl::item<1>, auto &, auto &) {});
28+
Q.parallel_for(1024, SpanRed1, SpanRed2,
29+
[=](sycl::item<1>, auto &, auto &) {});
30+
Q.parallel_for(1024, ScalarRed1, SpanRed1, ScalarRed2, SpanRed2,
31+
[=](sycl::item<1>, auto &, auto &, auto &, auto &) {});
32+
33+
// Shortcut and range<1> deduction from initializer.
34+
Q.parallel_for({1024}, ScalarRed1, [=](sycl::item<1>, auto &) {});
35+
Q.parallel_for({1024}, SpanRed1, [=](sycl::item<1>, auto &) {});
36+
Q.parallel_for({1024}, ScalarRed1, ScalarRed2,
37+
[=](sycl::item<1>, auto &, auto &) {});
38+
Q.parallel_for({1024}, SpanRed1, SpanRed2,
39+
[=](sycl::item<1>, auto &, auto &) {});
40+
Q.parallel_for({1024}, ScalarRed1, SpanRed1, ScalarRed2, SpanRed2,
41+
[=](sycl::item<1>, auto &, auto &, auto &, auto &) {});
42+
43+
// Shortcut and range<2> deduction from initializer.
44+
Q.parallel_for({1024, 1024}, ScalarRed1, [=](sycl::item<2>, auto &) {});
45+
Q.parallel_for({1024, 1024}, SpanRed1, [=](sycl::item<2>, auto &) {});
46+
Q.parallel_for({1024, 1024}, ScalarRed1, ScalarRed2,
47+
[=](sycl::item<2>, auto &, auto &) {});
48+
Q.parallel_for({1024, 1024}, SpanRed1, SpanRed2,
49+
[=](sycl::item<2>, auto &, auto &) {});
50+
Q.parallel_for({1024, 1024}, ScalarRed1, SpanRed1, ScalarRed2, SpanRed2,
51+
[=](sycl::item<2>, auto &, auto &, auto &, auto &) {});
52+
53+
// Shortcut and range<3> deduction from initializer.
54+
Q.parallel_for({1024, 1024, 1024}, ScalarRed1, [=](sycl::item<3>, auto &) {});
55+
Q.parallel_for({1024, 1024, 1024}, SpanRed1, [=](sycl::item<3>, auto &) {});
56+
Q.parallel_for({1024, 1024, 1024}, ScalarRed1, ScalarRed2,
57+
[=](sycl::item<3>, auto &, auto &) {});
58+
Q.parallel_for({1024, 1024, 1024}, SpanRed1, SpanRed2,
59+
[=](sycl::item<3>, auto &, auto &) {});
60+
Q.parallel_for({1024, 1024, 1024}, ScalarRed1, SpanRed1, ScalarRed2, SpanRed2,
61+
[=](sycl::item<3>, auto &, auto &, auto &, auto &) {});
62+
63+
// Submission and range<1> deduction from integer.
64+
Q.submit([&](sycl::handler &CGH) {
65+
CGH.parallel_for(1024, ScalarRed1, [=](sycl::item<1>, auto &) {});
66+
});
67+
Q.submit([&](sycl::handler &CGH) {
68+
CGH.parallel_for(1024, SpanRed1, [=](sycl::item<1>, auto &) {});
69+
});
70+
Q.submit([&](sycl::handler &CGH) {
71+
CGH.parallel_for(1024, ScalarRed1, ScalarRed2,
72+
[=](sycl::item<1>, auto &, auto &) {});
73+
});
74+
Q.submit([&](sycl::handler &CGH) {
75+
CGH.parallel_for(1024, SpanRed1, SpanRed2,
76+
[=](sycl::item<1>, auto &, auto &) {});
77+
});
78+
Q.submit([&](sycl::handler &CGH) {
79+
CGH.parallel_for(1024, ScalarRed1, SpanRed1, ScalarRed2, SpanRed2,
80+
[=](sycl::item<1>, auto &, auto &, auto &, auto &) {});
81+
});
82+
83+
// Submission and range<1> deduction from initializer.
84+
Q.submit([&](sycl::handler &CGH) {
85+
CGH.parallel_for({1024}, ScalarRed1, [=](sycl::item<1>, auto &) {});
86+
});
87+
Q.submit([&](sycl::handler &CGH) {
88+
CGH.parallel_for({1024}, SpanRed1, [=](sycl::item<1>, auto &) {});
89+
});
90+
Q.submit([&](sycl::handler &CGH) {
91+
CGH.parallel_for({1024}, ScalarRed1, ScalarRed2,
92+
[=](sycl::item<1>, auto &, auto &) {});
93+
});
94+
Q.submit([&](sycl::handler &CGH) {
95+
CGH.parallel_for({1024}, SpanRed1, SpanRed2,
96+
[=](sycl::item<1>, auto &, auto &) {});
97+
});
98+
Q.submit([&](sycl::handler &CGH) {
99+
CGH.parallel_for({1024}, ScalarRed1, SpanRed1, ScalarRed2, SpanRed2,
100+
[=](sycl::item<1>, auto &, auto &, auto &, auto &) {});
101+
});
102+
103+
// Submission and range<2> deduction from initializer.
104+
Q.submit([&](sycl::handler &CGH) {
105+
CGH.parallel_for({1024, 1024}, ScalarRed1, [=](sycl::item<2>, auto &) {});
106+
});
107+
Q.submit([&](sycl::handler &CGH) {
108+
CGH.parallel_for({1024, 1024}, SpanRed1, [=](sycl::item<2>, auto &) {});
109+
});
110+
Q.submit([&](sycl::handler &CGH) {
111+
CGH.parallel_for({1024, 1024}, ScalarRed1, ScalarRed2,
112+
[=](sycl::item<2>, auto &, auto &) {});
113+
});
114+
Q.submit([&](sycl::handler &CGH) {
115+
CGH.parallel_for({1024, 1024}, SpanRed1, SpanRed2,
116+
[=](sycl::item<2>, auto &, auto &) {});
117+
});
118+
Q.submit([&](sycl::handler &CGH) {
119+
CGH.parallel_for({1024, 1024}, ScalarRed1, SpanRed1, ScalarRed2, SpanRed2,
120+
[=](sycl::item<2>, auto &, auto &, auto &, auto &) {});
121+
});
122+
123+
// Submission and range<3> deduction from initializer.
124+
Q.submit([&](sycl::handler &CGH) {
125+
CGH.parallel_for({1024, 1024, 1024}, ScalarRed1,
126+
[=](sycl::item<3>, auto &) {});
127+
});
128+
Q.submit([&](sycl::handler &CGH) {
129+
CGH.parallel_for({1024, 1024, 1024}, SpanRed1,
130+
[=](sycl::item<3>, auto &) {});
131+
});
132+
Q.submit([&](sycl::handler &CGH) {
133+
CGH.parallel_for({1024, 1024, 1024}, ScalarRed1, ScalarRed2,
134+
[=](sycl::item<3>, auto &, auto &) {});
135+
});
136+
Q.submit([&](sycl::handler &CGH) {
137+
CGH.parallel_for({1024, 1024, 1024}, SpanRed1, SpanRed2,
138+
[=](sycl::item<3>, auto &, auto &) {});
139+
});
140+
Q.submit([&](sycl::handler &CGH) {
141+
CGH.parallel_for({1024, 1024, 1024}, ScalarRed1, SpanRed1, ScalarRed2,
142+
SpanRed2,
143+
[=](sycl::item<3>, auto &, auto &, auto &, auto &) {});
144+
});
145+
146+
return 0;
147+
}

0 commit comments

Comments
 (0)