Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit 0d6e673

Browse files
authored
[SYCL] Specified ONEAPI aliases in functions and function objects (#300)
* [SYCL] Added test for constexpr half operations Signed-off-by: mdimakov <[email protected]>
1 parent 907cce4 commit 0d6e673

File tree

8 files changed

+52
-46
lines changed

8 files changed

+52
-46
lines changed

SYCL/GroupAlgorithm/exclusive_scan.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -140,24 +140,26 @@ int main() {
140140
std::iota(input.begin(), input.end(), 0);
141141
std::fill(output.begin(), output.end(), 0);
142142

143-
test<class KernelNamePlusV>(q, input, output, plus<>(), 0);
144-
test<class KernelNameMinimumV>(q, input, output, minimum<>(),
143+
test<class KernelNamePlusV>(q, input, output, ONEAPI::plus<>(), 0);
144+
test<class KernelNameMinimumV>(q, input, output, ONEAPI::minimum<>(),
145145
std::numeric_limits<int>::max());
146-
test<class KernelNameMaximumV>(q, input, output, maximum<>(),
146+
test<class KernelNameMaximumV>(q, input, output, ONEAPI::maximum<>(),
147147
std::numeric_limits<int>::lowest());
148148

149-
test<class KernelNamePlusI>(q, input, output, plus<int>(), 0);
150-
test<class KernelNameMinimumI>(q, input, output, minimum<int>(),
149+
test<class KernelNamePlusI>(q, input, output, ONEAPI::plus<int>(), 0);
150+
test<class KernelNameMinimumI>(q, input, output, ONEAPI::minimum<int>(),
151151
std::numeric_limits<int>::max());
152-
test<class KernelNameMaximumI>(q, input, output, maximum<int>(),
152+
test<class KernelNameMaximumI>(q, input, output, ONEAPI::maximum<int>(),
153153
std::numeric_limits<int>::lowest());
154154

155155
#ifdef SPIRV_1_3
156-
test<class KernelName_VzAPutpBRRJrQPB>(q, input, output, multiplies<int>(),
157-
1);
158-
test<class KernelName_UXdGbr>(q, input, output, bit_or<int>(), 0);
159-
test<class KernelName_saYaodNyJknrPW>(q, input, output, bit_xor<int>(), 0);
160-
test<class KernelName_GPcuAlvAOjrDyP>(q, input, output, bit_and<int>(), ~0);
156+
test<class KernelName_VzAPutpBRRJrQPB>(q, input, output,
157+
ONEAPI::multiplies<int>(), 1);
158+
test<class KernelName_UXdGbr>(q, input, output, ONEAPI::bit_or<int>(), 0);
159+
test<class KernelName_saYaodNyJknrPW>(q, input, output,
160+
ONEAPI::bit_xor<int>(), 0);
161+
test<class KernelName_GPcuAlvAOjrDyP>(q, input, output,
162+
ONEAPI::bit_and<int>(), ~0);
161163
#endif // SPIRV_1_3
162164

163165
std::cout << "Test passed." << std::endl;

SYCL/GroupAlgorithm/inclusive_scan.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -140,25 +140,27 @@ int main() {
140140
std::iota(input.begin(), input.end(), 0);
141141
std::fill(output.begin(), output.end(), 0);
142142

143-
test<class KernelNamePlusV>(q, input, output, plus<>(), 0);
144-
test<class KernelNameMinimumV>(q, input, output, minimum<>(),
143+
test<class KernelNamePlusV>(q, input, output, ONEAPI::plus<>(), 0);
144+
test<class KernelNameMinimumV>(q, input, output, ONEAPI::minimum<>(),
145145
std::numeric_limits<int>::max());
146-
test<class KernelNameMaximumV>(q, input, output, maximum<>(),
146+
test<class KernelNameMaximumV>(q, input, output, ONEAPI::maximum<>(),
147147
std::numeric_limits<int>::lowest());
148148

149-
test<class KernelNamePlusI>(q, input, output, plus<int>(), 0);
150-
test<class KernelNameMinimumI>(q, input, output, minimum<int>(),
149+
test<class KernelNamePlusI>(q, input, output, ONEAPI::plus<int>(), 0);
150+
test<class KernelNameMinimumI>(q, input, output, ONEAPI::minimum<int>(),
151151
std::numeric_limits<int>::max());
152-
test<class KernelNameMaximumI>(q, input, output, maximum<int>(),
152+
test<class KernelNameMaximumI>(q, input, output, ONEAPI::maximum<int>(),
153153
std::numeric_limits<int>::lowest());
154154

155155
#ifdef SPIRV_1_3
156156
test<class KernelName_zMyjxUrBgeUGoxmDwhvJ>(q, input, output,
157-
multiplies<int>(), 1);
158-
test<class KernelName_SljjtroxNRaAXoVnT>(q, input, output, bit_or<int>(), 0);
159-
test<class KernelName_yXIZfjwjxQGiPeQAnc>(q, input, output, bit_xor<int>(),
160-
0);
161-
test<class KernelName_xGnAnMYHvqekCk>(q, input, output, bit_and<int>(), ~0);
157+
ONEAPI::multiplies<int>(), 1);
158+
test<class KernelName_SljjtroxNRaAXoVnT>(q, input, output,
159+
ONEAPI::bit_or<int>(), 0);
160+
test<class KernelName_yXIZfjwjxQGiPeQAnc>(q, input, output,
161+
ONEAPI::bit_xor<int>(), 0);
162+
test<class KernelName_xGnAnMYHvqekCk>(q, input, output,
163+
ONEAPI::bit_and<int>(), ~0);
162164
#endif // SPIRV_1_3
163165

164166
std::cout << "Test passed." << std::endl;

SYCL/GroupAlgorithm/reduce.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,24 +76,26 @@ int main() {
7676
std::iota(input.begin(), input.end(), 0);
7777
std::fill(output.begin(), output.end(), 0);
7878

79-
test<class KernelNamePlusV>(q, input, output, plus<>(), 0);
80-
test<class KernelNameMinimumV>(q, input, output, minimum<>(),
79+
test<class KernelNamePlusV>(q, input, output, ONEAPI::plus<>(), 0);
80+
test<class KernelNameMinimumV>(q, input, output, ONEAPI::minimum<>(),
8181
std::numeric_limits<int>::max());
82-
test<class KernelNameMaximumV>(q, input, output, maximum<>(),
82+
test<class KernelNameMaximumV>(q, input, output, ONEAPI::maximum<>(),
8383
std::numeric_limits<int>::lowest());
8484

85-
test<class KernelNamePlusI>(q, input, output, plus<int>(), 0);
86-
test<class KernelNameMinimumI>(q, input, output, minimum<int>(),
85+
test<class KernelNamePlusI>(q, input, output, ONEAPI::plus<int>(), 0);
86+
test<class KernelNameMinimumI>(q, input, output, ONEAPI::minimum<int>(),
8787
std::numeric_limits<int>::max());
88-
test<class KernelNameMaximumI>(q, input, output, maximum<int>(),
88+
test<class KernelNameMaximumI>(q, input, output, ONEAPI::maximum<int>(),
8989
std::numeric_limits<int>::lowest());
9090

9191
#ifdef SPIRV_1_3
9292
test<class KernelName_WonwuUVPUPOTKRKIBtT>(q, input, output,
93-
multiplies<int>(), 1);
94-
test<class KernelName_qYBaJDZTMGkdIwD>(q, input, output, bit_or<int>(), 0);
95-
test<class KernelName_eLSFt>(q, input, output, bit_xor<int>(), 0);
96-
test<class KernelName_uFhJnxSVhNAiFPTG>(q, input, output, bit_and<int>(), ~0);
93+
ONEAPI::multiplies<int>(), 1);
94+
test<class KernelName_qYBaJDZTMGkdIwD>(q, input, output,
95+
ONEAPI::bit_or<int>(), 0);
96+
test<class KernelName_eLSFt>(q, input, output, ONEAPI::bit_xor<int>(), 0);
97+
test<class KernelName_uFhJnxSVhNAiFPTG>(q, input, output,
98+
ONEAPI::bit_and<int>(), ~0);
9799
#endif // SPIRV_1_3
98100

99101
std::cout << "Test passed." << std::endl;

SYCL/SubGroup/broadcast.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ template <typename T> void check(queue &Queue) {
2323
ONEAPI::sub_group SG = NdItem.get_sub_group();
2424
/*Broadcast GID of element with SGLID == SGID % SGMLR*/
2525
syclacc[NdItem.get_global_id()] =
26-
broadcast(SG, T(NdItem.get_global_id(0)),
27-
SG.get_group_id() % SG.get_max_local_range()[0]);
26+
ONEAPI::broadcast(SG, T(NdItem.get_global_id(0)),
27+
SG.get_group_id() % SG.get_max_local_range()[0]);
2828
if (NdItem.get_global_id(0) == 0)
2929
sgsizeacc[0] = SG.get_max_local_range()[0];
3030
});

SYCL/SubGroup/generic_reduce.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ void check_op(queue &Queue, T init, BinaryOperation op, bool skip_init = false,
2525
ONEAPI::sub_group sg = NdItem.get_sub_group();
2626
if (skip_init) {
2727
acc[NdItem.get_global_id(0)] =
28-
reduce(sg, T(NdItem.get_global_id(0)), op);
28+
ONEAPI::reduce(sg, T(NdItem.get_global_id(0)), op);
2929
} else {
3030
acc[NdItem.get_global_id(0)] =
31-
reduce(sg, T(NdItem.get_global_id(0)), init, op);
31+
ONEAPI::reduce(sg, T(NdItem.get_global_id(0)), init, op);
3232
}
3333
if (NdItem.get_global_id(0) == 0)
3434
sgsizeacc[0] = sg.get_max_local_range()[0];

SYCL/SubGroup/reduce.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ void check_op(queue &Queue, T init, BinaryOperation op, bool skip_init = false,
2828
ONEAPI::sub_group sg = NdItem.get_sub_group();
2929
if (skip_init) {
3030
acc[NdItem.get_global_id(0)] =
31-
reduce(sg, T(NdItem.get_global_id(0)), op);
31+
ONEAPI::reduce(sg, T(NdItem.get_global_id(0)), op);
3232
} else {
3333
acc[NdItem.get_global_id(0)] =
34-
reduce(sg, T(NdItem.get_global_id(0)), init, op);
34+
ONEAPI::reduce(sg, T(NdItem.get_global_id(0)), init, op);
3535
}
3636
if (NdItem.get_global_id(0) == 0)
3737
sgsizeacc[0] = sg.get_max_local_range()[0];

SYCL/SubGroup/scan.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@ void check_op(queue &Queue, T init, BinaryOperation op, bool skip_init = false,
3030
ONEAPI::sub_group sg = NdItem.get_sub_group();
3131
if (skip_init) {
3232
exacc[NdItem.get_global_id(0)] =
33-
exclusive_scan(sg, T(NdItem.get_global_id(0)), op);
33+
ONEAPI::exclusive_scan(sg, T(NdItem.get_global_id(0)), op);
3434
inacc[NdItem.get_global_id(0)] =
35-
inclusive_scan(sg, T(NdItem.get_global_id(0)), op);
35+
ONEAPI::inclusive_scan(sg, T(NdItem.get_global_id(0)), op);
3636
} else {
37-
exacc[NdItem.get_global_id(0)] =
38-
exclusive_scan(sg, T(NdItem.get_global_id(0)), init, op);
39-
inacc[NdItem.get_global_id(0)] =
40-
inclusive_scan(sg, T(NdItem.get_global_id(0)), op, init);
37+
exacc[NdItem.get_global_id(0)] = ONEAPI::exclusive_scan(
38+
sg, T(NdItem.get_global_id(0)), init, op);
39+
inacc[NdItem.get_global_id(0)] = ONEAPI::inclusive_scan(
40+
sg, T(NdItem.get_global_id(0)), op, init);
4141
}
4242
if (NdItem.get_global_id(0) == 0)
4343
sgsizeacc[0] = sg.get_max_local_range()[0];

SYCL/SubGroup/vote.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,12 @@ void check(queue Queue, const int G, const int L, const int D, const int R) {
4848
cgh.parallel_for<class subgr>(NdRange, [=](nd_item<1> NdItem) {
4949
ONEAPI::sub_group SG = NdItem.get_sub_group();
5050
/* Set to 1 if any local ID in subgroup devided by D has remainder R */
51-
if (any_of(SG, SG.get_local_id().get(0) % D == R)) {
51+
if (ONEAPI::any_of(SG, SG.get_local_id().get(0) % D == R)) {
5252
sganyacc[NdItem.get_global_id()] = 1;
5353
}
5454
/* Set to 1 if remainder of division of subgroup local ID by D is less
5555
* than R for all work items in subgroup */
56-
if (all_of(SG, SG.get_local_id().get(0) % D < R)) {
56+
if (ONEAPI::all_of(SG, SG.get_local_id().get(0) % D < R)) {
5757
sgallacc[NdItem.get_global_id()] = 1;
5858
}
5959
});

0 commit comments

Comments
 (0)