Skip to content

Commit 68656e8

Browse files
committed
Force enum literals to simplify when inferring unions
While working on overhauling #7169, I discovered that simply just "deconstructing" enums into unions leads to some false positives in some real-world code. This is an existing problem, but became more prominent as I worked on improving type inference in the above PR. Here's a simplified example of one such problem I ran into: ``` from enum import Enum class Foo(Enum): A = 1 B = 2 class Wrapper: def __init__(self, x: bool, y: Foo) -> None: if x: if y is Foo.A: # 'y' is of type Literal[Foo.A] here pass else: # ...and of type Literal[Foo.B] here pass # We join these two types after the if/else to end up with # Literal[Foo.A, Foo.B] self.y = y else: # ...and so this fails! 'Foo' is not considered a subtype of # 'Literal[Foo.A, Foo.B]' self.y = y ``` I considered three different ways of fixing this: 1. Modify our various type comparison operations (`is_same`, `is_subtype`, `is_proper_subtype`, etc...) to consider `Foo` and `Literal[Foo.A, Foo.B]` equivalent. 2. Modify the 'join' logic so that when we join enum literals, we check and see if we can merge them back into the original class, undoing the "deconstruction". 3. Modify the `make_simplified_union` logic to do the reconstruction instead. I rejected the first two options: the first approach is the most sound one, but seemed complicated to implement. We have a lot of different type comparison operations and attempting to modify them all seems error-prone. I also didn't really like the idea of having two equally valid representations of the same type, and would rather push mypy to always standardize on one, just from a usability point of view. The second option seemed workable but limited to me. Modifying join would fix the specific example above, but I wasn't confident that was the only place we'd needed to patch. So I went with modifying `make_simplified_union` instead. The main disadvantage of this approach is that we still get false positives when working with Unions that come directly from the semantic analysis phase. For example, we still get an error with the following program: x: Literal[Foo.A, Foo.B] y: Foo # Error, we still think 'x' is of type 'Literal[Foo.A, Foo.B]' x = y But I think this is an acceptable tradeoff for now: I can't imagine too many people running into this. But if they do, we can always explore finding a way of simplifying unions after the semantic analysis phase or bite the bullet and implement approach 1.
1 parent 84126ab commit 68656e8

File tree

3 files changed

+148
-8
lines changed

3 files changed

+148
-8
lines changed

mypy/checker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from mypy.typeops import (
4949
map_type_from_supertype, bind_self, erase_to_bound, make_simplified_union,
5050
erase_def_to_union_or_bound, erase_to_union_or_bound,
51-
true_only, false_only, function_type,
51+
true_only, false_only, function_type, get_enum_values,
5252
)
5353
from mypy import message_registry
5454
from mypy.subtypes import (

mypy/typeops.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
since these may assume that MROs are ready.
66
"""
77

8-
from typing import cast, Optional, List, Sequence, Set
8+
from typing import cast, Optional, List, Sequence, Set, Dict
99
import sys
1010

1111
from mypy.types import (
@@ -300,6 +300,11 @@ def make_simplified_union(items: Sequence[Type],
300300
* [int, int] -> int
301301
* [int, Any] -> Union[int, Any] (Any types are not simplified away!)
302302
* [Any, Any] -> Any
303+
* [Literal[Foo.A], Literal[Foo.B]] -> Foo (assuming Foo is a enum with two variants A and B)
304+
305+
Note that we only collapse enum literals into the original enum when all literal variants
306+
are present. Since enums are effectively final and there are a fixed number of possible
307+
variants, it's safe to treat those two types as equivalent.
303308
304309
Note: This must NOT be used during semantic analysis, since TypeInfos may not
305310
be fully initialized.
@@ -316,6 +321,8 @@ def make_simplified_union(items: Sequence[Type],
316321

317322
from mypy.subtypes import is_proper_subtype
318323

324+
enums_found = {} # type: Dict[str, int]
325+
enum_max_members = {} # type: Dict[str, int]
319326
removed = set() # type: Set[int]
320327
for i, ti in enumerate(items):
321328
if i in removed: continue
@@ -327,13 +334,52 @@ def make_simplified_union(items: Sequence[Type],
327334
removed.add(j)
328335
cbt = cbt or tj.can_be_true
329336
cbf = cbf or tj.can_be_false
337+
330338
# if deleted subtypes had more general truthiness, use that
331339
if not ti.can_be_true and cbt:
332-
items[i] = true_or_false(ti)
340+
items[i] = ti = true_or_false(ti)
333341
elif not ti.can_be_false and cbf:
334-
items[i] = true_or_false(ti)
342+
items[i] = ti = true_or_false(ti)
343+
344+
# Keep track of all enum Literal types we encounter, in case
345+
# we can coalesce them together
346+
if isinstance(ti, LiteralType) and ti.is_enum_literal():
347+
enum_name = ti.fallback.type.fullname()
348+
if enum_name not in enum_max_members:
349+
enum_max_members[enum_name] = len(get_enum_values(ti.fallback))
350+
enums_found[enum_name] = enums_found.get(enum_name, 0) + 1
351+
if isinstance(ti, Instance) and ti.type.is_enum:
352+
enum_name = ti.type.fullname()
353+
if enum_name not in enum_max_members:
354+
enum_max_members[enum_name] = len(get_enum_values(ti))
355+
enums_found[enum_name] = enum_max_members[enum_name]
356+
357+
enums_to_compress = {n for (n, c) in enums_found.items() if c >= enum_max_members[n]}
358+
enums_encountered = set() # type: Set[str]
359+
simplified_set = [] # type: List[ProperType]
360+
for i, item in enumerate(items):
361+
if i in removed:
362+
continue
363+
364+
# Try seeing if this is an enum or enum literal, and if it's
365+
# one we should be collapsing away.
366+
if isinstance(item, LiteralType):
367+
instance = item.fallback # type: Optional[Instance]
368+
elif isinstance(item, Instance):
369+
instance = item
370+
else:
371+
instance = None
372+
373+
if instance and instance.type.is_enum:
374+
enum_name = instance.type.fullname()
375+
if enum_name in enums_encountered:
376+
continue
377+
if enum_name in enums_to_compress:
378+
simplified_set.append(instance)
379+
enums_encountered.add(enum_name)
380+
continue
381+
simplified_set.append(item)
335382

336-
simplified_set = [items[i] for i in range(len(items)) if i not in removed]
337383
return UnionType.make_union(simplified_set, line, column)
338384

339385

test-data/unit/check-enum.test

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,7 @@ elif x is Foo.C:
629629
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]'
630630
else:
631631
reveal_type(x) # No output here: this branch is unreachable
632+
reveal_type(x) # N: Revealed type is '__main__.Foo'
632633

633634
if Foo.A is x:
634635
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
@@ -638,6 +639,7 @@ elif Foo.C is x:
638639
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]'
639640
else:
640641
reveal_type(x) # No output here: this branch is unreachable
642+
reveal_type(x) # N: Revealed type is '__main__.Foo'
641643

642644
y: Foo
643645
if y is Foo.A:
@@ -648,6 +650,7 @@ elif y is Foo.C:
648650
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]'
649651
else:
650652
reveal_type(y) # No output here: this branch is unreachable
653+
reveal_type(y) # N: Revealed type is '__main__.Foo'
651654

652655
if Foo.A is y:
653656
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
@@ -657,6 +660,7 @@ elif Foo.C is y:
657660
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]'
658661
else:
659662
reveal_type(y) # No output here: this branch is unreachable
663+
reveal_type(y) # N: Revealed type is '__main__.Foo'
660664
[builtins fixtures/bool.pyi]
661665

662666
[case testEnumReachabilityChecksIndirect]
@@ -686,6 +690,8 @@ if y is x:
686690
else:
687691
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
688692
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
693+
reveal_type(x) # N: Revealed type is '__main__.Foo'
694+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
689695

690696
if x is z:
691697
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
@@ -703,6 +709,8 @@ else:
703709
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
704710
reveal_type(z) # N: Revealed type is '__main__.Foo*'
705711
accepts_foo_a(z)
712+
reveal_type(x) # N: Revealed type is '__main__.Foo'
713+
reveal_type(z) # N: Revealed type is '__main__.Foo*'
706714

707715
if y is z:
708716
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
@@ -718,6 +726,8 @@ if z is y:
718726
else:
719727
reveal_type(y) # No output: this branch is unreachable
720728
reveal_type(z) # No output: this branch is unreachable
729+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
730+
reveal_type(z) # N: Revealed type is '__main__.Foo*'
721731
[builtins fixtures/bool.pyi]
722732

723733
[case testEnumReachabilityNoNarrowingForUnionMessiness]
@@ -740,13 +750,17 @@ if x is y:
740750
else:
741751
reveal_type(x) # N: Revealed type is '__main__.Foo'
742752
reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
753+
reveal_type(x) # N: Revealed type is '__main__.Foo'
754+
reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
743755

744756
if y is z:
745757
reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
746758
reveal_type(z) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
747759
else:
748760
reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
749761
reveal_type(z) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
762+
reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
763+
reveal_type(z) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
750764
[builtins fixtures/bool.pyi]
751765

752766
[case testEnumReachabilityWithNone]
@@ -764,16 +778,19 @@ if x:
764778
reveal_type(x) # N: Revealed type is '__main__.Foo'
765779
else:
766780
reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]'
781+
reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]'
767782

768783
if x is not None:
769784
reveal_type(x) # N: Revealed type is '__main__.Foo'
770785
else:
771786
reveal_type(x) # N: Revealed type is 'None'
787+
reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]'
772788

773789
if x is Foo.A:
774790
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
775791
else:
776792
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C], None]'
793+
reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]'
777794
[builtins fixtures/bool.pyi]
778795

779796
[case testEnumReachabilityWithMultipleEnums]
@@ -793,18 +810,21 @@ if x1 is Foo.A:
793810
reveal_type(x1) # N: Revealed type is 'Literal[__main__.Foo.A]'
794811
else:
795812
reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.B], __main__.Bar]'
813+
reveal_type(x1) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]'
796814

797815
x2: Union[Foo, Bar]
798816
if x2 is Bar.A:
799817
reveal_type(x2) # N: Revealed type is 'Literal[__main__.Bar.A]'
800818
else:
801819
reveal_type(x2) # N: Revealed type is 'Union[__main__.Foo, Literal[__main__.Bar.B]]'
820+
reveal_type(x2) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]'
802821

803822
x3: Union[Foo, Bar]
804823
if x3 is Foo.A or x3 is Bar.A:
805824
reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Bar.A]]'
806825
else:
807826
reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Bar.B]]'
827+
reveal_type(x3) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]'
808828

809829
[builtins fixtures/bool.pyi]
810830

@@ -823,7 +843,7 @@ def func(x: Union[int, None, Empty] = _empty) -> int:
823843
# E: Unsupported left operand type for + ("Empty") \
824844
# N: Left operand is of type "Union[int, None, Empty]"
825845
if x is _empty:
826-
reveal_type(x) # N: Revealed type is 'Literal[__main__.Empty.token]'
846+
reveal_type(x) # N: Revealed type is '__main__.Empty'
827847
return 0
828848
elif x is None:
829849
reveal_type(x) # N: Revealed type is 'None'
@@ -870,7 +890,7 @@ def func(x: Union[int, None, Empty] = _empty) -> int:
870890
# E: Unsupported left operand type for + ("Empty") \
871891
# N: Left operand is of type "Union[int, None, Empty]"
872892
if x is _empty:
873-
reveal_type(x) # N: Revealed type is 'Literal[__main__.Empty.token]'
893+
reveal_type(x) # N: Revealed type is '__main__.Empty'
874894
return 0
875895
elif x is None:
876896
reveal_type(x) # N: Revealed type is 'None'
@@ -899,7 +919,7 @@ def func(x: Union[int, None, Empty] = _empty) -> int:
899919
# E: Unsupported left operand type for + ("Empty") \
900920
# N: Left operand is of type "Union[int, None, Empty]"
901921
if x is _empty:
902-
reveal_type(x) # N: Revealed type is 'Literal[__main__.Empty.token]'
922+
reveal_type(x) # N: Revealed type is '__main__.Empty'
903923
return 0
904924
elif x is None:
905925
reveal_type(x) # N: Revealed type is 'None'
@@ -908,3 +928,77 @@ def func(x: Union[int, None, Empty] = _empty) -> int:
908928
reveal_type(x) # N: Revealed type is 'builtins.int'
909929
return x + 2
910930
[builtins fixtures/primitives.pyi]
931+
932+
[case testEnumUnionCompression]
933+
from typing import Union
934+
from typing_extensions import Literal
935+
from enum import Enum
936+
937+
class Foo(Enum):
938+
A = 1
939+
B = 2
940+
C = 3
941+
942+
class Bar(Enum):
943+
X = 1
944+
Y = 2
945+
946+
x1: Literal[Foo.A, Foo.B, Foo.B, Foo.B, 1, None]
947+
assert x1 is not None
948+
reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B], Literal[1]]'
949+
950+
x2: Literal[1, Foo.A, Foo.B, Foo.C, None]
951+
assert x2 is not None
952+
reveal_type(x2) # N: Revealed type is 'Union[Literal[1], __main__.Foo]'
953+
954+
x3: Literal[Foo.A, Foo.B, 1, Foo.C, Foo.C, Foo.C, None]
955+
assert x3 is not None
956+
reveal_type(x3) # N: Revealed type is 'Union[__main__.Foo, Literal[1]]'
957+
958+
x4: Literal[Foo.A, Foo.B, Foo.C, Foo.C, Foo.C, None]
959+
assert x4 is not None
960+
reveal_type(x4) # N: Revealed type is '__main__.Foo'
961+
962+
x5: Union[Literal[Foo.A], Foo, None]
963+
assert x5 is not None
964+
reveal_type(x5) # N: Revealed type is '__main__.Foo'
965+
966+
x6: Literal[Foo.A, Bar.X, Foo.B, Bar.Y, Foo.C, None]
967+
assert x6 is not None
968+
reveal_type(x6) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]'
969+
970+
# TODO: We should really simplify this down into just 'Bar' as well.
971+
no_forcing: Literal[Bar.X, Bar.X, Bar.Y]
972+
reveal_type(no_forcing) # N: Revealed type is 'Union[Literal[__main__.Bar.X], Literal[__main__.Bar.X], Literal[__main__.Bar.Y]]'
973+
974+
[case testEnumUnionCompressionAssignment]
975+
from typing_extensions import Literal
976+
from enum import Enum
977+
978+
class Foo(Enum):
979+
A = 1
980+
B = 2
981+
982+
class Wrapper1:
983+
def __init__(self, x: object, y: Foo) -> None:
984+
if x:
985+
if y is Foo.A:
986+
pass
987+
else:
988+
pass
989+
self.y = y
990+
else:
991+
self.y = y
992+
reveal_type(self.y) # N: Revealed type is '__main__.Foo'
993+
994+
class Wrapper2:
995+
def __init__(self, x: object, y: Foo) -> None:
996+
if x:
997+
self.y = y
998+
else:
999+
if y is Foo.A:
1000+
pass
1001+
else:
1002+
pass
1003+
self.y = y
1004+
reveal_type(self.y) # N: Revealed type is '__main__.Foo'

0 commit comments

Comments
 (0)