Skip to content

Commit adad411

Browse files
committed
replace bool with bool_t in pandas/core/generic
1 parent b835ca2 commit adad411

File tree

4 files changed

+152
-14
lines changed

4 files changed

+152
-14
lines changed

.pre-commit-config.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,11 @@ repos:
193193
language: python
194194
types: [rst]
195195
files: ^doc/source/(development|reference)/
196+
- id: no-bool-in-core-generic
197+
name: Use bool_t instead of bool in pandas/core/generic.py
198+
entry: python scripts/no_bool_in_generic.py
199+
language: python
200+
files: ^pandas/core/generic\.py$
196201
- repo: https://github.com/asottile/yesqa
197202
rev: v1.2.2
198203
hooks:

pandas/core/generic.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ class NDFrame(PandasObject, SelectionMixin, indexing.IndexingMixin):
234234
def __init__(
235235
self,
236236
data: Manager,
237-
copy: bool = False,
237+
copy: bool_t = False,
238238
attrs: Optional[Mapping[Optional[Hashable], Any]] = None,
239239
):
240240
# copy kwarg is retained for mypy compat, is not used
@@ -251,7 +251,7 @@ def __init__(
251251

252252
@classmethod
253253
def _init_mgr(
254-
cls, mgr, axes, dtype: Optional[Dtype] = None, copy: bool = False
254+
cls, mgr, axes, dtype: Optional[Dtype] = None, copy: bool_t = False
255255
) -> Manager:
256256
""" passed a manager and a axes dict """
257257
for a, axe in axes.items():
@@ -344,8 +344,8 @@ def flags(self) -> Flags:
344344
def set_flags(
345345
self: FrameOrSeries,
346346
*,
347-
copy: bool = False,
348-
allows_duplicate_labels: Optional[bool] = None,
347+
copy: bool_t = False,
348+
allows_duplicate_labels: Optional[bool_t] = None,
349349
) -> FrameOrSeries:
350350
"""
351351
Return a new object with updated flags.
@@ -434,7 +434,7 @@ def _data(self):
434434
_stat_axis_name = "index"
435435
_AXIS_ORDERS: List[str]
436436
_AXIS_TO_AXIS_NUMBER: Dict[Axis, int] = {0: 0, "index": 0, "rows": 0}
437-
_AXIS_REVERSED: bool
437+
_AXIS_REVERSED: bool_t
438438
_info_axis_number: int
439439
_info_axis_name: str
440440
_AXIS_LEN: int
@@ -461,7 +461,7 @@ def _construct_axes_dict(self, axes=None, **kwargs):
461461
@final
462462
@classmethod
463463
def _construct_axes_from_arguments(
464-
cls, args, kwargs, require_all: bool = False, sentinel=None
464+
cls, args, kwargs, require_all: bool_t = False, sentinel=None
465465
):
466466
"""
467467
Construct and returns axes if supplied in args/kwargs.
@@ -662,7 +662,7 @@ def _obj_with_exclusions(self: FrameOrSeries) -> FrameOrSeries:
662662
""" internal compat with SelectionMixin """
663663
return self
664664

665-
def set_axis(self, labels, axis: Axis = 0, inplace: bool = False):
665+
def set_axis(self, labels, axis: Axis = 0, inplace: bool_t = False):
666666
"""
667667
Assign desired index to given axis.
668668
@@ -693,7 +693,7 @@ def set_axis(self, labels, axis: Axis = 0, inplace: bool = False):
693693
return self._set_axis_nocheck(labels, axis, inplace)
694694

695695
@final
696-
def _set_axis_nocheck(self, labels, axis: Axis, inplace: bool):
696+
def _set_axis_nocheck(self, labels, axis: Axis, inplace: bool_t):
697697
# NDFrame.rename with inplace=False calls set_axis(inplace=True) on a copy.
698698
if inplace:
699699
setattr(self, self._get_axis_name(axis), labels)
@@ -932,8 +932,8 @@ def rename(
932932
index: Optional[Renamer] = None,
933933
columns: Optional[Renamer] = None,
934934
axis: Optional[Axis] = None,
935-
copy: bool = True,
936-
inplace: bool = False,
935+
copy: bool_t = True,
936+
inplace: bool_t = False,
937937
level: Optional[Level] = None,
938938
errors: str = "ignore",
939939
) -> Optional[FrameOrSeries]:
@@ -1339,13 +1339,13 @@ def _set_axis_name(self, name, axis=0, inplace=False):
13391339
# Comparison Methods
13401340

13411341
@final
1342-
def _indexed_same(self, other) -> bool:
1342+
def _indexed_same(self, other) -> bool_t:
13431343
return all(
13441344
self._get_axis(a).equals(other._get_axis(a)) for a in self._AXIS_ORDERS
13451345
)
13461346

13471347
@final
1348-
def equals(self, other: object) -> bool:
1348+
def equals(self, other: object) -> bool_t:
13491349
"""
13501350
Test whether two objects contain the same elements.
13511351
@@ -5006,15 +5006,15 @@ def filter(
50065006
return self.reindex(**{name: [r for r in items if r in labels]})
50075007
elif like:
50085008

5009-
def f(x) -> bool:
5009+
def f(x) -> bool_t:
50105010
assert like is not None # needed for mypy
50115011
return like in ensure_str(x)
50125012

50135013
values = labels.map(f)
50145014
return self.loc(axis=axis)[values]
50155015
elif regex:
50165016

5017-
def f(x) -> bool:
5017+
def f(x) -> bool_t:
50185018
return matcher.search(ensure_str(x)) is not None
50195019

50205020
matcher = re.compile(regex)

scripts/no_bool_in_generic.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""
2+
Check that pandas/core/generic.py doesn't use bool as a type annotation.
3+
4+
There is already the method `bool`, so the alias `bool_t` should be used instead.
5+
6+
This is meant to be run as a pre-commit hook - to run it manually, you can do:
7+
8+
pre-commit run no-bool-in-core-generic --all-files
9+
10+
To automatically fixup pandas/core/generic.py, you can pass `--replace`,
11+
though note that you will also need the additional dependency `tokenize-rt`
12+
(which is left out from the pre-commit hook so that it uses the same virtualenv
13+
as the other local ones).
14+
"""
15+
16+
import argparse
17+
import ast
18+
from typing import (
19+
Optional,
20+
Sequence,
21+
Set,
22+
Tuple,
23+
)
24+
25+
ERROR_MESSAGE = (
26+
"Found annotation 'bool' at line {line}, column {col} - use 'bool_t' instead"
27+
)
28+
Offset = Tuple[int, int]
29+
30+
31+
class Visitor(ast.NodeVisitor):
32+
def __init__(self) -> None:
33+
self.to_replace: Set[Offset] = set()
34+
35+
def generic_visit(self, node):
36+
if hasattr(node, "annotation"):
37+
if isinstance(node.annotation, ast.Name):
38+
if node.annotation.id == "bool":
39+
self.to_replace.add(
40+
(node.annotation.lineno, node.annotation.col_offset)
41+
)
42+
elif hasattr(node, "returns"):
43+
if isinstance(node.returns, ast.Name):
44+
if node.returns.id == "bool":
45+
self.to_replace.add((node.returns.lineno, node.returns.col_offset))
46+
super().generic_visit(node)
47+
48+
49+
def replace_bool_with_bool_t(visitor: Visitor, content: str) -> str:
50+
from tokenize_rt import (
51+
reversed_enumerate,
52+
src_to_tokens,
53+
tokens_to_src,
54+
)
55+
56+
tokens = src_to_tokens(content)
57+
for n, i in reversed_enumerate(tokens):
58+
if i.offset in visitor.to_replace:
59+
tokens[n] = i._replace(src="bool_t")
60+
61+
new_src: str = tokens_to_src(tokens)
62+
return new_src
63+
64+
65+
def check_for_bool_in_generic(content: str, *, replace: bool) -> Optional[str]:
66+
tree = ast.parse(content)
67+
68+
visitor = Visitor()
69+
visitor.visit(tree)
70+
71+
if not visitor.to_replace:
72+
# Nothing to replace.
73+
return content
74+
75+
if not replace:
76+
line, col = visitor.to_replace.pop()
77+
msg = ERROR_MESSAGE.format(line=line, col=col)
78+
raise RuntimeError(msg)
79+
80+
return replace_bool_with_bool_t(visitor, content)
81+
82+
83+
def main(argv: Optional[Sequence[str]] = None) -> None:
84+
parser = argparse.ArgumentParser()
85+
parser.add_argument("paths", nargs="*")
86+
parser.add_argument("--replace", action="store_true")
87+
args = parser.parse_args(argv)
88+
89+
for path in args.paths:
90+
with open(path, encoding="utf-8") as fd:
91+
content = fd.read()
92+
new_content = check_for_bool_in_generic(content, replace=args.replace)
93+
if not args.replace or new_content is None:
94+
continue
95+
with open(path, "w", encoding="utf-8") as fd:
96+
fd.write(new_content)
97+
98+
99+
if __name__ == "__main__":
100+
main()
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import pytest
2+
3+
from scripts.no_bool_in_generic import check_for_bool_in_generic
4+
5+
BAD_FILE = "def foo(a: bool) -> bool:\n ..."
6+
GOOD_FILE = "def foo(a: bool_t) -> bool_t:\n ..."
7+
8+
9+
def test_bad_file():
10+
content = BAD_FILE
11+
msg = r"Found annotation 'bool' at line 1, column 11 - use 'bool_t' instead"
12+
with pytest.raises(RuntimeError, match=msg):
13+
check_for_bool_in_generic(content, replace=False)
14+
15+
16+
def test_good_file():
17+
# should not raise
18+
content = GOOD_FILE
19+
check_for_bool_in_generic(content, replace=False)
20+
21+
22+
def test_bad_file_with_replace():
23+
content = BAD_FILE
24+
result = check_for_bool_in_generic(content, replace=True)
25+
expected = GOOD_FILE
26+
assert result == expected
27+
28+
29+
def test_good_file_with_replace():
30+
content = GOOD_FILE
31+
result = check_for_bool_in_generic(content, replace=True)
32+
expected = content
33+
assert result == expected

0 commit comments

Comments
 (0)