Skip to content

CLN: groupby.ops follow-up cleanup #41204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 30, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,7 @@ class GroupBy(BaseGroupBy[FrameOrSeries]):
grouper: ops.BaseGrouper
as_index: bool

@final
def __init__(
self,
obj: FrameOrSeries,
Expand Down
216 changes: 129 additions & 87 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Hashable,
Iterator,
Sequence,
overload,
)

import numpy as np
Expand Down Expand Up @@ -47,23 +48,35 @@
is_categorical_dtype,
is_complex_dtype,
is_datetime64_any_dtype,
is_datetime64tz_dtype,
is_extension_array_dtype,
is_float_dtype,
is_integer_dtype,
is_numeric_dtype,
is_period_dtype,
is_sparse,
is_timedelta64_dtype,
needs_i8_conversion,
)
from pandas.core.dtypes.dtypes import ExtensionDtype
from pandas.core.dtypes.generic import ABCCategoricalIndex
from pandas.core.dtypes.missing import (
isna,
maybe_fill,
)

from pandas.core.arrays import ExtensionArray
from pandas.core.arrays import (
DatetimeArray,
ExtensionArray,
PeriodArray,
TimedeltaArray,
)
from pandas.core.arrays.boolean import BooleanDtype
from pandas.core.arrays.floating import (
Float64Dtype,
FloatingDtype,
)
from pandas.core.arrays.integer import (
Int64Dtype,
_IntegerDtype,
)
from pandas.core.arrays.masked import (
BaseMaskedArray,
BaseMaskedDtype,
Expand Down Expand Up @@ -194,7 +207,7 @@ def get_cython_func_and_vals(self, values: np.ndarray, is_numeric: bool):

return func, values

def disallow_invalid_ops(self, dtype: DtypeObj, is_numeric: bool = False):
def _disallow_invalid_ops(self, dtype: DtypeObj, is_numeric: bool = False):
"""
Check if we can do this operation with our cython functions.

Expand Down Expand Up @@ -230,7 +243,7 @@ def disallow_invalid_ops(self, dtype: DtypeObj, is_numeric: bool = False):
if how in ["prod", "cumprod"]:
raise TypeError(f"timedelta64 type does not support {how} operations")

def get_output_shape(self, ngroups: int, values: np.ndarray) -> Shape:
def _get_output_shape(self, ngroups: int, values: np.ndarray) -> Shape:
how = self.how
kind = self.kind

Expand Down Expand Up @@ -261,7 +274,15 @@ def get_out_dtype(self, dtype: np.dtype) -> np.dtype:
out_dtype = "object"
return np.dtype(out_dtype)

def get_result_dtype(self, dtype: DtypeObj) -> DtypeObj:
@overload
def _get_result_dtype(self, dtype: np.dtype) -> np.dtype:
...

@overload
def _get_result_dtype(self, dtype: ExtensionDtype) -> ExtensionDtype:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would a TypeVar work here? Might end up being useful elsewhere too

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dont think that would be accurate here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be misunderstanding overloads, but is this expressing ExtensionDtype -> ExtensionDtype and np.dtype -> np.dtype?

So was thinking a TypeVar like TypeVar('DtypeObjT', np.dtype, ExtensionDtype)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do TypeVars not preserve subclasses?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running mypy on something like

from typing import TypeVar

AnyStr = TypeVar('AnyStr', str, bytes)


class Upper(str):
    def __new__(cls, val):
        return str.__new__(cls, val.upper())


class Lower(str):
    def __new__(cls, val):
        return str.__new__(cls, val.lower())


def swap_lower_upper(s: AnyStr) -> AnyStr:
    if isinstance(s, bytes):
        return s
    elif isinstance(s, Lower):
        return Upper(s)
    else:
        return Lower(s)

swap_lower_upper(Lower("hello"))
swap_lower_upper(Upper("hello"))

gives no error, so I don't think subclasses are preserved.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool. will make a separate PR to implement DtypeObjT

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

implenting this with DtypeObjT and getting rid of the overloads im seeing

error: Incompatible return value type (got "dtype[signedinteger[_64Bit]]", expected "ExtensionDtype")  [return-value]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will have time to look in more detail later, initial guess would be that this return is the problem

if how in ["add", "cumsum", "sum", "prod"]:
if dtype == np.dtype(bool):
return np.dtype(np.int64)

IIRC mypy can't narrow based on type equality, so it's thinking that an ExtensionDtype can hit that return. But then not sure why it wouldn't complain with the overload, but those are confusing :)

...

def _get_result_dtype(self, dtype: DtypeObj) -> DtypeObj:
"""
Get the desired dtype of a result based on the
input dtype and how it was computed.
Expand All @@ -276,13 +297,6 @@ def get_result_dtype(self, dtype: DtypeObj) -> DtypeObj:
np.dtype or ExtensionDtype
The desired dtype of the result.
"""
from pandas.core.arrays.boolean import BooleanDtype
from pandas.core.arrays.floating import Float64Dtype
from pandas.core.arrays.integer import (
Int64Dtype,
_IntegerDtype,
)

how = self.how

if how in ["add", "cumsum", "sum", "prod"]:
Expand Down Expand Up @@ -315,15 +329,12 @@ def _ea_wrap_cython_operation(
# TODO: general case implementation overridable by EAs.
orig_values = values

if is_datetime64tz_dtype(values.dtype) or is_period_dtype(values.dtype):
if isinstance(orig_values, (DatetimeArray, PeriodArray)):
# All of the functions implemented here are ordinal, so we can
# operate on the tz-naive equivalents
npvalues = values.view("M8[ns]")
npvalues = orig_values._ndarray.view("M8[ns]")
res_values = self._cython_op_ndim_compat(
# error: Argument 1 to "_cython_op_ndim_compat" of
# "WrappedCythonOp" has incompatible type
# "Union[ExtensionArray, ndarray]"; expected "ndarray"
npvalues, # type: ignore[arg-type]
npvalues,
min_count=min_count,
ngroups=ngroups,
comp_ids=comp_ids,
Expand All @@ -336,14 +347,31 @@ def _ea_wrap_cython_operation(
# preserve float64 dtype
return res_values

res_values = res_values.astype("i8", copy=False)
# error: Too many arguments for "ExtensionArray"
result = type(orig_values)( # type: ignore[call-arg]
res_values, dtype=orig_values.dtype
res_values = res_values.view("i8")
result = type(orig_values)(res_values, dtype=orig_values.dtype)
return result

elif isinstance(orig_values, TimedeltaArray):
# We have an ExtensionArray but not ExtensionDtype
res_values = self._cython_op_ndim_compat(
orig_values._ndarray,
min_count=min_count,
ngroups=ngroups,
comp_ids=comp_ids,
mask=None,
**kwargs,
)
if self.how in ["rank"]:
# i.e. how in WrappedCythonOp.cast_blocklist, since
# other cast_blocklist methods dont go through cython_operation
# preserve float64 dtype
return res_values

# otherwise res_values has the same dtype as original values
result = type(orig_values)(res_values)
return result

elif is_integer_dtype(values.dtype) or is_bool_dtype(values.dtype):
elif isinstance(values.dtype, (BooleanDtype, _IntegerDtype)):
# IntegerArray or BooleanArray
npvalues = values.to_numpy("float64", na_value=np.nan)
res_values = self._cython_op_ndim_compat(
Expand All @@ -359,17 +387,14 @@ def _ea_wrap_cython_operation(
# other cast_blocklist methods dont go through cython_operation
return res_values

dtype = self.get_result_dtype(orig_values.dtype)
# error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]"
# has no attribute "construct_array_type"
cls = dtype.construct_array_type() # type: ignore[union-attr]
dtype = self._get_result_dtype(orig_values.dtype)
cls = dtype.construct_array_type()
return cls._from_sequence(res_values, dtype=dtype)

elif is_float_dtype(values.dtype):
elif isinstance(values.dtype, FloatingDtype):
# FloatingArray
# error: "ExtensionDtype" has no attribute "numpy_dtype"
npvalues = values.to_numpy(
values.dtype.numpy_dtype, # type: ignore[attr-defined]
values.dtype.numpy_dtype,
na_value=np.nan,
)
res_values = self._cython_op_ndim_compat(
Expand All @@ -385,10 +410,8 @@ def _ea_wrap_cython_operation(
# other cast_blocklist methods dont go through cython_operation
return res_values

dtype = self.get_result_dtype(orig_values.dtype)
# error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]"
# has no attribute "construct_array_type"
cls = dtype.construct_array_type() # type: ignore[union-attr]
dtype = self._get_result_dtype(orig_values.dtype)
cls = dtype.construct_array_type()
return cls._from_sequence(res_values, dtype=dtype)

raise NotImplementedError(
Expand Down Expand Up @@ -422,12 +445,13 @@ def _masked_ea_wrap_cython_operation(
mask=mask,
**kwargs,
)
dtype = self.get_result_dtype(orig_values.dtype)
dtype = self._get_result_dtype(orig_values.dtype)
assert isinstance(dtype, BaseMaskedDtype)
cls = dtype.construct_array_type()

return cls(res_values.astype(dtype.type, copy=False), mask)

@final
def _cython_op_ndim_compat(
self,
values: np.ndarray,
Expand Down Expand Up @@ -500,7 +524,7 @@ def _call_cython_op(
if mask is not None:
mask = mask.reshape(values.shape, order="C")

out_shape = self.get_output_shape(ngroups, values)
out_shape = self._get_output_shape(ngroups, values)
func, values = self.get_cython_func_and_vals(values, is_numeric)
out_dtype = self.get_out_dtype(values.dtype)

Expand Down Expand Up @@ -550,19 +574,71 @@ def _call_cython_op(
if self.how not in self.cast_blocklist:
# e.g. if we are int64 and need to restore to datetime64/timedelta64
# "rank" is the only member of cast_blocklist we get here
res_dtype = self.get_result_dtype(orig_values.dtype)
# error: Argument 2 to "maybe_downcast_to_dtype" has incompatible type
# "Union[dtype[Any], ExtensionDtype]"; expected "Union[str, dtype[Any]]"
op_result = maybe_downcast_to_dtype(
result, res_dtype # type: ignore[arg-type]
)
res_dtype = self._get_result_dtype(orig_values.dtype)
op_result = maybe_downcast_to_dtype(result, res_dtype)
else:
op_result = result

# error: Incompatible return value type (got "Union[ExtensionArray, ndarray]",
# expected "ndarray")
return op_result # type: ignore[return-value]

@final
def cython_operation(
self,
*,
values: ArrayLike,
axis: int,
min_count: int = -1,
comp_ids: np.ndarray,
ngroups: int,
**kwargs,
) -> ArrayLike:
"""
Call our cython function, with appropriate pre- and post- processing.
"""
if values.ndim > 2:
raise NotImplementedError("number of dimensions is currently limited to 2")
elif values.ndim == 2:
# Note: it is *not* the case that axis is always 0 for 1-dim values,
# as we can have 1D ExtensionArrays that we need to treat as 2D
assert axis == 1, axis

dtype = values.dtype
is_numeric = is_numeric_dtype(dtype)

# can we do this operation with our cython functions
# if not raise NotImplementedError
self._disallow_invalid_ops(dtype, is_numeric)

if not isinstance(values, np.ndarray):
# i.e. ExtensionArray
if isinstance(values, BaseMaskedArray) and self.uses_mask():
return self._masked_ea_wrap_cython_operation(
values,
min_count=min_count,
ngroups=ngroups,
comp_ids=comp_ids,
**kwargs,
)
else:
return self._ea_wrap_cython_operation(
values,
min_count=min_count,
ngroups=ngroups,
comp_ids=comp_ids,
**kwargs,
)

return self._cython_op_ndim_compat(
values,
min_count=min_count,
ngroups=ngroups,
comp_ids=comp_ids,
mask=None,
**kwargs,
)


class BaseGrouper:
"""
Expand Down Expand Up @@ -799,6 +875,7 @@ def group_info(self):

ngroups = len(obs_group_ids)
comp_ids = ensure_platform_int(comp_ids)

return comp_ids, obs_group_ids, ngroups

@final
Expand Down Expand Up @@ -868,58 +945,23 @@ def _cython_operation(
how: str,
axis: int,
min_count: int = -1,
mask: np.ndarray | None = None,
**kwargs,
) -> ArrayLike:
"""
Returns the values of a cython operation.
"""
assert kind in ["transform", "aggregate"]

if values.ndim > 2:
raise NotImplementedError("number of dimensions is currently limited to 2")
elif values.ndim == 2:
# Note: it is *not* the case that axis is always 0 for 1-dim values,
# as we can have 1D ExtensionArrays that we need to treat as 2D
assert axis == 1, axis

dtype = values.dtype
is_numeric = is_numeric_dtype(dtype)

cy_op = WrappedCythonOp(kind=kind, how=how)

# can we do this operation with our cython functions
# if not raise NotImplementedError
cy_op.disallow_invalid_ops(dtype, is_numeric)

comp_ids, _, _ = self.group_info
ngroups = self.ngroups

func_uses_mask = cy_op.uses_mask()
if is_extension_array_dtype(dtype):
if isinstance(values, BaseMaskedArray) and func_uses_mask:
return cy_op._masked_ea_wrap_cython_operation(
values,
min_count=min_count,
ngroups=ngroups,
comp_ids=comp_ids,
**kwargs,
)
else:
return cy_op._ea_wrap_cython_operation(
values,
min_count=min_count,
ngroups=ngroups,
comp_ids=comp_ids,
**kwargs,
)

return cy_op._cython_op_ndim_compat(
values,
return cy_op.cython_operation(
values=values,
axis=axis,
min_count=min_count,
ngroups=self.ngroups,
comp_ids=comp_ids,
mask=mask,
ngroups=ngroups,
**kwargs,
)

Expand Down Expand Up @@ -967,8 +1009,8 @@ def _aggregate_series_fast(self, obj: Series, func: F):
indexer = get_group_index_sorter(group_index, ngroups)
obj = obj.take(indexer)
group_index = group_index.take(indexer)
grouper = libreduction.SeriesGrouper(obj, func, group_index, ngroups)
result, counts = grouper.get_result()
sgrouper = libreduction.SeriesGrouper(obj, func, group_index, ngroups)
result, counts = sgrouper.get_result()
return result, counts

@final
Expand Down Expand Up @@ -1169,8 +1211,8 @@ def agg_series(self, obj: Series, func: F):
elif obj.index._has_complex_internals:
return self._aggregate_series_pure_python(obj, func)

grouper = libreduction.SeriesBinGrouper(obj, func, self.bins)
return grouper.get_result()
sbg = libreduction.SeriesBinGrouper(obj, func, self.bins)
return sbg.get_result()


def _is_indexed_like(obj, axes, axis: int) -> bool:
Expand Down