Skip to content

Commit e72027d

Browse files
authored
REF: standardize __array_ufunc__ patterns (#45113)
1 parent db6a491 commit e72027d

File tree

7 files changed

+70
-15
lines changed

7 files changed

+70
-15
lines changed

pandas/core/arrays/categorical.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,10 @@
8585
notna,
8686
)
8787

88-
from pandas.core import ops
88+
from pandas.core import (
89+
arraylike,
90+
ops,
91+
)
8992
from pandas.core.accessor import (
9093
PandasDelegate,
9194
delegate_names,
@@ -1516,6 +1519,14 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
15161519
if result is not NotImplemented:
15171520
return result
15181521

1522+
if method == "reduce":
1523+
# e.g. TestCategoricalAnalytics::test_min_max_ordered
1524+
result = arraylike.dispatch_reduction_ufunc(
1525+
self, ufunc, method, *inputs, **kwargs
1526+
)
1527+
if result is not NotImplemented:
1528+
return result
1529+
15191530
# for all other cases, raise for now (similarly as what happens in
15201531
# Series.__array_prepare__)
15211532
raise TypeError(

pandas/core/arrays/numpy_.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from __future__ import annotations
22

3-
import numbers
4-
53
import numpy as np
64

75
from pandas._libs import lib
@@ -130,8 +128,6 @@ def dtype(self) -> PandasDtype:
130128
def __array__(self, dtype: NpDtype | None = None) -> np.ndarray:
131129
return np.asarray(self._ndarray, dtype=dtype)
132130

133-
_HANDLED_TYPES = (np.ndarray, numbers.Number)
134-
135131
def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
136132
# Lightly modified version of
137133
# https://numpy.org/doc/stable/reference/generated/numpy.lib.mixins.NDArrayOperatorsMixin.html

pandas/core/arrays/sparse/array.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
notna,
7474
)
7575

76+
from pandas.core import arraylike
7677
import pandas.core.algorithms as algos
7778
from pandas.core.arraylike import OpsMixin
7879
from pandas.core.arrays import ExtensionArray
@@ -1415,7 +1416,9 @@ def any(self, axis=0, *args, **kwargs):
14151416

14161417
return values.any().item()
14171418

1418-
def sum(self, axis: int = 0, min_count: int = 0, *args, **kwargs) -> Scalar:
1419+
def sum(
1420+
self, axis: int = 0, min_count: int = 0, skipna: bool = True, *args, **kwargs
1421+
) -> Scalar:
14191422
"""
14201423
Sum of non-NA/null values
14211424
@@ -1437,6 +1440,11 @@ def sum(self, axis: int = 0, min_count: int = 0, *args, **kwargs) -> Scalar:
14371440
nv.validate_sum(args, kwargs)
14381441
valid_vals = self._valid_sp_values
14391442
sp_sum = valid_vals.sum()
1443+
has_na = self.sp_index.ngaps > 0 and not self._null_fill_value
1444+
1445+
if has_na and not skipna:
1446+
return na_value_for_dtype(self.dtype.subtype, compat=False)
1447+
14401448
if self._null_fill_value:
14411449
if check_below_min_count(valid_vals.shape, None, min_count):
14421450
return na_value_for_dtype(self.dtype.subtype, compat=False)
@@ -1589,6 +1597,21 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
15891597
if result is not NotImplemented:
15901598
return result
15911599

1600+
if "out" in kwargs:
1601+
# e.g. tests.arrays.sparse.test_arithmetics.test_ndarray_inplace
1602+
res = arraylike.dispatch_ufunc_with_out(
1603+
self, ufunc, method, *inputs, **kwargs
1604+
)
1605+
return res
1606+
1607+
if method == "reduce":
1608+
result = arraylike.dispatch_reduction_ufunc(
1609+
self, ufunc, method, *inputs, **kwargs
1610+
)
1611+
if result is not NotImplemented:
1612+
# e.g. tests.series.test_ufunc.TestNumpyReductions
1613+
return result
1614+
15921615
if len(inputs) == 1:
15931616
# No alignment necessary.
15941617
sp_values = getattr(ufunc, method)(self.sp_values, **kwargs)
@@ -1611,7 +1634,8 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
16111634
sp_values, self.sp_index, SparseDtype(sp_values.dtype, fill_value)
16121635
)
16131636

1614-
result = getattr(ufunc, method)(*(np.asarray(x) for x in inputs), **kwargs)
1637+
new_inputs = tuple(np.asarray(x) for x in inputs)
1638+
result = getattr(ufunc, method)(*new_inputs, **kwargs)
16151639
if out:
16161640
if len(out) == 1:
16171641
out = out[0]

pandas/core/indexes/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,12 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str_t, *inputs, **kwargs):
878878
if result is not NotImplemented:
879879
return result
880880

881+
if "out" in kwargs:
882+
# e.g. test_dti_isub_tdi
883+
return arraylike.dispatch_ufunc_with_out(
884+
self, ufunc, method, *inputs, **kwargs
885+
)
886+
881887
if method == "reduce":
882888
result = arraylike.dispatch_reduction_ufunc(
883889
self, ufunc, method, *inputs, **kwargs

pandas/tests/arithmetic/test_datetime64.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2107,7 +2107,7 @@ def test_dti_isub_tdi(self, tz_naive_fixture):
21072107
np.subtract(out, tdi, out=out)
21082108
tm.assert_datetime_array_equal(out, expected._data)
21092109

2110-
msg = "cannot subtract .* from a TimedeltaArray"
2110+
msg = "cannot subtract a datelike from a TimedeltaArray"
21112111
with pytest.raises(TypeError, match=msg):
21122112
tdi -= dti
21132113

@@ -2116,11 +2116,9 @@ def test_dti_isub_tdi(self, tz_naive_fixture):
21162116
result -= tdi.values
21172117
tm.assert_index_equal(result, expected)
21182118

2119-
msg = "cannot subtract DatetimeArray from ndarray"
21202119
with pytest.raises(TypeError, match=msg):
21212120
tdi.values -= dti
21222121

2123-
msg = "cannot subtract a datelike from a TimedeltaArray"
21242122
with pytest.raises(TypeError, match=msg):
21252123
tdi._values -= dti
21262124

pandas/tests/arrays/categorical/test_analytics.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,32 @@ def test_min_max_not_ordered_raises(self, aggregation):
2929
with pytest.raises(TypeError, match=msg):
3030
agg_func()
3131

32-
def test_min_max_ordered(self):
32+
ufunc = np.minimum if aggregation == "min" else np.maximum
33+
with pytest.raises(TypeError, match=msg):
34+
ufunc.reduce(cat)
35+
36+
def test_min_max_ordered(self, index_or_series_or_array):
3337
cat = Categorical(["a", "b", "c", "d"], ordered=True)
34-
_min = cat.min()
35-
_max = cat.max()
38+
obj = index_or_series_or_array(cat)
39+
_min = obj.min()
40+
_max = obj.max()
3641
assert _min == "a"
3742
assert _max == "d"
3843

44+
assert np.minimum.reduce(obj) == "a"
45+
assert np.maximum.reduce(obj) == "d"
46+
# TODO: raises if we pass axis=0 (on Index and Categorical, not Series)
47+
3948
cat = Categorical(
4049
["a", "b", "c", "d"], categories=["d", "c", "b", "a"], ordered=True
4150
)
42-
_min = cat.min()
43-
_max = cat.max()
51+
obj = index_or_series_or_array(cat)
52+
_min = obj.min()
53+
_max = obj.max()
4454
assert _min == "d"
4555
assert _max == "a"
56+
assert np.minimum.reduce(obj) == "d"
57+
assert np.maximum.reduce(obj) == "a"
4658

4759
@pytest.mark.parametrize(
4860
"categories,expected",

pandas/tests/extension/decimal/array.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
is_list_like,
2626
is_scalar,
2727
)
28+
from pandas.core import arraylike
2829
from pandas.core.arraylike import OpsMixin
2930
from pandas.core.arrays import (
3031
ExtensionArray,
@@ -121,6 +122,13 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
121122
inputs = tuple(x._data if isinstance(x, DecimalArray) else x for x in inputs)
122123
result = getattr(ufunc, method)(*inputs, **kwargs)
123124

125+
if method == "reduce":
126+
result = arraylike.dispatch_reduction_ufunc(
127+
self, ufunc, method, *inputs, **kwargs
128+
)
129+
if result is not NotImplemented:
130+
return result
131+
124132
def reconstruct(x):
125133
if isinstance(x, (decimal.Decimal, numbers.Number)):
126134
return x

0 commit comments

Comments
 (0)