Skip to content

Commit 4e81296

Browse files
authored
bpo-33536: Validate make_dataclass() field names. (GH-6906)
1 parent 5db5c06 commit 4e81296

File tree

3 files changed

+182
-108
lines changed

3 files changed

+182
-108
lines changed

Lib/dataclasses.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import copy
44
import types
55
import inspect
6+
import keyword
67

78
__all__ = ['dataclass',
89
'field',
@@ -1100,6 +1101,9 @@ class C(Base):
11001101
# Copy namespace since we're going to mutate it.
11011102
namespace = namespace.copy()
11021103

1104+
# While we're looking through the field names, validate that they
1105+
# are identifiers, are not keywords, and not duplicates.
1106+
seen = set()
11031107
anns = {}
11041108
for item in fields:
11051109
if isinstance(item, str):
@@ -1110,6 +1114,17 @@ class C(Base):
11101114
elif len(item) == 3:
11111115
name, tp, spec = item
11121116
namespace[name] = spec
1117+
else:
1118+
raise TypeError(f'Invalid field: {item!r}')
1119+
1120+
if not isinstance(name, str) or not name.isidentifier():
1121+
raise TypeError(f'Field names must be valid identifers: {name!r}')
1122+
if keyword.iskeyword(name):
1123+
raise TypeError(f'Field names must not be keywords: {name!r}')
1124+
if name in seen:
1125+
raise TypeError(f'Field name duplicated: {name!r}')
1126+
1127+
seen.add(name)
11131128
anns[name] = tp
11141129

11151130
namespace['__annotations__'] = anns

Lib/test/test_dataclasses.py

Lines changed: 165 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1826,114 +1826,6 @@ class R:
18261826
self.assertEqual(new_sample.x, another_new_sample.x)
18271827
self.assertEqual(sample.y, another_new_sample.y)
18281828

1829-
def test_helper_make_dataclass(self):
1830-
C = make_dataclass('C',
1831-
[('x', int),
1832-
('y', int, field(default=5))],
1833-
namespace={'add_one': lambda self: self.x + 1})
1834-
c = C(10)
1835-
self.assertEqual((c.x, c.y), (10, 5))
1836-
self.assertEqual(c.add_one(), 11)
1837-
1838-
1839-
def test_helper_make_dataclass_no_mutate_namespace(self):
1840-
# Make sure a provided namespace isn't mutated.
1841-
ns = {}
1842-
C = make_dataclass('C',
1843-
[('x', int),
1844-
('y', int, field(default=5))],
1845-
namespace=ns)
1846-
self.assertEqual(ns, {})
1847-
1848-
def test_helper_make_dataclass_base(self):
1849-
class Base1:
1850-
pass
1851-
class Base2:
1852-
pass
1853-
C = make_dataclass('C',
1854-
[('x', int)],
1855-
bases=(Base1, Base2))
1856-
c = C(2)
1857-
self.assertIsInstance(c, C)
1858-
self.assertIsInstance(c, Base1)
1859-
self.assertIsInstance(c, Base2)
1860-
1861-
def test_helper_make_dataclass_base_dataclass(self):
1862-
@dataclass
1863-
class Base1:
1864-
x: int
1865-
class Base2:
1866-
pass
1867-
C = make_dataclass('C',
1868-
[('y', int)],
1869-
bases=(Base1, Base2))
1870-
with self.assertRaisesRegex(TypeError, 'required positional'):
1871-
c = C(2)
1872-
c = C(1, 2)
1873-
self.assertIsInstance(c, C)
1874-
self.assertIsInstance(c, Base1)
1875-
self.assertIsInstance(c, Base2)
1876-
1877-
self.assertEqual((c.x, c.y), (1, 2))
1878-
1879-
def test_helper_make_dataclass_init_var(self):
1880-
def post_init(self, y):
1881-
self.x *= y
1882-
1883-
C = make_dataclass('C',
1884-
[('x', int),
1885-
('y', InitVar[int]),
1886-
],
1887-
namespace={'__post_init__': post_init},
1888-
)
1889-
c = C(2, 3)
1890-
self.assertEqual(vars(c), {'x': 6})
1891-
self.assertEqual(len(fields(c)), 1)
1892-
1893-
def test_helper_make_dataclass_class_var(self):
1894-
C = make_dataclass('C',
1895-
[('x', int),
1896-
('y', ClassVar[int], 10),
1897-
('z', ClassVar[int], field(default=20)),
1898-
])
1899-
c = C(1)
1900-
self.assertEqual(vars(c), {'x': 1})
1901-
self.assertEqual(len(fields(c)), 1)
1902-
self.assertEqual(C.y, 10)
1903-
self.assertEqual(C.z, 20)
1904-
1905-
def test_helper_make_dataclass_other_params(self):
1906-
C = make_dataclass('C',
1907-
[('x', int),
1908-
('y', ClassVar[int], 10),
1909-
('z', ClassVar[int], field(default=20)),
1910-
],
1911-
init=False)
1912-
# Make sure we have a repr, but no init.
1913-
self.assertNotIn('__init__', vars(C))
1914-
self.assertIn('__repr__', vars(C))
1915-
1916-
# Make sure random other params don't work.
1917-
with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
1918-
C = make_dataclass('C',
1919-
[],
1920-
xxinit=False)
1921-
1922-
def test_helper_make_dataclass_no_types(self):
1923-
C = make_dataclass('Point', ['x', 'y', 'z'])
1924-
c = C(1, 2, 3)
1925-
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
1926-
self.assertEqual(C.__annotations__, {'x': 'typing.Any',
1927-
'y': 'typing.Any',
1928-
'z': 'typing.Any'})
1929-
1930-
C = make_dataclass('Point', ['x', ('y', int), 'z'])
1931-
c = C(1, 2, 3)
1932-
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
1933-
self.assertEqual(C.__annotations__, {'x': 'typing.Any',
1934-
'y': int,
1935-
'z': 'typing.Any'})
1936-
19371829

19381830
class TestFieldNoAnnotation(unittest.TestCase):
19391831
def test_field_without_annotation(self):
@@ -2947,5 +2839,170 @@ def test_classvar_module_level_import(self):
29472839
self.assertNotIn('not_iv4', c.__dict__)
29482840

29492841

2842+
class TestMakeDataclass(unittest.TestCase):
2843+
def test_simple(self):
2844+
C = make_dataclass('C',
2845+
[('x', int),
2846+
('y', int, field(default=5))],
2847+
namespace={'add_one': lambda self: self.x + 1})
2848+
c = C(10)
2849+
self.assertEqual((c.x, c.y), (10, 5))
2850+
self.assertEqual(c.add_one(), 11)
2851+
2852+
2853+
def test_no_mutate_namespace(self):
2854+
# Make sure a provided namespace isn't mutated.
2855+
ns = {}
2856+
C = make_dataclass('C',
2857+
[('x', int),
2858+
('y', int, field(default=5))],
2859+
namespace=ns)
2860+
self.assertEqual(ns, {})
2861+
2862+
def test_base(self):
2863+
class Base1:
2864+
pass
2865+
class Base2:
2866+
pass
2867+
C = make_dataclass('C',
2868+
[('x', int)],
2869+
bases=(Base1, Base2))
2870+
c = C(2)
2871+
self.assertIsInstance(c, C)
2872+
self.assertIsInstance(c, Base1)
2873+
self.assertIsInstance(c, Base2)
2874+
2875+
def test_base_dataclass(self):
2876+
@dataclass
2877+
class Base1:
2878+
x: int
2879+
class Base2:
2880+
pass
2881+
C = make_dataclass('C',
2882+
[('y', int)],
2883+
bases=(Base1, Base2))
2884+
with self.assertRaisesRegex(TypeError, 'required positional'):
2885+
c = C(2)
2886+
c = C(1, 2)
2887+
self.assertIsInstance(c, C)
2888+
self.assertIsInstance(c, Base1)
2889+
self.assertIsInstance(c, Base2)
2890+
2891+
self.assertEqual((c.x, c.y), (1, 2))
2892+
2893+
def test_init_var(self):
2894+
def post_init(self, y):
2895+
self.x *= y
2896+
2897+
C = make_dataclass('C',
2898+
[('x', int),
2899+
('y', InitVar[int]),
2900+
],
2901+
namespace={'__post_init__': post_init},
2902+
)
2903+
c = C(2, 3)
2904+
self.assertEqual(vars(c), {'x': 6})
2905+
self.assertEqual(len(fields(c)), 1)
2906+
2907+
def test_class_var(self):
2908+
C = make_dataclass('C',
2909+
[('x', int),
2910+
('y', ClassVar[int], 10),
2911+
('z', ClassVar[int], field(default=20)),
2912+
])
2913+
c = C(1)
2914+
self.assertEqual(vars(c), {'x': 1})
2915+
self.assertEqual(len(fields(c)), 1)
2916+
self.assertEqual(C.y, 10)
2917+
self.assertEqual(C.z, 20)
2918+
2919+
def test_other_params(self):
2920+
C = make_dataclass('C',
2921+
[('x', int),
2922+
('y', ClassVar[int], 10),
2923+
('z', ClassVar[int], field(default=20)),
2924+
],
2925+
init=False)
2926+
# Make sure we have a repr, but no init.
2927+
self.assertNotIn('__init__', vars(C))
2928+
self.assertIn('__repr__', vars(C))
2929+
2930+
# Make sure random other params don't work.
2931+
with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
2932+
C = make_dataclass('C',
2933+
[],
2934+
xxinit=False)
2935+
2936+
def test_no_types(self):
2937+
C = make_dataclass('Point', ['x', 'y', 'z'])
2938+
c = C(1, 2, 3)
2939+
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
2940+
self.assertEqual(C.__annotations__, {'x': 'typing.Any',
2941+
'y': 'typing.Any',
2942+
'z': 'typing.Any'})
2943+
2944+
C = make_dataclass('Point', ['x', ('y', int), 'z'])
2945+
c = C(1, 2, 3)
2946+
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
2947+
self.assertEqual(C.__annotations__, {'x': 'typing.Any',
2948+
'y': int,
2949+
'z': 'typing.Any'})
2950+
2951+
def test_invalid_type_specification(self):
2952+
for bad_field in [(),
2953+
(1, 2, 3, 4),
2954+
]:
2955+
with self.subTest(bad_field=bad_field):
2956+
with self.assertRaisesRegex(TypeError, r'Invalid field: '):
2957+
make_dataclass('C', ['a', bad_field])
2958+
2959+
# And test for things with no len().
2960+
for bad_field in [float,
2961+
lambda x:x,
2962+
]:
2963+
with self.subTest(bad_field=bad_field):
2964+
with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
2965+
make_dataclass('C', ['a', bad_field])
2966+
2967+
def test_duplicate_field_names(self):
2968+
for field in ['a', 'ab']:
2969+
with self.subTest(field=field):
2970+
with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
2971+
make_dataclass('C', [field, 'a', field])
2972+
2973+
def test_keyword_field_names(self):
2974+
for field in ['for', 'async', 'await', 'as']:
2975+
with self.subTest(field=field):
2976+
with self.assertRaisesRegex(TypeError, 'must not be keywords'):
2977+
make_dataclass('C', ['a', field])
2978+
with self.assertRaisesRegex(TypeError, 'must not be keywords'):
2979+
make_dataclass('C', [field])
2980+
with self.assertRaisesRegex(TypeError, 'must not be keywords'):
2981+
make_dataclass('C', [field, 'a'])
2982+
2983+
def test_non_identifier_field_names(self):
2984+
for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
2985+
with self.subTest(field=field):
2986+
with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
2987+
make_dataclass('C', ['a', field])
2988+
with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
2989+
make_dataclass('C', [field])
2990+
with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
2991+
make_dataclass('C', [field, 'a'])
2992+
2993+
def test_underscore_field_names(self):
2994+
# Unlike namedtuple, it's okay if dataclass field names have
2995+
# an underscore.
2996+
make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
2997+
2998+
def test_funny_class_names_names(self):
2999+
# No reason to prevent weird class names, since
3000+
# types.new_class allows them.
3001+
for classname in ['()', 'x,y', '*', '2@3', '']:
3002+
with self.subTest(classname=classname):
3003+
C = make_dataclass(classname, ['a', 'b'])
3004+
self.assertEqual(C.__name__, classname)
3005+
3006+
29503007
if __name__ == '__main__':
29513008
unittest.main()
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
dataclasses.make_dataclass now checks for invalid field names and duplicate
2+
fields. Also, added a check for invalid field specifications.

0 commit comments

Comments
 (0)