Skip to content

Commit e0b6a1b

Browse files
committed
Add module reference extractor for types
This commit adds a new visitor to extract all module references from arbitrary types. This visitor should be considered provisional -- it will most likely be merged in with the typechecker at a later stage to make it a less expensive call.
1 parent 20718d7 commit e0b6a1b

File tree

3 files changed

+114
-12
lines changed

3 files changed

+114
-12
lines changed

mypy/checkexpr.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from mypy.constraints import get_actual_type
3737
from mypy.checkstrformat import StringFormatterChecker
3838
from mypy.expandtype import expand_type
39+
from mypy.util import split_module_names
3940

4041
from mypy import experiments
4142

@@ -45,17 +46,6 @@
4546
None]
4647

4748

48-
def split_module_names(mod_name: str) -> Iterable[str]:
49-
"""Yields the module and all parent module names.
50-
51-
So, if `mod_name` is 'a.b.c', this function will yield
52-
['a.b.c', 'a.b', and 'a']."""
53-
yield mod_name
54-
while '.' in mod_name:
55-
mod_name = mod_name.rsplit('.', 1)[0]
56-
yield mod_name
57-
58-
5949
def extract_refexpr_names(expr: RefExpr) -> Set[str]:
6050
"""Recursively extracts all module references from a reference expression.
6151

mypy/indirection.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from typing import Optional, Set, Iterable
2+
from abc import abstractmethod
3+
4+
from mypy.visitor import NodeVisitor
5+
from mypy.types import TypeVisitor
6+
from mypy.nodes import MODULE_REF
7+
import mypy.nodes as nodes
8+
import mypy.types as types
9+
from mypy.util import split_module_names
10+
11+
12+
def extract_module_names(symbol_name: str) -> Iterable[str]:
13+
"""Returns the module and parent modules of a fully qualified symbol name."""
14+
if symbol_name is not None:
15+
while '.' in symbol_name:
16+
symbol_name = symbol_name.rsplit('.', 1)[0]
17+
yield symbol_name
18+
19+
20+
class TypeIndirectionVisitor(TypeVisitor[Set[str]]):
21+
"""Returns all module references within a particular type."""
22+
23+
def __init__(self) -> None:
24+
self.cache = {} # type: Dict[types.Type, Set[str]]
25+
26+
def find_modules(self, typs: Iterable[types.Type]) -> Set[str]:
27+
return self._visit(*typs)
28+
29+
def _visit(self, *typs: types.Type) -> Set[str]:
30+
output = set() # type: Set[str]
31+
for typ in typs:
32+
if typ in self.cache:
33+
modules = self.cache[typ]
34+
else:
35+
modules = typ.accept(self)
36+
self.cache[typ] = set(modules)
37+
output.update(modules)
38+
return output
39+
40+
def visit_unbound_type(self, t: types.UnboundType) -> Set[str]:
41+
return self._visit(*t.args)
42+
43+
def visit_type_list(self, t: types.TypeList) -> Set[str]:
44+
return self._visit(*t.items)
45+
46+
def visit_error_type(self, t: types.ErrorType) -> Set[str]:
47+
return set()
48+
49+
def visit_any(self, t: types.AnyType) -> Set[str]:
50+
return set()
51+
52+
def visit_void(self, t: types.Void) -> Set[str]:
53+
return set()
54+
55+
def visit_none_type(self, t: types.NoneTyp) -> Set[str]:
56+
return set()
57+
58+
def visit_uninhabited_type(self, t: types.UninhabitedType) -> Set[str]:
59+
return set()
60+
61+
def visit_erased_type(self, t: types.ErasedType) -> Set[str]:
62+
return set()
63+
64+
def visit_deleted_type(self, t: types.DeletedType) -> Set[str]:
65+
return set()
66+
67+
def visit_type_var(self, t: types.TypeVarType) -> Set[str]:
68+
return self._visit(*t.values) | self._visit(t.upper_bound)
69+
70+
def visit_instance(self, t: types.Instance) -> Set[str]:
71+
out = self._visit(*t.args)
72+
if t.type is not None:
73+
out.update(split_module_names(t.type.module_name))
74+
return out
75+
76+
def visit_callable_type(self, t: types.CallableType) -> Set[str]:
77+
out = self._visit(*t.arg_types) | self._visit(t.ret_type)
78+
if t.definition is not None:
79+
out.update(extract_module_names(t.definition.fullname()))
80+
return out
81+
82+
def visit_overloaded(self, t: types.Overloaded) -> Set[str]:
83+
return self._visit(*t.items()) | self._visit(t.fallback)
84+
85+
def visit_tuple_type(self, t: types.TupleType) -> Set[str]:
86+
return self._visit(*t.items) | self._visit(t.fallback)
87+
88+
def visit_star_type(self, t: types.StarType) -> Set[str]:
89+
return set()
90+
91+
def visit_union_type(self, t: types.UnionType) -> Set[str]:
92+
return self._visit(*t.items)
93+
94+
def visit_partial_type(self, t: types.PartialType) -> Set[str]:
95+
return set()
96+
97+
def visit_ellipsis_type(self, t: types.EllipsisType) -> Set[str]:
98+
return set()
99+
100+
def visit_type_type(self, t: types.TypeType) -> Set[str]:
101+
return self._visit(t.item)

mypy/util.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import re
44
import subprocess
5-
from typing import TypeVar, List, Any, Tuple, Optional
5+
from typing import TypeVar, List, Any, Tuple, Optional, Iterable
66

77

88
T = TypeVar('T')
@@ -12,6 +12,17 @@
1212
default_python2_interpreter = ['python2', 'python', '/usr/bin/python']
1313

1414

15+
def split_module_names(mod_name: str) -> Iterable[str]:
16+
"""Yields the module and all parent module names.
17+
18+
So, if `mod_name` is 'a.b.c', this function will yield
19+
['a.b.c', 'a.b', and 'a']."""
20+
yield mod_name
21+
while '.' in mod_name:
22+
mod_name = mod_name.rsplit('.', 1)[0]
23+
yield mod_name
24+
25+
1526
def short_type(obj: object) -> str:
1627
"""Return the last component of the type name of an object.
1728

0 commit comments

Comments
 (0)