Skip to content

Commit 45b8da0

Browse files
committed
split out into separate file
1 parent 231b54d commit 45b8da0

File tree

5 files changed

+191
-33
lines changed

5 files changed

+191
-33
lines changed

pandas-stubs/core/indexes/base.pyi

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,15 @@ class Index(IndexOpsMixin[S1]):
264264
@property
265265
def str(
266266
self,
267-
) -> StringMethods[Self, MultiIndex, np_ndarray_bool, Index[list[str]]]: ...
267+
) -> StringMethods[
268+
Self,
269+
MultiIndex,
270+
np_ndarray_bool,
271+
Index[list[str]],
272+
Index[int],
273+
Index[bytes],
274+
Index[str],
275+
]: ...
268276
def is_(self, other) -> bool: ...
269277
def __len__(self) -> int: ...
270278
def __array__(self, dtype=...) -> np.ndarray: ...

pandas-stubs/core/series.pyi

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1156,7 +1156,15 @@ class Series(IndexOpsMixin[S1], NDFrame):
11561156
@property
11571157
def str(
11581158
self,
1159-
) -> StringMethods[Self, DataFrame, Series[bool], Series[list[str]]]: ...
1159+
) -> StringMethods[
1160+
Self,
1161+
DataFrame,
1162+
Series[bool],
1163+
Series[list[str]],
1164+
Series[int],
1165+
Series[bytes],
1166+
Series[str],
1167+
]: ...
11601168
@property
11611169
def dt(self) -> CombinedDatetimelikeProperties: ...
11621170
@property

pandas-stubs/core/strings.pyi

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,14 @@ _TS = TypeVar("_TS", bound=DataFrame | MultiIndex)
3737
_TS2 = TypeVar("_TS2", bound=Series[list[str]] | Index[list[str]])
3838
# The _TM type is what is used for the result of str.match
3939
_TM = TypeVar("_TM", bound=Series[bool] | np_ndarray_bool)
40+
# The _TI type is what is used for the result of str.index / str.find
41+
_TI = TypeVar("_TI", bound=Series[int] | Index[int])
42+
# The _TE type is what is used for the result of str.encode
43+
_TE = TypeVar("_TE", bound=Series[bytes] | Index[bytes])
44+
# The _TD type is what is used for the result of str.encode
45+
_TD = TypeVar("_TD", bound=Series[str] | Index[str])
4046

41-
class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]):
47+
class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2, _TI, _TE, _TD]):
4248
def __init__(self, data: T) -> None: ...
4349
def __getitem__(self, key: slice | int) -> T: ...
4450
def __iter__(self) -> T: ...
@@ -113,15 +119,15 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]):
113119
@overload
114120
def rpartition(self, *, expand: Literal[False]) -> pd.Series[type[object]]: ...
115121
def get(self, i: int) -> T: ...
116-
def join(self, sep: str) -> T: ...
122+
def join(self, sep: str) -> _TD: ...
117123
def contains(
118124
self,
119125
pat: str | re.Pattern[str],
120126
case: bool = ...,
121127
flags: int = ...,
122128
na: Scalar | NaTType | None = ...,
123129
regex: bool = ...,
124-
) -> Series[bool]: ...
130+
) -> _TM: ...
125131
def match(
126132
self, pat: str, case: bool = ..., flags: int = ..., na: Any = ...
127133
) -> _TM: ...
@@ -151,8 +157,8 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]):
151157
def slice_replace(
152158
self, start: int | None = ..., stop: int | None = ..., repl: str | None = ...
153159
) -> T: ...
154-
def decode(self, encoding: str, errors: str = ...) -> Series[str]: ...
155-
def encode(self, encoding: str, errors: str = ...) -> Series[bytes]: ...
160+
def decode(self, encoding: str, errors: str = ...) -> _TD: ...
161+
def encode(self, encoding: str, errors: str = ...) -> _TE: ...
156162
def strip(self, to_strip: str | None = ...) -> T: ...
157163
def lstrip(self, to_strip: str | None = ...) -> T: ...
158164
def rstrip(self, to_strip: str | None = ...) -> T: ...
@@ -167,9 +173,9 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]):
167173
) -> T: ...
168174
def get_dummies(self, sep: str = ...) -> pd.DataFrame: ...
169175
def translate(self, table: dict[int, int | str | None] | None) -> T: ...
170-
def count(self, pat: str, flags: int = ...) -> Series[int]: ...
171-
def startswith(self, pat: str | tuple[str, ...], na: Any = ...) -> Series[bool]: ...
172-
def endswith(self, pat: str | tuple[str, ...], na: Any = ...) -> Series[bool]: ...
176+
def count(self, pat: str, flags: int = ...) -> _TI: ...
177+
def startswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _TM: ...
178+
def endswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _TM: ...
173179
def findall(self, pat: str, flags: int = ...) -> _TS2: ...
174180
@overload
175181
def extract(
@@ -184,37 +190,29 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]):
184190
self, pat: str, flags: int = ..., *, expand: Literal[False]
185191
) -> Series[type[object]]: ...
186192
def extractall(self, pat: str, flags: int = ...) -> pd.DataFrame: ...
187-
def find(
188-
self, sub: str, start: int = ..., end: int | None = ...
189-
) -> Series[int]: ...
190-
def rfind(
191-
self, sub: str, start: int = ..., end: int | None = ...
192-
) -> Series[int]: ...
193+
def find(self, sub: str, start: int = ..., end: int | None = ...) -> _TI: ...
194+
def rfind(self, sub: str, start: int = ..., end: int | None = ...) -> _TI: ...
193195
def normalize(self, form: Literal["NFC", "NFKC", "NFD", "NFKD"]) -> T: ...
194-
def index(
195-
self, sub: str, start: int = ..., end: int | None = ...
196-
) -> Series[int]: ...
197-
def rindex(
198-
self, sub: str, start: int = ..., end: int | None = ...
199-
) -> Series[int]: ...
200-
def len(self) -> Series[int]: ...
196+
def index(self, sub: str, start: int = ..., end: int | None = ...) -> _TI: ...
197+
def rindex(self, sub: str, start: int = ..., end: int | None = ...) -> _TI: ...
198+
def len(self) -> _TI: ...
201199
def lower(self) -> T: ...
202200
def upper(self) -> T: ...
203201
def title(self) -> T: ...
204202
def capitalize(self) -> T: ...
205203
def swapcase(self) -> T: ...
206204
def casefold(self) -> T: ...
207-
def isalnum(self) -> Series[bool]: ...
208-
def isalpha(self) -> Series[bool]: ...
209-
def isdigit(self) -> Series[bool]: ...
210-
def isspace(self) -> Series[bool]: ...
211-
def islower(self) -> Series[bool]: ...
212-
def isupper(self) -> Series[bool]: ...
213-
def istitle(self) -> Series[bool]: ...
214-
def isnumeric(self) -> Series[bool]: ...
215-
def isdecimal(self) -> Series[bool]: ...
205+
def isalnum(self) -> _TM: ...
206+
def isalpha(self) -> _TM: ...
207+
def isdigit(self) -> _TM: ...
208+
def isspace(self) -> _TM: ...
209+
def islower(self) -> _TM: ...
210+
def isupper(self) -> _TM: ...
211+
def istitle(self) -> _TM: ...
212+
def isnumeric(self) -> _TM: ...
213+
def isdecimal(self) -> _TM: ...
216214
def fullmatch(
217215
self, pat: str, case: bool = ..., flags: int = ..., na: Any = ...
218-
) -> Series[bool]: ...
216+
) -> _TM: ...
219217
def removeprefix(self, prefix: str) -> T: ...
220218
def removesuffix(self, suffix: str) -> T: ...

test

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
test
2+
ind abc

tests/test_string_accessors.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import functools
2+
import re
3+
from typing import Any
4+
5+
import numpy as np
6+
import pandas as pd
7+
import pytest
8+
from typing_extensions import assert_type
9+
10+
from tests import check
11+
12+
13+
@pytest.mark.parametrize("constructor", ["series", "index"])
14+
@pytest.mark.parametrize(
15+
("method", "kwargs"),
16+
[
17+
("capitalize", {}),
18+
],
19+
)
20+
def test_string_accessors_type_preserving_series(
21+
constructor: Any, method: str, kwargs: Any
22+
) -> None:
23+
data = ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"]
24+
s = pd.Series(data)
25+
_check = functools.partial(check, klass=pd.Series, dtype=str)
26+
_check(assert_type(s.str.capitalize(), "pd.Series[str]"))
27+
_check(assert_type(s.str.casefold(), "pd.Series[str]"))
28+
check(assert_type(s.str.cat(sep="X"), str), str)
29+
_check(assert_type(s.str.center(10), "pd.Series[str]"))
30+
_check(assert_type(s.str.get(2), "pd.Series[str]"))
31+
_check(assert_type(s.str.ljust(80), "pd.Series[str]"))
32+
_check(assert_type(s.str.lower(), "pd.Series[str]"))
33+
_check(assert_type(s.str.lstrip("a"), "pd.Series[str]"))
34+
_check(assert_type(s.str.normalize("NFD"), "pd.Series[str]"))
35+
_check(assert_type(s.str.pad(80, "right"), "pd.Series[str]"))
36+
_check(assert_type(s.str.removeprefix("a"), "pd.Series[str]"))
37+
_check(assert_type(s.str.removesuffix("e"), "pd.Series[str]"))
38+
_check(assert_type(s.str.repeat(2), "pd.Series[str]"))
39+
_check(assert_type(s.str.replace("a", "X"), "pd.Series[str]"))
40+
_check(assert_type(s.str.rjust(80), "pd.Series[str]"))
41+
_check(assert_type(s.str.rstrip(), "pd.Series[str]"))
42+
_check(assert_type(s.str.slice(0, 4, 2), "pd.Series[str]"))
43+
_check(assert_type(s.str.slice_replace(0, 2, "XX"), "pd.Series[str]"))
44+
_check(assert_type(s.str.strip(), "pd.Series[str]"))
45+
_check(assert_type(s.str.swapcase(), "pd.Series[str]"))
46+
_check(assert_type(s.str.title(), "pd.Series[str]"))
47+
_check(
48+
assert_type(s.str.translate({241: "n"}), "pd.Series[str]"),
49+
)
50+
_check(assert_type(s.str.upper(), "pd.Series[str]"))
51+
_check(assert_type(s.str.wrap(80), "pd.Series[str]"))
52+
_check(assert_type(s.str.zfill(10), "pd.Series[str]"))
53+
54+
55+
def test_string_accessors_type_boolean():
56+
s = pd.Series(
57+
["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"]
58+
)
59+
check(assert_type(s.str.startswith("a"), "pd.Series[bool]"), pd.Series, np.bool_)
60+
check(
61+
assert_type(s.str.startswith(("a", "b")), "pd.Series[bool]"),
62+
pd.Series,
63+
np.bool_,
64+
)
65+
check(assert_type(s.str.contains("a"), "pd.Series[bool]"), pd.Series, np.bool_)
66+
check(
67+
assert_type(s.str.contains(re.compile(r"a")), "pd.Series[bool]"),
68+
pd.Series,
69+
np.bool_,
70+
)
71+
check(assert_type(s.str.endswith("e"), "pd.Series[bool]"), pd.Series, np.bool_)
72+
check(
73+
assert_type(s.str.endswith(("e", "f")), "pd.Series[bool]"), pd.Series, np.bool_
74+
)
75+
check(assert_type(s.str.fullmatch("apple"), "pd.Series[bool]"), pd.Series, np.bool_)
76+
check(assert_type(s.str.isalnum(), "pd.Series[bool]"), pd.Series, np.bool_)
77+
check(assert_type(s.str.isalpha(), "pd.Series[bool]"), pd.Series, np.bool_)
78+
check(assert_type(s.str.isdecimal(), "pd.Series[bool]"), pd.Series, np.bool_)
79+
check(assert_type(s.str.isdigit(), "pd.Series[bool]"), pd.Series, np.bool_)
80+
check(assert_type(s.str.isnumeric(), "pd.Series[bool]"), pd.Series, np.bool_)
81+
check(assert_type(s.str.islower(), "pd.Series[bool]"), pd.Series, np.bool_)
82+
check(assert_type(s.str.isspace(), "pd.Series[bool]"), pd.Series, np.bool_)
83+
check(assert_type(s.str.istitle(), "pd.Series[bool]"), pd.Series, np.bool_)
84+
check(assert_type(s.str.isupper(), "pd.Series[bool]"), pd.Series, np.bool_)
85+
check(assert_type(s.str.match("pp"), "pd.Series[bool]"), pd.Series, np.bool_)
86+
87+
88+
def test_string_accessors_type_integer():
89+
s = pd.Series(
90+
["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"]
91+
)
92+
check(assert_type(s.str.find("p"), "pd.Series[int]"), pd.Series, np.int64)
93+
check(assert_type(s.str.index("p"), "pd.Series[int]"), pd.Series, np.int64)
94+
check(assert_type(s.str.rfind("e"), "pd.Series[int]"), pd.Series, np.int64)
95+
check(assert_type(s.str.rindex("p"), "pd.Series[int]"), pd.Series, np.int64)
96+
check(assert_type(s.str.count("pp"), "pd.Series[int]"), pd.Series, np.integer)
97+
check(assert_type(s.str.len(), "pd.Series[int]"), pd.Series, np.integer)
98+
99+
100+
def test_string_accessors_encode_decode():
101+
s_str = pd.Series(["a1", "b2", "c3"])
102+
s_bytes = pd.Series([b"a1", b"b2", b"c3"])
103+
s2 = pd.Series([["apple", "banana"], ["cherry", "date"], [1, "eggplant"]])
104+
check(
105+
assert_type(s_bytes.str.decode("utf-8"), "pd.Series[str]"),
106+
"pd.Series[str]",
107+
str,
108+
)
109+
check(
110+
assert_type(s_str.str.encode("latin-1"), "pd.Series[bytes]"), pd.Series, bytes
111+
)
112+
check(assert_type(s2.str.join("-"), "pd.Series[str]"), pd.Series, str)
113+
114+
115+
def test_string_accessors_list():
116+
s = pd.Series(
117+
["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"]
118+
)
119+
check(assert_type(s.str.findall("pp"), "pd.Series[list[str]]"), pd.Series, list)
120+
check(assert_type(s.str.split("a"), "pd.Series[list[str]]"), pd.Series, list)
121+
# GH 194
122+
check(
123+
assert_type(s.str.split("a", expand=False), "pd.Series[list[str]]"),
124+
pd.Series,
125+
list,
126+
)
127+
check(assert_type(s.str.rsplit("a"), "pd.Series[list[str]]"), pd.Series, list)
128+
check(
129+
assert_type(s.str.rsplit("a", expand=False), "pd.Series[list[str]]"),
130+
pd.Series,
131+
list,
132+
)
133+
134+
135+
# def test_string_accessors_expanding():
136+
# check(assert_type(s3.str.extract(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame)
137+
# check(assert_type(s3.str.extractall(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame)
138+
# check(assert_type(s.str.get_dummies(), pd.DataFrame), pd.DataFrame)
139+
# check(assert_type(s.str.partition("p"), pd.DataFrame), pd.DataFrame)
140+
# check(assert_type(s.str.rpartition("p"), pd.DataFrame), pd.DataFrame)
141+
# check(assert_type(s.str.rsplit("a", expand=True), pd.DataFrame), pd.DataFrame)
142+
# check(assert_type(s.str.split("a", expand=True), pd.DataFrame), pd.DataFrame)

0 commit comments

Comments
 (0)