Skip to content

Commit cddd819

Browse files
authored
Speed up caching of subtype checks (#12539)
This is very performance critical. Implement a few micro-optimizations to speed caching a bit. In particular, we use dict.get to reduce the number of dict lookups required, and avoid tuple concatenation which tends to be a bit slow, as it has to construct temporary objects. It would probably be even better to avoid using tuples as keys altogether. This could be a reasonable follow-up improvement. Avoid caching if last known value is set, since it reduces the likelihood of cache hits a lot, because the space of literal values is big (essentially infinite). Also make the global strict_optional attribute an instance-level attribute for faster access, as we might now use it more frequently. I extracted the cached subtype check code into a microbenchmark and the new implementation seems about twice as fast (in an artificial setting, though). Work on #12526 (but should generally make things a little better).
1 parent 75e907d commit cddd819

File tree

13 files changed

+68
-45
lines changed

13 files changed

+68
-45
lines changed

mypy/checker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@
8383
from mypy.plugin import Plugin, CheckerPluginInterface
8484
from mypy.sharedparse import BINARY_MAGIC_METHODS
8585
from mypy.scope import Scope
86-
from mypy import state, errorcodes as codes
86+
from mypy import errorcodes as codes
87+
from mypy.state import state
8788
from mypy.traverser import has_return_statement, all_return_statements
8889
from mypy.errorcodes import ErrorCode, UNUSED_AWAITABLE, UNUSED_COROUTINE
8990
from mypy.util import is_typeshed_file, is_dunder, is_sunder

mypy/join.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818
from mypy.nodes import INVARIANT, COVARIANT, CONTRAVARIANT
1919
import mypy.typeops
20-
from mypy import state
20+
from mypy.state import state
2121

2222

2323
class InstanceJoiner:

mypy/meet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from mypy.erasetype import erase_type
1313
from mypy.maptype import map_instance_to_supertype
1414
from mypy.typeops import tuple_fallback, make_simplified_union, is_recursive_pair
15-
from mypy import state
15+
from mypy.state import state
1616
from mypy import join
1717

1818
# TODO Describe this module.

mypy/semanal_main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
MypyFile, TypeInfo, FuncDef, Decorator, OverloadedFuncDef, Var
3333
)
3434
from mypy.semanal_typeargs import TypeArgumentAnalyzer
35-
from mypy.state import strict_optional_set
35+
import mypy.state
3636
from mypy.semanal import (
3737
SemanticAnalyzer, apply_semantic_analyzer_patches, remove_imported_names_from_symtable
3838
)
@@ -356,7 +356,7 @@ def check_type_arguments(graph: 'Graph', scc: List[str], errors: Errors) -> None
356356
state.options,
357357
is_typeshed_file(state.path or ''))
358358
with state.wrap_context():
359-
with strict_optional_set(state.options.strict_optional):
359+
with mypy.state.state.strict_optional_set(state.options.strict_optional):
360360
state.tree.accept(analyzer)
361361

362362

@@ -371,7 +371,7 @@ def check_type_arguments_in_targets(targets: List[FineGrainedDeferredNode], stat
371371
state.options,
372372
is_typeshed_file(state.path or ''))
373373
with state.wrap_context():
374-
with strict_optional_set(state.options.strict_optional):
374+
with mypy.state.state.strict_optional_set(state.options.strict_optional):
375375
for target in targets:
376376
func: Optional[Union[FuncDef, OverloadedFuncDef]] = None
377377
if isinstance(target.node, (FuncDef, OverloadedFuncDef)):

mypy/state.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,28 @@
11
from contextlib import contextmanager
22
from typing import Optional, Tuple, Iterator
33

4+
from typing_extensions import Final
5+
46
# These are global mutable state. Don't add anything here unless there's a very
57
# good reason.
68

7-
# Value varies by file being processed
8-
strict_optional = False
9-
find_occurrences: Optional[Tuple[str, str]] = None
109

10+
class StrictOptionalState:
11+
# Wrap this in a class since it's faster that using a module-level attribute.
12+
13+
def __init__(self, strict_optional: bool) -> None:
14+
# Value varies by file being processed
15+
self.strict_optional = strict_optional
1116

12-
@contextmanager
13-
def strict_optional_set(value: bool) -> Iterator[None]:
14-
global strict_optional
15-
saved = strict_optional
16-
strict_optional = value
17-
try:
18-
yield
19-
finally:
20-
strict_optional = saved
17+
@contextmanager
18+
def strict_optional_set(self, value: bool) -> Iterator[None]:
19+
saved = self.strict_optional
20+
self.strict_optional = value
21+
try:
22+
yield
23+
finally:
24+
self.strict_optional = saved
25+
26+
27+
state: Final = StrictOptionalState(strict_optional=False)
28+
find_occurrences: Optional[Tuple[str, str]] = None

mypy/stubtest.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import mypy.build
2323
import mypy.modulefinder
24+
import mypy.state
2425
import mypy.types
2526
from mypy import nodes
2627
from mypy.config_parser import parse_config_file
@@ -560,7 +561,7 @@ def get_position(arg_name: str) -> int:
560561
return max(index for _, index in all_args[arg_name])
561562

562563
def get_type(arg_name: str) -> mypy.types.ProperType:
563-
with mypy.state.strict_optional_set(True):
564+
with mypy.state.state.strict_optional_set(True):
564565
all_types = [
565566
arg.variable.type or arg.type_annotation for arg, _ in all_args[arg_name]
566567
]
@@ -1099,7 +1100,7 @@ def is_subtype_helper(left: mypy.types.Type, right: mypy.types.Type) -> bool:
10991100
# Special case checks against TypedDicts
11001101
return True
11011102

1102-
with mypy.state.strict_optional_set(True):
1103+
with mypy.state.state.strict_optional_set(True):
11031104
return mypy.subtypes.is_subtype(left, right)
11041105

11051106

mypy/subtypes.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from mypy.expandtype import expand_type_by_instance
2626
from mypy.typestate import TypeState, SubtypeKind
2727
from mypy.options import Options
28-
from mypy import state
28+
from mypy.state import state
2929

3030
# Flags for detected protocol members
3131
IS_SETTABLE: Final = 1
@@ -206,7 +206,8 @@ def build_subtype_kind(*,
206206
ignore_pos_arg_names: bool = False,
207207
ignore_declared_variance: bool = False,
208208
ignore_promotions: bool = False) -> SubtypeKind:
209-
return (False, # is proper subtype?
209+
return (state.strict_optional,
210+
False, # is proper subtype?
210211
ignore_type_params,
211212
ignore_pos_arg_names,
212213
ignore_declared_variance,
@@ -1316,7 +1317,11 @@ def build_subtype_kind(*,
13161317
ignore_promotions: bool = False,
13171318
erase_instances: bool = False,
13181319
keep_erased_types: bool = False) -> SubtypeKind:
1319-
return True, ignore_promotions, erase_instances, keep_erased_types
1320+
return (state.strict_optional,
1321+
True,
1322+
ignore_promotions,
1323+
erase_instances,
1324+
keep_erased_types)
13201325

13211326
def _is_proper_subtype(self, left: Type, right: Type) -> bool:
13221327
return is_proper_subtype(left, right,

mypy/suggestions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
)
2828
from typing_extensions import TypedDict
2929

30-
from mypy.state import strict_optional_set
30+
from mypy.state import state
3131
from mypy.types import (
3232
Type, AnyType, TypeOfAny, CallableType, UnionType, NoneType, Instance, TupleType,
3333
TypeVarType, FunctionLike, UninhabitedType,
@@ -439,7 +439,7 @@ def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature:
439439

440440
is_method = bool(node.info) and not node.is_static
441441

442-
with strict_optional_set(graph[mod].options.strict_optional):
442+
with state.strict_optional_set(graph[mod].options.strict_optional):
443443
guesses = self.get_guesses(
444444
is_method,
445445
self.get_starting_type(node),
@@ -454,7 +454,7 @@ def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature:
454454
# Now try to find the return type!
455455
self.try_type(node, best)
456456
returns = get_return_types(self.manager.all_types, node)
457-
with strict_optional_set(graph[mod].options.strict_optional):
457+
with state.strict_optional_set(graph[mod].options.strict_optional):
458458
if returns:
459459
ret_types = generate_type_combinations(returns)
460460
else:
@@ -988,7 +988,7 @@ def refine_union(t: UnionType, s: ProperType) -> Type:
988988

989989
# Turn strict optional on when simplifying the union since we
990990
# don't want to drop Nones.
991-
with strict_optional_set(True):
991+
with state.strict_optional_set(True):
992992
return make_simplified_union(new_items)
993993

994994

mypy/test/testtypes.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from mypy.nodes import ARG_POS, ARG_OPT, ARG_STAR, ARG_STAR2, CONTRAVARIANT, INVARIANT, COVARIANT
1818
from mypy.subtypes import is_subtype, is_more_precise, is_proper_subtype
1919
from mypy.test.typefixture import TypeFixture, InterfaceTypeFixture
20-
from mypy.state import strict_optional_set
20+
from mypy.state import state
2121
from mypy.typeops import true_only, false_only, make_simplified_union
2222

2323

@@ -410,15 +410,15 @@ def test_true_only_of_union(self) -> None:
410410
assert to.items[1] is tup_type
411411

412412
def test_false_only_of_true_type_is_uninhabited(self) -> None:
413-
with strict_optional_set(True):
413+
with state.strict_optional_set(True):
414414
fo = false_only(self.tuple(AnyType(TypeOfAny.special_form)))
415415
assert_type(UninhabitedType, fo)
416416

417417
def test_false_only_tuple(self) -> None:
418-
with strict_optional_set(False):
418+
with state.strict_optional_set(False):
419419
fo = false_only(self.tuple(self.fx.a))
420420
assert_equal(fo, NoneType())
421-
with strict_optional_set(True):
421+
with state.strict_optional_set(True):
422422
fo = false_only(self.tuple(self.fx.a))
423423
assert_equal(fo, UninhabitedType())
424424

@@ -437,7 +437,7 @@ def test_false_only_of_instance(self) -> None:
437437
assert self.fx.a.can_be_true
438438

439439
def test_false_only_of_union(self) -> None:
440-
with strict_optional_set(True):
440+
with state.strict_optional_set(True):
441441
tup_type = self.tuple()
442442
# Union of something that is unknown, something that is always true, something
443443
# that is always false
@@ -1059,9 +1059,9 @@ def test_literal_type(self) -> None:
10591059
# FIX generic interfaces + ranges
10601060

10611061
def assert_meet_uninhabited(self, s: Type, t: Type) -> None:
1062-
with strict_optional_set(False):
1062+
with state.strict_optional_set(False):
10631063
self.assert_meet(s, t, self.fx.nonet)
1064-
with strict_optional_set(True):
1064+
with state.strict_optional_set(True):
10651065
self.assert_meet(s, t, self.fx.uninhabited)
10661066

10671067
def assert_meet(self, s: Type, t: Type, meet: Type) -> None:

mypy/typeops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from mypy.typevars import fill_typevars
2828

29-
from mypy import state
29+
from mypy.state import state
3030

3131

3232
def is_recursive_pair(s: Type, t: Type) -> bool:

mypy/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from mypy.backports import OrderedDict
1414
import mypy.nodes
15-
from mypy import state
15+
from mypy.state import state
1616
from mypy.nodes import (
1717
INVARIANT, SymbolNode, FuncDef,
1818
ArgKind, ARG_POS, ARG_STAR, ARG_STAR2,

mypy/typestate.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from mypy.nodes import TypeInfo
1010
from mypy.types import Instance, TypeAliasType, get_proper_type, Type
1111
from mypy.server.trigger import make_trigger
12-
from mypy import state
1312

1413
# Represents that the 'left' instance is a subtype of the 'right' instance
1514
SubtypeRelationship: _TypeAlias = Tuple[Instance, Instance]
@@ -124,20 +123,29 @@ def reset_all_subtype_caches_for(info: TypeInfo) -> None:
124123

125124
@staticmethod
126125
def is_cached_subtype_check(kind: SubtypeKind, left: Instance, right: Instance) -> bool:
126+
if left.last_known_value is not None or right.last_known_value is not None:
127+
# If there is a literal last known value, give up. There
128+
# will be an unbounded number of potential types to cache,
129+
# making caching less effective.
130+
return False
127131
info = right.type
128-
if info not in TypeState._subtype_caches:
132+
cache = TypeState._subtype_caches.get(info)
133+
if cache is None:
129134
return False
130-
cache = TypeState._subtype_caches[info]
131-
key = (state.strict_optional,) + kind
132-
if key not in cache:
135+
subcache = cache.get(kind)
136+
if subcache is None:
133137
return False
134-
return (left, right) in cache[key]
138+
return (left, right) in subcache
135139

136140
@staticmethod
137141
def record_subtype_cache_entry(kind: SubtypeKind,
138142
left: Instance, right: Instance) -> None:
143+
if left.last_known_value is not None or right.last_known_value is not None:
144+
# These are unlikely to match, due to the large space of
145+
# possible values. Avoid uselessly increasing cache sizes.
146+
return
139147
cache = TypeState._subtype_caches.setdefault(right.type, dict())
140-
cache.setdefault((state.strict_optional,) + kind, set()).add((left, right))
148+
cache.setdefault(kind, set()).add((left, right))
141149

142150
@staticmethod
143151
def reset_protocol_deps() -> None:

mypyc/irbuild/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def f(x: int) -> int:
2525

2626
from mypy.nodes import MypyFile, Expression, ClassDef
2727
from mypy.types import Type
28-
from mypy.state import strict_optional_set
28+
from mypy.state import state
2929
from mypy.build import Graph
3030

3131
from mypyc.common import TOP_LEVEL_NAME
@@ -45,7 +45,7 @@ def f(x: int) -> int:
4545
# The stubs for callable contextmanagers are busted so cast it to the
4646
# right type...
4747
F = TypeVar('F', bound=Callable[..., Any])
48-
strict_optional_dec = cast(Callable[[F], F], strict_optional_set(True))
48+
strict_optional_dec = cast(Callable[[F], F], state.strict_optional_set(True))
4949

5050

5151
@strict_optional_dec # Turn on strict optional for any type manipulations we do

0 commit comments

Comments
 (0)