Skip to content

Commit 7eb3760

Browse files
Reuse dpctl.tensor.sum for dpnp.sum (#1426)
* Reuse dpctl.tensor.sun for dpnp.sum * Update tests for dpnp.sum * Fix remarks * Update tests/third_party/cupy/testing/helper.py --------- Co-authored-by: Anton <[email protected]>
1 parent 66e5a9f commit 7eb3760

File tree

8 files changed

+148
-97
lines changed

8 files changed

+148
-97
lines changed

dpnp/dpnp_array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -984,7 +984,7 @@ def strides(self):
984984

985985
return self._array_obj.strides
986986

987-
def sum(self, axis=None, dtype=None, out=None, keepdims=False, initial=0, where=True):
987+
def sum(self, /, *, axis=None, dtype=None, keepdims=False, out=None, initial=0, where=True):
988988
"""
989989
Returns the sum along a given axis.
990990
@@ -994,7 +994,7 @@ def sum(self, axis=None, dtype=None, out=None, keepdims=False, initial=0, where=
994994
995995
"""
996996

997-
return dpnp.sum(self, axis, dtype, out, keepdims, initial, where)
997+
return dpnp.sum(self, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where)
998998

999999
# 'swapaxes',
10001000

dpnp/dpnp_iface_mathematical.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from .dpnp_utils import *
5151

5252
import dpnp
53+
from dpnp.dpnp_array import dpnp_array
5354

5455
import numpy
5556
import dpctl.tensor as dpt
@@ -173,7 +174,7 @@ def absolute(x,
173174
-------
174175
y : dpnp.ndarray
175176
An array containing the absolute value of each element in `x`.
176-
177+
177178
Limitations
178179
-----------
179180
Parameters `x` is only supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
@@ -601,7 +602,7 @@ def divide(x1,
601602
-------
602603
y : dpnp.ndarray
603604
The quotient ``x1/x2``, element-wise.
604-
605+
605606
Limitations
606607
-----------
607608
Parameters `x1` and `x2` are supported as either scalar, :class:`dpnp.ndarray`
@@ -1342,7 +1343,7 @@ def power(x1,
13421343
-------
13431344
y : dpnp.ndarray
13441345
The bases in `x1` raised to the exponents in `x2`.
1345-
1346+
13461347
Limitations
13471348
-----------
13481349
Parameters `x1` and `x2` are supported as either scalar, :class:`dpnp.ndarray`
@@ -1568,7 +1569,7 @@ def subtract(x1,
15681569
-------
15691570
y : dpnp.ndarray
15701571
The difference of `x1` and `x2`, element-wise.
1571-
1572+
15721573
Limitations
15731574
-----------
15741575
Parameters `x1` and `x2` are supported as either scalar, :class:`dpnp.ndarray`
@@ -1590,45 +1591,52 @@ def subtract(x1,
15901591
return _check_nd_call(numpy.subtract, dpnp_subtract, x1, x2, out=out, where=where, order=order, dtype=dtype, subok=subok, **kwargs)
15911592

15921593

1593-
def sum(x1, axis=None, dtype=None, out=None, keepdims=False, initial=None, where=True):
1594+
def sum(x, /, *, axis=None, dtype=None, keepdims=False, out=None, initial=0, where=True):
15941595
"""
15951596
Sum of array elements over a given axis.
15961597
15971598
For full documentation refer to :obj:`numpy.sum`.
15981599
1600+
Returns
1601+
-------
1602+
y : dpnp.ndarray
1603+
an array containing the sums. If the sum was computed over the
1604+
entire array, a zero-dimensional array is returned. The returned
1605+
array has the data type as described in the `dtype` parameter
1606+
of the Python Array API standard for the `sum` function.
1607+
15991608
Limitations
16001609
-----------
1601-
Parameter `where`` is unsupported.
1602-
Input array data types are limited by DPNP :ref:`Data types`.
1610+
Parameters `x` is supported as either :class:`dpnp.ndarray`
1611+
or :class:`dpctl.tensor.usm_ndarray`.
1612+
Parameters `out`, `initial` and `where` are supported with their default values.
1613+
Otherwise the function will be executed sequentially on CPU.
1614+
Input array data types are limited by supported DPNP :ref:`Data types`.
16031615
16041616
Examples
16051617
--------
16061618
>>> import dpnp as np
16071619
>>> np.sum(np.array([1, 2, 3, 4, 5]))
1608-
15
1609-
>>> result = np.sum([[0, 1], [0, 5]], axis=0)
1610-
[0, 6]
1620+
array(15)
1621+
>>> np.sum(np.array(5))
1622+
array(5)
1623+
>>> result = np.sum(np.array([[0, 1], [0, 5]]), axis=0)
1624+
array([0, 6])
16111625
16121626
"""
16131627

1614-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
1615-
if x1_desc:
1616-
if where is not True:
1617-
pass
1618-
else:
1619-
if dpnp.isscalar(out):
1620-
raise TypeError("output must be an array")
1621-
out_desc = dpnp.get_dpnp_descriptor(out, copy_when_nondefault_queue=False) if out is not None else None
1622-
result_obj = dpnp_sum(x1_desc, axis, dtype, out_desc, keepdims, initial, where).get_pyobj()
1623-
result = dpnp.convert_single_elem_array_to_scalar(result_obj, keepdims)
16241628

1625-
if x1_desc.size == 0 and axis is None:
1626-
result = dpnp.zeros_like(result)
1627-
if out is not None:
1628-
out[...] = result
1629-
return result
1629+
if out is not None:
1630+
pass
1631+
elif initial != 0:
1632+
pass
1633+
elif where is not True:
1634+
pass
1635+
else:
1636+
y = dpt.sum(dpnp.get_usm_ndarray(x), axis=axis, dtype=dtype, keepdims=keepdims)
1637+
return dpnp_array._create_from_usm_ndarray(y)
16301638

1631-
return call_origin(numpy.sum, x1, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where)
1639+
return call_origin(numpy.sum, x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where)
16321640

16331641

16341642
def trapz(y1, x1=None, dx=1.0, axis=-1):

tests/skipped_tests.tbl

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -783,22 +783,6 @@ tests/third_party/cupy/math_tests/test_rounding.py::TestRounding::test_rint
783783
tests/third_party/cupy/math_tests/test_rounding.py::TestRounding::test_rint_negative
784784
tests/third_party/cupy/math_tests/test_rounding.py::TestRounding::test_round_
785785
tests/third_party/cupy/math_tests/test_rounding.py::TestRounding::test_trunc
786-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all
787-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all2
788-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all_keepdims
789-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all_transposed2
790-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axes
791-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axes2
792-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axes3
793-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axes4
794-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axis
795-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axis2
796-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axis_huge
797-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axis_transposed
798-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axis_transposed2
799-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_dtype
800-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_keepdims_and_dtype
801-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_keepdims_multiple_axes
802786
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_out
803787
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_out_wrong_shape
804788
tests/third_party/cupy/math_tests/test_sumprod.py::TestCumprod::test_ndarray_cumprod_2dim_with_axis
@@ -833,7 +817,6 @@ tests/third_party/cupy/math_tests/test_sumprod.py::TestNansumNanprodLong_param_1
833817
tests/third_party/cupy/math_tests/test_sumprod.py::TestNansumNanprodLong_param_15_{axis=0, func='nanprod', keepdims=False, shape=(20, 30, 40), transpose_axes=False}::test_nansum_axis_transposed
834818
tests/third_party/cupy/math_tests/test_sumprod.py::TestNansumNanprodLong_param_9_{axis=0, func='nanprod', keepdims=True, shape=(2, 3, 4), transpose_axes=False}::test_nansum_all
835819
tests/third_party/cupy/math_tests/test_sumprod.py::TestNansumNanprodLong_param_9_{axis=0, func='nanprod', keepdims=True, shape=(2, 3, 4), transpose_axes=False}::test_nansum_axis_transposed
836-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all2
837820
tests/third_party/cupy/math_tests/test_trigonometric.py::TestUnwrap::test_unwrap_1dim
838821
tests/third_party/cupy/math_tests/test_trigonometric.py::TestUnwrap::test_unwrap_1dim_with_discont
839822
tests/third_party/cupy/math_tests/test_trigonometric.py::TestUnwrap::test_unwrap_2dim_with_axis

tests/skipped_tests_gpu.tbl

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,8 @@ tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticBinary2_para
8484
tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticBinary2_param_535_{arg1=array([[1, 2, 3], [4, 5, 6]], dtype=int64), arg2=array([[0, 1, 2], [3, 4, 5]]), dtype=float64, name='floor_divide', use_dtype=False}::test_binary
8585
tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticBinary2_param_543_{arg1=array([[1, 2, 3], [4, 5, 6]], dtype=int64), arg2=array([[0, 1, 2], [3, 4, 5]], dtype=int64), dtype=float64, name='floor_divide', use_dtype=False}::test_binary
8686

87-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_external_prod_all
88-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_external_prod_axis
89-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_external_sum_all
90-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_external_sum_axis
91-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_prod_all
92-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_prod_axis
93-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all
94-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all2
95-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all_keepdims
87+
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_out
88+
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_out_wrong_shape
9689
tests/third_party/cupy/math_tests/test_sumprod.py::TestCumprod::test_cumprod_1dim
9790
tests/third_party/cupy/math_tests/test_sumprod.py::TestCumprod::test_cumprod_2dim_without_axis
9891
tests/third_party/cupy/math_tests/test_sumprod.py::TestCumsum_param_0_{axis=0}::test_cumsum
@@ -921,22 +914,6 @@ tests/third_party/cupy/math_tests/test_rounding.py::TestRounding::test_rint
921914
tests/third_party/cupy/math_tests/test_rounding.py::TestRounding::test_rint_negative
922915
tests/third_party/cupy/math_tests/test_rounding.py::TestRounding::test_round_
923916
tests/third_party/cupy/math_tests/test_rounding.py::TestRounding::test_trunc
924-
925-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all_transposed2
926-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axes
927-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axes2
928-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axes3
929-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axes4
930-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axis
931-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axis2
932-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axis_huge
933-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axis_transposed
934-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axis_transposed2
935-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_dtype
936-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_keepdims_and_dtype
937-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_keepdims_multiple_axes
938-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_out
939-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_out_wrong_shape
940917
tests/third_party/cupy/math_tests/test_sumprod.py::TestCumprod::test_ndarray_cumprod_2dim_with_axis
941918
tests/third_party/cupy/math_tests/test_sumprod.py::TestDiff::test_diff_1dim
942919
tests/third_party/cupy/math_tests/test_sumprod.py::TestDiff::test_diff_1dim_with_n
@@ -1321,7 +1298,7 @@ tests/third_party/cupy/statistics_tests/test_histogram.py::TestHistogram::test_h
13211298
tests/third_party/cupy/statistics_tests/test_histogram.py::TestHistogram::test_histogram_array_bins
13221299
tests/third_party/cupy/statistics_tests/test_histogram.py::TestHistogram::test_histogram_bins_not_ordered
13231300
tests/third_party/cupy/statistics_tests/test_histogram.py::TestHistogram::test_histogram_complex_weights
1324-
tests/third_party/cupy/statistics_tests/test_histogram.py::TestHistogram::test_histogram_complex_weights_uneven_bins
1301+
tests/third_party/cupy/statistics_tests/test_histogram.py::TestHistogram::test_histogram_complex_weights_uneven_bins
13251302
tests/third_party/cupy/statistics_tests/test_histogram.py::TestHistogram::test_histogram_density
13261303
tests/third_party/cupy/statistics_tests/test_histogram.py::TestHistogram::test_histogram_empty
13271304
tests/third_party/cupy/statistics_tests/test_histogram.py::TestHistogram::test_histogram_float_weights

tests/test_mathematical.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def test_ediff1d_int(self, array, data_type):
387387
expected = numpy.ediff1d(np_a)
388388
assert_array_equal(expected, result)
389389

390-
390+
391391
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
392392
def test_ediff1d_args(self):
393393
np_a = numpy.array([1, 2, 4, 7, 0])
@@ -940,6 +940,7 @@ def test_sum_empty(dtype, axis):
940940
assert_array_equal(numpy_res, dpnp_res.asnumpy())
941941

942942

943+
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
943944
@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True, no_bool=True))
944945
def test_sum_empty_out(dtype):
945946
a = dpnp.empty((1, 2, 0, 4), dtype=dtype)

tests/test_sum.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,24 @@
1+
import pytest
2+
13
import dpnp
4+
from tests.helper import get_float_dtypes, has_support_aspect64
25

36
import numpy
47

5-
6-
def test_sum_float64():
7-
a = numpy.array([[[-2., 3.], [9.1, 0.2]], [[-2., 5.0], [-2, -1.2]], [[1.0, -2.], [5.0, -1.1]]])
8+
# Note: numpy.sum() always upcast integers to (u)int64 and float32 to
9+
# float64 for dtype=None. `np.sum` does that too for integers, but not for
10+
# float32, so we need to special-case it for these tests
11+
@pytest.mark.parametrize("dtype", get_float_dtypes())
12+
def test_sum_float(dtype):
13+
a = numpy.array([[[-2., 3.], [9.1, 0.2]], [[-2., 5.0], [-2, -1.2]], [[1.0, -2.], [5.0, -1.1]]], dtype=dtype)
814
ia = dpnp.array(a)
915

1016
for axis in range(len(a)):
1117
result = dpnp.sum(ia, axis=axis)
12-
expected = numpy.sum(a, axis=axis)
18+
if dtype == dpnp.float32 and has_support_aspect64():
19+
expected = numpy.sum(a, axis=axis, dtype=numpy.float64)
20+
else:
21+
expected = numpy.sum(a, axis=axis)
1322
numpy.testing.assert_array_equal(expected, result)
1423

1524

@@ -23,9 +32,12 @@ def test_sum_int():
2332

2433

2534
def test_sum_axis():
26-
a = numpy.array([[[-2., 3.], [9.1, 0.2]], [[-2., 5.0], [-2, -1.2]], [[1.0, -2.], [5.0, -1.1]]])
35+
a = numpy.array([[[-2., 3.], [9.1, 0.2]], [[-2., 5.0], [-2, -1.2]], [[1.0, -2.], [5.0, -1.1]]], dtype='f4')
2736
ia = dpnp.array(a)
2837

2938
result = dpnp.sum(ia, axis=1)
30-
expected = numpy.sum(a, axis=1)
39+
if has_support_aspect64():
40+
expected = numpy.sum(a, axis=1, dtype=numpy.float64)
41+
else:
42+
expected = numpy.sum(a, axis=1)
3143
numpy.testing.assert_array_equal(expected, result)

0 commit comments

Comments
 (0)