Skip to content

Commit ff875ab

Browse files
Kernels for max min mean median (#102)
* implementation of kernels for max, min, mean, median funcs
1 parent 5f71e74 commit ff875ab

12 files changed

+456
-87
lines changed

dpnp/backend.pxd

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,11 @@ cdef extern from "backend/backend_iface_fptr.hpp" namespace "DPNPFuncName": # n
6464
DPNP_FN_LOG1P
6565
DPNP_FN_LOG2
6666
DPNP_FN_MATMUL
67+
DPNP_FN_MAX
6768
DPNP_FN_MAXIMUM
69+
DPNP_FN_MEAN
70+
DPNP_FN_MEDIAN
71+
DPNP_FN_MIN
6872
DPNP_FN_MINIMUM
6973
DPNP_FN_MULTIPLY
7074
DPNP_FN_POWER
@@ -119,9 +123,9 @@ cdef extern from "backend/backend_iface.hpp":
119123

120124

121125
# C function pointer to the C library template functions
122-
ctypedef void(*fptr_1in_1out_t)(void *, void * , size_t)
123-
ctypedef void(*fptr_2in_1out_t)(void *, void*, void*, size_t)
124-
ctypedef void(*fptr_blas_gemm_2in_1out_t)(void *, void *, void *, size_t, size_t, size_t)
126+
ctypedef void(*fptr_1in_1out_t)(void * , void * , size_t)
127+
ctypedef void(*fptr_2in_1out_t)(void * , void*, void*, size_t)
128+
ctypedef void(*fptr_blas_gemm_2in_1out_t)(void * , void * , void * , size_t, size_t, size_t)
125129

126130
cdef dparray call_fptr_1in_1out(DPNPFuncName fptr_name, dparray x1, dparray_shape_type result_shape)
127131
cdef dparray call_fptr_2in_1out(DPNPFuncName fptr_name, dparray x1, dparray x2, dparray_shape_type result_shape)

dpnp/backend/backend_iface.hpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,86 @@ INP_DLLEXPORT void custom_sort_c(void* array, void* result, size_t size);
213213
template <typename _DataType>
214214
INP_DLLEXPORT void custom_cov_c(void* array1_in, void* result1, size_t nrows, size_t ncols);
215215

216+
/**
217+
* @ingroup BACKEND_API
218+
* @brief MKL implementation of max function
219+
*
220+
* @param [in] array Input array with data.
221+
*
222+
* @param [out] result Output array.
223+
*
224+
* @param [in] shape Shape of input array.
225+
*
226+
* @param [in] ndim Number of elements in shape.
227+
*
228+
* @param [in] axis Axis.
229+
*
230+
* @param [in] naxis Number of elements in axis.
231+
*
232+
*/
233+
template <typename _DataType>
234+
INP_DLLEXPORT void custom_max_c(void* array1_in, void* result1, size_t* shape, size_t ndim, size_t* axis, size_t naxis);
235+
236+
/**
237+
* @ingroup BACKEND_API
238+
* @brief MKL implementation of mean function
239+
*
240+
* @param [in] array Input array with data.
241+
*
242+
* @param [out] result Output array.
243+
*
244+
* @param [in] shape Shape of input array.
245+
*
246+
* @param [in] ndim Number of elements in shape.
247+
*
248+
* @param [in] axis Axis.
249+
*
250+
* @param [in] naxis Number of elements in axis.
251+
*
252+
*/
253+
template <typename _DataType, typename _ResultType>
254+
INP_DLLEXPORT void custom_mean_c(void* array, void* result, size_t* shape, size_t ndim, size_t* axis, size_t naxis);
255+
256+
/**
257+
* @ingroup BACKEND_API
258+
* @brief MKL implementation of median function
259+
*
260+
* @param [in] array Input array with data.
261+
*
262+
* @param [out] result Output array.
263+
*
264+
* @param [in] shape Shape of input array.
265+
*
266+
* @param [in] ndim Number of elements in shape.
267+
*
268+
* @param [in] axis Axis.
269+
*
270+
* @param [in] naxis Number of elements in axis.
271+
*
272+
*/
273+
template <typename _DataType, typename _ResultType>
274+
INP_DLLEXPORT void custom_median_c(void* array, void* result, size_t* shape, size_t ndim, size_t* axis, size_t naxis);
275+
276+
/**
277+
* @ingroup BACKEND_API
278+
* @brief MKL implementation of min function
279+
*
280+
* @param [in] array Input array with data.
281+
*
282+
* @param [out] result Output array.
283+
*
284+
* @param [in] shape Shape of input array.
285+
*
286+
* @param [in] ndim Number of elements in shape.
287+
*
288+
* @param [in] axis Axis.
289+
*
290+
* @param [in] naxis Number of elements in axis.
291+
*
292+
*/
293+
template <typename _DataType>
294+
INP_DLLEXPORT void custom_min_c(void* array, void* result, size_t* shape, size_t ndim, size_t* axis, size_t naxis);
295+
216296
/**
217297
* @ingroup BACKEND_API
218298
* @brief MKL implementation of argmax function

dpnp/backend/backend_iface_fptr.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,11 @@ static func_map_t func_map_init()
414414
fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_blas_gemm_c<float>};
415415
fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_blas_gemm_c<double>};
416416

417+
fmap[DPNPFuncName::DPNP_FN_MAX][eft_INT][eft_INT] = {eft_INT, (void*)custom_max_c<int>};
418+
fmap[DPNPFuncName::DPNP_FN_MAX][eft_LNG][eft_LNG] = {eft_LNG, (void*)custom_max_c<long>};
419+
fmap[DPNPFuncName::DPNP_FN_MAX][eft_FLT][eft_FLT] = {eft_FLT, (void*)custom_max_c<float>};
420+
fmap[DPNPFuncName::DPNP_FN_MAX][eft_DBL][eft_DBL] = {eft_DBL, (void*)custom_max_c<double>};
421+
417422
fmap[DPNPFuncName::DPNP_FN_MAXIMUM][eft_INT][eft_INT] = {eft_INT, (void*)custom_elemwise_maximum_c<int, int, int>};
418423
fmap[DPNPFuncName::DPNP_FN_MAXIMUM][eft_INT][eft_LNG] = {eft_LNG,
419424
(void*)custom_elemwise_maximum_c<int, long, long>};
@@ -446,6 +451,21 @@ static func_map_t func_map_init()
446451
fmap[DPNPFuncName::DPNP_FN_MAXIMUM][eft_DBL][eft_DBL] = {eft_DBL,
447452
(void*)custom_elemwise_maximum_c<double, double, double>};
448453

454+
fmap[DPNPFuncName::DPNP_FN_MEAN][eft_INT][eft_INT] = {eft_DBL, (void*)custom_mean_c<int, double>};
455+
fmap[DPNPFuncName::DPNP_FN_MEAN][eft_LNG][eft_LNG] = {eft_DBL, (void*)custom_mean_c<long, double>};
456+
fmap[DPNPFuncName::DPNP_FN_MEAN][eft_FLT][eft_FLT] = {eft_FLT, (void*)custom_mean_c<float, float>};
457+
fmap[DPNPFuncName::DPNP_FN_MEAN][eft_DBL][eft_DBL] = {eft_DBL, (void*)custom_mean_c<double, double>};
458+
459+
fmap[DPNPFuncName::DPNP_FN_MEDIAN][eft_INT][eft_INT] = {eft_DBL, (void*)custom_median_c<int, double>};
460+
fmap[DPNPFuncName::DPNP_FN_MEDIAN][eft_LNG][eft_LNG] = {eft_DBL, (void*)custom_median_c<long, double>};
461+
fmap[DPNPFuncName::DPNP_FN_MEDIAN][eft_FLT][eft_FLT] = {eft_DBL, (void*)custom_median_c<float, double>};
462+
fmap[DPNPFuncName::DPNP_FN_MEDIAN][eft_DBL][eft_DBL] = {eft_DBL, (void*)custom_median_c<double, double>};
463+
464+
fmap[DPNPFuncName::DPNP_FN_MIN][eft_INT][eft_INT] = {eft_INT, (void*)custom_min_c<int>};
465+
fmap[DPNPFuncName::DPNP_FN_MIN][eft_LNG][eft_LNG] = {eft_LNG, (void*)custom_min_c<long>};
466+
fmap[DPNPFuncName::DPNP_FN_MIN][eft_FLT][eft_FLT] = {eft_FLT, (void*)custom_min_c<float>};
467+
fmap[DPNPFuncName::DPNP_FN_MIN][eft_DBL][eft_DBL] = {eft_DBL, (void*)custom_min_c<double>};
468+
449469
fmap[DPNPFuncName::DPNP_FN_MINIMUM][eft_INT][eft_INT] = {eft_INT, (void*)custom_elemwise_minimum_c<int, int, int>};
450470
fmap[DPNPFuncName::DPNP_FN_MINIMUM][eft_INT][eft_LNG] = {eft_LNG,
451471
(void*)custom_elemwise_minimum_c<int, long, long>};

dpnp/backend/backend_iface_fptr.hpp

Lines changed: 57 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -59,62 +59,66 @@
5959
*/
6060
enum class DPNPFuncName : size_t
6161
{
62-
DPNP_FN_NONE, /**< Very first element of the enumeration */
63-
DPNP_FN_ADD, /**< Used in numpy.add() implementation */
64-
DPNP_FN_ARCCOS, /**< Used in numpy.arccos() implementation */
65-
DPNP_FN_ARCCOSH, /**< Used in numpy.arccosh() implementation */
66-
DPNP_FN_ARCSIN, /**< Used in numpy.arcsin() implementation */
67-
DPNP_FN_ARCSINH, /**< Used in numpy.arcsinh() implementation */
68-
DPNP_FN_ARCTAN, /**< Used in numpy.arctan() implementation */
69-
DPNP_FN_ARCTAN2, /**< Used in numpy.arctan2() implementation */
70-
DPNP_FN_ARCTANH, /**< Used in numpy.arctanh() implementation */
71-
DPNP_FN_ARGMAX, /**< Used in numpy.argmax() implementation */
72-
DPNP_FN_ARGMIN, /**< Used in numpy.argmin() implementation */
73-
DPNP_FN_ARGSORT, /**< Used in numpy.argsort() implementation */
74-
DPNP_FN_CBRT, /**< Used in numpy.cbrt() implementation */
75-
DPNP_FN_CEIL, /**< Used in numpy.ceil() implementation */
76-
DPNP_FN_COS, /**< Used in numpy.cos() implementation */
77-
DPNP_FN_COSH, /**< Used in numpy.cosh() implementation */
78-
DPNP_FN_COV, /**< Used in numpy.cov() implementation */
79-
DPNP_FN_DEGREES, /**< Used in numpy.degrees() implementation */
80-
DPNP_FN_DIVIDE, /**< Used in numpy.divide() implementation */
81-
DPNP_FN_DOT, /**< Used in numpy.dot() implementation */
82-
DPNP_FN_EIG, /**< Used in numpy.linalg.eig() implementation */
83-
DPNP_FN_EXP, /**< Used in numpy.exp() implementation */
84-
DPNP_FN_EXP2, /**< Used in numpy.exp2() implementation */
85-
DPNP_FN_EXPM1, /**< Used in numpy.expm1() implementation */
86-
DPNP_FN_FABS, /**< Used in numpy.fabs() implementation */
87-
DPNP_FN_FLOOR, /**< Used in numpy.floor() implementation */
88-
DPNP_FN_FMOD, /**< Used in numpy.fmod() implementation */
89-
DPNP_FN_GAUSSIAN, /**< Used in numpy.random.randn() implementation */
90-
DPNP_FN_HYPOT, /**< Used in numpy.hypot() implementation */
91-
DPNP_FN_LOG, /**< Used in numpy.log() implementation */
92-
DPNP_FN_LOG10, /**< Used in numpy.log10() implementation */
93-
DPNP_FN_LOG2, /**< Used in numpy.log2() implementation */
94-
DPNP_FN_LOG1P, /**< Used in numpy.log1p() implementation */
95-
DPNP_FN_MATMUL, /**< Used in numpy.matmul() implementation */
96-
DPNP_FN_MAXIMUM, /**< Used in numpy.maximum() implementation */
97-
DPNP_FN_MINIMUM, /**< Used in numpy.minimum() implementation */
98-
DPNP_FN_MULTIPLY, /**< Used in numpy.multiply() implementation */
99-
DPNP_FN_POWER, /**< Used in numpy.random.power() implementation */
100-
DPNP_FN_PROD, /**< Used in numpy.prod() implementation */
101-
DPNP_FN_UNIFORM, /**< Used in numpy.random.uniform() implementation */
102-
DPNP_FN_RADIANS, /**< Used in numpy.radians() implementation */
103-
DPNP_FN_RANDOM, /**< Used in numpy.random.random() implementation */
104-
DPNP_FN_RECIP, /**< Used in numpy.recip() implementation */
105-
DPNP_FN_SIGN, /**< Used in numpy.sign() implementation */
106-
DPNP_FN_SIN, /**< Used in numpy.sin() implementation */
107-
DPNP_FN_SINH, /**< Used in numpy.sinh() implementation */
108-
DPNP_FN_SORT, /**< Used in numpy.sort() implementation */
109-
DPNP_FN_SQRT, /**< Used in numpy.sqrt() implementation */
62+
DPNP_FN_NONE, /**< Very first element of the enumeration */
63+
DPNP_FN_ADD, /**< Used in numpy.add() implementation */
64+
DPNP_FN_ARCCOS, /**< Used in numpy.arccos() implementation */
65+
DPNP_FN_ARCCOSH, /**< Used in numpy.arccosh() implementation */
66+
DPNP_FN_ARCSIN, /**< Used in numpy.arcsin() implementation */
67+
DPNP_FN_ARCSINH, /**< Used in numpy.arcsinh() implementation */
68+
DPNP_FN_ARCTAN, /**< Used in numpy.arctan() implementation */
69+
DPNP_FN_ARCTAN2, /**< Used in numpy.arctan2() implementation */
70+
DPNP_FN_ARCTANH, /**< Used in numpy.arctanh() implementation */
71+
DPNP_FN_ARGMAX, /**< Used in numpy.argmax() implementation */
72+
DPNP_FN_ARGMIN, /**< Used in numpy.argmin() implementation */
73+
DPNP_FN_ARGSORT, /**< Used in numpy.argsort() implementation */
74+
DPNP_FN_CBRT, /**< Used in numpy.cbrt() implementation */
75+
DPNP_FN_CEIL, /**< Used in numpy.ceil() implementation */
76+
DPNP_FN_COS, /**< Used in numpy.cos() implementation */
77+
DPNP_FN_COSH, /**< Used in numpy.cosh() implementation */
78+
DPNP_FN_COV, /**< Used in numpy.cov() implementation */
79+
DPNP_FN_DEGREES, /**< Used in numpy.degrees() implementation */
80+
DPNP_FN_DIVIDE, /**< Used in numpy.divide() implementation */
81+
DPNP_FN_DOT, /**< Used in numpy.dot() implementation */
82+
DPNP_FN_EIG, /**< Used in numpy.linalg.eig() implementation */
83+
DPNP_FN_EXP, /**< Used in numpy.exp() implementation */
84+
DPNP_FN_EXP2, /**< Used in numpy.exp2() implementation */
85+
DPNP_FN_EXPM1, /**< Used in numpy.expm1() implementation */
86+
DPNP_FN_FABS, /**< Used in numpy.fabs() implementation */
87+
DPNP_FN_FLOOR, /**< Used in numpy.floor() implementation */
88+
DPNP_FN_FMOD, /**< Used in numpy.fmod() implementation */
89+
DPNP_FN_GAUSSIAN, /**< Used in numpy.random.randn() implementation */
90+
DPNP_FN_HYPOT, /**< Used in numpy.hypot() implementation */
91+
DPNP_FN_LOG, /**< Used in numpy.log() implementation */
92+
DPNP_FN_LOG10, /**< Used in numpy.log10() implementation */
93+
DPNP_FN_LOG2, /**< Used in numpy.log2() implementation */
94+
DPNP_FN_LOG1P, /**< Used in numpy.log1p() implementation */
95+
DPNP_FN_MATMUL, /**< Used in numpy.matmul() implementation */
96+
DPNP_FN_MAX, /**< Used in numpy.max() implementation */
97+
DPNP_FN_MAXIMUM, /**< Used in numpy.maximum() implementation */
98+
DPNP_FN_MEAN, /**< Used in numpy.mean() implementation */
99+
DPNP_FN_MEDIAN, /**< Used in numpy.median() implementation */
100+
DPNP_FN_MIN, /**< Used in numpy.min() implementation */
101+
DPNP_FN_MINIMUM, /**< Used in numpy.minimum() implementation */
102+
DPNP_FN_MULTIPLY, /**< Used in numpy.multiply() implementation */
103+
DPNP_FN_POWER, /**< Used in numpy.random.power() implementation */
104+
DPNP_FN_PROD, /**< Used in numpy.prod() implementation */
105+
DPNP_FN_UNIFORM, /**< Used in numpy.random.uniform() implementation */
106+
DPNP_FN_RADIANS, /**< Used in numpy.radians() implementation */
107+
DPNP_FN_RANDOM, /**< Used in numpy.random.random() implementation */
108+
DPNP_FN_RECIP, /**< Used in numpy.recip() implementation */
109+
DPNP_FN_SIGN, /**< Used in numpy.sign() implementation */
110+
DPNP_FN_SIN, /**< Used in numpy.sin() implementation */
111+
DPNP_FN_SINH, /**< Used in numpy.sinh() implementation */
112+
DPNP_FN_SORT, /**< Used in numpy.sort() implementation */
113+
DPNP_FN_SQRT, /**< Used in numpy.sqrt() implementation */
110114
DPNP_FN_SQUARE, /**< Used in numpy.square() implementation */
111-
DPNP_FN_SUBTRACT, /**< Used in numpy.subtract() implementation */
112-
DPNP_FN_SUM, /**< Used in numpy.sum() implementation */
113-
DPNP_FN_TAN, /**< Used in numpy.tan() implementation */
115+
DPNP_FN_SUBTRACT, /**< Used in numpy.subtract() implementation */
116+
DPNP_FN_SUM, /**< Used in numpy.sum() implementation */
117+
DPNP_FN_TAN, /**< Used in numpy.tan() implementation */
114118
DPNP_FN_TANH, /**< Used in numpy.tanh() implementation */
115119
DPNP_FN_TRANSPOSE, /**< Used in numpy.transpose() implementation */
116-
DPNP_FN_TRUNC, /**< Used in numpy.trunc() implementation */
117-
DPNP_FN_LAST /**< The latest element of the enumeration */
120+
DPNP_FN_TRUNC, /**< Used in numpy.trunc() implementation */
121+
DPNP_FN_LAST /**< The latest element of the enumeration */
118122
};
119123

120124
/**

dpnp/backend/custom_kernels_searching.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@ void custom_argmax_c(void* array1_in, void* result1, size_t size)
3838
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
3939
_idx_DataType* result = reinterpret_cast<_idx_DataType*>(result1);
4040

41-
auto queue = DPNP_QUEUE;
42-
4341
auto policy =
4442
oneapi::dpl::execution::make_device_policy<class custom_argmax_c_kernel<_DataType, _idx_DataType>>(DPNP_QUEUE);
4543

@@ -71,8 +69,6 @@ void custom_argmin_c(void* array1_in, void* result1, size_t size)
7169
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
7270
_idx_DataType* result = reinterpret_cast<_idx_DataType*>(result1);
7371

74-
auto queue = DPNP_QUEUE;
75-
7672
auto policy =
7773
oneapi::dpl::execution::make_device_policy<class custom_argmin_c_kernel<_DataType, _idx_DataType>>(DPNP_QUEUE);
7874

0 commit comments

Comments
 (0)