Skip to content

Commit d9c1062

Browse files
tharvikJukkaL
authored andcommitted
Improve stubgen (#1899)
* stubgen: add type to method arguments * stubgen: add type to assignments * stubgen: set return type for __init__ * stubgen: sort import output * stubgen: update test-data * remove specific for Optional * Revert "remove specific for Optional" This reverts commit 61f6a59. * only infer Optional for func
1 parent f8374bd commit d9c1062

File tree

2 files changed

+83
-71
lines changed

2 files changed

+83
-71
lines changed

mypy/stubgen.py

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ def visit_func_def(self, o: FuncDef) -> None:
239239
self.add('\n')
240240
if not self.is_top_level():
241241
self_inits = find_self_initializers(o)
242-
for init in self_inits:
243-
init_code = self.get_init(init)
242+
for init, value in self_inits:
243+
init_code = self.get_init(init, value)
244244
if init_code:
245245
self.add(init_code)
246246
self.add("%sdef %s(" % (self._indent, o.name()))
@@ -254,31 +254,24 @@ def visit_func_def(self, o: FuncDef) -> None:
254254
if init_stmt:
255255
if kind == ARG_NAMED and '*' not in args:
256256
args.append('*')
257-
arg = '%s=' % name
258-
rvalue = init_stmt.rvalue
259-
if isinstance(rvalue, IntExpr):
260-
arg += str(rvalue.value)
261-
elif isinstance(rvalue, StrExpr):
262-
arg += "''"
263-
elif isinstance(rvalue, BytesExpr):
264-
arg += "b''"
265-
elif isinstance(rvalue, FloatExpr):
266-
arg += "0.0"
267-
elif isinstance(rvalue, UnaryExpr) and isinstance(rvalue.expr, IntExpr):
268-
arg += '-%s' % rvalue.expr.value
269-
elif isinstance(rvalue, NameExpr) and rvalue.name in ('None', 'True', 'False'):
270-
arg += rvalue.name
271-
else:
272-
arg += '...'
257+
typename = self.get_str_type_of_node(init_stmt.rvalue, True)
258+
arg = '{}: {} = ...'.format(name, typename)
273259
elif kind == ARG_STAR:
274260
arg = '*%s' % name
275261
elif kind == ARG_STAR2:
276262
arg = '**%s' % name
277263
else:
278264
arg = name
279265
args.append(arg)
266+
retname = None
267+
if o.name() == '__init__':
268+
retname = 'None'
269+
retfield = ''
270+
if retname is not None:
271+
retfield = ' -> ' + retname
272+
280273
self.add(', '.join(args))
281-
self.add("): ...\n")
274+
self.add("){}: ...\n".format(retfield))
282275
self._state = FUNC
283276

284277
def visit_decorator(self, o: Decorator) -> None:
@@ -349,7 +342,7 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
349342
found = False
350343
for item in items:
351344
if isinstance(item, NameExpr):
352-
init = self.get_init(item.name)
345+
init = self.get_init(item.name, o.rvalue)
353346
if init:
354347
found = True
355348
if not sep and not self._indent and \
@@ -448,7 +441,7 @@ def visit_import(self, o: Import) -> None:
448441
self.add_import_line('import %s as %s\n' % (id, target_name))
449442
self.record_name(target_name)
450443

451-
def get_init(self, lvalue: str) -> str:
444+
def get_init(self, lvalue: str, rvalue: Node) -> str:
452445
"""Return initializer for a variable.
453446
454447
Return None if we've generated one already or if the variable is internal.
@@ -460,8 +453,8 @@ def get_init(self, lvalue: str) -> str:
460453
if self.is_private_name(lvalue) or self.is_not_in_all(lvalue):
461454
return None
462455
self._vars[-1].append(lvalue)
463-
self.add_typing_import('Any')
464-
return '%s%s = ... # type: Any\n' % (self._indent, lvalue)
456+
typename = self.get_str_type_of_node(rvalue)
457+
return '%s%s = ... # type: %s\n' % (self._indent, lvalue, typename)
465458

466459
def add(self, string: str) -> None:
467460
"""Add text to generated stub."""
@@ -484,7 +477,7 @@ def output(self) -> str:
484477
"""Return the text for the stub."""
485478
imports = ''
486479
if self._imports:
487-
imports += 'from typing import %s\n' % ", ".join(self._imports)
480+
imports += 'from typing import %s\n' % ", ".join(sorted(self._imports))
488481
if self._import_lines:
489482
imports += ''.join(self._import_lines)
490483
if imports and self._output:
@@ -507,6 +500,28 @@ def is_private_name(self, name: str) -> bool:
507500
'__setstate__',
508501
'__slots__'))
509502

503+
def get_str_type_of_node(self, rvalue: Node,
504+
can_infer_optional: bool = False) -> str:
505+
if isinstance(rvalue, IntExpr):
506+
return 'int'
507+
if isinstance(rvalue, StrExpr):
508+
return 'str'
509+
if isinstance(rvalue, BytesExpr):
510+
return 'bytes'
511+
if isinstance(rvalue, FloatExpr):
512+
return 'float'
513+
if isinstance(rvalue, UnaryExpr) and isinstance(rvalue.expr, IntExpr):
514+
return 'int'
515+
if isinstance(rvalue, NameExpr) and rvalue.name in ('True', 'False'):
516+
return 'bool'
517+
if can_infer_optional and \
518+
isinstance(rvalue, NameExpr) and rvalue.name == 'None':
519+
self.add_typing_import('Optional')
520+
self.add_typing_import('Any')
521+
return 'Optional[Any]'
522+
self.add_typing_import('Any')
523+
return 'Any'
524+
510525
def is_top_level(self) -> bool:
511526
"""Are we processing the top level of a file?"""
512527
return self._indent == ''
@@ -524,16 +539,16 @@ def is_recorded_name(self, name: str) -> bool:
524539
return self.is_top_level() and name in self._toplevel_names
525540

526541

527-
def find_self_initializers(fdef: FuncBase) -> List[str]:
528-
results = [] # type: List[str]
542+
def find_self_initializers(fdef: FuncBase) -> List[Tuple[str, Node]]:
543+
results = [] # type: List[Tuple[str, Node]]
529544

530545
class SelfTraverser(mypy.traverser.TraverserVisitor):
531546
def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
532547
lvalue = o.lvalues[0]
533548
if (isinstance(lvalue, MemberExpr) and
534549
isinstance(lvalue.expr, NameExpr) and
535550
lvalue.expr.name == 'self'):
536-
results.append(lvalue.name)
551+
results.append((lvalue.name, o.rvalue))
537552

538553
fdef.accept(SelfTraverser())
539554
return results

test-data/unit/stubgen.test

Lines changed: 41 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -20,38 +20,42 @@ def g(arg): ...
2020
def f(a, b=2): ...
2121
def g(b=-1, c=0): ...
2222
[out]
23-
def f(a, b=2): ...
24-
def g(b=-1, c=0): ...
23+
def f(a, b: int = ...): ...
24+
def g(b: int = ..., c: int = ...): ...
2525

2626
[case testDefaultArgNone]
2727
def f(x=None): ...
2828
[out]
29-
def f(x=None): ...
29+
from typing import Any, Optional
30+
31+
def f(x: Optional[Any] = ...): ...
3032

3133
[case testDefaultArgBool]
3234
def f(x=True, y=False): ...
3335
[out]
34-
def f(x=True, y=False): ...
36+
def f(x: bool = ..., y: bool = ...): ...
3537

3638
[case testDefaultArgStr]
3739
def f(x='foo'): ...
3840
[out]
39-
def f(x=''): ...
41+
def f(x: str = ...): ...
4042

4143
[case testDefaultArgBytes]
4244
def f(x=b'foo'): ...
4345
[out]
44-
def f(x=b''): ...
46+
def f(x: bytes = ...): ...
4547

4648
[case testDefaultArgFloat]
4749
def f(x=1.2): ...
4850
[out]
49-
def f(x=0.0): ...
51+
def f(x: float = ...): ...
5052

5153
[case testDefaultArgOther]
5254
def f(x=ord): ...
5355
[out]
54-
def f(x=...): ...
56+
from typing import Any
57+
58+
def f(x: Any = ...): ...
5559

5660
[case testVarArgs]
5761
def f(x, *y): ...
@@ -77,38 +81,30 @@ def g(): ...
7781
[case testVariable]
7882
x = 1
7983
[out]
80-
from typing import Any
81-
82-
x = ... # type: Any
84+
x = ... # type: int
8385

8486
[case testMultipleVariable]
8587
x = y = 1
8688
[out]
87-
from typing import Any
88-
89-
x = ... # type: Any
90-
y = ... # type: Any
89+
x = ... # type: int
90+
y = ... # type: int
9191

9292
[case testClassVariable]
9393
class C:
9494
x = 1
9595
[out]
96-
from typing import Any
97-
9896
class C:
99-
x = ... # type: Any
97+
x = ... # type: int
10098

10199
[case testSelfAssignment]
102100
class C:
103101
def __init__(self):
104102
self.x = 1
105103
x.y = 2
106104
[out]
107-
from typing import Any
108-
109105
class C:
110-
x = ... # type: Any
111-
def __init__(self): ...
106+
x = ... # type: int
107+
def __init__(self) -> None: ...
112108

113109
[case testSelfAndClassBodyAssignment]
114110
x = 1
@@ -118,13 +114,11 @@ class C:
118114
self.x = 1
119115
self.x = 1
120116
[out]
121-
from typing import Any
122-
123-
x = ... # type: Any
117+
x = ... # type: int
124118

125119
class C:
126-
x = ... # type: Any
127-
def __init__(self): ...
120+
x = ... # type: int
121+
def __init__(self) -> None: ...
128122

129123
[case testEmptyClass]
130124
class A: ...
@@ -189,8 +183,8 @@ y = ... # type: Any
189183
def f(x, *, y=1): ...
190184
def g(x, *, y=1, z=2): ...
191185
[out]
192-
def f(x, *, y=1): ...
193-
def g(x, *, y=1, z=2): ...
186+
def f(x, *, y: int = ...): ...
187+
def g(x, *, y: int = ..., z: int = ...): ...
194188

195189
[case testProperty]
196190
class A:
@@ -311,10 +305,8 @@ class A:
311305
x = 1
312306
def f(self): ...
313307
[out]
314-
from typing import Any
315-
316308
class A:
317-
x = ... # type: Any
309+
x = ... # type: int
318310
def f(self): ...
319311

320312
[case testMultiplePrivateDefs]
@@ -332,32 +324,29 @@ from re import match, search, sub
332324
__all__ = ['match', 'sub', 'x']
333325
x = 1
334326
[out]
335-
from typing import Any
336327
from re import match as match, sub as sub
337328

338-
x = ... # type: Any
329+
x = ... # type: int
339330

340331
[case testExportModule_import]
341332
import re
342333
__all__ = ['re', 'x']
343334
x = 1
344335
y = 2
345336
[out]
346-
from typing import Any
347337
import re as re
348338

349-
x = ... # type: Any
339+
x = ... # type: int
350340

351341
[case testExportModuleAs_import]
352342
import re as rex
353343
__all__ = ['rex', 'x']
354344
x = 1
355345
y = 2
356346
[out]
357-
from typing import Any
358347
import re as rex
359348

360-
x = ... # type: Any
349+
x = ... # type: int
361350

362351
[case testExportModuleInPackage_import]
363352
import urllib.parse as p
@@ -384,11 +373,9 @@ x = 1
384373
class C:
385374
def g(self): ...
386375
[out]
387-
from typing import Any
388-
389376
def f(): ...
390377

391-
x = ... # type: Any
378+
x = ... # type: int
392379

393380
class C:
394381
def g(self): ...
@@ -521,11 +508,9 @@ class A:
521508
def f(self): ...
522509
def g(self): ...
523510
[out]
524-
from typing import Any
525-
526511
class A:
527512
class B:
528-
x = ... # type: Any
513+
x = ... # type: int
529514
def f(self): ...
530515
def g(self): ...
531516

@@ -558,5 +543,17 @@ def syslog(a): pass
558543
[out]
559544
def syslog(a): ...
560545

546+
[case testInferOptionalOnlyFunc]
547+
class A:
548+
x = None
549+
def __init__(self, a=None) -> None:
550+
self.x = []
551+
[out]
552+
from typing import Any, Optional
553+
554+
class A:
555+
x = ... # type: Any
556+
def __init__(self, a: Optional[Any] = ...) -> None: ...
557+
561558
-- More features/fixes:
562559
-- do not export deleted names

0 commit comments

Comments
 (0)