@@ -76,72 +76,166 @@ class C:
76
76
new_version = type_get_version (C )
77
77
self .assertEqual (new_version , orig_version + 5 )
78
78
79
- def test_specialization_user_type_no_tag_overflow (self ):
79
+ _clear_type_cache ()
80
+
81
+
82
+ @support .cpython_only
83
+ class TypeCacheWithSpecializationTests (unittest .TestCase ):
84
+ def tearDown (self ):
85
+ _clear_type_cache ()
86
+
87
+ def _assign_and_check_valid_version (self , user_type ):
88
+ type_modified (user_type )
89
+ type_assign_version (user_type )
90
+ self .assertNotEqual (type_get_version (user_type ), 0 )
91
+
92
+ def _assign_and_check_version_0 (self , user_type ):
93
+ type_modified (user_type )
94
+ type_assign_specific_version_unsafe (user_type , 0 )
95
+ self .assertEqual (type_get_version (user_type ), 0 )
96
+
97
+ def _all_opnames (self , func ):
98
+ return set (instr .opname for instr in dis .Bytecode (func , adaptive = True ))
99
+
100
+ def _check_specialization (self , func , arg , opname , * , should_specialize ):
101
+ self .assertIn (opname , self ._all_opnames (func ))
102
+
103
+ for _ in range (100 ):
104
+ func (arg )
105
+
106
+ if should_specialize :
107
+ self .assertNotIn (opname , self ._all_opnames (func ))
108
+ else :
109
+ self .assertIn (opname , self ._all_opnames (func ))
110
+
111
+ def test_class_load_attr_specialization_user_type (self ):
80
112
class A :
81
113
def foo (self ):
82
114
pass
83
115
84
- class B :
85
- def foo (self ):
86
- pass
116
+ self ._assign_and_check_valid_version (A )
117
+
118
+ def load_foo_1 (type_ ):
119
+ type_ .foo
120
+
121
+ self ._check_specialization (load_foo_1 , A , "LOAD_ATTR" , should_specialize = True )
122
+ del load_foo_1
87
123
88
- type_modified (A )
89
- type_assign_version (A )
90
- type_modified (B )
91
- type_assign_version (B )
92
- self .assertNotEqual (type_get_version (A ), 0 )
93
- self .assertNotEqual (type_get_version (B ), 0 )
94
- self .assertNotEqual (type_get_version (A ), type_get_version (B ))
124
+ self ._assign_and_check_version_0 (A )
95
125
96
- def get_foo (type_ ):
126
+ def load_foo_2 (type_ ):
97
127
return type_ .foo
98
128
99
- self .assertIn (
100
- "LOAD_ATTR" ,
101
- [instr .opname for instr in dis .Bytecode (get_foo , adaptive = True )],
102
- )
129
+ self ._check_specialization (load_foo_2 , A , "LOAD_ATTR" , should_specialize = False )
103
130
104
- get_foo (A )
105
- get_foo (A )
131
+ def test_class_load_attr_specialization_static_type (self ):
132
+ self ._assign_and_check_valid_version (str )
133
+ self ._assign_and_check_valid_version (bytes )
106
134
107
- # check that specialization has occurred
108
- self .assertNotIn (
109
- "LOAD_ATTR" ,
110
- [instr .opname for instr in dis .Bytecode (get_foo , adaptive = True )],
111
- )
135
+ def get_capitalize_1 (type_ ):
136
+ return type_ .capitalize
112
137
113
- def test_specialization_user_type_tag_overflow (self ):
114
- class A :
115
- def foo (self ):
116
- pass
138
+ self ._check_specialization (get_capitalize_1 , str , "LOAD_ATTR" , should_specialize = True )
139
+ self .assertEqual (get_capitalize_1 (str )('hello' ), 'Hello' )
140
+ self .assertEqual (get_capitalize_1 (bytes )(b'hello' ), b'Hello' )
141
+ del get_capitalize_1
142
+
143
+ # Permanently overflow the static type version counter, and force str and bytes
144
+ # to have tp_version_tag == 0
145
+ for _ in range (2 ** 16 ):
146
+ type_modified (str )
147
+ type_assign_version (str )
148
+ type_modified (bytes )
149
+ type_assign_version (bytes )
150
+
151
+ self .assertEqual (type_get_version (str ), 0 )
152
+ self .assertEqual (type_get_version (bytes ), 0 )
153
+
154
+ def get_capitalize_2 (type_ ):
155
+ return type_ .capitalize
156
+
157
+ self ._check_specialization (get_capitalize_2 , str , "LOAD_ATTR" , should_specialize = False )
158
+ self .assertEqual (get_capitalize_2 (str )('hello' ), 'Hello' )
159
+ self .assertEqual (get_capitalize_2 (bytes )(b'hello' ), b'Hello' )
160
+
161
+ def test_property_load_attr_specialization_user_type (self ):
162
+ class G :
163
+ @property
164
+ def x (self ):
165
+ return 9
166
+
167
+ self ._assign_and_check_valid_version (G )
117
168
169
+ def load_x_1 (instance ):
170
+ instance .x
171
+
172
+ self ._check_specialization (load_x_1 , G (), "LOAD_ATTR" , should_specialize = True )
173
+ del load_x_1
174
+
175
+ self ._assign_and_check_version_0 (G )
176
+
177
+ def load_x_2 (instance ):
178
+ instance .x
179
+
180
+ self ._check_specialization (load_x_2 , G (), "LOAD_ATTR" , should_specialize = False )
181
+
182
+ def test_store_attr_specialization_user_type (self ):
118
183
class B :
119
- def foo (self ):
184
+ __slots__ = ("bar" ,)
185
+
186
+ self ._assign_and_check_valid_version (B )
187
+
188
+ def store_bar_1 (type_ ):
189
+ type_ .bar = 10
190
+
191
+ self ._check_specialization (store_bar_1 , B (), "STORE_ATTR" , should_specialize = True )
192
+ del store_bar_1
193
+
194
+ self ._assign_and_check_version_0 (B )
195
+
196
+ def store_bar_2 (type_ ):
197
+ type_ .bar = 10
198
+
199
+ self ._check_specialization (store_bar_2 , B (), "STORE_ATTR" , should_specialize = False )
200
+
201
+ def test_class_call_specialization_user_type (self ):
202
+ class F :
203
+ def __init__ (self ):
120
204
pass
121
205
122
- type_modified (A )
123
- type_assign_specific_version_unsafe (A , 0 )
124
- type_modified (B )
125
- type_assign_specific_version_unsafe (B , 0 )
126
- self .assertEqual (type_get_version (A ), 0 )
127
- self .assertEqual (type_get_version (B ), 0 )
206
+ self ._assign_and_check_valid_version (F )
128
207
129
- def get_foo (type_ ):
130
- return type_ .foo
208
+ def call_class_1 (type_ ):
209
+ type_ ()
210
+
211
+ self ._check_specialization (call_class_1 , F , "CALL" , should_specialize = True )
212
+ del call_class_1
213
+
214
+ self ._assign_and_check_version_0 (F )
215
+
216
+ def call_class_2 (type_ ):
217
+ type_ ()
218
+
219
+ self ._check_specialization (call_class_2 , F , "CALL" , should_specialize = False )
220
+
221
+ def test_to_bool_specialization_user_type (self ):
222
+ class H :
223
+ pass
224
+
225
+ self ._assign_and_check_valid_version (H )
226
+
227
+ def to_bool_1 (instance ):
228
+ not instance
229
+
230
+ self ._check_specialization (to_bool_1 , H (), "TO_BOOL" , should_specialize = True )
231
+ del to_bool_1
131
232
132
- self .assertIn (
133
- "LOAD_ATTR" ,
134
- [instr .opname for instr in dis .Bytecode (get_foo , adaptive = True )],
135
- )
233
+ self ._assign_and_check_version_0 (H )
136
234
137
- get_foo ( A )
138
- get_foo ( A )
235
+ def to_bool_2 ( instance ):
236
+ not instance
139
237
140
- # check that specialization has not occurred due to version tag == 0
141
- self .assertIn (
142
- "LOAD_ATTR" ,
143
- [instr .opname for instr in dis .Bytecode (get_foo , adaptive = True )],
144
- )
238
+ self ._check_specialization (to_bool_2 , H (), "TO_BOOL" , should_specialize = False )
145
239
146
240
147
241
if __name__ == "__main__" :
0 commit comments