Skip to content

Commit 8033909

Browse files
committed
Improve match narrowing and reachability analysis
Fixes #12534, #15878
1 parent 854a9f8 commit 8033909

File tree

2 files changed

+60
-2
lines changed

2 files changed

+60
-2
lines changed

mypy/checker.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4944,7 +4944,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
49444944
self.push_type_map(pattern_map)
49454945
self.push_type_map(pattern_type.captures)
49464946
if g is not None:
4947-
with self.binder.frame_context(can_skip=True, fall_through=3):
4947+
with self.binder.frame_context(can_skip=False, fall_through=3):
49484948
gt = get_proper_type(self.expr_checker.accept(g))
49494949

49504950
if isinstance(gt, DeletedType):
@@ -4953,6 +4953,21 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
49534953
guard_map, guard_else_map = self.find_isinstance_check(g)
49544954
else_map = or_conditional_maps(else_map, guard_else_map)
49554955

4956+
# If the guard narrowed the subject, copy the narrowed types over
4957+
if isinstance(p, AsPattern):
4958+
case_target = p.pattern or p.name
4959+
if isinstance(case_target, NameExpr):
4960+
for type_map in (guard_map, else_map):
4961+
if not type_map:
4962+
continue
4963+
for expr in list(type_map):
4964+
if not (
4965+
isinstance(expr, NameExpr)
4966+
and expr.fullname == case_target.fullname
4967+
):
4968+
continue
4969+
type_map[s.subject] = type_map[expr]
4970+
49564971
self.push_type_map(guard_map)
49574972
self.accept(b)
49584973
else:

test-data/unit/check-python310.test

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1372,7 +1372,7 @@ match m:
13721372
reveal_type(m) # N: Revealed type is "__main__.Medal"
13731373

13741374
[case testMatchNarrowUsingPatternGuardSpecialCase]
1375-
def f(x: int | str) -> int: # E: Missing return statement
1375+
def f(x: int | str) -> int:
13761376
match x:
13771377
case x if isinstance(x, str):
13781378
return 0
@@ -1973,3 +1973,46 @@ def f2(x: T) -> None:
19731973
case DataFrame(): # type: ignore[misc]
19741974
pass
19751975
[builtins fixtures/primitives.pyi]
1976+
1977+
[case testMatchGuardReachability]
1978+
# flags: --warn-unreachable
1979+
def f1(e: int) -> int:
1980+
match e:
1981+
case x if True:
1982+
return x
1983+
case _:
1984+
return 0 # E: Statement is unreachable
1985+
e = 0 # E: Statement is unreachable
1986+
1987+
1988+
def f2(e: int) -> int:
1989+
match e:
1990+
case x if bool():
1991+
return x
1992+
case _:
1993+
return 0
1994+
e = 0 # E: Statement is unreachable
1995+
1996+
def f3(e: int | str | bytes) -> int:
1997+
match e:
1998+
case x if isinstance(x, int):
1999+
return x
2000+
case [x]:
2001+
return 0 # E: Statement is unreachable
2002+
case str(x):
2003+
return 0
2004+
reveal_type(e) # N: Revealed type is "builtins.bytes"
2005+
return 0
2006+
2007+
def f4(e: int | str | bytes) -> int:
2008+
match e:
2009+
case int(x):
2010+
pass
2011+
case [x]:
2012+
return 0 # E: Statement is unreachable
2013+
case x if isinstance(x, str):
2014+
return 0
2015+
reveal_type(e) # N: Revealed type is "Union[builtins.int, builtins.bytes]"
2016+
return 0
2017+
2018+
[builtins fixtures/primitives.pyi]

0 commit comments

Comments
 (0)