-
-
Notifications
You must be signed in to change notification settings - Fork 18.6k
BUG: ArrowExtensionArray.mode(dropna=False) not respecting NAs #50986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
596d47d
6b28efe
d8c16ba
ad82e2f
ca5ece9
ad888bb
7997994
a009e7e
22d4ed4
1c89b6d
c791152
fa1a345
8fe50b7
b97db29
9815179
ae97a15
7acd4a5
0f2eea1
1a680a8
e8e6672
7a5aff1
4ef9a2f
5e5ed68
b3959f8
30918fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1345,7 +1345,6 @@ def _mode(self: ArrowExtensionArrayT, dropna: bool = True) -> ArrowExtensionArra | |
---------- | ||
dropna : bool, default True | ||
Don't consider counts of NA values. | ||
Not implemented by pyarrow. | ||
|
||
Returns | ||
------- | ||
|
@@ -1364,12 +1363,13 @@ def _mode(self: ArrowExtensionArrayT, dropna: bool = True) -> ArrowExtensionArra | |
else: | ||
data = self._data | ||
|
||
modes = pc.mode(data, pc.count_distinct(data).as_py()) | ||
values = modes.field(0) | ||
counts = modes.field(1) | ||
# counts sorted descending i.e counts[0] = max | ||
mask = pc.equal(counts, counts[0]) | ||
most_common = values.filter(mask) | ||
if dropna: | ||
data = data.drop_null() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know if you checked, but it might be more efficient to do this after the value_counts, so on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
So might as well do this after as you suggested There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like filtering after gives incorrect result for multi-mode tests. If filtering were to occur after, I would have to drop the NAs in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
That would be something like:
to drop values based on one field of the struct, before calculating So that line is a bit more complicated as calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah thanks. This passes the tests but appears slower than dropping the NAs beforehand for this example, so I think we should just drop the NAs beforehand for now.
|
||
|
||
res = pc.value_counts(data) | ||
most_common = res.field("values").filter( | ||
pc.equal(res.field("counts"), pc.max(res.field("counts"))) | ||
) | ||
|
||
if pa.types.is_temporal(pa_type): | ||
most_common = most_common.cast(pa_type) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1339,38 +1339,31 @@ def test_quantile(data, interpolation, quantile, request): | |
tm.assert_series_equal(result, expected) | ||
|
||
|
||
@pytest.mark.parametrize("dropna", [True, False]) | ||
@pytest.mark.parametrize( | ||
"take_idx, exp_idx", | ||
[[[0, 0, 2, 2, 4, 4], [4, 0]], [[0, 0, 0, 2, 4, 4], [0]]], | ||
[[[0, 0, 2, 2, 4, 4], [0, 4]], [[0, 0, 0, 2, 4, 4], [0]]], | ||
ids=["multi_mode", "single_mode"], | ||
) | ||
def test_mode(data_for_grouping, dropna, take_idx, exp_idx, request): | ||
pa_dtype = data_for_grouping.dtype.pyarrow_dtype | ||
if pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice! |
||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
raises=pa.ArrowNotImplementedError, | ||
reason=f"mode not supported by pyarrow for {pa_dtype}", | ||
) | ||
) | ||
elif ( | ||
pa.types.is_boolean(pa_dtype) | ||
and "multi_mode" in request.node.nodeid | ||
and pa_version_under9p0 | ||
): | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
reason="https://issues.apache.org/jira/browse/ARROW-17096", | ||
) | ||
) | ||
def test_mode_dropna_true(data_for_grouping, take_idx, exp_idx): | ||
data = data_for_grouping.take(take_idx) | ||
ser = pd.Series(data) | ||
result = ser.mode(dropna=dropna) | ||
result = ser.mode(dropna=True) | ||
expected = pd.Series(data_for_grouping.take(exp_idx)) | ||
tm.assert_series_equal(result, expected) | ||
|
||
|
||
def test_mode_dropna_false_mode_na(data): | ||
# GH 50982 | ||
more_nans = pd.Series([None, None, data[0]], dtype=data.dtype) | ||
result = more_nans.mode(dropna=False) | ||
expected = pd.Series([None], dtype=data.dtype) | ||
tm.assert_series_equal(result, expected) | ||
|
||
expected = pd.Series([None, data[0]], dtype=data.dtype) | ||
result = expected.mode(dropna=False) | ||
tm.assert_series_equal(result, expected) | ||
|
||
|
||
def test_is_bool_dtype(): | ||
# GH 22667 | ||
data = ArrowExtensionArray(pa.array([True, False, True])) | ||
|
Uh oh!
There was an error while loading. Please reload this page.