27
27
28
28
#include " dpnp_fptr.hpp"
29
29
#include " dpnp_iface.hpp"
30
+ #include " dpnp_iterator.hpp"
30
31
#include " dpnp_utils.hpp"
31
32
#include " dpnpc_memory_adapter.hpp"
32
33
#include " queue_sycl.hpp"
@@ -49,27 +50,66 @@ DPCTLSyclEventRef dpnp_invert_c(DPCTLSyclQueueRef q_ref,
49
50
sycl::queue q = *(reinterpret_cast <sycl::queue*>(q_ref));
50
51
sycl::event event;
51
52
52
- DPNPC_ptr_adapter<_DataType> input1_ptr (q_ref, array1_in, size);
53
- _DataType* array1 = input1_ptr.get_ptr ();
54
- _DataType* result = reinterpret_cast <_DataType*>(result1);
53
+ _DataType* input_data = static_cast <_DataType*>(array1_in);
54
+ _DataType* result = static_cast <_DataType*>(result1);
55
55
56
- sycl::range<1 > gws (size);
57
- auto kernel_parallel_for_func = [=](sycl::id<1 > global_id) {
58
- size_t i = global_id[0 ]; /* for (size_t i = 0; i < size; ++i)*/
56
+ constexpr size_t lws = 64 ;
57
+ constexpr unsigned int vec_sz = 8 ;
58
+
59
+ auto gws_range = sycl::range<1 >(((size + lws * vec_sz - 1 ) / (lws * vec_sz)) * lws);
60
+ auto lws_range = sycl::range<1 >(lws);
61
+
62
+ auto kernel_parallel_for_func = [=](sycl::nd_item<1 > nd_it) {
63
+ auto sg = nd_it.get_sub_group ();
64
+ const auto max_sg_size = sg.get_max_local_range ()[0 ];
65
+ const size_t start =
66
+ vec_sz * (nd_it.get_group (0 ) * nd_it.get_local_range (0 ) + sg.get_group_id ()[0 ] * max_sg_size);
67
+
68
+ if (start + static_cast <size_t >(vec_sz) * max_sg_size < size)
59
69
{
60
- _DataType input_elem1 = array1[i];
61
- result[i] = ~input_elem1;
70
+ using multi_ptrT = sycl::multi_ptr<_DataType, sycl::access::address_space::global_space>;
71
+
72
+ sycl::vec<_DataType, vec_sz> x = sg.load <vec_sz>(multi_ptrT (&input_data[start]));
73
+ sycl::vec<_DataType, vec_sz> res_vec;
74
+
75
+ if constexpr (std::is_same_v<_DataType, bool >)
76
+ {
77
+ #pragma unroll
78
+ for (size_t k = 0 ; k < vec_sz; ++k)
79
+ {
80
+ res_vec[k] = !(x[k]);
81
+ }
82
+ }
83
+ else
84
+ {
85
+ res_vec = ~x;
86
+ }
87
+
88
+ sg.store <vec_sz>(multi_ptrT (&result[start]), res_vec);
89
+ }
90
+ else
91
+ {
92
+ for (size_t k = start + sg.get_local_id ()[0 ]; k < size; k += max_sg_size)
93
+ {
94
+ if constexpr (std::is_same_v<_DataType, bool >)
95
+ {
96
+ result[k] = !(input_data[k]);
97
+ }
98
+ else
99
+ {
100
+ result[k] = ~(input_data[k]);
101
+ }
102
+ }
62
103
}
63
104
};
64
105
65
106
auto kernel_func = [&](sycl::handler& cgh) {
66
- cgh.parallel_for <class dpnp_invert_c_kernel <_DataType>>(gws, kernel_parallel_for_func);
107
+ cgh.parallel_for <class dpnp_invert_c_kernel <_DataType>>(sycl::nd_range<1 >(gws_range, lws_range),
108
+ kernel_parallel_for_func);
67
109
};
68
-
69
110
event = q.submit (kernel_func);
70
111
71
112
event_ref = reinterpret_cast <DPCTLSyclEventRef>(&event);
72
-
73
113
return DPCTLEvent_Copy (event_ref);
74
114
}
75
115
@@ -84,6 +124,7 @@ void dpnp_invert_c(void* array1_in, void* result1, size_t size)
84
124
size,
85
125
dep_event_vec_ref);
86
126
DPCTLEvent_WaitAndThrow (event_ref);
127
+ DPCTLEvent_Delete (event_ref);
87
128
}
88
129
89
130
template <typename _DataType>
@@ -98,9 +139,11 @@ DPCTLSyclEventRef (*dpnp_invert_ext_c)(DPCTLSyclQueueRef,
98
139
99
140
static void func_map_init_bitwise_1arg_1type (func_map_t & fmap)
100
141
{
142
+ fmap[DPNPFuncName::DPNP_FN_INVERT][eft_BLN][eft_BLN] = {eft_BLN, (void *)dpnp_invert_default_c<bool >};
101
143
fmap[DPNPFuncName::DPNP_FN_INVERT][eft_INT][eft_INT] = {eft_INT, (void *)dpnp_invert_default_c<int32_t >};
102
144
fmap[DPNPFuncName::DPNP_FN_INVERT][eft_LNG][eft_LNG] = {eft_LNG, (void *)dpnp_invert_default_c<int64_t >};
103
145
146
+ fmap[DPNPFuncName::DPNP_FN_INVERT_EXT][eft_BLN][eft_BLN] = {eft_BLN, (void *)dpnp_invert_ext_c<bool >};
104
147
fmap[DPNPFuncName::DPNP_FN_INVERT_EXT][eft_INT][eft_INT] = {eft_INT, (void *)dpnp_invert_ext_c<int32_t >};
105
148
fmap[DPNPFuncName::DPNP_FN_INVERT_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void *)dpnp_invert_ext_c<int64_t >};
106
149
@@ -114,6 +157,9 @@ static void func_map_init_bitwise_1arg_1type(func_map_t& fmap)
114
157
template <typename _KernelNameSpecialization> \
115
158
class __name__ ##_strides_kernel; \
116
159
\
160
+ template <typename _KernelNameSpecialization> \
161
+ class __name__ ##_broadcast_kernel; \
162
+ \
117
163
template <typename _DataType> \
118
164
DPCTLSyclEventRef __name__ (DPCTLSyclQueueRef q_ref, \
119
165
void * result_out, \
@@ -152,6 +198,8 @@ static void func_map_init_bitwise_1arg_1type(func_map_t& fmap)
152
198
_DataType* input2_data = static_cast <_DataType*>(const_cast <void *>(input2_in)); \
153
199
_DataType* result = static_cast <_DataType*>(result_out); \
154
200
\
201
+ bool use_broadcasting = !array_equal (input1_shape, input1_ndim, input2_shape, input2_ndim); \
202
+ \
155
203
shape_elem_type* input1_shape_offsets = new shape_elem_type[input1_ndim]; \
156
204
\
157
205
get_shape_offsets_inkernel (input1_shape, input1_ndim, input1_shape_offsets); \
@@ -167,7 +215,42 @@ static void func_map_init_bitwise_1arg_1type(func_map_t& fmap)
167
215
sycl::event event; \
168
216
sycl::range<1 > gws (result_size); \
169
217
\
170
- if (use_strides) \
218
+ if (use_broadcasting) \
219
+ { \
220
+ DPNPC_id<_DataType>* input1_it; \
221
+ const size_t input1_it_size_in_bytes = sizeof (DPNPC_id<_DataType>); \
222
+ input1_it = reinterpret_cast <DPNPC_id<_DataType>*>(dpnp_memory_alloc_c (q_ref, input1_it_size_in_bytes)); \
223
+ new (input1_it) DPNPC_id<_DataType>(q_ref, input1_data, input1_shape, input1_strides, input1_ndim); \
224
+ \
225
+ input1_it->broadcast_to_shape (result_shape, result_ndim); \
226
+ \
227
+ DPNPC_id<_DataType>* input2_it; \
228
+ const size_t input2_it_size_in_bytes = sizeof (DPNPC_id<_DataType>); \
229
+ input2_it = reinterpret_cast <DPNPC_id<_DataType>*>(dpnp_memory_alloc_c (q_ref, input2_it_size_in_bytes)); \
230
+ new (input2_it) DPNPC_id<_DataType>(q_ref, input2_data, input2_shape, input2_strides, input2_ndim); \
231
+ \
232
+ input2_it->broadcast_to_shape (result_shape, result_ndim); \
233
+ \
234
+ auto kernel_parallel_for_func = [=](sycl::id<1 > global_id) { \
235
+ const size_t i = global_id[0 ]; /* for (size_t i = 0; i < result_size; ++i) */ \
236
+ { \
237
+ const _DataType input1_elem = (*input1_it)[i]; \
238
+ const _DataType input2_elem = (*input2_it)[i]; \
239
+ result[i] = __operation__; \
240
+ } \
241
+ }; \
242
+ auto kernel_func = [&](sycl::handler& cgh) { \
243
+ cgh.parallel_for <class __name__ ##_broadcast_kernel<_DataType>>(gws, kernel_parallel_for_func); \
244
+ }; \
245
+ \
246
+ q.submit (kernel_func).wait (); \
247
+ \
248
+ input1_it->~DPNPC_id (); \
249
+ input2_it->~DPNPC_id (); \
250
+ \
251
+ return event_ref; \
252
+ } \
253
+ else if (use_strides) \
171
254
{ \
172
255
if ((result_ndim != input1_ndim) || (result_ndim != input2_ndim)) \
173
256
{ \
@@ -332,18 +415,21 @@ static void func_map_init_bitwise_2arg_1type(func_map_t& fmap)
332
415
fmap[DPNPFuncName::DPNP_FN_BITWISE_AND][eft_INT][eft_INT] = {eft_INT, (void *)dpnp_bitwise_and_c_default<int32_t >};
333
416
fmap[DPNPFuncName::DPNP_FN_BITWISE_AND][eft_LNG][eft_LNG] = {eft_LNG, (void *)dpnp_bitwise_and_c_default<int64_t >};
334
417
418
+ fmap[DPNPFuncName::DPNP_FN_BITWISE_AND_EXT][eft_BLN][eft_BLN] = {eft_BLN, (void *)dpnp_bitwise_and_c_ext<bool >};
335
419
fmap[DPNPFuncName::DPNP_FN_BITWISE_AND_EXT][eft_INT][eft_INT] = {eft_INT, (void *)dpnp_bitwise_and_c_ext<int32_t >};
336
420
fmap[DPNPFuncName::DPNP_FN_BITWISE_AND_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void *)dpnp_bitwise_and_c_ext<int64_t >};
337
421
338
422
fmap[DPNPFuncName::DPNP_FN_BITWISE_OR][eft_INT][eft_INT] = {eft_INT, (void *)dpnp_bitwise_or_c_default<int32_t >};
339
423
fmap[DPNPFuncName::DPNP_FN_BITWISE_OR][eft_LNG][eft_LNG] = {eft_LNG, (void *)dpnp_bitwise_or_c_default<int64_t >};
340
424
425
+ fmap[DPNPFuncName::DPNP_FN_BITWISE_OR_EXT][eft_BLN][eft_BLN] = {eft_BLN, (void *)dpnp_bitwise_or_c_ext<bool >};
341
426
fmap[DPNPFuncName::DPNP_FN_BITWISE_OR_EXT][eft_INT][eft_INT] = {eft_INT, (void *)dpnp_bitwise_or_c_ext<int32_t >};
342
427
fmap[DPNPFuncName::DPNP_FN_BITWISE_OR_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void *)dpnp_bitwise_or_c_ext<int64_t >};
343
428
344
429
fmap[DPNPFuncName::DPNP_FN_BITWISE_XOR][eft_INT][eft_INT] = {eft_INT, (void *)dpnp_bitwise_xor_c_default<int32_t >};
345
430
fmap[DPNPFuncName::DPNP_FN_BITWISE_XOR][eft_LNG][eft_LNG] = {eft_LNG, (void *)dpnp_bitwise_xor_c_default<int64_t >};
346
431
432
+ fmap[DPNPFuncName::DPNP_FN_BITWISE_XOR_EXT][eft_BLN][eft_BLN] = {eft_BLN, (void *)dpnp_bitwise_xor_c_ext<bool >};
347
433
fmap[DPNPFuncName::DPNP_FN_BITWISE_XOR_EXT][eft_INT][eft_INT] = {eft_INT, (void *)dpnp_bitwise_xor_c_ext<int32_t >};
348
434
fmap[DPNPFuncName::DPNP_FN_BITWISE_XOR_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void *)dpnp_bitwise_xor_c_ext<int64_t >};
349
435
0 commit comments