Skip to content

[SYCL] Reintroduce sub_group_mask version 2 #12404

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,7 @@ get_tangle_group(Group group) {
// TODO: Construct from compiler-generated mask. Return an invalid group in
// in the meantime. CUDA devices will report false for the tangle_group
// support aspect so kernels launch should ensure this is never run.
return tangle_group<sycl::sub_group>(
sycl::detail::Builder::createSubGroupMask<
sycl::ext::oneapi::sub_group_mask>(0, 0));
return tangle_group<sycl::sub_group>(0);
#endif
#else
throw runtime_error("Non-uniform groups are not supported on host device.",
Expand Down
76 changes: 55 additions & 21 deletions sycl/include/sycl/ext/oneapi/sub_group_mask.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

#include <sycl/detail/helpers.hpp> // for Builder
#include <sycl/detail/memcpy.hpp> // detail::memcpy
#include <sycl/detail/type_traits.hpp> // for is_sub_group
#include <sycl/exception.hpp> // for errc, exception
#include <sycl/feature_test.hpp> // for SYCL_EXT_ONEAPI_SUB_GROUP_MASK
#include <sycl/id.hpp> // for id
#include <sycl/marray.hpp> // for marray
#include <sycl/types.hpp> // for vec
Expand All @@ -35,25 +35,26 @@ template <typename Group> struct group_scope;

} // namespace detail

// forward decalre sycl::sub_group
struct sub_group;

namespace ext::oneapi {

#if defined(__SYCL_DEVICE_ONLY__) && defined(__AMDGCN__) && \
(__AMDGCN_WAVEFRONT_SIZE == 64)
#define BITS_TYPE uint64_t
#else
#define BITS_TYPE uint32_t
#endif
// forward decalre sycl::ext::oneapi::sub_group
struct sub_group;

// defining `group_ballot` here to make predicate default `true`
// need to forward declare sub_group_mask first
struct sub_group_mask;
template <typename Group>
std::enable_if_t<sycl::detail::is_sub_group<Group>::value, sub_group_mask>
std::enable_if_t<std::is_same_v<std::decay_t<Group>, sub_group> ||
std::is_same_v<std::decay_t<Group>, sycl::sub_group>,
sub_group_mask>
group_ballot(Group g, bool predicate = true);

struct sub_group_mask {
friend class sycl::detail::Builder;
using BitsType = BITS_TYPE;
using BitsType = uint64_t;

static constexpr size_t max_bits =
sizeof(BitsType) * CHAR_BIT /* implementation-defined */;
Expand Down Expand Up @@ -81,7 +82,8 @@ struct sub_group_mask {
}

reference(sub_group_mask &gmask, size_t pos) : Ref(gmask.Bits) {
RefBit = (pos < gmask.bits_num) ? (1UL << pos) : 0;
BitsType one = 1;
RefBit = (pos < gmask.bits_num) ? (one << pos) : 0;
}

private:
Expand All @@ -91,8 +93,36 @@ struct sub_group_mask {
BitsType RefBit;
};

#if SYCL_EXT_ONEAPI_SUB_GROUP_MASK >= 2
sub_group_mask() : sub_group_mask(0, GetMaxLocalRangeSize()){};

sub_group_mask(unsigned long long val)
: sub_group_mask(0, GetMaxLocalRangeSize()) {
Bits = val;
};

template <typename T, std::size_t K,
typename = std::enable_if_t<std::is_integral_v<T>>>
sub_group_mask(const sycl::marray<T, K> &val)
: sub_group_mask(0, GetMaxLocalRangeSize()) {
for (size_t I = 0, BytesCopied = 0; I < K && BytesCopied < sizeof(Bits);
++I) {
size_t RemainingBytes = sizeof(Bits) - BytesCopied;
size_t BytesToCopy =
RemainingBytes < sizeof(T) ? RemainingBytes : sizeof(T);
sycl::detail::memcpy(reinterpret_cast<char *>(&Bits) + BytesCopied,
&val[I], BytesToCopy);
BytesCopied += BytesToCopy;
}
}

sub_group_mask(const sub_group_mask &other) = default;
sub_group_mask &operator=(const sub_group_mask &other) = default;
#endif // SYCL_EXT_ONEAPI_SUB_GROUP_MASK

bool operator[](id<1> id) const {
return (Bits & ((id.get(0) < bits_num) ? (1UL << id.get(0)) : 0));
BitsType one = 1;
return (Bits & ((id.get(0) < bits_num) ? (one << id.get(0)) : 0));
}

reference operator[](id<1> id) { return {*this, id.get(0)}; }
Expand Down Expand Up @@ -254,10 +284,6 @@ struct sub_group_mask {
return Tmp;
}

sub_group_mask(const sub_group_mask &rhs) = default;

sub_group_mask &operator=(const sub_group_mask &rhs) = default;

template <typename Group>
friend std::enable_if_t<std::is_same_v<std::decay_t<Group>, sub_group>,
sub_group_mask>
Expand Down Expand Up @@ -285,6 +311,14 @@ struct sub_group_mask {
}

private:
static size_t GetMaxLocalRangeSize() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_SubgroupMaxSize();
#else
return max_bits;
#endif
}

sub_group_mask(BitsType rhs, size_t bn)
: Bits(rhs & valuable_bits(bn)), bits_num(bn) {
assert(bits_num <= max_bits);
Expand All @@ -302,15 +336,17 @@ struct sub_group_mask {
};

template <typename Group>
std::enable_if_t<sycl::detail::is_sub_group<Group>::value, sub_group_mask>
std::enable_if_t<std::is_same_v<std::decay_t<Group>, sub_group> ||
std::is_same_v<std::decay_t<Group>, sycl::sub_group>,
sub_group_mask>
group_ballot(Group g, bool predicate) {
(void)g;
#ifdef __SYCL_DEVICE_ONLY__
auto res = __spirv_GroupNonUniformBallot(
sycl::detail::spirv::group_scope<Group>::value, predicate);
BITS_TYPE val = res[0];
if constexpr (sizeof(BITS_TYPE) == 8)
val |= ((BITS_TYPE)res[1]) << 32;
sub_group_mask::BitsType val = res[0];
if constexpr (sizeof(sub_group_mask::BitsType) == 8)
val |= ((sub_group_mask::BitsType)res[1]) << 32;
return sycl::detail::Builder::createSubGroupMask<sub_group_mask>(
val, g.get_max_local_range()[0]);
#else
Expand All @@ -320,8 +356,6 @@ group_ballot(Group g, bool predicate) {
#endif
}

#undef BITS_TYPE

} // namespace ext::oneapi
} // namespace _V1
} // namespace sycl
1 change: 1 addition & 0 deletions sycl/include/syclcompat/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include <utility>

#include <sycl/builtins.hpp>
#include <sycl/ext/intel/experimental/usm_properties.hpp>
#include <sycl/ext/oneapi/group_local_memory.hpp>
#include <sycl/usm.hpp>

Expand Down
2 changes: 1 addition & 1 deletion sycl/source/feature_test.hpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ inline namespace _V1 {
// TODO: Move these feature-test macros to compiler driver.
#define SYCL_EXT_INTEL_DEVICE_INFO 6
#define SYCL_EXT_ONEAPI_DEVICE_ARCHITECTURE 1
#define SYCL_EXT_ONEAPI_SUB_GROUP_MASK 1
#define SYCL_EXT_ONEAPI_SUB_GROUP_MASK 2
#define SYCL_EXT_ONEAPI_LOCAL_MEMORY 1
#define SYCL_EXT_ONEAPI_MATRIX 1
#define SYCL_EXT_ONEAPI_ASSERT 1
Expand Down
119 changes: 119 additions & 0 deletions sycl/test-e2e/SubGroupMask/sub_group_mask_ver2.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

#include <iostream>
#include <sycl/sycl.hpp>

#define TEST_ON_DEVICE(TEST_BODY) \
{ \
sycl::queue queue; \
sycl::buffer<bool, 1> buf(&res, 1); \
queue.submit([&](sycl::handler &h) { \
auto acc = buf.get_access<sycl::access_mode::write>(h); \
h.parallel_for(sycl::nd_range<1>(sycl::range<1>(1), sycl::range<1>(1)), \
[=](sycl::nd_item<1> item) { TEST_BODY }); \
}); \
queue.wait(); \
} \
assert(res);

int main() {
#if SYCL_EXT_ONEAPI_SUB_GROUP_MASK >= 2
using mask_type = sycl::ext::oneapi::sub_group_mask;

// sub_group_mask()
{
mask_type mask;
assert(mask.none() && mask_type::max_bits == mask.size());

bool res = false;
// clang-format off
TEST_ON_DEVICE(
mask_type mask;
auto sg = item.get_sub_group();
acc[0] = mask.none() && (sg.get_max_local_range().size() == mask.size());
)
// clang-format on
}
// sub_group_mask(unsigned long long val)
{
unsigned long long val = 4815162342;
mask_type mask(val);
std::bitset<sizeof(val) * CHAR_BIT> bs(val);
bool res = true;
for (size_t i = 0;
i < std::min(static_cast<size_t>(mask.size()), bs.size()); ++i)
res &= mask[i] == bs[i];
assert(res);

// clang-format off
TEST_ON_DEVICE(
mask_type mask(val);
auto sg = item.get_sub_group();
acc[0] = sg.get_max_local_range().size() == mask.size();
for (size_t i = 0;
i < sycl::min(static_cast<size_t>(mask.size()), bs.size());
++i)
acc[0] &= mask[i] == bs[i];
)
// clang-format on
}
// template <typename T, std::size_t K> sub_group_mask(const sycl::marray<T,
// K>& &val)
{
sycl::marray<char, 4> marr{1, 2, 3, 4};
mask_type mask(marr);
std::bitset<CHAR_BIT> bs[4] = {1, 2, 3, 4};
bool res = true;
for (size_t i = 0; i < mask.size() && (i / CHAR_BIT) < 4; ++i)
res &= mask[i] == bs[i / CHAR_BIT][i % CHAR_BIT];
assert(res);

// clang-format off
TEST_ON_DEVICE(
mask_type mask(marr);
auto sg = item.get_sub_group();
acc[0] = sg.get_max_local_range().size() == mask.size();
for (size_t i = 0; i < mask.size() && (i / CHAR_BIT) < 4; ++i)
acc[0] &= mask[i] == bs[i / CHAR_BIT][i % CHAR_BIT];
)
// clang-format on
}
{
// sub_group_mask(const sub_group_mask &other)
unsigned long long val = 4815162342;
mask_type mask1(val);
mask_type mask2(mask1);
assert(mask1 == mask2);

bool res = false;
// clang-format off
TEST_ON_DEVICE(
mask_type mask1(val);
mask_type mask2(mask1);
acc[0] = mask1 == mask2;
)
// clang-format on
}
{
// sub_group_mask& operator=(const sub_group_mask &other)
unsigned long long val = 4815162342;
mask_type mask1(val);
mask_type mask2 = mask1;
assert(mask1 == mask2);

bool res = false;
// clang-format off
TEST_ON_DEVICE(
mask_type mask1(val);
mask_type mask2 = mask1;
acc[0] = mask1 == mask2;
)
// clang-format on
}
#else
std::cout << "Test skipped due to unsupported extension." << std::endl;
#endif

return 0;
}
2 changes: 1 addition & 1 deletion sycl/test/extensions/macro.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ constexpr bool backend_opencl_macro_defined = true;
constexpr bool backend_opencl_macro_defined = false;
#endif

#if SYCL_EXT_ONEAPI_SUB_GROUP_MASK == 1
#ifdef SYCL_EXT_ONEAPI_SUB_GROUP_MASK
constexpr bool sub_group_mask_macro_defined = true;
#else
constexpr bool sub_group_mask_macro_defined = false;
Expand Down
3 changes: 1 addition & 2 deletions sycl/test/extensions/sub_group_mask.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// RUN: %clangxx -fsycl -fsycl-device-only -fsyntax-only %s
//
// This test is intended to check sycl::ext::oneapi::sub_group_mask interface.
// There is a work in progress update to the spec: intel/llvm#8174
// TODO: udpate this test once revision 2 of the extension is supported
// test for spec ver.2: sycl/test-e2e/SubGroupMask/sub_group_mask_ver2.cpp

#include <sycl/sycl.hpp>

Expand Down