Skip to content

bpo-33536: Validate make_dataclass() field names. #6906

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions Lib/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy
import types
import inspect
import keyword

__all__ = ['dataclass',
'field',
Expand Down Expand Up @@ -1100,6 +1101,9 @@ class C(Base):
# Copy namespace since we're going to mutate it.
namespace = namespace.copy()

# While we're looking through the field names, validate that they
# are identifiers, are not keywords, and not duplicates.
seen = set()
anns = {}
for item in fields:
if isinstance(item, str):
Expand All @@ -1110,6 +1114,17 @@ class C(Base):
elif len(item) == 3:
name, tp, spec = item
namespace[name] = spec
else:
raise TypeError(f'Invalid field: {item!r}')

if not isinstance(name, str) or not name.isidentifier():
raise TypeError(f'Field names must be valid identifers: {name!r}')
if keyword.iskeyword(name):
raise TypeError(f'Field names must not be keywords: {name!r}')
if name in seen:
raise TypeError(f'Field name duplicated: {name!r}')

seen.add(name)
anns[name] = tp

namespace['__annotations__'] = anns
Expand Down
273 changes: 165 additions & 108 deletions Lib/test/test_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1826,114 +1826,6 @@ class R:
self.assertEqual(new_sample.x, another_new_sample.x)
self.assertEqual(sample.y, another_new_sample.y)

def test_helper_make_dataclass(self):
C = make_dataclass('C',
[('x', int),
('y', int, field(default=5))],
namespace={'add_one': lambda self: self.x + 1})
c = C(10)
self.assertEqual((c.x, c.y), (10, 5))
self.assertEqual(c.add_one(), 11)


def test_helper_make_dataclass_no_mutate_namespace(self):
# Make sure a provided namespace isn't mutated.
ns = {}
C = make_dataclass('C',
[('x', int),
('y', int, field(default=5))],
namespace=ns)
self.assertEqual(ns, {})

def test_helper_make_dataclass_base(self):
class Base1:
pass
class Base2:
pass
C = make_dataclass('C',
[('x', int)],
bases=(Base1, Base2))
c = C(2)
self.assertIsInstance(c, C)
self.assertIsInstance(c, Base1)
self.assertIsInstance(c, Base2)

def test_helper_make_dataclass_base_dataclass(self):
@dataclass
class Base1:
x: int
class Base2:
pass
C = make_dataclass('C',
[('y', int)],
bases=(Base1, Base2))
with self.assertRaisesRegex(TypeError, 'required positional'):
c = C(2)
c = C(1, 2)
self.assertIsInstance(c, C)
self.assertIsInstance(c, Base1)
self.assertIsInstance(c, Base2)

self.assertEqual((c.x, c.y), (1, 2))

def test_helper_make_dataclass_init_var(self):
def post_init(self, y):
self.x *= y

C = make_dataclass('C',
[('x', int),
('y', InitVar[int]),
],
namespace={'__post_init__': post_init},
)
c = C(2, 3)
self.assertEqual(vars(c), {'x': 6})
self.assertEqual(len(fields(c)), 1)

def test_helper_make_dataclass_class_var(self):
C = make_dataclass('C',
[('x', int),
('y', ClassVar[int], 10),
('z', ClassVar[int], field(default=20)),
])
c = C(1)
self.assertEqual(vars(c), {'x': 1})
self.assertEqual(len(fields(c)), 1)
self.assertEqual(C.y, 10)
self.assertEqual(C.z, 20)

def test_helper_make_dataclass_other_params(self):
C = make_dataclass('C',
[('x', int),
('y', ClassVar[int], 10),
('z', ClassVar[int], field(default=20)),
],
init=False)
# Make sure we have a repr, but no init.
self.assertNotIn('__init__', vars(C))
self.assertIn('__repr__', vars(C))

# Make sure random other params don't work.
with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
C = make_dataclass('C',
[],
xxinit=False)

def test_helper_make_dataclass_no_types(self):
C = make_dataclass('Point', ['x', 'y', 'z'])
c = C(1, 2, 3)
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
self.assertEqual(C.__annotations__, {'x': 'typing.Any',
'y': 'typing.Any',
'z': 'typing.Any'})

C = make_dataclass('Point', ['x', ('y', int), 'z'])
c = C(1, 2, 3)
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
self.assertEqual(C.__annotations__, {'x': 'typing.Any',
'y': int,
'z': 'typing.Any'})


class TestFieldNoAnnotation(unittest.TestCase):
def test_field_without_annotation(self):
Expand Down Expand Up @@ -2947,5 +2839,170 @@ def test_classvar_module_level_import(self):
self.assertNotIn('not_iv4', c.__dict__)


class TestMakeDataclass(unittest.TestCase):
def test_simple(self):
C = make_dataclass('C',
[('x', int),
('y', int, field(default=5))],
namespace={'add_one': lambda self: self.x + 1})
c = C(10)
self.assertEqual((c.x, c.y), (10, 5))
self.assertEqual(c.add_one(), 11)


def test_no_mutate_namespace(self):
# Make sure a provided namespace isn't mutated.
ns = {}
C = make_dataclass('C',
[('x', int),
('y', int, field(default=5))],
namespace=ns)
self.assertEqual(ns, {})

def test_base(self):
class Base1:
pass
class Base2:
pass
C = make_dataclass('C',
[('x', int)],
bases=(Base1, Base2))
c = C(2)
self.assertIsInstance(c, C)
self.assertIsInstance(c, Base1)
self.assertIsInstance(c, Base2)

def test_base_dataclass(self):
@dataclass
class Base1:
x: int
class Base2:
pass
C = make_dataclass('C',
[('y', int)],
bases=(Base1, Base2))
with self.assertRaisesRegex(TypeError, 'required positional'):
c = C(2)
c = C(1, 2)
self.assertIsInstance(c, C)
self.assertIsInstance(c, Base1)
self.assertIsInstance(c, Base2)

self.assertEqual((c.x, c.y), (1, 2))

def test_init_var(self):
def post_init(self, y):
self.x *= y

C = make_dataclass('C',
[('x', int),
('y', InitVar[int]),
],
namespace={'__post_init__': post_init},
)
c = C(2, 3)
self.assertEqual(vars(c), {'x': 6})
self.assertEqual(len(fields(c)), 1)

def test_class_var(self):
C = make_dataclass('C',
[('x', int),
('y', ClassVar[int], 10),
('z', ClassVar[int], field(default=20)),
])
c = C(1)
self.assertEqual(vars(c), {'x': 1})
self.assertEqual(len(fields(c)), 1)
self.assertEqual(C.y, 10)
self.assertEqual(C.z, 20)

def test_other_params(self):
C = make_dataclass('C',
[('x', int),
('y', ClassVar[int], 10),
('z', ClassVar[int], field(default=20)),
],
init=False)
# Make sure we have a repr, but no init.
self.assertNotIn('__init__', vars(C))
self.assertIn('__repr__', vars(C))

# Make sure random other params don't work.
with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
C = make_dataclass('C',
[],
xxinit=False)

def test_no_types(self):
C = make_dataclass('Point', ['x', 'y', 'z'])
c = C(1, 2, 3)
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
self.assertEqual(C.__annotations__, {'x': 'typing.Any',
'y': 'typing.Any',
'z': 'typing.Any'})

C = make_dataclass('Point', ['x', ('y', int), 'z'])
c = C(1, 2, 3)
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
self.assertEqual(C.__annotations__, {'x': 'typing.Any',
'y': int,
'z': 'typing.Any'})

def test_invalid_type_specification(self):
for bad_field in [(),
(1, 2, 3, 4),
]:
with self.subTest(bad_field=bad_field):
with self.assertRaisesRegex(TypeError, r'Invalid field: '):
make_dataclass('C', ['a', bad_field])

# And test for things with no len().
for bad_field in [float,
lambda x:x,
]:
with self.subTest(bad_field=bad_field):
with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
make_dataclass('C', ['a', bad_field])

def test_duplicate_field_names(self):
for field in ['a', 'ab']:
with self.subTest(field=field):
with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
make_dataclass('C', [field, 'a', field])

def test_keyword_field_names(self):
for field in ['for', 'async', 'await', 'as']:
with self.subTest(field=field):
with self.assertRaisesRegex(TypeError, 'must not be keywords'):
make_dataclass('C', ['a', field])
with self.assertRaisesRegex(TypeError, 'must not be keywords'):
make_dataclass('C', [field])
with self.assertRaisesRegex(TypeError, 'must not be keywords'):
make_dataclass('C', [field, 'a'])

def test_non_identifier_field_names(self):
for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
with self.subTest(field=field):
with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
make_dataclass('C', ['a', field])
with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
make_dataclass('C', [field])
with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
make_dataclass('C', [field, 'a'])

def test_underscore_field_names(self):
# Unlike namedtuple, it's okay if dataclass field names have
# an underscore.
make_dataclass('C', ['_', '_a', 'a_a', 'a_'])

def test_funny_class_names_names(self):
# No reason to prevent weird class names, since
# types.new_class allows them.
for classname in ['()', 'x,y', '*', '2@3', '']:
with self.subTest(classname=classname):
C = make_dataclass(classname, ['a', 'b'])
self.assertEqual(C.__name__, classname)


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
dataclasses.make_dataclass now checks for invalid field names and duplicate
fields. Also, added a check for invalid field specifications.