@@ -18,17 +18,18 @@ using namespace sycl;
18
18
template <class SpecializationKernelName , int TestNumber>
19
19
class exclusive_scan_kernel ;
20
20
21
- template <typename SpecializationKernelName, typename InputContainer,
22
- typename OutputContainer, class BinaryOperation >
21
+ template <typename BinaryOperation> class K0 ;
22
+ template <typename BinaryOperation> class K1 ;
23
+ template <typename BinaryOperation> class K2 ;
24
+ template <typename BinaryOperation> class K3 ;
25
+
26
+ template <typename InputContainer, typename OutputContainer,
27
+ class BinaryOperation >
23
28
void test (queue q, InputContainer input, OutputContainer output,
24
29
BinaryOperation binary_op,
25
30
typename OutputContainer::value_type identity) {
26
31
typedef typename InputContainer::value_type InputT;
27
32
typedef typename OutputContainer::value_type OutputT;
28
- typedef class exclusive_scan_kernel <SpecializationKernelName, 0 > kernel_name0;
29
- typedef class exclusive_scan_kernel <SpecializationKernelName, 1 > kernel_name1;
30
- typedef class exclusive_scan_kernel <SpecializationKernelName, 2 > kernel_name2;
31
- typedef class exclusive_scan_kernel <SpecializationKernelName, 3 > kernel_name3;
32
33
OutputT init = 42 ;
33
34
size_t N = input.size ();
34
35
size_t G = 64 ;
@@ -44,11 +45,12 @@ void test(queue q, InputContainer input, OutputContainer output,
44
45
q.submit ([&](handler &cgh) {
45
46
accessor in{in_buf, cgh, sycl::read_only};
46
47
accessor out{out_buf, cgh, sycl::write_only, sycl::no_init};
47
- cgh.parallel_for <kernel_name0>(nd_range<1 >(G, G), [=](nd_item<1 > it) {
48
- group<1 > g = it.get_group ();
49
- int lid = it.get_local_id (0 );
50
- out[lid] = exclusive_scan_over_group (g, in[lid], binary_op);
51
- });
48
+ cgh.parallel_for <K0<BinaryOperation>>(
49
+ nd_range<1 >(G, G), [=](nd_item<1 > it) {
50
+ group<1 > g = it.get_group ();
51
+ int lid = it.get_local_id (0 );
52
+ out[lid] = exclusive_scan_over_group (g, in[lid], binary_op);
53
+ });
52
54
});
53
55
54
56
complete_fusion_with_check (
@@ -73,11 +75,12 @@ void test(queue q, InputContainer input, OutputContainer output,
73
75
q.submit ([&](handler &cgh) {
74
76
accessor in{in_buf, cgh, sycl::read_only};
75
77
accessor out{out_buf, cgh, sycl::write_only, sycl::no_init};
76
- cgh.parallel_for <kernel_name1>(nd_range<1 >(G, G), [=](nd_item<1 > it) {
77
- group<1 > g = it.get_group ();
78
- int lid = it.get_local_id (0 );
79
- out[lid] = exclusive_scan_over_group (g, in[lid], init, binary_op);
80
- });
78
+ cgh.parallel_for <K1<BinaryOperation>>(
79
+ nd_range<1 >(G, G), [=](nd_item<1 > it) {
80
+ group<1 > g = it.get_group ();
81
+ int lid = it.get_local_id (0 );
82
+ out[lid] = exclusive_scan_over_group (g, in[lid], init, binary_op);
83
+ });
81
84
});
82
85
83
86
complete_fusion_with_check (
@@ -102,13 +105,14 @@ void test(queue q, InputContainer input, OutputContainer output,
102
105
q.submit ([&](handler &cgh) {
103
106
accessor in{in_buf, cgh, sycl::read_only};
104
107
accessor out{out_buf, cgh, sycl::write_only, sycl::no_init};
105
- cgh.parallel_for <kernel_name2>(nd_range<1 >(G, G), [=](nd_item<1 > it) {
106
- group<1 > g = it.get_group ();
107
- joint_exclusive_scan (
108
- g, in.template get_multi_ptr <access::decorated::no>(),
109
- in.template get_multi_ptr <access::decorated::no>() + N,
110
- out.template get_multi_ptr <access::decorated::no>(), binary_op);
111
- });
108
+ cgh.parallel_for <K2<BinaryOperation>>(
109
+ nd_range<1 >(G, G), [=](nd_item<1 > it) {
110
+ group<1 > g = it.get_group ();
111
+ joint_exclusive_scan (
112
+ g, in.template get_multi_ptr <access::decorated::no>(),
113
+ in.template get_multi_ptr <access::decorated::no>() + N,
114
+ out.template get_multi_ptr <access::decorated::no>(), binary_op);
115
+ });
112
116
});
113
117
complete_fusion_with_check (
114
118
fw, ext::codeplay::experimental::property::no_barriers{});
@@ -131,14 +135,15 @@ void test(queue q, InputContainer input, OutputContainer output,
131
135
q.submit ([&](handler &cgh) {
132
136
accessor in{in_buf, cgh, sycl::read_only};
133
137
accessor out{out_buf, cgh, sycl::write_only, sycl::no_init};
134
- cgh.parallel_for <kernel_name3>(nd_range<1 >(G, G), [=](nd_item<1 > it) {
135
- group<1 > g = it.get_group ();
136
- joint_exclusive_scan (
137
- g, in.template get_multi_ptr <access::decorated::no>(),
138
- in.template get_multi_ptr <access::decorated::no>() + N,
139
- out.template get_multi_ptr <access::decorated::no>(), init,
140
- binary_op);
141
- });
138
+ cgh.parallel_for <K3<BinaryOperation>>(
139
+ nd_range<1 >(G, G), [=](nd_item<1 > it) {
140
+ group<1 > g = it.get_group ();
141
+ joint_exclusive_scan (
142
+ g, in.template get_multi_ptr <access::decorated::no>(),
143
+ in.template get_multi_ptr <access::decorated::no>() + N,
144
+ out.template get_multi_ptr <access::decorated::no>(), init,
145
+ binary_op);
146
+ });
142
147
});
143
148
complete_fusion_with_check (
144
149
fw, ext::codeplay::experimental::property::no_barriers{});
@@ -160,24 +165,17 @@ int main() {
160
165
std::array<int , N> output;
161
166
std::fill (output.begin (), output.end (), 0 );
162
167
163
- test<class KernelNamePlusV >(q, input, output, sycl::plus<>(), 0 );
164
- test<class KernelNameMinimumV >(q, input, output, sycl::minimum<>(),
165
- std::numeric_limits<int >::max ());
166
- test<class KernelNameMaximumV >(q, input, output, sycl::maximum<>(),
167
- std::numeric_limits<int >::lowest ());
168
-
169
- test<class KernelNamePlusI >(q, input, output, sycl::plus<int >(), 0 );
170
- test<class KernelNameMinimumI >(q, input, output, sycl::minimum<int >(),
171
- std::numeric_limits<int >::max ());
172
- test<class KernelNameMaximumI >(q, input, output, sycl::maximum<int >(),
173
- std::numeric_limits<int >::lowest ());
174
- test<class KernelName_VzAPutpBRRJrQPB >(q, input, output,
175
- sycl::multiplies<int >(), 1 );
176
- test<class KernelName_UXdGbr >(q, input, output, sycl::bit_or<int >(), 0 );
177
- test<class KernelName_saYaodNyJknrPW >(q, input, output, sycl::bit_xor<int >(),
178
- 0 );
179
- test<class KernelName_GPcuAlvAOjrDyP >(q, input, output, sycl::bit_and<int >(),
180
- ~0 );
168
+ test (q, input, output, sycl::plus<>(), 0 );
169
+ test (q, input, output, sycl::minimum<>(), std::numeric_limits<int >::max ());
170
+ test (q, input, output, sycl::maximum<>(), std::numeric_limits<int >::lowest ());
171
+ test (q, input, output, sycl::plus<int >(), 0 );
172
+ test (q, input, output, sycl::minimum<int >(), std::numeric_limits<int >::max ());
173
+ test (q, input, output, sycl::maximum<int >(),
174
+ std::numeric_limits<int >::lowest ());
175
+ test (q, input, output, sycl::multiplies<int >(), 1 );
176
+ test (q, input, output, sycl::bit_or<int >(), 0 );
177
+ test (q, input, output, sycl::bit_xor<int >(), 0 );
178
+ test (q, input, output, sycl::bit_and<int >(), ~0 );
181
179
182
180
std::cout << " Test passed." << std::endl;
183
181
}
0 commit comments