Skip to content

Commit df6d715

Browse files
authored
[SYCL] Extend broadcast to TriviallyCopyable types (#2160)
Uses the TriviallyCopyable shuffle approach for broadcasts. Note that this works for both work-groups and sub-groups, because OpGroupBroadcast is defined as supporting both groups in SPIR-V. Signed-off-by: John Pennycook <[email protected]>
1 parent ca2c5bb commit df6d715

File tree

3 files changed

+197
-58
lines changed

3 files changed

+197
-58
lines changed

sycl/include/CL/sycl/detail/spirv.hpp

Lines changed: 142 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,32 @@ template <> struct group_scope<::cl::sycl::intel::sub_group> {
3333
static constexpr __spv::Scope::Flag value = __spv::Scope::Flag::Subgroup;
3434
};
3535

36+
// Generic shuffles and broadcasts may require multiple calls to SPIR-V
37+
// intrinsics, and should use the fewest broadcasts possible
38+
// - Loop over 64-bit chunks until remaining bytes < 64-bit
39+
// - At most one 32-bit, 16-bit and 8-bit chunk left over
40+
template <typename T, typename Functor>
41+
void GenericCall(const Functor &ApplyToBytes) {
42+
if (sizeof(T) >= sizeof(uint64_t)) {
43+
#pragma unroll
44+
for (size_t Offset = 0; Offset < sizeof(T); Offset += sizeof(uint64_t)) {
45+
ApplyToBytes(Offset, sizeof(uint64_t));
46+
}
47+
}
48+
if (sizeof(T) % sizeof(uint64_t) >= sizeof(uint32_t)) {
49+
size_t Offset = sizeof(T) / sizeof(uint64_t) * sizeof(uint64_t);
50+
ApplyToBytes(Offset, sizeof(uint32_t));
51+
}
52+
if (sizeof(T) % sizeof(uint32_t) >= sizeof(uint16_t)) {
53+
size_t Offset = sizeof(T) / sizeof(uint32_t) * sizeof(uint32_t);
54+
ApplyToBytes(Offset, sizeof(uint16_t));
55+
}
56+
if (sizeof(T) % sizeof(uint16_t) >= sizeof(uint8_t)) {
57+
size_t Offset = sizeof(T) / sizeof(uint16_t) * sizeof(uint16_t);
58+
ApplyToBytes(Offset, sizeof(uint8_t));
59+
}
60+
}
61+
3662
template <typename Group> bool GroupAll(bool pred) {
3763
return __spirv_GroupAll(group_scope<Group>::value, pred);
3864
}
@@ -41,47 +67,137 @@ template <typename Group> bool GroupAny(bool pred) {
4167
return __spirv_GroupAny(group_scope<Group>::value, pred);
4268
}
4369

70+
// Native broadcasts map directly to a SPIR-V GroupBroadcast intrinsic
71+
template <typename T>
72+
using is_native_broadcast = bool_constant<detail::is_arithmetic<T>::value>;
73+
74+
template <typename T, typename IdT = size_t>
75+
using EnableIfNativeBroadcast = detail::enable_if_t<
76+
is_native_broadcast<T>::value && std::is_integral<IdT>::value, T>;
77+
78+
// Bitcast broadcasts can be implemented using a single SPIR-V GroupBroadcast
79+
// intrinsic, but require type-punning via an appropriate integer type
80+
template <typename T>
81+
using is_bitcast_broadcast = bool_constant<
82+
!is_native_broadcast<T>::value && std::is_trivially_copyable<T>::value &&
83+
(sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4 || sizeof(T) == 8)>;
84+
85+
template <typename T, typename IdT = size_t>
86+
using EnableIfBitcastBroadcast = detail::enable_if_t<
87+
is_bitcast_broadcast<T>::value && std::is_integral<IdT>::value, T>;
88+
89+
template <typename T>
90+
using ConvertToNativeBroadcastType_t = select_cl_scalar_integral_unsigned_t<T>;
91+
92+
// Generic broadcasts may require multiple calls to SPIR-V GroupBroadcast
93+
// intrinsics, and should use the fewest broadcasts possible
94+
// - Loop over 64-bit chunks until remaining bytes < 64-bit
95+
// - At most one 32-bit, 16-bit and 8-bit chunk left over
96+
template <typename T>
97+
using is_generic_broadcast =
98+
bool_constant<!is_native_broadcast<T>::value &&
99+
!is_bitcast_broadcast<T>::value &&
100+
std::is_trivially_copyable<T>::value>;
101+
102+
template <typename T, typename IdT = size_t>
103+
using EnableIfGenericBroadcast = detail::enable_if_t<
104+
is_generic_broadcast<T>::value && std::is_integral<IdT>::value, T>;
105+
44106
// Broadcast with scalar local index
45107
// Work-group supports any integral type
46108
// Sub-group currently supports only uint32_t
109+
template <typename Group> struct GroupId { using type = size_t; };
110+
template <> struct GroupId<::cl::sycl::intel::sub_group> {
111+
using type = uint32_t;
112+
};
47113
template <typename Group, typename T, typename IdT>
48-
detail::enable_if_t<is_group<Group>::value && std::is_integral<IdT>::value, T>
49-
GroupBroadcast(T x, IdT local_id) {
114+
EnableIfNativeBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
115+
using GroupIdT = typename GroupId<Group>::type;
116+
GroupIdT GroupLocalId = static_cast<GroupIdT>(local_id);
50117
using OCLT = detail::ConvertToOpenCLType_t<T>;
51-
using OCLIdT = detail::ConvertToOpenCLType_t<IdT>;
52-
OCLT ocl_x = detail::convertDataToType<T, OCLT>(x);
53-
OCLIdT ocl_id = detail::convertDataToType<IdT, OCLIdT>(local_id);
54-
return __spirv_GroupBroadcast(group_scope<Group>::value, ocl_x, ocl_id);
118+
using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT>;
119+
OCLT OCLX = detail::convertDataToType<T, OCLT>(x);
120+
OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
121+
return __spirv_GroupBroadcast(group_scope<Group>::value, OCLX, OCLId);
55122
}
56123
template <typename Group, typename T, typename IdT>
57-
detail::enable_if_t<is_sub_group<Group>::value && std::is_integral<IdT>::value,
58-
T>
59-
GroupBroadcast(T x, IdT local_id) {
60-
using SGIdT = uint32_t;
61-
SGIdT sg_local_id = static_cast<SGIdT>(local_id);
62-
using OCLT = detail::ConvertToOpenCLType_t<T>;
63-
using OCLIdT = detail::ConvertToOpenCLType_t<SGIdT>;
64-
OCLT ocl_x = detail::convertDataToType<T, OCLT>(x);
65-
OCLIdT ocl_id = detail::convertDataToType<SGIdT, OCLIdT>(sg_local_id);
66-
return __spirv_GroupBroadcast(group_scope<Group>::value, ocl_x, ocl_id);
124+
EnableIfBitcastBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
125+
using GroupIdT = typename GroupId<Group>::type;
126+
GroupIdT GroupLocalId = static_cast<GroupIdT>(local_id);
127+
using BroadcastT = ConvertToNativeBroadcastType_t<T>;
128+
using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT>;
129+
auto BroadcastX = detail::bit_cast<BroadcastT>(x);
130+
OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
131+
BroadcastT Result =
132+
__spirv_GroupBroadcast(group_scope<Group>::value, BroadcastX, OCLId);
133+
return detail::bit_cast<T>(Result);
134+
}
135+
template <typename Group, typename T, typename IdT>
136+
EnableIfGenericBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
137+
T Result;
138+
char *XBytes = reinterpret_cast<char *>(&x);
139+
char *ResultBytes = reinterpret_cast<char *>(&Result);
140+
auto BroadcastBytes = [=](size_t Offset, size_t Size) {
141+
uint64_t BroadcastX, BroadcastResult;
142+
detail::memcpy(&BroadcastX, XBytes + Offset, Size);
143+
BroadcastResult = GroupBroadcast<Group>(BroadcastX, local_id);
144+
detail::memcpy(ResultBytes + Offset, &BroadcastResult, Size);
145+
};
146+
GenericCall<T>(BroadcastBytes);
147+
return Result;
67148
}
68149

69150
// Broadcast with vector local index
70151
template <typename Group, typename T, int Dimensions>
71-
T GroupBroadcast(T x, id<Dimensions> local_id) {
152+
EnableIfNativeBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
72153
if (Dimensions == 1) {
73154
return GroupBroadcast<Group>(x, local_id[0]);
74155
}
75156
using IdT = vec<size_t, Dimensions>;
76157
using OCLT = detail::ConvertToOpenCLType_t<T>;
77158
using OCLIdT = detail::ConvertToOpenCLType_t<IdT>;
78-
IdT vec_id;
159+
IdT VecId;
160+
for (int i = 0; i < Dimensions; ++i) {
161+
VecId[i] = local_id[Dimensions - i - 1];
162+
}
163+
OCLT OCLX = detail::convertDataToType<T, OCLT>(x);
164+
OCLIdT OCLId = detail::convertDataToType<IdT, OCLIdT>(VecId);
165+
return __spirv_GroupBroadcast(group_scope<Group>::value, OCLX, OCLId);
166+
}
167+
template <typename Group, typename T, int Dimensions>
168+
EnableIfBitcastBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
169+
if (Dimensions == 1) {
170+
return GroupBroadcast<Group>(x, local_id[0]);
171+
}
172+
using IdT = vec<size_t, Dimensions>;
173+
using BroadcastT = ConvertToNativeBroadcastType_t<T>;
174+
using OCLIdT = detail::ConvertToOpenCLType_t<IdT>;
175+
IdT VecId;
79176
for (int i = 0; i < Dimensions; ++i) {
80-
vec_id[i] = local_id[Dimensions - i - 1];
177+
VecId[i] = local_id[Dimensions - i - 1];
81178
}
82-
OCLT ocl_x = detail::convertDataToType<T, OCLT>(x);
83-
OCLIdT ocl_id = detail::convertDataToType<IdT, OCLIdT>(vec_id);
84-
return __spirv_GroupBroadcast(group_scope<Group>::value, ocl_x, ocl_id);
179+
auto BroadcastX = detail::bit_cast<BroadcastT>(x);
180+
OCLIdT OCLId = detail::convertDataToType<IdT, OCLIdT>(VecId);
181+
BroadcastT Result =
182+
__spirv_GroupBroadcast(group_scope<Group>::value, BroadcastX, OCLId);
183+
return detail::bit_cast<T>(Result);
184+
}
185+
template <typename Group, typename T, int Dimensions>
186+
EnableIfGenericBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
187+
if (Dimensions == 1) {
188+
return GroupBroadcast<Group>(x, local_id[0]);
189+
}
190+
T Result;
191+
char *XBytes = reinterpret_cast<char *>(&x);
192+
char *ResultBytes = reinterpret_cast<char *>(&Result);
193+
auto BroadcastBytes = [=](size_t Offset, size_t Size) {
194+
uint64_t BroadcastX, BroadcastResult;
195+
detail::memcpy(&BroadcastX, XBytes + Offset, Size);
196+
BroadcastResult = GroupBroadcast<Group>(BroadcastX, local_id);
197+
detail::memcpy(ResultBytes + Offset, &BroadcastResult, Size);
198+
};
199+
GenericCall<T>(BroadcastBytes);
200+
return Result;
85201
}
86202

87203
// Single happens-before means semantics should always apply to all spaces
@@ -400,28 +516,6 @@ using EnableIfGenericShuffle =
400516
sizeof(T) == 4 || sizeof(T) == 8)),
401517
T>;
402518

403-
template <typename T, typename ShuffleFunctor>
404-
void GenericShuffle(const ShuffleFunctor &ShuffleBytes) {
405-
if (sizeof(T) >= sizeof(uint64_t)) {
406-
#pragma unroll
407-
for (size_t Offset = 0; Offset < sizeof(T); Offset += sizeof(uint64_t)) {
408-
ShuffleBytes(Offset, sizeof(uint64_t));
409-
}
410-
}
411-
if (sizeof(T) % sizeof(uint64_t) >= sizeof(uint32_t)) {
412-
size_t Offset = sizeof(T) / sizeof(uint64_t) * sizeof(uint64_t);
413-
ShuffleBytes(Offset, sizeof(uint32_t));
414-
}
415-
if (sizeof(T) % sizeof(uint32_t) >= sizeof(uint16_t)) {
416-
size_t Offset = sizeof(T) / sizeof(uint32_t) * sizeof(uint32_t);
417-
ShuffleBytes(Offset, sizeof(uint16_t));
418-
}
419-
if (sizeof(T) % sizeof(uint16_t) >= sizeof(uint8_t)) {
420-
size_t Offset = sizeof(T) / sizeof(uint16_t) * sizeof(uint16_t);
421-
ShuffleBytes(Offset, sizeof(uint8_t));
422-
}
423-
}
424-
425519
template <typename T>
426520
EnableIfGenericShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
427521
T Result;
@@ -433,7 +527,7 @@ EnableIfGenericShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
433527
ShuffleResult = SubgroupShuffle(ShuffleX, local_id);
434528
detail::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
435529
};
436-
GenericShuffle<T>(ShuffleBytes);
530+
GenericCall<T>(ShuffleBytes);
437531
return Result;
438532
}
439533

@@ -448,7 +542,7 @@ EnableIfGenericShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
448542
ShuffleResult = SubgroupShuffleXor(ShuffleX, local_id);
449543
detail::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
450544
};
451-
GenericShuffle<T>(ShuffleBytes);
545+
GenericCall<T>(ShuffleBytes);
452546
return Result;
453547
}
454548

@@ -465,7 +559,7 @@ EnableIfGenericShuffle<T> SubgroupShuffleDown(T x, T y, id<1> local_id) {
465559
ShuffleResult = SubgroupShuffleDown(ShuffleX, ShuffleY, local_id);
466560
detail::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
467561
};
468-
GenericShuffle<T>(ShuffleBytes);
562+
GenericCall<T>(ShuffleBytes);
469563
return Result;
470564
}
471565

@@ -482,7 +576,7 @@ EnableIfGenericShuffle<T> SubgroupShuffleUp(T x, T y, id<1> local_id) {
482576
ShuffleResult = SubgroupShuffleUp(ShuffleX, ShuffleY, local_id);
483577
detail::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
484578
};
485-
GenericShuffle<T>(ShuffleBytes);
579+
GenericCall<T>(ShuffleBytes);
486580
return Result;
487581
}
488582

sycl/include/CL/sycl/intel/group_algorithm.hpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,12 @@ template <typename Ptr, typename T>
138138
using EnableIfIsPointer =
139139
cl::sycl::detail::enable_if_t<cl::sycl::detail::is_pointer<Ptr>::value, T>;
140140

141+
template <typename T>
142+
using EnableIfIsTriviallyCopyable = cl::sycl::detail::enable_if_t<
143+
std::is_trivially_copyable<T>::value &&
144+
!cl::sycl::detail::is_vector_arithmetic<T>::value,
145+
T>;
146+
141147
// EnableIf shorthands for algorithms that depend on type and an operator
142148
template <typename T, typename BinaryOperation>
143149
using EnableIfIsScalarArithmeticNativeOp = cl::sycl::detail::enable_if_t<
@@ -286,8 +292,8 @@ EnableIfIsPointer<Ptr, bool> none_of(Group g, Ptr first, Ptr last,
286292
}
287293

288294
template <typename Group, typename T>
289-
EnableIfIsScalarArithmetic<T> broadcast(Group, T x,
290-
typename Group::id_type local_id) {
295+
EnableIfIsTriviallyCopyable<T> broadcast(Group, T x,
296+
typename Group::id_type local_id) {
291297
static_assert(sycl::detail::is_generic_group<Group>::value,
292298
"Group algorithms only support the sycl::group and "
293299
"intel::sub_group class.");
@@ -323,7 +329,7 @@ EnableIfIsVectorArithmetic<T> broadcast(Group g, T x,
323329
}
324330

325331
template <typename Group, typename T>
326-
EnableIfIsScalarArithmetic<T>
332+
EnableIfIsTriviallyCopyable<T>
327333
broadcast(Group g, T x, typename Group::linear_id_type linear_local_id) {
328334
static_assert(sycl::detail::is_generic_group<Group>::value,
329335
"Group algorithms only support the sycl::group and "
@@ -363,7 +369,7 @@ broadcast(Group g, T x, typename Group::linear_id_type linear_local_id) {
363369
}
364370

365371
template <typename Group, typename T>
366-
EnableIfIsScalarArithmetic<T> broadcast(Group g, T x) {
372+
EnableIfIsTriviallyCopyable<T> broadcast(Group g, T x) {
367373
static_assert(sycl::detail::is_generic_group<Group>::value,
368374
"Group algorithms only support the sycl::group and "
369375
"intel::sub_group class.");

sycl/test/group-algorithm/broadcast.cpp

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,19 @@
1010
#include <CL/sycl.hpp>
1111
#include <algorithm>
1212
#include <cassert>
13+
#include <complex>
1314
#include <numeric>
1415
using namespace sycl;
1516
using namespace sycl::intel;
1617

18+
template <typename InputContainer, typename OutputContainer>
1719
class broadcast_kernel;
1820

1921
template <typename InputContainer, typename OutputContainer>
2022
void test(queue q, InputContainer input, OutputContainer output) {
2123
typedef typename InputContainer::value_type InputT;
2224
typedef typename OutputContainer::value_type OutputT;
23-
typedef class broadcast_kernel kernel_name;
25+
typedef class broadcast_kernel<InputContainer, OutputContainer> kernel_name;
2426
size_t N = input.size();
2527
size_t G = 4;
2628
range<2> R(G, G);
@@ -54,12 +56,49 @@ int main() {
5456
}
5557

5658
constexpr int N = 16;
57-
std::array<int, N> input;
58-
std::array<int, N> output;
59-
std::iota(input.begin(), input.end(), 1);
60-
std::fill(output.begin(), output.end(), false);
6159

62-
test(q, input, output);
60+
// Test built-in scalar type
61+
{
62+
std::array<int, N> input;
63+
std::array<int, 3> output;
64+
std::iota(input.begin(), input.end(), 1);
65+
std::fill(output.begin(), output.end(), false);
66+
test(q, input, output);
67+
}
68+
69+
// Test pointer type
70+
{
71+
std::array<int *, N> input;
72+
std::array<int *, 3> output;
73+
for (int i = 0; i < N; ++i) {
74+
input[i] = static_cast<int *>(0x0) + i;
75+
}
76+
std::fill(output.begin(), output.end(), static_cast<int *>(0x0));
77+
test(q, input, output);
78+
}
6379

80+
// Test user-defined type
81+
// - Use complex as a proxy for this
82+
// - Test float and double to test 64-bit and 128-bit types
83+
{
84+
std::array<std::complex<float>, N> input;
85+
std::array<std::complex<float>, 3> output;
86+
for (int i = 0; i < N; ++i) {
87+
input[i] =
88+
std::complex<float>(0, 1) + (float)i * std::complex<float>(2, 2);
89+
}
90+
std::fill(output.begin(), output.end(), std::complex<float>(0, 0));
91+
test(q, input, output);
92+
}
93+
{
94+
std::array<std::complex<double>, N> input;
95+
std::array<std::complex<double>, 3> output;
96+
for (int i = 0; i < N; ++i) {
97+
input[i] =
98+
std::complex<double>(0, 1) + (double)i * std::complex<double>(2, 2);
99+
}
100+
std::fill(output.begin(), output.end(), std::complex<float>(0, 0));
101+
test(q, input, output);
102+
}
64103
std::cout << "Test passed." << std::endl;
65104
}

0 commit comments

Comments
 (0)