Skip to content

Commit d25d768

Browse files
authored
TYP: nargsort (#52768)
* TYP: nargsort * mypy fixup * mypy fixup
1 parent 26de031 commit d25d768

File tree

5 files changed

+88
-54
lines changed

5 files changed

+88
-54
lines changed

pandas/core/indexes/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5633,7 +5633,7 @@ def sort_values(
56335633
>>> idx.sort_values(ascending=False, return_indexer=True)
56345634
(Index([1000, 100, 10, 1], dtype='int64'), array([3, 1, 0, 2]))
56355635
"""
5636-
idx = ensure_key_mapped(self, key)
5636+
idx = cast(Index, ensure_key_mapped(self, key))
56375637

56385638
# GH 35584. Sort missing values according to na_position kwarg
56395639
# ignore na_position for MultiIndex

pandas/core/indexes/multi.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2171,7 +2171,11 @@ def argsort(
21712171
# lexsort is significantly faster than self._values.argsort()
21722172
target = self._sort_levels_monotonic(raise_if_incomparable=True)
21732173
return lexsort_indexer(
2174-
target._get_codes_for_sorting(), na_position=na_position
2174+
# error: Argument 1 to "lexsort_indexer" has incompatible type
2175+
# "List[Categorical]"; expected "Union[List[Union[ExtensionArray,
2176+
# ndarray[Any, Any]]], List[Series]]"
2177+
target._get_codes_for_sorting(), # type: ignore[arg-type]
2178+
na_position=na_position,
21752179
)
21762180
return self._values.argsort(*args, **kwargs)
21772181

pandas/core/series.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3535,7 +3535,10 @@ def sort_values(
35353535
raise ValueError(f"invalid na_position: {na_position}")
35363536

35373537
# GH 35922. Make sorting stable by leveraging nargsort
3538-
values_to_sort = ensure_key_mapped(self, key)._values if key else self._values
3538+
if key:
3539+
values_to_sort = cast(Series, ensure_key_mapped(self, key))._values
3540+
else:
3541+
values_to_sort = self._values
35393542
sorted_index = nargsort(values_to_sort, kind, bool(ascending), na_position)
35403543

35413544
if is_range_indexer(sorted_index, len(sorted_index)):

pandas/core/sorting.py

Lines changed: 75 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from pandas.core.dtypes.common import (
2525
ensure_int64,
2626
ensure_platform_int,
27-
is_extension_array_dtype,
2827
)
2928
from pandas.core.dtypes.generic import (
3029
ABCMultiIndex,
@@ -36,6 +35,7 @@
3635

3736
if TYPE_CHECKING:
3837
from pandas._typing import (
38+
ArrayLike,
3939
AxisInt,
4040
IndexKeyFunc,
4141
Level,
@@ -45,7 +45,10 @@
4545
npt,
4646
)
4747

48-
from pandas import MultiIndex
48+
from pandas import (
49+
MultiIndex,
50+
Series,
51+
)
4952
from pandas.core.arrays import ExtensionArray
5053
from pandas.core.indexes.base import Index
5154

@@ -79,7 +82,10 @@ def get_indexer_indexer(
7982
The indexer for the new index.
8083
"""
8184

82-
target = ensure_key_mapped(target, key, levels=level)
85+
# error: Incompatible types in assignment (expression has type
86+
# "Union[ExtensionArray, ndarray[Any, Any], Index, Series]", variable has
87+
# type "Index")
88+
target = ensure_key_mapped(target, key, levels=level) # type:ignore[assignment]
8389
target = target._sort_levels_monotonic()
8490

8591
if level is not None:
@@ -304,7 +310,7 @@ def indexer_from_factorized(
304310

305311

306312
def lexsort_indexer(
307-
keys,
313+
keys: list[ArrayLike] | list[Series],
308314
orders=None,
309315
na_position: str = "last",
310316
key: Callable | None = None,
@@ -315,8 +321,9 @@ def lexsort_indexer(
315321
316322
Parameters
317323
----------
318-
keys : sequence of arrays
324+
keys : list[ArrayLike] | list[Series]
319325
Sequence of ndarrays to be sorted by the indexer
326+
list[Series] is only if key is not None.
320327
orders : bool or list of booleans, optional
321328
Determines the sorting order for each element in keys. If a list,
322329
it must be the same length as keys. This determines whether the
@@ -343,7 +350,10 @@ def lexsort_indexer(
343350
elif orders is None:
344351
orders = [True] * len(keys)
345352

346-
keys = [ensure_key_mapped(k, key) for k in keys]
353+
# error: Incompatible types in assignment (expression has type
354+
# "List[Union[ExtensionArray, ndarray[Any, Any], Index, Series]]", variable
355+
# has type "Union[List[Union[ExtensionArray, ndarray[Any, Any]]], List[Series]]")
356+
keys = [ensure_key_mapped(k, key) for k in keys] # type: ignore[assignment]
347357

348358
for k, order in zip(keys, orders):
349359
if na_position not in ["last", "first"]:
@@ -354,7 +364,9 @@ def lexsort_indexer(
354364
codes = k.copy()
355365
n = len(codes)
356366
mask_n = n
357-
if mask.any():
367+
# error: Item "ExtensionArray" of "Union[Any, ExtensionArray,
368+
# ndarray[Any, Any]]" has no attribute "any"
369+
if mask.any(): # type: ignore[union-attr]
358370
n -= 1
359371

360372
else:
@@ -369,14 +381,40 @@ def lexsort_indexer(
369381

370382
if order: # ascending
371383
if na_position == "last":
372-
codes = np.where(mask, n, codes)
384+
# error: Argument 1 to "where" has incompatible type "Union[Any,
385+
# ExtensionArray, ndarray[Any, Any]]"; expected
386+
# "Union[_SupportsArray[dtype[Any]],
387+
# _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float,
388+
# complex, str, bytes, _NestedSequence[Union[bool, int, float,
389+
# complex, str, bytes]]]"
390+
codes = np.where(mask, n, codes) # type: ignore[arg-type]
373391
elif na_position == "first":
374-
codes += 1
392+
# error: Incompatible types in assignment (expression has type
393+
# "Union[Any, int, ndarray[Any, dtype[signedinteger[Any]]]]",
394+
# variable has type "Union[Series, ExtensionArray, ndarray[Any, Any]]")
395+
# error: Unsupported operand types for + ("ExtensionArray" and "int")
396+
codes += 1 # type: ignore[operator,assignment]
375397
else: # not order means descending
376398
if na_position == "last":
377-
codes = np.where(mask, n, n - codes - 1)
399+
# error: Unsupported operand types for - ("int" and "ExtensionArray")
400+
# error: Argument 1 to "where" has incompatible type "Union[Any,
401+
# ExtensionArray, ndarray[Any, Any]]"; expected
402+
# "Union[_SupportsArray[dtype[Any]],
403+
# _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float,
404+
# complex, str, bytes, _NestedSequence[Union[bool, int, float,
405+
# complex, str, bytes]]]"
406+
codes = np.where(
407+
mask, n, n - codes - 1 # type: ignore[operator,arg-type]
408+
)
378409
elif na_position == "first":
379-
codes = np.where(mask, 0, n - codes)
410+
# error: Unsupported operand types for - ("int" and "ExtensionArray")
411+
# error: Argument 1 to "where" has incompatible type "Union[Any,
412+
# ExtensionArray, ndarray[Any, Any]]"; expected
413+
# "Union[_SupportsArray[dtype[Any]],
414+
# _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float,
415+
# complex, str, bytes, _NestedSequence[Union[bool, int, float,
416+
# complex, str, bytes]]]"
417+
codes = np.where(mask, 0, n - codes) # type: ignore[operator,arg-type]
380418

381419
shape.append(mask_n)
382420
labels.append(codes)
@@ -385,7 +423,7 @@ def lexsort_indexer(
385423

386424

387425
def nargsort(
388-
items,
426+
items: ArrayLike | Index | Series,
389427
kind: str = "quicksort",
390428
ascending: bool = True,
391429
na_position: str = "last",
@@ -401,6 +439,7 @@ def nargsort(
401439
402440
Parameters
403441
----------
442+
items : np.ndarray, ExtensionArray, Index, or Series
404443
kind : str, default 'quicksort'
405444
ascending : bool, default True
406445
na_position : {'first', 'last'}, default 'last'
@@ -414,6 +453,7 @@ def nargsort(
414453
"""
415454

416455
if key is not None:
456+
# see TestDataFrameSortKey, TestRangeIndex::test_sort_values_key
417457
items = ensure_key_mapped(items, key)
418458
return nargsort(
419459
items,
@@ -425,16 +465,27 @@ def nargsort(
425465
)
426466

427467
if isinstance(items, ABCRangeIndex):
428-
return items.argsort(ascending=ascending) # TODO: test coverage with key?
468+
return items.argsort(ascending=ascending)
429469
elif not isinstance(items, ABCMultiIndex):
430470
items = extract_array(items)
471+
else:
472+
raise TypeError(
473+
"nargsort does not support MultiIndex. Use index.sort_values instead."
474+
)
475+
431476
if mask is None:
432-
mask = np.asarray(isna(items)) # TODO: does this exclude MultiIndex too?
477+
mask = np.asarray(isna(items))
433478

434-
if is_extension_array_dtype(items):
435-
return items.argsort(ascending=ascending, kind=kind, na_position=na_position)
436-
else:
437-
items = np.asanyarray(items)
479+
if not isinstance(items, np.ndarray):
480+
# i.e. ExtensionArray
481+
return items.argsort(
482+
ascending=ascending,
483+
# error: Argument "kind" to "argsort" of "ExtensionArray" has
484+
# incompatible type "str"; expected "Literal['quicksort',
485+
# 'mergesort', 'heapsort', 'stable']"
486+
kind=kind, # type: ignore[arg-type]
487+
na_position=na_position,
488+
)
438489

439490
idx = np.arange(len(items))
440491
non_nans = items[~mask]
@@ -551,7 +602,9 @@ def _ensure_key_mapped_multiindex(
551602
return type(index).from_arrays(mapped)
552603

553604

554-
def ensure_key_mapped(values, key: Callable | None, levels=None):
605+
def ensure_key_mapped(
606+
values: ArrayLike | Index | Series, key: Callable | None, levels=None
607+
) -> ArrayLike | Index | Series:
555608
"""
556609
Applies a callable key function to the values function and checks
557610
that the resulting value has the same shape. Can be called on Index
@@ -584,8 +637,10 @@ def ensure_key_mapped(values, key: Callable | None, levels=None):
584637
): # convert to a new Index subclass, not necessarily the same
585638
result = Index(result)
586639
else:
640+
# try to revert to original type otherwise
587641
type_of_values = type(values)
588-
result = type_of_values(result) # try to revert to original type otherwise
642+
# error: Too many arguments for "ExtensionArray"
643+
result = type_of_values(result) # type: ignore[call-arg]
589644
except TypeError:
590645
raise TypeError(
591646
f"User-provided `key` function returned an invalid type {type(result)} \

pandas/tests/test_sorting.py

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -156,61 +156,33 @@ def test_lexsort_indexer(self, order, na_position, exp):
156156
tm.assert_numpy_array_equal(result, np.array(exp, dtype=np.intp))
157157

158158
@pytest.mark.parametrize(
159-
"ascending, na_position, exp, box",
159+
"ascending, na_position, exp",
160160
[
161161
[
162162
True,
163163
"last",
164164
list(range(5, 105)) + list(range(5)) + list(range(105, 110)),
165-
list,
166165
],
167166
[
168167
True,
169168
"first",
170169
list(range(5)) + list(range(105, 110)) + list(range(5, 105)),
171-
list,
172170
],
173171
[
174172
False,
175173
"last",
176174
list(range(104, 4, -1)) + list(range(5)) + list(range(105, 110)),
177-
list,
178175
],
179176
[
180177
False,
181178
"first",
182179
list(range(5)) + list(range(105, 110)) + list(range(104, 4, -1)),
183-
list,
184-
],
185-
[
186-
True,
187-
"last",
188-
list(range(5, 105)) + list(range(5)) + list(range(105, 110)),
189-
lambda x: np.array(x, dtype="O"),
190-
],
191-
[
192-
True,
193-
"first",
194-
list(range(5)) + list(range(105, 110)) + list(range(5, 105)),
195-
lambda x: np.array(x, dtype="O"),
196-
],
197-
[
198-
False,
199-
"last",
200-
list(range(104, 4, -1)) + list(range(5)) + list(range(105, 110)),
201-
lambda x: np.array(x, dtype="O"),
202-
],
203-
[
204-
False,
205-
"first",
206-
list(range(5)) + list(range(105, 110)) + list(range(104, 4, -1)),
207-
lambda x: np.array(x, dtype="O"),
208180
],
209181
],
210182
)
211-
def test_nargsort(self, ascending, na_position, exp, box):
183+
def test_nargsort(self, ascending, na_position, exp):
212184
# list places NaNs last, np.array(..., dtype="O") may not place NaNs first
213-
items = box([np.nan] * 5 + list(range(100)) + [np.nan] * 5)
185+
items = np.array([np.nan] * 5 + list(range(100)) + [np.nan] * 5, dtype="O")
214186

215187
# mergesort is the most difficult to get right because we want it to be
216188
# stable.

0 commit comments

Comments
 (0)