Skip to content

Commit 6ab618d

Browse files
committed
use_dpctl_remainder_func
1 parent 8ace62f commit 6ab618d

12 files changed

+308
-229
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
30+
#include "common.hpp"
31+
#include "types_matrix.hpp"
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace vm
40+
{
41+
template <typename T>
42+
sycl::event remainder_contig_impl(sycl::queue exec_q,
43+
const std::int64_t n,
44+
const char *in_a,
45+
const char *in_b,
46+
char *out_y,
47+
const std::vector<sycl::event> &depends)
48+
{
49+
type_utils::validate_type_for_device<T>(exec_q);
50+
51+
const T *a = reinterpret_cast<const T *>(in_a);
52+
const T *b = reinterpret_cast<const T *>(in_b);
53+
T *y = reinterpret_cast<T *>(out_y);
54+
55+
return mkl_vm::remainder(
56+
exec_q,
57+
n, // number of elements to be calculated
58+
a, // pointer `a` containing 1st input vector of size n
59+
b, // pointer `b` containing 2nd input vector of size n
60+
y, // pointer `y` to the output vector of size n
61+
depends);
62+
}
63+
64+
template <typename fnT, typename T>
65+
struct RemainderContigFactory
66+
{
67+
fnT get()
68+
{
69+
if constexpr (std::is_same_v<
70+
typename types::RemainderOutputType<T>::value_type,
71+
void>)
72+
{
73+
return nullptr;
74+
}
75+
else {
76+
return remainder_contig_impl<T>;
77+
}
78+
}
79+
};
80+
} // namespace vm
81+
} // namespace ext
82+
} // namespace backend
83+
} // namespace dpnp

dpnp/backend/extensions/vm/types_matrix.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,21 @@ struct DivOutputType
6868
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
6969
};
7070

71+
/**
72+
* @brief A factory to define pairs of supported types for which
73+
* MKL VM library provides support in oneapi::mkl::vm::remainder<T> function.
74+
*
75+
* @tparam T Type of input vectors `a` and `b` and of result vector `y`.
76+
*/
77+
template <typename T>
78+
struct RemainderOutputType
79+
{
80+
using value_type = typename std::disjunction<
81+
dpctl_td_ns::BinaryTypeMapResultEntry<T, double, T, double, double>,
82+
dpctl_td_ns::BinaryTypeMapResultEntry<T, float, T, float, float>,
83+
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
84+
};
85+
7186
/**
7287
* @brief A factory to define pairs of supported types for which
7388
* MKL VM library provides support in oneapi::mkl::vm::cos<T> function.

dpnp/backend/extensions/vm/vm_py.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "cos.hpp"
3535
#include "div.hpp"
3636
#include "ln.hpp"
37+
#include "remainder.hpp"
3738
#include "sin.hpp"
3839
#include "sqr.hpp"
3940
#include "sqrt.hpp"
@@ -46,6 +47,7 @@ using vm_ext::binary_impl_fn_ptr_t;
4647
using vm_ext::unary_impl_fn_ptr_t;
4748

4849
static binary_impl_fn_ptr_t div_dispatch_vector[dpctl_td_ns::num_types];
50+
static binary_impl_fn_ptr_t remainder_dispatch_vector[dpctl_td_ns::num_types];
4951

5052
static unary_impl_fn_ptr_t cos_dispatch_vector[dpctl_td_ns::num_types];
5153
static unary_impl_fn_ptr_t ln_dispatch_vector[dpctl_td_ns::num_types];
@@ -88,6 +90,37 @@ PYBIND11_MODULE(_vm_impl, m)
8890
py::arg("dst"));
8991
}
9092

93+
// BinaryUfunc: ==== REMAINDER(x1, x2) ====
94+
{
95+
vm_ext::init_ufunc_dispatch_vector<binary_impl_fn_ptr_t,
96+
vm_ext::RemainderContigFactory>(
97+
remainder_dispatch_vector);
98+
99+
auto remainder_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
100+
arrayT dst, const event_vecT &depends = {}) {
101+
return vm_ext::binary_ufunc(exec_q, src1, src2, dst, depends,
102+
remainder_dispatch_vector);
103+
};
104+
m.def("_remainder", remainder_pyapi,
105+
"Call `remainder` function from OneMKL VM library to performs "
106+
"element "
107+
"by element remainder of vector `src1` by vector `src2` "
108+
"to resulting vector `dst`",
109+
py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"),
110+
py::arg("dst"), py::arg("depends") = py::list());
111+
112+
auto remainder_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src1,
113+
arrayT src2, arrayT dst) {
114+
return vm_ext::need_to_call_binary_ufunc(exec_q, src1, src2, dst,
115+
remainder_dispatch_vector);
116+
};
117+
m.def("_mkl_remainder_to_call", remainder_need_to_call_pyapi,
118+
"Check input arguments to answer if `remainder` function from "
119+
"OneMKL VM library can be used",
120+
py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"),
121+
py::arg("dst"));
122+
}
123+
91124
// UnaryUfunc: ==== Cos(x) ====
92125
{
93126
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -313,25 +313,23 @@ enum class DPNPFuncName : size_t
313313
DPNP_FN_PUT_ALONG_AXIS_EXT, /**< Used in numpy.put_along_axis() impl,
314314
requires extra parameters */
315315
DPNP_FN_QR, /**< Used in numpy.linalg.qr() impl */
316-
DPNP_FN_QR_EXT, /**< Used in numpy.linalg.qr() impl, requires extra
317-
parameters */
318-
DPNP_FN_RADIANS, /**< Used in numpy.radians() impl */
319-
DPNP_FN_RADIANS_EXT, /**< Used in numpy.radians() impl, requires extra
320-
parameters */
321-
DPNP_FN_REMAINDER, /**< Used in numpy.remainder() impl */
322-
DPNP_FN_REMAINDER_EXT, /**< Used in numpy.remainder() impl, requires extra
323-
parameters */
324-
DPNP_FN_RECIP, /**< Used in numpy.recip() impl */
325-
DPNP_FN_RECIP_EXT, /**< Used in numpy.recip() impl, requires extra
326-
parameters */
327-
DPNP_FN_REPEAT, /**< Used in numpy.repeat() impl */
328-
DPNP_FN_REPEAT_EXT, /**< Used in numpy.repeat() impl, requires extra
329-
parameters */
330-
DPNP_FN_RIGHT_SHIFT, /**< Used in numpy.right_shift() impl */
331-
DPNP_FN_RNG_BETA, /**< Used in numpy.random.beta() impl */
332-
DPNP_FN_RNG_BETA_EXT, /**< Used in numpy.random.beta() impl, requires extra
333-
parameters */
334-
DPNP_FN_RNG_BINOMIAL, /**< Used in numpy.random.binomial() impl */
316+
DPNP_FN_QR_EXT, /**< Used in numpy.linalg.qr() impl, requires extra
317+
parameters */
318+
DPNP_FN_RADIANS, /**< Used in numpy.radians() impl */
319+
DPNP_FN_RADIANS_EXT, /**< Used in numpy.radians() impl, requires extra
320+
parameters */
321+
DPNP_FN_REMAINDER, /**< Used in numpy.remainder() impl */
322+
DPNP_FN_RECIP, /**< Used in numpy.recip() impl */
323+
DPNP_FN_RECIP_EXT, /**< Used in numpy.recip() impl, requires extra
324+
parameters */
325+
DPNP_FN_REPEAT, /**< Used in numpy.repeat() impl */
326+
DPNP_FN_REPEAT_EXT, /**< Used in numpy.repeat() impl, requires extra
327+
parameters */
328+
DPNP_FN_RIGHT_SHIFT, /**< Used in numpy.right_shift() impl */
329+
DPNP_FN_RNG_BETA, /**< Used in numpy.random.beta() impl */
330+
DPNP_FN_RNG_BETA_EXT, /**< Used in numpy.random.beta() impl, requires extra
331+
parameters */
332+
DPNP_FN_RNG_BINOMIAL, /**< Used in numpy.random.binomial() impl */
335333
DPNP_FN_RNG_BINOMIAL_EXT, /**< Used in numpy.random.binomial() impl,
336334
requires extra parameters */
337335
DPNP_FN_RNG_CHISQUARE, /**< Used in numpy.random.chisquare() impl */

dpnp/backend/kernels/dpnp_krnl_mathematical.cpp

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -988,23 +988,6 @@ void (*dpnp_remainder_default_c)(void *,
988988
const size_t *) =
989989
dpnp_remainder_c<_DataType_output, _DataType_input1, _DataType_input2>;
990990

991-
template <typename _DataType_output,
992-
typename _DataType_input1,
993-
typename _DataType_input2>
994-
DPCTLSyclEventRef (*dpnp_remainder_ext_c)(DPCTLSyclQueueRef,
995-
void *,
996-
const void *,
997-
const size_t,
998-
const shape_elem_type *,
999-
const size_t,
1000-
const void *,
1001-
const size_t,
1002-
const shape_elem_type *,
1003-
const size_t,
1004-
const size_t *,
1005-
const DPCTLEventVectorRef) =
1006-
dpnp_remainder_c<_DataType_output, _DataType_input1, _DataType_input2>;
1007-
1008991
template <typename _KernelNameSpecialization1,
1009992
typename _KernelNameSpecialization2,
1010993
typename _KernelNameSpecialization3>
@@ -1385,39 +1368,6 @@ void func_map_init_mathematical(func_map_t &fmap)
13851368
fmap[DPNPFuncName::DPNP_FN_REMAINDER][eft_DBL][eft_DBL] = {
13861369
eft_DBL, (void *)dpnp_remainder_default_c<double, double, double>};
13871370

1388-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_INT][eft_INT] = {
1389-
eft_INT, (void *)dpnp_remainder_ext_c<int32_t, int32_t, int32_t>};
1390-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_INT][eft_LNG] = {
1391-
eft_LNG, (void *)dpnp_remainder_ext_c<int64_t, int32_t, int64_t>};
1392-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_INT][eft_FLT] = {
1393-
eft_DBL, (void *)dpnp_remainder_ext_c<double, int32_t, float>};
1394-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_INT][eft_DBL] = {
1395-
eft_DBL, (void *)dpnp_remainder_ext_c<double, int32_t, double>};
1396-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_LNG][eft_INT] = {
1397-
eft_LNG, (void *)dpnp_remainder_ext_c<int64_t, int64_t, int32_t>};
1398-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_LNG][eft_LNG] = {
1399-
eft_LNG, (void *)dpnp_remainder_ext_c<int64_t, int64_t, int64_t>};
1400-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_LNG][eft_FLT] = {
1401-
eft_DBL, (void *)dpnp_remainder_ext_c<double, int64_t, float>};
1402-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_LNG][eft_DBL] = {
1403-
eft_DBL, (void *)dpnp_remainder_ext_c<double, int64_t, double>};
1404-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_FLT][eft_INT] = {
1405-
eft_DBL, (void *)dpnp_remainder_ext_c<double, float, int32_t>};
1406-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_FLT][eft_LNG] = {
1407-
eft_DBL, (void *)dpnp_remainder_ext_c<double, float, int64_t>};
1408-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_FLT][eft_FLT] = {
1409-
eft_FLT, (void *)dpnp_remainder_ext_c<float, float, float>};
1410-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_FLT][eft_DBL] = {
1411-
eft_DBL, (void *)dpnp_remainder_ext_c<double, float, double>};
1412-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_DBL][eft_INT] = {
1413-
eft_DBL, (void *)dpnp_remainder_ext_c<double, double, int32_t>};
1414-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_DBL][eft_LNG] = {
1415-
eft_DBL, (void *)dpnp_remainder_ext_c<double, double, int64_t>};
1416-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_DBL][eft_FLT] = {
1417-
eft_DBL, (void *)dpnp_remainder_ext_c<double, double, float>};
1418-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_DBL][eft_DBL] = {
1419-
eft_DBL, (void *)dpnp_remainder_ext_c<double, double, double>};
1420-
14211371
fmap[DPNPFuncName::DPNP_FN_TRAPZ][eft_INT][eft_INT] = {
14221372
eft_DBL, (void *)dpnp_trapz_default_c<int32_t, int32_t, double>};
14231373
fmap[DPNPFuncName::DPNP_FN_TRAPZ][eft_INT][eft_LNG] = {

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
188188
DPNP_FN_QR_EXT
189189
DPNP_FN_RADIANS
190190
DPNP_FN_RADIANS_EXT
191-
DPNP_FN_REMAINDER
192-
DPNP_FN_REMAINDER_EXT
193191
DPNP_FN_RECIP
194192
DPNP_FN_RECIP_EXT
195193
DPNP_FN_REPEAT
@@ -442,9 +440,6 @@ cpdef dpnp_descriptor dpnp_minimum(dpnp_descriptor x1_obj, dpnp_descriptor x2_ob
442440
cpdef dpnp_descriptor dpnp_negative(dpnp_descriptor array1)
443441
cpdef dpnp_descriptor dpnp_power(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
444442
dpnp_descriptor out=*, object where=*)
445-
cpdef dpnp_descriptor dpnp_remainder(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
446-
dpnp_descriptor out=*, object where=*)
447-
448443

449444
"""
450445
Array manipulation routines

dpnp/dpnp_algo/dpnp_algo_mathematical.pxi

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ __all__ += [
6262
"dpnp_negative",
6363
"dpnp_power",
6464
"dpnp_prod",
65-
"dpnp_remainder",
6665
"dpnp_sign",
6766
"dpnp_sum",
6867
"dpnp_trapz",
@@ -546,14 +545,6 @@ cpdef utils.dpnp_descriptor dpnp_prod(utils.dpnp_descriptor x1,
546545
return result
547546

548547

549-
cpdef utils.dpnp_descriptor dpnp_remainder(utils.dpnp_descriptor x1_obj,
550-
utils.dpnp_descriptor x2_obj,
551-
object dtype=None,
552-
utils.dpnp_descriptor out=None,
553-
object where=True):
554-
return call_fptr_2in_1out(DPNP_FN_REMAINDER_EXT, x1_obj, x2_obj, dtype, out, where)
555-
556-
557548
cpdef utils.dpnp_descriptor dpnp_sign(utils.dpnp_descriptor x1):
558549
return call_fptr_1in_1out_strides(DPNP_FN_SIGN_EXT, x1)
559550

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
"dpnp_logical_xor",
6767
"dpnp_multiply",
6868
"dpnp_not_equal",
69+
"dpnp_remainder",
6970
"dpnp_right_shift",
7071
"dpnp_sin",
7172
"dpnp_sqrt",
@@ -86,7 +87,7 @@ def check_nd_call_func(
8687
**kwargs,
8788
):
8889
"""
89-
Checks arguments and calls function with a single input array.
90+
Checks arguments and calls a function.
9091
9192
Chooses a common internal elementwise function to call in DPNP based on input arguments
9293
or to fallback on NumPy call if any passed argument is not currently supported.
@@ -127,7 +128,6 @@ def check_nd_call_func(
127128
order
128129
)
129130
)
130-
131131
return dpnp_func(*x_args, out=out, order=order)
132132
return call_origin(
133133
origin_func,
@@ -1174,6 +1174,48 @@ def dpnp_not_equal(x1, x2, out=None, order="K"):
11741174
return dpnp_array._create_from_usm_ndarray(res_usm)
11751175

11761176

1177+
_remainder_docstring_ = """
1178+
remainder(x1, x2, out=None, order='K')
1179+
Calculates the remainder of division for each element `x1_i` of the input array
1180+
`x1` with the respective element `x2_i` of the input array `x2`.
1181+
This function is equivalent to the Python modulus operator.
1182+
Args:
1183+
x1 (dpnp.ndarray):
1184+
First input array, expected to have a real-valued data type.
1185+
x2 (dpnp.ndarray):
1186+
Second input array, also expected to have a real-valued data type.
1187+
out ({None, usm_ndarray}, optional):
1188+
Output array to populate.
1189+
Array have the correct shape and the expected data type.
1190+
order ("C","F","A","K", optional):
1191+
Memory layout of the newly output array, if parameter `out` is `None`.
1192+
Default: "K".
1193+
Returns:
1194+
dpnp.ndarray:
1195+
an array containing the element-wise remainders. The data type of
1196+
the returned array is determined by the Type Promotion Rules.
1197+
"""
1198+
1199+
1200+
remainder_func = BinaryElementwiseFunc(
1201+
"remainder",
1202+
ti._remainder_result_type,
1203+
ti._remainder,
1204+
_remainder_docstring_,
1205+
)
1206+
1207+
1208+
def dpnp_remainder(x1, x2, out=None, order="K"):
1209+
# dpctl.tensor only works with usm_ndarray or scalar
1210+
x1_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x1)
1211+
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
1212+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
1213+
1214+
res_usm = remainder_func(
1215+
x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order
1216+
)
1217+
return dpnp_array._create_from_usm_ndarray(res_usm)
1218+
11771219
_right_shift_docstring_ = """
11781220
right_shift(x1, x2, out=None, order='K')
11791221

0 commit comments

Comments
 (0)