1
1
from __future__ import annotations
2
2
3
- from typing import Iterable , Set
3
+ from typing import Iterable
4
4
5
5
import mypy .types as types
6
6
from mypy .types import TypeVisitor
@@ -17,105 +17,118 @@ def extract_module_names(type_name: str | None) -> list[str]:
17
17
return []
18
18
19
19
20
- class TypeIndirectionVisitor (TypeVisitor [Set [ str ] ]):
20
+ class TypeIndirectionVisitor (TypeVisitor [None ]):
21
21
"""Returns all module references within a particular type."""
22
22
23
23
def __init__ (self ) -> None :
24
- self .cache : dict [types .Type , set [str ]] = {}
24
+ # Module references are collected here
25
+ self .modules : set [str ] = set ()
26
+ # User to avoid infinite recursion with recursive type aliases
25
27
self .seen_aliases : set [types .TypeAliasType ] = set ()
28
+ # Used to avoid redundant work
29
+ self .seen_fullnames : set [str ] = set ()
26
30
27
31
def find_modules (self , typs : Iterable [types .Type ]) -> set [str ]:
28
- self .seen_aliases .clear ()
29
- return self ._visit (typs )
32
+ self .modules = set ()
33
+ self .seen_fullnames = set ()
34
+ self .seen_aliases = set ()
35
+ self ._visit (typs )
36
+ return self .modules
30
37
31
- def _visit (self , typ_or_typs : types .Type | Iterable [types .Type ]) -> set [ str ] :
38
+ def _visit (self , typ_or_typs : types .Type | Iterable [types .Type ]) -> None :
32
39
typs = [typ_or_typs ] if isinstance (typ_or_typs , types .Type ) else typ_or_typs
33
- output : set [str ] = set ()
34
40
for typ in typs :
35
41
if isinstance (typ , types .TypeAliasType ):
36
42
# Avoid infinite recursion for recursive type aliases.
37
43
if typ in self .seen_aliases :
38
44
continue
39
45
self .seen_aliases .add (typ )
40
- if typ in self .cache :
41
- modules = self .cache [typ ]
42
- else :
43
- modules = typ .accept (self )
44
- self .cache [typ ] = set (modules )
45
- output .update (modules )
46
- return output
46
+ typ .accept (self )
47
47
48
- def visit_unbound_type (self , t : types .UnboundType ) -> set [str ]:
49
- return self ._visit (t .args )
48
+ def _visit_module_name (self , module_name : str ) -> None :
49
+ if module_name not in self .modules :
50
+ self .modules .update (split_module_names (module_name ))
50
51
51
- def visit_any (self , t : types .AnyType ) -> set [ str ] :
52
- return set ( )
52
+ def visit_unbound_type (self , t : types .UnboundType ) -> None :
53
+ self . _visit ( t . args )
53
54
54
- def visit_none_type (self , t : types .NoneType ) -> set [ str ] :
55
- return set ()
55
+ def visit_any (self , t : types .AnyType ) -> None :
56
+ pass
56
57
57
- def visit_uninhabited_type (self , t : types .UninhabitedType ) -> set [ str ] :
58
- return set ()
58
+ def visit_none_type (self , t : types .NoneType ) -> None :
59
+ pass
59
60
60
- def visit_erased_type (self , t : types .ErasedType ) -> set [ str ] :
61
- return set ()
61
+ def visit_uninhabited_type (self , t : types .UninhabitedType ) -> None :
62
+ pass
62
63
63
- def visit_deleted_type (self , t : types .DeletedType ) -> set [ str ] :
64
- return set ()
64
+ def visit_erased_type (self , t : types .ErasedType ) -> None :
65
+ pass
65
66
66
- def visit_type_var (self , t : types .TypeVarType ) -> set [ str ] :
67
- return self . _visit ( t . values ) | self . _visit ( t . upper_bound ) | self . _visit ( t . default )
67
+ def visit_deleted_type (self , t : types .DeletedType ) -> None :
68
+ pass
68
69
69
- def visit_param_spec (self , t : types .ParamSpecType ) -> set [str ]:
70
- return self ._visit (t .upper_bound ) | self ._visit (t .default )
70
+ def visit_type_var (self , t : types .TypeVarType ) -> None :
71
+ self ._visit (t .values )
72
+ self ._visit (t .upper_bound )
73
+ self ._visit (t .default )
71
74
72
- def visit_type_var_tuple (self , t : types .TypeVarTupleType ) -> set [str ]:
73
- return self ._visit (t .upper_bound ) | self ._visit (t .default )
75
+ def visit_param_spec (self , t : types .ParamSpecType ) -> None :
76
+ self ._visit (t .upper_bound )
77
+ self ._visit (t .default )
74
78
75
- def visit_unpack_type (self , t : types .UnpackType ) -> set [str ]:
76
- return t .type .accept (self )
79
+ def visit_type_var_tuple (self , t : types .TypeVarTupleType ) -> None :
80
+ self ._visit (t .upper_bound )
81
+ self ._visit (t .default )
77
82
78
- def visit_parameters (self , t : types .Parameters ) -> set [ str ] :
79
- return self . _visit ( t . arg_types )
83
+ def visit_unpack_type (self , t : types .UnpackType ) -> None :
84
+ t . type . accept ( self )
80
85
81
- def visit_instance (self , t : types .Instance ) -> set [str ]:
82
- out = self ._visit (t .args )
86
+ def visit_parameters (self , t : types .Parameters ) -> None :
87
+ self ._visit (t .arg_types )
88
+
89
+ def visit_instance (self , t : types .Instance ) -> None :
90
+ self ._visit (t .args )
83
91
if t .type :
84
92
# Uses of a class depend on everything in the MRO,
85
93
# as changes to classes in the MRO can add types to methods,
86
94
# change property types, change the MRO itself, etc.
87
95
for s in t .type .mro :
88
- out . update ( split_module_names ( s .module_name ) )
96
+ self . _visit_module_name ( s .module_name )
89
97
if t .type .metaclass_type is not None :
90
- out .update (split_module_names (t .type .metaclass_type .type .module_name ))
91
- return out
98
+ self ._visit_module_name (t .type .metaclass_type .type .module_name )
92
99
93
- def visit_callable_type (self , t : types .CallableType ) -> set [str ]:
94
- out = self ._visit (t .arg_types ) | self ._visit (t .ret_type )
100
+ def visit_callable_type (self , t : types .CallableType ) -> None :
101
+ self ._visit (t .arg_types )
102
+ self ._visit (t .ret_type )
95
103
if t .definition is not None :
96
- out .update (extract_module_names (t .definition .fullname ))
97
- return out
104
+ fullname = t .definition .fullname
105
+ if fullname not in self .seen_fullnames :
106
+ self .modules .update (extract_module_names (t .definition .fullname ))
107
+ self .seen_fullnames .add (fullname )
98
108
99
- def visit_overloaded (self , t : types .Overloaded ) -> set [str ]:
100
- return self ._visit (t .items ) | self ._visit (t .fallback )
109
+ def visit_overloaded (self , t : types .Overloaded ) -> None :
110
+ self ._visit (t .items )
111
+ self ._visit (t .fallback )
101
112
102
- def visit_tuple_type (self , t : types .TupleType ) -> set [str ]:
103
- return self ._visit (t .items ) | self ._visit (t .partial_fallback )
113
+ def visit_tuple_type (self , t : types .TupleType ) -> None :
114
+ self ._visit (t .items )
115
+ self ._visit (t .partial_fallback )
104
116
105
- def visit_typeddict_type (self , t : types .TypedDictType ) -> set [str ]:
106
- return self ._visit (t .items .values ()) | self ._visit (t .fallback )
117
+ def visit_typeddict_type (self , t : types .TypedDictType ) -> None :
118
+ self ._visit (t .items .values ())
119
+ self ._visit (t .fallback )
107
120
108
- def visit_literal_type (self , t : types .LiteralType ) -> set [ str ] :
109
- return self ._visit (t .fallback )
121
+ def visit_literal_type (self , t : types .LiteralType ) -> None :
122
+ self ._visit (t .fallback )
110
123
111
- def visit_union_type (self , t : types .UnionType ) -> set [ str ] :
112
- return self ._visit (t .items )
124
+ def visit_union_type (self , t : types .UnionType ) -> None :
125
+ self ._visit (t .items )
113
126
114
- def visit_partial_type (self , t : types .PartialType ) -> set [ str ] :
115
- return set ()
127
+ def visit_partial_type (self , t : types .PartialType ) -> None :
128
+ pass
116
129
117
- def visit_type_type (self , t : types .TypeType ) -> set [ str ] :
118
- return self ._visit (t .item )
130
+ def visit_type_type (self , t : types .TypeType ) -> None :
131
+ self ._visit (t .item )
119
132
120
- def visit_type_alias_type (self , t : types .TypeAliasType ) -> set [ str ] :
121
- return self ._visit (types .get_proper_type (t ))
133
+ def visit_type_alias_type (self , t : types .TypeAliasType ) -> None :
134
+ self ._visit (types .get_proper_type (t ))
0 commit comments