Skip to content

[SYCL] Implement sub_group_mask version 2 #11195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Oct 2, 2023
78 changes: 57 additions & 21 deletions sycl/include/sycl/ext/oneapi/sub_group_mask.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
#pragma once

#include <sycl/detail/helpers.hpp> // for Builder
#include <sycl/detail/type_traits.hpp> // for is_sub_group
#include <sycl/exception.hpp> // for errc, exception
#include <sycl/feature_test.hpp> // for SYCL_EXT_ONEAPI_SUB_GROUP_MASK
#include <sycl/id.hpp> // for id
#include <sycl/marray.hpp> // for marray

#include <assert.h> // for assert
#include <climits> // for CHAR_BIT
#include <cstring> // for memcpy
#include <stddef.h> // for size_t
#include <stdint.h> // for uint32_t
#include <system_error> // for error_code
Expand All @@ -33,25 +34,26 @@ template <typename Group> struct group_scope;

} // namespace detail

// forward decalre sycl::sub_group
struct sub_group;

namespace ext::oneapi {

#if defined(__SYCL_DEVICE_ONLY__) && defined(__AMDGCN__) && \
(__AMDGCN_WAVEFRONT_SIZE == 64)
#define BITS_TYPE uint64_t
#else
#define BITS_TYPE uint32_t
#endif
// forward decalre sycl::ext::oneapi::sub_group
struct sub_group;

// defining `group_ballot` here to make predicate default `true`
// need to forward declare sub_group_mask first
struct sub_group_mask;
template <typename Group>
std::enable_if_t<sycl::detail::is_sub_group<Group>::value, sub_group_mask>
std::enable_if_t<std::is_same_v<std::decay_t<Group>, sub_group> ||
std::is_same_v<std::decay_t<Group>, sycl::sub_group>,
sub_group_mask>
group_ballot(Group g, bool predicate = true);

struct sub_group_mask {
friend class sycl::detail::Builder;
using BitsType = BITS_TYPE;
using BitsType = uint64_t;

static constexpr size_t max_bits =
sizeof(BitsType) * CHAR_BIT /* implementation-defined */;
Expand Down Expand Up @@ -79,7 +81,8 @@ struct sub_group_mask {
}

reference(sub_group_mask &gmask, size_t pos) : Ref(gmask.Bits) {
RefBit = (pos < gmask.bits_num) ? (1UL << pos) : 0;
BitsType one = 1;
RefBit = (pos < gmask.bits_num) ? (one << pos) : 0;
}

private:
Expand All @@ -89,8 +92,37 @@ struct sub_group_mask {
BitsType RefBit;
};

#if SYCL_EXT_ONEAPI_SUB_GROUP_MASK >= 2
sub_group_mask() : sub_group_mask(0, GetMaxLocalRangeSize()){};

sub_group_mask(unsigned long long val)
: sub_group_mask(0, GetMaxLocalRangeSize()) {
Bits = val;
};

template <typename T, std::size_t K,
typename = std::enable_if_t<std::is_integral_v<T>>>
sub_group_mask(const sycl::marray<T, K> &val)
: sub_group_mask(0, GetMaxLocalRangeSize()) {
for (size_t I = 0, BytesCopied = 0; I < K && BytesCopied < sizeof(Bits);
++I) {
size_t RemainingBytes = sizeof(Bits) - BytesCopied;
size_t BytesToCopy =
RemainingBytes < sizeof(T) ? RemainingBytes : sizeof(T);
// TODO: memcpy is not guaranteed to work in kernels. Find alternative.
std::memcpy(reinterpret_cast<char *>(&Bits) + BytesCopied, &val[I],
BytesToCopy);
BytesCopied += BytesToCopy;
}
}

sub_group_mask(const sub_group_mask &other) = default;
sub_group_mask &operator=(const sub_group_mask &other) = default;
#endif // SYCL_EXT_ONEAPI_SUB_GROUP_MASK

bool operator[](id<1> id) const {
return (Bits & ((id.get(0) < bits_num) ? (1UL << id.get(0)) : 0));
BitsType one = 1;
return (Bits & ((id.get(0) < bits_num) ? (one << id.get(0)) : 0));
}

reference operator[](id<1> id) { return {*this, id.get(0)}; }
Expand Down Expand Up @@ -252,10 +284,6 @@ struct sub_group_mask {
return Tmp;
}

sub_group_mask(const sub_group_mask &rhs) = default;

sub_group_mask &operator=(const sub_group_mask &rhs) = default;

template <typename Group>
friend std::enable_if_t<std::is_same_v<std::decay_t<Group>, sub_group>,
sub_group_mask>
Expand Down Expand Up @@ -283,6 +311,14 @@ struct sub_group_mask {
}

private:
static size_t GetMaxLocalRangeSize() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_SubgroupMaxSize();
#else
return max_bits;
#endif
}

sub_group_mask(BitsType rhs, size_t bn)
: Bits(rhs & valuable_bits(bn)), bits_num(bn) {
assert(bits_num <= max_bits);
Expand All @@ -300,15 +336,17 @@ struct sub_group_mask {
};

template <typename Group>
std::enable_if_t<sycl::detail::is_sub_group<Group>::value, sub_group_mask>
std::enable_if_t<std::is_same_v<std::decay_t<Group>, sub_group> ||
std::is_same_v<std::decay_t<Group>, sycl::sub_group>,
sub_group_mask>
group_ballot(Group g, bool predicate) {
(void)g;
#ifdef __SYCL_DEVICE_ONLY__
auto res = __spirv_GroupNonUniformBallot(
sycl::detail::spirv::group_scope<Group>::value, predicate);
BITS_TYPE val = res[0];
if constexpr (sizeof(BITS_TYPE) == 8)
val |= ((BITS_TYPE)res[1]) << 32;
sub_group_mask::BitsType val = res[0];
if constexpr (sizeof(sub_group_mask::BitsType) == 8)
val |= ((sub_group_mask::BitsType)res[1]) << 32;
return sycl::detail::Builder::createSubGroupMask<sub_group_mask>(
val, g.get_max_local_range()[0]);
#else
Expand All @@ -318,8 +356,6 @@ group_ballot(Group g, bool predicate) {
#endif
}

#undef BITS_TYPE

} // namespace ext::oneapi
} // namespace _V1
} // namespace sycl
1 change: 1 addition & 0 deletions sycl/include/syclcompat/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include <utility>

#include <sycl/builtins.hpp>
#include <sycl/ext/intel/experimental/usm_properties.hpp>
#include <sycl/ext/oneapi/group_local_memory.hpp>
#include <sycl/usm.hpp>

Expand Down
2 changes: 1 addition & 1 deletion sycl/source/feature_test.hpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ inline namespace _V1 {
// TODO: Move these feature-test macros to compiler driver.
#define SYCL_EXT_INTEL_DEVICE_INFO 6
#define SYCL_EXT_ONEAPI_DEVICE_ARCHITECTURE 1
#define SYCL_EXT_ONEAPI_SUB_GROUP_MASK 1
#define SYCL_EXT_ONEAPI_SUB_GROUP_MASK 2
#define SYCL_EXT_ONEAPI_LOCAL_MEMORY 1
#define SYCL_EXT_ONEAPI_MATRIX 1
#define SYCL_EXT_ONEAPI_ASSERT 1
Expand Down
119 changes: 119 additions & 0 deletions sycl/test-e2e/SubGroupMask/sub_group_mask_ver2.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

#include <iostream>
#include <sycl/sycl.hpp>

#define TEST_ON_DEVICE(TEST_BODY) \
{ \
sycl::queue queue; \
sycl::buffer<bool, 1> buf(&res, 1); \
queue.submit([&](sycl::handler &h) { \
auto acc = buf.get_access<sycl::access_mode::write>(h); \
h.parallel_for(sycl::nd_range<1>(sycl::range<1>(1), sycl::range<1>(1)), \
[=](sycl::nd_item<1> item) { TEST_BODY }); \
}); \
queue.wait(); \
} \
assert(res);

int main() {
#if SYCL_EXT_ONEAPI_SUB_GROUP_MASK >= 2
using mask_type = sycl::ext::oneapi::sub_group_mask;

// sub_group_mask()
{
mask_type mask;
assert(mask.none() && mask_type::max_bits == mask.size());

bool res = false;
// clang-format off
TEST_ON_DEVICE(
mask_type mask;
auto sg = item.get_sub_group();
acc[0] = mask.none() && (sg.get_max_local_range().size() == mask.size());
)
// clang-format on
}
// sub_group_mask(unsigned long long val)
{
unsigned long long val = 4815162342;
mask_type mask(val);
std::bitset<sizeof(val) * CHAR_BIT> bs(val);
bool res = true;
for (size_t i = 0;
i < std::min(static_cast<size_t>(mask.size()), bs.size()); ++i)
res &= mask[i] == bs[i];
assert(res);

// clang-format off
TEST_ON_DEVICE(
mask_type mask(val);
auto sg = item.get_sub_group();
acc[0] = sg.get_max_local_range().size() == mask.size();
for (size_t i = 0;
i < sycl::min(static_cast<size_t>(mask.size()), bs.size());
++i)
acc[0] &= mask[i] == bs[i];
)
// clang-format on
}
// template <typename T, std::size_t K> sub_group_mask(const sycl::marray<T,
// K>& &val)
{
sycl::marray<char, 4> marr{1, 2, 3, 4};
mask_type mask(marr);
std::bitset<CHAR_BIT> bs[4] = {1, 2, 3, 4};
bool res = true;
for (size_t i = 0; i < mask.size() && (i / CHAR_BIT) < 4; ++i)
res &= mask[i] == bs[i / CHAR_BIT][i % CHAR_BIT];
assert(res);

// clang-format off
TEST_ON_DEVICE(
mask_type mask(marr);
auto sg = item.get_sub_group();
acc[0] = sg.get_max_local_range().size() == mask.size();
for (size_t i = 0; i < mask.size() && (i / CHAR_BIT) < 4; ++i)
acc[0] &= mask[i] == bs[i / CHAR_BIT][i % CHAR_BIT];
)
// clang-format on
}
{
// sub_group_mask(const sub_group_mask &other)
unsigned long long val = 4815162342;
mask_type mask1(val);
mask_type mask2(mask1);
assert(mask1 == mask2);

bool res = false;
// clang-format off
TEST_ON_DEVICE(
mask_type mask1(val);
mask_type mask2(mask1);
acc[0] = mask1 == mask2;
)
// clang-format on
}
{
// sub_group_mask& operator=(const sub_group_mask &other)
unsigned long long val = 4815162342;
mask_type mask1(val);
mask_type mask2 = mask1;
assert(mask1 == mask2);

bool res = false;
// clang-format off
TEST_ON_DEVICE(
mask_type mask1(val);
mask_type mask2 = mask1;
acc[0] = mask1 == mask2;
)
// clang-format on
}
#else
std::cout << "Test skipped due to unsupported extension." << std::endl;
#endif

return 0;
}
2 changes: 1 addition & 1 deletion sycl/test/extensions/macro.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ constexpr bool backend_opencl_macro_defined = true;
constexpr bool backend_opencl_macro_defined = false;
#endif

#if SYCL_EXT_ONEAPI_SUB_GROUP_MASK == 1
#ifdef SYCL_EXT_ONEAPI_SUB_GROUP_MASK
constexpr bool sub_group_mask_macro_defined = true;
#else
constexpr bool sub_group_mask_macro_defined = false;
Expand Down
3 changes: 1 addition & 2 deletions sycl/test/extensions/sub_group_mask.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// RUN: %clangxx -fsycl -fsycl-device-only -fsyntax-only %s
//
// This test is intended to check sycl::ext::oneapi::sub_group_mask interface.
// There is a work in progress update to the spec: intel/llvm#8174
// TODO: udpate this test once revision 2 of the extension is supported
// test for spec ver.2: sycl/test-e2e/SubGroupMask/sub_group_mask_ver2.cpp

#include <sycl/sycl.hpp>

Expand Down