Skip to content

Commit 675e074

Browse files
authored
Rewrite references to inner functions in treetransform (#2065)
Fixes #1323. (This is an updated version of @rwbarton's PR #1791).
1 parent 7ab19c1 commit 675e074

File tree

2 files changed

+80
-2
lines changed

2 files changed

+80
-2
lines changed

mypy/treetransform.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
YieldExpr, ExecStmt, Argument, BackquoteExpr, AwaitExpr,
2323
)
2424
from mypy.types import Type, FunctionLike, Instance
25+
from mypy.traverser import TraverserVisitor
2526
from mypy.visitor import NodeVisitor
2627

2728

@@ -36,7 +37,7 @@ class TransformVisitor(NodeVisitor[Node]):
3637
3738
* Do not duplicate TypeInfo nodes. This would generally not be desirable.
3839
* Only update some name binding cross-references, but only those that
39-
refer to Var nodes, not those targeting ClassDef, TypeInfo or FuncDef
40+
refer to Var or FuncDef nodes, not those targeting ClassDef or TypeInfo
4041
nodes.
4142
* Types are not transformed, but you can override type() to also perform
4243
type transformation.
@@ -48,6 +49,11 @@ def __init__(self) -> None:
4849
# There may be multiple references to a Var node. Keep track of
4950
# Var translations using a dictionary.
5051
self.var_map = {} # type: Dict[Var, Var]
52+
# These are uninitialized placeholder nodes used temporarily for nested
53+
# functions while we are transforming a top-level function. This maps an
54+
# untransformed node to a placeholder (which will later become the
55+
# transformed node).
56+
self.func_placeholder_map = {} # type: Dict[FuncDef, FuncDef]
5157

5258
def visit_mypy_file(self, node: MypyFile) -> Node:
5359
# NOTE: The 'names' and 'imports' instance variables will be empty!
@@ -98,6 +104,18 @@ def copy_argument(self, argument: Argument) -> Argument:
98104

99105
def visit_func_def(self, node: FuncDef) -> FuncDef:
100106
# Note that a FuncDef must be transformed to a FuncDef.
107+
108+
# These contortions are needed to handle the case of recursive
109+
# references inside the function being transformed.
110+
# Set up placholder nodes for references within this function
111+
# to other functions defined inside it.
112+
# Don't create an entry for this function itself though,
113+
# since we want self-references to point to the original
114+
# function if this is the top-level node we are transforming.
115+
init = FuncMapInitializer(self)
116+
for stmt in node.body.body:
117+
stmt.accept(init)
118+
101119
new = FuncDef(node.name(),
102120
[self.copy_argument(arg) for arg in node.arguments],
103121
self.block(node.body),
@@ -113,7 +131,17 @@ def visit_func_def(self, node: FuncDef) -> FuncDef:
113131
new.is_class = node.is_class
114132
new.is_property = node.is_property
115133
new.original_def = node.original_def
116-
return new
134+
135+
if node in self.func_placeholder_map:
136+
# There is a placeholder definition for this function. Replace
137+
# the attributes of the placeholder with those form the transformed
138+
# function. We know that the classes will be identical (otherwise
139+
# this wouldn't work).
140+
result = self.func_placeholder_map[node]
141+
result.__dict__ = new.__dict__
142+
return result
143+
else:
144+
return new
117145

118146
def visit_func_expr(self, node: FuncExpr) -> Node:
119147
new = FuncExpr([self.copy_argument(arg) for arg in node.arguments],
@@ -330,6 +358,9 @@ def copy_ref(self, new: RefExpr, original: RefExpr) -> None:
330358
target = original.node
331359
if isinstance(target, Var):
332360
target = self.visit_var(target)
361+
elif isinstance(target, FuncDef):
362+
# Use a placeholder node for the function if it exists.
363+
target = self.func_placeholder_map.get(target, target)
333364
new.node = target
334365
new.is_def = original.is_def
335366

@@ -527,3 +558,20 @@ def types(self, types: List[Type]) -> List[Type]:
527558

528559
def optional_types(self, types: List[Type]) -> List[Type]:
529560
return [self.optional_type(type) for type in types]
561+
562+
563+
class FuncMapInitializer(TraverserVisitor):
564+
"""This traverser creates mappings from nested FuncDefs to placeholder FuncDefs.
565+
566+
The placholders will later be replaced with transformed nodes.
567+
"""
568+
569+
def __init__(self, transformer: TransformVisitor) -> None:
570+
self.transformer = transformer
571+
572+
def visit_func_def(self, node: FuncDef) -> None:
573+
if node not in self.transformer.func_placeholder_map:
574+
# Haven't seen this FuncDef before, so create a placeholder node.
575+
self.transformer.func_placeholder_map[node] = FuncDef(
576+
node.name(), node.arguments, node.body, None)
577+
super().visit_func_def(node)

test-data/unit/check-typevar-values.test

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,3 +479,33 @@ a = g
479479
b = g
480480
b = g
481481
b = f # E: Incompatible types in assignment (expression has type Callable[[T], T], variable has type Callable[[U], U])
482+
483+
[case testInnerFunctionWithTypevarValues]
484+
from typing import TypeVar
485+
T = TypeVar('T', int, str)
486+
U = TypeVar('U', int, str)
487+
def outer(x: T) -> T:
488+
def inner(y: T) -> T:
489+
return x
490+
def inner2(y: U) -> U:
491+
return y
492+
inner(x)
493+
inner(3) # E: Argument 1 to "inner" has incompatible type "int"; expected "str"
494+
inner2(x)
495+
inner2(3)
496+
outer(3)
497+
return x
498+
[out]
499+
main: note: In function "outer":
500+
501+
[case testInnerFunctionMutualRecursionWithTypevarValues]
502+
from typing import TypeVar
503+
T = TypeVar('T', int, str)
504+
def outer(x: T) -> T:
505+
def inner1(y: T) -> T:
506+
return inner2(y)
507+
def inner2(y: T) -> T:
508+
return inner1('a') # E: Argument 1 to "inner1" has incompatible type "str"; expected "int"
509+
return inner1(x)
510+
[out]
511+
main: note: In function "inner2":

0 commit comments

Comments
 (0)