Skip to content

Commit d841859

Browse files
authored
Properly support union of TypedDicts as dict literal context (#14505)
Fixes #14481 (regression) Fixes #13274 Fixes #8533 Most notably, if literal matches multiple items in union, it is not an error, it is only an error if it matches none of them, so I adjust the error message accordingly. An import caveat is that an unrelated error like `{"key": 42 + "no"}` can cause no item to match (an hence an extra error), but I think it is fine, since we still show the actual error, and avoiding this would require some dirty hacks. Also note there was an (obvious) bug in one of the fixtures, that caused one of repros not repro in tests, fixing it required tweaking an unrelated test.
1 parent cb14d6f commit d841859

File tree

5 files changed

+131
-33
lines changed

5 files changed

+131
-33
lines changed

mypy/checkexpr.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4188,6 +4188,17 @@ def fast_dict_type(self, e: DictExpr) -> Type | None:
41884188
self.resolved_type[e] = dt
41894189
return dt
41904190

4191+
def check_typeddict_literal_in_context(
4192+
self, e: DictExpr, typeddict_context: TypedDictType
4193+
) -> Type:
4194+
orig_ret_type = self.check_typeddict_call_with_dict(
4195+
callee=typeddict_context, kwargs=e, context=e, orig_callee=None
4196+
)
4197+
ret_type = get_proper_type(orig_ret_type)
4198+
if isinstance(ret_type, TypedDictType):
4199+
return ret_type.copy_modified()
4200+
return typeddict_context.copy_modified()
4201+
41914202
def visit_dict_expr(self, e: DictExpr) -> Type:
41924203
"""Type check a dict expression.
41934204
@@ -4197,15 +4208,20 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
41974208
# an error, but returns the TypedDict type that matches the literal it found
41984209
# that would cause a second error when that TypedDict type is returned upstream
41994210
# to avoid the second error, we always return TypedDict type that was requested
4200-
typeddict_context = self.find_typeddict_context(self.type_context[-1], e)
4201-
if typeddict_context:
4202-
orig_ret_type = self.check_typeddict_call_with_dict(
4203-
callee=typeddict_context, kwargs=e, context=e, orig_callee=None
4204-
)
4205-
ret_type = get_proper_type(orig_ret_type)
4206-
if isinstance(ret_type, TypedDictType):
4207-
return ret_type.copy_modified()
4208-
return typeddict_context.copy_modified()
4211+
typeddict_contexts = self.find_typeddict_context(self.type_context[-1], e)
4212+
if typeddict_contexts:
4213+
if len(typeddict_contexts) == 1:
4214+
return self.check_typeddict_literal_in_context(e, typeddict_contexts[0])
4215+
# Multiple items union, check if at least one of them matches cleanly.
4216+
for typeddict_context in typeddict_contexts:
4217+
with self.msg.filter_errors() as err, self.chk.local_type_map() as tmap:
4218+
ret_type = self.check_typeddict_literal_in_context(e, typeddict_context)
4219+
if err.has_new_errors():
4220+
continue
4221+
self.chk.store_types(tmap)
4222+
return ret_type
4223+
# No item matched without an error, so we can't unambiguously choose the item.
4224+
self.msg.typeddict_context_ambiguous(typeddict_contexts, e)
42094225

42104226
# fast path attempt
42114227
dt = self.fast_dict_type(e)
@@ -4271,26 +4287,20 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
42714287

42724288
def find_typeddict_context(
42734289
self, context: Type | None, dict_expr: DictExpr
4274-
) -> TypedDictType | None:
4290+
) -> list[TypedDictType]:
42754291
context = get_proper_type(context)
42764292
if isinstance(context, TypedDictType):
4277-
return context
4293+
return [context]
42784294
elif isinstance(context, UnionType):
42794295
items = []
42804296
for item in context.items:
4281-
item_context = self.find_typeddict_context(item, dict_expr)
4282-
if item_context is not None and self.match_typeddict_call_with_dict(
4283-
item_context, dict_expr, dict_expr
4284-
):
4285-
items.append(item_context)
4286-
if len(items) == 1:
4287-
# Only one union item is valid TypedDict for the given dict_expr, so use the
4288-
# context as it's unambiguous.
4289-
return items[0]
4290-
if len(items) > 1:
4291-
self.msg.typeddict_context_ambiguous(items, dict_expr)
4297+
item_contexts = self.find_typeddict_context(item, dict_expr)
4298+
for item_context in item_contexts:
4299+
if self.match_typeddict_call_with_dict(item_context, dict_expr, dict_expr):
4300+
items.append(item_context)
4301+
return items
42924302
# No TypedDict type in context.
4293-
return None
4303+
return []
42944304

42954305
def visit_lambda_expr(self, e: LambdaExpr) -> Type:
42964306
"""Type check lambda expression."""

mypy/messages.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1705,7 +1705,9 @@ def typeddict_key_not_found(
17051705

17061706
def typeddict_context_ambiguous(self, types: list[TypedDictType], context: Context) -> None:
17071707
formatted_types = ", ".join(list(format_type_distinctly(*types)))
1708-
self.fail(f"Type of TypedDict is ambiguous, could be any of ({formatted_types})", context)
1708+
self.fail(
1709+
f"Type of TypedDict is ambiguous, none of ({formatted_types}) matches cleanly", context
1710+
)
17091711

17101712
def typeddict_key_cannot_be_deleted(
17111713
self, typ: TypedDictType, item_name: str, context: Context

test-data/unit/check-typeddict.test

Lines changed: 91 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -895,15 +895,25 @@ c: Union[A, B] = {'@type': 'a-type', 'a': 'Test'}
895895
reveal_type(c) # N: Revealed type is "Union[TypedDict('__main__.A', {'@type': Literal['a-type'], 'a': builtins.str}), TypedDict('__main__.B', {'@type': Literal['b-type'], 'b': builtins.int})]"
896896
[builtins fixtures/dict.pyi]
897897

898-
[case testTypedDictUnionAmbiguousCase]
898+
[case testTypedDictUnionAmbiguousCaseBothMatch]
899899
from typing import Union, Mapping, Any, cast
900900
from typing_extensions import TypedDict, Literal
901901

902-
A = TypedDict('A', {'@type': Literal['a-type'], 'a': str})
903-
B = TypedDict('B', {'@type': Literal['a-type'], 'a': str})
902+
A = TypedDict('A', {'@type': Literal['a-type'], 'value': str})
903+
B = TypedDict('B', {'@type': Literal['b-type'], 'value': str})
904+
905+
c: Union[A, B] = {'@type': 'a-type', 'value': 'Test'}
906+
[builtins fixtures/dict.pyi]
907+
908+
[case testTypedDictUnionAmbiguousCaseNoMatch]
909+
from typing import Union, Mapping, Any, cast
910+
from typing_extensions import TypedDict, Literal
904911

905-
c: Union[A, B] = {'@type': 'a-type', 'a': 'Test'} # E: Type of TypedDict is ambiguous, could be any of ("A", "B") \
906-
# E: Incompatible types in assignment (expression has type "Dict[str, str]", variable has type "Union[A, B]")
912+
A = TypedDict('A', {'@type': Literal['a-type'], 'value': int})
913+
B = TypedDict('B', {'@type': Literal['b-type'], 'value': int})
914+
915+
c: Union[A, B] = {'@type': 'a-type', 'value': 'Test'} # E: Type of TypedDict is ambiguous, none of ("A", "B") matches cleanly \
916+
# E: Incompatible types in assignment (expression has type "Dict[str, str]", variable has type "Union[A, B]")
907917
[builtins fixtures/dict.pyi]
908918

909919
-- Use dict literals
@@ -2786,3 +2796,79 @@ TDC = TypedDict("TDC", {"val": int, "next": Optional[Self]}) # E: Self type can
27862796

27872797
[builtins fixtures/dict.pyi]
27882798
[typing fixtures/typing-typeddict.pyi]
2799+
2800+
[case testUnionOfEquivalentTypedDictsInferred]
2801+
from typing import TypedDict, Dict
2802+
2803+
D = TypedDict("D", {"foo": int}, total=False)
2804+
2805+
def f(d: Dict[str, D]) -> None:
2806+
args = d["a"]
2807+
args.update(d.get("b", {})) # OK
2808+
[builtins fixtures/dict.pyi]
2809+
[typing fixtures/typing-typeddict.pyi]
2810+
2811+
[case testUnionOfEquivalentTypedDictsDeclared]
2812+
from typing import TypedDict, Union
2813+
2814+
class A(TypedDict, total=False):
2815+
name: str
2816+
class B(TypedDict, total=False):
2817+
name: str
2818+
2819+
def foo(data: Union[A, B]) -> None: ...
2820+
foo({"name": "Robert"}) # OK
2821+
[builtins fixtures/dict.pyi]
2822+
[typing fixtures/typing-typeddict.pyi]
2823+
2824+
[case testUnionOfEquivalentTypedDictsEmpty]
2825+
from typing import TypedDict, Union
2826+
2827+
class Foo(TypedDict, total=False):
2828+
foo: str
2829+
class Bar(TypedDict, total=False):
2830+
bar: str
2831+
2832+
def foo(body: Union[Foo, Bar] = {}) -> None: # OK
2833+
...
2834+
[builtins fixtures/dict.pyi]
2835+
[typing fixtures/typing-typeddict.pyi]
2836+
2837+
[case testUnionOfEquivalentTypedDictsDistinct]
2838+
from typing import TypedDict, Union, Literal
2839+
2840+
class A(TypedDict):
2841+
type: Literal['a']
2842+
value: bool
2843+
class B(TypedDict):
2844+
type: Literal['b']
2845+
value: str
2846+
2847+
Response = Union[A, B]
2848+
def method(message: Response) -> None: ...
2849+
2850+
method({'type': 'a', 'value': True}) # OK
2851+
method({'type': 'b', 'value': 'abc'}) # OK
2852+
method({'type': 'a', 'value': 'abc'}) # E: Type of TypedDict is ambiguous, none of ("A", "B") matches cleanly \
2853+
# E: Argument 1 to "method" has incompatible type "Dict[str, str]"; expected "Union[A, B]"
2854+
[builtins fixtures/dict.pyi]
2855+
[typing fixtures/typing-typeddict.pyi]
2856+
2857+
[case testUnionOfEquivalentTypedDictsNested]
2858+
from typing import TypedDict, Union
2859+
2860+
class A(TypedDict, total=False):
2861+
foo: C
2862+
class B(TypedDict, total=False):
2863+
foo: D
2864+
class C(TypedDict, total=False):
2865+
c: str
2866+
class D(TypedDict, total=False):
2867+
d: str
2868+
2869+
def foo(data: Union[A, B]) -> None: ...
2870+
foo({"foo": {"c": "foo"}}) # OK
2871+
foo({"foo": {"e": "foo"}}) # E: Type of TypedDict is ambiguous, none of ("A", "B") matches cleanly \
2872+
# E: Argument 1 to "foo" has incompatible type "Dict[str, Dict[str, str]]"; expected "Union[A, B]"
2873+
[builtins fixtures/dict.pyi]
2874+
[typing fixtures/typing-typeddict.pyi]

test-data/unit/check-unions.test

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -971,14 +971,14 @@ if x:
971971
[builtins fixtures/dict.pyi]
972972
[out]
973973

974-
[case testUnpackUnionNoCrashOnPartialNoneList]
974+
[case testUnpackUnionNoCrashOnPartialList]
975975
# flags: --strict-optional
976976
from typing import Dict, Tuple, List, Any
977977

978978
a: Any
979979
d: Dict[str, Tuple[List[Tuple[str, str]], str]]
980-
x, _ = d.get(a, ([], []))
981-
reveal_type(x) # N: Revealed type is "Union[builtins.list[Tuple[builtins.str, builtins.str]], builtins.list[<nothing>]]"
980+
x, _ = d.get(a, ([], ""))
981+
reveal_type(x) # N: Revealed type is "builtins.list[Tuple[builtins.str, builtins.str]]"
982982

983983
for y in x: pass
984984
[builtins fixtures/dict.pyi]

test-data/unit/fixtures/dict.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class dict(Mapping[KT, VT]):
2929
@overload
3030
def get(self, k: KT) -> Optional[VT]: pass
3131
@overload
32-
def get(self, k: KT, default: Union[KT, T]) -> Union[VT, T]: pass
32+
def get(self, k: KT, default: Union[VT, T]) -> Union[VT, T]: pass
3333
def __len__(self) -> int: ...
3434

3535
class int: # for convenience

0 commit comments

Comments
 (0)