Skip to content

Improve stubgen #1899

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 7, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 42 additions & 27 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ def visit_func_def(self, o: FuncDef) -> None:
self.add('\n')
if not self.is_top_level():
self_inits = find_self_initializers(o)
for init in self_inits:
init_code = self.get_init(init)
for init, value in self_inits:
init_code = self.get_init(init, value)
if init_code:
self.add(init_code)
self.add("%sdef %s(" % (self._indent, o.name()))
Expand All @@ -254,31 +254,24 @@ def visit_func_def(self, o: FuncDef) -> None:
if init_stmt:
if kind == ARG_NAMED and '*' not in args:
args.append('*')
arg = '%s=' % name
rvalue = init_stmt.rvalue
if isinstance(rvalue, IntExpr):
arg += str(rvalue.value)
elif isinstance(rvalue, StrExpr):
arg += "''"
elif isinstance(rvalue, BytesExpr):
arg += "b''"
elif isinstance(rvalue, FloatExpr):
arg += "0.0"
elif isinstance(rvalue, UnaryExpr) and isinstance(rvalue.expr, IntExpr):
arg += '-%s' % rvalue.expr.value
elif isinstance(rvalue, NameExpr) and rvalue.name in ('None', 'True', 'False'):
arg += rvalue.name
else:
arg += '...'
typename = self.get_str_type_of_node(init_stmt.rvalue, True)
arg = '{}: {} = ...'.format(name, typename)
elif kind == ARG_STAR:
arg = '*%s' % name
elif kind == ARG_STAR2:
arg = '**%s' % name
else:
arg = name
args.append(arg)
retname = None
if o.name() == '__init__':
retname = 'None'
retfield = ''
if retname is not None:
retfield = ' -> ' + retname

self.add(', '.join(args))
self.add("): ...\n")
self.add("){}: ...\n".format(retfield))
self._state = FUNC

def visit_decorator(self, o: Decorator) -> None:
Expand Down Expand Up @@ -349,7 +342,7 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
found = False
for item in items:
if isinstance(item, NameExpr):
init = self.get_init(item.name)
init = self.get_init(item.name, o.rvalue)
if init:
found = True
if not sep and not self._indent and \
Expand Down Expand Up @@ -448,7 +441,7 @@ def visit_import(self, o: Import) -> None:
self.add_import_line('import %s as %s\n' % (id, target_name))
self.record_name(target_name)

def get_init(self, lvalue: str) -> str:
def get_init(self, lvalue: str, rvalue: Node) -> str:
"""Return initializer for a variable.

Return None if we've generated one already or if the variable is internal.
Expand All @@ -460,8 +453,8 @@ def get_init(self, lvalue: str) -> str:
if self.is_private_name(lvalue) or self.is_not_in_all(lvalue):
return None
self._vars[-1].append(lvalue)
self.add_typing_import('Any')
return '%s%s = ... # type: Any\n' % (self._indent, lvalue)
typename = self.get_str_type_of_node(rvalue)
return '%s%s = ... # type: %s\n' % (self._indent, lvalue, typename)

def add(self, string: str) -> None:
"""Add text to generated stub."""
Expand All @@ -484,7 +477,7 @@ def output(self) -> str:
"""Return the text for the stub."""
imports = ''
if self._imports:
imports += 'from typing import %s\n' % ", ".join(self._imports)
imports += 'from typing import %s\n' % ", ".join(sorted(self._imports))
if self._import_lines:
imports += ''.join(self._import_lines)
if imports and self._output:
Expand All @@ -507,6 +500,28 @@ def is_private_name(self, name: str) -> bool:
'__setstate__',
'__slots__'))

def get_str_type_of_node(self, rvalue: Node,
can_infer_optional: bool = False) -> str:
if isinstance(rvalue, IntExpr):
return 'int'
if isinstance(rvalue, StrExpr):
return 'str'
if isinstance(rvalue, BytesExpr):
return 'bytes'
if isinstance(rvalue, FloatExpr):
return 'float'
if isinstance(rvalue, UnaryExpr) and isinstance(rvalue.expr, IntExpr):
return 'int'
if isinstance(rvalue, NameExpr) and rvalue.name in ('True', 'False'):
return 'bool'
if can_infer_optional and \
isinstance(rvalue, NameExpr) and rvalue.name == 'None':
self.add_typing_import('Optional')
self.add_typing_import('Any')
return 'Optional[Any]'
self.add_typing_import('Any')
return 'Any'

def is_top_level(self) -> bool:
"""Are we processing the top level of a file?"""
return self._indent == ''
Expand All @@ -524,16 +539,16 @@ def is_recorded_name(self, name: str) -> bool:
return self.is_top_level() and name in self._toplevel_names


def find_self_initializers(fdef: FuncBase) -> List[str]:
results = [] # type: List[str]
def find_self_initializers(fdef: FuncBase) -> List[Tuple[str, Node]]:
results = [] # type: List[Tuple[str, Node]]

class SelfTraverser(mypy.traverser.TraverserVisitor):
def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
lvalue = o.lvalues[0]
if (isinstance(lvalue, MemberExpr) and
isinstance(lvalue.expr, NameExpr) and
lvalue.expr.name == 'self'):
results.append(lvalue.name)
results.append((lvalue.name, o.rvalue))

fdef.accept(SelfTraverser())
return results
Expand Down
85 changes: 41 additions & 44 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -20,38 +20,42 @@ def g(arg): ...
def f(a, b=2): ...
def g(b=-1, c=0): ...
[out]
def f(a, b=2): ...
def g(b=-1, c=0): ...
def f(a, b: int = ...): ...
def g(b: int = ..., c: int = ...): ...

[case testDefaultArgNone]
def f(x=None): ...
[out]
def f(x=None): ...
from typing import Any, Optional

def f(x: Optional[Any] = ...): ...

[case testDefaultArgBool]
def f(x=True, y=False): ...
[out]
def f(x=True, y=False): ...
def f(x: bool = ..., y: bool = ...): ...

[case testDefaultArgStr]
def f(x='foo'): ...
[out]
def f(x=''): ...
def f(x: str = ...): ...

[case testDefaultArgBytes]
def f(x=b'foo'): ...
[out]
def f(x=b''): ...
def f(x: bytes = ...): ...

[case testDefaultArgFloat]
def f(x=1.2): ...
[out]
def f(x=0.0): ...
def f(x: float = ...): ...

[case testDefaultArgOther]
def f(x=ord): ...
[out]
def f(x=...): ...
from typing import Any

def f(x: Any = ...): ...

[case testVarArgs]
def f(x, *y): ...
Expand All @@ -77,38 +81,30 @@ def g(): ...
[case testVariable]
x = 1
[out]
from typing import Any

x = ... # type: Any
x = ... # type: int

[case testMultipleVariable]
x = y = 1
[out]
from typing import Any

x = ... # type: Any
y = ... # type: Any
x = ... # type: int
y = ... # type: int

[case testClassVariable]
class C:
x = 1
[out]
from typing import Any

class C:
x = ... # type: Any
x = ... # type: int

[case testSelfAssignment]
class C:
def __init__(self):
self.x = 1
x.y = 2
[out]
from typing import Any

class C:
x = ... # type: Any
def __init__(self): ...
x = ... # type: int
def __init__(self) -> None: ...

[case testSelfAndClassBodyAssignment]
x = 1
Expand All @@ -118,13 +114,11 @@ class C:
self.x = 1
self.x = 1
[out]
from typing import Any

x = ... # type: Any
x = ... # type: int

class C:
x = ... # type: Any
def __init__(self): ...
x = ... # type: int
def __init__(self) -> None: ...

[case testEmptyClass]
class A: ...
Expand Down Expand Up @@ -189,8 +183,8 @@ y = ... # type: Any
def f(x, *, y=1): ...
def g(x, *, y=1, z=2): ...
[out]
def f(x, *, y=1): ...
def g(x, *, y=1, z=2): ...
def f(x, *, y: int = ...): ...
def g(x, *, y: int = ..., z: int = ...): ...

[case testProperty]
class A:
Expand Down Expand Up @@ -311,10 +305,8 @@ class A:
x = 1
def f(self): ...
[out]
from typing import Any

class A:
x = ... # type: Any
x = ... # type: int
def f(self): ...

[case testMultiplePrivateDefs]
Expand All @@ -332,32 +324,29 @@ from re import match, search, sub
__all__ = ['match', 'sub', 'x']
x = 1
[out]
from typing import Any
from re import match as match, sub as sub

x = ... # type: Any
x = ... # type: int

[case testExportModule_import]
import re
__all__ = ['re', 'x']
x = 1
y = 2
[out]
from typing import Any
import re as re

x = ... # type: Any
x = ... # type: int

[case testExportModuleAs_import]
import re as rex
__all__ = ['rex', 'x']
x = 1
y = 2
[out]
from typing import Any
import re as rex

x = ... # type: Any
x = ... # type: int

[case testExportModuleInPackage_import]
import urllib.parse as p
Expand All @@ -384,11 +373,9 @@ x = 1
class C:
def g(self): ...
[out]
from typing import Any

def f(): ...

x = ... # type: Any
x = ... # type: int

class C:
def g(self): ...
Expand Down Expand Up @@ -521,11 +508,9 @@ class A:
def f(self): ...
def g(self): ...
[out]
from typing import Any

class A:
class B:
x = ... # type: Any
x = ... # type: int
def f(self): ...
def g(self): ...

Expand Down Expand Up @@ -558,5 +543,17 @@ def syslog(a): pass
[out]
def syslog(a): ...

[case testInferOptionalOnlyFunc]
class A:
x = None
def __init__(self, a=None) -> None:
self.x = []
[out]
from typing import Any, Optional

class A:
x = ... # type: Any
def __init__(self, a: Optional[Any] = ...) -> None: ...

-- More features/fixes:
-- do not export deleted names