Skip to content

Include walrus assignments in conditional inference #19038

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6512,7 +6512,7 @@ def refine_parent_types(self, expr: Expression, expr_type: Type) -> Mapping[Expr
# and create function that will try replaying the same lookup
# operation against arbitrary types.
if isinstance(expr, MemberExpr):
parent_expr = collapse_walrus(expr.expr)
parent_expr = self._propagate_walrus_assignments(expr.expr, output)
parent_type = self.lookup_type_or_none(parent_expr)
member_name = expr.name

Expand All @@ -6535,9 +6535,10 @@ def replay_lookup(new_parent_type: ProperType) -> Type | None:
return member_type

elif isinstance(expr, IndexExpr):
parent_expr = collapse_walrus(expr.base)
parent_expr = self._propagate_walrus_assignments(expr.base, output)
parent_type = self.lookup_type_or_none(parent_expr)

self._propagate_walrus_assignments(expr.index, output)
index_type = self.lookup_type_or_none(expr.index)
if index_type is None:
return output
Expand Down Expand Up @@ -6611,6 +6612,24 @@ def replay_lookup(new_parent_type: ProperType) -> Type | None:
expr = parent_expr
expr_type = output[parent_expr] = make_simplified_union(new_parent_types)

def _propagate_walrus_assignments(
self, expr: Expression, type_map: dict[Expression, Type]
) -> Expression:
"""Add assignments from walrus expressions to inferred types.

Only considers nested assignment exprs, does not recurse into other types.
This may be added later if necessary by implementing a dedicated visitor.
"""
if isinstance(expr, AssignmentExpr):
if isinstance(expr.value, AssignmentExpr):
self._propagate_walrus_assignments(expr.value, type_map)
assigned_type = self.lookup_type_or_none(expr.value)
parent_expr = collapse_walrus(expr)
if assigned_type is not None:
type_map[parent_expr] = assigned_type
return parent_expr
return expr

def refine_identity_comparison_expression(
self,
operands: list[Expression],
Expand Down
92 changes: 92 additions & 0 deletions test-data/unit/check-inference.test
Original file line number Diff line number Diff line change
Expand Up @@ -3979,3 +3979,95 @@ def check(mapping: Mapping[str, _T]) -> None:
reveal_type(ok1) # N: Revealed type is "Union[_T`-1, builtins.str]"
ok2: Union[_T, str] = mapping.get("", "")
[builtins fixtures/tuple.pyi]

[case testInferWalrusAssignmentAttrInCondition]
class Foo:
def __init__(self, value: bool) -> None:
self.value = value

def check_and(maybe: bool) -> None:
foo = None
if maybe and (foo := Foo(True)).value:
reveal_type(foo) # N: Revealed type is "__main__.Foo"
else:
reveal_type(foo) # N: Revealed type is "Union[__main__.Foo, None]"

def check_and_nested(maybe: bool) -> None:
foo = None
bar = None
baz = None
if maybe and (foo := (bar := (baz := Foo(True)))).value:
reveal_type(foo) # N: Revealed type is "__main__.Foo"
reveal_type(bar) # N: Revealed type is "__main__.Foo"
reveal_type(baz) # N: Revealed type is "__main__.Foo"
else:
reveal_type(foo) # N: Revealed type is "Union[__main__.Foo, None]"
reveal_type(bar) # N: Revealed type is "Union[__main__.Foo, None]"
reveal_type(baz) # N: Revealed type is "Union[__main__.Foo, None]"

def check_or(maybe: bool) -> None:
foo = None
if maybe or (foo := Foo(True)).value:
reveal_type(foo) # N: Revealed type is "Union[__main__.Foo, None]"
else:
reveal_type(foo) # N: Revealed type is "__main__.Foo"

def check_or_nested(maybe: bool) -> None:
foo = None
bar = None
baz = None
if maybe and (foo := (bar := (baz := Foo(True)))).value:
reveal_type(foo) # N: Revealed type is "__main__.Foo"
reveal_type(bar) # N: Revealed type is "__main__.Foo"
reveal_type(baz) # N: Revealed type is "__main__.Foo"
else:
reveal_type(foo) # N: Revealed type is "Union[__main__.Foo, None]"
reveal_type(bar) # N: Revealed type is "Union[__main__.Foo, None]"
reveal_type(baz) # N: Revealed type is "Union[__main__.Foo, None]"

[case testInferWalrusAssignmentIndexInCondition]
def check_and(maybe: bool) -> None:
foo = None
bar = None
if maybe and (foo := [1])[(bar := 0)]:
reveal_type(foo) # N: Revealed type is "builtins.list[builtins.int]"
reveal_type(bar) # N: Revealed type is "builtins.int"
else:
reveal_type(foo) # N: Revealed type is "Union[builtins.list[builtins.int], None]"
reveal_type(bar) # N: Revealed type is "Union[builtins.int, None]"

def check_and_nested(maybe: bool) -> None:
foo = None
bar = None
baz = None
if maybe and (foo := (bar := (baz := [1])))[0]:
reveal_type(foo) # N: Revealed type is "builtins.list[builtins.int]"
reveal_type(bar) # N: Revealed type is "builtins.list[builtins.int]"
reveal_type(baz) # N: Revealed type is "builtins.list[builtins.int]"
else:
reveal_type(foo) # N: Revealed type is "Union[builtins.list[builtins.int], None]"
reveal_type(bar) # N: Revealed type is "Union[builtins.list[builtins.int], None]"
reveal_type(baz) # N: Revealed type is "Union[builtins.list[builtins.int], None]"

def check_or(maybe: bool) -> None:
foo = None
bar = None
if maybe or (foo := [1])[(bar := 0)]:
reveal_type(foo) # N: Revealed type is "Union[builtins.list[builtins.int], None]"
reveal_type(bar) # N: Revealed type is "Union[builtins.int, None]"
else:
reveal_type(foo) # N: Revealed type is "builtins.list[builtins.int]"
reveal_type(bar) # N: Revealed type is "builtins.int"

def check_or_nested(maybe: bool) -> None:
foo = None
bar = None
baz = None
if maybe or (foo := (bar := (baz := [1])))[0]:
reveal_type(foo) # N: Revealed type is "Union[builtins.list[builtins.int], None]"
reveal_type(bar) # N: Revealed type is "Union[builtins.list[builtins.int], None]"
reveal_type(baz) # N: Revealed type is "Union[builtins.list[builtins.int], None]"
else:
reveal_type(foo) # N: Revealed type is "builtins.list[builtins.int]"
reveal_type(bar) # N: Revealed type is "builtins.list[builtins.int]"
reveal_type(baz) # N: Revealed type is "builtins.list[builtins.int]"