Skip to content

Commit 7bbc7af

Browse files
authored
Allow TypedDict key with literal type during construction (#7645)
Also refactored an operation away from `mypy.plugins.common` so that it is more cleanly reusable. Also kept an alias in the original location to avoid breaking existing plugins. Fixes #7644.
1 parent e4d99c6 commit 7bbc7af

File tree

4 files changed

+76
-46
lines changed

4 files changed

+76
-46
lines changed

mypy/checkexpr.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
from mypy.plugin import Plugin, MethodContext, MethodSigContext, FunctionContext
6060
from mypy.typeops import (
6161
tuple_fallback, make_simplified_union, true_only, false_only, erase_to_union_or_bound,
62-
function_type, callable_type,
62+
function_type, callable_type, try_getting_str_literals
6363
)
6464
import mypy.errorcodes as codes
6565

@@ -493,11 +493,19 @@ def check_typeddict_call_with_dict(self, callee: TypedDictType,
493493

494494
item_names = [] # List[str]
495495
for item_name_expr, item_arg in kwargs.items:
496-
if not isinstance(item_name_expr, StrExpr):
496+
literal_value = None
497+
if item_name_expr:
498+
key_type = self.accept(item_name_expr)
499+
values = try_getting_str_literals(item_name_expr, key_type)
500+
if values and len(values) == 1:
501+
literal_value = values[0]
502+
if literal_value is None:
497503
key_context = item_name_expr or item_arg
498-
self.chk.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, key_context)
504+
self.chk.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL,
505+
key_context)
499506
return AnyType(TypeOfAny.from_error)
500-
item_names.append(item_name_expr.value)
507+
else:
508+
item_names.append(literal_value)
501509

502510
return self.check_typeddict_call_with_kwargs(
503511
callee, OrderedDict(zip(item_names, item_args)), context)

mypy/plugins/common.py

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,14 @@
22

33
from mypy.nodes import (
44
ARG_POS, MDEF, Argument, Block, CallExpr, Expression, SYMBOL_FUNCBASE_TYPES,
5-
FuncDef, PassStmt, RefExpr, SymbolTableNode, Var, StrExpr,
5+
FuncDef, PassStmt, RefExpr, SymbolTableNode, Var
66
)
77
from mypy.plugin import ClassDefContext
88
from mypy.semanal import set_callable_name
9-
from mypy.types import (
10-
CallableType, Overloaded, Type, TypeVarDef, LiteralType, Instance, UnionType,
11-
get_proper_type, get_proper_types
12-
)
9+
from mypy.types import CallableType, Overloaded, Type, TypeVarDef, get_proper_type
1310
from mypy.typevars import fill_typevars
1411
from mypy.util import get_unique_redefinition_name
12+
from mypy.typeops import try_getting_str_literals # noqa: F401 # Part of public API
1513

1614

1715
def _get_decorator_bool_argument(
@@ -130,38 +128,3 @@ def add_method(
130128

131129
info.names[name] = SymbolTableNode(MDEF, func, plugin_generated=True)
132130
info.defn.defs.body.append(func)
133-
134-
135-
def try_getting_str_literals(expr: Expression, typ: Type) -> Optional[List[str]]:
136-
"""If the given expression or type corresponds to a string literal
137-
or a union of string literals, returns a list of the underlying strings.
138-
Otherwise, returns None.
139-
140-
Specifically, this function is guaranteed to return a list with
141-
one or more strings if one one the following is true:
142-
143-
1. 'expr' is a StrExpr
144-
2. 'typ' is a LiteralType containing a string
145-
3. 'typ' is a UnionType containing only LiteralType of strings
146-
"""
147-
typ = get_proper_type(typ)
148-
149-
if isinstance(expr, StrExpr):
150-
return [expr.value]
151-
152-
if isinstance(typ, Instance) and typ.last_known_value is not None:
153-
possible_literals = [typ.last_known_value] # type: List[Type]
154-
elif isinstance(typ, UnionType):
155-
possible_literals = list(typ.items)
156-
else:
157-
possible_literals = [typ]
158-
159-
strings = []
160-
for lit in get_proper_types(possible_literals):
161-
if isinstance(lit, LiteralType) and lit.fallback.type.fullname() == 'builtins.str':
162-
val = lit.value
163-
assert isinstance(val, str)
164-
strings.append(val)
165-
else:
166-
return None
167-
return strings

mypy/typeops.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010
from mypy.types import (
1111
TupleType, Instance, FunctionLike, Type, CallableType, TypeVarDef, Overloaded,
1212
TypeVarType, TypeType, UninhabitedType, FormalArgument, UnionType, NoneType,
13-
AnyType, TypeOfAny, TypeType, ProperType, get_proper_type, get_proper_types, copy_type
13+
AnyType, TypeOfAny, TypeType, ProperType, LiteralType, get_proper_type, get_proper_types,
14+
copy_type
1415
)
1516
from mypy.nodes import (
16-
FuncBase, FuncItem, OverloadedFuncDef, TypeInfo, TypeVar, ARG_STAR, ARG_STAR2,
17+
FuncBase, FuncItem, OverloadedFuncDef, TypeInfo, TypeVar, ARG_STAR, ARG_STAR2, Expression,
18+
StrExpr
1719
)
1820
from mypy.maptype import map_instance_to_supertype
1921
from mypy.expandtype import expand_type_by_instance, expand_type
@@ -417,3 +419,38 @@ def callable_type(fdef: FuncItem, fallback: Instance,
417419
column=fdef.column,
418420
implicit=True,
419421
)
422+
423+
424+
def try_getting_str_literals(expr: Expression, typ: Type) -> Optional[List[str]]:
425+
"""If the given expression or type corresponds to a string literal
426+
or a union of string literals, returns a list of the underlying strings.
427+
Otherwise, returns None.
428+
429+
Specifically, this function is guaranteed to return a list with
430+
one or more strings if one one the following is true:
431+
432+
1. 'expr' is a StrExpr
433+
2. 'typ' is a LiteralType containing a string
434+
3. 'typ' is a UnionType containing only LiteralType of strings
435+
"""
436+
typ = get_proper_type(typ)
437+
438+
if isinstance(expr, StrExpr):
439+
return [expr.value]
440+
441+
if isinstance(typ, Instance) and typ.last_known_value is not None:
442+
possible_literals = [typ.last_known_value] # type: List[Type]
443+
elif isinstance(typ, UnionType):
444+
possible_literals = list(typ.items)
445+
else:
446+
possible_literals = [typ]
447+
448+
strings = []
449+
for lit in get_proper_types(possible_literals):
450+
if isinstance(lit, LiteralType) and lit.fallback.type.fullname() == 'builtins.str':
451+
val = lit.value
452+
assert isinstance(val, str)
453+
strings.append(val)
454+
else:
455+
return None
456+
return strings

test-data/unit/check-typeddict.test

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1883,3 +1883,25 @@ assert isinstance(u2, Mapping)
18831883
reveal_type(u2) # N: Revealed type is 'TypedDict('__main__.User', {'id': builtins.int, 'name': builtins.str})'
18841884
[builtins fixtures/dict.pyi]
18851885
[typing fixtures/typing-full.pyi]
1886+
1887+
[case testTypedDictLiteralTypeKeyInCreation]
1888+
from typing import TypedDict, Final, Literal
1889+
1890+
class Value(TypedDict):
1891+
num: int
1892+
1893+
num: Final = 'num'
1894+
v: Value = {num: 5}
1895+
v = {num: ''} # E: Incompatible types (expression has type "str", TypedDict item "num" has type "int")
1896+
1897+
bad: Final = 2
1898+
v = {bad: 3} # E: Expected TypedDict key to be string literal
1899+
union: Literal['num', 'foo']
1900+
v = {union: 2} # E: Expected TypedDict key to be string literal
1901+
num2: Literal['num']
1902+
v = {num2: 2}
1903+
bad2: Literal['bad']
1904+
v = {bad2: 2} # E: Extra key 'bad' for TypedDict "Value"
1905+
1906+
[builtins fixtures/dict.pyi]
1907+
[typing fixtures/typing-full.pyi]

0 commit comments

Comments
 (0)