Skip to content

Revert "[SYCL] Remove deprecated shuffles from the sub-group class" #13463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions sycl/include/sycl/sub_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,64 @@ struct sub_group {
#endif
}

template <typename T>
using EnableIfIsScalarArithmetic =
std::enable_if_t<sycl::detail::is_scalar_arithmetic<T>::value, T>;

/* --- one-input shuffles --- */
/* indices in [0 , sub_group size) */
template <typename T>
__SYCL_DEPRECATED("Shuffles in the sub-group class are deprecated.")
T shuffle(T x, id_type local_id) const {
#ifdef __SYCL_DEVICE_ONLY__
return sycl::detail::spirv::Shuffle(*this, x, local_id);
#else
(void)x;
(void)local_id;
throw sycl::exception(make_error_code(errc::feature_not_supported),
"Sub-groups are not supported on host.");
#endif
}

template <typename T>
__SYCL_DEPRECATED("Shuffles in the sub-group class are deprecated.")
T shuffle_down(T x, uint32_t delta) const {
#ifdef __SYCL_DEVICE_ONLY__
return sycl::detail::spirv::ShuffleDown(*this, x, delta);
#else
(void)x;
(void)delta;
throw sycl::exception(make_error_code(errc::feature_not_supported),
"Sub-groups are not supported on host.");
#endif
}

template <typename T>
__SYCL_DEPRECATED("Shuffles in the sub-group class are deprecated.")
T shuffle_up(T x, uint32_t delta) const {
#ifdef __SYCL_DEVICE_ONLY__
return sycl::detail::spirv::ShuffleUp(*this, x, delta);
#else
(void)x;
(void)delta;
throw sycl::exception(make_error_code(errc::feature_not_supported),
"Sub-groups are not supported on host.");
#endif
}

template <typename T>
__SYCL_DEPRECATED("Shuffles in the sub-group class are deprecated.")
T shuffle_xor(T x, id_type value) const {
#ifdef __SYCL_DEVICE_ONLY__
return sycl::detail::spirv::ShuffleXor(*this, x, value);
#else
(void)x;
(void)value;
throw sycl::exception(make_error_code(errc::feature_not_supported),
"Sub-groups are not supported on host.");
#endif
}

/* --- sub_group load/stores --- */
/* these can map to SIMD or block read/write hardware where available */
#ifdef __SYCL_DEVICE_ONLY__
Expand Down
240 changes: 240 additions & 0 deletions sycl/test-e2e/SubGroup/generic-shuffle.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
// RUN: %{build} -fsycl-device-code-split=per_kernel -o %t.out
// RUN: %{run} %t.out
//
//==-- generic_shuffle.cpp - SYCL sub_group generic shuffle test *- C++ -*--==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "helper.hpp"
#include <algorithm>
#include <complex>
#include <sycl/sycl.hpp>
#include <vector>
template <typename T> class pointer_kernel;

using namespace sycl;

template <typename SpecializationKernelName, typename T>
void check_pointer(queue &Queue, size_t G = 256, size_t L = 64) {
try {
nd_range<1> NdRange(G, L);
buffer<T *> buf(G);
buffer<T *> buf_up(G);
buffer<T *> buf_down(G);
buffer<T *> buf_xor(G);
buffer<size_t> sgsizebuf(1);
Queue.submit([&](handler &cgh) {
auto acc = buf.template get_access<access::mode::read_write>(cgh);
auto acc_up = buf_up.template get_access<access::mode::read_write>(cgh);
auto acc_down =
buf_down.template get_access<access::mode::read_write>(cgh);
auto acc_xor = buf_xor.template get_access<access::mode::read_write>(cgh);
auto sgsizeacc = sgsizebuf.get_access<access::mode::read_write>(cgh);

cgh.parallel_for<SpecializationKernelName>(
NdRange, [=](nd_item<1> NdItem) {
sycl::sub_group SG = NdItem.get_sub_group();
uint32_t wggid = NdItem.get_global_id(0);
uint32_t sgid = SG.get_group_id().get(0);
if (wggid == 0)
sgsizeacc[0] = SG.get_max_local_range()[0];

T *ptr = static_cast<T *>(0x0) + wggid;

/*GID of middle element in every subgroup*/
acc[NdItem.get_global_id()] =
SG.shuffle(ptr, SG.get_max_local_range()[0] / 2);

/* Save GID-SGID */
acc_up[NdItem.get_global_id()] = SG.shuffle_up(ptr, sgid);

/* Save GID+SGID */
acc_down[NdItem.get_global_id()] = SG.shuffle_down(ptr, sgid);

/* Save GID with SGLID = ( SGLID XOR SGID ) % SGMaxSize */
acc_xor[NdItem.get_global_id()] =
SG.shuffle_xor(ptr, sgid % SG.get_max_local_range()[0]);
});
});
host_accessor acc(buf);
host_accessor acc_up(buf_up);
host_accessor acc_down(buf_down);
host_accessor acc_xor(buf_xor);
host_accessor sgsizeacc(sgsizebuf);

size_t sg_size = sgsizeacc[0];
int SGid = 0;
int SGLid = 0;
int SGBeginGid = 0;
for (int j = 0; j < G; j++) {
if (j % L % sg_size == 0) {
SGid++;
SGLid = 0;
SGBeginGid = j;
}
if (j % L == 0) {
SGid = 0;
SGLid = 0;
SGBeginGid = j;
}

/*GID of middle element in every subgroup*/
exit_if_not_equal(acc[j],
static_cast<T *>(0x0) +
(j / L * L + SGid * sg_size + sg_size / 2),
"shuffle");

/* Value GID+SGID for all element except last SGID in SG*/
if (j % L % sg_size + SGid < sg_size && j % L + SGid < L) {
exit_if_not_equal(acc_down[j], static_cast<T *>(0x0) + (j + SGid),
"shuffle_down");
}

/* Value GID-SGID for all element except first SGID in SG*/
if (j % L % sg_size >= SGid) {
exit_if_not_equal(acc_up[j], static_cast<T *>(0x0) + (j - SGid),
"shuffle_up");
}

/* Value GID with SGLID = ( SGLID XOR SGID ) % SGMaxSize */
exit_if_not_equal(acc_xor[j],
static_cast<T *>(0x0) +
(SGBeginGid + (SGLid ^ (SGid % sg_size))),
"shuffle_xor");
SGLid++;
}
} catch (exception e) {
std::cout << "SYCL exception caught: " << e.what();
exit(1);
}
}

template <typename SpecializationKernelName, typename T, typename Generator>
void check_struct(queue &Queue, Generator &Gen, size_t G = 256, size_t L = 64) {

// Fill a vector with values that will be shuffled
std::vector<T> values(G);
std::generate(values.begin(), values.end(), Gen);

try {
nd_range<1> NdRange(G, L);
buffer<T> buf(G);
buffer<T> buf_up(G);
buffer<T> buf_down(G);
buffer<T> buf_xor(G);
buffer<size_t> sgsizebuf(1);
buffer<T> buf_in(values.data(), values.size());
Queue.submit([&](handler &cgh) {
auto acc = buf.template get_access<access::mode::read_write>(cgh);
auto acc_up = buf_up.template get_access<access::mode::read_write>(cgh);
auto acc_down =
buf_down.template get_access<access::mode::read_write>(cgh);
auto acc_xor = buf_xor.template get_access<access::mode::read_write>(cgh);
auto sgsizeacc = sgsizebuf.get_access<access::mode::read_write>(cgh);
auto in = buf_in.template get_access<access::mode::read>(cgh);

cgh.parallel_for<SpecializationKernelName>(
NdRange, [=](nd_item<1> NdItem) {
sycl::sub_group SG = NdItem.get_sub_group();
uint32_t wggid = NdItem.get_global_id(0);
uint32_t sgid = SG.get_group_id().get(0);
if (wggid == 0)
sgsizeacc[0] = SG.get_max_local_range()[0];

T val = in[wggid];

/*GID of middle element in every subgroup*/
acc[NdItem.get_global_id()] =
SG.shuffle(val, SG.get_max_local_range()[0] / 2);

/* Save GID-SGID */
acc_up[NdItem.get_global_id()] = SG.shuffle_up(val, sgid);

/* Save GID+SGID */
acc_down[NdItem.get_global_id()] = SG.shuffle_down(val, sgid);

/* Save GID with SGLID = ( SGLID XOR SGID ) % SGMaxSize */
acc_xor[NdItem.get_global_id()] =
SG.shuffle_xor(val, sgid % SG.get_max_local_range()[0]);
});
});
host_accessor acc(buf);
host_accessor acc_up(buf_up);
host_accessor acc_down(buf_down);
host_accessor acc_xor(buf_xor);
host_accessor sgsizeacc(sgsizebuf);

size_t sg_size = sgsizeacc[0];
int SGid = 0;
int SGLid = 0;
int SGBeginGid = 0;
for (int j = 0; j < G; j++) {
if (j % L % sg_size == 0) {
SGid++;
SGLid = 0;
SGBeginGid = j;
}
if (j % L == 0) {
SGid = 0;
SGLid = 0;
SGBeginGid = j;
}

/*GID of middle element in every subgroup*/
exit_if_not_equal(
acc[j], values[j / L * L + SGid * sg_size + sg_size / 2], "shuffle");

/* Value GID+SGID for all element except last SGID in SG*/
if (j % L % sg_size + SGid < sg_size && j % L + SGid < L) {
exit_if_not_equal(acc_down[j], values[j + SGid], "shuffle_down");
}

/* Value GID-SGID for all element except first SGID in SG*/
if (j % L % sg_size >= SGid) {
exit_if_not_equal(acc_up[j], values[j - SGid], "shuffle_up");
}

/* Value GID with SGLID = ( SGLID XOR SGID ) % SGMaxSize */
exit_if_not_equal(acc_xor[j],
values[SGBeginGid + (SGLid ^ (SGid % sg_size))],
"shuffle_xor");
SGLid++;
}
} catch (exception e) {
std::cout << "SYCL exception caught: " << e.what();
exit(1);
}
}

int main() {
queue Queue;

// Test shuffle of pointer types
check_pointer<class KernelName_mNiN, int>(Queue);

// Test shuffle of non-native types
auto ComplexFloatGenerator = [state = std::complex<float>(0, 1)]() mutable {
return state += std::complex<float>(2, 2);
};
check_struct<class KernelName_zHfIPOLOFsXiZiCvG, std::complex<float>>(
Queue, ComplexFloatGenerator);

if (Queue.get_device().has(sycl::aspect::fp64)) {
auto ComplexDoubleGenerator = [state =
std::complex<double>(0, 1)]() mutable {
return state += std::complex<double>(2, 2);
};
check_struct<class KernelName_CjlHUmnuxWtyejZFD, std::complex<double>>(
Queue, ComplexDoubleGenerator);
} else {
std::cout << "fp64 tests were skipped due to the device not supporting the "
"aspect.";
}

std::cout << "Test passed." << std::endl;
return 0;
}
54 changes: 54 additions & 0 deletions sycl/test-e2e/SubGroup/shuffle.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

//==------------ shuffle.cpp - SYCL sub_group shuffle test -----*- C++ -*---==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "shuffle.hpp"
#include <iostream>

int main() {
queue Queue;
check<short>(Queue);
check<unsigned short>(Queue);
check<int>(Queue);
check<int, 2>(Queue);
check<int, 4>(Queue);
check<int, 8>(Queue);
check<int, 16>(Queue);
check<unsigned int>(Queue);
check<unsigned int, 2>(Queue);
check<unsigned int, 4>(Queue);
check<unsigned int, 8>(Queue);
check<unsigned int, 16>(Queue);
check<long>(Queue);
check<long, 2>(Queue);
check<long, 4>(Queue);
check<long, 8>(Queue);
check<long, 16>(Queue);
check<unsigned long>(Queue);
check<unsigned long, 2>(Queue);
check<unsigned long, 4>(Queue);
check<unsigned long, 8>(Queue);
check<unsigned long, 16>(Queue);
check<float>(Queue);
check<float, 2>(Queue);
check<float, 4>(Queue);
check<float, 8>(Queue);
check<float, 16>(Queue);

// Check long long and unsigned long long because they differ from
// long and unsigned long according to C++ rules even if they have the same
// size at some system.
check<long long>(Queue);
check<long long, 16>(Queue);
check<unsigned long long>(Queue);
check<unsigned long long, 16>(Queue);
std::cout << "Test passed." << std::endl;
return 0;
}
26 changes: 26 additions & 0 deletions sycl/test-e2e/SubGroup/shuffle_fp16.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// REQUIRES: aspect-fp16
// REQUIRES: gpu

// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "shuffle.hpp"
#include <iostream>

int main() {
queue Queue;
check<half>(Queue);
check<half, 2>(Queue);
check<half, 4>(Queue);
check<half, 8>(Queue);
check<half, 16>(Queue);
std::cout << "Test passed." << std::endl;
return 0;
}
Loading