Skip to content

Commit df76db6

Browse files
committed
Generalize reachability checks to support enums
This diff adds support for performing reachability and narrowing analysis when doing certain enum checks. For example, given the following enum: class Foo(Enum): A = 1 B = 2 ...this pull request will make mypy do the following: x: Foo if x is Foo.A: reveal_type(x) # type: Literal[Foo.A] elif x is Foo.B: reveal_type(x) # type: Literal[Foo.B] else: reveal_type(x) # No output: branch inferred as unreachable This diff does not attempt to perform this same sort of narrowing for equality checks: I suspect implementing those will be harder due to their overridable nature. (E.g. you can define custom `__eq__` methods within Enum subclasses). This pull request also finally adds support for the enum behavior [described in PEP 484][0] and also sort of partially addresses #6366 [0]: https://www.python.org/dev/peps/pep-0484/#support-for-singleton-types-in-unions
1 parent 44172ca commit df76db6

File tree

2 files changed

+330
-11
lines changed

2 files changed

+330
-11
lines changed

mypy/checker.py

Lines changed: 86 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3487,21 +3487,34 @@ def find_isinstance_check(self, node: Expression
34873487
vartype = type_map[expr]
34883488
return self.conditional_callable_type_map(expr, vartype)
34893489
elif isinstance(node, ComparisonExpr):
3490-
# Check for `x is None` and `x is not None`.
3490+
operand_types = [coerce_to_literal(type_map[expr])
3491+
for expr in node.operands if expr in type_map]
3492+
34913493
is_not = node.operators == ['is not']
3492-
if any(is_literal_none(n) for n in node.operands) and (
3493-
is_not or node.operators == ['is']):
3494+
if (is_not or node.operators == ['is']) and len(operand_types) == len(node.operands):
34943495
if_vars = {} # type: TypeMap
34953496
else_vars = {} # type: TypeMap
3496-
for expr in node.operands:
3497-
if (literal(expr) == LITERAL_TYPE and not is_literal_none(expr)
3498-
and expr in type_map):
3497+
3498+
for i, expr in enumerate(node.operands):
3499+
var_type = operand_types[i]
3500+
other_type = operand_types[1 - i]
3501+
3502+
if literal(expr) == LITERAL_TYPE and is_singleton_type(other_type):
34993503
# This should only be true at most once: there should be
3500-
# two elements in node.operands, and at least one of them
3501-
# should represent a None.
3502-
vartype = type_map[expr]
3503-
none_typ = [TypeRange(NoneType(), is_upper_bound=False)]
3504-
if_vars, else_vars = conditional_type_map(expr, vartype, none_typ)
3504+
# exactly two elements in node.operands and if the 'other type' is
3505+
# a singleton type, it by definition does not need to be narrowed:
3506+
# it already has the most precise type possible so does not need to
3507+
# be narrowed/included in the output map.
3508+
#
3509+
# TODO: Generalize this to handle the case where 'other_type' is
3510+
# a union of singleton types.
3511+
3512+
if isinstance(other_type, LiteralType) and other_type.is_enum_literal():
3513+
fallback_name = other_type.fallback.type.fullname()
3514+
var_type = try_expanding_enum_to_union(var_type, fallback_name)
3515+
3516+
target_type = [TypeRange(other_type, is_upper_bound=False)]
3517+
if_vars, else_vars = conditional_type_map(expr, var_type, target_type)
35053518
break
35063519

35073520
if is_not:
@@ -4438,3 +4451,65 @@ def is_overlapping_types_no_promote(left: Type, right: Type) -> bool:
44384451
def is_private(node_name: str) -> bool:
44394452
"""Check if node is private to class definition."""
44404453
return node_name.startswith('__') and not node_name.endswith('__')
4454+
4455+
4456+
def is_singleton_type(typ: Type) -> bool:
4457+
"""Returns 'true' if this type is a "singleton type" -- if there exists
4458+
exactly only one runtime value associated with this type.
4459+
4460+
That is, given two values 'a' and 'b' that have the same type 't',
4461+
'is_singleton_type(t)' returns True if and only if the expression 'a is b' is
4462+
always true.
4463+
4464+
Currently, this returns True when given NoneTypes and enum LiteralTypes.
4465+
4466+
Note that other kinds of LiteralTypes cannot count as singleton types. For
4467+
example, suppose we do 'a = 100000 + 1' and 'b = 100001'. It is not guaranteed
4468+
that 'a is b' will always be true -- some implementations of Python will end up
4469+
constructing two distinct instances of 100001.
4470+
"""
4471+
# TODO: Also make this return True if the type is a bool LiteralType.
4472+
# Also make this return True if the type corresponds to ... (ellipsis) or NotImplemented?
4473+
return isinstance(typ, NoneType) or (isinstance(typ, LiteralType) and typ.is_enum_literal())
4474+
4475+
4476+
def try_expanding_enum_to_union(typ: Type, target_fullname: str) -> Type:
4477+
"""Attempts to recursively any enum Instances with the given target_fullname
4478+
into a Union of all of its component LiteralTypes.
4479+
4480+
For example, if we have:
4481+
4482+
class Color(Enum):
4483+
RED = 1
4484+
BLUE = 2
4485+
YELLOW = 3
4486+
4487+
...and if we call `try_expanding_enum_to_union(color_instance, 'module.Color')`,
4488+
this function will return Literal[Color.RED, Color.BLUE, Color.YELLOW].
4489+
"""
4490+
if isinstance(typ, UnionType):
4491+
new_items = [try_expanding_enum_to_union(item, target_fullname)
4492+
for item in typ.items]
4493+
return UnionType.make_simplified_union(new_items)
4494+
elif isinstance(typ, Instance) and typ.type.is_enum and typ.type.fullname() == target_fullname:
4495+
new_items = []
4496+
for name, symbol in typ.type.names.items():
4497+
if not isinstance(symbol.node, Var):
4498+
continue
4499+
new_items.append(LiteralType(name, typ))
4500+
return UnionType.make_simplified_union(new_items)
4501+
else:
4502+
return typ
4503+
4504+
4505+
def coerce_to_literal(typ: Type) -> Type:
4506+
"""Recursively converts any Instances that have a last_known_value into the
4507+
corresponding LiteralType.
4508+
"""
4509+
if isinstance(typ, UnionType):
4510+
new_items = [coerce_to_literal(item) for item in typ.items]
4511+
return UnionType.make_simplified_union(new_items)
4512+
elif isinstance(typ, Instance) and typ.last_known_value:
4513+
return typ.last_known_value
4514+
else:
4515+
return typ

test-data/unit/check-enum.test

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,3 +610,247 @@ class SomeEnum(Enum):
610610
main:2: error: Revealed type is 'builtins.int'
611611
[out2]
612612
main:2: error: Revealed type is 'builtins.str'
613+
614+
[case testEnumReachabilityChecksBasic]
615+
from enum import Enum
616+
from typing_extensions import Literal
617+
618+
class Foo(Enum):
619+
A = 1
620+
B = 2
621+
C = 3
622+
623+
x: Literal[Foo.A, Foo.B, Foo.C]
624+
if x is Foo.A:
625+
reveal_type(x) # E: Revealed type is 'Literal[__main__.Foo.A]'
626+
elif x is Foo.B:
627+
reveal_type(x) # E: Revealed type is 'Literal[__main__.Foo.B]'
628+
elif x is Foo.C:
629+
reveal_type(x) # E: Revealed type is 'Literal[__main__.Foo.C]'
630+
else:
631+
reveal_type(x) # No output here: this branch is unreachable
632+
633+
if Foo.A is x:
634+
reveal_type(x) # E: Revealed type is 'Literal[__main__.Foo.A]'
635+
elif Foo.B is x:
636+
reveal_type(x) # E: Revealed type is 'Literal[__main__.Foo.B]'
637+
elif Foo.C is x:
638+
reveal_type(x) # E: Revealed type is 'Literal[__main__.Foo.C]'
639+
else:
640+
reveal_type(x) # No output here: this branch is unreachable
641+
642+
y: Foo
643+
if y is Foo.A:
644+
reveal_type(y) # E: Revealed type is 'Literal[__main__.Foo.A]'
645+
elif y is Foo.B:
646+
reveal_type(y) # E: Revealed type is 'Literal[__main__.Foo.B]'
647+
elif y is Foo.C:
648+
reveal_type(y) # E: Revealed type is 'Literal[__main__.Foo.C]'
649+
else:
650+
reveal_type(y) # No output here: this branch is unreachable
651+
652+
if Foo.A is y:
653+
reveal_type(y) # E: Revealed type is 'Literal[__main__.Foo.A]'
654+
elif Foo.B is y:
655+
reveal_type(y) # E: Revealed type is 'Literal[__main__.Foo.B]'
656+
elif Foo.C is y:
657+
reveal_type(y) # E: Revealed type is 'Literal[__main__.Foo.C]'
658+
else:
659+
reveal_type(y) # No output here: this branch is unreachable
660+
[builtins fixtures/bool.pyi]
661+
662+
[case testEnumReachabilityChecksIndirect]
663+
from enum import Enum
664+
from typing_extensions import Literal, Final
665+
666+
class Foo(Enum):
667+
A = 1
668+
B = 2
669+
C = 3
670+
671+
def accepts_foo_a(x: Literal[Foo.A]) -> None: ...
672+
673+
x: Foo
674+
y: Literal[Foo.A]
675+
z: Final = Foo.A
676+
677+
if x is y:
678+
reveal_type(x) # E: Revealed type is 'Literal[__main__.Foo.A]'
679+
reveal_type(y) # E: Revealed type is 'Literal[__main__.Foo.A]'
680+
else:
681+
reveal_type(x) # E: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
682+
reveal_type(y) # E: Revealed type is 'Literal[__main__.Foo.A]'
683+
if y is x:
684+
reveal_type(x) # E: Revealed type is 'Literal[__main__.Foo.A]'
685+
reveal_type(y) # E: Revealed type is 'Literal[__main__.Foo.A]'
686+
else:
687+
reveal_type(x) # E: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
688+
reveal_type(y) # E: Revealed type is 'Literal[__main__.Foo.A]'
689+
690+
if x is z:
691+
reveal_type(x) # E: Revealed type is 'Literal[__main__.Foo.A]'
692+
reveal_type(z) # E: Revealed type is '__main__.Foo'
693+
accepts_foo_a(z)
694+
else:
695+
reveal_type(x) # E: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
696+
reveal_type(z) # E: Revealed type is '__main__.Foo'
697+
accepts_foo_a(z)
698+
if z is x:
699+
reveal_type(x) # E: Revealed type is 'Literal[__main__.Foo.A]'
700+
reveal_type(z) # E: Revealed type is '__main__.Foo'
701+
accepts_foo_a(z)
702+
else:
703+
reveal_type(x) # E: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
704+
reveal_type(z) # E: Revealed type is '__main__.Foo'
705+
accepts_foo_a(z)
706+
707+
if y is z:
708+
reveal_type(y) # E: Revealed type is 'Literal[__main__.Foo.A]'
709+
reveal_type(z) # E: Revealed type is '__main__.Foo'
710+
accepts_foo_a(z)
711+
else:
712+
reveal_type(y) # No output: this branch is unreachable
713+
reveal_type(z) # No output: this branch is unreachable
714+
if z is y:
715+
reveal_type(y) # E: Revealed type is 'Literal[__main__.Foo.A]'
716+
reveal_type(z) # E: Revealed type is '__main__.Foo'
717+
accepts_foo_a(z)
718+
else:
719+
reveal_type(y) # No output: this branch is unreachable
720+
reveal_type(z) # No output: this branch is unreachable
721+
[builtins fixtures/bool.pyi]
722+
723+
[case testEnumReachabilityNoNarrowingForUnionMessiness]
724+
from enum import Enum
725+
from typing_extensions import Literal
726+
727+
class Foo(Enum):
728+
A = 1
729+
B = 2
730+
C = 3
731+
732+
x: Foo
733+
y: Literal[Foo.A, Foo.B]
734+
z: Literal[Foo.B, Foo.C]
735+
736+
# For the sake of simplicity, no narrowing is done when the narrower type is a Union.
737+
if x is y:
738+
reveal_type(x) # E: Revealed type is '__main__.Foo'
739+
reveal_type(y) # E: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
740+
else:
741+
reveal_type(x) # E: Revealed type is '__main__.Foo'
742+
reveal_type(y) # E: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
743+
744+
if y is z:
745+
reveal_type(y) # E: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
746+
reveal_type(z) # E: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
747+
else:
748+
reveal_type(y) # E: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
749+
reveal_type(z) # E: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
750+
[builtins fixtures/bool.pyi]
751+
752+
[case testEnumReachabilityWithNone]
753+
# flags: --strict-optional
754+
from enum import Enum
755+
from typing import Optional
756+
757+
class Foo(Enum):
758+
A = 1
759+
B = 2
760+
C = 3
761+
762+
x: Optional[Foo]
763+
if x:
764+
reveal_type(x) # E: Revealed type is '__main__.Foo'
765+
else:
766+
reveal_type(x) # E: Revealed type is 'Union[__main__.Foo, None]'
767+
768+
if x is not None:
769+
reveal_type(x) # E: Revealed type is '__main__.Foo'
770+
else:
771+
reveal_type(x) # E: Revealed type is 'None'
772+
773+
if x is Foo.A:
774+
reveal_type(x) # E: Revealed type is 'Literal[__main__.Foo.A]'
775+
else:
776+
reveal_type(x) # E: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C], None]'
777+
[builtins fixtures/bool.pyi]
778+
779+
[case testEnumReachabilityWithMultipleEnums]
780+
from enum import Enum
781+
from typing import Union
782+
from typing_extensions import Literal
783+
784+
class Foo(Enum):
785+
A = 1
786+
B = 2
787+
class Bar(Enum):
788+
A = 1
789+
B = 2
790+
791+
x1: Union[Foo, Bar]
792+
if x1 is Foo.A:
793+
reveal_type(x1) # E: Revealed type is 'Literal[__main__.Foo.A]'
794+
else:
795+
reveal_type(x1) # E: Revealed type is 'Union[Literal[__main__.Foo.B], __main__.Bar]'
796+
797+
x2: Union[Foo, Bar]
798+
if x2 is Bar.A:
799+
reveal_type(x2) # E: Revealed type is 'Literal[__main__.Bar.A]'
800+
else:
801+
reveal_type(x2) # E: Revealed type is 'Union[__main__.Foo, Literal[__main__.Bar.B]]'
802+
803+
x3: Union[Foo, Bar]
804+
if x3 is Foo.A or x3 is Bar.A:
805+
reveal_type(x3) # E: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Bar.A]]'
806+
else:
807+
reveal_type(x3) # E: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Bar.B]]'
808+
809+
[builtins fixtures/bool.pyi]
810+
811+
[case testEnumReachabilityPEP484Example1]
812+
# flags: --strict-optional
813+
from typing import Union
814+
from typing_extensions import Final
815+
from enum import Enum
816+
817+
class Empty(Enum):
818+
token = 0
819+
_empty: Final = Empty.token
820+
821+
def func(x: Union[int, None, Empty] = _empty) -> int:
822+
boom = x + 42 # E: Unsupported left operand type for + ("None") \
823+
# E: Unsupported left operand type for + ("Empty") \
824+
# N: Left operand is of type "Union[int, None, Empty]"
825+
if x is _empty:
826+
reveal_type(x) # E: Revealed type is 'Literal[__main__.Empty.token]'
827+
return 0
828+
elif x is None:
829+
reveal_type(x) # E: Revealed type is 'None'
830+
return 1
831+
else: # At this point typechecker knows that x can only have type int
832+
reveal_type(x) # E: Revealed type is 'builtins.int'
833+
return x + 2
834+
[builtins fixtures/primitives.pyi]
835+
836+
[case testEnumReachabilityPEP484Example2]
837+
from typing import Union
838+
from enum import Enum
839+
840+
class Reason(Enum):
841+
timeout = 1
842+
error = 2
843+
844+
def process(response: Union[str, Reason] = '') -> str:
845+
if response is Reason.timeout:
846+
reveal_type(response) # E: Revealed type is 'Literal[__main__.Reason.timeout]'
847+
return 'TIMEOUT'
848+
elif response is Reason.error:
849+
reveal_type(response) # E: Revealed type is 'Literal[__main__.Reason.error]'
850+
return 'ERROR'
851+
else:
852+
# response can be only str, all other possible values exhausted
853+
reveal_type(response) # E: Revealed type is 'builtins.str'
854+
return 'PROCESSED: ' + response
855+
856+
[builtins fixtures/primitives.pyi]

0 commit comments

Comments
 (0)