Skip to content

Commit 4d6a066

Browse files
authored
REF: EA quantile logic to EA._quantile (#44412)
1 parent a7b536c commit 4d6a066

File tree

6 files changed

+106
-99
lines changed

6 files changed

+106
-99
lines changed

pandas/core/array_algos/quantile.py

Lines changed: 3 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,19 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
4-
53
import numpy as np
64

75
from pandas._typing import (
86
ArrayLike,
97
npt,
108
)
119

12-
from pandas.core.dtypes.common import is_sparse
1310
from pandas.core.dtypes.missing import (
1411
isna,
1512
na_value_for_dtype,
1613
)
1714

1815
from pandas.core.nanops import nanpercentile
1916

20-
if TYPE_CHECKING:
21-
from pandas.core.arrays import ExtensionArray
22-
2317

2418
def quantile_compat(
2519
values: ArrayLike, qs: npt.NDArray[np.float64], interpolation: str
@@ -40,23 +34,12 @@ def quantile_compat(
4034
if isinstance(values, np.ndarray):
4135
fill_value = na_value_for_dtype(values.dtype, compat=False)
4236
mask = isna(values)
43-
return _quantile_with_mask(values, mask, fill_value, qs, interpolation)
37+
return quantile_with_mask(values, mask, fill_value, qs, interpolation)
4438
else:
45-
# In general we don't want to import from arrays here;
46-
# this is temporary pending discussion in GH#41428
47-
from pandas.core.arrays import BaseMaskedArray
48-
49-
if isinstance(values, BaseMaskedArray):
50-
# e.g. IntegerArray, does not implement _from_factorized
51-
out = _quantile_ea_fallback(values, qs, interpolation)
52-
53-
else:
54-
out = _quantile_ea_compat(values, qs, interpolation)
39+
return values._quantile(qs, interpolation)
5540

56-
return out
5741

58-
59-
def _quantile_with_mask(
42+
def quantile_with_mask(
6043
values: np.ndarray,
6144
mask: np.ndarray,
6245
fill_value,
@@ -114,82 +97,3 @@ def _quantile_with_mask(
11497
result = result.T
11598

11699
return result
117-
118-
119-
def _quantile_ea_compat(
120-
values: ExtensionArray, qs: npt.NDArray[np.float64], interpolation: str
121-
) -> ExtensionArray:
122-
"""
123-
ExtensionArray compatibility layer for _quantile_with_mask.
124-
125-
We pretend that an ExtensionArray with shape (N,) is actually (1, N,)
126-
for compatibility with non-EA code.
127-
128-
Parameters
129-
----------
130-
values : ExtensionArray
131-
qs : np.ndarray[float64]
132-
interpolation: str
133-
134-
Returns
135-
-------
136-
ExtensionArray
137-
"""
138-
# TODO(EA2D): make-believe not needed with 2D EAs
139-
orig = values
140-
141-
# asarray needed for Sparse, see GH#24600
142-
mask = np.asarray(values.isna())
143-
mask = np.atleast_2d(mask)
144-
145-
arr, fill_value = values._values_for_factorize()
146-
arr = np.atleast_2d(arr)
147-
148-
result = _quantile_with_mask(arr, mask, fill_value, qs, interpolation)
149-
150-
if not is_sparse(orig.dtype):
151-
# shape[0] should be 1 as long as EAs are 1D
152-
153-
if orig.ndim == 2:
154-
# i.e. DatetimeArray
155-
result = type(orig)._from_factorized(result, orig)
156-
157-
else:
158-
assert result.shape == (1, len(qs)), result.shape
159-
result = type(orig)._from_factorized(result[0], orig)
160-
161-
# error: Incompatible return value type (got "ndarray", expected "ExtensionArray")
162-
return result # type: ignore[return-value]
163-
164-
165-
def _quantile_ea_fallback(
166-
values: ExtensionArray, qs: npt.NDArray[np.float64], interpolation: str
167-
) -> ExtensionArray:
168-
"""
169-
quantile compatibility for ExtensionArray subclasses that do not
170-
implement `_from_factorized`, e.g. IntegerArray.
171-
172-
Notes
173-
-----
174-
We assume that all impacted cases are 1D-only.
175-
"""
176-
mask = np.atleast_2d(np.asarray(values.isna()))
177-
npvalues = np.atleast_2d(np.asarray(values))
178-
179-
res = _quantile_with_mask(
180-
npvalues,
181-
mask=mask,
182-
fill_value=values.dtype.na_value,
183-
qs=qs,
184-
interpolation=interpolation,
185-
)
186-
assert res.ndim == 2
187-
assert res.shape[0] == 1
188-
res = res[0]
189-
try:
190-
out = type(values)._from_sequence(res, dtype=values.dtype)
191-
except TypeError:
192-
# GH#42626: not able to safely cast Int64
193-
# for floating point output
194-
out = np.atleast_2d(np.asarray(res, dtype=np.float64))
195-
return out

pandas/core/arrays/_mixins.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
unique,
5454
value_counts,
5555
)
56+
from pandas.core.array_algos.quantile import quantile_with_mask
5657
from pandas.core.array_algos.transforms import shift
5758
from pandas.core.arrays.base import ExtensionArray
5859
from pandas.core.construction import extract_array
@@ -463,6 +464,30 @@ def value_counts(self, dropna: bool = True):
463464
index = Index(index_arr, name=result.index.name)
464465
return Series(result._values, index=index, name=result.name)
465466

467+
def _quantile(
468+
self: NDArrayBackedExtensionArrayT,
469+
qs: npt.NDArray[np.float64],
470+
interpolation: str,
471+
) -> NDArrayBackedExtensionArrayT:
472+
# TODO: disable for Categorical if not ordered?
473+
474+
# asarray needed for Sparse, see GH#24600
475+
mask = np.asarray(self.isna())
476+
mask = np.atleast_2d(mask)
477+
478+
arr = np.atleast_2d(self._ndarray)
479+
# TODO: something NDArrayBacked-specific instead of _values_for_factorize[1]?
480+
fill_value = self._values_for_factorize()[1]
481+
482+
res_values = quantile_with_mask(arr, mask, fill_value, qs, interpolation)
483+
484+
result = type(self)._from_factorized(res_values, self)
485+
if self.ndim == 1:
486+
assert result.shape == (1, len(qs)), result.shape
487+
result = result[0]
488+
489+
return result
490+
466491
# ------------------------------------------------------------------------
467492
# numpy-like methods
468493

pandas/core/arrays/base.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
isin,
7676
unique,
7777
)
78+
from pandas.core.array_algos.quantile import quantile_with_mask
7879
from pandas.core.sorting import (
7980
nargminmax,
8081
nargsort,
@@ -1494,6 +1495,41 @@ def _empty(cls, shape: Shape, dtype: ExtensionDtype):
14941495
)
14951496
return result
14961497

1498+
def _quantile(
1499+
self: ExtensionArrayT, qs: npt.NDArray[np.float64], interpolation: str
1500+
) -> ExtensionArrayT:
1501+
"""
1502+
Compute the quantiles of self for each quantile in `qs`.
1503+
1504+
Parameters
1505+
----------
1506+
qs : np.ndarray[float64]
1507+
interpolation: str
1508+
1509+
Returns
1510+
-------
1511+
same type as self
1512+
"""
1513+
# asarray needed for Sparse, see GH#24600
1514+
mask = np.asarray(self.isna())
1515+
mask = np.atleast_2d(mask)
1516+
1517+
arr = np.atleast_2d(np.asarray(self))
1518+
fill_value = np.nan
1519+
1520+
res_values = quantile_with_mask(arr, mask, fill_value, qs, interpolation)
1521+
1522+
if self.ndim == 2:
1523+
# i.e. DatetimeArray
1524+
result = type(self)._from_sequence(res_values)
1525+
1526+
else:
1527+
# shape[0] should be 1 as long as EAs are 1D
1528+
assert res_values.shape == (1, len(qs)), res_values.shape
1529+
result = type(self)._from_sequence(res_values[0])
1530+
1531+
return result
1532+
14971533
def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
14981534
if any(
14991535
isinstance(other, (ABCSeries, ABCIndex, ABCDataFrame)) for other in inputs

pandas/core/arrays/masked.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
take,
6666
)
6767
from pandas.core.array_algos import masked_reductions
68+
from pandas.core.array_algos.quantile import quantile_with_mask
6869
from pandas.core.arraylike import OpsMixin
6970
from pandas.core.arrays import ExtensionArray
7071
from pandas.core.indexers import check_array_indexer
@@ -692,6 +693,38 @@ def equals(self, other) -> bool:
692693
right = other._data[~other._mask]
693694
return array_equivalent(left, right, dtype_equal=True)
694695

696+
def _quantile(
697+
self: BaseMaskedArrayT, qs: npt.NDArray[np.float64], interpolation: str
698+
) -> BaseMaskedArrayT:
699+
"""
700+
Dispatch to quantile_with_mask, needed because we do not have
701+
_from_factorized.
702+
703+
Notes
704+
-----
705+
We assume that all impacted cases are 1D-only.
706+
"""
707+
mask = np.atleast_2d(np.asarray(self.isna()))
708+
npvalues = np.atleast_2d(np.asarray(self))
709+
710+
res = quantile_with_mask(
711+
npvalues,
712+
mask=mask,
713+
fill_value=self.dtype.na_value,
714+
qs=qs,
715+
interpolation=interpolation,
716+
)
717+
assert res.ndim == 2
718+
assert res.shape[0] == 1
719+
res = res[0]
720+
try:
721+
out = type(self)._from_sequence(res, dtype=self.dtype)
722+
except TypeError:
723+
# GH#42626: not able to safely cast Int64
724+
# for floating point output
725+
out = np.asarray(res, dtype=np.float64)
726+
return out
727+
695728
def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
696729
if name in {"any", "all"}:
697730
return getattr(self, name)(skipna=skipna, **kwargs)

pandas/core/arrays/sparse/array.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,12 @@ def value_counts(self, dropna: bool = True) -> Series:
863863
keys = Index(keys)
864864
return Series(counts, index=keys)
865865

866+
def _quantile(self, qs: npt.NDArray[np.float64], interpolation: str):
867+
# Special case: the returned array isn't _really_ sparse, so we don't
868+
# wrap it in a SparseArray
869+
result = super()._quantile(qs, interpolation)
870+
return np.asarray(result)
871+
866872
# --------
867873
# Indexing
868874
# --------

pandas/core/internals/blocks.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,6 +1310,9 @@ def quantile(
13101310
assert is_list_like(qs) # caller is responsible for this
13111311

13121312
result = quantile_compat(self.values, np.asarray(qs._values), interpolation)
1313+
# ensure_block_shape needed for cases where we start with EA and result
1314+
# is ndarray, e.g. IntegerArray, SparseArray
1315+
result = ensure_block_shape(result, ndim=2)
13131316
return new_block_2d(result, placement=self._mgr_locs)
13141317

13151318

0 commit comments

Comments
 (0)