24
24
from pandas .core .dtypes .common import (
25
25
ensure_int64 ,
26
26
ensure_platform_int ,
27
- is_extension_array_dtype ,
28
27
)
29
28
from pandas .core .dtypes .generic import (
30
29
ABCMultiIndex ,
36
35
37
36
if TYPE_CHECKING :
38
37
from pandas ._typing import (
38
+ ArrayLike ,
39
39
AxisInt ,
40
40
IndexKeyFunc ,
41
41
Level ,
45
45
npt ,
46
46
)
47
47
48
- from pandas import MultiIndex
48
+ from pandas import (
49
+ MultiIndex ,
50
+ Series ,
51
+ )
49
52
from pandas .core .arrays import ExtensionArray
50
53
from pandas .core .indexes .base import Index
51
54
@@ -79,7 +82,10 @@ def get_indexer_indexer(
79
82
The indexer for the new index.
80
83
"""
81
84
82
- target = ensure_key_mapped (target , key , levels = level )
85
+ # error: Incompatible types in assignment (expression has type
86
+ # "Union[ExtensionArray, ndarray[Any, Any], Index, Series]", variable has
87
+ # type "Index")
88
+ target = ensure_key_mapped (target , key , levels = level ) # type:ignore[assignment]
83
89
target = target ._sort_levels_monotonic ()
84
90
85
91
if level is not None :
@@ -304,7 +310,7 @@ def indexer_from_factorized(
304
310
305
311
306
312
def lexsort_indexer (
307
- keys ,
313
+ keys : list [ ArrayLike ] | list [ Series ] ,
308
314
orders = None ,
309
315
na_position : str = "last" ,
310
316
key : Callable | None = None ,
@@ -315,8 +321,9 @@ def lexsort_indexer(
315
321
316
322
Parameters
317
323
----------
318
- keys : sequence of arrays
324
+ keys : list[ArrayLike] | list[Series]
319
325
Sequence of ndarrays to be sorted by the indexer
326
+ list[Series] is only if key is not None.
320
327
orders : bool or list of booleans, optional
321
328
Determines the sorting order for each element in keys. If a list,
322
329
it must be the same length as keys. This determines whether the
@@ -343,7 +350,10 @@ def lexsort_indexer(
343
350
elif orders is None :
344
351
orders = [True ] * len (keys )
345
352
346
- keys = [ensure_key_mapped (k , key ) for k in keys ]
353
+ # error: Incompatible types in assignment (expression has type
354
+ # "List[Union[ExtensionArray, ndarray[Any, Any], Index, Series]]", variable
355
+ # has type "Union[List[Union[ExtensionArray, ndarray[Any, Any]]], List[Series]]")
356
+ keys = [ensure_key_mapped (k , key ) for k in keys ] # type: ignore[assignment]
347
357
348
358
for k , order in zip (keys , orders ):
349
359
if na_position not in ["last" , "first" ]:
@@ -354,7 +364,9 @@ def lexsort_indexer(
354
364
codes = k .copy ()
355
365
n = len (codes )
356
366
mask_n = n
357
- if mask .any ():
367
+ # error: Item "ExtensionArray" of "Union[Any, ExtensionArray,
368
+ # ndarray[Any, Any]]" has no attribute "any"
369
+ if mask .any (): # type: ignore[union-attr]
358
370
n -= 1
359
371
360
372
else :
@@ -369,14 +381,40 @@ def lexsort_indexer(
369
381
370
382
if order : # ascending
371
383
if na_position == "last" :
372
- codes = np .where (mask , n , codes )
384
+ # error: Argument 1 to "where" has incompatible type "Union[Any,
385
+ # ExtensionArray, ndarray[Any, Any]]"; expected
386
+ # "Union[_SupportsArray[dtype[Any]],
387
+ # _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float,
388
+ # complex, str, bytes, _NestedSequence[Union[bool, int, float,
389
+ # complex, str, bytes]]]"
390
+ codes = np .where (mask , n , codes ) # type: ignore[arg-type]
373
391
elif na_position == "first" :
374
- codes += 1
392
+ # error: Incompatible types in assignment (expression has type
393
+ # "Union[Any, int, ndarray[Any, dtype[signedinteger[Any]]]]",
394
+ # variable has type "Union[Series, ExtensionArray, ndarray[Any, Any]]")
395
+ # error: Unsupported operand types for + ("ExtensionArray" and "int")
396
+ codes += 1 # type: ignore[operator,assignment]
375
397
else : # not order means descending
376
398
if na_position == "last" :
377
- codes = np .where (mask , n , n - codes - 1 )
399
+ # error: Unsupported operand types for - ("int" and "ExtensionArray")
400
+ # error: Argument 1 to "where" has incompatible type "Union[Any,
401
+ # ExtensionArray, ndarray[Any, Any]]"; expected
402
+ # "Union[_SupportsArray[dtype[Any]],
403
+ # _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float,
404
+ # complex, str, bytes, _NestedSequence[Union[bool, int, float,
405
+ # complex, str, bytes]]]"
406
+ codes = np .where (
407
+ mask , n , n - codes - 1 # type: ignore[operator,arg-type]
408
+ )
378
409
elif na_position == "first" :
379
- codes = np .where (mask , 0 , n - codes )
410
+ # error: Unsupported operand types for - ("int" and "ExtensionArray")
411
+ # error: Argument 1 to "where" has incompatible type "Union[Any,
412
+ # ExtensionArray, ndarray[Any, Any]]"; expected
413
+ # "Union[_SupportsArray[dtype[Any]],
414
+ # _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float,
415
+ # complex, str, bytes, _NestedSequence[Union[bool, int, float,
416
+ # complex, str, bytes]]]"
417
+ codes = np .where (mask , 0 , n - codes ) # type: ignore[operator,arg-type]
380
418
381
419
shape .append (mask_n )
382
420
labels .append (codes )
@@ -385,7 +423,7 @@ def lexsort_indexer(
385
423
386
424
387
425
def nargsort (
388
- items ,
426
+ items : ArrayLike | Index | Series ,
389
427
kind : str = "quicksort" ,
390
428
ascending : bool = True ,
391
429
na_position : str = "last" ,
@@ -401,6 +439,7 @@ def nargsort(
401
439
402
440
Parameters
403
441
----------
442
+ items : np.ndarray, ExtensionArray, Index, or Series
404
443
kind : str, default 'quicksort'
405
444
ascending : bool, default True
406
445
na_position : {'first', 'last'}, default 'last'
@@ -414,6 +453,7 @@ def nargsort(
414
453
"""
415
454
416
455
if key is not None :
456
+ # see TestDataFrameSortKey, TestRangeIndex::test_sort_values_key
417
457
items = ensure_key_mapped (items , key )
418
458
return nargsort (
419
459
items ,
@@ -425,16 +465,27 @@ def nargsort(
425
465
)
426
466
427
467
if isinstance (items , ABCRangeIndex ):
428
- return items .argsort (ascending = ascending ) # TODO: test coverage with key?
468
+ return items .argsort (ascending = ascending )
429
469
elif not isinstance (items , ABCMultiIndex ):
430
470
items = extract_array (items )
471
+ else :
472
+ raise TypeError (
473
+ "nargsort does not support MultiIndex. Use index.sort_values instead."
474
+ )
475
+
431
476
if mask is None :
432
- mask = np .asarray (isna (items )) # TODO: does this exclude MultiIndex too?
477
+ mask = np .asarray (isna (items ))
433
478
434
- if is_extension_array_dtype (items ):
435
- return items .argsort (ascending = ascending , kind = kind , na_position = na_position )
436
- else :
437
- items = np .asanyarray (items )
479
+ if not isinstance (items , np .ndarray ):
480
+ # i.e. ExtensionArray
481
+ return items .argsort (
482
+ ascending = ascending ,
483
+ # error: Argument "kind" to "argsort" of "ExtensionArray" has
484
+ # incompatible type "str"; expected "Literal['quicksort',
485
+ # 'mergesort', 'heapsort', 'stable']"
486
+ kind = kind , # type: ignore[arg-type]
487
+ na_position = na_position ,
488
+ )
438
489
439
490
idx = np .arange (len (items ))
440
491
non_nans = items [~ mask ]
@@ -551,7 +602,9 @@ def _ensure_key_mapped_multiindex(
551
602
return type (index ).from_arrays (mapped )
552
603
553
604
554
- def ensure_key_mapped (values , key : Callable | None , levels = None ):
605
+ def ensure_key_mapped (
606
+ values : ArrayLike | Index | Series , key : Callable | None , levels = None
607
+ ) -> ArrayLike | Index | Series :
555
608
"""
556
609
Applies a callable key function to the values function and checks
557
610
that the resulting value has the same shape. Can be called on Index
@@ -584,8 +637,10 @@ def ensure_key_mapped(values, key: Callable | None, levels=None):
584
637
): # convert to a new Index subclass, not necessarily the same
585
638
result = Index (result )
586
639
else :
640
+ # try to revert to original type otherwise
587
641
type_of_values = type (values )
588
- result = type_of_values (result ) # try to revert to original type otherwise
642
+ # error: Too many arguments for "ExtensionArray"
643
+ result = type_of_values (result ) # type: ignore[call-arg]
589
644
except TypeError :
590
645
raise TypeError (
591
646
f"User-provided `key` function returned an invalid type { type (result )} \
0 commit comments