Skip to content

Commit 9330193

Browse files
authored
Fix narrowing of IntEnum and StrEnum types (#17874)
Fix regression introduced in #17866. It should still be possible to narrow IntEnum and StrEnum types, but only when types match or are disjoint. Add more logic to rule out narrowing when types are ambigous.
1 parent 3c09b32 commit 9330193

File tree

3 files changed

+117
-10
lines changed

3 files changed

+117
-10
lines changed

mypy/checker.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5984,7 +5984,9 @@ def has_no_custom_eq_checks(t: Type) -> bool:
59845984
coerce_only_in_literal_context = True
59855985

59865986
expr_types = [operand_types[i] for i in expr_indices]
5987-
should_narrow_by_identity = all(map(has_no_custom_eq_checks, expr_types))
5987+
should_narrow_by_identity = all(
5988+
map(has_no_custom_eq_checks, expr_types)
5989+
) and not is_ambiguous_mix_of_enums(expr_types)
59885990

59895991
if_map: TypeMap = {}
59905992
else_map: TypeMap = {}
@@ -8604,3 +8606,45 @@ def visit_starred_pattern(self, p: StarredPattern) -> None:
86048606
self.lvalue = True
86058607
p.capture.accept(self)
86068608
self.lvalue = False
8609+
8610+
8611+
def is_ambiguous_mix_of_enums(types: list[Type]) -> bool:
8612+
"""Do types have IntEnum/StrEnum types that are potentially overlapping with other types?
8613+
8614+
If True, we shouldn't attempt type narrowing based on enum values, as it gets
8615+
too ambiguous.
8616+
8617+
For example, return True if there's an 'int' type together with an IntEnum literal.
8618+
However, IntEnum together with a literal of the same IntEnum type is not ambiguous.
8619+
"""
8620+
# We need these things for this to be ambiguous:
8621+
# (1) an IntEnum or StrEnum type
8622+
# (2) either a different IntEnum/StrEnum type or a non-enum type ("<other>")
8623+
#
8624+
# It would be slightly more correct to calculate this separately for IntEnum and
8625+
# StrEnum related types, as an IntEnum can't be confused with a StrEnum.
8626+
return len(_ambiguous_enum_variants(types)) > 1
8627+
8628+
8629+
def _ambiguous_enum_variants(types: list[Type]) -> set[str]:
8630+
result = set()
8631+
for t in types:
8632+
t = get_proper_type(t)
8633+
if isinstance(t, UnionType):
8634+
result.update(_ambiguous_enum_variants(t.items))
8635+
elif isinstance(t, Instance):
8636+
if t.last_known_value:
8637+
result.update(_ambiguous_enum_variants([t.last_known_value]))
8638+
elif t.type.is_enum and any(
8639+
base.fullname in ("enum.IntEnum", "enum.StrEnum") for base in t.type.mro
8640+
):
8641+
result.add(t.type.fullname)
8642+
elif not t.type.is_enum:
8643+
# These might compare equal to IntEnum/StrEnum types (e.g. Decimal), so
8644+
# let's be conservative
8645+
result.add("<other>")
8646+
elif isinstance(t, LiteralType):
8647+
result.update(_ambiguous_enum_variants([t.fallback]))
8648+
else:
8649+
result.add("<other>")
8650+
return result

mypy/typeops.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,14 +1078,6 @@ def custom_special_method(typ: Type, name: str, check_all: bool = False) -> bool
10781078
"""
10791079
typ = get_proper_type(typ)
10801080
if isinstance(typ, Instance):
1081-
if (
1082-
typ.type.is_enum
1083-
and name in ("__eq__", "__ne__")
1084-
and any(base.fullname in ("enum.IntEnum", "enum.StrEnum") for base in typ.type.mro)
1085-
):
1086-
# IntEnum and StrEnum values have non-straightfoward equality, so treat them
1087-
# as if they had custom __eq__ and __ne__
1088-
return True
10891081
method = typ.type.get(name)
10901082
if method and isinstance(method.node, (SYMBOL_FUNCBASE_TYPES, Decorator, Var)):
10911083
if method.node.info:

test-data/unit/check-narrowing.test

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2135,7 +2135,7 @@ else:
21352135
# mypy: strict-equality
21362136
from __future__ import annotations
21372137
from typing import Any
2138-
from enum import IntEnum, StrEnum
2138+
from enum import IntEnum
21392139

21402140
class IE(IntEnum):
21412141
X = 1
@@ -2178,6 +2178,71 @@ def f5(x: int) -> None:
21782178
reveal_type(x) # N: Revealed type is "builtins.int"
21792179
else:
21802180
reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X]"
2181+
2182+
def f6(x: IE) -> None:
2183+
if x == IE.X:
2184+
reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X]"
2185+
else:
2186+
reveal_type(x) # N: Revealed type is "Literal[__main__.IE.Y]"
2187+
[builtins fixtures/primitives.pyi]
2188+
2189+
[case testNarrowingWithIntEnum2]
2190+
# mypy: strict-equality
2191+
from __future__ import annotations
2192+
from typing import Any
2193+
from enum import IntEnum, Enum
2194+
2195+
class MyDecimal: ...
2196+
2197+
class IE(IntEnum):
2198+
X = 1
2199+
Y = 2
2200+
2201+
class IE2(IntEnum):
2202+
X = 1
2203+
Y = 2
2204+
2205+
class E(Enum):
2206+
X = 1
2207+
Y = 2
2208+
2209+
def f1(x: IE | MyDecimal) -> None:
2210+
if x == IE.X:
2211+
reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.MyDecimal]"
2212+
else:
2213+
reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.MyDecimal]"
2214+
2215+
def f2(x: E | bytes) -> None:
2216+
if x == E.X:
2217+
reveal_type(x) # N: Revealed type is "Literal[__main__.E.X]"
2218+
else:
2219+
reveal_type(x) # N: Revealed type is "Union[Literal[__main__.E.Y], builtins.bytes]"
2220+
2221+
def f3(x: IE | IE2) -> None:
2222+
if x == IE.X:
2223+
reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.IE2]"
2224+
else:
2225+
reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.IE2]"
2226+
2227+
def f4(x: IE | E) -> None:
2228+
if x == IE.X:
2229+
reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X]"
2230+
elif x == E.X:
2231+
reveal_type(x) # N: Revealed type is "Literal[__main__.E.X]"
2232+
else:
2233+
reveal_type(x) # N: Revealed type is "Union[Literal[__main__.IE.Y], Literal[__main__.E.Y]]"
2234+
2235+
def f5(x: E | str | int) -> None:
2236+
if x == E.X:
2237+
reveal_type(x) # N: Revealed type is "Literal[__main__.E.X]"
2238+
else:
2239+
reveal_type(x) # N: Revealed type is "Union[Literal[__main__.E.Y], builtins.str, builtins.int]"
2240+
2241+
def f6(x: IE | Any) -> None:
2242+
if x == IE.X:
2243+
reveal_type(x) # N: Revealed type is "Union[__main__.IE, Any]"
2244+
else:
2245+
reveal_type(x) # N: Revealed type is "Union[__main__.IE, Any]"
21812246
[builtins fixtures/primitives.pyi]
21822247

21832248
[case testNarrowingWithStrEnum]
@@ -2205,4 +2270,10 @@ def f3(x: object) -> None:
22052270
reveal_type(x) # N: Revealed type is "builtins.object"
22062271
else:
22072272
reveal_type(x) # N: Revealed type is "builtins.object"
2273+
2274+
def f4(x: SE) -> None:
2275+
if x == SE.A:
2276+
reveal_type(x) # N: Revealed type is "Literal[__main__.SE.A]"
2277+
else:
2278+
reveal_type(x) # N: Revealed type is "Literal[__main__.SE.B]"
22082279
[builtins fixtures/primitives.pyi]

0 commit comments

Comments
 (0)