Skip to content

Commit 020f93d

Browse files
GH456 First attempt GroupBy.transform improved typing
1 parent 320cf41 commit 020f93d

File tree

3 files changed

+86
-6
lines changed

3 files changed

+86
-6
lines changed

pandas-stubs/_typing.pyi

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -925,6 +925,44 @@ GroupByObjectNonScalar: TypeAlias = (
925925
| list[Grouper]
926926
)
927927
GroupByObject: TypeAlias = Scalar | Index | GroupByObjectNonScalar | Series
928+
GroupByFuncStrs: TypeAlias = Literal[
929+
# Reduction/aggregation functions
930+
"all",
931+
"any",
932+
"corrwith",
933+
"count",
934+
"first",
935+
"idxmax",
936+
"idxmin",
937+
"last",
938+
"max",
939+
"mean",
940+
"median",
941+
"min",
942+
"nunique",
943+
"prod",
944+
"quantile",
945+
"sem",
946+
"size",
947+
"skew",
948+
"std",
949+
"sum",
950+
"var",
951+
# Transformation functions
952+
"bfill",
953+
"cumcount",
954+
"cummax",
955+
"cummin",
956+
"cumprod",
957+
"cumsum",
958+
"diff",
959+
"ffill",
960+
"fillna",
961+
"ngroup",
962+
"pct_change",
963+
"rank",
964+
"shift",
965+
]
928966

929967
StataDateFormat: TypeAlias = Literal[
930968
"tc",

pandas-stubs/core/groupby/generic.pyi

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ from collections.abc import (
77
)
88
from typing import (
99
Any,
10+
Concatenate,
1011
Generic,
1112
Literal,
1213
NamedTuple,
@@ -31,15 +32,18 @@ from typing_extensions import (
3132
from pandas._libs.tslibs.timestamps import Timestamp
3233
from pandas._typing import (
3334
S1,
35+
S2,
3436
AggFuncTypeBase,
3537
AggFuncTypeFrame,
3638
ByT,
3739
CorrelationMethod,
3840
Dtype,
41+
GroupByFuncStrs,
3942
IndexLabel,
4043
Level,
4144
ListLike,
4245
NsmallestNlargestKeep,
46+
P,
4347
Scalar,
4448
TakeIndexer,
4549
WindowingEngine,
@@ -72,14 +76,24 @@ class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]):
7276
**kwargs,
7377
) -> Series: ...
7478
agg = aggregate
79+
@overload
7580
def transform(
7681
self,
77-
func: Callable | str,
78-
*args,
82+
func: Callable[Concatenate[Series[S1], P], Series[S2]],
83+
*args: Any,
7984
engine: WindowingEngine = ...,
8085
engine_kwargs: WindowingEngineKwargs = ...,
81-
**kwargs,
86+
**kwargs: Any,
87+
) -> Series[S2]: ...
88+
@overload
89+
def transform(
90+
self,
91+
func: Callable,
92+
*args: Any,
93+
**kwargs: Any,
8294
) -> Series: ...
95+
@overload
96+
def transform(self, func: GroupByFuncStrs, *args, **kwargs) -> Series: ...
8397
def filter(
8498
self, func: Callable | str, dropna: bool = ..., *args, **kwargs
8599
) -> Series: ...
@@ -206,14 +220,24 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]):
206220
**kwargs,
207221
) -> DataFrame: ...
208222
agg = aggregate
223+
@overload
209224
def transform(
210225
self,
211-
func: Callable | str,
212-
*args,
226+
func: Callable[Concatenate[DataFrame, P], DataFrame],
227+
*args: Any,
213228
engine: WindowingEngine = ...,
214229
engine_kwargs: WindowingEngineKwargs = ...,
215-
**kwargs,
230+
**kwargs: Any,
216231
) -> DataFrame: ...
232+
@overload
233+
def transform(
234+
self,
235+
func: Callable,
236+
*args: Any,
237+
**kwargs: Any,
238+
) -> DataFrame: ...
239+
@overload
240+
def transform(self, func: GroupByFuncStrs, *args, **kwargs) -> DataFrame: ...
217241
def filter(
218242
self, func: Callable, dropna: bool = ..., *args, **kwargs
219243
) -> DataFrame: ...

tests/test_series.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,6 +1084,24 @@ def test_types_groupby_agg() -> None:
10841084
)
10851085

10861086

1087+
def test_types_groupby_transform() -> None:
1088+
s: pd.Series[int] = pd.Series([4, 2, 1, 8], index=["a", "b", "a", "b"])
1089+
1090+
def transform_func(
1091+
x: pd.Series[int], pos_arg: bool, kw_arg: str
1092+
) -> pd.Series[float]:
1093+
return x / (2.0 if pos_arg else 1.0)
1094+
1095+
check(
1096+
assert_type(
1097+
s.groupby(lambda x: x).transform(transform_func, True, kw_arg="foo"),
1098+
"pd.Series[float]",
1099+
),
1100+
pd.Series,
1101+
float,
1102+
)
1103+
1104+
10871105
def test_types_groupby_aggregate() -> None:
10881106
s = pd.Series([4, 2, 1, 8], index=["a", "b", "a", "b"])
10891107
check(assert_type(s.groupby(level=0).aggregate("sum"), pd.Series), pd.Series)

0 commit comments

Comments
 (0)