Skip to content

Commit f0d21be

Browse files
authored
Add scalar support in bitwise functions (partial) (#641)
1 parent aba9480 commit f0d21be

File tree

6 files changed

+33
-19
lines changed

6 files changed

+33
-19
lines changed

dpnp/backend/include/dpnp_gen_2arg_1type_tbl.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,13 @@
4646
/** */ \
4747
/** Function "__name__" executes operator "__operation__" over corresponding elements of input arrays */ \
4848
/** */ \
49+
/** @param[out] result1 Output array. */ \
4950
/** @param[in] array1 Input array 1. */ \
51+
/** @param[in] size1 Number of elements in @ref array1 */ \
5052
/** @param[in] array2 Input array 2. */ \
51-
/** @param[out] result1 Output array. */ \
52-
/** @param[in] size Number of elements in the output array. */ \
53+
/** @param[in] size2 Number of elements in @ref array2 */ \
5354
template <typename _DataType> \
54-
void __name__(void* array1, void* array2, void* result1, size_t size);
55+
void __name__(void* result1, const void* array1, const size_t size1, const void* array2, const size_t size2);
5556

5657
#endif
5758

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,8 @@ INP_DLLEXPORT void dpnp_invert_c(void* array1_in, void* result, size_t size);
666666

667667
#define MACRO_2ARG_1TYPE_OP(__name__, __operation__) \
668668
template <typename _DataType> \
669-
INP_DLLEXPORT void __name__(void* array1_in1, void* array2_in, void* result1, size_t size);
669+
INP_DLLEXPORT void __name__( \
670+
void* result1, const void* array1, const size_t size1, const void* array2, const size_t size2);
670671

671672
#include <dpnp_gen_2arg_1type_tbl.hpp>
672673

dpnp/backend/kernels/dpnp_krnl_bitwise.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,25 @@ static void func_map_init_bitwise_1arg_1type(func_map_t& fmap)
7070
class __name__##_kernel; \
7171
\
7272
template <typename _DataType> \
73-
void __name__(void* array1_in, void* array2_in, void* result1, size_t size) \
73+
void __name__(void* result1, const void* array1_in, const size_t size1, const void* array2_in, const size_t size2) \
7474
{ \
75+
if (!size1 || !size2) \
76+
{ \
77+
return; \
78+
} \
79+
\
7580
cl::sycl::event event; \
76-
_DataType* array1 = reinterpret_cast<_DataType*>(array1_in); \
77-
_DataType* array2 = reinterpret_cast<_DataType*>(array2_in); \
81+
const _DataType* array1 = reinterpret_cast<const _DataType*>(array1_in); \
82+
const _DataType* array2 = reinterpret_cast<const _DataType*>(array2_in); \
7883
_DataType* result = reinterpret_cast<_DataType*>(result1); \
7984
\
80-
cl::sycl::range<1> gws(size); \
85+
const size_t gws_size = std::max(size1, size2); \
86+
cl::sycl::range<1> gws(gws_size); \
8187
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) { \
8288
size_t i = global_id[0]; /*for (size_t i = 0; i < size; ++i)*/ \
8389
{ \
84-
_DataType input_elem1 = array1[i]; \
85-
_DataType input_elem2 = array2[i]; \
90+
const _DataType input_elem1 = (size1 == 1) ? array1[0] : array1[i]; \
91+
const _DataType input_elem2 = (size2 == 1) ? array2[0] : array2[i]; \
8692
result[i] = __operation__; \
8793
} \
8894
}; \

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,12 +211,13 @@ cdef extern from "dpnp_iface.hpp":
211211
ctypedef void(*fptr_1out_t)(void *, size_t)
212212
ctypedef void(*fptr_1in_1out_t)(void * , void * , size_t)
213213
ctypedef void(*fptr_2in_1out_t)(void * , void*, void*, size_t)
214+
ctypedef void(*fptr_2in_1out_new_t)(void * , void*, size_t, void*, size_t) # to be fused with fptr_2in_1out_t
214215
ctypedef void(*fptr_blas_gemm_2in_1out_t)(void * , void * , void * , size_t, size_t, size_t)
215216
ctypedef void(*dpnp_reduction_c_t)(void * , const void * , const size_t*, const size_t, const long*, const size_t, const void * , const long*)
216217

217218
cdef dparray call_fptr_1out(DPNPFuncName fptr_name, result_shape, result_dtype)
218219
cdef dparray call_fptr_1in_1out(DPNPFuncName fptr_name, dparray x1, dparray_shape_type result_shape)
219-
cdef dparray call_fptr_2in_1out(DPNPFuncName fptr_name, dparray x1, dparray x2, dparray_shape_type result_shape)
220+
cdef dparray call_fptr_2in_1out(DPNPFuncName fptr_name, dparray x1, dparray x2, dparray_shape_type result_shape, new_version=*)
220221

221222

222223
cpdef dparray dpnp_astype(dparray array1, dtype_target)

dpnp/dpnp_algo/dpnp_algo.pyx

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ cdef dparray call_fptr_1in_1out(DPNPFuncName fptr_name, dparray x1, dparray_shap
330330
return result
331331

332332

333-
cdef dparray call_fptr_2in_1out(DPNPFuncName fptr_name, dparray x1, dparray x2, dparray_shape_type result_shape):
333+
cdef dparray call_fptr_2in_1out(DPNPFuncName fptr_name, dparray x1, dparray x2, dparray_shape_type result_shape, new_version=False):
334334

335335
""" Convert string type names (dparray.dtype) to C enum DPNPFuncType """
336336
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype)
@@ -343,8 +343,13 @@ cdef dparray call_fptr_2in_1out(DPNPFuncName fptr_name, dparray x1, dparray x2,
343343
""" Create result array with type given by FPTR data """
344344
cdef dparray result = dparray(result_shape, dtype=result_type)
345345

346-
cdef fptr_2in_1out_t func = <fptr_2in_1out_t > kernel_data.ptr
347346
""" Call FPTR function """
348-
func(x1.get_data(), x2.get_data(), result.get_data(), x1.size)
347+
# parameter 'new_version' must be removed in shortly
348+
cdef fptr_2in_1out_t func_old = <fptr_2in_1out_t > kernel_data.ptr # can't define it inside 'if' due Cython limitation
349+
cdef fptr_2in_1out_new_t func_new = <fptr_2in_1out_new_t > kernel_data.ptr
350+
if (new_version):
351+
func_new(result.get_data(), x1.get_data(), x1.size, x2.get_data(), x2.size)
352+
else:
353+
func_old(x1.get_data(), x2.get_data(), result.get_data(), x1.size)
349354

350355
return result

dpnp/dpnp_algo/dpnp_algo_bitwise.pyx

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,23 +47,23 @@ __all__ += [
4747

4848

4949
cpdef dparray dpnp_bitwise_and(dparray array1, dparray array2):
50-
return call_fptr_2in_1out(DPNP_FN_BITWISE_AND, array1, array2, array1.shape)
50+
return call_fptr_2in_1out(DPNP_FN_BITWISE_AND, array1, array2, array1.shape, True)
5151

5252

5353
cpdef dparray dpnp_bitwise_or(dparray array1, dparray array2):
54-
return call_fptr_2in_1out(DPNP_FN_BITWISE_OR, array1, array2, array1.shape)
54+
return call_fptr_2in_1out(DPNP_FN_BITWISE_OR, array1, array2, array1.shape, True)
5555

5656

5757
cpdef dparray dpnp_bitwise_xor(dparray array1, dparray array2):
58-
return call_fptr_2in_1out(DPNP_FN_BITWISE_XOR, array1, array2, array1.shape)
58+
return call_fptr_2in_1out(DPNP_FN_BITWISE_XOR, array1, array2, array1.shape, True)
5959

6060

6161
cpdef dparray dpnp_invert(dparray arr):
6262
return call_fptr_1in_1out(DPNP_FN_INVERT, arr, arr.shape)
6363

6464

6565
cpdef dparray dpnp_left_shift(dparray array1, dparray array2):
66-
return call_fptr_2in_1out(DPNP_FN_LEFT_SHIFT, array1, array2, array1.shape)
66+
return call_fptr_2in_1out(DPNP_FN_LEFT_SHIFT, array1, array2, array1.shape, True)
6767

6868
cpdef dparray dpnp_right_shift(dparray array1, dparray array2):
69-
return call_fptr_2in_1out(DPNP_FN_RIGHT_SHIFT, array1, array2, array1.shape)
69+
return call_fptr_2in_1out(DPNP_FN_RIGHT_SHIFT, array1, array2, array1.shape, True)

0 commit comments

Comments
 (0)