Skip to content

Commit c22c01b

Browse files
committed
Add mask as return value of _values_for_argsort
1 parent 891a419 commit c22c01b

17 files changed

+109
-12
lines changed

pandas/core/arrays/base.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def isna(self) -> ArrayLike:
361361
"""
362362
raise AbstractMethodError(self)
363363

364-
def _values_for_argsort(self) -> np.ndarray:
364+
def _values_for_argsort(self) -> Tuple[np.ndarray, np.ndarray]:
365365
"""
366366
Return values for sorting.
367367
@@ -376,7 +376,7 @@ def _values_for_argsort(self) -> np.ndarray:
376376
ExtensionArray.argsort
377377
"""
378378
# Note: this is used in `ExtensionArray.argsort`.
379-
return np.array(self)
379+
return np.array(self), self.isna()
380380

381381
def argsort(self, ascending=True, kind='quicksort', *args, **kwargs):
382382
"""
@@ -406,8 +406,29 @@ def argsort(self, ascending=True, kind='quicksort', *args, **kwargs):
406406
# 1. _values_for_argsort : construct the values passed to np.argsort
407407
# 2. argsort : total control over sorting.
408408
ascending = nv.validate_argsort_with_ascending(ascending, args, kwargs)
409-
values = self._values_for_argsort()
410-
result = np.argsort(values, kind=kind, **kwargs)
409+
values, mask = self._values_for_argsort()
410+
411+
def permutation(mask):
412+
# Return a permutation which maps the indices of the
413+
# subarray without nan to the indices of the original array.
414+
permu = np.arange(len(mask))
415+
nan_loc = np.arange(len(mask))[mask]
416+
offset = 0
417+
for x in nan_loc:
418+
permu[x - offset:] += 1
419+
offset += 1
420+
return permu
421+
422+
if mask.any():
423+
notmask = ~mask
424+
notnull = np.argsort(values[notmask], kind=kind, **kwargs)
425+
permu = permutation(mask)
426+
notnull = permu[notnull]
427+
allnan = np.arange(len(self))[mask]
428+
result = np.append(notnull, allnan)
429+
else:
430+
result = np.argsort(values, kind=kind, **kwargs)
431+
411432
if not ascending:
412433
result = result[::-1]
413434
return result

pandas/core/arrays/categorical.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1521,7 +1521,7 @@ def check_for_ordered(self, op):
15211521
"Categorical to an ordered one\n".format(op=op))
15221522

15231523
def _values_for_argsort(self):
1524-
return self._codes.copy()
1524+
return self._codes.copy(), self.isna()
15251525

15261526
def argsort(self, *args, **kwargs):
15271527
# TODO(PY2): use correct signature

pandas/core/arrays/datetimelike.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,7 @@ def _from_factorized(cls, values, original):
621621
return cls(values, dtype=original.dtype)
622622

623623
def _values_for_argsort(self):
624-
return self._data
624+
return self._data, self._isnan
625625

626626
# ------------------------------------------------------------------
627627
# Additional array methods

pandas/core/arrays/integer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,8 @@ def value_counts(self, dropna=True):
512512

513513
return Series(array, index=index)
514514

515-
def _values_for_argsort(self) -> np.ndarray:
515+
#def _values_for_argsort(self) -> np.ndarray:
516+
def _values_for_argsort(self):
516517
"""Return values for sorting.
517518
518519
Returns
@@ -526,8 +527,9 @@ def _values_for_argsort(self) -> np.ndarray:
526527
ExtensionArray.argsort
527528
"""
528529
data = self._data.copy()
530+
mask = self._mask
529531
data[self._mask] = data.min() - 1
530-
return data
532+
return data, mask
531533

532534
@classmethod
533535
def _create_comparison_method(cls, op):

pandas/core/arrays/numpy_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def copy(self, deep=False):
281281
return type(self)(self._ndarray.copy())
282282

283283
def _values_for_argsort(self):
284-
return self._ndarray
284+
return self._ndarray, self.isna()
285285

286286
def _values_for_factorize(self):
287287
return self._ndarray, -1

pandas/core/arrays/period.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,7 @@ def _check_timedeltalike_freq_compat(self, other):
674674
_raise_on_incompatible(self, other)
675675

676676
def _values_for_argsort(self):
677-
return self._data
677+
return self._data, self._isnan
678678

679679

680680
PeriodArray._add_comparison_ops()

pandas/tests/extension/base/methods.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ def test_argsort_missing(self, data_missing_for_sorting):
4444
expected = pd.Series(np.array([1, -1, 0], dtype=np.int64))
4545
self.assert_series_equal(result, expected)
4646

47+
def test_argsort_nan_loc(self, data_multiple_nan):
48+
result = data_multiple_nan.argsort()
49+
expected = np.array([3, 9, 7, 1, 0, 6, 2, 4, 5, 8])
50+
tm.assert_numpy_array_equal(result, expected)
51+
4752
@pytest.mark.parametrize('ascending', [True, False])
4853
def test_sort_values(self, data_for_sorting, ascending):
4954
ser = pd.Series(data_for_sorting)

pandas/tests/extension/decimal/test_decimal.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ def data_missing_for_sorting():
4646
decimal.Decimal('0')])
4747

4848

49+
@pytest.fixture
50+
def data_multiple_nan():
51+
return DecimalArray([decimal.Decimal(x) for x in
52+
[5, 4, np.nan, 1, np.nan, np.nan, 6, 3, np.nan, 2]])
53+
54+
4955
@pytest.fixture
5056
def na_cmp():
5157
return lambda x, y: x.is_nan() and y.is_nan()

pandas/tests/extension/json/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def _concat_same_type(cls, to_concat):
171171
return cls(data)
172172

173173
def _values_for_factorize(self):
174-
frozen = self._values_for_argsort()
174+
frozen, _ = self._values_for_argsort()
175175
if len(frozen) == 0:
176176
# _factorize_array expects 1-d array, this is a len-0 2-d array.
177177
frozen = frozen.ravel()
@@ -182,7 +182,7 @@ def _values_for_argsort(self):
182182
# If all the elemnts of self are the same size P, NumPy will
183183
# cast them to an (N, P) array, instead of an (N,) array of tuples.
184184
frozen = [()] + [tuple(x.items()) for x in self]
185-
return np.array(frozen, dtype=object)[1:]
185+
return np.array(frozen, dtype=object)[1:], self.isna()
186186

187187

188188
def make_data():

pandas/tests/extension/json/test_json.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ def data_missing_for_sorting():
5050
return JSONArray([{'b': 1}, {}, {'a': 4}])
5151

5252

53+
@pytest.fixture
54+
def data_multiple_nan():
55+
return JSONArray([{'e': 5}, {'d': 4}, {}, {'a': 1}, {},
56+
{}, {'f': 6}, {'c': 3}, {}, {'b': 2}])
57+
58+
5359
@pytest.fixture
5460
def na_value(dtype):
5561
return dtype.na_value

pandas/tests/extension/test_categorical.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ def data_missing_for_sorting():
6868
ordered=True)
6969

7070

71+
@pytest.fixture
72+
def data_multiple_nan():
73+
return Categorical(['E', 'D', None, 'A', None,
74+
None, 'F', 'C', None, 'B'], ordered=True)
75+
76+
7177
@pytest.fixture
7278
def na_value():
7379
return np.nan

pandas/tests/extension/test_datetime.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,15 @@ def data_missing_for_sorting(dtype):
4545
dtype=dtype)
4646

4747

48+
@pytest.fixture
49+
def data_multiple_nan(dtype):
50+
idx = pd.date_range(start='2000-01-01', end='2000-01-06')
51+
return DatetimeArray(np.array([idx[4], idx[3], 'NaT', idx[0], 'NaT',
52+
'NaT', idx[5], idx[2], 'NaT', idx[1]],
53+
dtype='datetime64[ns]'),
54+
dtype=dtype)
55+
56+
4857
@pytest.fixture
4958
def data_for_grouping(dtype):
5059
"""

pandas/tests/extension/test_integer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ def data_for_sorting(dtype):
5757
return integer_array([1, 2, 0], dtype=dtype)
5858

5959

60+
@pytest.fixture
61+
def data_multiple_nan(dtype):
62+
return integer_array([5, 4, np.nan, 1, np.nan,
63+
np.nan, 6, 3, np.nan, 2], dtype=dtype)
64+
65+
6066
@pytest.fixture
6167
def data_missing_for_sorting(dtype):
6268
return integer_array([1, np.nan, 0], dtype=dtype)

pandas/tests/extension/test_interval.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ def data_missing_for_sorting():
5757
return IntervalArray.from_tuples([(1, 2), None, (0, 1)])
5858

5959

60+
@pytest.fixture
61+
def data_multiple_nan():
62+
return IntervalArray.from_tuples([(5, 6), (4, 5), None, (1, 2),
63+
None, None, (6, 7), (3, 4),
64+
None, (2, 3)])
65+
66+
6067
@pytest.fixture
6168
def na_value():
6269
return np.nan

pandas/tests/extension/test_numpy.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,23 @@ def data_missing_for_sorting(allow_in_pandas, dtype):
101101
)
102102

103103

104+
@pytest.fixture
105+
def data_multiple_nan(allow_in_pandas, dtype):
106+
"""Length-10 array with a known sort order.
107+
108+
This should be three items [B, NA, A] with
109+
A < B and NA missing.
110+
"""
111+
if dtype.numpy_dtype == 'object':
112+
return PandasArray(
113+
np.array([(5,), (4,), np.nan, (1,), np.nan,
114+
np.nan, (6,), (3,), np.nan, (2,)])
115+
)
116+
return PandasArray(
117+
np.array([5, 4, np.nan, 1, np.nan, np.nan, 6, 3, np.nan, 2])
118+
)
119+
120+
104121
@pytest.fixture
105122
def data_for_grouping(allow_in_pandas, dtype):
106123
"""Data for factorization, grouping, and unique tests.

pandas/tests/extension/test_period.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ def data_missing_for_sorting(dtype):
4040
return PeriodArray([2018, iNaT, 2017], freq=dtype.freq)
4141

4242

43+
@pytest.fixture
44+
def data_multiple_nan(dtype):
45+
return PeriodArray([2005, 2004, iNaT, 2001, iNaT, iNaT,
46+
2006, 2003, iNaT, 2002], freq=dtype.freq)
47+
48+
4349
@pytest.fixture
4450
def data_for_grouping(dtype):
4551
B = 2018

pandas/tests/extension/test_sparse.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ def data_missing_for_sorting(request):
6565
return SparseArray([2, np.nan, 1], fill_value=request.param)
6666

6767

68+
@pytest.fixture(params=[0, np.nan])
69+
def data_multiple_nan(request):
70+
return SparseArray([5, 4, np.nan, 1, np.nan,
71+
np.nan, 6, 3, np.nan, 2], fill_value=request.param)
72+
73+
6874
@pytest.fixture
6975
def na_value():
7076
return np.nan

0 commit comments

Comments
 (0)