Skip to content

Commit 5289d50

Browse files
committed
[SYCL] Make joint_reduce work with sub_group
Note: the unqualified name lookup of joint_reduce in the overload of joint_reduce without an init param was not finding the overload of joint_reduce with an init param (because that declaration was located after it), so it searched for joint_reduce via ADL. With sycl::group, ADL can find both overloads of joint_reduce, but sycl::sub_group = sycl::ext::oneapi::sub_group, ADL finds no joint_reduce in sycl::ext::oneapi. Signed-off-by: Cai, Justin <[email protected]>
1 parent afebb25 commit 5289d50

File tree

1 file changed

+23
-23
lines changed

1 file changed

+23
-23
lines changed

sycl/include/sycl/group_algorithm.hpp

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -315,29 +315,6 @@ reduce_over_group(Group g, V x, T init, BinaryOperation binary_op) {
315315
}
316316

317317
// ---- joint_reduce
318-
template <typename Group, typename Ptr, class BinaryOperation>
319-
detail::enable_if_t<
320-
(is_group_v<std::decay_t<Group>> && detail::is_pointer<Ptr>::value &&
321-
detail::is_arithmetic_or_complex<
322-
typename detail::remove_pointer<Ptr>::type>::value &&
323-
detail::is_plus_or_multiplies_if_complex<
324-
typename detail::remove_pointer<Ptr>::type, BinaryOperation>::value),
325-
typename detail::remove_pointer<Ptr>::type>
326-
joint_reduce(Group g, Ptr first, Ptr last, BinaryOperation binary_op) {
327-
#ifdef __SYCL_DEVICE_ONLY__
328-
using T = typename detail::remove_pointer<Ptr>::type;
329-
T init = detail::identity_for_ga_op<T, BinaryOperation>();
330-
return joint_reduce(g, first, last, init, binary_op);
331-
#else
332-
(void)g;
333-
(void)first;
334-
(void)last;
335-
(void)binary_op;
336-
throw runtime_error("Group algorithms are not supported on host.",
337-
PI_ERROR_INVALID_DEVICE);
338-
#endif
339-
}
340-
341318
template <typename Group, typename Ptr, typename T, class BinaryOperation>
342319
detail::enable_if_t<
343320
(is_group_v<std::decay_t<Group>> && detail::is_pointer<Ptr>::value &&
@@ -373,6 +350,29 @@ joint_reduce(Group g, Ptr first, Ptr last, T init, BinaryOperation binary_op) {
373350
#endif
374351
}
375352

353+
template <typename Group, typename Ptr, class BinaryOperation>
354+
detail::enable_if_t<
355+
(is_group_v<std::decay_t<Group>> && detail::is_pointer<Ptr>::value &&
356+
detail::is_arithmetic_or_complex<
357+
typename detail::remove_pointer<Ptr>::type>::value &&
358+
detail::is_plus_or_multiplies_if_complex<
359+
typename detail::remove_pointer<Ptr>::type, BinaryOperation>::value),
360+
typename detail::remove_pointer<Ptr>::type>
361+
joint_reduce(Group g, Ptr first, Ptr last, BinaryOperation binary_op) {
362+
#ifdef __SYCL_DEVICE_ONLY__
363+
using T = typename detail::remove_pointer<Ptr>::type;
364+
T init = detail::identity_for_ga_op<T, BinaryOperation>();
365+
return joint_reduce(g, first, last, init, binary_op);
366+
#else
367+
(void)g;
368+
(void)first;
369+
(void)last;
370+
(void)binary_op;
371+
throw runtime_error("Group algorithms are not supported on host.",
372+
PI_ERROR_INVALID_DEVICE);
373+
#endif
374+
}
375+
376376
// ---- any_of_group
377377
template <typename Group>
378378
detail::enable_if_t<is_group_v<std::decay_t<Group>>, bool>

0 commit comments

Comments
 (0)