Skip to content

Commit 8f3c37c

Browse files
committed
use dpctl.tensor.sign and dpctl.tensor.negative in dpnp
1 parent 42e02d9 commit 8f3c37c

File tree

8 files changed

+179
-74
lines changed

8 files changed

+179
-74
lines changed

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,6 @@ enum class DPNPFuncName : size_t
288288
DPNP_FN_NANVAR_EXT, /**< Used in numpy.nanvar() impl, requires extra
289289
parameters */
290290
DPNP_FN_NEGATIVE, /**< Used in numpy.negative() impl */
291-
DPNP_FN_NEGATIVE_EXT, /**< Used in numpy.negative() impl, requires extra
292-
parameters */
293291
DPNP_FN_NONZERO, /**< Used in numpy.nonzero() impl */
294292
DPNP_FN_NOT_EQUAL_EXT, /**< Used in numpy.not_equal() impl, requires extra
295293
parameters */
@@ -449,10 +447,8 @@ enum class DPNPFuncName : size_t
449447
DPNP_FN_SEARCHSORTED_EXT, /**< Used in numpy.searchsorted() impl, requires
450448
extra parameters */
451449
DPNP_FN_SIGN, /**< Used in numpy.sign() impl */
452-
DPNP_FN_SIGN_EXT, /**< Used in numpy.sign() impl, requires extra parameters
453-
*/
454-
DPNP_FN_SIN, /**< Used in numpy.sin() impl */
455-
DPNP_FN_SINH, /**< Used in numpy.sinh() impl */
450+
DPNP_FN_SIN, /**< Used in numpy.sin() impl */
451+
DPNP_FN_SINH, /**< Used in numpy.sinh() impl */
456452
DPNP_FN_SINH_EXT, /**< Used in numpy.sinh() impl, requires extra parameters
457453
*/
458454
DPNP_FN_SORT, /**< Used in numpy.sort() impl */

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,15 +1132,6 @@ static void func_map_init_elemwise_1arg_1type(func_map_t &fmap)
11321132
fmap[DPNPFuncName::DPNP_FN_NEGATIVE][eft_DBL][eft_DBL] = {
11331133
eft_DBL, (void *)dpnp_negative_c_default<double>};
11341134

1135-
fmap[DPNPFuncName::DPNP_FN_NEGATIVE_EXT][eft_INT][eft_INT] = {
1136-
eft_INT, (void *)dpnp_negative_c_ext<int32_t>};
1137-
fmap[DPNPFuncName::DPNP_FN_NEGATIVE_EXT][eft_LNG][eft_LNG] = {
1138-
eft_LNG, (void *)dpnp_negative_c_ext<int64_t>};
1139-
fmap[DPNPFuncName::DPNP_FN_NEGATIVE_EXT][eft_FLT][eft_FLT] = {
1140-
eft_FLT, (void *)dpnp_negative_c_ext<float>};
1141-
fmap[DPNPFuncName::DPNP_FN_NEGATIVE_EXT][eft_DBL][eft_DBL] = {
1142-
eft_DBL, (void *)dpnp_negative_c_ext<double>};
1143-
11441135
fmap[DPNPFuncName::DPNP_FN_RECIP][eft_INT][eft_INT] = {
11451136
eft_INT, (void *)dpnp_recip_c_default<int32_t>};
11461137
fmap[DPNPFuncName::DPNP_FN_RECIP][eft_LNG][eft_LNG] = {
@@ -1168,15 +1159,6 @@ static void func_map_init_elemwise_1arg_1type(func_map_t &fmap)
11681159
fmap[DPNPFuncName::DPNP_FN_SIGN][eft_DBL][eft_DBL] = {
11691160
eft_DBL, (void *)dpnp_sign_c_default<double>};
11701161

1171-
fmap[DPNPFuncName::DPNP_FN_SIGN_EXT][eft_INT][eft_INT] = {
1172-
eft_INT, (void *)dpnp_sign_c_ext<int32_t>};
1173-
fmap[DPNPFuncName::DPNP_FN_SIGN_EXT][eft_LNG][eft_LNG] = {
1174-
eft_LNG, (void *)dpnp_sign_c_ext<int64_t>};
1175-
fmap[DPNPFuncName::DPNP_FN_SIGN_EXT][eft_FLT][eft_FLT] = {
1176-
eft_FLT, (void *)dpnp_sign_c_ext<float>};
1177-
fmap[DPNPFuncName::DPNP_FN_SIGN_EXT][eft_DBL][eft_DBL] = {
1178-
eft_DBL, (void *)dpnp_sign_c_ext<double>};
1179-
11801162
fmap[DPNPFuncName::DPNP_FN_SQUARE][eft_INT][eft_INT] = {
11811163
eft_INT, (void *)dpnp_square_c_default<int32_t>};
11821164
fmap[DPNPFuncName::DPNP_FN_SQUARE][eft_LNG][eft_LNG] = {

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
168168
DPNP_FN_MODF_EXT
169169
DPNP_FN_NANVAR
170170
DPNP_FN_NANVAR_EXT
171-
DPNP_FN_NEGATIVE
172-
DPNP_FN_NEGATIVE_EXT
173171
DPNP_FN_NONZERO
174172
DPNP_FN_ONES
175173
DPNP_FN_ONES_LIKE
@@ -267,8 +265,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
267265
DPNP_FN_RNG_ZIPF_EXT
268266
DPNP_FN_SEARCHSORTED
269267
DPNP_FN_SEARCHSORTED_EXT
270-
DPNP_FN_SIGN
271-
DPNP_FN_SIGN_EXT
272268
DPNP_FN_SINH
273269
DPNP_FN_SINH_EXT
274270
DPNP_FN_SORT
@@ -440,7 +436,6 @@ cpdef dpnp_descriptor dpnp_maximum(dpnp_descriptor x1_obj, dpnp_descriptor x2_ob
440436
dpnp_descriptor out=*, object where=*)
441437
cpdef dpnp_descriptor dpnp_minimum(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
442438
dpnp_descriptor out=*, object where=*)
443-
cpdef dpnp_descriptor dpnp_negative(dpnp_descriptor array1)
444439
cpdef dpnp_descriptor dpnp_power(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
445440
dpnp_descriptor out=*, object where=*)
446441
cpdef dpnp_descriptor dpnp_remainder(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,

dpnp/dpnp_algo/dpnp_algo_mathematical.pxi

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,9 @@ __all__ += [
5959
"dpnp_nancumsum",
6060
"dpnp_nanprod",
6161
"dpnp_nansum",
62-
"dpnp_negative",
6362
"dpnp_power",
6463
"dpnp_prod",
6564
"dpnp_remainder",
66-
"dpnp_sign",
6765
"dpnp_sum",
6866
"dpnp_trapz",
6967
"dpnp_trunc"
@@ -472,10 +470,6 @@ cpdef utils.dpnp_descriptor dpnp_nansum(utils.dpnp_descriptor x1):
472470
return dpnp_sum(result)
473471

474472

475-
cpdef utils.dpnp_descriptor dpnp_negative(dpnp_descriptor x1):
476-
return call_fptr_1in_1out_strides(DPNP_FN_NEGATIVE_EXT, x1)
477-
478-
479473
cpdef utils.dpnp_descriptor dpnp_power(utils.dpnp_descriptor x1_obj,
480474
utils.dpnp_descriptor x2_obj,
481475
object dtype=None,
@@ -554,10 +548,6 @@ cpdef utils.dpnp_descriptor dpnp_remainder(utils.dpnp_descriptor x1_obj,
554548
return call_fptr_2in_1out(DPNP_FN_REMAINDER_EXT, x1_obj, x2_obj, dtype, out, where)
555549

556550

557-
cpdef utils.dpnp_descriptor dpnp_sign(utils.dpnp_descriptor x1):
558-
return call_fptr_1in_1out_strides(DPNP_FN_SIGN_EXT, x1)
559-
560-
561551
cpdef utils.dpnp_descriptor dpnp_sum(utils.dpnp_descriptor x1,
562552
object axis=None,
563553
object dtype=None,

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,6 +1132,42 @@ def dpnp_multiply(x1, x2, out=None, order="K"):
11321132
return dpnp_array._create_from_usm_ndarray(res_usm)
11331133

11341134

1135+
_negative_docstring = """
1136+
negative(x, out=None, order="K")
1137+
1138+
Computes the numerical negative for each element `x_i` of input array `x`.
1139+
1140+
Args:
1141+
x (dpnp.ndarray):
1142+
Input array, expected to have numeric data type.
1143+
out ({None, dpnp.ndarray}, optional):
1144+
Output array to populate.
1145+
Array have the correct shape and the expected data type.
1146+
order ("C","F","A","K", optional):
1147+
Memory layout of the newly output array, if parameter `out` is `None`.
1148+
Default: "K".
1149+
Returns:
1150+
dpnp.ndarray:
1151+
An array containing the negative of `x`.
1152+
"""
1153+
1154+
1155+
negative_func = UnaryElementwiseFunc(
1156+
"negative", ti._negative_result_type, ti._negative, _negative_docstring
1157+
)
1158+
1159+
1160+
def dpnp_negative(x, out=None, order="K"):
1161+
"""Invokes negative() from dpctl.tensor implementation for negative() function."""
1162+
1163+
# dpctl.tensor only works with usm_ndarray
1164+
x1_usm = dpnp.get_usm_ndarray(x)
1165+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
1166+
1167+
res_usm = negative_func(x1_usm, out=out_usm, order=order)
1168+
return dpnp_array._create_from_usm_ndarray(res_usm)
1169+
1170+
11351171
_not_equal_docstring_ = """
11361172
not_equal(x1, x2, out=None, order="K")
11371173
@@ -1217,6 +1253,47 @@ def dpnp_right_shift(x1, x2, out=None, order="K"):
12171253
return dpnp_array._create_from_usm_ndarray(res_usm)
12181254

12191255

1256+
_sign_docstring = """
1257+
sign(x, out=None, order="K")
1258+
1259+
Computes an indication of the sign of each element `x_i` of input array `x`
1260+
using the signum function.
1261+
1262+
The signum function returns `-1` if `x_i` is less than `0`,
1263+
`0` if `x_i` is equal to `0`, and `1` if `x_i` is greater than `0`.
1264+
1265+
Args:
1266+
x (dpnp.ndarray):
1267+
Input array, expected to have numeric data type.
1268+
out ({None, dpnp.ndarray}, optional):
1269+
Output array to populate.
1270+
Array have the correct shape and the expected data type.
1271+
order ("C","F","A","K", optional):
1272+
Memory layout of the newly output array, if parameter `out` is `None`.
1273+
Default: "K".
1274+
Returns:
1275+
dpnp.ndarray:
1276+
An array containing the element-wise results. The data type of the
1277+
returned array is determined by the Type Promotion Rules.
1278+
"""
1279+
1280+
1281+
sign_func = UnaryElementwiseFunc(
1282+
"sign", ti._sign_result_type, ti._sign, _sign_docstring
1283+
)
1284+
1285+
1286+
def dpnp_sign(x, out=None, order="K"):
1287+
"""Invokes sign() from dpctl.tensor implementation for sign() function."""
1288+
1289+
# dpctl.tensor only works with usm_ndarray
1290+
x1_usm = dpnp.get_usm_ndarray(x)
1291+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
1292+
1293+
res_usm = sign_func(x1_usm, out=out_usm, order=order)
1294+
return dpnp_array._create_from_usm_ndarray(res_usm)
1295+
1296+
12201297
_sin_docstring = """
12211298
sin(x, out=None, order='K')
12221299
Computes sine for each element `x_i` of input array `x`.

dpnp/dpnp_iface_mathematical.py

Lines changed: 98 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,13 @@
4949

5050
from .dpnp_algo import *
5151
from .dpnp_algo.dpnp_elementwise_common import (
52+
check_nd_call_func,
5253
dpnp_add,
5354
dpnp_divide,
5455
dpnp_floor_divide,
5556
dpnp_multiply,
57+
dpnp_negative,
58+
dpnp_sign,
5659
dpnp_subtract,
5760
)
5861
from .dpnp_utils import *
@@ -1404,37 +1407,64 @@ def nansum(x1, **kwargs):
14041407
return call_origin(numpy.nansum, x1, **kwargs)
14051408

14061409

1407-
def negative(x1, **kwargs):
1410+
def negative(
1411+
x,
1412+
/,
1413+
out=None,
1414+
*,
1415+
order="K",
1416+
where=True,
1417+
dtype=None,
1418+
subok=True,
1419+
**kwargs,
1420+
):
14081421
"""
14091422
Negative element-wise.
14101423
14111424
For full documentation refer to :obj:`numpy.negative`.
14121425
1426+
Returns
1427+
-------
1428+
out : dpnp.ndarray
1429+
The numerical negative of each element of `x`.
1430+
14131431
Limitations
14141432
-----------
1415-
Parameter ``x1`` is supported as :obj:`dpnp.ndarray`.
1416-
Keyword arguments ``kwargs`` are currently unsupported.
1417-
Otherwise the functions will be executed sequentially on CPU.
1418-
Input array data types are limited by supported DPNP :ref:`Data types`.
1433+
Parameters `x` is only supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
1434+
Parameters `where`, `dtype` and `subok` are supported with their default values.
1435+
Keyword arguments `kwargs` are currently unsupported.
1436+
Otherwise the functions will be executed sequentially on CPU.
1437+
Input array data types are limited by supported DPNP :ref:`Data types`.
14191438
1420-
.. see also: :obj:`dpnp.copysign` : Change the sign of x1 to that of x2, element-wise.
1439+
See Also
1440+
-----------
1441+
:obj:`dpnp.copysign` : Change the sign of `x1` to that of `x2`, element-wise.
14211442
14221443
Examples
14231444
--------
14241445
>>> import dpnp as np
1425-
>>> result = np.negative([1, -1])
1426-
>>> [x for x in result]
1427-
[-1, 1]
1446+
>>> np.negative(np.array([1, -1]))
1447+
array([-1, 1])
1448+
1449+
The ``-`` operator can be used as a shorthand for ``negative`` on
1450+
:class:`dpnp.ndarray`.
14281451
1452+
>>> x = np.array([1., -1.])
1453+
>>> -x
1454+
array([-1., 1.])
14291455
"""
14301456

1431-
x1_desc = dpnp.get_dpnp_descriptor(
1432-
x1, copy_when_strides=False, copy_when_nondefault_queue=False
1457+
return check_nd_call_func(
1458+
numpy.negative,
1459+
dpnp_negative,
1460+
x,
1461+
out=out,
1462+
where=where,
1463+
order=order,
1464+
dtype=dtype,
1465+
subok=subok,
1466+
**kwargs,
14331467
)
1434-
if x1_desc and not kwargs:
1435-
return dpnp_negative(x1_desc).get_pyobj()
1436-
1437-
return call_origin(numpy.negative, x1, **kwargs)
14381468

14391469

14401470
def power(x1, x2, /, out=None, *, where=True, dtype=None, subok=True, **kwargs):
@@ -1686,35 +1716,70 @@ def round_(a, decimals=0, out=None):
16861716
return around(a, decimals, out)
16871717

16881718

1689-
def sign(x1, **kwargs):
1719+
def sign(
1720+
x,
1721+
/,
1722+
out=None,
1723+
*,
1724+
order="K",
1725+
where=True,
1726+
dtype=None,
1727+
subok=True,
1728+
**kwargs,
1729+
):
16901730
"""
16911731
Returns an element-wise indication of the sign of a number.
16921732
16931733
For full documentation refer to :obj:`numpy.sign`.
16941734
1735+
Returns
1736+
-------
1737+
out : dpnp.ndarray
1738+
The indication of the sign of each element of `x`.
1739+
16951740
Limitations
16961741
-----------
1697-
Parameter ``x1`` is supported as :obj:`dpnp.ndarray`.
1698-
Keyword arguments ``kwargs`` are currently unsupported.
1699-
Otherwise the functions will be executed sequentially on CPU.
1700-
Input array data types are limited by supported DPNP :ref:`Data types`.
1742+
Parameters `x` is only supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
1743+
Parameters `where`, `dtype` and `subok` are supported with their default values.
1744+
Keyword argument `kwargs` is currently unsupported.
1745+
Otherwise the functions will be executed sequentially on CPU.
1746+
Input array data types are limited by supported DPNP :ref:`Data types`.
17011747
17021748
Examples
17031749
--------
17041750
>>> import dpnp as np
1705-
>>> result = np.sign(np.array([-5., 4.5]))
1706-
>>> [x for x in result]
1707-
[-1.0, 1.0]
1708-
1709-
"""
1710-
1711-
x1_desc = dpnp.get_dpnp_descriptor(
1712-
x1, copy_when_strides=False, copy_when_nondefault_queue=False
1713-
)
1714-
if x1_desc and not kwargs:
1715-
return dpnp_sign(x1_desc).get_pyobj()
1716-
1717-
return call_origin(numpy.sign, x1, **kwargs)
1751+
>>> np.sign(np.array([-5., 4.5]))
1752+
array([-1.0, 1.0])
1753+
>>> np.sign(np.array(0))
1754+
array(0)
1755+
>>> np.sign(np.array(5-2j))
1756+
array([1+0j])
1757+
1758+
"""
1759+
1760+
if numpy.iscomplexobj(x):
1761+
return call_origin(
1762+
numpy.sign,
1763+
x,
1764+
out=out,
1765+
where=where,
1766+
order=order,
1767+
dtype=dtype,
1768+
subok=subok,
1769+
**kwargs,
1770+
)
1771+
else:
1772+
return check_nd_call_func(
1773+
numpy.sign,
1774+
dpnp_sign,
1775+
x,
1776+
out=out,
1777+
where=where,
1778+
order=order,
1779+
dtype=dtype,
1780+
subok=subok,
1781+
**kwargs,
1782+
)
17181783

17191784

17201785
def subtract(

0 commit comments

Comments
 (0)