Skip to content

Commit 53b945c

Browse files
committed
BUG: rank raising for arrow string dtypes (pandas-dev#55362)
(cherry picked from commit 6a83910)
1 parent f07deef commit 53b945c

File tree

4 files changed

+106
-7
lines changed

4 files changed

+106
-7
lines changed

doc/source/whatsnew/v2.1.2.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ Bug fixes
2626
- Fixed bug in :meth:`DataFrame.idxmin` and :meth:`DataFrame.idxmax` raising for arrow dtypes (:issue:`55368`)
2727
- Fixed bug in :meth:`DataFrame.interpolate` raising incorrect error message (:issue:`55347`)
2828
- Fixed bug in :meth:`Index.insert` raising when inserting ``None`` into :class:`Index` with ``dtype="string[pyarrow_numpy]"`` (:issue:`55365`)
29+
- Fixed bug in :meth:`Series.rank` for ``string[pyarrow_numpy]`` dtype (:issue:`55362`)
2930
- Silence ``Period[B]`` warnings introduced by :issue:`53446` during normal plotting activity (:issue:`55138`)
3031
-
3132

pandas/core/arrays/arrow/array.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1712,7 +1712,7 @@ def __setitem__(self, key, value) -> None:
17121712
data = pa.chunked_array([data])
17131713
self._pa_array = data
17141714

1715-
def _rank(
1715+
def _rank_calc(
17161716
self,
17171717
*,
17181718
axis: AxisInt = 0,
@@ -1721,9 +1721,6 @@ def _rank(
17211721
ascending: bool = True,
17221722
pct: bool = False,
17231723
):
1724-
"""
1725-
See Series.rank.__doc__.
1726-
"""
17271724
if pa_version_under9p0 or axis != 0:
17281725
ranked = super()._rank(
17291726
axis=axis,
@@ -1738,7 +1735,7 @@ def _rank(
17381735
else:
17391736
pa_type = pa.uint64()
17401737
result = pa.array(ranked, type=pa_type, from_pandas=True)
1741-
return type(self)(result)
1738+
return result
17421739

17431740
data = self._pa_array.combine_chunks()
17441741
sort_keys = "ascending" if ascending else "descending"
@@ -1777,7 +1774,29 @@ def _rank(
17771774
divisor = pc.count(result)
17781775
result = pc.divide(result, divisor)
17791776

1780-
return type(self)(result)
1777+
return result
1778+
1779+
def _rank(
1780+
self,
1781+
*,
1782+
axis: AxisInt = 0,
1783+
method: str = "average",
1784+
na_option: str = "keep",
1785+
ascending: bool = True,
1786+
pct: bool = False,
1787+
):
1788+
"""
1789+
See Series.rank.__doc__.
1790+
"""
1791+
return type(self)(
1792+
self._rank_calc(
1793+
axis=axis,
1794+
method=method,
1795+
na_option=na_option,
1796+
ascending=ascending,
1797+
pct=pct,
1798+
)
1799+
)
17811800

17821801
def _quantile(self, qs: npt.NDArray[np.float64], interpolation: str) -> Self:
17831802
"""

pandas/core/arrays/string_arrow.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
lib,
1616
missing as libmissing,
1717
)
18-
from pandas.compat import pa_version_under7p0
18+
from pandas.compat import (
19+
pa_version_under7p0,
20+
pa_version_under13p0,
21+
)
1922
from pandas.util._exceptions import find_stack_level
2023

2124
from pandas.core.dtypes.common import (
@@ -48,6 +51,7 @@
4851

4952
if TYPE_CHECKING:
5053
from pandas._typing import (
54+
AxisInt,
5155
Dtype,
5256
Scalar,
5357
npt,
@@ -444,6 +448,65 @@ def _str_rstrip(self, to_strip=None):
444448
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
445449
return type(self)(result)
446450

451+
def _str_removeprefix(self, prefix: str):
452+
if not pa_version_under13p0:
453+
starts_with = pc.starts_with(self._pa_array, pattern=prefix)
454+
removed = pc.utf8_slice_codeunits(self._pa_array, len(prefix))
455+
result = pc.if_else(starts_with, removed, self._pa_array)
456+
return type(self)(result)
457+
return super()._str_removeprefix(prefix)
458+
459+
def _str_removesuffix(self, suffix: str):
460+
ends_with = pc.ends_with(self._pa_array, pattern=suffix)
461+
removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix))
462+
result = pc.if_else(ends_with, removed, self._pa_array)
463+
return type(self)(result)
464+
465+
def _str_count(self, pat: str, flags: int = 0):
466+
if flags:
467+
return super()._str_count(pat, flags)
468+
result = pc.count_substring_regex(self._pa_array, pat)
469+
return self._convert_int_dtype(result)
470+
471+
def _str_find(self, sub: str, start: int = 0, end: int | None = None):
472+
if start != 0 and end is not None:
473+
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
474+
result = pc.find_substring(slices, sub)
475+
not_found = pc.equal(result, -1)
476+
offset_result = pc.add(result, end - start)
477+
result = pc.if_else(not_found, result, offset_result)
478+
elif start == 0 and end is None:
479+
slices = self._pa_array
480+
result = pc.find_substring(slices, sub)
481+
else:
482+
return super()._str_find(sub, start, end)
483+
return self._convert_int_dtype(result)
484+
485+
def _convert_int_dtype(self, result):
486+
return Int64Dtype().__from_arrow__(result)
487+
488+
def _rank(
489+
self,
490+
*,
491+
axis: AxisInt = 0,
492+
method: str = "average",
493+
na_option: str = "keep",
494+
ascending: bool = True,
495+
pct: bool = False,
496+
):
497+
"""
498+
See Series.rank.__doc__.
499+
"""
500+
return self._convert_int_dtype(
501+
self._rank_calc(
502+
axis=axis,
503+
method=method,
504+
na_option=na_option,
505+
ascending=ascending,
506+
pct=pct,
507+
)
508+
)
509+
447510

448511
class ArrowStringArrayNumpySemantics(ArrowStringArray):
449512
_storage = "pyarrow_numpy"
@@ -527,6 +590,10 @@ def _str_map(
527590
return lib.map_infer_mask(arr, f, mask.view("uint8"))
528591

529592
def _convert_int_dtype(self, result):
593+
if isinstance(result, pa.Array):
594+
result = result.to_numpy(zero_copy_only=False)
595+
else:
596+
result = result.to_numpy()
530597
if result.dtype == np.int32:
531598
result = result.astype(np.int64)
532599
return result

pandas/tests/frame/methods/test_rank.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,3 +488,15 @@ def test_rank_mixed_axis_zero(self, data, expected):
488488
df.rank()
489489
result = df.rank(numeric_only=True)
490490
tm.assert_frame_equal(result, expected)
491+
492+
@pytest.mark.parametrize(
493+
"dtype, exp_dtype",
494+
[("string[pyarrow]", "Int64"), ("string[pyarrow_numpy]", "float64")],
495+
)
496+
def test_rank_string_dtype(self, dtype, exp_dtype):
497+
# GH#55362
498+
pytest.importorskip("pyarrow")
499+
obj = Series(["foo", "foo", None, "foo"], dtype=dtype)
500+
result = obj.rank(method="first")
501+
expected = Series([1, 2, None, 3], dtype=exp_dtype)
502+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)