Skip to content

Commit 0d57db2

Browse files
miss-islington1st1
authored andcommitted
bpo-34776: Fix dataclasses to support __future__ "annotations" mode (GH-9518) (#17531)
(cherry picked from commit d219cc4) Co-authored-by: Yury Selivanov <[email protected]>
1 parent 79c2974 commit 0d57db2

File tree

4 files changed

+78
-34
lines changed

4 files changed

+78
-34
lines changed

Lib/dataclasses.py

Lines changed: 53 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -379,23 +379,24 @@ def _create_fn(name, args, body, *, globals=None, locals=None,
379379
# worries about external callers.
380380
if locals is None:
381381
locals = {}
382-
# __builtins__ may be the "builtins" module or
383-
# the value of its "__dict__",
384-
# so make sure "__builtins__" is the module.
385-
if globals is not None and '__builtins__' not in globals:
386-
globals['__builtins__'] = builtins
382+
if 'BUILTINS' not in locals:
383+
locals['BUILTINS'] = builtins
387384
return_annotation = ''
388385
if return_type is not MISSING:
389386
locals['_return_type'] = return_type
390387
return_annotation = '->_return_type'
391388
args = ','.join(args)
392-
body = '\n'.join(f' {b}' for b in body)
389+
body = '\n'.join(f' {b}' for b in body)
393390

394391
# Compute the text of the entire function.
395-
txt = f'def {name}({args}){return_annotation}:\n{body}'
392+
txt = f' def {name}({args}){return_annotation}:\n{body}'
396393

397-
exec(txt, globals, locals)
398-
return locals[name]
394+
local_vars = ', '.join(locals.keys())
395+
txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"
396+
397+
ns = {}
398+
exec(txt, globals, ns)
399+
return ns['__create_fn__'](**locals)
399400

400401

401402
def _field_assign(frozen, name, value, self_name):
@@ -406,7 +407,7 @@ def _field_assign(frozen, name, value, self_name):
406407
# self_name is what "self" is called in this function: don't
407408
# hard-code "self", since that might be a field name.
408409
if frozen:
409-
return f'__builtins__.object.__setattr__({self_name},{name!r},{value})'
410+
return f'BUILTINS.object.__setattr__({self_name},{name!r},{value})'
410411
return f'{self_name}.{name}={value}'
411412

412413

@@ -483,7 +484,7 @@ def _init_param(f):
483484
return f'{f.name}:_type_{f.name}{default}'
484485

485486

486-
def _init_fn(fields, frozen, has_post_init, self_name):
487+
def _init_fn(fields, frozen, has_post_init, self_name, globals):
487488
# fields contains both real fields and InitVar pseudo-fields.
488489

489490
# Make sure we don't have fields without defaults following fields
@@ -501,12 +502,15 @@ def _init_fn(fields, frozen, has_post_init, self_name):
501502
raise TypeError(f'non-default argument {f.name!r} '
502503
'follows default argument')
503504

504-
globals = {'MISSING': MISSING,
505-
'_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY}
505+
locals = {f'_type_{f.name}': f.type for f in fields}
506+
locals.update({
507+
'MISSING': MISSING,
508+
'_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY,
509+
})
506510

507511
body_lines = []
508512
for f in fields:
509-
line = _field_init(f, frozen, globals, self_name)
513+
line = _field_init(f, frozen, locals, self_name)
510514
# line is None means that this field doesn't require
511515
# initialization (it's a pseudo-field). Just skip it.
512516
if line:
@@ -522,7 +526,6 @@ def _init_fn(fields, frozen, has_post_init, self_name):
522526
if not body_lines:
523527
body_lines = ['pass']
524528

525-
locals = {f'_type_{f.name}': f.type for f in fields}
526529
return _create_fn('__init__',
527530
[self_name] + [_init_param(f) for f in fields if f.init],
528531
body_lines,
@@ -531,20 +534,19 @@ def _init_fn(fields, frozen, has_post_init, self_name):
531534
return_type=None)
532535

533536

534-
def _repr_fn(fields):
537+
def _repr_fn(fields, globals):
535538
fn = _create_fn('__repr__',
536539
('self',),
537540
['return self.__class__.__qualname__ + f"(' +
538541
', '.join([f"{f.name}={{self.{f.name}!r}}"
539542
for f in fields]) +
540-
')"'])
543+
')"'],
544+
globals=globals)
541545
return _recursive_repr(fn)
542546

543547

544-
def _frozen_get_del_attr(cls, fields):
545-
# XXX: globals is modified on the first call to _create_fn, then
546-
# the modified version is used in the second call. Is this okay?
547-
globals = {'cls': cls,
548+
def _frozen_get_del_attr(cls, fields, globals):
549+
locals = {'cls': cls,
548550
'FrozenInstanceError': FrozenInstanceError}
549551
if fields:
550552
fields_str = '(' + ','.join(repr(f.name) for f in fields) + ',)'
@@ -556,17 +558,19 @@ def _frozen_get_del_attr(cls, fields):
556558
(f'if type(self) is cls or name in {fields_str}:',
557559
' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
558560
f'super(cls, self).__setattr__(name, value)'),
561+
locals=locals,
559562
globals=globals),
560563
_create_fn('__delattr__',
561564
('self', 'name'),
562565
(f'if type(self) is cls or name in {fields_str}:',
563566
' raise FrozenInstanceError(f"cannot delete field {name!r}")',
564567
f'super(cls, self).__delattr__(name)'),
568+
locals=locals,
565569
globals=globals),
566570
)
567571

568572

569-
def _cmp_fn(name, op, self_tuple, other_tuple):
573+
def _cmp_fn(name, op, self_tuple, other_tuple, globals):
570574
# Create a comparison function. If the fields in the object are
571575
# named 'x' and 'y', then self_tuple is the string
572576
# '(self.x,self.y)' and other_tuple is the string
@@ -576,14 +580,16 @@ def _cmp_fn(name, op, self_tuple, other_tuple):
576580
('self', 'other'),
577581
[ 'if other.__class__ is self.__class__:',
578582
f' return {self_tuple}{op}{other_tuple}',
579-
'return NotImplemented'])
583+
'return NotImplemented'],
584+
globals=globals)
580585

581586

582-
def _hash_fn(fields):
587+
def _hash_fn(fields, globals):
583588
self_tuple = _tuple_str('self', fields)
584589
return _create_fn('__hash__',
585590
('self',),
586-
[f'return hash({self_tuple})'])
591+
[f'return hash({self_tuple})'],
592+
globals=globals)
587593

588594

589595
def _is_classvar(a_type, typing):
@@ -756,14 +762,14 @@ def _set_new_attribute(cls, name, value):
756762
# take. The common case is to do nothing, so instead of providing a
757763
# function that is a no-op, use None to signify that.
758764

759-
def _hash_set_none(cls, fields):
765+
def _hash_set_none(cls, fields, globals):
760766
return None
761767

762-
def _hash_add(cls, fields):
768+
def _hash_add(cls, fields, globals):
763769
flds = [f for f in fields if (f.compare if f.hash is None else f.hash)]
764-
return _hash_fn(flds)
770+
return _hash_fn(flds, globals)
765771

766-
def _hash_exception(cls, fields):
772+
def _hash_exception(cls, fields, globals):
767773
# Raise an exception.
768774
raise TypeError(f'Cannot overwrite attribute __hash__ '
769775
f'in class {cls.__name__}')
@@ -805,6 +811,16 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
805811
# is defined by the base class, which is found first.
806812
fields = {}
807813

814+
if cls.__module__ in sys.modules:
815+
globals = sys.modules[cls.__module__].__dict__
816+
else:
817+
# Theoretically this can happen if someone writes
818+
# a custom string to cls.__module__. In which case
819+
# such dataclass won't be fully introspectable
820+
# (w.r.t. typing.get_type_hints) but will still function
821+
# correctly.
822+
globals = {}
823+
808824
setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order,
809825
unsafe_hash, frozen))
810826

@@ -914,6 +930,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
914930
# if possible.
915931
'__dataclass_self__' if 'self' in fields
916932
else 'self',
933+
globals,
917934
))
918935

919936
# Get the fields as a list, and include only real fields. This is
@@ -922,7 +939,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
922939

923940
if repr:
924941
flds = [f for f in field_list if f.repr]
925-
_set_new_attribute(cls, '__repr__', _repr_fn(flds))
942+
_set_new_attribute(cls, '__repr__', _repr_fn(flds, globals))
926943

927944
if eq:
928945
# Create _eq__ method. There's no need for a __ne__ method,
@@ -932,7 +949,8 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
932949
other_tuple = _tuple_str('other', flds)
933950
_set_new_attribute(cls, '__eq__',
934951
_cmp_fn('__eq__', '==',
935-
self_tuple, other_tuple))
952+
self_tuple, other_tuple,
953+
globals=globals))
936954

937955
if order:
938956
# Create and set the ordering methods.
@@ -945,13 +963,14 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
945963
('__ge__', '>='),
946964
]:
947965
if _set_new_attribute(cls, name,
948-
_cmp_fn(name, op, self_tuple, other_tuple)):
966+
_cmp_fn(name, op, self_tuple, other_tuple,
967+
globals=globals)):
949968
raise TypeError(f'Cannot overwrite attribute {name} '
950969
f'in class {cls.__name__}. Consider using '
951970
'functools.total_ordering')
952971

953972
if frozen:
954-
for fn in _frozen_get_del_attr(cls, field_list):
973+
for fn in _frozen_get_del_attr(cls, field_list, globals):
955974
if _set_new_attribute(cls, fn.__name__, fn):
956975
raise TypeError(f'Cannot overwrite attribute {fn.__name__} '
957976
f'in class {cls.__name__}')
@@ -964,7 +983,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
964983
if hash_action:
965984
# No need to call _set_new_attribute here, since by the time
966985
# we're here the overwriting is unconditional.
967-
cls.__hash__ = hash_action(cls, field_list)
986+
cls.__hash__ = hash_action(cls, field_list, globals)
968987

969988
if not getattr(cls, '__doc__'):
970989
# Create a class doc-string.

Lib/test/dataclass_textanno.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from __future__ import annotations
2+
3+
import dataclasses
4+
5+
6+
class Foo:
7+
pass
8+
9+
10+
@dataclasses.dataclass
11+
class Bar:
12+
foo: Foo

Lib/test/test_dataclasses.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import unittest
1111
from unittest.mock import Mock
1212
from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional
13+
from typing import get_type_hints
1314
from collections import deque, OrderedDict, namedtuple
1415
from functools import total_ordering
1516

@@ -2926,6 +2927,17 @@ def test_classvar_module_level_import(self):
29262927
# won't exist on the instance.
29272928
self.assertNotIn('not_iv4', c.__dict__)
29282929

2930+
def test_text_annotations(self):
2931+
from test import dataclass_textanno
2932+
2933+
self.assertEqual(
2934+
get_type_hints(dataclass_textanno.Bar),
2935+
{'foo': dataclass_textanno.Foo})
2936+
self.assertEqual(
2937+
get_type_hints(dataclass_textanno.Bar.__init__),
2938+
{'foo': dataclass_textanno.Foo,
2939+
'return': type(None)})
2940+
29292941

29302942
class TestMakeDataclass(unittest.TestCase):
29312943
def test_simple(self):
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix dataclasses to support forward references in type annotations

0 commit comments

Comments
 (0)