@@ -74,7 +74,7 @@ DPCTLSyclEventRef dpnp_all_c(DPCTLSyclQueueRef q_ref,
74
74
sycl::nd_range<1 > gws (gws_range, lws_range);
75
75
76
76
auto kernel_parallel_for_func = [=](sycl::nd_item<1 > nd_it) {
77
- auto gr = nd_it.get_group ();
77
+ auto gr = nd_it.get_sub_group ();
78
78
const auto max_gr_size = gr.get_max_local_range ()[0 ];
79
79
const size_t start =
80
80
vec_sz * (nd_it.get_group (0 ) * nd_it.get_local_range (0 ) +
@@ -127,8 +127,79 @@ DPCTLSyclEventRef (*dpnp_all_ext_c)(DPCTLSyclQueueRef,
127
127
const DPCTLEventVectorRef) =
128
128
dpnp_all_c<_DataType, _ResultType>;
129
129
130
- template <typename _DataType1, typename _DataType2, typename _ResultType>
131
- class dpnp_allclose_c_kernel ;
130
+ template <typename _DataType1, typename _DataType2, typename _TolType>
131
+ class dpnp_allclose_kernel ;
132
+
133
+ template <typename _DataType1, typename _DataType2, typename _TolType>
134
+ static sycl::event dpnp_allclose (sycl::queue &q,
135
+ const _DataType1 *array1,
136
+ const _DataType2 *array2,
137
+ bool *result,
138
+ const size_t size,
139
+ const _TolType rtol_val,
140
+ const _TolType atol_val)
141
+ {
142
+ sycl::event fill_event = q.fill (result, true , 1 );
143
+ if (!size) {
144
+ return fill_event;
145
+ }
146
+
147
+ constexpr size_t lws = 64 ;
148
+ constexpr size_t vec_sz = 8 ;
149
+
150
+ auto gws_range =
151
+ sycl::range<1 >(((size + lws * vec_sz - 1 ) / (lws * vec_sz)) * lws);
152
+ auto lws_range = sycl::range<1 >(lws);
153
+ sycl::nd_range<1 > gws (gws_range, lws_range);
154
+
155
+ auto kernel_parallel_for_func = [=](sycl::nd_item<1 > nd_it) {
156
+ auto gr = nd_it.get_sub_group ();
157
+ const auto max_gr_size = gr.get_max_local_range ()[0 ];
158
+ const auto gr_size = gr.get_local_linear_range ();
159
+ const size_t start =
160
+ vec_sz * (nd_it.get_group (0 ) * nd_it.get_local_range (0 ) +
161
+ gr.get_group_linear_id () * max_gr_size);
162
+ const size_t end = sycl::min (start + vec_sz * gr_size, size);
163
+
164
+ // each work-item iterates over "vec_sz" elements in the input arrays
165
+ bool partial = true ;
166
+
167
+ for (size_t i = start + gr.get_local_linear_id (); i < end; i += gr_size)
168
+ {
169
+ if constexpr (std::is_floating_point_v<_DataType1> &&
170
+ std::is_floating_point_v<_DataType2>)
171
+ {
172
+ if (std::isinf (array1[i]) || std::isinf (array2[i])) {
173
+ partial &= (array1[i] == array2[i]);
174
+ continue ;
175
+ }
176
+ }
177
+
178
+ // casting integeral to floating type to avoid bad behavior
179
+ // on abs(MIN_INT), which leads to undefined result
180
+ using _Arr2Type = std::conditional_t <std::is_integral_v<_DataType2>,
181
+ _TolType, _DataType2>;
182
+ _Arr2Type arr2 = static_cast <_Arr2Type>(array2[i]);
183
+
184
+ partial &= (std::abs (array1[i] - arr2) <=
185
+ (atol_val + rtol_val * std::abs (arr2)));
186
+ }
187
+ partial = sycl::all_of_group (gr, partial);
188
+
189
+ if (gr.leader () && (partial == false )) {
190
+ result[0 ] = false ;
191
+ }
192
+ };
193
+
194
+ auto kernel_func = [&](sycl::handler &cgh) {
195
+ cgh.depends_on (fill_event);
196
+ cgh.parallel_for <
197
+ class dpnp_allclose_kernel <_DataType1, _DataType2, _TolType>>(
198
+ gws, kernel_parallel_for_func);
199
+ };
200
+
201
+ return q.submit (kernel_func);
202
+ }
132
203
133
204
template <typename _DataType1, typename _DataType2, typename _ResultType>
134
205
DPCTLSyclEventRef dpnp_allclose_c (DPCTLSyclQueueRef q_ref,
@@ -140,6 +211,9 @@ DPCTLSyclEventRef dpnp_allclose_c(DPCTLSyclQueueRef q_ref,
140
211
double atol_val,
141
212
const DPCTLEventVectorRef dep_event_vec_ref)
142
213
{
214
+ static_assert (std::is_same_v<_ResultType, bool >,
215
+ " Boolean result type is required" );
216
+
143
217
// avoid warning unused variable
144
218
(void )dep_event_vec_ref;
145
219
@@ -152,40 +226,21 @@ DPCTLSyclEventRef dpnp_allclose_c(DPCTLSyclQueueRef q_ref,
152
226
sycl::queue q = *(reinterpret_cast <sycl::queue *>(q_ref));
153
227
sycl::event event;
154
228
155
- DPNPC_ptr_adapter<_DataType1> input1_ptr (q_ref, array1_in, size);
156
- DPNPC_ptr_adapter<_DataType2> input2_ptr (q_ref, array2_in, size);
157
- DPNPC_ptr_adapter<_ResultType> result1_ptr (q_ref, result1, 1 , true , true );
158
- const _DataType1 *array1 = input1_ptr.get_ptr ();
159
- const _DataType2 *array2 = input2_ptr.get_ptr ();
160
- _ResultType *result = result1_ptr.get_ptr ();
161
-
162
- result[0 ] = true ;
229
+ const _DataType1 *array1 = static_cast <const _DataType1 *>(array1_in);
230
+ const _DataType2 *array2 = static_cast <const _DataType2 *>(array2_in);
231
+ bool *result = static_cast <bool *>(result1);
163
232
164
- if (!size) {
165
- return event_ref;
233
+ if (q.get_device ().has (sycl::aspect::fp64)) {
234
+ event =
235
+ dpnp_allclose (q, array1, array2, result, size, rtol_val, atol_val);
236
+ }
237
+ else {
238
+ float rtol = static_cast <float >(rtol_val);
239
+ float atol = static_cast <float >(atol_val);
240
+ event = dpnp_allclose (q, array1, array2, result, size, rtol, atol);
166
241
}
167
-
168
- sycl::range<1 > gws (size);
169
- auto kernel_parallel_for_func = [=](sycl::id<1 > global_id) {
170
- size_t i = global_id[0 ];
171
-
172
- if (std::abs (array1[i] - array2[i]) >
173
- (atol_val + rtol_val * std::abs (array2[i])))
174
- {
175
- result[0 ] = false ;
176
- }
177
- };
178
-
179
- auto kernel_func = [&](sycl::handler &cgh) {
180
- cgh.parallel_for <
181
- class dpnp_allclose_c_kernel <_DataType1, _DataType2, _ResultType>>(
182
- gws, kernel_parallel_for_func);
183
- };
184
-
185
- event = q.submit (kernel_func);
186
242
187
243
event_ref = reinterpret_cast <DPCTLSyclEventRef>(&event);
188
-
189
244
return DPCTLEvent_Copy (event_ref);
190
245
}
191
246
@@ -269,7 +324,7 @@ DPCTLSyclEventRef dpnp_any_c(DPCTLSyclQueueRef q_ref,
269
324
sycl::nd_range<1 > gws (gws_range, lws_range);
270
325
271
326
auto kernel_parallel_for_func = [=](sycl::nd_item<1 > nd_it) {
272
- auto gr = nd_it.get_group ();
327
+ auto gr = nd_it.get_sub_group ();
273
328
const auto max_gr_size = gr.get_max_local_range ()[0 ];
274
329
const size_t start =
275
330
vec_sz * (nd_it.get_group (0 ) * nd_it.get_local_range (0 ) +
@@ -521,8 +576,6 @@ DPCTLSyclEventRef (*dpnp_any_ext_c)(DPCTLSyclQueueRef,
521
576
else { \
522
577
constexpr size_t lws = 64 ; \
523
578
constexpr unsigned int vec_sz = 8 ; \
524
- constexpr sycl::access::address_space global_space = \
525
- sycl::access::address_space::global_space; \
526
579
\
527
580
auto gws_range = sycl::range<1 >( \
528
581
((result_size + lws * vec_sz - 1 ) / (lws * vec_sz)) * lws); \
@@ -537,22 +590,28 @@ DPCTLSyclEventRef (*dpnp_any_ext_c)(DPCTLSyclQueueRef,
537
590
\
538
591
if (start + static_cast <size_t >(vec_sz) * max_sg_size < \
539
592
result_size) { \
540
- sycl::vec<_DataType_input1, vec_sz> x1 = sg.load <vec_sz>( \
541
- sycl::multi_ptr<_DataType_input1, global_space>( \
542
- &input1_data[start])); \
543
- sycl::vec<_DataType_input2, vec_sz> x2 = sg.load <vec_sz>( \
544
- sycl::multi_ptr<_DataType_input2, global_space>( \
545
- &input2_data[start])); \
593
+ auto input1_multi_ptr = sycl::address_space_cast< \
594
+ sycl::access::address_space::global_space, \
595
+ sycl::access::decorated::yes>(&input1_data[start]); \
596
+ auto input2_multi_ptr = sycl::address_space_cast< \
597
+ sycl::access::address_space::global_space, \
598
+ sycl::access::decorated::yes>(&input2_data[start]); \
599
+ auto result_multi_ptr = sycl::address_space_cast< \
600
+ sycl::access::address_space::global_space, \
601
+ sycl::access::decorated::yes>(&result[start]); \
602
+ \
603
+ sycl::vec<_DataType_input1, vec_sz> x1 = \
604
+ sg.load <vec_sz>(input1_multi_ptr); \
605
+ sycl::vec<_DataType_input2, vec_sz> x2 = \
606
+ sg.load <vec_sz>(input2_multi_ptr); \
546
607
sycl::vec<bool , vec_sz> res_vec; \
547
608
\
548
609
for (size_t k = 0 ; k < vec_sz; ++k) { \
549
610
const _DataType_input1 input1_elem = x1[k]; \
550
611
const _DataType_input2 input2_elem = x2[k]; \
551
612
res_vec[k] = __operation__; \
552
613
} \
553
- sg.store <vec_sz>( \
554
- sycl::multi_ptr<bool , global_space>(&result[start]), \
555
- res_vec); \
614
+ sg.store <vec_sz>(result_multi_ptr, res_vec); \
556
615
} \
557
616
else { \
558
617
for (size_t k = start; k < result_size; ++k) { \
0 commit comments