14
14
#include < sycl/detail/type_traits.hpp>
15
15
#include < sycl/half_type.hpp>
16
16
17
+ #include < array>
18
+ #include < type_traits>
19
+ #include < utility>
20
+
17
21
namespace sycl {
18
22
__SYCL_INLINE_VER_NAMESPACE (_V1) {
19
23
24
+ template <typename DataT, std::size_t N> class marray ;
25
+
26
+ namespace detail {
27
+
28
+ // Helper trait for counting the aggregate number of arguments in a type list,
29
+ // expanding marrays.
30
+ template <typename ... Ts> struct GetMArrayArgsSize ;
31
+ template <> struct GetMArrayArgsSize <> {
32
+ static constexpr std::size_t value = 0 ;
33
+ };
34
+ template <typename T, std::size_t N, typename ... Ts>
35
+ struct GetMArrayArgsSize <marray<T, N>, Ts...> {
36
+ static constexpr std::size_t value = N + GetMArrayArgsSize<Ts...>::value;
37
+ };
38
+ template <typename T, typename ... Ts> struct GetMArrayArgsSize <T, Ts...> {
39
+ static constexpr std::size_t value = 1 + GetMArrayArgsSize<Ts...>::value;
40
+ };
41
+
42
+ // Helper function for concatenating two std::array.
43
+ template <typename T, std::size_t ... Is1, std::size_t ... Is2>
44
+ constexpr std::array<T, sizeof ...(Is1) + sizeof ...(Is2)>
45
+ ConcatArrays (const std::array<T, sizeof ...(Is1)> &A1,
46
+ const std::array<T, sizeof ...(Is2)> &A2,
47
+ std::index_sequence<Is1...>, std::index_sequence<Is2...>) {
48
+ return {A1[Is1]..., A2[Is2]...};
49
+ }
50
+ template <typename T, std::size_t N1, std::size_t N2>
51
+ constexpr std::array<T, N1 + N2> ConcatArrays (const std::array<T, N1> &A1,
52
+ const std::array<T, N2> &A2) {
53
+ return ConcatArrays (A1, A2, std::make_index_sequence<N1>(),
54
+ std::make_index_sequence<N2>());
55
+ }
56
+
57
+ // Utility trait for creating an std::array from an marray.
58
+ template <typename DataT, typename T, std::size_t ... Is>
59
+ constexpr std::array<T, sizeof ...(Is)>
60
+ MArrayToArray (const marray<T, sizeof ...(Is)> &A, std::index_sequence<Is...>) {
61
+ return {static_cast <DataT>(A.MData [Is])...};
62
+ }
63
+ template <typename DataT, typename T, std::size_t N>
64
+ constexpr std::array<T, N> MArrayToArray (const marray<T, N> &A) {
65
+ return MArrayToArray<DataT>(A, std::make_index_sequence<N>());
66
+ }
67
+
68
+ // Utility for creating an std::array from a arguments of either types
69
+ // convertible to DataT or marrays of a type convertible to DataT.
70
+ template <typename DataT, typename ... ArgTN> struct ArrayCreator ;
71
+ template <typename DataT, typename ArgT, typename ... ArgTN>
72
+ struct ArrayCreator <DataT, ArgT, ArgTN...> {
73
+ static constexpr std::array<DataT, GetMArrayArgsSize<ArgT, ArgTN...>::value>
74
+ Create (const ArgT &Arg, const ArgTN &...Args) {
75
+ return ConcatArrays (std::array<DataT, 1 >{static_cast <DataT>(Arg)},
76
+ ArrayCreator<DataT, ArgTN...>::Create (Args...));
77
+ }
78
+ };
79
+ template <typename DataT, typename T, std::size_t N, typename ... ArgTN>
80
+ struct ArrayCreator <DataT, marray<T, N>, ArgTN...> {
81
+ static constexpr std::array<DataT,
82
+ GetMArrayArgsSize<marray<T, N>, ArgTN...>::value>
83
+ Create (const marray<T, N> &Arg, const ArgTN &...Args) {
84
+ return ConcatArrays (MArrayToArray<DataT>(Arg),
85
+ ArrayCreator<DataT, ArgTN...>::Create (Args...));
86
+ }
87
+ };
88
+ template <typename DataT> struct ArrayCreator <DataT> {
89
+ static constexpr std::array<DataT, 0 > Create () {
90
+ return std::array<DataT, 0 >{};
91
+ }
92
+ };
93
+ } // namespace detail
94
+
20
95
// / Provides a cross-platform math array class template that works on
21
96
// / SYCL devices as well as in host C++ code.
22
97
// /
@@ -34,37 +109,50 @@ template <typename Type, std::size_t NumElements> class marray {
34
109
private:
35
110
value_type MData[NumElements];
36
111
37
- template <class ...> struct conjunction : std::true_type {};
38
- template <class B1 , class ... tail>
39
- struct conjunction <B1, tail...>
40
- : std::conditional<bool (B1::value), conjunction<tail...>, B1>::type {};
41
-
42
- // TypeChecker is needed for (const ArgTN &... Args) ctor to validate Args.
43
- template <typename T, typename DataT_>
44
- struct TypeChecker : std::is_convertible<T, DataT_> {};
112
+ // Trait for checking if an argument type is either convertible to the data
113
+ // type or an array of types convertible to the data type.
114
+ template <typename T>
115
+ struct IsSuitableArgType : std::is_convertible<T, DataT> {};
116
+ template <typename T, size_t N>
117
+ struct IsSuitableArgType <marray<T, N>> : std::is_convertible<T, DataT> {};
45
118
46
- // Shortcuts for Args validation in (const ArgTN &... Args) ctor.
119
+ // Trait for computing the conjunction of of IsSuitableArgType. The empty type
120
+ // list will trivially evaluate to true.
47
121
template <typename ... ArgTN>
48
- using EnableIfSuitableTypes = typename std::enable_if<
49
- conjunction<TypeChecker<ArgTN, DataT>...>::value>::type;
122
+ struct AllSuitableArgTypes : std::conjunction<IsSuitableArgType<ArgTN>...> {};
123
+
124
+ // FIXME: MArrayToArray needs to be a friend to access MData. If the subscript
125
+ // operator is made constexpr this can be removed.
126
+ template <typename , typename T, std::size_t ... Is>
127
+ friend constexpr std::array<T, sizeof ...(Is)>
128
+ detail::MArrayToArray (const marray<T, sizeof ...(Is)> &,
129
+ std::index_sequence<Is...>);
50
130
51
131
constexpr void initialize_data (const Type &Arg) {
52
132
for (size_t i = 0 ; i < NumElements; ++i) {
53
133
MData[i] = Arg;
54
134
}
55
135
}
56
136
137
+ template <size_t ... Is>
138
+ constexpr marray (const std::array<DataT, NumElements> &Arr,
139
+ std::index_sequence<Is...>)
140
+ : MData{Arr[Is]...} {}
141
+
57
142
public:
58
143
constexpr marray () : MData{} {}
59
144
60
145
explicit constexpr marray (const Type &Arg) : MData{Arg} {
61
146
initialize_data (Arg);
62
147
}
63
148
64
- template <
65
- typename ... ArgTN, typename = EnableIfSuitableTypes<ArgTN...>,
66
- typename = typename std::enable_if<sizeof ...(ArgTN) == NumElements>::type>
67
- constexpr marray (const ArgTN &...Args) : MData{static_cast <Type>(Args)...} {}
149
+ template <typename ... ArgTN,
150
+ typename = std::enable_if_t <
151
+ AllSuitableArgTypes<ArgTN...>::value &&
152
+ detail::GetMArrayArgsSize<ArgTN...>::value == NumElements>>
153
+ constexpr marray (const ArgTN &...Args)
154
+ : marray{detail::ArrayCreator<DataT, ArgTN...>::Create (Args...),
155
+ std::make_index_sequence<NumElements>()} {}
68
156
69
157
constexpr marray (const marray<Type, NumElements> &Rhs) = default;
70
158
0 commit comments