Skip to content

Commit 1471df7

Browse files
committed
Add tests for more specializations
1 parent ecaa6dd commit 1471df7

File tree

1 file changed

+141
-47
lines changed

1 file changed

+141
-47
lines changed

Lib/test/test_type_cache.py

Lines changed: 141 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -76,72 +76,166 @@ class C:
7676
new_version = type_get_version(C)
7777
self.assertEqual(new_version, orig_version + 5)
7878

79-
def test_specialization_user_type_no_tag_overflow(self):
79+
_clear_type_cache()
80+
81+
82+
@support.cpython_only
83+
class TypeCacheWithSpecializationTests(unittest.TestCase):
84+
def tearDown(self):
85+
_clear_type_cache()
86+
87+
def _assign_and_check_valid_version(self, user_type):
88+
type_modified(user_type)
89+
type_assign_version(user_type)
90+
self.assertNotEqual(type_get_version(user_type), 0)
91+
92+
def _assign_and_check_version_0(self, user_type):
93+
type_modified(user_type)
94+
type_assign_specific_version_unsafe(user_type, 0)
95+
self.assertEqual(type_get_version(user_type), 0)
96+
97+
def _all_opnames(self, func):
98+
return set(instr.opname for instr in dis.Bytecode(func, adaptive=True))
99+
100+
def _check_specialization(self, func, arg, opname, *, should_specialize):
101+
self.assertIn(opname, self._all_opnames(func))
102+
103+
for _ in range(100):
104+
func(arg)
105+
106+
if should_specialize:
107+
self.assertNotIn(opname, self._all_opnames(func))
108+
else:
109+
self.assertIn(opname, self._all_opnames(func))
110+
111+
def test_class_load_attr_specialization_user_type(self):
80112
class A:
81113
def foo(self):
82114
pass
83115

84-
class B:
85-
def foo(self):
86-
pass
116+
self._assign_and_check_valid_version(A)
117+
118+
def load_foo_1(type_):
119+
type_.foo
120+
121+
self._check_specialization(load_foo_1, A, "LOAD_ATTR", should_specialize=True)
122+
del load_foo_1
87123

88-
type_modified(A)
89-
type_assign_version(A)
90-
type_modified(B)
91-
type_assign_version(B)
92-
self.assertNotEqual(type_get_version(A), 0)
93-
self.assertNotEqual(type_get_version(B), 0)
94-
self.assertNotEqual(type_get_version(A), type_get_version(B))
124+
self._assign_and_check_version_0(A)
95125

96-
def get_foo(type_):
126+
def load_foo_2(type_):
97127
return type_.foo
98128

99-
self.assertIn(
100-
"LOAD_ATTR",
101-
[instr.opname for instr in dis.Bytecode(get_foo, adaptive=True)],
102-
)
129+
self._check_specialization(load_foo_2, A, "LOAD_ATTR", should_specialize=False)
103130

104-
get_foo(A)
105-
get_foo(A)
131+
def test_class_load_attr_specialization_static_type(self):
132+
self._assign_and_check_valid_version(str)
133+
self._assign_and_check_valid_version(bytes)
106134

107-
# check that specialization has occurred
108-
self.assertNotIn(
109-
"LOAD_ATTR",
110-
[instr.opname for instr in dis.Bytecode(get_foo, adaptive=True)],
111-
)
135+
def get_capitalize_1(type_):
136+
return type_.capitalize
112137

113-
def test_specialization_user_type_tag_overflow(self):
114-
class A:
115-
def foo(self):
116-
pass
138+
self._check_specialization(get_capitalize_1, str, "LOAD_ATTR", should_specialize=True)
139+
self.assertEqual(get_capitalize_1(str)('hello'), 'Hello')
140+
self.assertEqual(get_capitalize_1(bytes)(b'hello'), b'Hello')
141+
del get_capitalize_1
142+
143+
# Permanently overflow the static type version counter, and force str and bytes
144+
# to have tp_version_tag == 0
145+
for _ in range(2**16):
146+
type_modified(str)
147+
type_assign_version(str)
148+
type_modified(bytes)
149+
type_assign_version(bytes)
150+
151+
self.assertEqual(type_get_version(str), 0)
152+
self.assertEqual(type_get_version(bytes), 0)
153+
154+
def get_capitalize_2(type_):
155+
return type_.capitalize
156+
157+
self._check_specialization(get_capitalize_2, str, "LOAD_ATTR", should_specialize=False)
158+
self.assertEqual(get_capitalize_2(str)('hello'), 'Hello')
159+
self.assertEqual(get_capitalize_2(bytes)(b'hello'), b'Hello')
160+
161+
def test_property_load_attr_specialization_user_type(self):
162+
class G:
163+
@property
164+
def x(self):
165+
return 9
166+
167+
self._assign_and_check_valid_version(G)
117168

169+
def load_x_1(instance):
170+
instance.x
171+
172+
self._check_specialization(load_x_1, G(), "LOAD_ATTR", should_specialize=True)
173+
del load_x_1
174+
175+
self._assign_and_check_version_0(G)
176+
177+
def load_x_2(instance):
178+
instance.x
179+
180+
self._check_specialization(load_x_2, G(), "LOAD_ATTR", should_specialize=False)
181+
182+
def test_store_attr_specialization_user_type(self):
118183
class B:
119-
def foo(self):
184+
__slots__ = ("bar",)
185+
186+
self._assign_and_check_valid_version(B)
187+
188+
def store_bar_1(type_):
189+
type_.bar = 10
190+
191+
self._check_specialization(store_bar_1, B(), "STORE_ATTR", should_specialize=True)
192+
del store_bar_1
193+
194+
self._assign_and_check_version_0(B)
195+
196+
def store_bar_2(type_):
197+
type_.bar = 10
198+
199+
self._check_specialization(store_bar_2, B(), "STORE_ATTR", should_specialize=False)
200+
201+
def test_class_call_specialization_user_type(self):
202+
class F:
203+
def __init__(self):
120204
pass
121205

122-
type_modified(A)
123-
type_assign_specific_version_unsafe(A, 0)
124-
type_modified(B)
125-
type_assign_specific_version_unsafe(B, 0)
126-
self.assertEqual(type_get_version(A), 0)
127-
self.assertEqual(type_get_version(B), 0)
206+
self._assign_and_check_valid_version(F)
128207

129-
def get_foo(type_):
130-
return type_.foo
208+
def call_class_1(type_):
209+
type_()
210+
211+
self._check_specialization(call_class_1, F, "CALL", should_specialize=True)
212+
del call_class_1
213+
214+
self._assign_and_check_version_0(F)
215+
216+
def call_class_2(type_):
217+
type_()
218+
219+
self._check_specialization(call_class_2, F, "CALL", should_specialize=False)
220+
221+
def test_to_bool_specialization_user_type(self):
222+
class H:
223+
pass
224+
225+
self._assign_and_check_valid_version(H)
226+
227+
def to_bool_1(instance):
228+
not instance
229+
230+
self._check_specialization(to_bool_1, H(), "TO_BOOL", should_specialize=True)
231+
del to_bool_1
131232

132-
self.assertIn(
133-
"LOAD_ATTR",
134-
[instr.opname for instr in dis.Bytecode(get_foo, adaptive=True)],
135-
)
233+
self._assign_and_check_version_0(H)
136234

137-
get_foo(A)
138-
get_foo(A)
235+
def to_bool_2(instance):
236+
not instance
139237

140-
# check that specialization has not occurred due to version tag == 0
141-
self.assertIn(
142-
"LOAD_ATTR",
143-
[instr.opname for instr in dis.Bytecode(get_foo, adaptive=True)],
144-
)
238+
self._check_specialization(to_bool_2, H(), "TO_BOOL", should_specialize=False)
145239

146240

147241
if __name__ == "__main__":

0 commit comments

Comments
 (0)