Skip to content

Commit 29a2063

Browse files
authored
Add support of bool type in bitwise operations (#1334)
* Add support of bool type in bitwise operations * Update dpnp/dpnp_algo/dpnp_algo_bitwise.pyx
1 parent 684f393 commit 29a2063

14 files changed

+401
-165
lines changed

dpnp/backend/include/dpnp_gen_2arg_1type_tbl.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//*****************************************************************************
2-
// Copyright (c) 2016-2020, Intel Corporation
2+
// Copyright (c) 2016-2023, Intel Corporation
33
// All rights reserved.
44
//
55
// Redistribution and use in source and binary forms, with or without
@@ -104,7 +104,7 @@
104104

105105
#endif
106106

107-
MACRO_2ARG_1TYPE_OP(dpnp_bitwise_and_c, input1_elem& input2_elem)
107+
MACRO_2ARG_1TYPE_OP(dpnp_bitwise_and_c, input1_elem & input2_elem)
108108
MACRO_2ARG_1TYPE_OP(dpnp_bitwise_or_c, input1_elem | input2_elem)
109109
MACRO_2ARG_1TYPE_OP(dpnp_bitwise_xor_c, input1_elem ^ input2_elem)
110110
MACRO_2ARG_1TYPE_OP(dpnp_left_shift_c, input1_elem << input2_elem)

dpnp/backend/kernels/dpnp_krnl_bitwise.cpp

Lines changed: 98 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
#include "dpnp_fptr.hpp"
2929
#include "dpnp_iface.hpp"
30+
#include "dpnp_iterator.hpp"
3031
#include "dpnp_utils.hpp"
3132
#include "dpnpc_memory_adapter.hpp"
3233
#include "queue_sycl.hpp"
@@ -49,27 +50,66 @@ DPCTLSyclEventRef dpnp_invert_c(DPCTLSyclQueueRef q_ref,
4950
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
5051
sycl::event event;
5152

52-
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, array1_in, size);
53-
_DataType* array1 = input1_ptr.get_ptr();
54-
_DataType* result = reinterpret_cast<_DataType*>(result1);
53+
_DataType* input_data = static_cast<_DataType*>(array1_in);
54+
_DataType* result = static_cast<_DataType*>(result1);
5555

56-
sycl::range<1> gws(size);
57-
auto kernel_parallel_for_func = [=](sycl::id<1> global_id) {
58-
size_t i = global_id[0]; /*for (size_t i = 0; i < size; ++i)*/
56+
constexpr size_t lws = 64;
57+
constexpr unsigned int vec_sz = 8;
58+
59+
auto gws_range = sycl::range<1>(((size + lws * vec_sz - 1) / (lws * vec_sz)) * lws);
60+
auto lws_range = sycl::range<1>(lws);
61+
62+
auto kernel_parallel_for_func = [=](sycl::nd_item<1> nd_it) {
63+
auto sg = nd_it.get_sub_group();
64+
const auto max_sg_size = sg.get_max_local_range()[0];
65+
const size_t start =
66+
vec_sz * (nd_it.get_group(0) * nd_it.get_local_range(0) + sg.get_group_id()[0] * max_sg_size);
67+
68+
if (start + static_cast<size_t>(vec_sz) * max_sg_size < size)
5969
{
60-
_DataType input_elem1 = array1[i];
61-
result[i] = ~input_elem1;
70+
using multi_ptrT = sycl::multi_ptr<_DataType, sycl::access::address_space::global_space>;
71+
72+
sycl::vec<_DataType, vec_sz> x = sg.load<vec_sz>(multi_ptrT(&input_data[start]));
73+
sycl::vec<_DataType, vec_sz> res_vec;
74+
75+
if constexpr (std::is_same_v<_DataType, bool>)
76+
{
77+
#pragma unroll
78+
for (size_t k = 0; k < vec_sz; ++k)
79+
{
80+
res_vec[k] = !(x[k]);
81+
}
82+
}
83+
else
84+
{
85+
res_vec = ~x;
86+
}
87+
88+
sg.store<vec_sz>(multi_ptrT(&result[start]), res_vec);
89+
}
90+
else
91+
{
92+
for (size_t k = start + sg.get_local_id()[0]; k < size; k += max_sg_size)
93+
{
94+
if constexpr (std::is_same_v<_DataType, bool>)
95+
{
96+
result[k] = !(input_data[k]);
97+
}
98+
else
99+
{
100+
result[k] = ~(input_data[k]);
101+
}
102+
}
62103
}
63104
};
64105

65106
auto kernel_func = [&](sycl::handler& cgh) {
66-
cgh.parallel_for<class dpnp_invert_c_kernel<_DataType>>(gws, kernel_parallel_for_func);
107+
cgh.parallel_for<class dpnp_invert_c_kernel<_DataType>>(sycl::nd_range<1>(gws_range, lws_range),
108+
kernel_parallel_for_func);
67109
};
68-
69110
event = q.submit(kernel_func);
70111

71112
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);
72-
73113
return DPCTLEvent_Copy(event_ref);
74114
}
75115

@@ -84,6 +124,7 @@ void dpnp_invert_c(void* array1_in, void* result1, size_t size)
84124
size,
85125
dep_event_vec_ref);
86126
DPCTLEvent_WaitAndThrow(event_ref);
127+
DPCTLEvent_Delete(event_ref);
87128
}
88129

89130
template <typename _DataType>
@@ -98,9 +139,11 @@ DPCTLSyclEventRef (*dpnp_invert_ext_c)(DPCTLSyclQueueRef,
98139

99140
static void func_map_init_bitwise_1arg_1type(func_map_t& fmap)
100141
{
142+
fmap[DPNPFuncName::DPNP_FN_INVERT][eft_BLN][eft_BLN] = {eft_BLN, (void*)dpnp_invert_default_c<bool>};
101143
fmap[DPNPFuncName::DPNP_FN_INVERT][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_invert_default_c<int32_t>};
102144
fmap[DPNPFuncName::DPNP_FN_INVERT][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_invert_default_c<int64_t>};
103145

146+
fmap[DPNPFuncName::DPNP_FN_INVERT_EXT][eft_BLN][eft_BLN] = {eft_BLN, (void*)dpnp_invert_ext_c<bool>};
104147
fmap[DPNPFuncName::DPNP_FN_INVERT_EXT][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_invert_ext_c<int32_t>};
105148
fmap[DPNPFuncName::DPNP_FN_INVERT_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_invert_ext_c<int64_t>};
106149

@@ -114,6 +157,9 @@ static void func_map_init_bitwise_1arg_1type(func_map_t& fmap)
114157
template <typename _KernelNameSpecialization> \
115158
class __name__##_strides_kernel; \
116159
\
160+
template <typename _KernelNameSpecialization> \
161+
class __name__##_broadcast_kernel; \
162+
\
117163
template <typename _DataType> \
118164
DPCTLSyclEventRef __name__(DPCTLSyclQueueRef q_ref, \
119165
void* result_out, \
@@ -152,6 +198,8 @@ static void func_map_init_bitwise_1arg_1type(func_map_t& fmap)
152198
_DataType* input2_data = static_cast<_DataType*>(const_cast<void*>(input2_in)); \
153199
_DataType* result = static_cast<_DataType*>(result_out); \
154200
\
201+
bool use_broadcasting = !array_equal(input1_shape, input1_ndim, input2_shape, input2_ndim); \
202+
\
155203
shape_elem_type* input1_shape_offsets = new shape_elem_type[input1_ndim]; \
156204
\
157205
get_shape_offsets_inkernel(input1_shape, input1_ndim, input1_shape_offsets); \
@@ -167,7 +215,42 @@ static void func_map_init_bitwise_1arg_1type(func_map_t& fmap)
167215
sycl::event event; \
168216
sycl::range<1> gws(result_size); \
169217
\
170-
if (use_strides) \
218+
if (use_broadcasting) \
219+
{ \
220+
DPNPC_id<_DataType>* input1_it; \
221+
const size_t input1_it_size_in_bytes = sizeof(DPNPC_id<_DataType>); \
222+
input1_it = reinterpret_cast<DPNPC_id<_DataType>*>(dpnp_memory_alloc_c(q_ref, input1_it_size_in_bytes)); \
223+
new (input1_it) DPNPC_id<_DataType>(q_ref, input1_data, input1_shape, input1_strides, input1_ndim); \
224+
\
225+
input1_it->broadcast_to_shape(result_shape, result_ndim); \
226+
\
227+
DPNPC_id<_DataType>* input2_it; \
228+
const size_t input2_it_size_in_bytes = sizeof(DPNPC_id<_DataType>); \
229+
input2_it = reinterpret_cast<DPNPC_id<_DataType>*>(dpnp_memory_alloc_c(q_ref, input2_it_size_in_bytes)); \
230+
new (input2_it) DPNPC_id<_DataType>(q_ref, input2_data, input2_shape, input2_strides, input2_ndim); \
231+
\
232+
input2_it->broadcast_to_shape(result_shape, result_ndim); \
233+
\
234+
auto kernel_parallel_for_func = [=](sycl::id<1> global_id) { \
235+
const size_t i = global_id[0]; /* for (size_t i = 0; i < result_size; ++i) */ \
236+
{ \
237+
const _DataType input1_elem = (*input1_it)[i]; \
238+
const _DataType input2_elem = (*input2_it)[i]; \
239+
result[i] = __operation__; \
240+
} \
241+
}; \
242+
auto kernel_func = [&](sycl::handler& cgh) { \
243+
cgh.parallel_for<class __name__##_broadcast_kernel<_DataType>>(gws, kernel_parallel_for_func); \
244+
}; \
245+
\
246+
q.submit(kernel_func).wait(); \
247+
\
248+
input1_it->~DPNPC_id(); \
249+
input2_it->~DPNPC_id(); \
250+
\
251+
return event_ref; \
252+
} \
253+
else if (use_strides) \
171254
{ \
172255
if ((result_ndim != input1_ndim) || (result_ndim != input2_ndim)) \
173256
{ \
@@ -332,18 +415,21 @@ static void func_map_init_bitwise_2arg_1type(func_map_t& fmap)
332415
fmap[DPNPFuncName::DPNP_FN_BITWISE_AND][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_bitwise_and_c_default<int32_t>};
333416
fmap[DPNPFuncName::DPNP_FN_BITWISE_AND][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_bitwise_and_c_default<int64_t>};
334417

418+
fmap[DPNPFuncName::DPNP_FN_BITWISE_AND_EXT][eft_BLN][eft_BLN] = {eft_BLN, (void*)dpnp_bitwise_and_c_ext<bool>};
335419
fmap[DPNPFuncName::DPNP_FN_BITWISE_AND_EXT][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_bitwise_and_c_ext<int32_t>};
336420
fmap[DPNPFuncName::DPNP_FN_BITWISE_AND_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_bitwise_and_c_ext<int64_t>};
337421

338422
fmap[DPNPFuncName::DPNP_FN_BITWISE_OR][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_bitwise_or_c_default<int32_t>};
339423
fmap[DPNPFuncName::DPNP_FN_BITWISE_OR][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_bitwise_or_c_default<int64_t>};
340424

425+
fmap[DPNPFuncName::DPNP_FN_BITWISE_OR_EXT][eft_BLN][eft_BLN] = {eft_BLN, (void*)dpnp_bitwise_or_c_ext<bool>};
341426
fmap[DPNPFuncName::DPNP_FN_BITWISE_OR_EXT][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_bitwise_or_c_ext<int32_t>};
342427
fmap[DPNPFuncName::DPNP_FN_BITWISE_OR_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_bitwise_or_c_ext<int64_t>};
343428

344429
fmap[DPNPFuncName::DPNP_FN_BITWISE_XOR][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_bitwise_xor_c_default<int32_t>};
345430
fmap[DPNPFuncName::DPNP_FN_BITWISE_XOR][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_bitwise_xor_c_default<int64_t>};
346431

432+
fmap[DPNPFuncName::DPNP_FN_BITWISE_XOR_EXT][eft_BLN][eft_BLN] = {eft_BLN, (void*)dpnp_bitwise_xor_c_ext<bool>};
347433
fmap[DPNPFuncName::DPNP_FN_BITWISE_XOR_EXT][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_bitwise_xor_c_ext<int32_t>};
348434
fmap[DPNPFuncName::DPNP_FN_BITWISE_XOR_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_bitwise_xor_c_ext<int64_t>};
349435

dpnp/dpnp_algo/dpnp_algo_bitwise.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# cython: language_level=3
22
# -*- coding: utf-8 -*-
33
# *****************************************************************************
4-
# Copyright (c) 2016-2020, Intel Corporation
4+
# Copyright (c) 2016-2023, Intel Corporation
55
# All rights reserved.
66
#
77
# Redistribution and use in source and binary forms, with or without
@@ -68,8 +68,8 @@ cpdef utils.dpnp_descriptor dpnp_bitwise_xor(utils.dpnp_descriptor x1_obj,
6868
return call_fptr_2in_1out_strides(DPNP_FN_BITWISE_XOR_EXT, x1_obj, x2_obj, dtype=dtype, out=out, where=where)
6969

7070

71-
cpdef utils.dpnp_descriptor dpnp_invert(utils.dpnp_descriptor arr):
72-
return call_fptr_1in_1out(DPNP_FN_INVERT_EXT, arr, arr.shape)
71+
cpdef utils.dpnp_descriptor dpnp_invert(utils.dpnp_descriptor arr, utils.dpnp_descriptor out=None):
72+
return call_fptr_1in_1out(DPNP_FN_INVERT_EXT, arr, arr.shape, out=out, func_name="invert")
7373

7474

7575
cpdef utils.dpnp_descriptor dpnp_left_shift(utils.dpnp_descriptor x1_obj,

dpnp/dpnp_array.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,9 @@ def __abs__(self):
125125
def __add__(self, other):
126126
return dpnp.add(self, other)
127127

128-
# '__and__',
128+
def __and__(self, other):
129+
return dpnp.bitwise_and(self, other)
130+
129131
# '__array__',
130132
# '__array_finalize__',
131133
# '__array_function__',
@@ -193,9 +195,17 @@ def __gt__(self, other):
193195

194196
# '__hash__',
195197
# '__iadd__',
196-
# '__iand__',
198+
199+
def __iand__(self, other):
200+
dpnp.bitwise_and(self, other, out=self)
201+
return self
202+
197203
# '__ifloordiv__',
198-
# '__ilshift__',
204+
205+
def __ilshift__(self, other):
206+
dpnp.left_shift(self, other, out=self)
207+
return self
208+
199209
# '__imatmul__',
200210
# '__imod__',
201211
# '__imul__',
@@ -209,18 +219,28 @@ def __index__(self):
209219
def __int__(self):
210220
return self._array_obj.__int__()
211221

212-
# '__invert__',
213-
# '__ior__',
222+
def __invert__(self):
223+
return dpnp.invert(self)
224+
225+
def __ior__(self, other):
226+
dpnp.bitwise_or(self, other, out=self)
227+
return self
214228

215229
def __ipow__(self, other):
216230
dpnp.power(self, other, out=self)
217231
return self
218232

219-
# '__irshift__',
233+
def __irshift__(self, other):
234+
dpnp.right_shift(self, other, out=self)
235+
return self
236+
220237
# '__isub__',
221238
# '__iter__',
222239
# '__itruediv__',
223-
# '__ixor__',
240+
241+
def __ixor__(self, other):
242+
dpnp.bitwise_xor(self, other, out=self)
243+
return self
224244

225245
def __le__(self, other):
226246
return dpnp.less_equal(self, other)
@@ -232,7 +252,8 @@ def __len__(self):
232252

233253
return self._array_obj.__len__()
234254

235-
# '__lshift__',
255+
def __lshift__(self, other):
256+
return dpnp.left_shift(self, other)
236257

237258
def __lt__(self, other):
238259
return dpnp.less(self, other)
@@ -253,7 +274,10 @@ def __neg__(self):
253274
return dpnp.negative(self)
254275

255276
# '__new__',
256-
# '__or__',
277+
278+
def __or__(self, other):
279+
return dpnp.bitwise_or(self, other)
280+
257281
# '__pos__',
258282

259283
def __pow__(self, other):
@@ -262,7 +286,9 @@ def __pow__(self, other):
262286
def __radd__(self, other):
263287
return dpnp.add(other, self)
264288

265-
# '__rand__',
289+
def __rand__(self, other):
290+
return dpnp.bitwise_and(other, self)
291+
266292
# '__rdivmod__',
267293
# '__reduce__',
268294
# '__reduce_ex__',
@@ -271,7 +297,9 @@ def __repr__(self):
271297
return dpt.usm_ndarray_repr(self._array_obj, prefix="array")
272298

273299
# '__rfloordiv__',
274-
# '__rlshift__',
300+
301+
def __rlshift__(self, other):
302+
return dpnp.left_shift(other, self)
275303

276304
def __rmatmul__(self, other):
277305
return dpnp.matmul(other, self)
@@ -282,21 +310,27 @@ def __rmod__(self, other):
282310
def __rmul__(self, other):
283311
return dpnp.multiply(other, self)
284312

285-
# '__ror__',
286-
313+
def __ror__(self, other):
314+
return dpnp.bitwise_or(other, self)
315+
287316
def __rpow__(self, other):
288317
return dpnp.power(other, self)
289318

290-
# '__rrshift__',
291-
# '__rshift__',
319+
def __rrshift__(self, other):
320+
return dpnp.right_shift(other, self)
321+
322+
def __rshift__(self, other):
323+
return dpnp.right_shift(self, other)
292324

293325
def __rsub__(self, other):
294326
return dpnp.subtract(other, self)
295327

296328
def __rtruediv__(self, other):
297329
return dpnp.true_divide(other, self)
298330

299-
# '__rxor__',
331+
def __rxor__(self, other):
332+
return dpnp.bitwise_xor(other, self)
333+
300334
# '__setattr__',
301335

302336
def __setitem__(self, key, val):
@@ -334,7 +368,8 @@ def __sub__(self, other):
334368
def __truediv__(self, other):
335369
return dpnp.true_divide(self, other)
336370

337-
# '__xor__',
371+
def __xor__(self, other):
372+
return dpnp.bitwise_xor(self, other)
338373

339374
@staticmethod
340375
def _create_from_usm_ndarray(usm_ary : dpt.usm_ndarray):

dpnp/dpnp_iface.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ def from_dlpack(obj, /):
251251
def get_dpnp_descriptor(ext_obj,
252252
copy_when_strides=True,
253253
copy_when_nondefault_queue=True,
254+
alloc_dtype=None,
254255
alloc_usm_type=None,
255256
alloc_queue=None):
256257
"""
@@ -274,7 +275,7 @@ def get_dpnp_descriptor(ext_obj,
274275
# If input object is a scalar, it means it was allocated on host memory.
275276
# We need to copy it to USM memory according to compute follows data paradigm.
276277
if isscalar(ext_obj):
277-
ext_obj = array(ext_obj, usm_type=alloc_usm_type, sycl_queue=alloc_queue)
278+
ext_obj = array(ext_obj, dtype=alloc_dtype, usm_type=alloc_usm_type, sycl_queue=alloc_queue)
278279

279280
# while dpnp functions have no implementation with strides support
280281
# we need to create a non-strided copy

0 commit comments

Comments
 (0)