Skip to content

stubgen: Add support for yield statements #10745

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 40 additions & 8 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
from mypy.find_sources import create_source_list, InvalidSourceList
from mypy.build import build
from mypy.errors import CompileError, Errors
from mypy.traverser import has_return_statement
from mypy.traverser import all_yield_expressions, has_return_statement, has_yield_expression
from mypy.moduleinspect import ModuleInspect


Expand Down Expand Up @@ -550,13 +550,17 @@ def visit_mypy_file(self, o: MypyFile) -> None:
self.path = o.path
self.defined_names = find_defined_names(o)
self.referenced_names = find_referenced_names(o)
typing_imports = ["Any", "Optional", "TypeVar"]
for t in typing_imports:
if t not in self.defined_names:
alias = None
else:
alias = '_' + t
self.import_tracker.add_import_from("typing", [(t, alias)])
known_imports = {
"typing": ["Any", "TypeVar"],
"collections.abc": ["Generator"],
}
for pkg, imports in known_imports.items():
for t in imports:
if t not in self.defined_names:
alias = None
else:
alias = '_' + t
self.import_tracker.add_import_from(pkg, [(t, alias)])
super().visit_mypy_file(o)
undefined_names = [name for name in self._all_ or []
if name not in self._toplevel_names]
Expand Down Expand Up @@ -662,6 +666,23 @@ def visit_func_def(self, o: FuncDef, is_abstract: bool = False,
# Always assume abstract methods return Any unless explicitly annotated. Also
# some dunder methods should not have a None return type.
retname = None # implicit Any
elif has_yield_expression(o):
self.add_abc_import('Generator')
yield_name = 'None'
send_name = 'None'
return_name = 'None'
for expr, in_assignment in all_yield_expressions(o):
if expr.expr is not None and not self.is_none_expr(expr.expr):
self.add_typing_import('Any')
yield_name = 'Any'
if in_assignment:
self.add_typing_import('Any')
send_name = 'Any'
if has_return_statement(o):
self.add_typing_import('Any')
return_name = 'Any'
generator_name = self.typing_name('Generator')
retname = f'{generator_name}[{yield_name}, {send_name}, {return_name}]'
elif not has_return_statement(o) and not is_abstract:
retname = 'None'
retfield = ''
Expand All @@ -672,6 +693,9 @@ def visit_func_def(self, o: FuncDef, is_abstract: bool = False,
self.add("){}: ...\n".format(retfield))
self._state = FUNC

def is_none_expr(self, expr: Expression) -> bool:
return isinstance(expr, NameExpr) and expr.name == "None"

def visit_decorator(self, o: Decorator) -> None:
if self.is_private_name(o.func.name, o.func.fullname):
return
Expand Down Expand Up @@ -1107,6 +1131,14 @@ def add_typing_import(self, name: str) -> None:
name = self.typing_name(name)
self.import_tracker.require_name(name)

def add_abc_import(self, name: str) -> None:
"""Add a name to be imported from collections.abc, unless it's imported already.

The import will be internal to the stub.
"""
name = self.typing_name(name)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name typing_name is not really correct anymore, since we use it for collections.abc as well. Should we rename it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel strongly, it's still a typing concept even if it lives in collections.abc.

self.import_tracker.require_name(name)

def add_import_line(self, line: str) -> None:
"""Add a line of text to the import section, unless it's already there."""
if line not in self._import_lines:
Expand Down
46 changes: 43 additions & 3 deletions mypy/traverser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Generic node traverser visitor"""

from typing import List
from typing import List, Tuple
from mypy_extensions import mypyc_attr

from mypy.visitor import NodeVisitor
Expand Down Expand Up @@ -319,9 +319,22 @@ def has_return_statement(fdef: FuncBase) -> bool:
return seeker.found


class ReturnCollector(TraverserVisitor):
class YieldSeeker(TraverserVisitor):
def __init__(self) -> None:
self.found = False

def visit_yield_expr(self, o: YieldExpr) -> None:
self.found = True


def has_yield_expression(fdef: FuncBase) -> bool:
seeker = YieldSeeker()
fdef.accept(seeker)
return seeker.found


class FuncCollectorBase(TraverserVisitor):
def __init__(self) -> None:
self.return_statements: List[ReturnStmt] = []
self.inside_func = False

def visit_func_def(self, defn: FuncDef) -> None:
Expand All @@ -330,6 +343,12 @@ def visit_func_def(self, defn: FuncDef) -> None:
super().visit_func_def(defn)
self.inside_func = False


class ReturnCollector(FuncCollectorBase):
def __init__(self) -> None:
super().__init__()
self.return_statements: List[ReturnStmt] = []

def visit_return_stmt(self, stmt: ReturnStmt) -> None:
self.return_statements.append(stmt)

Expand All @@ -338,3 +357,24 @@ def all_return_statements(node: Node) -> List[ReturnStmt]:
v = ReturnCollector()
node.accept(v)
return v.return_statements


class YieldCollector(FuncCollectorBase):
def __init__(self) -> None:
super().__init__()
self.in_assignment = False
self.yield_expressions: List[Tuple[YieldExpr, bool]] = []

def visit_assignment_stmt(self, stmt: AssignmentStmt) -> None:
self.in_assignment = True
super().visit_assignment_stmt(stmt)
self.in_assignment = False

def visit_yield_expr(self, expr: YieldExpr) -> None:
self.yield_expressions.append((expr, self.in_assignment))


def all_yield_expressions(node: Node) -> List[Tuple[YieldExpr, bool]]:
v = YieldCollector()
node.accept(v)
return v.yield_expressions
50 changes: 50 additions & 0 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,56 @@ def f(): ...
[out]
def f() -> None: ...

[case testFunctionYields]
def f():
yield 123
def g():
x = yield
def h1():
yield
return
def h2():
yield
return "abc"
def all():
x = yield 123
return "abc"
[out]
from collections.abc import Generator
from typing import Any

def f() -> Generator[Any, None, None]: ...
def g() -> Generator[None, Any, None]: ...
def h1() -> Generator[None, None, None]: ...
def h2() -> Generator[None, None, Any]: ...
def all() -> Generator[Any, Any, Any]: ...

[case testFunctionYieldsNone]
def f():
yield
def g():
yield None

[out]
from collections.abc import Generator

def f() -> Generator[None, None, None]: ...
def g() -> Generator[None, None, None]: ...

[case testGeneratorAlreadyDefined]
class Generator:
pass

def f():
yield 123
[out]
from collections.abc import Generator as _Generator
from typing import Any

class Generator: ...

def f() -> _Generator[Any, None, None]: ...

[case testCallable]
from typing import Callable

Expand Down