Skip to content

Commit 42641d2

Browse files
Revert "[SYCL] Remove deprecated shuffles from the sub-group class (#13236)"
This reverts commit e9b0e60.
1 parent 2336d02 commit 42641d2

File tree

5 files changed

+403
-0
lines changed

5 files changed

+403
-0
lines changed

sycl/include/sycl/sub_group.hpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,64 @@ struct sub_group {
210210
#endif
211211
}
212212

213+
template <typename T>
214+
using EnableIfIsScalarArithmetic =
215+
std::enable_if_t<sycl::detail::is_scalar_arithmetic<T>::value, T>;
216+
217+
/* --- one-input shuffles --- */
218+
/* indices in [0 , sub_group size) */
219+
template <typename T>
220+
__SYCL_DEPRECATED("Shuffles in the sub-group class are deprecated.")
221+
T shuffle(T x, id_type local_id) const {
222+
#ifdef __SYCL_DEVICE_ONLY__
223+
return sycl::detail::spirv::Shuffle(*this, x, local_id);
224+
#else
225+
(void)x;
226+
(void)local_id;
227+
throw sycl::exception(make_error_code(errc::feature_not_supported),
228+
"Sub-groups are not supported on host.");
229+
#endif
230+
}
231+
232+
template <typename T>
233+
__SYCL_DEPRECATED("Shuffles in the sub-group class are deprecated.")
234+
T shuffle_down(T x, uint32_t delta) const {
235+
#ifdef __SYCL_DEVICE_ONLY__
236+
return sycl::detail::spirv::ShuffleDown(*this, x, delta);
237+
#else
238+
(void)x;
239+
(void)delta;
240+
throw sycl::exception(make_error_code(errc::feature_not_supported),
241+
"Sub-groups are not supported on host.");
242+
#endif
243+
}
244+
245+
template <typename T>
246+
__SYCL_DEPRECATED("Shuffles in the sub-group class are deprecated.")
247+
T shuffle_up(T x, uint32_t delta) const {
248+
#ifdef __SYCL_DEVICE_ONLY__
249+
return sycl::detail::spirv::ShuffleUp(*this, x, delta);
250+
#else
251+
(void)x;
252+
(void)delta;
253+
throw sycl::exception(make_error_code(errc::feature_not_supported),
254+
"Sub-groups are not supported on host.");
255+
#endif
256+
}
257+
258+
template <typename T>
259+
__SYCL_DEPRECATED("Shuffles in the sub-group class are deprecated.")
260+
T shuffle_xor(T x, id_type value) const {
261+
#ifdef __SYCL_DEVICE_ONLY__
262+
return sycl::detail::spirv::ShuffleXor(*this, x, value);
263+
#else
264+
(void)x;
265+
(void)value;
266+
throw sycl::exception(make_error_code(errc::feature_not_supported),
267+
"Sub-groups are not supported on host.");
268+
#endif
269+
}
270+
213271
/* --- sub_group load/stores --- */
214272
/* these can map to SIMD or block read/write hardware where available */
215273
#ifdef __SYCL_DEVICE_ONLY__
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
// RUN: %{build} -fsycl-device-code-split=per_kernel -o %t.out
2+
// RUN: %{run} %t.out
3+
//
4+
//==-- generic_shuffle.cpp - SYCL sub_group generic shuffle test *- C++ -*--==//
5+
//
6+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
7+
// See https://llvm.org/LICENSE.txt for license information.
8+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
9+
//
10+
//===----------------------------------------------------------------------===//
11+
12+
#include "helper.hpp"
13+
#include <algorithm>
14+
#include <complex>
15+
#include <sycl/sycl.hpp>
16+
#include <vector>
17+
template <typename T> class pointer_kernel;
18+
19+
using namespace sycl;
20+
21+
template <typename SpecializationKernelName, typename T>
22+
void check_pointer(queue &Queue, size_t G = 256, size_t L = 64) {
23+
try {
24+
nd_range<1> NdRange(G, L);
25+
buffer<T *> buf(G);
26+
buffer<T *> buf_up(G);
27+
buffer<T *> buf_down(G);
28+
buffer<T *> buf_xor(G);
29+
buffer<size_t> sgsizebuf(1);
30+
Queue.submit([&](handler &cgh) {
31+
auto acc = buf.template get_access<access::mode::read_write>(cgh);
32+
auto acc_up = buf_up.template get_access<access::mode::read_write>(cgh);
33+
auto acc_down =
34+
buf_down.template get_access<access::mode::read_write>(cgh);
35+
auto acc_xor = buf_xor.template get_access<access::mode::read_write>(cgh);
36+
auto sgsizeacc = sgsizebuf.get_access<access::mode::read_write>(cgh);
37+
38+
cgh.parallel_for<SpecializationKernelName>(
39+
NdRange, [=](nd_item<1> NdItem) {
40+
sycl::sub_group SG = NdItem.get_sub_group();
41+
uint32_t wggid = NdItem.get_global_id(0);
42+
uint32_t sgid = SG.get_group_id().get(0);
43+
if (wggid == 0)
44+
sgsizeacc[0] = SG.get_max_local_range()[0];
45+
46+
T *ptr = static_cast<T *>(0x0) + wggid;
47+
48+
/*GID of middle element in every subgroup*/
49+
acc[NdItem.get_global_id()] =
50+
SG.shuffle(ptr, SG.get_max_local_range()[0] / 2);
51+
52+
/* Save GID-SGID */
53+
acc_up[NdItem.get_global_id()] = SG.shuffle_up(ptr, sgid);
54+
55+
/* Save GID+SGID */
56+
acc_down[NdItem.get_global_id()] = SG.shuffle_down(ptr, sgid);
57+
58+
/* Save GID with SGLID = ( SGLID XOR SGID ) % SGMaxSize */
59+
acc_xor[NdItem.get_global_id()] =
60+
SG.shuffle_xor(ptr, sgid % SG.get_max_local_range()[0]);
61+
});
62+
});
63+
host_accessor acc(buf);
64+
host_accessor acc_up(buf_up);
65+
host_accessor acc_down(buf_down);
66+
host_accessor acc_xor(buf_xor);
67+
host_accessor sgsizeacc(sgsizebuf);
68+
69+
size_t sg_size = sgsizeacc[0];
70+
int SGid = 0;
71+
int SGLid = 0;
72+
int SGBeginGid = 0;
73+
for (int j = 0; j < G; j++) {
74+
if (j % L % sg_size == 0) {
75+
SGid++;
76+
SGLid = 0;
77+
SGBeginGid = j;
78+
}
79+
if (j % L == 0) {
80+
SGid = 0;
81+
SGLid = 0;
82+
SGBeginGid = j;
83+
}
84+
85+
/*GID of middle element in every subgroup*/
86+
exit_if_not_equal(acc[j],
87+
static_cast<T *>(0x0) +
88+
(j / L * L + SGid * sg_size + sg_size / 2),
89+
"shuffle");
90+
91+
/* Value GID+SGID for all element except last SGID in SG*/
92+
if (j % L % sg_size + SGid < sg_size && j % L + SGid < L) {
93+
exit_if_not_equal(acc_down[j], static_cast<T *>(0x0) + (j + SGid),
94+
"shuffle_down");
95+
}
96+
97+
/* Value GID-SGID for all element except first SGID in SG*/
98+
if (j % L % sg_size >= SGid) {
99+
exit_if_not_equal(acc_up[j], static_cast<T *>(0x0) + (j - SGid),
100+
"shuffle_up");
101+
}
102+
103+
/* Value GID with SGLID = ( SGLID XOR SGID ) % SGMaxSize */
104+
exit_if_not_equal(acc_xor[j],
105+
static_cast<T *>(0x0) +
106+
(SGBeginGid + (SGLid ^ (SGid % sg_size))),
107+
"shuffle_xor");
108+
SGLid++;
109+
}
110+
} catch (exception e) {
111+
std::cout << "SYCL exception caught: " << e.what();
112+
exit(1);
113+
}
114+
}
115+
116+
template <typename SpecializationKernelName, typename T, typename Generator>
117+
void check_struct(queue &Queue, Generator &Gen, size_t G = 256, size_t L = 64) {
118+
119+
// Fill a vector with values that will be shuffled
120+
std::vector<T> values(G);
121+
std::generate(values.begin(), values.end(), Gen);
122+
123+
try {
124+
nd_range<1> NdRange(G, L);
125+
buffer<T> buf(G);
126+
buffer<T> buf_up(G);
127+
buffer<T> buf_down(G);
128+
buffer<T> buf_xor(G);
129+
buffer<size_t> sgsizebuf(1);
130+
buffer<T> buf_in(values.data(), values.size());
131+
Queue.submit([&](handler &cgh) {
132+
auto acc = buf.template get_access<access::mode::read_write>(cgh);
133+
auto acc_up = buf_up.template get_access<access::mode::read_write>(cgh);
134+
auto acc_down =
135+
buf_down.template get_access<access::mode::read_write>(cgh);
136+
auto acc_xor = buf_xor.template get_access<access::mode::read_write>(cgh);
137+
auto sgsizeacc = sgsizebuf.get_access<access::mode::read_write>(cgh);
138+
auto in = buf_in.template get_access<access::mode::read>(cgh);
139+
140+
cgh.parallel_for<SpecializationKernelName>(
141+
NdRange, [=](nd_item<1> NdItem) {
142+
sycl::sub_group SG = NdItem.get_sub_group();
143+
uint32_t wggid = NdItem.get_global_id(0);
144+
uint32_t sgid = SG.get_group_id().get(0);
145+
if (wggid == 0)
146+
sgsizeacc[0] = SG.get_max_local_range()[0];
147+
148+
T val = in[wggid];
149+
150+
/*GID of middle element in every subgroup*/
151+
acc[NdItem.get_global_id()] =
152+
SG.shuffle(val, SG.get_max_local_range()[0] / 2);
153+
154+
/* Save GID-SGID */
155+
acc_up[NdItem.get_global_id()] = SG.shuffle_up(val, sgid);
156+
157+
/* Save GID+SGID */
158+
acc_down[NdItem.get_global_id()] = SG.shuffle_down(val, sgid);
159+
160+
/* Save GID with SGLID = ( SGLID XOR SGID ) % SGMaxSize */
161+
acc_xor[NdItem.get_global_id()] =
162+
SG.shuffle_xor(val, sgid % SG.get_max_local_range()[0]);
163+
});
164+
});
165+
host_accessor acc(buf);
166+
host_accessor acc_up(buf_up);
167+
host_accessor acc_down(buf_down);
168+
host_accessor acc_xor(buf_xor);
169+
host_accessor sgsizeacc(sgsizebuf);
170+
171+
size_t sg_size = sgsizeacc[0];
172+
int SGid = 0;
173+
int SGLid = 0;
174+
int SGBeginGid = 0;
175+
for (int j = 0; j < G; j++) {
176+
if (j % L % sg_size == 0) {
177+
SGid++;
178+
SGLid = 0;
179+
SGBeginGid = j;
180+
}
181+
if (j % L == 0) {
182+
SGid = 0;
183+
SGLid = 0;
184+
SGBeginGid = j;
185+
}
186+
187+
/*GID of middle element in every subgroup*/
188+
exit_if_not_equal(
189+
acc[j], values[j / L * L + SGid * sg_size + sg_size / 2], "shuffle");
190+
191+
/* Value GID+SGID for all element except last SGID in SG*/
192+
if (j % L % sg_size + SGid < sg_size && j % L + SGid < L) {
193+
exit_if_not_equal(acc_down[j], values[j + SGid], "shuffle_down");
194+
}
195+
196+
/* Value GID-SGID for all element except first SGID in SG*/
197+
if (j % L % sg_size >= SGid) {
198+
exit_if_not_equal(acc_up[j], values[j - SGid], "shuffle_up");
199+
}
200+
201+
/* Value GID with SGLID = ( SGLID XOR SGID ) % SGMaxSize */
202+
exit_if_not_equal(acc_xor[j],
203+
values[SGBeginGid + (SGLid ^ (SGid % sg_size))],
204+
"shuffle_xor");
205+
SGLid++;
206+
}
207+
} catch (exception e) {
208+
std::cout << "SYCL exception caught: " << e.what();
209+
exit(1);
210+
}
211+
}
212+
213+
int main() {
214+
queue Queue;
215+
216+
// Test shuffle of pointer types
217+
check_pointer<class KernelName_mNiN, int>(Queue);
218+
219+
// Test shuffle of non-native types
220+
auto ComplexFloatGenerator = [state = std::complex<float>(0, 1)]() mutable {
221+
return state += std::complex<float>(2, 2);
222+
};
223+
check_struct<class KernelName_zHfIPOLOFsXiZiCvG, std::complex<float>>(
224+
Queue, ComplexFloatGenerator);
225+
226+
if (Queue.get_device().has(sycl::aspect::fp64)) {
227+
auto ComplexDoubleGenerator = [state =
228+
std::complex<double>(0, 1)]() mutable {
229+
return state += std::complex<double>(2, 2);
230+
};
231+
check_struct<class KernelName_CjlHUmnuxWtyejZFD, std::complex<double>>(
232+
Queue, ComplexDoubleGenerator);
233+
} else {
234+
std::cout << "fp64 tests were skipped due to the device not supporting the "
235+
"aspect.";
236+
}
237+
238+
std::cout << "Test passed." << std::endl;
239+
return 0;
240+
}

sycl/test-e2e/SubGroup/shuffle.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// RUN: %{build} -o %t.out
2+
// RUN: %{run} %t.out
3+
4+
//==------------ shuffle.cpp - SYCL sub_group shuffle test -----*- C++ -*---==//
5+
//
6+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
7+
// See https://llvm.org/LICENSE.txt for license information.
8+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
9+
//
10+
//===----------------------------------------------------------------------===//
11+
12+
#include "shuffle.hpp"
13+
#include <iostream>
14+
15+
int main() {
16+
queue Queue;
17+
check<short>(Queue);
18+
check<unsigned short>(Queue);
19+
check<int>(Queue);
20+
check<int, 2>(Queue);
21+
check<int, 4>(Queue);
22+
check<int, 8>(Queue);
23+
check<int, 16>(Queue);
24+
check<unsigned int>(Queue);
25+
check<unsigned int, 2>(Queue);
26+
check<unsigned int, 4>(Queue);
27+
check<unsigned int, 8>(Queue);
28+
check<unsigned int, 16>(Queue);
29+
check<long>(Queue);
30+
check<long, 2>(Queue);
31+
check<long, 4>(Queue);
32+
check<long, 8>(Queue);
33+
check<long, 16>(Queue);
34+
check<unsigned long>(Queue);
35+
check<unsigned long, 2>(Queue);
36+
check<unsigned long, 4>(Queue);
37+
check<unsigned long, 8>(Queue);
38+
check<unsigned long, 16>(Queue);
39+
check<float>(Queue);
40+
check<float, 2>(Queue);
41+
check<float, 4>(Queue);
42+
check<float, 8>(Queue);
43+
check<float, 16>(Queue);
44+
45+
// Check long long and unsigned long long because they differ from
46+
// long and unsigned long according to C++ rules even if they have the same
47+
// size at some system.
48+
check<long long>(Queue);
49+
check<long long, 16>(Queue);
50+
check<unsigned long long>(Queue);
51+
check<unsigned long long, 16>(Queue);
52+
std::cout << "Test passed." << std::endl;
53+
return 0;
54+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// REQUIRES: aspect-fp16
2+
// REQUIRES: gpu
3+
4+
// RUN: %{build} -o %t.out
5+
// RUN: %{run} %t.out
6+
7+
//
8+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9+
// See https://llvm.org/LICENSE.txt for license information.
10+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "shuffle.hpp"
15+
#include <iostream>
16+
17+
int main() {
18+
queue Queue;
19+
check<half>(Queue);
20+
check<half, 2>(Queue);
21+
check<half, 4>(Queue);
22+
check<half, 8>(Queue);
23+
check<half, 16>(Queue);
24+
std::cout << "Test passed." << std::endl;
25+
return 0;
26+
}

0 commit comments

Comments
 (0)