Skip to content

Commit cfda1e8

Browse files
Do not trigger B901 with explicit Generator return type (#481)
1 parent b15feed commit cfda1e8

File tree

3 files changed

+69
-8
lines changed

3 files changed

+69
-8
lines changed

bugbear.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,10 +1187,27 @@ def _loop(parent, node):
11871187
for child in node.body:
11881188
yield from _loop(node, child)
11891189

1190-
def check_for_b901(self, node):
1190+
def check_for_b901(self, node: ast.FunctionDef) -> None:
11911191
if node.name == "__await__":
11921192
return
11931193

1194+
# If the user explicitly wrote the 3-argument version of Generator as the
1195+
# return annotation, they probably know what they were doing.
1196+
if (
1197+
node.returns is not None
1198+
and isinstance(node.returns, ast.Subscript)
1199+
and (
1200+
is_name(node.returns.value, "Generator")
1201+
or is_name(node.returns.value, "typing.Generator")
1202+
or is_name(node.returns.value, "collections.abc.Generator")
1203+
)
1204+
):
1205+
slice = node.returns.slice
1206+
if sys.version_info < (3, 9) and isinstance(slice, ast.Index):
1207+
slice = slice.value
1208+
if isinstance(slice, ast.Tuple) and len(slice.elts) == 3:
1209+
return
1210+
11941211
has_yield = False
11951212
return_node = None
11961213

@@ -1204,9 +1221,8 @@ def check_for_b901(self, node):
12041221
if isinstance(x, ast.Return) and x.value is not None:
12051222
return_node = x
12061223

1207-
if has_yield and return_node is not None:
1208-
self.errors.append(B901(return_node.lineno, return_node.col_offset))
1209-
break
1224+
if has_yield and return_node is not None:
1225+
self.errors.append(B901(return_node.lineno, return_node.col_offset))
12101226

12111227
# taken from pep8-naming
12121228
@classmethod
@@ -1703,6 +1719,16 @@ def compose_call_path(node):
17031719
yield node.id
17041720

17051721

1722+
def is_name(node: ast.expr, name: str) -> bool:
1723+
if "." not in name:
1724+
return isinstance(node, ast.Name) and node.id == name
1725+
else:
1726+
if not isinstance(node, ast.Attribute):
1727+
return False
1728+
rest, attr = name.rsplit(".", maxsplit=1)
1729+
return node.attr == attr and is_name(node.value, rest)
1730+
1731+
17061732
def _transform_slice_to_py39(slice: ast.expr | ast.Slice) -> ast.Slice | ast.expr:
17071733
"""Transform a py38 style slice to a py39 style slice.
17081734

tests/b901.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""
22
Should emit:
3-
B901 - on lines 9, 36
3+
B901
44
"""
55

66
def broken():
77
if True:
8-
return [1, 2, 3]
8+
return [1, 2, 3] # B901
99

1010
yield 3
1111
yield 2
@@ -32,7 +32,7 @@ def not_broken3():
3232

3333

3434
def broken2():
35-
return [3, 2, 1]
35+
return [3, 2, 1] # B901
3636

3737
yield from not_broken()
3838

@@ -75,3 +75,35 @@ class NotBroken9(object):
7575
def __await__(self):
7676
yield from function()
7777
return 42
78+
79+
80+
def broken3():
81+
if True:
82+
return [1, 2, 3] # B901
83+
else:
84+
yield 3
85+
86+
87+
def broken4() -> Iterable[str]:
88+
yield "x"
89+
return ["x"] # B901
90+
91+
92+
def broken5() -> Generator[str]:
93+
yield "x"
94+
return ["x"] # B901
95+
96+
97+
def not_broken10() -> Generator[str, int, float]:
98+
yield "x"
99+
return 1.0
100+
101+
102+
def not_broken11() -> typing.Generator[str, int, float]:
103+
yield "x"
104+
return 1.0
105+
106+
107+
def not_broken12() -> collections.abc.Generator[str, int, float]:
108+
yield "x"
109+
return 1.0

tests/test_bugbear.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,10 @@ def test_b901(self):
792792
filename = Path(__file__).absolute().parent / "b901.py"
793793
bbc = BugBearChecker(filename=str(filename))
794794
errors = list(bbc.run())
795-
self.assertEqual(errors, self.errors(B901(8, 8), B901(35, 4)))
795+
self.assertEqual(
796+
errors,
797+
self.errors(B901(8, 8), B901(35, 4), B901(82, 8), B901(89, 4), B901(94, 4)),
798+
)
796799

797800
def test_b902(self):
798801
filename = Path(__file__).absolute().parent / "b902.py"

0 commit comments

Comments
 (0)