Skip to content

Commit b826141

Browse files
committed
BUG: Maintain column order with groupby.nth
1 parent e0e948d commit b826141

File tree

4 files changed

+40
-8
lines changed

4 files changed

+40
-8
lines changed

doc/source/whatsnew/v0.24.0.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1297,6 +1297,7 @@ Groupby/Resample/Rolling
12971297
- :func:`RollingGroupby.agg` and :func:`ExpandingGroupby.agg` now support multiple aggregation functions as parameters (:issue:`15072`)
12981298
- Bug in :meth:`DataFrame.resample` and :meth:`Series.resample` when resampling by a weekly offset (``'W'``) across a DST transition (:issue:`9119`, :issue:`21459`)
12991299
- Bug in :meth:`DataFrame.expanding` in which the ``axis`` argument was not being respected during aggregations (:issue:`23372`)
1300+
- Bug in :func:`pandas.core.groupby.GroupBy.nth` where column order was not always preserved (:issue:`20760`)
13001301

13011302
Reshaping
13021303
^^^^^^^^^

pandas/core/groupby/groupby.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,8 @@ def _set_group_selection(self):
492492

493493
if len(groupers):
494494
# GH12839 clear selected obj cache when group selection changes
495-
self._group_selection = ax.difference(Index(groupers)).tolist()
495+
self._group_selection = ax.difference(Index(groupers),
496+
sort=False).tolist()
496497
self._reset_cache('_selected_obj')
497498

498499
def _set_result_index_ordered(self, result):

pandas/core/indexes/base.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2923,17 +2923,20 @@ def intersection(self, other):
29232923
taken.name = None
29242924
return taken
29252925

2926-
def difference(self, other):
2926+
def difference(self, other, sort=True):
29272927
"""
29282928
Return a new Index with elements from the index that are not in
29292929
`other`.
29302930
29312931
This is the set difference of two Index objects.
2932-
It's sorted if sorting is possible.
29332932
29342933
Parameters
29352934
----------
29362935
other : Index or array-like
2936+
sort : bool, default True
2937+
Sort the resulting index if possible
2938+
2939+
.. versionadded:: 0.24.0
29372940
29382941
Returns
29392942
-------
@@ -2942,10 +2945,12 @@ def difference(self, other):
29422945
Examples
29432946
--------
29442947
2945-
>>> idx1 = pd.Index([1, 2, 3, 4])
2948+
>>> idx1 = pd.Index([2, 1, 3, 4])
29462949
>>> idx2 = pd.Index([3, 4, 5, 6])
29472950
>>> idx1.difference(idx2)
29482951
Int64Index([1, 2], dtype='int64')
2952+
>>> idx1.difference(idx2, sort=False)
2953+
Int64Index([2, 1], dtype='int64')
29492954
29502955
"""
29512956
self._assert_can_do_setop(other)
@@ -2964,10 +2969,11 @@ def difference(self, other):
29642969
label_diff = np.setdiff1d(np.arange(this.size), indexer,
29652970
assume_unique=True)
29662971
the_diff = this.values.take(label_diff)
2967-
try:
2968-
the_diff = sorting.safe_sort(the_diff)
2969-
except TypeError:
2970-
pass
2972+
if sort:
2973+
try:
2974+
the_diff = sorting.safe_sort(the_diff)
2975+
except TypeError:
2976+
pass
29712977

29722978
return this._shallow_copy(the_diff, name=result_name, freq=None)
29732979

pandas/tests/groupby/test_nth.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,3 +390,27 @@ def test_nth_empty():
390390
names=['a', 'b']),
391391
columns=['c'])
392392
assert_frame_equal(result, expected)
393+
394+
395+
def test_nth_column_order():
396+
# GH 20760
397+
# Check that nth preserves column order
398+
df = DataFrame([[1, 'b', 100],
399+
[1, 'a', 50],
400+
[1, 'a', np.nan],
401+
[2, 'c', 200],
402+
[2, 'd', 150]],
403+
columns=['A', 'C', 'B'])
404+
result = df.groupby('A').nth(0)
405+
expected = DataFrame([['b', 100.0],
406+
['c', 200.0]],
407+
columns=['C', 'B'],
408+
index=Index([1, 2], name='A'))
409+
assert_frame_equal(result, expected)
410+
411+
result = df.groupby('A').nth(-1, dropna='any')
412+
expected = DataFrame([['a', 50.0],
413+
['d', 150.0]],
414+
columns=['C', 'B'],
415+
index=Index([1, 2], name='A'))
416+
assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)