Skip to content

Commit aa31c1e

Browse files
authored
[SYCL] Fix corner case when using short or char with exclusive scan (#10270)
1 parent 4c6d8e0 commit aa31c1e

File tree

2 files changed

+95
-1
lines changed

2 files changed

+95
-1
lines changed

sycl/include/sycl/group_algorithm.hpp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,13 @@ template <template <typename> typename F> struct get_scalar_binary_op<F<void>> {
158158
using type = F<void>;
159159
};
160160

161+
// ---- is_max_or_min
162+
template <typename T> struct is_max_or_min : std::false_type {};
163+
template <typename T>
164+
struct is_max_or_min<sycl::maximum<T>> : std::true_type {};
165+
template <typename T>
166+
struct is_max_or_min<sycl::minimum<T>> : std::true_type {};
167+
161168
// ---- identity_for_ga_op
162169
// the group algorithms support std::complex, limited to sycl::plus operation
163170
// get the correct identity for group algorithm operation.
@@ -678,8 +685,34 @@ exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
678685
sycl::detail::ExtractMask(sycl::detail::GetMask(g))[0]);
679686
}
680687
#endif
681-
return sycl::detail::calc<__spv::GroupOperation::ExclusiveScan>(
688+
// For the first work item in the group, we cannot return the result
689+
// of calc when T is a signed char or short type and the
690+
// BinaryOperation is maximum or minimum. calc uses SPIRV group
691+
// collective instructions, which only operate on 32 or 64 bit
692+
// integers. So, when using calc with a short or char type, the
693+
// argument is converted to a 32 bit integer, the 32 bit group
694+
// operation is performed, and then converted back to the original
695+
// short or char type. For an exclusive scan, the first work item
696+
// returns the identity for the supplied operation. However, the
697+
// identity of a 32 bit signed integer maximum or minimum when
698+
// converted to a signed char or short does not correspond to the
699+
// identity of a signed char or short maximum or minimum. For
700+
// example, the identity of a signed 32 bit maximum is
701+
// INT_MIN=-2**31, and when converted to a signed char, results in
702+
// 0. However, the identity of a signed char maximum is
703+
// SCHAR_MIN=-2**7. Therefore, we need the following check to
704+
// circumvent this issue.
705+
auto res = sycl::detail::calc<__spv::GroupOperation::ExclusiveScan>(
682706
g, typename sycl::detail::GroupOpTag<T>::type(), x, binary_op);
707+
if constexpr ((std::is_same_v<signed char, T> ||
708+
std::is_same_v<signed short, T> ||
709+
(std::is_signed_v<char> && std::is_same_v<char, T>)) &&
710+
detail::is_max_or_min<BinaryOperation>::value) {
711+
auto local_id = sycl::detail::get_local_linear_id(g);
712+
if (local_id == 0)
713+
return sycl::known_identity_v<BinaryOperation, T>;
714+
}
715+
return res;
683716
#else
684717
(void)g;
685718
throw sycl::exception(make_error_code(errc::feature_not_supported),
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// RUN: %{build} -o %t.out
2+
// RUN: %{run} %t.out
3+
4+
// This test ensures the result computed by exclusive_scan_over_group
5+
// for the first work item when given a short or char argument with
6+
// the maximum or minimum operator is computed correctly.
7+
#include <numeric>
8+
#include <sycl/sycl.hpp>
9+
10+
using namespace sycl;
11+
queue q;
12+
int cur_test = 0;
13+
int n_fail = 0;
14+
15+
template <typename T, typename OpT> void test() {
16+
auto op = OpT();
17+
auto init = sycl::known_identity_v<decltype(op), T>;
18+
auto *p = malloc_shared<T>(1, q);
19+
*p = 0;
20+
T ref;
21+
std::exclusive_scan(p, p + 1, &ref, init, op);
22+
range r(1);
23+
q.parallel_for(nd_range(r, r), [=](nd_item<1> it) {
24+
auto g = it.get_group();
25+
*p = exclusive_scan_over_group(g, *p, op);
26+
}).wait();
27+
28+
if (*p != ref) {
29+
std::cout << "test " << cur_test << " fail\n";
30+
std::cout << "got: " << int(*p) << "\n";
31+
std::cout << "expected: " << int(ref) << "\n\n";
32+
++n_fail;
33+
}
34+
++cur_test;
35+
free(p, q);
36+
}
37+
38+
int main() {
39+
test<char, sycl::maximum<char>>();
40+
test<signed char, sycl::maximum<signed char>>();
41+
test<unsigned char, sycl::maximum<unsigned char>>();
42+
test<char, sycl::maximum<void>>();
43+
test<signed char, sycl::maximum<void>>();
44+
test<unsigned char, sycl::maximum<void>>();
45+
test<short, sycl::maximum<short>>();
46+
test<unsigned short, sycl::maximum<unsigned short>>();
47+
test<short, sycl::maximum<void>>();
48+
test<unsigned short, sycl::maximum<void>>();
49+
50+
test<char, sycl::minimum<char>>();
51+
test<signed char, sycl::minimum<signed char>>();
52+
test<unsigned char, sycl::minimum<unsigned char>>();
53+
test<char, sycl::minimum<void>>();
54+
test<signed char, sycl::minimum<void>>();
55+
test<unsigned char, sycl::minimum<void>>();
56+
test<short, sycl::minimum<short>>();
57+
test<unsigned short, sycl::minimum<unsigned short>>();
58+
test<short, sycl::minimum<void>>();
59+
test<unsigned short, sycl::minimum<void>>();
60+
return n_fail != 0;
61+
}

0 commit comments

Comments
 (0)