Skip to content

Commit 106a6f5

Browse files
GH456 Attempt GroupBy.aggregate improved typing
1 parent 020f93d commit 106a6f5

File tree

2 files changed

+69
-6
lines changed

2 files changed

+69
-6
lines changed

pandas-stubs/core/groupby/generic.pyi

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ from pandas.core.groupby.groupby import (
2323
GroupBy,
2424
GroupByPlot,
2525
)
26-
from pandas.core.series import Series
26+
from pandas.core.series import (
27+
Series,
28+
UnknownSeries,
29+
)
2730
from typing_extensions import (
2831
Self,
2932
TypeAlias,
@@ -57,10 +60,21 @@ class NamedAgg(NamedTuple):
5760
aggfunc: AggScalar
5861

5962
class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]):
63+
@overload
64+
def aggregate(
65+
self,
66+
func: Callable[Concatenate[Series[S1], P], S2],
67+
/,
68+
*args,
69+
engine: WindowingEngine = ...,
70+
engine_kwargs: WindowingEngineKwargs = ...,
71+
**kwargs,
72+
) -> Series[S2]: ...
6073
@overload
6174
def aggregate(
6275
self,
6376
func: list[AggFuncTypeBase],
77+
/,
6478
*args,
6579
engine: WindowingEngine = ...,
6680
engine_kwargs: WindowingEngineKwargs = ...,
@@ -70,16 +84,18 @@ class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]):
7084
def aggregate(
7185
self,
7286
func: AggFuncTypeBase | None = ...,
87+
/,
7388
*args,
7489
engine: WindowingEngine = ...,
7590
engine_kwargs: WindowingEngineKwargs = ...,
7691
**kwargs,
77-
) -> Series: ...
92+
) -> UnknownSeries: ...
7893
agg = aggregate
7994
@overload
8095
def transform(
8196
self,
8297
func: Callable[Concatenate[Series[S1], P], Series[S2]],
98+
/,
8399
*args: Any,
84100
engine: WindowingEngine = ...,
85101
engine_kwargs: WindowingEngineKwargs = ...,
@@ -91,9 +107,9 @@ class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]):
91107
func: Callable,
92108
*args: Any,
93109
**kwargs: Any,
94-
) -> Series: ...
110+
) -> UnknownSeries: ...
95111
@overload
96-
def transform(self, func: GroupByFuncStrs, *args, **kwargs) -> Series: ...
112+
def transform(self, func: GroupByFuncStrs, *args, **kwargs) -> UnknownSeries: ...
97113
def filter(
98114
self, func: Callable | str, dropna: bool = ..., *args, **kwargs
99115
) -> Series: ...

tests/test_series.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,7 +1078,16 @@ def test_types_groupby_agg() -> None:
10781078
r"The provided callable <built-in function (min|sum)> is currently using",
10791079
upper="2.2.99",
10801080
):
1081-
check(assert_type(s.groupby(level=0).agg(sum), pd.Series), pd.Series)
1081+
1082+
def sum_sr(s: pd.Series[int]) -> int:
1083+
# type of `sum` not well inferred by mypy
1084+
return sum(s)
1085+
1086+
check(
1087+
assert_type(s.groupby(level=0).agg(sum_sr), "pd.Series[int]"),
1088+
pd.Series,
1089+
np.integer,
1090+
)
10821091
check(
10831092
assert_type(s.groupby(level=0).agg([min, sum]), pd.DataFrame), pd.DataFrame
10841093
)
@@ -1100,6 +1109,16 @@ def transform_func(
11001109
pd.Series,
11011110
float,
11021111
)
1112+
check(
1113+
assert_type(
1114+
s.groupby(lambda x: x).transform(
1115+
transform_func, True, engine="cython", kw_arg="foo"
1116+
),
1117+
"pd.Series[float]",
1118+
),
1119+
pd.Series,
1120+
float,
1121+
)
11031122

11041123

11051124
def test_types_groupby_aggregate() -> None:
@@ -1109,12 +1128,40 @@ def test_types_groupby_aggregate() -> None:
11091128
assert_type(s.groupby(level=0).aggregate(["min", "sum"]), pd.DataFrame),
11101129
pd.DataFrame,
11111130
)
1131+
1132+
def func(s: pd.Series[int]) -> float:
1133+
return s.astype(float).min()
1134+
1135+
s = pd.Series([1, 2, 3, 4])
1136+
s.groupby([1, 1, 2, 2]).agg(lambda x: x.astype(float).min())
1137+
check(
1138+
assert_type(s.groupby(level=0).aggregate(func), "pd.Series[float]"),
1139+
pd.Series,
1140+
np.floating,
1141+
)
1142+
check(
1143+
assert_type(
1144+
s.groupby(level=0).aggregate(func, engine="cython"), "pd.Series[float]"
1145+
),
1146+
pd.Series,
1147+
np.floating,
1148+
)
1149+
11121150
with pytest_warns_bounded(
11131151
FutureWarning,
11141152
r"The provided callable <built-in function (min|sum)> is currently using",
11151153
upper="2.2.99",
11161154
):
1117-
check(assert_type(s.groupby(level=0).aggregate(sum), pd.Series), pd.Series)
1155+
1156+
def sum_sr(s: pd.Series[int]) -> int:
1157+
# type of `sum` not well inferred by mypy
1158+
return sum(s)
1159+
1160+
check(
1161+
assert_type(s.groupby(level=0).aggregate(sum_sr), "pd.Series[int]"),
1162+
pd.Series,
1163+
np.integer,
1164+
)
11181165
check(
11191166
assert_type(s.groupby(level=0).aggregate([min, sum]), pd.DataFrame),
11201167
pd.DataFrame,

0 commit comments

Comments
 (0)