@@ -254,7 +254,7 @@ def visit_func_def(self, o: FuncDef) -> None:
254
254
if init_stmt :
255
255
if kind == ARG_NAMED and '*' not in args :
256
256
args .append ('*' )
257
- typename = self .get_str_type_of_node (init_stmt .rvalue )
257
+ typename = self .get_str_type_of_node (init_stmt .rvalue , True )
258
258
arg = '{}: {} = ...' .format (name , typename )
259
259
elif kind == ARG_STAR :
260
260
arg = '*%s' % name
@@ -500,7 +500,8 @@ def is_private_name(self, name: str) -> bool:
500
500
'__setstate__' ,
501
501
'__slots__' ))
502
502
503
- def get_str_type_of_node (self , rvalue : Node ) -> str :
503
+ def get_str_type_of_node (self , rvalue : Node ,
504
+ can_infer_optional : bool = False ) -> str :
504
505
if isinstance (rvalue , IntExpr ):
505
506
return 'int'
506
507
if isinstance (rvalue , StrExpr ):
@@ -513,7 +514,8 @@ def get_str_type_of_node(self, rvalue: Node) -> str:
513
514
return 'int'
514
515
if isinstance (rvalue , NameExpr ) and rvalue .name in ('True' , 'False' ):
515
516
return 'bool'
516
- if isinstance (rvalue , NameExpr ) and rvalue .name == 'None' :
517
+ if can_infer_optional and \
518
+ isinstance (rvalue , NameExpr ) and rvalue .name == 'None' :
517
519
self .add_typing_import ('Optional' )
518
520
self .add_typing_import ('Any' )
519
521
return 'Optional[Any]'
0 commit comments