Skip to content

TYP: nargsort #52768

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5632,7 +5632,7 @@ def sort_values(
>>> idx.sort_values(ascending=False, return_indexer=True)
(Index([1000, 100, 10, 1], dtype='int64'), array([3, 1, 0, 2]))
"""
idx = ensure_key_mapped(self, key)
idx = cast(Index, ensure_key_mapped(self, key))

# GH 35584. Sort missing values according to na_position kwarg
# ignore na_position for MultiIndex
Expand Down
6 changes: 5 additions & 1 deletion pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2171,7 +2171,11 @@ def argsort(
# lexsort is significantly faster than self._values.argsort()
target = self._sort_levels_monotonic(raise_if_incomparable=True)
return lexsort_indexer(
target._get_codes_for_sorting(), na_position=na_position
# error: Argument 1 to "lexsort_indexer" has incompatible type
# "List[Categorical]"; expected "Union[List[Union[ExtensionArray,
# ndarray[Any, Any]]], List[Series]]"
target._get_codes_for_sorting(), # type: ignore[arg-type]
na_position=na_position,
)
return self._values.argsort(*args, **kwargs)

Expand Down
5 changes: 4 additions & 1 deletion pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3535,7 +3535,10 @@ def sort_values(
raise ValueError(f"invalid na_position: {na_position}")

# GH 35922. Make sorting stable by leveraging nargsort
values_to_sort = ensure_key_mapped(self, key)._values if key else self._values
if key:
values_to_sort = cast(Series, ensure_key_mapped(self, key))._values
else:
values_to_sort = self._values
sorted_index = nargsort(values_to_sort, kind, bool(ascending), na_position)

if is_range_indexer(sorted_index, len(sorted_index)):
Expand Down
95 changes: 75 additions & 20 deletions pandas/core/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from pandas.core.dtypes.common import (
ensure_int64,
ensure_platform_int,
is_extension_array_dtype,
)
from pandas.core.dtypes.generic import (
ABCMultiIndex,
Expand All @@ -36,6 +35,7 @@

if TYPE_CHECKING:
from pandas._typing import (
ArrayLike,
AxisInt,
IndexKeyFunc,
Level,
Expand All @@ -45,7 +45,10 @@
npt,
)

from pandas import MultiIndex
from pandas import (
MultiIndex,
Series,
)
from pandas.core.arrays import ExtensionArray
from pandas.core.indexes.base import Index

Expand Down Expand Up @@ -79,7 +82,10 @@ def get_indexer_indexer(
The indexer for the new index.
"""

target = ensure_key_mapped(target, key, levels=level)
# error: Incompatible types in assignment (expression has type
# "Union[ExtensionArray, ndarray[Any, Any], Index, Series]", variable has
# type "Index")
target = ensure_key_mapped(target, key, levels=level) # type:ignore[assignment]
target = target._sort_levels_monotonic()

if level is not None:
Expand Down Expand Up @@ -304,7 +310,7 @@ def indexer_from_factorized(


def lexsort_indexer(
keys,
keys: list[ArrayLike] | list[Series],
orders=None,
na_position: str = "last",
key: Callable | None = None,
Expand All @@ -315,8 +321,9 @@ def lexsort_indexer(

Parameters
----------
keys : sequence of arrays
keys : list[ArrayLike] | list[Series]
Sequence of ndarrays to be sorted by the indexer
list[Series] is only if key is not None.
orders : bool or list of booleans, optional
Determines the sorting order for each element in keys. If a list,
it must be the same length as keys. This determines whether the
Expand All @@ -343,7 +350,10 @@ def lexsort_indexer(
elif orders is None:
orders = [True] * len(keys)

keys = [ensure_key_mapped(k, key) for k in keys]
# error: Incompatible types in assignment (expression has type
# "List[Union[ExtensionArray, ndarray[Any, Any], Index, Series]]", variable
# has type "Union[List[Union[ExtensionArray, ndarray[Any, Any]]], List[Series]]")
keys = [ensure_key_mapped(k, key) for k in keys] # type: ignore[assignment]

for k, order in zip(keys, orders):
if na_position not in ["last", "first"]:
Expand All @@ -354,7 +364,9 @@ def lexsort_indexer(
codes = k.copy()
n = len(codes)
mask_n = n
if mask.any():
# error: Item "ExtensionArray" of "Union[Any, ExtensionArray,
# ndarray[Any, Any]]" has no attribute "any"
if mask.any(): # type: ignore[union-attr]
n -= 1

else:
Expand All @@ -369,14 +381,40 @@ def lexsort_indexer(

if order: # ascending
if na_position == "last":
codes = np.where(mask, n, codes)
# error: Argument 1 to "where" has incompatible type "Union[Any,
# ExtensionArray, ndarray[Any, Any]]"; expected
# "Union[_SupportsArray[dtype[Any]],
# _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float,
# complex, str, bytes, _NestedSequence[Union[bool, int, float,
# complex, str, bytes]]]"
codes = np.where(mask, n, codes) # type: ignore[arg-type]
elif na_position == "first":
codes += 1
# error: Incompatible types in assignment (expression has type
# "Union[Any, int, ndarray[Any, dtype[signedinteger[Any]]]]",
# variable has type "Union[Series, ExtensionArray, ndarray[Any, Any]]")
# error: Unsupported operand types for + ("ExtensionArray" and "int")
codes += 1 # type: ignore[operator,assignment]
else: # not order means descending
if na_position == "last":
codes = np.where(mask, n, n - codes - 1)
# error: Unsupported operand types for - ("int" and "ExtensionArray")
# error: Argument 1 to "where" has incompatible type "Union[Any,
# ExtensionArray, ndarray[Any, Any]]"; expected
# "Union[_SupportsArray[dtype[Any]],
# _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float,
# complex, str, bytes, _NestedSequence[Union[bool, int, float,
# complex, str, bytes]]]"
codes = np.where(
mask, n, n - codes - 1 # type: ignore[operator,arg-type]
)
elif na_position == "first":
codes = np.where(mask, 0, n - codes)
# error: Unsupported operand types for - ("int" and "ExtensionArray")
# error: Argument 1 to "where" has incompatible type "Union[Any,
# ExtensionArray, ndarray[Any, Any]]"; expected
# "Union[_SupportsArray[dtype[Any]],
# _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float,
# complex, str, bytes, _NestedSequence[Union[bool, int, float,
# complex, str, bytes]]]"
codes = np.where(mask, 0, n - codes) # type: ignore[operator,arg-type]

shape.append(mask_n)
labels.append(codes)
Expand All @@ -385,7 +423,7 @@ def lexsort_indexer(


def nargsort(
items,
items: ArrayLike | Index | Series,
kind: str = "quicksort",
ascending: bool = True,
na_position: str = "last",
Expand All @@ -401,6 +439,7 @@ def nargsort(

Parameters
----------
items : np.ndarray, ExtensionArray, Index, or Series
kind : str, default 'quicksort'
ascending : bool, default True
na_position : {'first', 'last'}, default 'last'
Expand All @@ -414,6 +453,7 @@ def nargsort(
"""

if key is not None:
# see TestDataFrameSortKey, TestRangeIndex::test_sort_values_key
items = ensure_key_mapped(items, key)
return nargsort(
items,
Expand All @@ -425,16 +465,27 @@ def nargsort(
)

if isinstance(items, ABCRangeIndex):
return items.argsort(ascending=ascending) # TODO: test coverage with key?
return items.argsort(ascending=ascending)
elif not isinstance(items, ABCMultiIndex):
items = extract_array(items)
else:
raise TypeError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a test that hits this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't public and the place in Index that calls nargsort specifically goes for a separate path for MultiIndex

"nargsort does not support MultiIndex. Use index.sort_values instead."
)

if mask is None:
mask = np.asarray(isna(items)) # TODO: does this exclude MultiIndex too?
mask = np.asarray(isna(items))

if is_extension_array_dtype(items):
return items.argsort(ascending=ascending, kind=kind, na_position=na_position)
else:
items = np.asanyarray(items)
if not isinstance(items, np.ndarray):
# i.e. ExtensionArray
return items.argsort(
ascending=ascending,
# error: Argument "kind" to "argsort" of "ExtensionArray" has
# incompatible type "str"; expected "Literal['quicksort',
# 'mergesort', 'heapsort', 'stable']"
kind=kind, # type: ignore[arg-type]
na_position=na_position,
)

idx = np.arange(len(items))
non_nans = items[~mask]
Expand Down Expand Up @@ -551,7 +602,9 @@ def _ensure_key_mapped_multiindex(
return type(index).from_arrays(mapped)


def ensure_key_mapped(values, key: Callable | None, levels=None):
def ensure_key_mapped(
values: ArrayLike | Index | Series, key: Callable | None, levels=None
) -> ArrayLike | Index | Series:
"""
Applies a callable key function to the values function and checks
that the resulting value has the same shape. Can be called on Index
Expand Down Expand Up @@ -584,8 +637,10 @@ def ensure_key_mapped(values, key: Callable | None, levels=None):
): # convert to a new Index subclass, not necessarily the same
result = Index(result)
else:
# try to revert to original type otherwise
type_of_values = type(values)
result = type_of_values(result) # try to revert to original type otherwise
# error: Too many arguments for "ExtensionArray"
result = type_of_values(result) # type: ignore[call-arg]
except TypeError:
raise TypeError(
f"User-provided `key` function returned an invalid type {type(result)} \
Expand Down
34 changes: 3 additions & 31 deletions pandas/tests/test_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,61 +156,33 @@ def test_lexsort_indexer(self, order, na_position, exp):
tm.assert_numpy_array_equal(result, np.array(exp, dtype=np.intp))

@pytest.mark.parametrize(
"ascending, na_position, exp, box",
"ascending, na_position, exp",
[
[
True,
"last",
list(range(5, 105)) + list(range(5)) + list(range(105, 110)),
list,
],
[
True,
"first",
list(range(5)) + list(range(105, 110)) + list(range(5, 105)),
list,
],
[
False,
"last",
list(range(104, 4, -1)) + list(range(5)) + list(range(105, 110)),
list,
],
[
False,
"first",
list(range(5)) + list(range(105, 110)) + list(range(104, 4, -1)),
list,
],
[
True,
"last",
list(range(5, 105)) + list(range(5)) + list(range(105, 110)),
lambda x: np.array(x, dtype="O"),
],
[
True,
"first",
list(range(5)) + list(range(105, 110)) + list(range(5, 105)),
lambda x: np.array(x, dtype="O"),
],
[
False,
"last",
list(range(104, 4, -1)) + list(range(5)) + list(range(105, 110)),
lambda x: np.array(x, dtype="O"),
],
[
False,
"first",
list(range(5)) + list(range(105, 110)) + list(range(104, 4, -1)),
lambda x: np.array(x, dtype="O"),
],
],
)
def test_nargsort(self, ascending, na_position, exp, box):
def test_nargsort(self, ascending, na_position, exp):
# list places NaNs last, np.array(..., dtype="O") may not place NaNs first
items = box([np.nan] * 5 + list(range(100)) + [np.nan] * 5)
items = np.array([np.nan] * 5 + list(range(100)) + [np.nan] * 5, dtype="O")

# mergesort is the most difficult to get right because we want it to be
# stable.
Expand Down