Skip to content

Commit 457fb55

Browse files
authored
Merge pull request #1473 from IntelPython/dpctl_square
Improve dpnp.square() implementation
2 parents f294dfd + 725de87 commit 457fb55

File tree

14 files changed

+900
-74
lines changed

14 files changed

+900
-74
lines changed

.github/workflows/conda-package.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ env:
2323
test_umath.py
2424
test_usm_type.py
2525
third_party/cupy/math_tests/test_explog.py
26+
third_party/cupy/math_tests/test_misc.py
2627
third_party/cupy/math_tests/test_trigonometric.py
2728
third_party/cupy/sorting_tests/test_sort.py
2829
VER_JSON_NAME: 'version.json'

dpnp/backend/extensions/vm/sqr.hpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
30+
#include "common.hpp"
31+
#include "types_matrix.hpp"
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace vm
40+
{
41+
template <typename T>
42+
sycl::event sqr_contig_impl(sycl::queue exec_q,
43+
const std::int64_t n,
44+
const char *in_a,
45+
char *out_y,
46+
const std::vector<sycl::event> &depends)
47+
{
48+
type_utils::validate_type_for_device<T>(exec_q);
49+
50+
const T *a = reinterpret_cast<const T *>(in_a);
51+
T *y = reinterpret_cast<T *>(out_y);
52+
53+
return mkl_vm::sqr(exec_q,
54+
n, // number of elements to be calculated
55+
a, // pointer `a` containing input vector of size n
56+
y, // pointer `y` to the output vector of size n
57+
depends);
58+
}
59+
60+
template <typename fnT, typename T>
61+
struct SqrContigFactory
62+
{
63+
fnT get()
64+
{
65+
if constexpr (std::is_same_v<
66+
typename types::SqrOutputType<T>::value_type, void>)
67+
{
68+
return nullptr;
69+
}
70+
else {
71+
return sqr_contig_impl<T>;
72+
}
73+
}
74+
};
75+
} // namespace vm
76+
} // namespace ext
77+
} // namespace backend
78+
} // namespace dpnp

dpnp/backend/extensions/vm/types_matrix.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,21 @@ struct SinOutputType
125125
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
126126
};
127127

128+
/**
129+
* @brief A factory to define pairs of supported types for which
130+
* MKL VM library provides support in oneapi::mkl::vm::sqr<T> function.
131+
*
132+
* @tparam T Type of input vector `a` and of result vector `y`.
133+
*/
134+
template <typename T>
135+
struct SqrOutputType
136+
{
137+
using value_type = typename std::disjunction<
138+
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
139+
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
140+
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
141+
};
142+
128143
/**
129144
* @brief A factory to define pairs of supported types for which
130145
* MKL VM library provides support in oneapi::mkl::vm::sqrt<T> function.

dpnp/backend/extensions/vm/vm_py.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "div.hpp"
3636
#include "ln.hpp"
3737
#include "sin.hpp"
38+
#include "sqr.hpp"
3839
#include "sqrt.hpp"
3940
#include "types_matrix.hpp"
4041

@@ -49,6 +50,7 @@ static binary_impl_fn_ptr_t div_dispatch_vector[dpctl_td_ns::num_types];
4950
static unary_impl_fn_ptr_t cos_dispatch_vector[dpctl_td_ns::num_types];
5051
static unary_impl_fn_ptr_t ln_dispatch_vector[dpctl_td_ns::num_types];
5152
static unary_impl_fn_ptr_t sin_dispatch_vector[dpctl_td_ns::num_types];
53+
static unary_impl_fn_ptr_t sqr_dispatch_vector[dpctl_td_ns::num_types];
5254
static unary_impl_fn_ptr_t sqrt_dispatch_vector[dpctl_td_ns::num_types];
5355

5456
PYBIND11_MODULE(_vm_impl, m)
@@ -170,6 +172,35 @@ PYBIND11_MODULE(_vm_impl, m)
170172
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
171173
}
172174

175+
// UnaryUfunc: ==== Sqr(x) ====
176+
{
177+
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
178+
vm_ext::SqrContigFactory>(
179+
sqr_dispatch_vector);
180+
181+
auto sqr_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
182+
const event_vecT &depends = {}) {
183+
return vm_ext::unary_ufunc(exec_q, src, dst, depends,
184+
sqr_dispatch_vector);
185+
};
186+
m.def(
187+
"_sqr", sqr_pyapi,
188+
"Call `sqr` from OneMKL VM library to performs element by element "
189+
"operation of squaring of vector `src` to resulting vector `dst`",
190+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"),
191+
py::arg("depends") = py::list());
192+
193+
auto sqr_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
194+
arrayT dst) {
195+
return vm_ext::need_to_call_unary_ufunc(exec_q, src, dst,
196+
sqr_dispatch_vector);
197+
};
198+
m.def("_mkl_sqr_to_call", sqr_need_to_call_pyapi,
199+
"Check input arguments to answer if `sqr` function from "
200+
"OneMKL VM library can be used",
201+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
202+
}
203+
173204
// UnaryUfunc: ==== Sqrt(x) ====
174205
{
175206
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -476,9 +476,7 @@ enum class DPNPFuncName : size_t
476476
DPNP_FN_SQRT_EXT, /**< Used in numpy.sqrt() impl, requires extra parameters
477477
*/
478478
DPNP_FN_SQUARE, /**< Used in numpy.square() impl */
479-
DPNP_FN_SQUARE_EXT, /**< Used in numpy.square() impl, requires extra
480-
parameters */
481-
DPNP_FN_STD, /**< Used in numpy.std() impl */
479+
DPNP_FN_STD, /**< Used in numpy.std() impl */
482480
DPNP_FN_STD_EXT, /**< Used in numpy.std() impl, requires extra parameters */
483481
DPNP_FN_SUBTRACT, /**< Used in numpy.subtract() impl */
484482
DPNP_FN_SUBTRACT_EXT, /**< Used in numpy.subtract() impl, requires extra

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,15 +1156,6 @@ static void func_map_init_elemwise_1arg_1type(func_map_t &fmap)
11561156
fmap[DPNPFuncName::DPNP_FN_SQUARE][eft_DBL][eft_DBL] = {
11571157
eft_DBL, (void *)dpnp_square_c_default<double>};
11581158

1159-
fmap[DPNPFuncName::DPNP_FN_SQUARE_EXT][eft_INT][eft_INT] = {
1160-
eft_INT, (void *)dpnp_square_c_ext<int32_t>};
1161-
fmap[DPNPFuncName::DPNP_FN_SQUARE_EXT][eft_LNG][eft_LNG] = {
1162-
eft_LNG, (void *)dpnp_square_c_ext<int64_t>};
1163-
fmap[DPNPFuncName::DPNP_FN_SQUARE_EXT][eft_FLT][eft_FLT] = {
1164-
eft_FLT, (void *)dpnp_square_c_ext<float>};
1165-
fmap[DPNPFuncName::DPNP_FN_SQUARE_EXT][eft_DBL][eft_DBL] = {
1166-
eft_DBL, (void *)dpnp_square_c_ext<double>};
1167-
11681159
return;
11691160
}
11701161

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
291291
DPNP_FN_SINH_EXT
292292
DPNP_FN_SORT
293293
DPNP_FN_SORT_EXT
294-
DPNP_FN_SQUARE
295-
DPNP_FN_SQUARE_EXT
296294
DPNP_FN_STD
297295
DPNP_FN_STD_EXT
298296
DPNP_FN_SUM
@@ -543,6 +541,5 @@ cpdef dpnp_descriptor dpnp_log2(dpnp_descriptor array1)
543541
cpdef dpnp_descriptor dpnp_radians(dpnp_descriptor array1)
544542
cpdef dpnp_descriptor dpnp_recip(dpnp_descriptor array1)
545543
cpdef dpnp_descriptor dpnp_sinh(dpnp_descriptor array1)
546-
cpdef dpnp_descriptor dpnp_square(dpnp_descriptor array1)
547544
cpdef dpnp_descriptor dpnp_tan(dpnp_descriptor array1, dpnp_descriptor out)
548545
cpdef dpnp_descriptor dpnp_tanh(dpnp_descriptor array1)

dpnp/dpnp_algo/dpnp_algo_trigonometric.pxi

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ __all__ += [
5454
'dpnp_radians',
5555
'dpnp_recip',
5656
'dpnp_sinh',
57-
'dpnp_square',
5857
'dpnp_tan',
5958
'dpnp_tanh',
6059
'dpnp_unwrap'
@@ -133,10 +132,6 @@ cpdef utils.dpnp_descriptor dpnp_sinh(utils.dpnp_descriptor x1):
133132
return call_fptr_1in_1out_strides(DPNP_FN_SINH_EXT, x1)
134133

135134

136-
cpdef utils.dpnp_descriptor dpnp_square(utils.dpnp_descriptor x1):
137-
return call_fptr_1in_1out_strides(DPNP_FN_SQUARE_EXT, x1)
138-
139-
140135
cpdef utils.dpnp_descriptor dpnp_tan(utils.dpnp_descriptor x1, utils.dpnp_descriptor out):
141136
return call_fptr_1in_1out_strides(DPNP_FN_TAN_EXT, x1, dtype=None, out=out, where=True, func_name='tan')
142137

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
"dpnp_not_equal",
6060
"dpnp_sin",
6161
"dpnp_sqrt",
62+
"dpnp_square",
6263
"dpnp_subtract",
6364
]
6465

@@ -900,6 +901,57 @@ def _call_sqrt(src, dst, sycl_queue, depends=None):
900901
return dpnp_array._create_from_usm_ndarray(res_usm)
901902

902903

904+
_square_docstring_ = """
905+
square(x, out=None, order='K')
906+
Computes `x_i**2` (or `x_i*x_i`) for each element `x_i` of input array `x`.
907+
Args:
908+
x (dpnp.ndarray):
909+
Input array.
910+
out ({None, dpnp.ndarray}, optional):
911+
Output array to populate. Array must have the correct
912+
shape and the expected data type.
913+
order ("C","F","A","K", optional): memory layout of the new
914+
output array, if parameter `out` is `None`.
915+
Default: "K".
916+
Return:
917+
dpnp.ndarray:
918+
An array containing the element-wise square results.
919+
"""
920+
921+
922+
def dpnp_square(x, out=None, order="K"):
923+
"""
924+
Invokes sqr() function from pybind11 extension of OneMKL VM if possible.
925+
926+
Otherwise fully relies on dpctl.tensor implementation for square() function.
927+
928+
"""
929+
930+
def _call_square(src, dst, sycl_queue, depends=None):
931+
"""A callback to register in UnaryElementwiseFunc class of dpctl.tensor"""
932+
933+
if depends is None:
934+
depends = []
935+
936+
if vmi._mkl_sqr_to_call(sycl_queue, src, dst):
937+
# call pybind11 extension for sqr() function from OneMKL VM
938+
return vmi._sqr(sycl_queue, src, dst, depends)
939+
return ti._square(src, dst, sycl_queue, depends)
940+
941+
# dpctl.tensor only works with usm_ndarray or scalar
942+
x_usm = dpnp.get_usm_ndarray(x)
943+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
944+
945+
func = UnaryElementwiseFunc(
946+
"square",
947+
ti._square_result_type,
948+
_call_square,
949+
_square_docstring_,
950+
)
951+
res_usm = func(x_usm, out=out_usm, order=order)
952+
return dpnp_array._create_from_usm_ndarray(res_usm)
953+
954+
903955
_subtract_docstring_ = """
904956
subtract(x1, x2, out=None, order="K")
905957

dpnp/dpnp_iface_trigonometric.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
dpnp_log,
5353
dpnp_sin,
5454
dpnp_sqrt,
55+
dpnp_square,
5556
)
5657

5758
__all__ = [
@@ -1108,19 +1109,40 @@ def sqrt(
11081109
)
11091110

11101111

1111-
def square(x1):
1112+
def square(
1113+
x,
1114+
/,
1115+
out=None,
1116+
*,
1117+
order="K",
1118+
where=True,
1119+
dtype=None,
1120+
subok=True,
1121+
**kwargs,
1122+
):
11121123
"""
11131124
Return the element-wise square of the input.
11141125
11151126
For full documentation refer to :obj:`numpy.square`.
11161127
1128+
Returns
1129+
-------
1130+
y : dpnp.ndarray
1131+
Element-wise `x * x`, of the same shape and dtype as `x`.
1132+
11171133
Limitations
11181134
-----------
1119-
Input array is supported as :obj:`dpnp.ndarray`.
1135+
Input array is supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
1136+
Parameter `out` is supported as class:`dpnp.ndarray`, class:`dpctl.tensor.usm_ndarray` or
1137+
with default value ``None``.
1138+
Parameters `where`, `dtype` and `subok` are supported with their default values.
1139+
Otherwise the function will be executed sequentially on CPU.
11201140
Input array data types are limited by supported DPNP :ref:`Data types`.
11211141
11221142
See Also
11231143
--------
1144+
:obj:`dpnp..linalg.matrix_power` : Raise a square matrix
1145+
to the (integer) power `n`.
11241146
:obj:`dpnp.sqrt` : Return the positive square-root of an array,
11251147
element-wise.
11261148
:obj:`dpnp.power` : First array elements raised to powers
@@ -1129,20 +1151,23 @@ def square(x1):
11291151
Examples
11301152
--------
11311153
>>> import dpnp as np
1132-
>>> x = np.array([1, 2, 3])
1133-
>>> out = np.square(x)
1134-
>>> [i for i in out]
1135-
[1, 4, 9]
1154+
>>> x = np.array([-1j, 1])
1155+
>>> np.square(x)
1156+
array([-1.+0.j, 1.+0.j])
11361157
11371158
"""
11381159

1139-
x1_desc = dpnp.get_dpnp_descriptor(
1140-
x1, copy_when_strides=False, copy_when_nondefault_queue=False
1160+
return check_nd_call_func(
1161+
numpy.square,
1162+
dpnp_square,
1163+
x,
1164+
out=out,
1165+
where=where,
1166+
order=order,
1167+
dtype=dtype,
1168+
subok=subok,
1169+
**kwargs,
11411170
)
1142-
if x1_desc:
1143-
return dpnp_square(x1_desc).get_pyobj()
1144-
1145-
return call_origin(numpy.square, x1, **kwargs)
11461171

11471172

11481173
def tan(x1, out=None, **kwargs):

0 commit comments

Comments
 (0)