Skip to content

Commit 3680449

Browse files
JukkaLsobolevn
andauthored
Simplify unions when erasing last known values (#12064)
When we erase last known values in an union with multiple Instance types, make sure that the resulting union doesn't have duplicate erased types. The duplicate items weren't incorrect as such, but they could cause overly complex error messages and potentially slow type checking performance. This is one of the fixes extracted from #12054. Since some of the changes may cause regressions, it's better to split the PR. Work on #12051. Co-authored-by: Nikita Sobolev <[email protected]>
1 parent af366c0 commit 3680449

File tree

4 files changed

+95
-4
lines changed

4 files changed

+95
-4
lines changed

mypy/erasetype.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from typing import Optional, Container, Callable
1+
from typing import Optional, Container, Callable, List, Dict, cast
22

33
from mypy.types import (
44
Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarId, Instance, TypeVarType,
55
CallableType, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType,
66
DeletedType, TypeTranslator, UninhabitedType, TypeType, TypeOfAny, LiteralType, ProperType,
7-
get_proper_type, TypeAliasType, ParamSpecType
7+
get_proper_type, get_proper_types, TypeAliasType, ParamSpecType
88
)
99
from mypy.nodes import ARG_STAR, ARG_STAR2
1010

@@ -161,3 +161,34 @@ def visit_type_alias_type(self, t: TypeAliasType) -> Type:
161161
# Type aliases can't contain literal values, because they are
162162
# always constructed as explicit types.
163163
return t
164+
165+
def visit_union_type(self, t: UnionType) -> Type:
166+
new = cast(UnionType, super().visit_union_type(t))
167+
# Erasure can result in many duplicate items; merge them.
168+
# Call make_simplified_union only on lists of instance types
169+
# that all have the same fullname, to avoid simplifying too
170+
# much.
171+
instances = [item for item in new.items
172+
if isinstance(get_proper_type(item), Instance)]
173+
# Avoid merge in simple cases such as optional types.
174+
if len(instances) > 1:
175+
instances_by_name: Dict[str, List[Instance]] = {}
176+
new_items = get_proper_types(new.items)
177+
for item in new_items:
178+
if isinstance(item, Instance) and not item.args:
179+
instances_by_name.setdefault(item.type.fullname, []).append(item)
180+
merged: List[Type] = []
181+
for item in new_items:
182+
if isinstance(item, Instance) and not item.args:
183+
types = instances_by_name.get(item.type.fullname)
184+
if types is not None:
185+
if len(types) == 1:
186+
merged.append(item)
187+
else:
188+
from mypy.typeops import make_simplified_union
189+
merged.append(make_simplified_union(types))
190+
del instances_by_name[item.type.fullname]
191+
else:
192+
merged.append(item)
193+
return UnionType.make_union(merged)
194+
return new

mypy/test/testtypes.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33
from typing import List, Tuple
44

55
from mypy.test.helpers import Suite, assert_equal, assert_type, skip
6-
from mypy.erasetype import erase_type
6+
from mypy.erasetype import erase_type, remove_instance_last_known_values
77
from mypy.expandtype import expand_type
88
from mypy.join import join_types, join_simple
99
from mypy.meet import meet_types, narrow_declared_type
1010
from mypy.sametypes import is_same_type
1111
from mypy.indirection import TypeIndirectionVisitor
1212
from mypy.types import (
1313
UnboundType, AnyType, CallableType, TupleType, TypeVarType, Type, Instance, NoneType,
14-
Overloaded, TypeType, UnionType, UninhabitedType, TypeVarId, TypeOfAny, get_proper_type
14+
Overloaded, TypeType, UnionType, UninhabitedType, TypeVarId, TypeOfAny, ProperType,
15+
get_proper_type
1516
)
1617
from mypy.nodes import ARG_POS, ARG_OPT, ARG_STAR, ARG_STAR2, CONTRAVARIANT, INVARIANT, COVARIANT
1718
from mypy.subtypes import is_subtype, is_more_precise, is_proper_subtype
@@ -1092,3 +1093,46 @@ def assert_simple_is_same(self, s: Type, t: Type, expected: bool, strict: bool)
10921093
'({} == {}) is {{}} ({{}} expected)'.format(s, t))
10931094
assert_equal(hash(s) == hash(t), expected,
10941095
'(hash({}) == hash({}) is {{}} ({{}} expected)'.format(s, t))
1096+
1097+
1098+
class RemoveLastKnownValueSuite(Suite):
1099+
def setUp(self) -> None:
1100+
self.fx = TypeFixture()
1101+
1102+
def test_optional(self) -> None:
1103+
t = UnionType.make_union([self.fx.a, self.fx.nonet])
1104+
self.assert_union_result(t, [self.fx.a, self.fx.nonet])
1105+
1106+
def test_two_instances(self) -> None:
1107+
t = UnionType.make_union([self.fx.a, self.fx.b])
1108+
self.assert_union_result(t, [self.fx.a, self.fx.b])
1109+
1110+
def test_multiple_same_instances(self) -> None:
1111+
t = UnionType.make_union([self.fx.a, self.fx.a])
1112+
assert remove_instance_last_known_values(t) == self.fx.a
1113+
t = UnionType.make_union([self.fx.a, self.fx.a, self.fx.b])
1114+
self.assert_union_result(t, [self.fx.a, self.fx.b])
1115+
t = UnionType.make_union([self.fx.a, self.fx.nonet, self.fx.a, self.fx.b])
1116+
self.assert_union_result(t, [self.fx.a, self.fx.nonet, self.fx.b])
1117+
1118+
def test_single_last_known_value(self) -> None:
1119+
t = UnionType.make_union([self.fx.lit1_inst, self.fx.nonet])
1120+
self.assert_union_result(t, [self.fx.a, self.fx.nonet])
1121+
1122+
def test_last_known_values_with_merge(self) -> None:
1123+
t = UnionType.make_union([self.fx.lit1_inst, self.fx.lit2_inst, self.fx.lit4_inst])
1124+
assert remove_instance_last_known_values(t) == self.fx.a
1125+
t = UnionType.make_union([self.fx.lit1_inst,
1126+
self.fx.b,
1127+
self.fx.lit2_inst,
1128+
self.fx.lit4_inst])
1129+
self.assert_union_result(t, [self.fx.a, self.fx.b])
1130+
1131+
def test_generics(self) -> None:
1132+
t = UnionType.make_union([self.fx.ga, self.fx.gb])
1133+
self.assert_union_result(t, [self.fx.ga, self.fx.gb])
1134+
1135+
def assert_union_result(self, t: ProperType, expected: List[Type]) -> None:
1136+
t2 = remove_instance_last_known_values(t)
1137+
assert type(t2) is UnionType
1138+
assert t2.items == expected

mypy/test/typefixture.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,11 @@ def make_type_var(name: str, id: int, values: List[Type], upper_bound: Type,
157157
self.lit1 = LiteralType(1, self.a)
158158
self.lit2 = LiteralType(2, self.a)
159159
self.lit3 = LiteralType("foo", self.d)
160+
self.lit4 = LiteralType(4, self.a)
160161
self.lit1_inst = Instance(self.ai, [], last_known_value=self.lit1)
161162
self.lit2_inst = Instance(self.ai, [], last_known_value=self.lit2)
162163
self.lit3_inst = Instance(self.di, [], last_known_value=self.lit3)
164+
self.lit4_inst = Instance(self.ai, [], last_known_value=self.lit4)
163165

164166
self.type_a = TypeType.make_normalized(self.a)
165167
self.type_b = TypeType.make_normalized(self.b)

test-data/unit/check-enum.test

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1868,3 +1868,17 @@ class WithOverload(enum.IntEnum):
18681868
class SubWithOverload(WithOverload): # Should pass
18691869
pass
18701870
[builtins fixtures/tuple.pyi]
1871+
1872+
[case testEnumtValueUnionSimplification]
1873+
from enum import IntEnum
1874+
from typing import Any
1875+
1876+
class C(IntEnum):
1877+
X = 0
1878+
Y = 1
1879+
Z = 2
1880+
1881+
def f1(c: C) -> None:
1882+
x = {'x': c.value}
1883+
reveal_type(x) # N: Revealed type is "builtins.dict[builtins.str*, builtins.int]"
1884+
[builtins fixtures/dict.pyi]

0 commit comments

Comments
 (0)