Skip to content

Commit 20513fb

Browse files
authored
implement dpnp.logsumexp and dpnp.reduce_hypot (#1648)
* implement logsumexp and reduce_hypot * fix pre-commit * address comments
1 parent 9e8323e commit 20513fb

File tree

7 files changed

+314
-14
lines changed

7 files changed

+314
-14
lines changed

doc/reference/math.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ Trigonometric functions
2323
dpnp.unwrap
2424
dpnp.deg2rad
2525
dpnp.rad2deg
26+
dpnp.reduce_hypot
2627

2728

2829
Hyperbolic functions
@@ -94,6 +95,7 @@ Exponents and logarithms
9495
dpnp.log1p
9596
dpnp.logaddexp
9697
dpnp.logaddexp2
98+
dpnp.logsumexp
9799

98100

99101
Other special functions

dpnp/dpnp_iface_trigonometric.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,12 @@
4040
"""
4141

4242

43+
import dpctl.tensor as dpt
4344
import numpy
4445

4546
import dpnp
4647
from dpnp.dpnp_algo import *
48+
from dpnp.dpnp_array import dpnp_array
4749
from dpnp.dpnp_utils import *
4850

4951
from .dpnp_algo.dpnp_elementwise_common import (
@@ -98,9 +100,11 @@
98100
"log1p",
99101
"log2",
100102
"logaddexp",
103+
"logsumexp",
101104
"rad2deg",
102105
"radians",
103106
"reciprocal",
107+
"reduce_hypot",
104108
"rsqrt",
105109
"sin",
106110
"sinh",
@@ -989,6 +993,10 @@ def hypot(
989993
Otherwise the function will be executed sequentially on CPU.
990994
Input array data types are limited by supported real-valued data types.
991995
996+
See Also
997+
--------
998+
:obj:`dpnp.reduce_hypot` : The square root of the sum of squares of elements in the input array.
999+
9921000
Examples
9931001
--------
9941002
>>> import dpnp as np
@@ -1303,6 +1311,7 @@ def logaddexp(
13031311
--------
13041312
:obj:`dpnp.log` : Natural logarithm, element-wise.
13051313
:obj:`dpnp.exp` : Exponential, element-wise.
1314+
:obj:`dpnp.logsumdexp` : Logarithm of the sum of exponentials of elements in the input array.
13061315
13071316
Examples
13081317
--------
@@ -1331,6 +1340,81 @@ def logaddexp(
13311340
)
13321341

13331342

1343+
def logsumexp(x, axis=None, out=None, dtype=None, keepdims=False):
1344+
"""
1345+
Calculates the logarithm of the sum of exponentials of elements in the input array.
1346+
1347+
Parameters
1348+
----------
1349+
x : {dpnp_array, usm_ndarray}
1350+
Input array, expected to have a real-valued data type.
1351+
axis : int or tuple of ints, optional
1352+
Axis or axes along which values must be computed. If a tuple
1353+
of unique integers, values are computed over multiple axes.
1354+
If ``None``, the result is computed over the entire array.
1355+
Default: ``None``.
1356+
out : {dpnp_array, usm_ndarray}, optional
1357+
If provided, the result will be inserted into this array. It should
1358+
be of the appropriate shape and dtype.
1359+
dtype : data type, optional
1360+
Data type of the returned array. If ``None``, the default data
1361+
type is inferred from the "kind" of the input array data type.
1362+
* If `x` has a real-valued floating-point data type,
1363+
the returned array will have the default real-valued
1364+
floating-point data type for the device where input
1365+
array `x` is allocated.
1366+
* If `x` has a boolean or integral data type, the returned array
1367+
will have the default floating point data type for the device
1368+
where input array `x` is allocated.
1369+
* If `x` has a complex-valued floating-point data type,
1370+
an error is raised.
1371+
If the data type (either specified or resolved) differs from the
1372+
data type of `x`, the input array elements are cast to the
1373+
specified data type before computing the result. Default: ``None``.
1374+
keepdims : bool
1375+
If ``True``, the reduced axes (dimensions) are included in the result
1376+
as singleton dimensions, so that the returned array remains
1377+
compatible with the input arrays according to Array Broadcasting
1378+
rules. Otherwise, if ``False``, the reduced axes are not included in
1379+
the returned array. Default: ``False``.
1380+
1381+
Returns
1382+
-------
1383+
out : dpnp.ndarray
1384+
An array containing the results. If the result was computed over
1385+
the entire array, a zero-dimensional array is returned. The returned
1386+
array has the data type as described in the `dtype` parameter
1387+
description above.
1388+
1389+
Note
1390+
----
1391+
This function is equivalent of `numpy.logaddexp.reduce`.
1392+
1393+
See Also
1394+
--------
1395+
:obj:`dpnp.log` : Natural logarithm, element-wise.
1396+
:obj:`dpnp.exp` : Exponential, element-wise.
1397+
:obj:`dpnp.logaddexp` : Logarithm of the sum of exponentiations of the inputs, element-wise.
1398+
1399+
Examples
1400+
--------
1401+
>>> import dpnp as np
1402+
>>> a = np.ones(10)
1403+
>>> np.logsumexp(a)
1404+
array(3.30258509)
1405+
>>> np.log(np.sum(np.exp(a)))
1406+
array(3.30258509)
1407+
1408+
"""
1409+
1410+
dpt_array = dpnp.get_usm_ndarray(x)
1411+
result = dpnp_array._create_from_usm_ndarray(
1412+
dpt.logsumexp(dpt_array, axis=axis, dtype=dtype, keepdims=keepdims)
1413+
)
1414+
1415+
return dpnp.get_result_array(result, out, casting="same_kind")
1416+
1417+
13341418
def reciprocal(x1, **kwargs):
13351419
"""
13361420
Return the reciprocal of the argument, element-wise.
@@ -1363,6 +1447,79 @@ def reciprocal(x1, **kwargs):
13631447
return call_origin(numpy.reciprocal, x1, **kwargs)
13641448

13651449

1450+
def reduce_hypot(x, axis=None, out=None, dtype=None, keepdims=False):
1451+
"""
1452+
Calculates the square root of the sum of squares of elements in the input array.
1453+
1454+
Parameters
1455+
----------
1456+
x : {dpnp_array, usm_ndarray}
1457+
Input array, expected to have a real-valued data type.
1458+
axis : int or tuple of ints, optional
1459+
Axis or axes along which values must be computed. If a tuple
1460+
of unique integers, values are computed over multiple axes.
1461+
If ``None``, the result is computed over the entire array.
1462+
Default: ``None``.
1463+
out : {dpnp_array, usm_ndarray}, optional
1464+
If provided, the result will be inserted into this array. It should
1465+
be of the appropriate shape and dtype.
1466+
dtype : data type, optional
1467+
Data type of the returned array. If ``None``, the default data
1468+
type is inferred from the "kind" of the input array data type.
1469+
* If `x` has a real-valued floating-point data type,
1470+
the returned array will have the default real-valued
1471+
floating-point data type for the device where input
1472+
array `x` is allocated.
1473+
* If `x` has a boolean or integral data type, the returned array
1474+
will have the default floating point data type for the device
1475+
where input array `x` is allocated.
1476+
* If `x` has a complex-valued floating-point data type,
1477+
an error is raised.
1478+
If the data type (either specified or resolved) differs from the
1479+
data type of `x`, the input array elements are cast to the
1480+
specified data type before computing the result. Default: ``None``.
1481+
keepdims : bool
1482+
If ``True``, the reduced axes (dimensions) are included in the result
1483+
as singleton dimensions, so that the returned array remains
1484+
compatible with the input arrays according to Array Broadcasting
1485+
rules. Otherwise, if ``False``, the reduced axes are not included in
1486+
the returned array. Default: ``False``.
1487+
1488+
Returns
1489+
-------
1490+
out : dpnp.ndarray
1491+
An array containing the results. If the result was computed over
1492+
the entire array, a zero-dimensional array is returned. The returned
1493+
array has the data type as described in the `dtype` parameter
1494+
description above.
1495+
1496+
Note
1497+
----
1498+
This function is equivalent of `numpy.hypot.reduce`.
1499+
1500+
See Also
1501+
--------
1502+
:obj:`dpnp.hypot` : Given the "legs" of a right triangle, return its hypotenuse.
1503+
1504+
Examples
1505+
--------
1506+
>>> import dpnp as np
1507+
>>> a = np.ones(10)
1508+
>>> np.reduce_hypot(a)
1509+
array(3.16227766)
1510+
>>> np.sqrt(np.sum(np.square(a)))
1511+
array(3.16227766)
1512+
1513+
"""
1514+
1515+
dpt_array = dpnp.get_usm_ndarray(x)
1516+
result = dpnp_array._create_from_usm_ndarray(
1517+
dpt.reduce_hypot(dpt_array, axis=axis, dtype=dtype, keepdims=keepdims)
1518+
)
1519+
1520+
return dpnp.get_result_array(result, out, casting="same_kind")
1521+
1522+
13661523
def rsqrt(
13671524
x,
13681525
/,

tests/helper.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,17 @@ def assert_dtype_allclose(
3434
list_64bit_types = [numpy.float64, numpy.complex128]
3535
is_inexact = lambda x: dpnp.issubdtype(x.dtype, dpnp.inexact)
3636
if is_inexact(dpnp_arr) or is_inexact(numpy_arr):
37-
tol = 8 * max(
38-
dpnp.finfo(dpnp_arr).resolution,
39-
numpy.finfo(numpy_arr.dtype).resolution,
37+
tol_dpnp = (
38+
dpnp.finfo(dpnp_arr).resolution
39+
if is_inexact(dpnp_arr)
40+
else -dpnp.inf
4041
)
42+
tol_numpy = (
43+
numpy.finfo(numpy_arr.dtype).resolution
44+
if is_inexact(numpy_arr)
45+
else -dpnp.inf
46+
)
47+
tol = 8 * max(tol_dpnp, tol_numpy)
4148
assert_allclose(dpnp_arr.asnumpy(), numpy_arr, atol=tol, rtol=tol)
4249
if check_type:
4350
numpy_arr_dtype = numpy_arr.dtype

tests/test_mathematical.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1752,6 +1752,90 @@ def test_invalid_out(self, out):
17521752
assert_raises(TypeError, numpy.hypot, a.asnumpy(), 2, out)
17531753

17541754

1755+
class TestLogSumExp:
1756+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True))
1757+
@pytest.mark.parametrize("axis", [None, 2, -1, (0, 1)])
1758+
@pytest.mark.parametrize("keepdims", [True, False])
1759+
def test_logsumexp(self, dtype, axis, keepdims):
1760+
a = dpnp.ones((3, 4, 5, 6, 7), dtype=dtype)
1761+
res = dpnp.logsumexp(a, axis=axis, keepdims=keepdims)
1762+
exp_dtype = dpnp.default_float_type(a.device)
1763+
exp = numpy.logaddexp.reduce(
1764+
dpnp.asnumpy(a), axis=axis, keepdims=keepdims, dtype=exp_dtype
1765+
)
1766+
1767+
assert_dtype_allclose(res, exp)
1768+
1769+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True))
1770+
@pytest.mark.parametrize("axis", [None, 2, -1, (0, 1)])
1771+
@pytest.mark.parametrize("keepdims", [True, False])
1772+
def test_logsumexp_out(self, dtype, axis, keepdims):
1773+
a = dpnp.ones((3, 4, 5, 6, 7), dtype=dtype)
1774+
exp_dtype = dpnp.default_float_type(a.device)
1775+
exp = numpy.logaddexp.reduce(
1776+
dpnp.asnumpy(a), axis=axis, keepdims=keepdims, dtype=exp_dtype
1777+
)
1778+
dpnp_out = dpnp.empty(exp.shape, dtype=exp_dtype)
1779+
res = dpnp.logsumexp(a, axis=axis, out=dpnp_out, keepdims=keepdims)
1780+
1781+
assert res is dpnp_out
1782+
assert_dtype_allclose(res, exp)
1783+
1784+
@pytest.mark.parametrize(
1785+
"in_dtype", get_all_dtypes(no_bool=True, no_complex=True)
1786+
)
1787+
@pytest.mark.parametrize("out_dtype", get_all_dtypes(no_bool=True))
1788+
def test_logsumexp_dtype(self, in_dtype, out_dtype):
1789+
a = dpnp.ones(100, dtype=in_dtype)
1790+
res = dpnp.logsumexp(a, dtype=out_dtype)
1791+
exp = numpy.logaddexp.reduce(dpnp.asnumpy(a))
1792+
exp = exp.astype(out_dtype)
1793+
1794+
assert_allclose(res, exp, rtol=1e-06)
1795+
1796+
1797+
class TestReduceHypot:
1798+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True))
1799+
@pytest.mark.parametrize("axis", [None, 2, -1, (0, 1)])
1800+
@pytest.mark.parametrize("keepdims", [True, False])
1801+
def test_reduce_hypot(self, dtype, axis, keepdims):
1802+
a = dpnp.ones((3, 4, 5, 6, 7), dtype=dtype)
1803+
res = dpnp.reduce_hypot(a, axis=axis, keepdims=keepdims)
1804+
exp_dtype = dpnp.default_float_type(a.device)
1805+
exp = numpy.hypot.reduce(
1806+
dpnp.asnumpy(a), axis=axis, keepdims=keepdims, dtype=exp_dtype
1807+
)
1808+
1809+
assert_dtype_allclose(res, exp)
1810+
1811+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True))
1812+
@pytest.mark.parametrize("axis", [None, 2, -1, (0, 1)])
1813+
@pytest.mark.parametrize("keepdims", [True, False])
1814+
def test_reduce_hypot_out(self, dtype, axis, keepdims):
1815+
a = dpnp.ones((3, 4, 5, 6, 7), dtype=dtype)
1816+
exp_dtype = dpnp.default_float_type(a.device)
1817+
exp = numpy.hypot.reduce(
1818+
dpnp.asnumpy(a), axis=axis, keepdims=keepdims, dtype=exp_dtype
1819+
)
1820+
dpnp_out = dpnp.empty(exp.shape, dtype=exp_dtype)
1821+
res = dpnp.reduce_hypot(a, axis=axis, out=dpnp_out, keepdims=keepdims)
1822+
1823+
assert res is dpnp_out
1824+
assert_dtype_allclose(res, exp)
1825+
1826+
@pytest.mark.parametrize(
1827+
"in_dtype", get_all_dtypes(no_bool=True, no_complex=True)
1828+
)
1829+
@pytest.mark.parametrize("out_dtype", get_all_dtypes(no_bool=True))
1830+
def test_reduce_hypot_dtype(self, in_dtype, out_dtype):
1831+
a = dpnp.ones(99, dtype=in_dtype)
1832+
res = dpnp.reduce_hypot(a, dtype=out_dtype)
1833+
exp = numpy.hypot.reduce(dpnp.asnumpy(a))
1834+
exp = exp.astype(out_dtype)
1835+
1836+
assert_allclose(res, exp, rtol=1e-06)
1837+
1838+
17551839
class TestMaximum:
17561840
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
17571841
def test_maximum(self, dtype):

tests/test_strides.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import dpnp
88

9-
from .helper import get_all_dtypes
9+
from .helper import assert_dtype_allclose, get_all_dtypes
1010

1111

1212
def _getattr(ex, str_):
@@ -99,17 +99,33 @@ def test_strides_1arg(func_name, dtype, shape):
9999

100100

101101
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
102-
def test_strides_rsqrt(dtype):
103-
a = numpy.arange(1, 11, dtype=dtype)
104-
b = a[::2]
102+
def test_rsqrt(dtype):
103+
a = numpy.arange(1, 11, dtype=dtype)[::2]
104+
dpa = dpnp.arange(1, 11, dtype=dtype)[::2]
105105

106-
dpa = dpnp.arange(1, 11, dtype=dtype)
107-
dpb = dpa[::2]
106+
result = dpnp.rsqrt(dpa)
107+
expected = 1 / numpy.sqrt(a)
108+
assert_dtype_allclose(result, expected)
108109

109-
result = dpnp.rsqrt(dpb)
110-
expected = 1 / numpy.sqrt(b)
111110

112-
assert_allclose(result, expected, rtol=1e-06)
111+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
112+
def test_logsumexp(dtype):
113+
a = numpy.arange(10, dtype=dtype)[::2]
114+
dpa = dpnp.arange(10, dtype=dtype)[::2]
115+
116+
result = dpnp.logsumexp(dpa)
117+
expected = numpy.logaddexp.reduce(a)
118+
assert_allclose(result, expected)
119+
120+
121+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
122+
def test_reduce_hypot(dtype):
123+
a = numpy.arange(10, dtype=dtype)[::2]
124+
dpa = dpnp.arange(10, dtype=dtype)[::2]
125+
126+
result = dpnp.reduce_hypot(dpa)
127+
expected = numpy.hypot.reduce(a)
128+
assert_allclose(result, expected)
113129

114130

115131
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)