Skip to content

WIP: Use default in TypeVar so Series defaults to Series[Any], and Index to Index[Any] #1232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ mypy round.py
we get the following error message:

```text
round.py:6: error: Argument "decimals" to "round" of "DataFrame" has incompatible type "DataFrame"; expected "Union[int, Dict[Any, Any], Series[Any]]" [arg-type]
round.py:6: error: Argument "decimals" to "round" of "DataFrame" has incompatible type "DataFrame"; expected "Union[int, Dict[Any, Any], Series]" [arg-type]
Found 1 error in 1 file (checked 1 source file)
```

Expand Down
2 changes: 1 addition & 1 deletion docs/philosophy.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ lt = s < 3

In the pandas source, `lt` is a `Series` with a `dtype` of `bool`. In the pandas-stubs,
the type of `lt` is `Series[bool]`. This allows further type checking to occur in other
pandas methods. Note that in the above example, `s` is typed as `Series[Any]` because
pandas methods. Note that in the above example, `s` is typed as `Series` because
its type cannot be statically inferred.

This also allows type checking for operations on series that contain date/time data. Consider
Expand Down
6 changes: 3 additions & 3 deletions pandas-stubs/_libs/tslibs/timestamps.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ from typing import (
import numpy as np
from pandas import (
DatetimeIndex,
Index,
TimedeltaIndex,
)
from pandas.core.indexes.base import UnknownIndex
from pandas.core.series import (
Series,
TimedeltaSeries,
Expand Down Expand Up @@ -236,15 +236,15 @@ class Timestamp(datetime, SupportsIndex):
@overload
def __eq__(self, other: TimestampSeries) -> Series[bool]: ... # type: ignore[overload-overlap]
@overload
def __eq__(self, other: npt.NDArray[np.datetime64] | UnknownIndex) -> np_ndarray_bool: ... # type: ignore[overload-overlap]
def __eq__(self, other: npt.NDArray[np.datetime64] | Index) -> np_ndarray_bool: ... # type: ignore[overload-overlap]
@overload
def __eq__(self, other: object) -> Literal[False]: ...
@overload
def __ne__(self, other: Timestamp | datetime | np.datetime64) -> bool: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
@overload
def __ne__(self, other: TimestampSeries) -> Series[bool]: ... # type: ignore[overload-overlap]
@overload
def __ne__(self, other: npt.NDArray[np.datetime64] | UnknownIndex) -> np_ndarray_bool: ... # type: ignore[overload-overlap]
def __ne__(self, other: npt.NDArray[np.datetime64] | Index) -> np_ndarray_bool: ... # type: ignore[overload-overlap]
@overload
def __ne__(self, other: object) -> Literal[True]: ...
def __hash__(self) -> int: ...
Expand Down
10 changes: 6 additions & 4 deletions pandas-stubs/_typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ from typing import (
Protocol,
SupportsIndex,
TypedDict,
TypeVar,
overload,
)

Expand All @@ -35,6 +34,7 @@ from pandas.core.tools.datetimes import FulldatetimeDict
from typing_extensions import (
ParamSpec,
TypeAlias,
TypeVar,
)

from pandas._libs.interval import Interval
Expand Down Expand Up @@ -65,7 +65,7 @@ HashableT5 = TypeVar("HashableT5", bound=Hashable)
# array-like

ArrayLike: TypeAlias = ExtensionArray | np.ndarray
AnyArrayLike: TypeAlias = ArrayLike | Index[Any] | Series[Any]
AnyArrayLike: TypeAlias = ArrayLike | Index | Series

# list-like

Expand Down Expand Up @@ -801,7 +801,7 @@ DtypeNp = TypeVar("DtypeNp", bound=np.dtype[np.generic])
KeysArgType: TypeAlias = Any
ListLikeT = TypeVar("ListLikeT", bound=ListLike)
ListLikeExceptSeriesAndStr: TypeAlias = (
MutableSequence[Any] | np.ndarray | tuple[Any, ...] | Index[Any]
MutableSequence[Any] | np.ndarray | tuple[Any, ...] | Index
)
ListLikeU: TypeAlias = Sequence | np.ndarray | Series | Index
ListLikeHashable: TypeAlias = (
Expand Down Expand Up @@ -842,6 +842,7 @@ S1 = TypeVar(
| CategoricalDtype
| BaseOffset
| list[str],
default=Any,
)

S2 = TypeVar(
Expand Down Expand Up @@ -891,6 +892,7 @@ ByT = TypeVar(
| Period
| Interval[int | float | Timestamp | Timedelta]
| tuple,
default=Any,
)
# Use a distinct SeriesByT when using groupby with Series of known dtype.
# Essentially, an intersection between Series S1 TypeVar, and ByT TypeVar
Expand Down Expand Up @@ -949,7 +951,7 @@ ReplaceValue: TypeAlias = (
| NAType
| Sequence[Scalar | Pattern]
| Mapping[HashableT, ScalarT]
| Series[Any]
| Series
| None
)

Expand Down
8 changes: 4 additions & 4 deletions pandas-stubs/core/dtypes/missing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ isneginf_scalar = ...
@overload
def isna(obj: DataFrame) -> DataFrame: ...
@overload
def isna(obj: Series[Any]) -> Series[bool]: ...
def isna(obj: Series) -> Series[bool]: ...
@overload
def isna(obj: Index[Any] | list[Any] | ArrayLike) -> npt.NDArray[np.bool_]: ...
def isna(obj: Index | list[Any] | ArrayLike) -> npt.NDArray[np.bool_]: ...
@overload
def isna(
obj: Scalar | NaTType | NAType | None,
Expand All @@ -39,9 +39,9 @@ isnull = isna
@overload
def notna(obj: DataFrame) -> DataFrame: ...
@overload
def notna(obj: Series[Any]) -> Series[bool]: ...
def notna(obj: Series) -> Series[bool]: ...
@overload
def notna(obj: Index[Any] | list[Any] | ArrayLike) -> npt.NDArray[np.bool_]: ...
def notna(obj: Index | list[Any] | ArrayLike) -> npt.NDArray[np.bool_]: ...
@overload
def notna(obj: ScalarT | NaTType | NAType | None) -> TypeIs[ScalarT]: ...

Expand Down
45 changes: 22 additions & 23 deletions pandas-stubs/core/frame.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ from pandas.core.reshape.pivot import (
)
from pandas.core.series import (
Series,
UnknownSeries,
)
from pandas.core.window import (
Expanding,
Expand All @@ -75,7 +74,7 @@ from pandas._libs.tslibs import BaseOffset
from pandas._libs.tslibs.nattype import NaTType
from pandas._libs.tslibs.offsets import DateOffset
from pandas._typing import (
S1,
S2,
AggFuncTypeBase,
AggFuncTypeDictFrame,
AggFuncTypeDictSeries,
Expand Down Expand Up @@ -1318,11 +1317,11 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
@overload
def stack(
self, level: Level | list[Level] = ..., dropna: _bool = ..., sort: _bool = ...
) -> Self | Series[Any]: ...
) -> Self | Series: ...
@overload
def stack(
self, level: Level | list[Level] = ..., future_stack: _bool = ...
) -> Self | Series[Any]: ...
) -> Self | Series: ...
def explode(
self, column: Sequence[Hashable], ignore_index: _bool = ...
) -> Self: ...
Expand Down Expand Up @@ -1382,7 +1381,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
@overload
def apply(
self,
f: Callable[..., ListLikeExceptSeriesAndStr | Series[Any]],
f: Callable[..., ListLikeExceptSeriesAndStr | Series],
axis: AxisIndex = ...,
raw: _bool = ...,
result_type: None = ...,
Expand All @@ -1392,13 +1391,13 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
@overload
def apply(
self,
f: Callable[..., S1 | NAType],
f: Callable[..., S2 | NAType],
axis: AxisIndex = ...,
raw: _bool = ...,
result_type: None = ...,
args: Any = ...,
**kwargs: Any,
) -> Series[S1]: ...
) -> Series[S2]: ...
# Since non-scalar type T is not supported in Series[T],
# we separate this overload from the above one
@overload
Expand All @@ -1410,24 +1409,24 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
result_type: None = ...,
args: Any = ...,
**kwargs: Any,
) -> Series[Any]: ...
) -> Series: ...

# apply() overloads with keyword result_type, and axis does not matter
@overload
def apply(
self,
f: Callable[..., S1 | NAType],
f: Callable[..., S2 | NAType],
axis: Axis = ...,
raw: _bool = ...,
args: Any = ...,
*,
result_type: Literal["expand", "reduce"],
**kwargs: Any,
) -> Series[S1]: ...
) -> Series[S2]: ...
@overload
def apply(
self,
f: Callable[..., ListLikeExceptSeriesAndStr | Series[Any] | Mapping[Any, Any]],
f: Callable[..., ListLikeExceptSeriesAndStr | Series | Mapping[Any, Any]],
axis: Axis = ...,
raw: _bool = ...,
args: Any = ...,
Expand All @@ -1445,12 +1444,12 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
*,
result_type: Literal["reduce"],
**kwargs: Any,
) -> Series[Any]: ...
) -> Series: ...
@overload
def apply(
self,
f: Callable[
..., ListLikeExceptSeriesAndStr | Series[Any] | Scalar | Mapping[Any, Any]
..., ListLikeExceptSeriesAndStr | Series | Scalar | Mapping[Any, Any]
],
axis: Axis = ...,
raw: _bool = ...,
Expand All @@ -1464,27 +1463,27 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
@overload
def apply(
self,
f: Callable[..., Series[Any]],
f: Callable[..., Series],
axis: AxisIndex = ...,
raw: _bool = ...,
args: Any = ...,
*,
result_type: Literal["reduce"],
**kwargs: Any,
) -> Series[Any]: ...
) -> Series: ...

# apply() overloads with default result_type of None, and keyword axis=1 matters
@overload
def apply(
self,
f: Callable[..., S1 | NAType],
f: Callable[..., S2 | NAType],
raw: _bool = ...,
result_type: None = ...,
args: Any = ...,
*,
axis: AxisColumn,
**kwargs: Any,
) -> Series[S1]: ...
) -> Series[S2]: ...
@overload
def apply(
self,
Expand All @@ -1495,11 +1494,11 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
*,
axis: AxisColumn,
**kwargs: Any,
) -> Series[Any]: ...
) -> Series: ...
@overload
def apply(
self,
f: Callable[..., Series[Any]],
f: Callable[..., Series],
raw: _bool = ...,
result_type: None = ...,
args: Any = ...,
Expand All @@ -1512,7 +1511,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
@overload
def apply(
self,
f: Callable[..., Series[Any]],
f: Callable[..., Series],
raw: _bool = ...,
args: Any = ...,
*,
Expand All @@ -1537,7 +1536,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
) -> Self: ...
def merge(
self,
right: DataFrame | Series[Any],
right: DataFrame | Series,
how: MergeHow = ...,
on: IndexLabel | AnyArrayLike | None = ...,
left_on: IndexLabel | AnyArrayLike | None = ...,
Expand Down Expand Up @@ -2011,7 +2010,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
| Callable[[DataFrame], DataFrame]
| Callable[[Any], _bool]
),
other: Scalar | UnknownSeries | DataFrame | Callable | NAType | None = ...,
other: Scalar | Series | DataFrame | Callable | NAType | None = ...,
*,
inplace: Literal[True],
axis: Axis | None = ...,
Expand All @@ -2027,7 +2026,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
| Callable[[DataFrame], DataFrame]
| Callable[[Any], _bool]
),
other: Scalar | UnknownSeries | DataFrame | Callable | NAType | None = ...,
other: Scalar | Series | DataFrame | Callable | NAType | None = ...,
*,
inplace: Literal[False] = ...,
axis: Axis | None = ...,
Expand Down
16 changes: 8 additions & 8 deletions pandas-stubs/core/generic.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ from pandas import Index
import pandas.core.indexing as indexing
from pandas.core.resample import DatetimeIndexResampler
from pandas.core.series import (
UnknownSeries,
Series,
)
import sqlalchemy.engine
from typing_extensions import (
Expand Down Expand Up @@ -82,7 +82,7 @@ class NDFrame(indexing.IndexingMixin):
def ndim(self) -> int: ...
@property
def size(self) -> int: ...
def equals(self, other: UnknownSeries) -> _bool: ...
def equals(self, other: Series) -> _bool: ...
def __neg__(self) -> Self: ...
def __pos__(self) -> Self: ...
def __nonzero__(self) -> None: ...
Expand Down Expand Up @@ -306,7 +306,7 @@ class NDFrame(indexing.IndexingMixin):
labels: None = ...,
*,
axis: Axis = ...,
index: Hashable | Sequence[Hashable] | Index[Any] = ...,
index: Hashable | Sequence[Hashable] | Index = ...,
columns: Hashable | Iterable[Hashable],
level: Level | None = ...,
inplace: Literal[True],
Expand All @@ -318,7 +318,7 @@ class NDFrame(indexing.IndexingMixin):
labels: None = ...,
*,
axis: Axis = ...,
index: Hashable | Sequence[Hashable] | Index[Any],
index: Hashable | Sequence[Hashable] | Index,
columns: Hashable | Iterable[Hashable] = ...,
level: Level | None = ...,
inplace: Literal[True],
Expand All @@ -327,7 +327,7 @@ class NDFrame(indexing.IndexingMixin):
@overload
def drop(
self,
labels: Hashable | Sequence[Hashable] | Index[Any],
labels: Hashable | Sequence[Hashable] | Index,
*,
axis: Axis = ...,
index: None = ...,
Expand All @@ -342,7 +342,7 @@ class NDFrame(indexing.IndexingMixin):
labels: None = ...,
*,
axis: Axis = ...,
index: Hashable | Sequence[Hashable] | Index[Any] = ...,
index: Hashable | Sequence[Hashable] | Index = ...,
columns: Hashable | Iterable[Hashable],
level: Level | None = ...,
inplace: Literal[False] = ...,
Expand All @@ -354,7 +354,7 @@ class NDFrame(indexing.IndexingMixin):
labels: None = ...,
*,
axis: Axis = ...,
index: Hashable | Sequence[Hashable] | Index[Any],
index: Hashable | Sequence[Hashable] | Index,
columns: Hashable | Iterable[Hashable] = ...,
level: Level | None = ...,
inplace: Literal[False] = ...,
Expand All @@ -363,7 +363,7 @@ class NDFrame(indexing.IndexingMixin):
@overload
def drop(
self,
labels: Hashable | Sequence[Hashable] | Index[Any],
labels: Hashable | Sequence[Hashable] | Index,
*,
axis: Axis = ...,
index: None = ...,
Expand Down
Loading