Skip to content

Commit 20718d7

Browse files
committed
Modify typechecker to record all module references
This commit modifies the typechecking stage to record all module references that a particular module makes. This information will be used in a future commit. The motivation for this change is that it turns out that the set of all module references is often larger then the set of explicitly imported modules, and having this information would allow us to make incremental mode more aggressive.
1 parent b9f677b commit 20718d7

File tree

2 files changed

+57
-2
lines changed

2 files changed

+57
-2
lines changed

mypy/checker.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ class TypeChecker(NodeVisitor[Type]):
8686
msg = None # type: MessageBuilder
8787
# Types of type checked nodes
8888
type_map = None # type: Dict[Node, Type]
89+
# Types of type checked nodes within this specific module
90+
module_type_map = None # type: Dict[Node, Type]
8991

9092
# Helper for managing conditional types
9193
binder = None # type: ConditionalTypeBinder
@@ -117,6 +119,10 @@ class TypeChecker(NodeVisitor[Type]):
117119
is_typeshed_stub = False
118120
options = None # type: Options
119121

122+
# The set of all dependencies (suppressed or not) that this module accesses, either
123+
# directly or indirectly.
124+
module_refs = None # type: Set[str]
125+
120126
def __init__(self, errors: Errors, modules: Dict[str, MypyFile], options: Options) -> None:
121127
"""Construct a type checker.
122128
@@ -127,6 +133,7 @@ def __init__(self, errors: Errors, modules: Dict[str, MypyFile], options: Option
127133
self.options = options
128134
self.msg = MessageBuilder(errors, modules)
129135
self.type_map = {}
136+
self.module_type_map = {}
130137
self.binder = ConditionalTypeBinder()
131138
self.expr_checker = mypy.checkexpr.ExpressionChecker(self, self.msg)
132139
self.return_types = []
@@ -138,6 +145,7 @@ def __init__(self, errors: Errors, modules: Dict[str, MypyFile], options: Option
138145
self.deferred_nodes = []
139146
self.pass_num = 0
140147
self.current_node_deferred = False
148+
self.module_refs = set()
141149

142150
def visit_file(self, file_node: MypyFile, path: str) -> None:
143151
"""Type check a mypy file with the given path."""
@@ -148,6 +156,8 @@ def visit_file(self, file_node: MypyFile, path: str) -> None:
148156
self.weak_opts = file_node.weak_opts
149157
self.enter_partial_types()
150158
self.is_typeshed_stub = self.errors.is_typeshed_file(path)
159+
self.module_type_map = {}
160+
self.module_refs = set()
151161

152162
for d in file_node.defs:
153163
self.accept(d)
@@ -2186,6 +2196,8 @@ def check_type_equivalency(self, t1: Type, t2: Type, node: Context,
21862196
def store_type(self, node: Node, typ: Type) -> None:
21872197
"""Store the type of a node in the type map."""
21882198
self.type_map[node] = typ
2199+
if typ is not None:
2200+
self.module_type_map[node] = typ
21892201

21902202
def typing_mode_none(self) -> bool:
21912203
if self.is_dynamic_function() and not self.options.check_untyped_defs:

mypy/checkexpr.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Expression type checker. This file is conceptually part of TypeChecker."""
22

3-
from typing import cast, Dict, List, Tuple, Callable, Union, Optional
3+
from typing import cast, Dict, Set, List, Iterable, Tuple, Callable, Union, Optional
44

55
from mypy.types import (
66
Type, AnyType, CallableType, Overloaded, NoneTyp, Void, TypeVarDef,
@@ -16,7 +16,7 @@
1616
ListComprehension, GeneratorExpr, SetExpr, MypyFile, Decorator,
1717
ConditionalExpr, ComparisonExpr, TempNode, SetComprehension,
1818
DictionaryComprehension, ComplexExpr, EllipsisExpr, StarExpr,
19-
TypeAliasExpr, BackquoteExpr, ARG_POS, ARG_NAMED, ARG_STAR2
19+
TypeAliasExpr, BackquoteExpr, ARG_POS, ARG_NAMED, ARG_STAR2, MODULE_REF,
2020
)
2121
from mypy.nodes import function_type
2222
from mypy import nodes
@@ -45,6 +45,43 @@
4545
None]
4646

4747

48+
def split_module_names(mod_name: str) -> Iterable[str]:
49+
"""Yields the module and all parent module names.
50+
51+
So, if `mod_name` is 'a.b.c', this function will yield
52+
['a.b.c', 'a.b', and 'a']."""
53+
yield mod_name
54+
while '.' in mod_name:
55+
mod_name = mod_name.rsplit('.', 1)[0]
56+
yield mod_name
57+
58+
59+
def extract_refexpr_names(expr: RefExpr) -> Set[str]:
60+
"""Recursively extracts all module references from a reference expression.
61+
62+
Note that currently, the only two subclasses of RefExpr are NameExpr and
63+
MemberExpr."""
64+
output = set() # type: Set[str]
65+
while expr.kind == MODULE_REF or expr.fullname is not None:
66+
if expr.kind == MODULE_REF:
67+
output.add(expr.fullname)
68+
elif expr.fullname is not None and '.' in expr.fullname:
69+
output.add(expr.fullname.rsplit('.', 1)[0])
70+
71+
if isinstance(expr, NameExpr):
72+
if expr.info is not None:
73+
output.update(split_module_names(expr.info.module_name))
74+
break
75+
elif isinstance(expr, MemberExpr):
76+
if isinstance(expr.expr, RefExpr):
77+
expr = expr.expr
78+
else:
79+
break
80+
else:
81+
raise AssertionError("Unknown RefExpr subclass: {}".format(type(expr)))
82+
return output
83+
84+
4885
class Finished(Exception):
4986
"""Raised if we can terminate overload argument check early (no match)."""
5087

@@ -70,11 +107,16 @@ def __init__(self,
70107
self.msg = msg
71108
self.strfrm_checker = StringFormatterChecker(self, self.chk, self.msg)
72109

110+
def _visit_typeinfo(self, info: nodes.TypeInfo) -> None:
111+
if info is not None:
112+
self.chk.module_refs.update(split_module_names(info.module_name))
113+
73114
def visit_name_expr(self, e: NameExpr) -> Type:
74115
"""Type check a name expression.
75116
76117
It can be of any kind: local, member or global.
77118
"""
119+
self.chk.module_refs.update(extract_refexpr_names(e))
78120
result = self.analyze_ref_expr(e)
79121
return self.chk.narrow_type_from_binder(e, result)
80122

@@ -858,6 +900,7 @@ def apply_generic_arguments2(self, overload: Overloaded, types: List[Type],
858900

859901
def visit_member_expr(self, e: MemberExpr) -> Type:
860902
"""Visit member expression (of form e.id)."""
903+
self.chk.module_refs.update(extract_refexpr_names(e))
861904
result = self.analyze_ordinary_member_access(e, False)
862905
return self.chk.narrow_type_from_binder(e, result)
863906

0 commit comments

Comments
 (0)