22
22
YieldExpr , ExecStmt , Argument , BackquoteExpr , AwaitExpr ,
23
23
)
24
24
from mypy .types import Type , FunctionLike , Instance
25
+ from mypy .traverser import TraverserVisitor
25
26
from mypy .visitor import NodeVisitor
26
27
27
28
@@ -36,7 +37,7 @@ class TransformVisitor(NodeVisitor[Node]):
36
37
37
38
* Do not duplicate TypeInfo nodes. This would generally not be desirable.
38
39
* Only update some name binding cross-references, but only those that
39
- refer to Var nodes, not those targeting ClassDef, TypeInfo or FuncDef
40
+ refer to Var or FuncDef nodes, not those targeting ClassDef or TypeInfo
40
41
nodes.
41
42
* Types are not transformed, but you can override type() to also perform
42
43
type transformation.
@@ -48,6 +49,11 @@ def __init__(self) -> None:
48
49
# There may be multiple references to a Var node. Keep track of
49
50
# Var translations using a dictionary.
50
51
self .var_map = {} # type: Dict[Var, Var]
52
+ # These are uninitialized placeholder nodes used temporarily for nested
53
+ # functions while we are transforming a top-level function. This maps an
54
+ # untransformed node to a placeholder (which will later become the
55
+ # transformed node).
56
+ self .func_placeholder_map = {} # type: Dict[FuncDef, FuncDef]
51
57
52
58
def visit_mypy_file (self , node : MypyFile ) -> Node :
53
59
# NOTE: The 'names' and 'imports' instance variables will be empty!
@@ -98,6 +104,18 @@ def copy_argument(self, argument: Argument) -> Argument:
98
104
99
105
def visit_func_def (self , node : FuncDef ) -> FuncDef :
100
106
# Note that a FuncDef must be transformed to a FuncDef.
107
+
108
+ # These contortions are needed to handle the case of recursive
109
+ # references inside the function being transformed.
110
+ # Set up placholder nodes for references within this function
111
+ # to other functions defined inside it.
112
+ # Don't create an entry for this function itself though,
113
+ # since we want self-references to point to the original
114
+ # function if this is the top-level node we are transforming.
115
+ init = FuncMapInitializer (self )
116
+ for stmt in node .body .body :
117
+ stmt .accept (init )
118
+
101
119
new = FuncDef (node .name (),
102
120
[self .copy_argument (arg ) for arg in node .arguments ],
103
121
self .block (node .body ),
@@ -113,7 +131,17 @@ def visit_func_def(self, node: FuncDef) -> FuncDef:
113
131
new .is_class = node .is_class
114
132
new .is_property = node .is_property
115
133
new .original_def = node .original_def
116
- return new
134
+
135
+ if node in self .func_placeholder_map :
136
+ # There is a placeholder definition for this function. Replace
137
+ # the attributes of the placeholder with those form the transformed
138
+ # function. We know that the classes will be identical (otherwise
139
+ # this wouldn't work).
140
+ result = self .func_placeholder_map [node ]
141
+ result .__dict__ = new .__dict__
142
+ return result
143
+ else :
144
+ return new
117
145
118
146
def visit_func_expr (self , node : FuncExpr ) -> Node :
119
147
new = FuncExpr ([self .copy_argument (arg ) for arg in node .arguments ],
@@ -330,6 +358,9 @@ def copy_ref(self, new: RefExpr, original: RefExpr) -> None:
330
358
target = original .node
331
359
if isinstance (target , Var ):
332
360
target = self .visit_var (target )
361
+ elif isinstance (target , FuncDef ):
362
+ # Use a placeholder node for the function if it exists.
363
+ target = self .func_placeholder_map .get (target , target )
333
364
new .node = target
334
365
new .is_def = original .is_def
335
366
@@ -527,3 +558,20 @@ def types(self, types: List[Type]) -> List[Type]:
527
558
528
559
def optional_types (self , types : List [Type ]) -> List [Type ]:
529
560
return [self .optional_type (type ) for type in types ]
561
+
562
+
563
+ class FuncMapInitializer (TraverserVisitor ):
564
+ """This traverser creates mappings from nested FuncDefs to placeholder FuncDefs.
565
+
566
+ The placholders will later be replaced with transformed nodes.
567
+ """
568
+
569
+ def __init__ (self , transformer : TransformVisitor ) -> None :
570
+ self .transformer = transformer
571
+
572
+ def visit_func_def (self , node : FuncDef ) -> None :
573
+ if node not in self .transformer .func_placeholder_map :
574
+ # Haven't seen this FuncDef before, so create a placeholder node.
575
+ self .transformer .func_placeholder_map [node ] = FuncDef (
576
+ node .name (), node .arguments , node .body , None )
577
+ super ().visit_func_def (node )
0 commit comments