Skip to content

Commit a867181

Browse files
Zac-HDilevkivskyi
authored andcommitted
Track optional TypdeDict keys (#687)
Backport of python/cpython#17214 (BPO-38834)
1 parent d63e9e5 commit a867181

File tree

2 files changed

+36
-7
lines changed

2 files changed

+36
-7
lines changed

src_py3/test_typing_extensions.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,9 @@ class Point2D(TypedDict):
428428
x: int
429429
y: int
430430
431+
class Point2Dor3D(Point2D, total=False):
432+
z: int
433+
431434
class LabelPoint2D(Point2D, Label): ...
432435
433436
class Options(TypedDict, total=False):
@@ -442,7 +445,7 @@ class Options(TypedDict, total=False):
442445
ann_module = ann_module2 = ann_module3 = None
443446
A = B = CSub = G = CoolEmployee = CoolEmployeeWithDefault = object
444447
XMeth = XRepr = HasCallProtocol = NoneAndForward = Loop = object
445-
Point2D = LabelPoint2D = Options = object
448+
Point2D = Point2Dor3D = LabelPoint2D = Options = object
446449

447450
gth = get_type_hints
448451

@@ -1481,7 +1484,7 @@ def test_typeddict_create_errors(self):
14811484

14821485
def test_typeddict_errors(self):
14831486
Emp = TypedDict('Emp', {'name': str, 'id': int})
1484-
if hasattr(typing, 'TypedDict'):
1487+
if sys.version_info[:2] >= (3, 9):
14851488
self.assertEqual(TypedDict.__module__, 'typing')
14861489
else:
14871490
self.assertEqual(TypedDict.__module__, 'typing_extensions')
@@ -1543,6 +1546,11 @@ def test_total(self):
15431546
self.assertEqual(Options(log_level=2), {'log_level': 2})
15441547
self.assertEqual(Options.__total__, False)
15451548

1549+
@skipUnless(PY36, 'Python 3.6 required')
1550+
def test_optional_keys(self):
1551+
assert Point2Dor3D.__required_keys__ == frozenset(['x', 'y'])
1552+
assert Point2Dor3D.__optional_keys__ == frozenset(['z'])
1553+
15461554

15471555
@skipUnless(TYPING_3_5_3, "Python >= 3.5.3 required")
15481556
class AnnotatedTests(BaseTestCase):
@@ -1817,7 +1825,14 @@ def test_typing_extensions_includes_standard(self):
18171825
self.assertIn('runtime', a)
18181826

18191827
def test_typing_extensions_defers_when_possible(self):
1820-
exclude = {'overload', 'Text', 'TYPE_CHECKING', 'Final', 'get_type_hints'}
1828+
exclude = {
1829+
'overload',
1830+
'Text',
1831+
'TypedDict',
1832+
'TYPE_CHECKING',
1833+
'Final',
1834+
'get_type_hints'
1835+
}
18211836
for item in typing_extensions.__all__:
18221837
if item not in exclude and hasattr(typing, item):
18231838
self.assertIs(

src_py3/typing_extensions.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1569,7 +1569,9 @@ def runtime_checkable(cls):
15691569
runtime = runtime_checkable
15701570

15711571

1572-
if hasattr(typing, 'TypedDict'):
1572+
if sys.version_info[:2] >= (3, 9):
1573+
# The standard library TypedDict in Python 3.8 does not store runtime information
1574+
# about which (if any) keys are optional. See https://bugs.python.org/issue38834
15731575
TypedDict = typing.TypedDict
15741576
else:
15751577
def _check_fails(cls, other):
@@ -1652,9 +1654,20 @@ def __new__(cls, name, bases, ns, total=True):
16521654
anns = ns.get('__annotations__', {})
16531655
msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type"
16541656
anns = {n: typing._type_check(tp, msg) for n, tp in anns.items()}
1657+
required = set(anns if total else ())
1658+
optional = set(() if total else anns)
1659+
16551660
for base in bases:
1656-
anns.update(base.__dict__.get('__annotations__', {}))
1661+
base_anns = base.__dict__.get('__annotations__', {})
1662+
anns.update(base_anns)
1663+
if getattr(base, '__total__', True):
1664+
required.update(base_anns)
1665+
else:
1666+
optional.update(base_anns)
1667+
16571668
tp_dict.__annotations__ = anns
1669+
tp_dict.__required_keys__ = frozenset(required)
1670+
tp_dict.__optional_keys__ = frozenset(optional)
16581671
if not hasattr(tp_dict, '__total__'):
16591672
tp_dict.__total__ = total
16601673
return tp_dict
@@ -1682,8 +1695,9 @@ class Point2D(TypedDict):
16821695
16831696
assert Point2D(x=1, y=2, label='first') == dict(x=1, y=2, label='first')
16841697
1685-
The type info could be accessed via Point2D.__annotations__. TypedDict
1686-
supports two additional equivalent forms::
1698+
The type info can be accessed via the Point2D.__annotations__ dict, and
1699+
the Point2D.__required_keys__ and Point2D.__optional_keys__ frozensets.
1700+
TypedDict supports two additional equivalent forms::
16871701
16881702
Point2D = TypedDict('Point2D', x=int, y=int, label=str)
16891703
Point2D = TypedDict('Point2D', {'x': int, 'y': int, 'label': str})

0 commit comments

Comments
 (0)