Skip to content

ENH: Implement more str accessor methods for ArrowDtype #52614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.0.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Other
- :class:`DataFrame` created from empty dicts had :attr:`~DataFrame.columns` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`)
- :class:`Series` created from empty dicts had :attr:`~Series.index` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`)
- Implemented :meth:`Series.str.split` and :meth:`Series.str.rsplit` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`52401`)
- Implemented most ``str`` accessor methods for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`52401`)

.. ---------------------------------------------------------------------------
.. _whatsnew_201.contributors:
Expand Down
200 changes: 165 additions & 35 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import operator
import re
import sys
import textwrap
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -10,6 +12,7 @@
Sequence,
cast,
)
import unicodedata

import numpy as np

Expand Down Expand Up @@ -1868,13 +1871,29 @@ def _str_join(self, sep: str):
return type(self)(pc.binary_join(self._pa_array, sep))

def _str_partition(self, sep: str, expand: bool):
raise NotImplementedError(
"str.partition not supported with pd.ArrowDtype(pa.string())."
return type(self)(
pa.chunked_array(
[
[
None if val.as_py() is None else val.as_py().partition(sep)
for val in chunk
]
for chunk in self._pa_array.iterchunks()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like this will maintain the chunking structure. is there a reason not to chain these together and end up with a single chunk?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw this related issue about ops not maintaining the underlying chunking structure and though best to try to keep it here: #42357

]
)
)

def _str_rpartition(self, sep: str, expand: bool):
raise NotImplementedError(
"str.rpartition not supported with pd.ArrowDtype(pa.string())."
return type(self)(
pa.chunked_array(
[
[
None if val.as_py() is None else val.as_py().rpartition(sep)
for val in chunk
]
for chunk in self._pa_array.iterchunks()
]
)
)

def _str_slice(
Expand Down Expand Up @@ -1964,14 +1983,31 @@ def _str_rstrip(self, to_strip=None):
return type(self)(result)

def _str_removeprefix(self, prefix: str):
raise NotImplementedError(
"str.removeprefix not supported with pd.ArrowDtype(pa.string())."
)
# TODO: Should work once https://github.com/apache/arrow/issues/14991 is fixed
# starts_with = pc.starts_with(self._pa_array, pattern=prefix)
# removed = pc.utf8_slice_codeunits(self._pa_array, len(prefix))
# result = pc.if_else(starts_with, removed, self._pa_array)
# return type(self)(result)
if sys.version_info < (3, 9):
# NOTE pyupgrade will remove this when we run it with --py39-plus
# so don't remove the unnecessary `else` statement below
from pandas.util._str_methods import removeprefix

else:
removeprefix = lambda arg, prefix: arg.removeprefix(prefix)
return type(self)(
pa.chunked_array(
[
[
None
if val.as_py() is None
else removeprefix(val.as_py(), prefix)
for val in chunk
]
for chunk in self._pa_array.iterchunks()
]
)
)

def _str_removesuffix(self, suffix: str):
ends_with = pc.ends_with(self._pa_array, pattern=suffix)
Expand All @@ -1980,48 +2016,124 @@ def _str_removesuffix(self, suffix: str):
return type(self)(result)

def _str_casefold(self):
raise NotImplementedError(
"str.casefold not supported with pd.ArrowDtype(pa.string())."
return type(self)(
pa.chunked_array(
[
[
None if val.as_py() is None else val.as_py().casefold()
for val in chunk
]
for chunk in self._pa_array.iterchunks()
]
)
)

def _str_encode(self, encoding, errors: str = "strict"):
raise NotImplementedError(
"str.encode not supported with pd.ArrowDtype(pa.string())."
def _str_encode(self, encoding: str, errors: str = "strict"):
return type(self)(
pa.chunked_array(
[
[
None
if val.as_py() is None
else val.as_py().encode(encoding, errors)
for val in chunk
]
for chunk in self._pa_array.iterchunks()
]
)
)

def _str_extract(self, pat: str, flags: int = 0, expand: bool = True):
raise NotImplementedError(
"str.extract not supported with pd.ArrowDtype(pa.string())."
)

def _str_findall(self, pat, flags: int = 0):
raise NotImplementedError(
"str.findall not supported with pd.ArrowDtype(pa.string())."
def _str_findall(self, pat: str, flags: int = 0):
regex = re.compile(pat, flags=flags)
return type(self)(
pa.chunked_array(
[
[
None if val.as_py() is None else regex.findall(val.as_py())
for val in chunk
]
for chunk in self._pa_array.iterchunks()
]
)
)

def _str_get_dummies(self, sep: str = "|"):
raise NotImplementedError(
"str.get_dummies not supported with pd.ArrowDtype(pa.string())."
)
split = pc.split_pattern(self._pa_array, sep).combine_chunks()
uniques = split.flatten().unique()
uniques_sorted = uniques.take(pa.compute.array_sort_indices(uniques))
result_data = []
for lst in split.to_pylist():
if lst is None:
result_data.append([False] * len(uniques_sorted))
else:
res = pc.is_in(uniques_sorted, pa.array(set(lst)))
result_data.append(res.to_pylist())
result = type(self)(pa.array(result_data))
return result, uniques_sorted.to_pylist()

def _str_index(self, sub, start: int = 0, end=None):
raise NotImplementedError(
"str.index not supported with pd.ArrowDtype(pa.string())."
def _str_index(self, sub: str, start: int = 0, end: int | None = None):
return type(self)(
pa.chunked_array(
[
[
None
if val.as_py() is None
else val.as_py().index(sub, start, end)
for val in chunk
]
for chunk in self._pa_array.iterchunks()
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it worth making a helper for this pattern so this can just be

def _str_index(...)
    predicate = lambda x: x.index(sub, start, end)
    return self._helper_whatever(predicate)

)
)

def _str_rindex(self, sub, start: int = 0, end=None):
raise NotImplementedError(
"str.rindex not supported with pd.ArrowDtype(pa.string())."
def _str_rindex(self, sub: str, start: int = 0, end: int | None = None):
return type(self)(
pa.chunked_array(
[
[
None
if val.as_py() is None
else val.as_py().rindex(sub, start, end)
for val in chunk
]
for chunk in self._pa_array.iterchunks()
]
)
)

def _str_normalize(self, form):
raise NotImplementedError(
"str.normalize not supported with pd.ArrowDtype(pa.string())."
def _str_normalize(self, form: str):
return type(self)(
pa.chunked_array(
[
[
None
if val.as_py() is None
else unicodedata.normalize(form, val.as_py())
for val in chunk
]
for chunk in self._pa_array.iterchunks()
]
)
)

def _str_rfind(self, sub, start: int = 0, end=None):
raise NotImplementedError(
"str.rfind not supported with pd.ArrowDtype(pa.string())."
def _str_rfind(self, sub: str, start: int = 0, end=None):
return type(self)(
pa.chunked_array(
[
[
None
if val.as_py() is None
else val.as_py().rfind(sub, start, end)
for val in chunk
]
for chunk in self._pa_array.iterchunks()
]
)
)

def _str_split(
Expand All @@ -2046,14 +2158,32 @@ def _str_rsplit(self, pat: str | None = None, n: int | None = -1):
pc.split_pattern(self._pa_array, pat, max_splits=n, reverse=True)
)

def _str_translate(self, table):
raise NotImplementedError(
"str.translate not supported with pd.ArrowDtype(pa.string())."
def _str_translate(self, table: dict[int, str]):
return type(self)(
pa.chunked_array(
[
[
None if val.as_py() is None else val.as_py().translate(table)
for val in chunk
]
for chunk in self._pa_array.iterchunks()
]
)
)

def _str_wrap(self, width: int, **kwargs):
raise NotImplementedError(
"str.wrap not supported with pd.ArrowDtype(pa.string())."
kwargs["width"] = width
tw = textwrap.TextWrapper(**kwargs)
return type(self)(
pa.chunked_array(
[
[
None if val.as_py() is None else "\n".join(tw.wrap(val.as_py()))
for val in chunk
]
for chunk in self._pa_array.iterchunks()
]
)
)

@property
Expand Down
10 changes: 7 additions & 3 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,6 @@ def _wrap_result(
if expand is None:
# infer from ndim if expand is not specified
expand = result.ndim != 1

elif expand is True and not isinstance(self._orig, ABCIndex):
# required when expand=True is explicitly specified
# not needed when inferred
Expand All @@ -280,10 +279,15 @@ def _wrap_result(
result._pa_array.combine_chunks().value_lengths()
).as_py()
if result.isna().any():
# ArrowExtensionArray.fillna doesn't work for list scalars
result._pa_array = result._pa_array.fill_null([None] * max_len)
if name is not None:
labels = name
else:
labels = range(max_len)
result = {
i: ArrowExtensionArray(pa.array(res))
for i, res in enumerate(zip(*result.tolist()))
label: ArrowExtensionArray(pa.array(res))
for label, res in zip(labels, (zip(*result.tolist())))
}
elif is_object_dtype(result):

Expand Down
Loading