Skip to content

Commit d569ccc

Browse files
authored
stubgen: Add support for yield statements (#10745)
1 parent beba94c commit d569ccc

File tree

3 files changed

+133
-11
lines changed

3 files changed

+133
-11
lines changed

mypy/stubgen.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@
9191
from mypy.find_sources import create_source_list, InvalidSourceList
9292
from mypy.build import build
9393
from mypy.errors import CompileError, Errors
94-
from mypy.traverser import has_return_statement
94+
from mypy.traverser import all_yield_expressions, has_return_statement, has_yield_expression
9595
from mypy.moduleinspect import ModuleInspect
9696

9797

@@ -550,13 +550,17 @@ def visit_mypy_file(self, o: MypyFile) -> None:
550550
self.path = o.path
551551
self.defined_names = find_defined_names(o)
552552
self.referenced_names = find_referenced_names(o)
553-
typing_imports = ["Any", "Optional", "TypeVar"]
554-
for t in typing_imports:
555-
if t not in self.defined_names:
556-
alias = None
557-
else:
558-
alias = '_' + t
559-
self.import_tracker.add_import_from("typing", [(t, alias)])
553+
known_imports = {
554+
"typing": ["Any", "TypeVar"],
555+
"collections.abc": ["Generator"],
556+
}
557+
for pkg, imports in known_imports.items():
558+
for t in imports:
559+
if t not in self.defined_names:
560+
alias = None
561+
else:
562+
alias = '_' + t
563+
self.import_tracker.add_import_from(pkg, [(t, alias)])
560564
super().visit_mypy_file(o)
561565
undefined_names = [name for name in self._all_ or []
562566
if name not in self._toplevel_names]
@@ -661,6 +665,23 @@ def visit_func_def(self, o: FuncDef, is_abstract: bool = False,
661665
# Always assume abstract methods return Any unless explicitly annotated. Also
662666
# some dunder methods should not have a None return type.
663667
retname = None # implicit Any
668+
elif has_yield_expression(o):
669+
self.add_abc_import('Generator')
670+
yield_name = 'None'
671+
send_name = 'None'
672+
return_name = 'None'
673+
for expr, in_assignment in all_yield_expressions(o):
674+
if expr.expr is not None and not self.is_none_expr(expr.expr):
675+
self.add_typing_import('Any')
676+
yield_name = 'Any'
677+
if in_assignment:
678+
self.add_typing_import('Any')
679+
send_name = 'Any'
680+
if has_return_statement(o):
681+
self.add_typing_import('Any')
682+
return_name = 'Any'
683+
generator_name = self.typing_name('Generator')
684+
retname = f'{generator_name}[{yield_name}, {send_name}, {return_name}]'
664685
elif not has_return_statement(o) and not is_abstract:
665686
retname = 'None'
666687
retfield = ''
@@ -671,6 +692,9 @@ def visit_func_def(self, o: FuncDef, is_abstract: bool = False,
671692
self.add("){}: ...\n".format(retfield))
672693
self._state = FUNC
673694

695+
def is_none_expr(self, expr: Expression) -> bool:
696+
return isinstance(expr, NameExpr) and expr.name == "None"
697+
674698
def visit_decorator(self, o: Decorator) -> None:
675699
if self.is_private_name(o.func.name, o.func.fullname):
676700
return
@@ -1106,6 +1130,14 @@ def add_typing_import(self, name: str) -> None:
11061130
name = self.typing_name(name)
11071131
self.import_tracker.require_name(name)
11081132

1133+
def add_abc_import(self, name: str) -> None:
1134+
"""Add a name to be imported from collections.abc, unless it's imported already.
1135+
1136+
The import will be internal to the stub.
1137+
"""
1138+
name = self.typing_name(name)
1139+
self.import_tracker.require_name(name)
1140+
11091141
def add_import_line(self, line: str) -> None:
11101142
"""Add a line of text to the import section, unless it's already there."""
11111143
if line not in self._import_lines:

mypy/traverser.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Generic node traverser visitor"""
22

3-
from typing import List
3+
from typing import List, Tuple
44
from mypy_extensions import mypyc_attr
55

66
from mypy.visitor import NodeVisitor
@@ -319,9 +319,22 @@ def has_return_statement(fdef: FuncBase) -> bool:
319319
return seeker.found
320320

321321

322-
class ReturnCollector(TraverserVisitor):
322+
class YieldSeeker(TraverserVisitor):
323+
def __init__(self) -> None:
324+
self.found = False
325+
326+
def visit_yield_expr(self, o: YieldExpr) -> None:
327+
self.found = True
328+
329+
330+
def has_yield_expression(fdef: FuncBase) -> bool:
331+
seeker = YieldSeeker()
332+
fdef.accept(seeker)
333+
return seeker.found
334+
335+
336+
class FuncCollectorBase(TraverserVisitor):
323337
def __init__(self) -> None:
324-
self.return_statements: List[ReturnStmt] = []
325338
self.inside_func = False
326339

327340
def visit_func_def(self, defn: FuncDef) -> None:
@@ -330,6 +343,12 @@ def visit_func_def(self, defn: FuncDef) -> None:
330343
super().visit_func_def(defn)
331344
self.inside_func = False
332345

346+
347+
class ReturnCollector(FuncCollectorBase):
348+
def __init__(self) -> None:
349+
super().__init__()
350+
self.return_statements: List[ReturnStmt] = []
351+
333352
def visit_return_stmt(self, stmt: ReturnStmt) -> None:
334353
self.return_statements.append(stmt)
335354

@@ -338,3 +357,24 @@ def all_return_statements(node: Node) -> List[ReturnStmt]:
338357
v = ReturnCollector()
339358
node.accept(v)
340359
return v.return_statements
360+
361+
362+
class YieldCollector(FuncCollectorBase):
363+
def __init__(self) -> None:
364+
super().__init__()
365+
self.in_assignment = False
366+
self.yield_expressions: List[Tuple[YieldExpr, bool]] = []
367+
368+
def visit_assignment_stmt(self, stmt: AssignmentStmt) -> None:
369+
self.in_assignment = True
370+
super().visit_assignment_stmt(stmt)
371+
self.in_assignment = False
372+
373+
def visit_yield_expr(self, expr: YieldExpr) -> None:
374+
self.yield_expressions.append((expr, self.in_assignment))
375+
376+
377+
def all_yield_expressions(node: Node) -> List[Tuple[YieldExpr, bool]]:
378+
v = YieldCollector()
379+
node.accept(v)
380+
return v.yield_expressions

test-data/unit/stubgen.test

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -944,6 +944,56 @@ def f(): ...
944944
[out]
945945
def f() -> None: ...
946946

947+
[case testFunctionYields]
948+
def f():
949+
yield 123
950+
def g():
951+
x = yield
952+
def h1():
953+
yield
954+
return
955+
def h2():
956+
yield
957+
return "abc"
958+
def all():
959+
x = yield 123
960+
return "abc"
961+
[out]
962+
from collections.abc import Generator
963+
from typing import Any
964+
965+
def f() -> Generator[Any, None, None]: ...
966+
def g() -> Generator[None, Any, None]: ...
967+
def h1() -> Generator[None, None, None]: ...
968+
def h2() -> Generator[None, None, Any]: ...
969+
def all() -> Generator[Any, Any, Any]: ...
970+
971+
[case testFunctionYieldsNone]
972+
def f():
973+
yield
974+
def g():
975+
yield None
976+
977+
[out]
978+
from collections.abc import Generator
979+
980+
def f() -> Generator[None, None, None]: ...
981+
def g() -> Generator[None, None, None]: ...
982+
983+
[case testGeneratorAlreadyDefined]
984+
class Generator:
985+
pass
986+
987+
def f():
988+
yield 123
989+
[out]
990+
from collections.abc import Generator as _Generator
991+
from typing import Any
992+
993+
class Generator: ...
994+
995+
def f() -> _Generator[Any, None, None]: ...
996+
947997
[case testCallable]
948998
from typing import Callable
949999

0 commit comments

Comments
 (0)