Skip to content

Commit 73ba1e7

Browse files
authored
Optimize type indirection visitor (#18298)
This was a performance bottleneck when type checking torch. It used to perform lots of set unions and hash value calculations on mypy type objects, which are both pretty expensive. Now we mostly rely on set contains and set add operations with strings, which are much faster. We also avoid constructing many temporary objects. Speeds up type checking torch by about 3%. Also appears to speed up self check by about 2%.
1 parent d3be43d commit 73ba1e7

File tree

2 files changed

+78
-63
lines changed

2 files changed

+78
-63
lines changed

mypy/indirection.py

Lines changed: 74 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Iterable, Set
3+
from typing import Iterable
44

55
import mypy.types as types
66
from mypy.types import TypeVisitor
@@ -17,105 +17,118 @@ def extract_module_names(type_name: str | None) -> list[str]:
1717
return []
1818

1919

20-
class TypeIndirectionVisitor(TypeVisitor[Set[str]]):
20+
class TypeIndirectionVisitor(TypeVisitor[None]):
2121
"""Returns all module references within a particular type."""
2222

2323
def __init__(self) -> None:
24-
self.cache: dict[types.Type, set[str]] = {}
24+
# Module references are collected here
25+
self.modules: set[str] = set()
26+
# User to avoid infinite recursion with recursive type aliases
2527
self.seen_aliases: set[types.TypeAliasType] = set()
28+
# Used to avoid redundant work
29+
self.seen_fullnames: set[str] = set()
2630

2731
def find_modules(self, typs: Iterable[types.Type]) -> set[str]:
28-
self.seen_aliases.clear()
29-
return self._visit(typs)
32+
self.modules = set()
33+
self.seen_fullnames = set()
34+
self.seen_aliases = set()
35+
self._visit(typs)
36+
return self.modules
3037

31-
def _visit(self, typ_or_typs: types.Type | Iterable[types.Type]) -> set[str]:
38+
def _visit(self, typ_or_typs: types.Type | Iterable[types.Type]) -> None:
3239
typs = [typ_or_typs] if isinstance(typ_or_typs, types.Type) else typ_or_typs
33-
output: set[str] = set()
3440
for typ in typs:
3541
if isinstance(typ, types.TypeAliasType):
3642
# Avoid infinite recursion for recursive type aliases.
3743
if typ in self.seen_aliases:
3844
continue
3945
self.seen_aliases.add(typ)
40-
if typ in self.cache:
41-
modules = self.cache[typ]
42-
else:
43-
modules = typ.accept(self)
44-
self.cache[typ] = set(modules)
45-
output.update(modules)
46-
return output
46+
typ.accept(self)
4747

48-
def visit_unbound_type(self, t: types.UnboundType) -> set[str]:
49-
return self._visit(t.args)
48+
def _visit_module_name(self, module_name: str) -> None:
49+
if module_name not in self.modules:
50+
self.modules.update(split_module_names(module_name))
5051

51-
def visit_any(self, t: types.AnyType) -> set[str]:
52-
return set()
52+
def visit_unbound_type(self, t: types.UnboundType) -> None:
53+
self._visit(t.args)
5354

54-
def visit_none_type(self, t: types.NoneType) -> set[str]:
55-
return set()
55+
def visit_any(self, t: types.AnyType) -> None:
56+
pass
5657

57-
def visit_uninhabited_type(self, t: types.UninhabitedType) -> set[str]:
58-
return set()
58+
def visit_none_type(self, t: types.NoneType) -> None:
59+
pass
5960

60-
def visit_erased_type(self, t: types.ErasedType) -> set[str]:
61-
return set()
61+
def visit_uninhabited_type(self, t: types.UninhabitedType) -> None:
62+
pass
6263

63-
def visit_deleted_type(self, t: types.DeletedType) -> set[str]:
64-
return set()
64+
def visit_erased_type(self, t: types.ErasedType) -> None:
65+
pass
6566

66-
def visit_type_var(self, t: types.TypeVarType) -> set[str]:
67-
return self._visit(t.values) | self._visit(t.upper_bound) | self._visit(t.default)
67+
def visit_deleted_type(self, t: types.DeletedType) -> None:
68+
pass
6869

69-
def visit_param_spec(self, t: types.ParamSpecType) -> set[str]:
70-
return self._visit(t.upper_bound) | self._visit(t.default)
70+
def visit_type_var(self, t: types.TypeVarType) -> None:
71+
self._visit(t.values)
72+
self._visit(t.upper_bound)
73+
self._visit(t.default)
7174

72-
def visit_type_var_tuple(self, t: types.TypeVarTupleType) -> set[str]:
73-
return self._visit(t.upper_bound) | self._visit(t.default)
75+
def visit_param_spec(self, t: types.ParamSpecType) -> None:
76+
self._visit(t.upper_bound)
77+
self._visit(t.default)
7478

75-
def visit_unpack_type(self, t: types.UnpackType) -> set[str]:
76-
return t.type.accept(self)
79+
def visit_type_var_tuple(self, t: types.TypeVarTupleType) -> None:
80+
self._visit(t.upper_bound)
81+
self._visit(t.default)
7782

78-
def visit_parameters(self, t: types.Parameters) -> set[str]:
79-
return self._visit(t.arg_types)
83+
def visit_unpack_type(self, t: types.UnpackType) -> None:
84+
t.type.accept(self)
8085

81-
def visit_instance(self, t: types.Instance) -> set[str]:
82-
out = self._visit(t.args)
86+
def visit_parameters(self, t: types.Parameters) -> None:
87+
self._visit(t.arg_types)
88+
89+
def visit_instance(self, t: types.Instance) -> None:
90+
self._visit(t.args)
8391
if t.type:
8492
# Uses of a class depend on everything in the MRO,
8593
# as changes to classes in the MRO can add types to methods,
8694
# change property types, change the MRO itself, etc.
8795
for s in t.type.mro:
88-
out.update(split_module_names(s.module_name))
96+
self._visit_module_name(s.module_name)
8997
if t.type.metaclass_type is not None:
90-
out.update(split_module_names(t.type.metaclass_type.type.module_name))
91-
return out
98+
self._visit_module_name(t.type.metaclass_type.type.module_name)
9299

93-
def visit_callable_type(self, t: types.CallableType) -> set[str]:
94-
out = self._visit(t.arg_types) | self._visit(t.ret_type)
100+
def visit_callable_type(self, t: types.CallableType) -> None:
101+
self._visit(t.arg_types)
102+
self._visit(t.ret_type)
95103
if t.definition is not None:
96-
out.update(extract_module_names(t.definition.fullname))
97-
return out
104+
fullname = t.definition.fullname
105+
if fullname not in self.seen_fullnames:
106+
self.modules.update(extract_module_names(t.definition.fullname))
107+
self.seen_fullnames.add(fullname)
98108

99-
def visit_overloaded(self, t: types.Overloaded) -> set[str]:
100-
return self._visit(t.items) | self._visit(t.fallback)
109+
def visit_overloaded(self, t: types.Overloaded) -> None:
110+
self._visit(t.items)
111+
self._visit(t.fallback)
101112

102-
def visit_tuple_type(self, t: types.TupleType) -> set[str]:
103-
return self._visit(t.items) | self._visit(t.partial_fallback)
113+
def visit_tuple_type(self, t: types.TupleType) -> None:
114+
self._visit(t.items)
115+
self._visit(t.partial_fallback)
104116

105-
def visit_typeddict_type(self, t: types.TypedDictType) -> set[str]:
106-
return self._visit(t.items.values()) | self._visit(t.fallback)
117+
def visit_typeddict_type(self, t: types.TypedDictType) -> None:
118+
self._visit(t.items.values())
119+
self._visit(t.fallback)
107120

108-
def visit_literal_type(self, t: types.LiteralType) -> set[str]:
109-
return self._visit(t.fallback)
121+
def visit_literal_type(self, t: types.LiteralType) -> None:
122+
self._visit(t.fallback)
110123

111-
def visit_union_type(self, t: types.UnionType) -> set[str]:
112-
return self._visit(t.items)
124+
def visit_union_type(self, t: types.UnionType) -> None:
125+
self._visit(t.items)
113126

114-
def visit_partial_type(self, t: types.PartialType) -> set[str]:
115-
return set()
127+
def visit_partial_type(self, t: types.PartialType) -> None:
128+
pass
116129

117-
def visit_type_type(self, t: types.TypeType) -> set[str]:
118-
return self._visit(t.item)
130+
def visit_type_type(self, t: types.TypeType) -> None:
131+
self._visit(t.item)
119132

120-
def visit_type_alias_type(self, t: types.TypeAliasType) -> set[str]:
121-
return self._visit(types.get_proper_type(t))
133+
def visit_type_alias_type(self, t: types.TypeAliasType) -> None:
134+
self._visit(types.get_proper_type(t))

mypy/test/testtypes.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,12 +230,14 @@ def test_recursive_nested_in_non_recursive(self) -> None:
230230
def test_indirection_no_infinite_recursion(self) -> None:
231231
A, _ = self.fx.def_alias_1(self.fx.a)
232232
visitor = TypeIndirectionVisitor()
233-
modules = A.accept(visitor)
233+
A.accept(visitor)
234+
modules = visitor.modules
234235
assert modules == {"__main__", "builtins"}
235236

236237
A, _ = self.fx.def_alias_2(self.fx.a)
237238
visitor = TypeIndirectionVisitor()
238-
modules = A.accept(visitor)
239+
A.accept(visitor)
240+
modules = visitor.modules
239241
assert modules == {"__main__", "builtins"}
240242

241243

0 commit comments

Comments
 (0)