Skip to content

Commit 92192ad

Browse files
authored
MULTIPLY enable broadcasting version 0.1 (#620)
* MULTIPLY enable broadcasting version 0.1
1 parent f95f159 commit 92192ad

12 files changed

+320
-60
lines changed

dpnp/backend/include/dpnp_gen_2arg_3type_tbl.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ MACRO_2ARG_3TYPES_OP(dpnp_fmod_c, cl::sycl::fmod((double)input_elem1, (double)in
6666
MACRO_2ARG_3TYPES_OP(dpnp_hypot_c, cl::sycl::hypot((double)input_elem1, (double)input_elem2), oneapi::mkl::vm::hypot)
6767
MACRO_2ARG_3TYPES_OP(dpnp_maximum_c, cl::sycl::max(input_elem1, input_elem2), oneapi::mkl::vm::fmax)
6868
MACRO_2ARG_3TYPES_OP(dpnp_minimum_c, cl::sycl::min(input_elem1, input_elem2), oneapi::mkl::vm::fmin)
69-
MACRO_2ARG_3TYPES_OP(dpnp_multiply_c, input_elem1* input_elem2, oneapi::mkl::vm::mul)
7069
MACRO_2ARG_3TYPES_OP(dpnp_power_c, cl::sycl::pow((double)input_elem1, (double)input_elem2), oneapi::mkl::vm::pow)
7170
MACRO_2ARG_3TYPES_OP(dpnp_subtract_c, input_elem1 - input_elem2, oneapi::mkl::vm::sub)
7271

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,33 @@ INP_DLLEXPORT void dpnp_floor_divide_c(void* array1_in, void* array2_in, void* r
725725
template <typename _DataType_input, typename _DataType_output>
726726
INP_DLLEXPORT void dpnp_modf_c(void* array1_in, void* result1_out, void* result2_out, size_t size);
727727

728+
/**
729+
* @ingroup BACKEND_API
730+
* @brief multiply function.
731+
*
732+
* @param [out] result_out Output array.
733+
* @param [in] input1_in Input 1 either array or scalar.
734+
* @param [in] input1_size Number of elements in input 1.
735+
* @param [in] input1_shape Shape of input 1.
736+
* @param [in] input1_shape_ndim Size of shape 1.
737+
* @param [in] input2_in Input 2 either array or scalar.
738+
* @param [in] input2_size Number of elements in input 2.
739+
* @param [in] input2_shape Shape of input 2.
740+
* @param [in] input2_shape_ndim Size of shape 2.
741+
* @param [in] where Mask array.
742+
*/
743+
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
744+
INP_DLLEXPORT void dpnp_multiply_c(void* result_out,
745+
const void* input1_in,
746+
const size_t input1_size,
747+
const size_t* input1_shape,
748+
const size_t input1_shape_ndim,
749+
const void* input2_in,
750+
const size_t input2_size,
751+
const size_t* input2_shape,
752+
const size_t input2_shape_ndim,
753+
const size_t* where);
754+
728755
/**
729756
* @ingroup BACKEND_API
730757
* @brief Implementation of ones function

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -510,23 +510,6 @@ static void func_map_init_elemwise_2arg_3type(func_map_t& fmap)
510510
fmap[DPNPFuncName::DPNP_FN_MINIMUM][eft_DBL][eft_FLT] = {eft_DBL, (void*)dpnp_minimum_c<double, float, double>};
511511
fmap[DPNPFuncName::DPNP_FN_MINIMUM][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_minimum_c<double, double, double>};
512512

513-
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_multiply_c<int, int, int>};
514-
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_INT][eft_LNG] = {eft_LNG, (void*)dpnp_multiply_c<int, long, long>};
515-
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_INT][eft_FLT] = {eft_DBL, (void*)dpnp_multiply_c<int, float, double>};
516-
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_INT][eft_DBL] = {eft_DBL, (void*)dpnp_multiply_c<int, double, double>};
517-
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_LNG][eft_INT] = {eft_LNG, (void*)dpnp_multiply_c<long, int, long>};
518-
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_multiply_c<long, long, long>};
519-
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_LNG][eft_FLT] = {eft_DBL, (void*)dpnp_multiply_c<long, float, double>};
520-
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_LNG][eft_DBL] = {eft_DBL, (void*)dpnp_multiply_c<long, double, double>};
521-
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_FLT][eft_INT] = {eft_DBL, (void*)dpnp_multiply_c<float, int, double>};
522-
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_FLT][eft_LNG] = {eft_DBL, (void*)dpnp_multiply_c<float, long, double>};
523-
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_multiply_c<float, float, float>};
524-
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_FLT][eft_DBL] = {eft_DBL, (void*)dpnp_multiply_c<float, double, double>};
525-
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_DBL][eft_INT] = {eft_DBL, (void*)dpnp_multiply_c<double, int, double>};
526-
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_DBL][eft_LNG] = {eft_DBL, (void*)dpnp_multiply_c<double, long, double>};
527-
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_DBL][eft_FLT] = {eft_DBL, (void*)dpnp_multiply_c<double, float, double>};
528-
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_multiply_c<double, double, double>};
529-
530513
fmap[DPNPFuncName::DPNP_FN_POWER][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_power_c<int, int, int>};
531514
fmap[DPNPFuncName::DPNP_FN_POWER][eft_INT][eft_LNG] = {eft_LNG, (void*)dpnp_power_c<int, long, long>};
532515
fmap[DPNPFuncName::DPNP_FN_POWER][eft_INT][eft_FLT] = {eft_DBL, (void*)dpnp_power_c<int, float, double>};

dpnp/backend/kernels/dpnp_krnl_mathematical.cpp

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,79 @@ void dpnp_modf_c(void* array1_in, void* result1_out, void* result2_out, size_t s
232232
event.wait();
233233
}
234234

235+
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
236+
class dpnp_multiply_c_kernel;
237+
238+
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
239+
void dpnp_multiply_c(void* result_out,
240+
const void* input1_in,
241+
const size_t input1_size,
242+
const size_t* input1_shape,
243+
const size_t input1_shape_ndim,
244+
const void* input2_in,
245+
const size_t input2_size,
246+
const size_t* input2_shape,
247+
const size_t input2_shape_ndim,
248+
const size_t* where)
249+
{
250+
// avoid warning unused variable
251+
(void)input1_shape;
252+
(void)input1_shape_ndim;
253+
(void)input2_shape;
254+
(void)input2_shape_ndim;
255+
(void)where;
256+
257+
if (!input1_size || !input2_size)
258+
{
259+
return;
260+
}
261+
262+
const size_t result_size = (input2_size > input1_size) ? input2_size : input1_size;
263+
264+
const _DataType_input1* input1_data = reinterpret_cast<const _DataType_input1*>(input1_in);
265+
const _DataType_input2* input2_data = reinterpret_cast<const _DataType_input2*>(input2_in);
266+
_DataType_output* result = reinterpret_cast<_DataType_output*>(result_out);
267+
268+
cl::sycl::range<1> gws(result_size);
269+
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
270+
size_t i = global_id[0]; /*for (size_t i = 0; i < result_size; ++i)*/
271+
{
272+
const _DataType_input1 input1_elem = (input1_size == 1) ? input1_data[0] : input1_data[i];
273+
const _DataType_input2 input2_elem = (input2_size == 1) ? input2_data[0] : input2_data[i];
274+
result[i] = input1_elem * input2_elem;
275+
}
276+
};
277+
auto kernel_func = [&](cl::sycl::handler& cgh) {
278+
cgh.parallel_for<class dpnp_multiply_c_kernel<_DataType_output, _DataType_input1,
279+
_DataType_input2>>(gws, kernel_parallel_for_func);
280+
};
281+
282+
cl::sycl::event event;
283+
284+
if (input1_size == input2_size)
285+
{
286+
if constexpr ((std::is_same<_DataType_input1, double>::value ||
287+
std::is_same<_DataType_input1, float>::value) &&
288+
std::is_same<_DataType_input2, _DataType_input1>::value)
289+
{
290+
_DataType_input1* input1 = const_cast<_DataType_input1*>(input1_data);
291+
_DataType_input2* input2 = const_cast<_DataType_input2*>(input2_data);
292+
// https://docs.oneapi.com/versions/latest/onemkl/mul.html
293+
event = oneapi::mkl::vm::mul(DPNP_QUEUE, result_size, input1, input2, result);
294+
}
295+
else
296+
{
297+
event = DPNP_QUEUE.submit(kernel_func);
298+
}
299+
}
300+
else
301+
{
302+
event = DPNP_QUEUE.submit(kernel_func);
303+
}
304+
305+
event.wait();
306+
}
307+
235308
template <typename _KernelNameSpecialization1, typename _KernelNameSpecialization2, typename _KernelNameSpecialization3>
236309
class dpnp_remainder_c_kernel;
237310

@@ -411,6 +484,34 @@ void func_map_init_mathematical(func_map_t& fmap)
411484
fmap[DPNPFuncName::DPNP_FN_MODF][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_modf_c<float, float>};
412485
fmap[DPNPFuncName::DPNP_FN_MODF][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_modf_c<double, double>};
413486

487+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_BLN][eft_BLN] = {eft_BLN, (void*)dpnp_multiply_c<bool, bool, bool>};
488+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_BLN][eft_INT] = {eft_INT, (void*)dpnp_multiply_c<int, bool, int>};
489+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_BLN][eft_LNG] = {eft_LNG, (void*)dpnp_multiply_c<long, bool, long>};
490+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_BLN][eft_FLT] = {eft_FLT, (void*)dpnp_multiply_c<float, bool, float>};
491+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_BLN][eft_DBL] = {eft_DBL, (void*)dpnp_multiply_c<double, bool, double>};
492+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_INT][eft_BLN] = {eft_INT, (void*)dpnp_multiply_c<int, int, bool>};
493+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_multiply_c<int, int, int>};
494+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_INT][eft_LNG] = {eft_LNG, (void*)dpnp_multiply_c<long, int, long>};
495+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_INT][eft_FLT] = {eft_DBL, (void*)dpnp_multiply_c<double, int, float>};
496+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_INT][eft_DBL] = {eft_DBL, (void*)dpnp_multiply_c<double, int, double>};
497+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_LNG][eft_BLN] = {eft_LNG, (void*)dpnp_multiply_c<long, long, bool>};
498+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_LNG][eft_INT] = {eft_LNG, (void*)dpnp_multiply_c<long, long, int>};
499+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_multiply_c<long, long, long>};
500+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_LNG][eft_FLT] = {eft_DBL, (void*)dpnp_multiply_c<double, long, float>};
501+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_LNG][eft_DBL] = {eft_DBL, (void*)dpnp_multiply_c<double, long, double>};
502+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_FLT][eft_BLN] = {eft_FLT, (void*)dpnp_multiply_c<float, float, bool>};
503+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_FLT][eft_INT] = {eft_DBL, (void*)dpnp_multiply_c<double, float, int>};
504+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_FLT][eft_LNG] = {eft_DBL, (void*)dpnp_multiply_c<double, float, long>};
505+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_multiply_c<float, float, float>};
506+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_FLT][eft_DBL] = {eft_DBL, (void*)dpnp_multiply_c<double, float, double>};
507+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_DBL][eft_BLN] = {eft_DBL, (void*)dpnp_multiply_c<double, double, bool>};
508+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_DBL][eft_INT] = {eft_DBL, (void*)dpnp_multiply_c<double, double, int>};
509+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_DBL][eft_LNG] = {eft_DBL, (void*)dpnp_multiply_c<double, double, long>};
510+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_DBL][eft_FLT] = {eft_DBL, (void*)dpnp_multiply_c<double, double, float>};
511+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_multiply_c<double, double, double>};
512+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_C128][eft_C128] = {
513+
eft_C128, (void*)dpnp_multiply_c<std::complex<double>, std::complex<double>, std::complex<double>>};
514+
414515
fmap[DPNPFuncName::DPNP_FN_REMAINDER][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_remainder_c<int, int, int>};
415516
fmap[DPNPFuncName::DPNP_FN_REMAINDER][eft_INT][eft_LNG] = {eft_LNG, (void*)dpnp_remainder_c<int, long, long>};
416517
fmap[DPNPFuncName::DPNP_FN_REMAINDER][eft_INT][eft_FLT] = {eft_DBL, (void*)dpnp_remainder_c<int, float, double>};

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ ctypedef void(*fptr_1out_t)(void *, size_t)
212212
ctypedef void(*fptr_1in_1out_t)(void * , void * , size_t)
213213
ctypedef void(*fptr_2in_1out_t)(void * , void*, void*, size_t)
214214
ctypedef void(*fptr_2in_1out_new_t)(void * , void*, size_t, void*, size_t) # to be fused with fptr_2in_1out_t
215+
ctypedef void(*fptr_2in_1out_full_t)(void *, const void *, const size_t, const long*, const size_t,
216+
const void *, const size_t, const long*, const size_t, const long*)
215217
ctypedef void(*fptr_blas_gemm_2in_1out_t)(void * , void * , void * , size_t, size_t, size_t)
216218
ctypedef void(*dpnp_reduction_c_t)(void * , const void * , const size_t*, const size_t, const long*, const size_t, const void * , const long*)
217219

@@ -282,7 +284,7 @@ cpdef dparray dpnp_divide(dparray array1, dparray array2)
282284
cpdef dparray dpnp_hypot(dparray array1, dparray array2)
283285
cpdef dparray dpnp_maximum(dparray array1, dparray array2)
284286
cpdef dparray dpnp_minimum(dparray array1, dparray array2)
285-
cpdef dparray dpnp_multiply(dparray array1, array2)
287+
cpdef dparray dpnp_multiply(object x1_obj, object x2_obj)
286288
cpdef dparray dpnp_negative(dparray array1)
287289
cpdef dparray dpnp_power(dparray array1, array2)
288290
cpdef dparray dpnp_remainder(dparray array1, dparray array2)

dpnp/dpnp_algo/dpnp_algo_mathematical.pyx

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ and the rest of the library
3232
3333
"""
3434

35-
3635
from dpnp.dpnp_utils cimport *
3736
import dpnp
3837
import numpy
@@ -273,23 +272,50 @@ cpdef tuple dpnp_modf(dparray x1):
273272
return result1, result2
274273

275274

276-
cpdef dparray dpnp_multiply(dparray x1, x2):
277-
cdef dparray result
278-
if dpnp.isscalar(x2):
279-
x2_ = dpnp.array([x2])
275+
cpdef dparray dpnp_multiply(object x1_obj, object x2_obj, dparray out=None, object where=True):
276+
cdef dparray_shape_type x1_shape, x2_shape, result_shape
280277

281-
types_map = {
282-
(dpnp.int32, dpnp.float64): dpnp.float64,
283-
(dpnp.int64, dpnp.float64): dpnp.float64,
284-
}
278+
cdef bint x1_obj_is_dparray = isinstance(x1_obj, dparray)
279+
cdef bint x2_obj_is_dparray = isinstance(x2_obj, dparray)
285280

286-
res_type = types_map.get((x1.dtype.type, x2_.dtype.type), x1.dtype)
287-
result = dparray(x1.shape, dtype=res_type)
288-
for i in range(x1.size):
289-
result[i] = x1[i] * x2
290-
return result.reshape(x1.shape)
281+
cdef dparray x1_dparray, x2_dparray
282+
283+
common_type = find_common_type(x1_obj, x2_obj)
284+
285+
if x1_obj_is_dparray:
286+
x1_dparray = x1_obj
291287
else:
292-
return call_fptr_2in_1out(DPNP_FN_MULTIPLY, x1, x2, x1.shape)
288+
x1_dparray = dparray((1,), dtype=common_type)
289+
copy_values_to_dparray(x1_dparray, (x1_obj,))
290+
291+
if x2_obj_is_dparray:
292+
x2_dparray = x2_obj
293+
else:
294+
x2_dparray = dparray((1,), dtype=common_type)
295+
copy_values_to_dparray(x2_dparray, (x2_obj,))
296+
297+
x1_shape = x1_dparray.shape
298+
x2_shape = x2_dparray.shape
299+
result_shape = get_common_shape(x1_shape, x2_shape)
300+
301+
# Convert string type names (dparray.dtype) to C enum DPNPFuncType
302+
cdef DPNPFuncType x1_c_type = dpnp_dtype_to_DPNPFuncType(x1_dparray.dtype)
303+
cdef DPNPFuncType x2_c_type = dpnp_dtype_to_DPNPFuncType(x2_dparray.dtype)
304+
305+
# get the FPTR data structure
306+
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_MULTIPLY, x1_c_type, x2_c_type)
307+
308+
cdef DPNPFuncType result_c_type = get_output_c_type(DPNP_FN_MULTIPLY, kernel_data.return_type, out, None)
309+
310+
# Create result array
311+
cdef dparray result = create_output_array(result_shape, result_c_type, out)
312+
313+
cdef fptr_2in_1out_full_t func = <fptr_2in_1out_full_t > kernel_data.ptr
314+
# Call FPTR function
315+
func(result.get_data(), x1_dparray.get_data(), x1_dparray.size, x1_shape.data(), x1_shape.size(),
316+
x2_dparray.get_data(), x2_dparray.size, x2_shape.data(), x2_shape.size(), NULL)
317+
318+
return result
293319

294320

295321
cpdef dparray dpnp_nancumprod(dparray x1):

dpnp/dpnp_iface_mathematical.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,7 @@ def modf(x, **kwargs):
972972
return call_origin(numpy.modf, x, **kwargs)
973973

974974

975-
def multiply(x1, x2, **kwargs):
975+
def multiply(x1, x2, out=None, where=True, **kwargs):
976976
"""
977977
Multiply arguments element-wise.
978978
@@ -995,33 +995,32 @@ def multiply(x1, x2, **kwargs):
995995
[1, 4, 9, 16, 25]
996996
997997
"""
998+
x1_is_scalar, x2_is_scalar = dpnp.isscalar(x1), dpnp.isscalar(x2)
999+
x1_is_dparray, x2_is_dparray = isinstance(x1, dparray), isinstance(x2, dparray)
9981000

999-
is_x1_dparray = isinstance(x1, dparray)
1000-
is_x2_dparray = isinstance(x2, dparray)
1001-
1002-
is_x1_scalar = dpnp.isscalar(x1)
1003-
is_x2_scalar = dpnp.isscalar(x2)
1004-
1005-
if not use_origin_backend(x1):
1006-
if kwargs:
1001+
if not use_origin_backend(x1) and not kwargs:
1002+
if not x1_is_dparray and not x1_is_scalar:
1003+
pass
1004+
elif not x2_is_dparray and not x2_is_scalar:
1005+
pass
1006+
elif x1_is_scalar and x2_is_scalar:
10071007
pass
1008-
elif not (is_x1_dparray or is_x1_scalar):
1008+
elif x1_is_dparray and x1.ndim == 0:
10091009
pass
1010-
elif not (is_x2_dparray or is_x2_scalar):
1010+
elif x2_is_dparray and x2.ndim == 0:
10111011
pass
1012-
elif is_x1_scalar and is_x2_scalar:
1012+
elif x1_is_dparray and x2_is_dparray and x1.size != x2.size:
10131013
pass
1014-
elif (is_x1_dparray and is_x2_dparray) and (x1.size != x2.size):
1014+
elif x1_is_dparray and x2_is_dparray and x1.shape != x2.shape:
1015+
pass
1016+
elif out is not None and not isinstance(out, dparray):
10151017
pass
1016-
elif (is_x1_dparray and is_x2_dparray) and (x1.shape != x2.shape):
1018+
elif not where:
10171019
pass
10181020
else:
1019-
if is_x1_scalar:
1020-
return dpnp_multiply(x2, x1)
1021-
else:
1022-
return dpnp_multiply(x1, x2)
1021+
return dpnp_multiply(x1, x2, out, where)
10231022

1024-
return call_origin(numpy.multiply, x1, x2, **kwargs)
1023+
return call_origin(numpy.multiply, x1, x2, out=out, where=where, **kwargs)
10251024

10261025

10271026
def nancumprod(x1, **kwargs):

dpnp/dpnp_utils/dpnp_algo_utils.pxd

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ Returns a tuple of:
9595
2. dtype
9696
"""
9797

98+
cpdef find_common_type(object x1_obj, object x2_obj)
99+
"""
100+
Find common type of 2 input objects
101+
"""
102+
98103
cdef long copy_values_to_dparray(dparray dst, input_obj, size_t dst_idx=*) except -1
99104
"""
100105
Copy values to `dst` by iterating element by element in `input_obj`
@@ -125,6 +130,11 @@ cpdef nd2dp_array(arr)
125130
Convert ndarray to dparray
126131
"""
127132

133+
cdef dparray_shape_type get_common_shape(dparray_shape_type input1_shape, dparray_shape_type input2_shape)
134+
"""
135+
Calculate common shape from input shapes
136+
"""
137+
128138
cdef dparray_shape_type get_reduction_output_shape(dparray_shape_type input_shape, object axis, cpp_bool keepdims)
129139
"""
130140
Calculate output array shape in reduction functions

0 commit comments

Comments
 (0)