Skip to content

Commit 412b1ab

Browse files
committed
type check boolean return values
1 parent 385b1bd commit 412b1ab

File tree

1 file changed

+52
-27
lines changed

1 file changed

+52
-27
lines changed

tests/test_string_accessors.py

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -74,37 +74,62 @@ def test_string_accessors_type_preserving_index() -> None:
7474
_check(assert_type(idx.str.zfill(10), "pd.Index[str]"))
7575

7676

77-
def test_string_accessors_type_boolean():
78-
s = pd.Series(
79-
["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"]
80-
)
81-
check(assert_type(s.str.startswith("a"), "pd.Series[bool]"), pd.Series, np.bool_)
82-
check(
77+
def test_string_accessors_type_boolean_series():
78+
data = ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"]
79+
s = pd.Series(data)
80+
_check = functools.partial(check, klass=pd.Series, dtype=bool)
81+
_check(assert_type(s.str.startswith("a"), "pd.Series[bool]"))
82+
_check(
8383
assert_type(s.str.startswith(("a", "b")), "pd.Series[bool]"),
84-
pd.Series,
85-
np.bool_,
8684
)
87-
check(assert_type(s.str.contains("a"), "pd.Series[bool]"), pd.Series, np.bool_)
88-
check(
85+
_check(
86+
assert_type(s.str.contains("a"), "pd.Series[bool]"),
87+
)
88+
_check(
8989
assert_type(s.str.contains(re.compile(r"a")), "pd.Series[bool]"),
90-
pd.Series,
91-
np.bool_,
9290
)
93-
check(assert_type(s.str.endswith("e"), "pd.Series[bool]"), pd.Series, np.bool_)
94-
check(
95-
assert_type(s.str.endswith(("e", "f")), "pd.Series[bool]"), pd.Series, np.bool_
91+
_check(assert_type(s.str.endswith("e"), "pd.Series[bool]"))
92+
_check(assert_type(s.str.endswith(("e", "f")), "pd.Series[bool]"))
93+
_check(assert_type(s.str.fullmatch("apple"), "pd.Series[bool]"))
94+
_check(assert_type(s.str.isalnum(), "pd.Series[bool]"))
95+
_check(assert_type(s.str.isalpha(), "pd.Series[bool]"))
96+
_check(assert_type(s.str.isdecimal(), "pd.Series[bool]"))
97+
_check(assert_type(s.str.isdigit(), "pd.Series[bool]"))
98+
_check(assert_type(s.str.isnumeric(), "pd.Series[bool]"))
99+
_check(assert_type(s.str.islower(), "pd.Series[bool]"))
100+
_check(assert_type(s.str.isspace(), "pd.Series[bool]"))
101+
_check(assert_type(s.str.istitle(), "pd.Series[bool]"))
102+
_check(assert_type(s.str.isupper(), "pd.Series[bool]"))
103+
_check(assert_type(s.str.match("pp"), "pd.Series[bool]"))
104+
105+
106+
def test_string_accessors_type_boolean_index():
107+
data = ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"]
108+
idx = pd.Index(data)
109+
_check = functools.partial(check, klass=np.ndarray, dtype=np.bool_)
110+
_check(assert_type(idx.str.startswith("a"), "npt.NDArray[np.bool_]"))
111+
_check(
112+
assert_type(idx.str.startswith(("a", "b")), "npt.NDArray[np.bool_]"),
96113
)
97-
check(assert_type(s.str.fullmatch("apple"), "pd.Series[bool]"), pd.Series, np.bool_)
98-
check(assert_type(s.str.isalnum(), "pd.Series[bool]"), pd.Series, np.bool_)
99-
check(assert_type(s.str.isalpha(), "pd.Series[bool]"), pd.Series, np.bool_)
100-
check(assert_type(s.str.isdecimal(), "pd.Series[bool]"), pd.Series, np.bool_)
101-
check(assert_type(s.str.isdigit(), "pd.Series[bool]"), pd.Series, np.bool_)
102-
check(assert_type(s.str.isnumeric(), "pd.Series[bool]"), pd.Series, np.bool_)
103-
check(assert_type(s.str.islower(), "pd.Series[bool]"), pd.Series, np.bool_)
104-
check(assert_type(s.str.isspace(), "pd.Series[bool]"), pd.Series, np.bool_)
105-
check(assert_type(s.str.istitle(), "pd.Series[bool]"), pd.Series, np.bool_)
106-
check(assert_type(s.str.isupper(), "pd.Series[bool]"), pd.Series, np.bool_)
107-
check(assert_type(s.str.match("pp"), "pd.Series[bool]"), pd.Series, np.bool_)
114+
_check(
115+
assert_type(idx.str.contains("a"), "npt.NDArray[np.bool_]"),
116+
)
117+
_check(
118+
assert_type(idx.str.contains(re.compile(r"a")), "npt.NDArray[np.bool_]"),
119+
)
120+
_check(assert_type(idx.str.endswith("e"), "npt.NDArray[np.bool_]"))
121+
_check(assert_type(idx.str.endswith(("e", "f")), "npt.NDArray[np.bool_]"))
122+
_check(assert_type(idx.str.fullmatch("apple"), "npt.NDArray[np.bool_]"))
123+
_check(assert_type(idx.str.isalnum(), "npt.NDArray[np.bool_]"))
124+
_check(assert_type(idx.str.isalpha(), "npt.NDArray[np.bool_]"))
125+
_check(assert_type(idx.str.isdecimal(), "npt.NDArray[np.bool_]"))
126+
_check(assert_type(idx.str.isdigit(), "npt.NDArray[np.bool_]"))
127+
_check(assert_type(idx.str.isnumeric(), "npt.NDArray[np.bool_]"))
128+
_check(assert_type(idx.str.islower(), "npt.NDArray[np.bool_]"))
129+
_check(assert_type(idx.str.isspace(), "npt.NDArray[np.bool_]"))
130+
_check(assert_type(idx.str.istitle(), "npt.NDArray[np.bool_]"))
131+
_check(assert_type(idx.str.isupper(), "npt.NDArray[np.bool_]"))
132+
_check(assert_type(idx.str.match("pp"), "npt.NDArray[np.bool_]"))
108133

109134

110135
def test_string_accessors_type_integer():
@@ -125,7 +150,7 @@ def test_string_accessors_encode_decode():
125150
s2 = pd.Series([["apple", "banana"], ["cherry", "date"], [1, "eggplant"]])
126151
check(
127152
assert_type(s_bytes.str.decode("utf-8"), "pd.Series[str]"),
128-
"pd.Series[str]",
153+
pd.Series,
129154
str,
130155
)
131156
check(

0 commit comments

Comments
 (0)