-
-
Notifications
You must be signed in to change notification settings - Fork 18.6k
REF/PERF: ArrowStringArray.__setitem__ #46400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 5 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
e379a22
ArrowStringArray.__setitem__
lukemanley e21c4ff
Merge remote-tracking branch 'upstream/main' into arrowstringarray-se…
lukemanley 0e35f6a
fixes
lukemanley f292054
whatsnew
lukemanley 773f375
fix test
lukemanley f44bcbb
refactor
lukemanley 76a25a9
fix docstring
lukemanley File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,6 @@ | |
TYPE_CHECKING, | ||
Any, | ||
Union, | ||
cast, | ||
overload, | ||
) | ||
|
||
|
@@ -31,6 +30,7 @@ | |
pa_version_under2p0, | ||
pa_version_under3p0, | ||
pa_version_under4p0, | ||
pa_version_under5p0, | ||
) | ||
from pandas.util._decorators import doc | ||
|
||
|
@@ -40,6 +40,7 @@ | |
is_dtype_equal, | ||
is_integer, | ||
is_integer_dtype, | ||
is_list_like, | ||
is_object_dtype, | ||
is_scalar, | ||
is_string_dtype, | ||
|
@@ -363,48 +364,139 @@ def __setitem__(self, key: int | slice | np.ndarray, value: Any) -> None: | |
""" | ||
key = check_array_indexer(self, key) | ||
|
||
if is_integer(key): | ||
key = cast(int, key) | ||
if is_list_like(key): | ||
key = np.asarray(key) | ||
if len(key) == 1: | ||
key = key[0] | ||
|
||
if not is_scalar(value): | ||
raise ValueError("Must pass scalars with scalar indexer") | ||
elif isna(value): | ||
value_is_scalar = is_scalar(value) | ||
|
||
# NA -> None | ||
if value_is_scalar: | ||
if isna(value): | ||
value = None | ||
elif not isinstance(value, str): | ||
raise ValueError("Scalar must be NA or str") | ||
|
||
# Slice data and insert in-between | ||
new_data = [ | ||
*self._data[0:key].chunks, | ||
pa.array([value], type=pa.string()), | ||
*self._data[(key + 1) :].chunks, | ||
] | ||
self._data = pa.chunked_array(new_data) | ||
else: | ||
# Convert to integer indices and iteratively assign. | ||
# TODO: Make a faster variant of this in Arrow upstream. | ||
# This is probably extremely slow. | ||
|
||
# Convert all possible input key types to an array of integers | ||
if isinstance(key, slice): | ||
key_array = np.array(range(len(self))[key]) | ||
elif is_bool_dtype(key): | ||
# TODO(ARROW-9430): Directly support setitem(booleans) | ||
key_array = np.argwhere(key).flatten() | ||
else: | ||
# TODO(ARROW-9431): Directly support setitem(integers) | ||
key_array = np.asanyarray(key) | ||
value = np.asarray(value, dtype=object) | ||
for i, v in enumerate(value): | ||
if isna(v): | ||
value[i] = None | ||
elif not isinstance(v, str): | ||
raise ValueError("Scalar must be NA or str") | ||
|
||
# reorder values to align with the mask positions | ||
if is_bool_dtype(key): | ||
pass | ||
elif isinstance(key, slice): | ||
if not value_is_scalar and key.step is not None and key.step < 0: | ||
value = value[::-1] | ||
else: | ||
if not value_is_scalar: | ||
if is_scalar(key): | ||
raise ValueError("Length of indexer and values mismatch") | ||
key = np.asarray(key) | ||
if len(key) != len(value): | ||
raise ValueError("Length of indexer and values mismatch") | ||
|
||
if np.any(key < -len(self)): | ||
min_key = np.asarray(key).min() | ||
raise IndexError( | ||
f"index {min_key} is out of bounds for array of length {len(self)}" | ||
) | ||
if np.any(key >= len(self)): | ||
max_key = np.asarray(key).max() | ||
raise IndexError( | ||
f"index {max_key} is out of bounds for array of length {len(self)}" | ||
) | ||
|
||
if is_scalar(value): | ||
value = np.broadcast_to(value, len(key_array)) | ||
# convert negative indices to positive before sorting | ||
if is_integer(key): | ||
if key < 0: | ||
key += len(self) | ||
else: | ||
value = np.asarray(value) | ||
key = np.asarray(key) | ||
key[key < 0] += len(self) | ||
if not value_is_scalar: | ||
value = value[np.argsort(key)] | ||
|
||
# fast path | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same thing here would do something along the lines of
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
if is_integer(key) and value_is_scalar and self._data.num_chunks == 1: | ||
idx = int(key) # type: ignore[arg-type] | ||
chunk = pa.concat_arrays( | ||
[ | ||
self._data.chunks[0][:idx], | ||
pa.array([value], type=pa.string()), | ||
self._data.chunks[0][idx + 1 :], | ||
] | ||
) | ||
self._data = pa.chunked_array([chunk]) | ||
return | ||
|
||
if len(key_array) != len(value): | ||
# create mask for positions to set | ||
mask: npt.NDArray[np.bool_] | ||
if is_bool_dtype(key): | ||
mask = key # type: ignore[assignment] | ||
else: | ||
mask = np.zeros(len(self), dtype=np.bool_) | ||
mask[key] = True | ||
|
||
if not value_is_scalar: | ||
if len(value) != np.sum(mask): | ||
raise ValueError("Length of indexer and values mismatch") | ||
|
||
for k, v in zip(key_array, value): | ||
self[k] = v | ||
indices = mask.nonzero()[0] | ||
|
||
# loop through the array chunks and set the new values while | ||
# leaving the chunking layout unchanged | ||
start = stop = 0 | ||
new_data = [] | ||
|
||
for chunk in self._data.iterchunks(): | ||
start, stop = stop, stop + len(chunk) | ||
|
||
if len(indices) == 0 or indices[0] >= stop: | ||
new_data.append(chunk) | ||
continue | ||
|
||
n = int(np.searchsorted(indices, stop, side="left")) | ||
c_indices, indices = indices[:n], indices[n:] | ||
|
||
if value_is_scalar: | ||
c_value = value | ||
else: | ||
c_value, value = value[:n], value[n:] | ||
|
||
if n == 1: | ||
# fast path | ||
idx = c_indices[0] - start | ||
v = [c_value] if value_is_scalar else c_value | ||
chunk = pa.concat_arrays( | ||
[ | ||
chunk[:idx], | ||
pa.array(v, type=pa.string()), | ||
chunk[idx + 1 :], | ||
] | ||
) | ||
|
||
elif n > 0: | ||
submask = mask[start:stop] | ||
if not pa_version_under5p0: | ||
if c_value is None or isna(np.array(c_value)).all(): | ||
chunk = pc.if_else(submask, None, chunk) | ||
else: | ||
chunk = pc.replace_with_mask(chunk, submask, c_value) | ||
else: | ||
# The pyarrow compute functions were added in | ||
# version 5.0. For prior versions we implement | ||
# our own by converting to numpy and back. | ||
chunk = chunk.to_numpy(zero_copy_only=False) | ||
chunk[submask] = c_value | ||
chunk = pa.array(chunk, type=pa.string()) | ||
|
||
new_data.append(chunk) | ||
|
||
self._data = pa.chunked_array(new_data) | ||
|
||
def take( | ||
self, | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i would create a helper method like
_validate_key()
to encapsulate all of this (ok on this class for now, but we likey want to push this to the ArrowExtensionArray (or maybe we need a ArrowIndexingMixin or similar), that can be later (or here if convenient).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I refactored pretty extensively, this logic is now self contained