@@ -239,8 +239,8 @@ def visit_func_def(self, o: FuncDef) -> None:
239
239
self .add ('\n ' )
240
240
if not self .is_top_level ():
241
241
self_inits = find_self_initializers (o )
242
- for init in self_inits :
243
- init_code = self .get_init (init )
242
+ for init , value in self_inits :
243
+ init_code = self .get_init (init , value )
244
244
if init_code :
245
245
self .add (init_code )
246
246
self .add ("%sdef %s(" % (self ._indent , o .name ()))
@@ -254,31 +254,24 @@ def visit_func_def(self, o: FuncDef) -> None:
254
254
if init_stmt :
255
255
if kind == ARG_NAMED and '*' not in args :
256
256
args .append ('*' )
257
- arg = '%s=' % name
258
- rvalue = init_stmt .rvalue
259
- if isinstance (rvalue , IntExpr ):
260
- arg += str (rvalue .value )
261
- elif isinstance (rvalue , StrExpr ):
262
- arg += "''"
263
- elif isinstance (rvalue , BytesExpr ):
264
- arg += "b''"
265
- elif isinstance (rvalue , FloatExpr ):
266
- arg += "0.0"
267
- elif isinstance (rvalue , UnaryExpr ) and isinstance (rvalue .expr , IntExpr ):
268
- arg += '-%s' % rvalue .expr .value
269
- elif isinstance (rvalue , NameExpr ) and rvalue .name in ('None' , 'True' , 'False' ):
270
- arg += rvalue .name
271
- else :
272
- arg += '...'
257
+ typename = self .get_str_type_of_node (init_stmt .rvalue , True )
258
+ arg = '{}: {} = ...' .format (name , typename )
273
259
elif kind == ARG_STAR :
274
260
arg = '*%s' % name
275
261
elif kind == ARG_STAR2 :
276
262
arg = '**%s' % name
277
263
else :
278
264
arg = name
279
265
args .append (arg )
266
+ retname = None
267
+ if o .name () == '__init__' :
268
+ retname = 'None'
269
+ retfield = ''
270
+ if retname is not None :
271
+ retfield = ' -> ' + retname
272
+
280
273
self .add (', ' .join (args ))
281
- self .add ("): ...\n " )
274
+ self .add ("){} : ...\n " . format ( retfield ) )
282
275
self ._state = FUNC
283
276
284
277
def visit_decorator (self , o : Decorator ) -> None :
@@ -349,7 +342,7 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
349
342
found = False
350
343
for item in items :
351
344
if isinstance (item , NameExpr ):
352
- init = self .get_init (item .name )
345
+ init = self .get_init (item .name , o . rvalue )
353
346
if init :
354
347
found = True
355
348
if not sep and not self ._indent and \
@@ -448,7 +441,7 @@ def visit_import(self, o: Import) -> None:
448
441
self .add_import_line ('import %s as %s\n ' % (id , target_name ))
449
442
self .record_name (target_name )
450
443
451
- def get_init (self , lvalue : str ) -> str :
444
+ def get_init (self , lvalue : str , rvalue : Node ) -> str :
452
445
"""Return initializer for a variable.
453
446
454
447
Return None if we've generated one already or if the variable is internal.
@@ -460,8 +453,8 @@ def get_init(self, lvalue: str) -> str:
460
453
if self .is_private_name (lvalue ) or self .is_not_in_all (lvalue ):
461
454
return None
462
455
self ._vars [- 1 ].append (lvalue )
463
- self .add_typing_import ( 'Any' )
464
- return '%s%s = ... # type: Any \n ' % (self ._indent , lvalue )
456
+ typename = self .get_str_type_of_node ( rvalue )
457
+ return '%s%s = ... # type: %s \n ' % (self ._indent , lvalue , typename )
465
458
466
459
def add (self , string : str ) -> None :
467
460
"""Add text to generated stub."""
@@ -484,7 +477,7 @@ def output(self) -> str:
484
477
"""Return the text for the stub."""
485
478
imports = ''
486
479
if self ._imports :
487
- imports += 'from typing import %s\n ' % ", " .join (self ._imports )
480
+ imports += 'from typing import %s\n ' % ", " .join (sorted ( self ._imports ) )
488
481
if self ._import_lines :
489
482
imports += '' .join (self ._import_lines )
490
483
if imports and self ._output :
@@ -507,6 +500,28 @@ def is_private_name(self, name: str) -> bool:
507
500
'__setstate__' ,
508
501
'__slots__' ))
509
502
503
+ def get_str_type_of_node (self , rvalue : Node ,
504
+ can_infer_optional : bool = False ) -> str :
505
+ if isinstance (rvalue , IntExpr ):
506
+ return 'int'
507
+ if isinstance (rvalue , StrExpr ):
508
+ return 'str'
509
+ if isinstance (rvalue , BytesExpr ):
510
+ return 'bytes'
511
+ if isinstance (rvalue , FloatExpr ):
512
+ return 'float'
513
+ if isinstance (rvalue , UnaryExpr ) and isinstance (rvalue .expr , IntExpr ):
514
+ return 'int'
515
+ if isinstance (rvalue , NameExpr ) and rvalue .name in ('True' , 'False' ):
516
+ return 'bool'
517
+ if can_infer_optional and \
518
+ isinstance (rvalue , NameExpr ) and rvalue .name == 'None' :
519
+ self .add_typing_import ('Optional' )
520
+ self .add_typing_import ('Any' )
521
+ return 'Optional[Any]'
522
+ self .add_typing_import ('Any' )
523
+ return 'Any'
524
+
510
525
def is_top_level (self ) -> bool :
511
526
"""Are we processing the top level of a file?"""
512
527
return self ._indent == ''
@@ -524,16 +539,16 @@ def is_recorded_name(self, name: str) -> bool:
524
539
return self .is_top_level () and name in self ._toplevel_names
525
540
526
541
527
- def find_self_initializers (fdef : FuncBase ) -> List [str ]:
528
- results = [] # type: List[str]
542
+ def find_self_initializers (fdef : FuncBase ) -> List [Tuple [ str , Node ] ]:
543
+ results = [] # type: List[Tuple[ str, Node] ]
529
544
530
545
class SelfTraverser (mypy .traverser .TraverserVisitor ):
531
546
def visit_assignment_stmt (self , o : AssignmentStmt ) -> None :
532
547
lvalue = o .lvalues [0 ]
533
548
if (isinstance (lvalue , MemberExpr ) and
534
549
isinstance (lvalue .expr , NameExpr ) and
535
550
lvalue .expr .name == 'self' ):
536
- results .append (lvalue .name )
551
+ results .append (( lvalue .name , o . rvalue ) )
537
552
538
553
fdef .accept (SelfTraverser ())
539
554
return results
0 commit comments