Skip to content

Commit becc391

Browse files
[SYCL] Fix variadic marray ctor (#8271)
Currently the variadic marray constructor only accepts values that are directly convertible to the element type of the marray. However, according to the SYCL 2020 specification the arguments of the constructor can be any combination of such values and marrays with suitable types, where the aggregate size of the arguments must match the number of arguments of the specified marray. This patch allows the use of marray arguments in this constructor. --------- Signed-off-by: Larsen, Steffen <[email protected]>
1 parent a7cc399 commit becc391

File tree

2 files changed

+154
-17
lines changed

2 files changed

+154
-17
lines changed

sycl/include/sycl/marray.hpp

Lines changed: 103 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,84 @@
1414
#include <sycl/detail/type_traits.hpp>
1515
#include <sycl/half_type.hpp>
1616

17+
#include <array>
18+
#include <type_traits>
19+
#include <utility>
20+
1721
namespace sycl {
1822
__SYCL_INLINE_VER_NAMESPACE(_V1) {
1923

24+
template <typename DataT, std::size_t N> class marray;
25+
26+
namespace detail {
27+
28+
// Helper trait for counting the aggregate number of arguments in a type list,
29+
// expanding marrays.
30+
template <typename... Ts> struct GetMArrayArgsSize;
31+
template <> struct GetMArrayArgsSize<> {
32+
static constexpr std::size_t value = 0;
33+
};
34+
template <typename T, std::size_t N, typename... Ts>
35+
struct GetMArrayArgsSize<marray<T, N>, Ts...> {
36+
static constexpr std::size_t value = N + GetMArrayArgsSize<Ts...>::value;
37+
};
38+
template <typename T, typename... Ts> struct GetMArrayArgsSize<T, Ts...> {
39+
static constexpr std::size_t value = 1 + GetMArrayArgsSize<Ts...>::value;
40+
};
41+
42+
// Helper function for concatenating two std::array.
43+
template <typename T, std::size_t... Is1, std::size_t... Is2>
44+
constexpr std::array<T, sizeof...(Is1) + sizeof...(Is2)>
45+
ConcatArrays(const std::array<T, sizeof...(Is1)> &A1,
46+
const std::array<T, sizeof...(Is2)> &A2,
47+
std::index_sequence<Is1...>, std::index_sequence<Is2...>) {
48+
return {A1[Is1]..., A2[Is2]...};
49+
}
50+
template <typename T, std::size_t N1, std::size_t N2>
51+
constexpr std::array<T, N1 + N2> ConcatArrays(const std::array<T, N1> &A1,
52+
const std::array<T, N2> &A2) {
53+
return ConcatArrays(A1, A2, std::make_index_sequence<N1>(),
54+
std::make_index_sequence<N2>());
55+
}
56+
57+
// Utility trait for creating an std::array from an marray.
58+
template <typename DataT, typename T, std::size_t... Is>
59+
constexpr std::array<T, sizeof...(Is)>
60+
MArrayToArray(const marray<T, sizeof...(Is)> &A, std::index_sequence<Is...>) {
61+
return {static_cast<DataT>(A.MData[Is])...};
62+
}
63+
template <typename DataT, typename T, std::size_t N>
64+
constexpr std::array<T, N> MArrayToArray(const marray<T, N> &A) {
65+
return MArrayToArray<DataT>(A, std::make_index_sequence<N>());
66+
}
67+
68+
// Utility for creating an std::array from a arguments of either types
69+
// convertible to DataT or marrays of a type convertible to DataT.
70+
template <typename DataT, typename... ArgTN> struct ArrayCreator;
71+
template <typename DataT, typename ArgT, typename... ArgTN>
72+
struct ArrayCreator<DataT, ArgT, ArgTN...> {
73+
static constexpr std::array<DataT, GetMArrayArgsSize<ArgT, ArgTN...>::value>
74+
Create(const ArgT &Arg, const ArgTN &...Args) {
75+
return ConcatArrays(std::array<DataT, 1>{static_cast<DataT>(Arg)},
76+
ArrayCreator<DataT, ArgTN...>::Create(Args...));
77+
}
78+
};
79+
template <typename DataT, typename T, std::size_t N, typename... ArgTN>
80+
struct ArrayCreator<DataT, marray<T, N>, ArgTN...> {
81+
static constexpr std::array<DataT,
82+
GetMArrayArgsSize<marray<T, N>, ArgTN...>::value>
83+
Create(const marray<T, N> &Arg, const ArgTN &...Args) {
84+
return ConcatArrays(MArrayToArray<DataT>(Arg),
85+
ArrayCreator<DataT, ArgTN...>::Create(Args...));
86+
}
87+
};
88+
template <typename DataT> struct ArrayCreator<DataT> {
89+
static constexpr std::array<DataT, 0> Create() {
90+
return std::array<DataT, 0>{};
91+
}
92+
};
93+
} // namespace detail
94+
2095
/// Provides a cross-platform math array class template that works on
2196
/// SYCL devices as well as in host C++ code.
2297
///
@@ -34,37 +109,50 @@ template <typename Type, std::size_t NumElements> class marray {
34109
private:
35110
value_type MData[NumElements];
36111

37-
template <class...> struct conjunction : std::true_type {};
38-
template <class B1, class... tail>
39-
struct conjunction<B1, tail...>
40-
: std::conditional<bool(B1::value), conjunction<tail...>, B1>::type {};
41-
42-
// TypeChecker is needed for (const ArgTN &... Args) ctor to validate Args.
43-
template <typename T, typename DataT_>
44-
struct TypeChecker : std::is_convertible<T, DataT_> {};
112+
// Trait for checking if an argument type is either convertible to the data
113+
// type or an array of types convertible to the data type.
114+
template <typename T>
115+
struct IsSuitableArgType : std::is_convertible<T, DataT> {};
116+
template <typename T, size_t N>
117+
struct IsSuitableArgType<marray<T, N>> : std::is_convertible<T, DataT> {};
45118

46-
// Shortcuts for Args validation in (const ArgTN &... Args) ctor.
119+
// Trait for computing the conjunction of of IsSuitableArgType. The empty type
120+
// list will trivially evaluate to true.
47121
template <typename... ArgTN>
48-
using EnableIfSuitableTypes = typename std::enable_if<
49-
conjunction<TypeChecker<ArgTN, DataT>...>::value>::type;
122+
struct AllSuitableArgTypes : std::conjunction<IsSuitableArgType<ArgTN>...> {};
123+
124+
// FIXME: MArrayToArray needs to be a friend to access MData. If the subscript
125+
// operator is made constexpr this can be removed.
126+
template <typename, typename T, std::size_t... Is>
127+
friend constexpr std::array<T, sizeof...(Is)>
128+
detail::MArrayToArray(const marray<T, sizeof...(Is)> &,
129+
std::index_sequence<Is...>);
50130

51131
constexpr void initialize_data(const Type &Arg) {
52132
for (size_t i = 0; i < NumElements; ++i) {
53133
MData[i] = Arg;
54134
}
55135
}
56136

137+
template <size_t... Is>
138+
constexpr marray(const std::array<DataT, NumElements> &Arr,
139+
std::index_sequence<Is...>)
140+
: MData{Arr[Is]...} {}
141+
57142
public:
58143
constexpr marray() : MData{} {}
59144

60145
explicit constexpr marray(const Type &Arg) : MData{Arg} {
61146
initialize_data(Arg);
62147
}
63148

64-
template <
65-
typename... ArgTN, typename = EnableIfSuitableTypes<ArgTN...>,
66-
typename = typename std::enable_if<sizeof...(ArgTN) == NumElements>::type>
67-
constexpr marray(const ArgTN &...Args) : MData{static_cast<Type>(Args)...} {}
149+
template <typename... ArgTN,
150+
typename = std::enable_if_t<
151+
AllSuitableArgTypes<ArgTN...>::value &&
152+
detail::GetMArrayArgsSize<ArgTN...>::value == NumElements>>
153+
constexpr marray(const ArgTN &...Args)
154+
: marray{detail::ArrayCreator<DataT, ArgTN...>::Create(Args...),
155+
std::make_index_sequence<NumElements>()} {}
68156

69157
constexpr marray(const marray<Type, NumElements> &Rhs) = default;
70158

sycl/test/basic_tests/marray/marray.cpp

100755100644
Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,42 @@
1212
#include <sycl/sycl.hpp>
1313
using namespace sycl;
1414

15+
struct NotDefaultConstructible {
16+
NotDefaultConstructible() = delete;
17+
constexpr NotDefaultConstructible(int){};
18+
};
19+
20+
template <typename DataT> void CheckConstexprVariadicCtors() {
21+
constexpr DataT default_val{1};
22+
23+
constexpr sycl::marray<DataT, 5> marray_with_5_elements(
24+
default_val, default_val, default_val, default_val, default_val);
25+
constexpr sycl::marray<DataT, 3> marray_with_3_elements(
26+
default_val, default_val, default_val);
27+
28+
constexpr sycl::marray<DataT, 6> m1(marray_with_5_elements, default_val);
29+
constexpr sycl::marray<DataT, 6> m2(default_val, marray_with_5_elements);
30+
constexpr sycl::marray<DataT, 7> m3(default_val, marray_with_5_elements,
31+
default_val);
32+
constexpr sycl::marray<DataT, 8> m4(marray_with_5_elements,
33+
marray_with_3_elements);
34+
constexpr sycl::marray<DataT, 9> m5(default_val, marray_with_5_elements,
35+
marray_with_3_elements);
36+
constexpr sycl::marray<DataT, 9> m6(marray_with_5_elements, default_val,
37+
marray_with_3_elements);
38+
constexpr sycl::marray<DataT, 9> m7(marray_with_5_elements,
39+
marray_with_3_elements, default_val);
40+
constexpr sycl::marray<DataT, 10> m8(default_val, marray_with_5_elements,
41+
default_val, marray_with_3_elements);
42+
constexpr sycl::marray<DataT, 10> m9(default_val, marray_with_5_elements,
43+
marray_with_3_elements, default_val);
44+
constexpr sycl::marray<DataT, 10> m10(marray_with_5_elements, default_val,
45+
marray_with_3_elements, default_val);
46+
constexpr sycl::marray<DataT, 11> m11(default_val, marray_with_5_elements,
47+
default_val, marray_with_3_elements,
48+
default_val);
49+
}
50+
1551
int main() {
1652
// Constructing vector from a scalar
1753
sycl::marray<int, 1> marray_from_one_elem(1);
@@ -99,6 +135,21 @@ int main() {
99135
constexpr sycl::marray<double, 5> mb(ma);
100136
constexpr sycl::marray<double, 5> mc = ma;
101137

138+
// check variadic ctor
139+
CheckConstexprVariadicCtors<bool>();
140+
CheckConstexprVariadicCtors<std::int8_t>();
141+
CheckConstexprVariadicCtors<std::uint8_t>();
142+
CheckConstexprVariadicCtors<std::int16_t>();
143+
CheckConstexprVariadicCtors<std::uint16_t>();
144+
CheckConstexprVariadicCtors<std::int32_t>();
145+
CheckConstexprVariadicCtors<std::uint32_t>();
146+
CheckConstexprVariadicCtors<std::int64_t>();
147+
CheckConstexprVariadicCtors<std::uint64_t>();
148+
CheckConstexprVariadicCtors<sycl::half>();
149+
CheckConstexprVariadicCtors<float>();
150+
CheckConstexprVariadicCtors<double>();
151+
CheckConstexprVariadicCtors<NotDefaultConstructible>();
152+
102153
// check trivially copyability
103154
struct Copyable {
104155
int a;
@@ -119,6 +170,4 @@ int main() {
119170
"sycl::marray<std::string, 5> is device copyable type");
120171

121172
return 0;
122-
123-
return 0;
124173
}

0 commit comments

Comments
 (0)