Skip to content

Commit 266871d

Browse files
committed
use_dpctl_round_func_in_dpnp
1 parent 46d16a5 commit 266871d

File tree

5 files changed

+210
-0
lines changed

5 files changed

+210
-0
lines changed

dpnp/backend/extensions/vm/round.hpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
30+
#include "common.hpp"
31+
#include "types_matrix.hpp"
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace vm
40+
{
41+
template <typename T>
42+
sycl::event round_contig_impl(sycl::queue exec_q,
43+
const std::int64_t n,
44+
const char *in_a,
45+
char *out_y,
46+
const std::vector<sycl::event> &depends)
47+
{
48+
type_utils::validate_type_for_device<T>(exec_q);
49+
50+
const T *a = reinterpret_cast<const T *>(in_a);
51+
T *y = reinterpret_cast<T *>(out_y);
52+
53+
return mkl_vm::round(exec_q,
54+
n, // number of elements to be calculated
55+
a, // pointer `a` containing input vector of size n
56+
y, // pointer `y` to the output vector of size n
57+
depends);
58+
}
59+
60+
template <typename fnT, typename T>
61+
struct RoundContigFactory
62+
{
63+
fnT get()
64+
{
65+
if constexpr (std::is_same_v<
66+
typename types::RoundOutputType<T>::value_type, void>)
67+
{
68+
return nullptr;
69+
}
70+
else {
71+
return round_contig_impl<T>;
72+
}
73+
}
74+
};
75+
} // namespace vm
76+
} // namespace ext
77+
} // namespace backend
78+
} // namespace dpnp

dpnp/backend/extensions/vm/types_matrix.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,21 @@ struct LnOutputType
136136
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
137137
};
138138

139+
/**
140+
* @brief A factory to define pairs of supported types for which
141+
* MKL VM library provides support in oneapi::mkl::vm::round<T> function.
142+
*
143+
* @tparam T Type of input vector `a` and of result vector `y`.
144+
*/
145+
template <typename T>
146+
struct RoundOutputType
147+
{
148+
using value_type = typename std::disjunction<
149+
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
150+
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
151+
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
152+
};
153+
139154
/**
140155
* @brief A factory to define pairs of supported types for which
141156
* MKL VM library provides support in oneapi::mkl::vm::sin<T> function.

dpnp/backend/extensions/vm/vm_py.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "div.hpp"
3737
#include "floor.hpp"
3838
#include "ln.hpp"
39+
#include "round.hpp"
3940
#include "sin.hpp"
4041
#include "sqr.hpp"
4142
#include "sqrt.hpp"
@@ -54,6 +55,7 @@ static unary_impl_fn_ptr_t ceil_dispatch_vector[dpctl_td_ns::num_types];
5455
static unary_impl_fn_ptr_t cos_dispatch_vector[dpctl_td_ns::num_types];
5556
static unary_impl_fn_ptr_t floor_dispatch_vector[dpctl_td_ns::num_types];
5657
static unary_impl_fn_ptr_t ln_dispatch_vector[dpctl_td_ns::num_types];
58+
static unary_impl_fn_ptr_t round_dispatch_vector[dpctl_td_ns::num_types];
5759
static unary_impl_fn_ptr_t sin_dispatch_vector[dpctl_td_ns::num_types];
5860
static unary_impl_fn_ptr_t sqr_dispatch_vector[dpctl_td_ns::num_types];
5961
static unary_impl_fn_ptr_t sqrt_dispatch_vector[dpctl_td_ns::num_types];
@@ -206,6 +208,34 @@ PYBIND11_MODULE(_vm_impl, m)
206208
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
207209
}
208210

211+
// UnaryUfunc: ==== Round(x) ====
212+
{
213+
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
214+
vm_ext::RoundContigFactory>(
215+
round_dispatch_vector);
216+
217+
auto round_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
218+
const event_vecT &depends = {}) {
219+
return vm_ext::unary_ufunc(exec_q, src, dst, depends,
220+
round_dispatch_vector);
221+
};
222+
m.def("_round", round_pyapi,
223+
"Call `round` function from OneMKL VM library to compute "
224+
"the rounded value of vector elements",
225+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"),
226+
py::arg("depends") = py::list());
227+
228+
auto round_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
229+
arrayT dst) {
230+
return vm_ext::need_to_call_unary_ufunc(exec_q, src, dst,
231+
round_dispatch_vector);
232+
};
233+
m.def("_mkl_round_to_call", round_need_to_call_pyapi,
234+
"Check input arguments to answer if `round` function from "
235+
"OneMKL VM library can be used",
236+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
237+
}
238+
209239
// UnaryUfunc: ==== Sin(x) ====
210240
{
211241
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
"dpnp_logical_xor",
6464
"dpnp_multiply",
6565
"dpnp_not_equal",
66+
"dpnp_round",
6667
"dpnp_sin",
6768
"dpnp_sqrt",
6869
"dpnp_square",
@@ -1062,6 +1063,58 @@ def dpnp_not_equal(x1, x2, out=None, order="K"):
10621063
return dpnp_array._create_from_usm_ndarray(res_usm)
10631064

10641065

1066+
_round_docstring = """
1067+
round(x, out=None, order='K')
1068+
1069+
Rounds each element `x_i` of the input array `x` to
1070+
the nearest integer-valued number.
1071+
1072+
Args:
1073+
x (dpnp.ndarray):
1074+
Input array, expected to have numeric data type.
1075+
out ({None, dpnp.ndarray}, optional):
1076+
Output array to populate. Array must have the correct
1077+
shape and the expected data type.
1078+
order ("C","F","A","K", optional): memory layout of the new
1079+
output array, if parameter `out` is `None`.
1080+
Default: "K".
1081+
Return:
1082+
dpnp.ndarray:
1083+
An array containing the element-wise rounded value. The data type
1084+
of the returned array is determined by the Type Promotion Rules.
1085+
"""
1086+
1087+
1088+
def dpnp_round(x, out=None, order="K"):
1089+
"""
1090+
Invokes round() function from pybind11 extension of OneMKL VM if possible.
1091+
1092+
Otherwise fully relies on dpctl.tensor implementation for round() function.
1093+
1094+
"""
1095+
1096+
def _call_round(src, dst, sycl_queue, depends=None):
1097+
"""A callback to register in UnaryElementwiseFunc class of dpctl.tensor"""
1098+
1099+
if depends is None:
1100+
depends = []
1101+
1102+
if vmi._mkl_round_to_call(sycl_queue, src, dst):
1103+
# call pybind11 extension for round() function from OneMKL VM
1104+
return vmi._round(sycl_queue, src, dst, depends)
1105+
return ti._round(src, dst, sycl_queue, depends)
1106+
1107+
# dpctl.tensor only works with usm_ndarray
1108+
x1_usm = dpnp.get_usm_ndarray(x)
1109+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
1110+
1111+
func = UnaryElementwiseFunc(
1112+
"round", ti._round_result_type, _call_round, _round_docstring
1113+
)
1114+
res_usm = func(x1_usm, out=out_usm, order=order)
1115+
return dpnp_array._create_from_usm_ndarray(res_usm)
1116+
1117+
10651118
_sin_docstring = """
10661119
sin(x, out=None, order='K')
10671120
Computes sine for each element `x_i` of input array `x`.

dpnp/dpnp_iface_mathematical.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,11 +372,19 @@ def ceil(
372372
373373
Limitations
374374
-----------
375+
<<<<<<< HEAD
375376
Parameter `x` is only supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
376377
Parameters `where`, `dtype`, and `subok` are supported with their default values.
377378
Keyword arguments `kwargs` are currently unsupported.
378379
Otherwise the function will be executed sequentially on CPU.
379380
Input array data types are limited by real-value data types.
381+
=======
382+
Parameter `x` is only supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
383+
Parameters `where`, `dtype`, and `subok` are supported with their default values.
384+
Keyword arguments `kwargs` are currently unsupported.
385+
Otherwise the functions will be executed sequentially on CPU.
386+
Input array data types are limited by real-value data types.
387+
>>>>>>> use_dpctl_rounding_funcs_in_dpnp
380388
381389
See Also
382390
--------
@@ -821,6 +829,7 @@ def floor(
821829
822830
Limitations
823831
-----------
832+
<<<<<<< HEAD
824833
Parameter `x` is only supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
825834
Parameters `where`, `dtype`, and `subok` are supported with their default values.
826835
Keyword arguments `kwargs` are currently unsupported.
@@ -836,6 +845,23 @@ def floor(
836845
-----
837846
Some spreadsheet programs calculate the "floor-towards-zero", in other words floor(-2.5) == -2.
838847
DPNP instead uses the definition of floor where floor(-2.5) == -3.
848+
=======
849+
Parameter `x` is only supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
850+
Parameters `where`, `dtype`, and `subok` are supported with their default values.
851+
Keyword arguments `kwargs` are currently unsupported.
852+
Otherwise the functions will be executed sequentially on CPU.
853+
Input array data types are limited by real-value data types.
854+
855+
See Also
856+
--------
857+
:obj:`dpnp.ceil` : Compute the ceiling of the input, element-wise.
858+
:obj:`dpnp.trunc` : Return the truncated value of the input, element-wise.
859+
860+
Notes
861+
-----
862+
Some spreadsheet programs calculate the "floor-towards-zero", in other words floor(-2.5) == -2.
863+
DPNP instead uses the definition of floor where floor(-2.5) == -3.
864+
>>>>>>> use_dpctl_rounding_funcs_in_dpnp
839865
840866
Examples
841867
--------
@@ -2028,11 +2054,19 @@ def trunc(
20282054
20292055
Limitations
20302056
-----------
2057+
<<<<<<< HEAD
20312058
Parameter `x` is only supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
20322059
Parameters `where`, `dtype`, and `subok` are supported with their default values.
20332060
Keyword arguments `kwargs` are currently unsupported.
20342061
Otherwise the function will be executed sequentially on CPU.
20352062
Input array data types are limited by real-value data types.
2063+
=======
2064+
Parameter `x` is only supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
2065+
Parameters `where`, `dtype`, and `subok` are supported with their default values.
2066+
Keyword arguments `kwargs` are currently unsupported.
2067+
Otherwise the functions will be executed sequentially on CPU.
2068+
Input array data types are limited by real-value data types.
2069+
>>>>>>> use_dpctl_rounding_funcs_in_dpnp
20362070
20372071
See Also
20382072
--------

0 commit comments

Comments
 (0)