Skip to content

Commit fd9ce2b

Browse files
Update dpnp.sqrt using dpctl and OneMKL implementations (#1470)
* Reuse dpctl.tensor.sqrt for dpnp.sqrt * Update tests and docstrings for dpnp.sqrt * Add sqrt call from OneMKL by pybind11 extension * Update test_umath for dpnp.sqrt * Remove DPNP_FN_SQRT_EXT and update docstrings * Return deleted DPNP_FN_SQRT_EXT * Update dpnp/backend/extensions/vm/vm_py.cpp * Update dpnp/backend/extensions/vm/vm_py.cpp --------- Co-authored-by: Anton <[email protected]>
1 parent f82cdc4 commit fd9ce2b

File tree

11 files changed

+273
-53
lines changed

11 files changed

+273
-53
lines changed

dpnp/backend/extensions/vm/sqrt.hpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
30+
#include "common.hpp"
31+
#include "types_matrix.hpp"
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace vm
40+
{
41+
template <typename T>
42+
sycl::event sqrt_contig_impl(sycl::queue exec_q,
43+
const std::int64_t n,
44+
const char *in_a,
45+
char *out_y,
46+
const std::vector<sycl::event> &depends)
47+
{
48+
type_utils::validate_type_for_device<T>(exec_q);
49+
50+
const T *a = reinterpret_cast<const T *>(in_a);
51+
T *y = reinterpret_cast<T *>(out_y);
52+
53+
return mkl_vm::sqrt(exec_q,
54+
n, // number of elements to be calculated
55+
a, // pointer `a` containing input vector of size n
56+
y, // pointer `y` to the output vector of size n
57+
depends);
58+
}
59+
60+
template <typename fnT, typename T>
61+
struct SqrtContigFactory
62+
{
63+
fnT get()
64+
{
65+
if constexpr (std::is_same_v<
66+
typename types::SqrtOutputType<T>::value_type, void>)
67+
{
68+
return nullptr;
69+
}
70+
else {
71+
return sqrt_contig_impl<T>;
72+
}
73+
}
74+
};
75+
} // namespace vm
76+
} // namespace ext
77+
} // namespace backend
78+
} // namespace dpnp

dpnp/backend/extensions/vm/types_matrix.hpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,25 @@ struct SinOutputType
124124
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
125125
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
126126
};
127+
128+
/**
129+
* @brief A factory to define pairs of supported types for which
130+
* MKL VM library provides support in oneapi::mkl::vm::sqrt<T> function.
131+
*
132+
* @tparam T Type of input vector `a` and of result vector `y`.
133+
*/
134+
template <typename T>
135+
struct SqrtOutputType
136+
{
137+
using value_type = typename std::disjunction<
138+
dpctl_td_ns::
139+
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
140+
dpctl_td_ns::
141+
TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
142+
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
143+
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
144+
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
145+
};
127146
} // namespace types
128147
} // namespace vm
129148
} // namespace ext

dpnp/backend/extensions/vm/vm_py.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "div.hpp"
3636
#include "ln.hpp"
3737
#include "sin.hpp"
38+
#include "sqrt.hpp"
3839
#include "types_matrix.hpp"
3940

4041
namespace py = pybind11;
@@ -48,6 +49,7 @@ static binary_impl_fn_ptr_t div_dispatch_vector[dpctl_td_ns::num_types];
4849
static unary_impl_fn_ptr_t cos_dispatch_vector[dpctl_td_ns::num_types];
4950
static unary_impl_fn_ptr_t ln_dispatch_vector[dpctl_td_ns::num_types];
5051
static unary_impl_fn_ptr_t sin_dispatch_vector[dpctl_td_ns::num_types];
52+
static unary_impl_fn_ptr_t sqrt_dispatch_vector[dpctl_td_ns::num_types];
5153

5254
PYBIND11_MODULE(_vm_impl, m)
5355
{
@@ -167,4 +169,34 @@ PYBIND11_MODULE(_vm_impl, m)
167169
"OneMKL VM library can be used",
168170
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
169171
}
172+
173+
// UnaryUfunc: ==== Sqrt(x) ====
174+
{
175+
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
176+
vm_ext::SqrtContigFactory>(
177+
sqrt_dispatch_vector);
178+
179+
auto sqrt_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
180+
const event_vecT &depends = {}) {
181+
return vm_ext::unary_ufunc(exec_q, src, dst, depends,
182+
sqrt_dispatch_vector);
183+
};
184+
m.def(
185+
"_sqrt", sqrt_pyapi,
186+
"Call `sqrt` from OneMKL VM library to performs element by element "
187+
"operation of extracting the square root "
188+
"of vector `src` to resulting vector `dst`",
189+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"),
190+
py::arg("depends") = py::list());
191+
192+
auto sqrt_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
193+
arrayT dst) {
194+
return vm_ext::need_to_call_unary_ufunc(exec_q, src, dst,
195+
sqrt_dispatch_vector);
196+
};
197+
m.def("_mkl_sqrt_to_call", sqrt_need_to_call_pyapi,
198+
"Check input arguments to answer if `sqrt` function from "
199+
"OneMKL VM library can be used",
200+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
201+
}
170202
}

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -729,10 +729,7 @@ static void func_map_init_elemwise_1arg_2type(func_map_t &fmap)
729729
fmap[DPNPFuncName::DPNP_FN_SQRT][eft_DBL][eft_DBL] = {
730730
eft_DBL, (void *)dpnp_sqrt_c_default<double, double>};
731731

732-
fmap[DPNPFuncName::DPNP_FN_SQRT_EXT][eft_INT][eft_INT] = {
733-
eft_DBL, (void *)dpnp_sqrt_c_ext<int32_t, double>};
734-
fmap[DPNPFuncName::DPNP_FN_SQRT_EXT][eft_LNG][eft_LNG] = {
735-
eft_DBL, (void *)dpnp_sqrt_c_ext<int64_t, double>};
732+
// Used in dpnp_std_c
736733
fmap[DPNPFuncName::DPNP_FN_SQRT_EXT][eft_FLT][eft_FLT] = {
737734
eft_FLT, (void *)dpnp_sqrt_c_ext<float, float>};
738735
fmap[DPNPFuncName::DPNP_FN_SQRT_EXT][eft_DBL][eft_DBL] = {

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
295295
DPNP_FN_SINH_EXT
296296
DPNP_FN_SORT
297297
DPNP_FN_SORT_EXT
298-
DPNP_FN_SQRT
299-
DPNP_FN_SQRT_EXT
300298
DPNP_FN_SQUARE
301299
DPNP_FN_SQUARE_EXT
302300
DPNP_FN_STD
@@ -553,7 +551,6 @@ cpdef dpnp_descriptor dpnp_log2(dpnp_descriptor array1)
553551
cpdef dpnp_descriptor dpnp_radians(dpnp_descriptor array1)
554552
cpdef dpnp_descriptor dpnp_recip(dpnp_descriptor array1)
555553
cpdef dpnp_descriptor dpnp_sinh(dpnp_descriptor array1)
556-
cpdef dpnp_descriptor dpnp_sqrt(dpnp_descriptor array1, dpnp_descriptor out)
557554
cpdef dpnp_descriptor dpnp_square(dpnp_descriptor array1)
558555
cpdef dpnp_descriptor dpnp_tan(dpnp_descriptor array1, dpnp_descriptor out)
559556
cpdef dpnp_descriptor dpnp_tanh(dpnp_descriptor array1)

dpnp/dpnp_algo/dpnp_algo_trigonometric.pxi

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ __all__ += [
5454
'dpnp_radians',
5555
'dpnp_recip',
5656
'dpnp_sinh',
57-
'dpnp_sqrt',
5857
'dpnp_square',
5958
'dpnp_tan',
6059
'dpnp_tanh',
@@ -134,10 +133,6 @@ cpdef utils.dpnp_descriptor dpnp_sinh(utils.dpnp_descriptor x1):
134133
return call_fptr_1in_1out_strides(DPNP_FN_SINH_EXT, x1)
135134

136135

137-
cpdef utils.dpnp_descriptor dpnp_sqrt(utils.dpnp_descriptor x1, utils.dpnp_descriptor out):
138-
return call_fptr_1in_1out_strides(DPNP_FN_SQRT_EXT, x1, dtype=None, out=out, where=True, func_name='sqrt')
139-
140-
141136
cpdef utils.dpnp_descriptor dpnp_square(utils.dpnp_descriptor x1):
142137
return call_fptr_1in_1out_strides(DPNP_FN_SQUARE_EXT, x1)
143138

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
"dpnp_multiply",
5555
"dpnp_not_equal",
5656
"dpnp_sin",
57+
"dpnp_sqrt",
5758
"dpnp_subtract",
5859
]
5960

@@ -685,6 +686,57 @@ def _call_sin(src, dst, sycl_queue, depends=None):
685686
return dpnp_array._create_from_usm_ndarray(res_usm)
686687

687688

689+
_sqrt_docstring_ = """
690+
sqrt(x, out=None, order='K')
691+
Computes the non-negative square-root for each element `x_i` for input array `x`.
692+
Args:
693+
x (dpnp.ndarray):
694+
Input array.
695+
out ({None, dpnp.ndarray}, optional):
696+
Output array to populate. Array must have the correct
697+
shape and the expected data type.
698+
order ("C","F","A","K", optional): memory layout of the new
699+
output array, if parameter `out` is `None`.
700+
Default: "K".
701+
Return:
702+
dpnp.ndarray:
703+
An array containing the element-wise square-root results.
704+
"""
705+
706+
707+
def dpnp_sqrt(x, out=None, order="K"):
708+
"""
709+
Invokes sqrt() function from pybind11 extension of OneMKL VM if possible.
710+
711+
Otherwise fully relies on dpctl.tensor implementation for sqrt() function.
712+
713+
"""
714+
715+
def _call_sqrt(src, dst, sycl_queue, depends=None):
716+
"""A callback to register in UnaryElementwiseFunc class of dpctl.tensor"""
717+
718+
if depends is None:
719+
depends = []
720+
721+
if vmi._mkl_sqrt_to_call(sycl_queue, src, dst):
722+
# call pybind11 extension for sqrt() function from OneMKL VM
723+
return vmi._sqrt(sycl_queue, src, dst, depends)
724+
return ti._sqrt(src, dst, sycl_queue, depends)
725+
726+
# dpctl.tensor only works with usm_ndarray or scalar
727+
x_usm = dpnp.get_usm_ndarray(x)
728+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
729+
730+
func = UnaryElementwiseFunc(
731+
"sqrt",
732+
ti._sqrt_result_type,
733+
_call_sqrt,
734+
_sqrt_docstring_,
735+
)
736+
res_usm = func(x_usm, out=out_usm, order=order)
737+
return dpnp_array._create_from_usm_ndarray(res_usm)
738+
739+
688740
_subtract_docstring_ = """
689741
subtract(x1, x2, out=None, order="K")
690742

dpnp/dpnp_iface_bitwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def bitwise_and(x1, x2, dtype=None, out=None, where=True, **kwargs):
133133
Returns
134134
-------
135135
y : dpnp.ndarray
136-
An array containing the element-wise results.
136+
An array containing the element-wise results of positive square root.
137137
138138
Limitations
139139
-----------

dpnp/dpnp_iface_trigonometric.py

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
"""
4141

4242

43-
import dpctl.tensor as dpt
4443
import numpy
4544

4645
import dpnp
@@ -52,6 +51,7 @@
5251
dpnp_cos,
5352
dpnp_log,
5453
dpnp_sin,
54+
dpnp_sqrt,
5555
)
5656

5757
__all__ = [
@@ -1048,51 +1048,64 @@ def sinh(x1):
10481048
return call_origin(numpy.sinh, x1, **kwargs)
10491049

10501050

1051-
def sqrt(x1, /, out=None, **kwargs):
1051+
def sqrt(
1052+
x,
1053+
/,
1054+
out=None,
1055+
*,
1056+
order="K",
1057+
where=True,
1058+
dtype=None,
1059+
subok=True,
1060+
**kwargs,
1061+
):
10521062
"""
1053-
Return the positive square-root of an array, element-wise.
1063+
Return the non-negative square-root of an array, element-wise.
10541064
10551065
For full documentation refer to :obj:`numpy.sqrt`.
10561066
1067+
Returns
1068+
-------
1069+
y : dpnp.ndarray
1070+
An array of the same shape as `x`, containing the positive
1071+
square-root of each element in `x`. If any element in `x` is
1072+
complex, a complex array is returned (and the square-roots of
1073+
negative reals are calculated). If all of the elements in `x`
1074+
are real, so is `y`, with negative elements returning ``nan``.
1075+
10571076
Limitations
10581077
-----------
10591078
Input array is supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
10601079
Parameter `out` is supported as class:`dpnp.ndarray`, class:`dpctl.tensor.usm_ndarray` or
10611080
with default value ``None``.
1081+
Parameters `where`, `dtype` and `subok` are supported with their default values.
10621082
Otherwise the function will be executed sequentially on CPU.
1063-
Keyword arguments ``kwargs`` are currently unsupported.
10641083
Input array data types are limited by supported DPNP :ref:`Data types`.
10651084
10661085
Examples
10671086
--------
10681087
>>> import dpnp as np
10691088
>>> x = np.array([1, 4, 9])
1070-
>>> out = np.sqrt(x)
1071-
>>> [i for i in out]
1072-
[1.0, 2.0, 3.0]
1089+
>>> np.sqrt(x)
1090+
array([1., 2., 3.])
1091+
1092+
>>> x2 = np.array([4, -1, np.inf])
1093+
>>> np.sqrt(x2)
1094+
array([ 2., nan, inf])
10731095
10741096
"""
10751097

1076-
x1_desc = (
1077-
dpnp.get_dpnp_descriptor(
1078-
x1, copy_when_strides=False, copy_when_nondefault_queue=False
1079-
)
1080-
if not kwargs
1081-
else None
1098+
return check_nd_call_func(
1099+
numpy.sqrt,
1100+
dpnp_sqrt,
1101+
x,
1102+
out=out,
1103+
where=where,
1104+
order=order,
1105+
dtype=dtype,
1106+
subok=subok,
1107+
**kwargs,
10821108
)
1083-
if x1_desc:
1084-
if out is not None:
1085-
if not isinstance(out, (dpnp.ndarray, dpt.usm_ndarray)):
1086-
raise TypeError("return array must be of supported array type")
1087-
out_desc = (
1088-
dpnp.get_dpnp_descriptor(out, copy_when_nondefault_queue=False)
1089-
or None
1090-
)
1091-
else:
1092-
out_desc = None
1093-
return dpnp_sqrt(x1_desc, out=out_desc).get_pyobj()
1094-
1095-
return call_origin(numpy.sqrt, x1, out=out, **kwargs)
10961109

10971110

10981111
def square(x1):

dpnp/linalg/dpnp_algo_linalg.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ cpdef object dpnp_norm(object input, ord=None, axis=None):
366366

367367
input = dpnp.ravel(input, order='K')
368368
sqnorm = dpnp.dot(input, input)
369-
ret = dpnp.sqrt([sqnorm])
369+
ret = dpnp.sqrt(sqnorm)
370370
return dpnp.array(ret.reshape(1, *ret.shape), dtype=res_type)
371371

372372
len_axis = 1 if axis is None else len(axis_)

0 commit comments

Comments
 (0)