Skip to content

Commit 09ae643

Browse files
authored
[SYCL] Update group_broadcast to support vec types (#8886)
Update the `OpGroupBroadcast` definition to allow vector value and result types. Add special case for `vec` types to enable `GenericBroadcast` since certain backends don't support vector types in `OpGroupBroadcast`. Signed-off-by: Michael Aziz <[email protected]>
1 parent e80b888 commit 09ae643

File tree

3 files changed

+34
-9
lines changed

3 files changed

+34
-9
lines changed

clang/lib/Sema/SPIRVBuiltins.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -926,9 +926,9 @@ foreach name = ["GroupAll", "GroupAny"] in {
926926

927927
foreach name = ["GroupBroadcast"] in {
928928
foreach IDType = TLAllInts.List in {
929-
def : SPVBuiltin<name, [AGenType1, UInt, AGenType1, IDType], Attr.Convergent>;
930-
def : SPVBuiltin<name, [AGenType1, UInt, AGenType1, VectorType<IDType, 2>], Attr.Convergent>;
931-
def : SPVBuiltin<name, [AGenType1, UInt, AGenType1, VectorType<IDType, 3>], Attr.Convergent>;
929+
def : SPVBuiltin<name, [AGenTypeN, UInt, AGenTypeN, IDType], Attr.Convergent>;
930+
def : SPVBuiltin<name, [AGenTypeN, UInt, AGenTypeN, VectorType<IDType, 2>], Attr.Convergent>;
931+
def : SPVBuiltin<name, [AGenTypeN, UInt, AGenTypeN, VectorType<IDType, 3>], Attr.Convergent>;
932932
def : SPVBuiltin<name, [Bool, UInt, Bool, IDType], Attr.Convergent>;
933933
def : SPVBuiltin<name, [Bool, UInt, Bool, VectorType<IDType, 2>], Attr.Convergent>;
934934
def : SPVBuiltin<name, [Bool, UInt, Bool, VectorType<IDType, 3>], Attr.Convergent>;

sycl/include/sycl/detail/spirv.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,12 @@ template <typename Group> bool GroupAny(bool pred) {
103103
}
104104

105105
// Native broadcasts map directly to a SPIR-V GroupBroadcast intrinsic
106-
// FIXME: Do not special-case for half once all backends support all data types.
106+
// FIXME: Do not special-case for half or vec once all backends support all data
107+
// types.
107108
template <typename T>
108-
using is_native_broadcast = bool_constant<detail::is_arithmetic<T>::value &&
109-
!std::is_same<T, half>::value>;
109+
using is_native_broadcast =
110+
bool_constant<detail::is_arithmetic<T>::value &&
111+
!std::is_same<T, half>::value && !detail::is_vec<T>::value>;
110112

111113
template <typename T, typename IdT = size_t>
112114
using EnableIfNativeBroadcast = detail::enable_if_t<

sycl/test-e2e/GroupAlgorithm/SYCL2020/group_broadcast.cpp

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,18 @@
1212
#include <sycl/sycl.hpp>
1313
using namespace sycl;
1414

15+
template <typename T> bool equal(const T &a, const T &b) { return a == b; }
16+
17+
template <typename T, int N>
18+
bool equal(const vec<T, N> &a, const vec<T, N> &b) {
19+
for (int i = 0; i < N; i++) {
20+
if (a[i] != b[i]) {
21+
return false;
22+
}
23+
}
24+
return true;
25+
}
26+
1527
template <typename kernel_name, typename InputContainer,
1628
typename OutputContainer>
1729
void test(queue q, InputContainer input, OutputContainer output) {
@@ -37,9 +49,9 @@ void test(queue q, InputContainer input, OutputContainer output) {
3749
});
3850
});
3951
}
40-
assert(output[0] == input[0]);
41-
assert(output[1] == input[1 * G + 2]);
42-
assert(output[2] == input[2 * G + 1]);
52+
assert(equal(output[0], input[0]));
53+
assert(equal(output[1], input[1 * G + 2]));
54+
assert(equal(output[2], input[2 * G + 1]));
4355
}
4456

4557
int main() {
@@ -71,6 +83,17 @@ int main() {
7183
test<class KernelName_NrqELzFQToOSPsRNMi>(q, input, output);
7284
}
7385

86+
// Test vector types
87+
{
88+
std::array<vec<int, 4>, N> input;
89+
std::array<vec<int, 4>, 3> output;
90+
for (int i = 0; i < N; ++i) {
91+
input[i] = vec<int, 4>{i, i, i, i};
92+
}
93+
std::fill(output.begin(), output.end(), vec<int, 4>{0, 0, 0, 0});
94+
test<class KernelName_VectorGroupBroadcast>(q, input, output);
95+
}
96+
7497
// Test user-defined type
7598
// - Use complex as a proxy for this
7699
// - Test float and double to test 64-bit and 128-bit types

0 commit comments

Comments
 (0)