|
| 1 | +from typing import Optional, Set, Iterable |
| 2 | +from abc import abstractmethod |
| 3 | + |
| 4 | +from mypy.visitor import NodeVisitor |
| 5 | +from mypy.types import TypeVisitor |
| 6 | +from mypy.nodes import MODULE_REF |
| 7 | +import mypy.nodes as nodes |
| 8 | +import mypy.types as types |
| 9 | +from mypy.util import split_module_names |
| 10 | + |
| 11 | + |
| 12 | +def extract_module_names(symbol_name: str) -> Iterable[str]: |
| 13 | + """Returns the module and parent modules of a fully qualified symbol name.""" |
| 14 | + if symbol_name is not None: |
| 15 | + while '.' in symbol_name: |
| 16 | + symbol_name = symbol_name.rsplit('.', 1)[0] |
| 17 | + yield symbol_name |
| 18 | + |
| 19 | + |
| 20 | +class TypeIndirectionVisitor(TypeVisitor[Set[str]]): |
| 21 | + """Returns all module references within a particular type.""" |
| 22 | + |
| 23 | + def __init__(self) -> None: |
| 24 | + self.cache = {} # type: Dict[types.Type, Set[str]] |
| 25 | + |
| 26 | + def find_modules(self, typs: Iterable[types.Type]) -> Set[str]: |
| 27 | + return self._visit(*typs) |
| 28 | + |
| 29 | + def _visit(self, *typs: types.Type) -> Set[str]: |
| 30 | + output = set() # type: Set[str] |
| 31 | + for typ in typs: |
| 32 | + if typ in self.cache: |
| 33 | + modules = self.cache[typ] |
| 34 | + else: |
| 35 | + modules = typ.accept(self) |
| 36 | + self.cache[typ] = set(modules) |
| 37 | + output.update(modules) |
| 38 | + return output |
| 39 | + |
| 40 | + def visit_unbound_type(self, t: types.UnboundType) -> Set[str]: |
| 41 | + return self._visit(*t.args) |
| 42 | + |
| 43 | + def visit_type_list(self, t: types.TypeList) -> Set[str]: |
| 44 | + return self._visit(*t.items) |
| 45 | + |
| 46 | + def visit_error_type(self, t: types.ErrorType) -> Set[str]: |
| 47 | + return set() |
| 48 | + |
| 49 | + def visit_any(self, t: types.AnyType) -> Set[str]: |
| 50 | + return set() |
| 51 | + |
| 52 | + def visit_void(self, t: types.Void) -> Set[str]: |
| 53 | + return set() |
| 54 | + |
| 55 | + def visit_none_type(self, t: types.NoneTyp) -> Set[str]: |
| 56 | + return set() |
| 57 | + |
| 58 | + def visit_uninhabited_type(self, t: types.UninhabitedType) -> Set[str]: |
| 59 | + return set() |
| 60 | + |
| 61 | + def visit_erased_type(self, t: types.ErasedType) -> Set[str]: |
| 62 | + return set() |
| 63 | + |
| 64 | + def visit_deleted_type(self, t: types.DeletedType) -> Set[str]: |
| 65 | + return set() |
| 66 | + |
| 67 | + def visit_type_var(self, t: types.TypeVarType) -> Set[str]: |
| 68 | + return self._visit(*t.values) | self._visit(t.upper_bound) |
| 69 | + |
| 70 | + def visit_instance(self, t: types.Instance) -> Set[str]: |
| 71 | + out = self._visit(*t.args) |
| 72 | + if t.type is not None: |
| 73 | + out.update(split_module_names(t.type.module_name)) |
| 74 | + return out |
| 75 | + |
| 76 | + def visit_callable_type(self, t: types.CallableType) -> Set[str]: |
| 77 | + out = self._visit(*t.arg_types) | self._visit(t.ret_type) |
| 78 | + if t.definition is not None: |
| 79 | + out.update(extract_module_names(t.definition.fullname())) |
| 80 | + return out |
| 81 | + |
| 82 | + def visit_overloaded(self, t: types.Overloaded) -> Set[str]: |
| 83 | + return self._visit(*t.items()) | self._visit(t.fallback) |
| 84 | + |
| 85 | + def visit_tuple_type(self, t: types.TupleType) -> Set[str]: |
| 86 | + return self._visit(*t.items) | self._visit(t.fallback) |
| 87 | + |
| 88 | + def visit_star_type(self, t: types.StarType) -> Set[str]: |
| 89 | + return set() |
| 90 | + |
| 91 | + def visit_union_type(self, t: types.UnionType) -> Set[str]: |
| 92 | + return self._visit(*t.items) |
| 93 | + |
| 94 | + def visit_partial_type(self, t: types.PartialType) -> Set[str]: |
| 95 | + return set() |
| 96 | + |
| 97 | + def visit_ellipsis_type(self, t: types.EllipsisType) -> Set[str]: |
| 98 | + return set() |
| 99 | + |
| 100 | + def visit_type_type(self, t: types.TypeType) -> Set[str]: |
| 101 | + return self._visit(t.item) |
0 commit comments