Skip to content

Implement strides for 11 element-wise functions #985

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
37 changes: 32 additions & 5 deletions dpnp/backend/include/dpnp_gen_2arg_3type_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,39 @@
/** */ \
/** Function "__name__" executes operator "__operation1__" over corresponding elements of input arrays */ \
/** */ \
/** @param[in] array1 Input array 1. */ \
/** @param[in] array2 Input array 2. */ \
/** @param[out] result1 Output array. */ \
/** @param[in] size Number of elements in the output array. */ \
/** @param[out] result_out Output array. */ \
/** @param[in] result_size Output array size. */ \
/** @param[in] result_ndim Number of output array dimensions. */ \
/** @param[in] result_shape Output array shape. */ \
/** @param[in] result_strides Output array strides. */ \
/** @param[in] input1_in Input array 1. */ \
/** @param[in] input1_size Input array 1 size. */ \
/** @param[in] input1_ndim Number of input array 1 dimensions. */ \
/** @param[in] input1_shape Input array 1 shape. */ \
/** @param[in] input1_strides Input array 1 strides. */ \
/** @param[in] input2_in Input array 2. */ \
/** @param[in] input2_size Input array 2 size. */ \
/** @param[in] input2_ndim Number of input array 2 dimensions. */ \
/** @param[in] input2_shape Input array 2 shape. */ \
/** @param[in] input2_strides Input array 2 strides. */ \
/** @param[in] where Where condition. */ \
template <typename _DataType_input1, typename _DataType_input2, typename _DataType_output> \
void __name__(void* array1, void* array2, void* result1, size_t size);
void __name__(void* result_out, \
const size_t result_size, \
const size_t result_ndim, \
const size_t* result_shape, \
const size_t* result_strides, \
const void* input1_in, \
const size_t input1_size, \
const size_t input1_ndim, \
const size_t* input1_shape, \
const size_t* input1_strides, \
const void* input2_in, \
const size_t input2_size, \
const size_t input2_ndim, \
const size_t* input2_shape, \
const size_t* input2_strides, \
const size_t* where)

#endif

Expand Down
10 changes: 8 additions & 2 deletions dpnp/backend/include/dpnp_iface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -966,14 +966,20 @@ INP_DLLEXPORT void dpnp_invert_c(void* array1_in, void* result, size_t size);
#define MACRO_2ARG_3TYPES_OP(__name__, __operation1__, __operation2__) \
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2> \
INP_DLLEXPORT void __name__(void* result_out, \
const size_t result_size, \
const size_t result_ndim, \
const size_t* result_shape, \
const size_t* result_strides, \
const void* input1_in, \
const size_t input1_size, \
const size_t input1_ndim, \
const size_t* input1_shape, \
const size_t input1_shape_ndim, \
const size_t* input1_strides, \
const void* input2_in, \
const size_t input2_size, \
const size_t input2_ndim, \
const size_t* input2_shape, \
const size_t input2_shape_ndim, \
const size_t* input2_strides, \
const size_t* where);

#include <dpnp_gen_2arg_3type_tbl.hpp>
Expand Down
12 changes: 8 additions & 4 deletions dpnp/backend/kernels/dpnp_krnl_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,6 @@ void dpnp_dot_c(void* result_out,
const size_t* input2_shape,
const size_t* input2_strides)
{
(void)result_strides;

DPNPC_ptr_adapter<_DataType_input1> input1_ptr(input1_in, input1_size);
DPNPC_ptr_adapter<_DataType_input2> input2_ptr(input2_in, input2_size);

Expand All @@ -195,14 +193,20 @@ void dpnp_dot_c(void* result_out,
// there is no support of strides in multiply function
// so result can be wrong if input array has non-standard (c-contiguous) strides
dpnp_multiply_c<_DataType_output, _DataType_input1, _DataType_input2>(result,
result_size,
result_ndim,
result_shape,
result_strides,
input1_in,
input1_size,
input1_shape,
input1_ndim,
input1_shape,
input1_strides,
input2_in,
input2_size,
input2_shape,
input2_ndim,
input2_shape,
input2_strides,
NULL);
return;
}
Expand Down
Loading