1
1
"""Utilities for checking that internal ir is valid and consistent."""
2
- from typing import List , Union
2
+ from typing import List , Union , Set , Tuple
3
3
from mypyc .ir .pprint import format_func
4
4
from mypyc .ir .ops import (
5
5
OpVisitor , BasicBlock , Op , ControlOp , Goto , Branch , Return , Unreachable ,
6
6
Assign , AssignMulti , LoadErrorValue , LoadLiteral , GetAttr , SetAttr , LoadStatic ,
7
7
InitStatic , TupleGet , TupleSet , IncRef , DecRef , Call , MethodCall , Cast ,
8
8
Box , Unbox , RaiseStandardError , CallC , Truncate , LoadGlobal , IntOp , ComparisonOp ,
9
- LoadMem , SetMem , GetElementPtr , LoadAddress , KeepAlive
9
+ LoadMem , SetMem , GetElementPtr , LoadAddress , KeepAlive , Register , Integer ,
10
+ BaseAssign
10
11
)
11
- from mypyc .ir .func_ir import FuncIR
12
+ from mypyc .ir .rtypes import (
13
+ RType , RPrimitive , RUnion , is_object_rprimitive , RInstance , RArray ,
14
+ int_rprimitive , list_rprimitive , dict_rprimitive , set_rprimitive ,
15
+ range_rprimitive , str_rprimitive , bytes_rprimitive , tuple_rprimitive
16
+ )
17
+ from mypyc .ir .func_ir import FuncIR , FUNC_STATICMETHOD
12
18
13
19
14
20
class FnError (object ):
@@ -17,8 +23,11 @@ def __init__(self, source: Union[Op, BasicBlock], desc: str) -> None:
17
23
self .desc = desc
18
24
19
25
def __eq__ (self , other : object ) -> bool :
20
- return isinstance (other , FnError ) and self .source == other .source and \
21
- self .desc == other .desc
26
+ return (
27
+ isinstance (other , FnError )
28
+ and self .source == other .source
29
+ and self .desc == other .desc
30
+ )
22
31
23
32
def __repr__ (self ) -> str :
24
33
return f"FnError(source={ self .source } , desc={ self .desc } )"
@@ -28,19 +37,44 @@ def check_func_ir(fn: FuncIR) -> List[FnError]:
28
37
"""Applies validations to a given function ir and returns a list of errors found."""
29
38
errors = []
30
39
40
+ op_set = set ()
41
+
31
42
for block in fn .blocks :
32
43
if not block .terminated :
33
- errors .append (FnError (
34
- source = block .ops [- 1 ] if block .ops else block ,
35
- desc = "Block not terminated" ,
36
- ))
44
+ errors .append (
45
+ FnError (
46
+ source = block .ops [- 1 ] if block .ops else block ,
47
+ desc = "Block not terminated" ,
48
+ )
49
+ )
50
+ for op in block .ops [:- 1 ]:
51
+ if isinstance (op , ControlOp ):
52
+ errors .append (
53
+ FnError (
54
+ source = op ,
55
+ desc = "Block has operations after control op" ,
56
+ )
57
+ )
58
+
59
+ if op in op_set :
60
+ errors .append (
61
+ FnError (
62
+ source = op ,
63
+ desc = "Func has a duplicate op" ,
64
+ )
65
+ )
66
+ op_set .add (op )
67
+
68
+ errors .extend (check_op_sources_valid (fn ))
69
+ if errors :
70
+ return errors
37
71
38
72
op_checker = OpChecker (fn )
39
73
for block in fn .blocks :
40
74
for op in block .ops :
41
75
op .accept (op_checker )
42
76
43
- return errors + op_checker .errors
77
+ return op_checker .errors
44
78
45
79
46
80
class IrCheckException (Exception ):
@@ -50,11 +84,90 @@ class IrCheckException(Exception):
50
84
def assert_func_ir_valid (fn : FuncIR ) -> None :
51
85
errors = check_func_ir (fn )
52
86
if errors :
53
- raise IrCheckException ("Internal error: Generated invalid IR: \n " + "\n " .join (
54
- format_func (fn , [(e .source , e .desc ) for e in errors ])),
87
+ raise IrCheckException (
88
+ "Internal error: Generated invalid IR: \n "
89
+ + "\n " .join (format_func (fn , [(e .source , e .desc ) for e in errors ])),
55
90
)
56
91
57
92
93
+ def check_op_sources_valid (fn : FuncIR ) -> List [FnError ]:
94
+ errors = []
95
+ valid_ops : Set [Op ] = set ()
96
+ valid_registers : Set [Register ] = set ()
97
+
98
+ for block in fn .blocks :
99
+ valid_ops .update (block .ops )
100
+
101
+ valid_registers .update (
102
+ [op .dest for op in block .ops if isinstance (op , BaseAssign )]
103
+ )
104
+
105
+ valid_registers .update (fn .arg_regs )
106
+
107
+ for block in fn .blocks :
108
+ for op in block .ops :
109
+ for source in op .sources ():
110
+ if isinstance (source , Integer ):
111
+ pass
112
+ elif isinstance (source , Op ):
113
+ if source not in valid_ops :
114
+ errors .append (
115
+ FnError (
116
+ source = op ,
117
+ desc = f"Invalid op reference to op of type { type (source ).__name__ } " ,
118
+ )
119
+ )
120
+ elif isinstance (source , Register ):
121
+ if source not in valid_registers :
122
+ errors .append (
123
+ FnError (
124
+ source = op ,
125
+ desc = f"Invalid op reference to register { source .name } " ,
126
+ )
127
+ )
128
+
129
+ return errors
130
+
131
+
132
+ disjoint_types = set (
133
+ [
134
+ int_rprimitive .name ,
135
+ bytes_rprimitive .name ,
136
+ str_rprimitive .name ,
137
+ dict_rprimitive .name ,
138
+ list_rprimitive .name ,
139
+ set_rprimitive .name ,
140
+ tuple_rprimitive .name ,
141
+ range_rprimitive .name ,
142
+ ]
143
+ )
144
+
145
+
146
+ def can_coerce_to (src : RType , dest : RType ) -> bool :
147
+ """Check if src can be assigned to dest_rtype.
148
+
149
+ Currently okay to have false positives.
150
+ """
151
+ if isinstance (dest , RUnion ):
152
+ return any (can_coerce_to (src , d ) for d in dest .items )
153
+
154
+ if isinstance (dest , RPrimitive ):
155
+ if isinstance (src , RPrimitive ):
156
+ # If either src or dest is a disjoint type, then they must both be.
157
+ if src .name in disjoint_types and dest .name in disjoint_types :
158
+ return src .name == dest .name
159
+ return src .size == dest .size
160
+ if isinstance (src , RInstance ):
161
+ return is_object_rprimitive (dest )
162
+ if isinstance (src , RUnion ):
163
+ # IR doesn't have the ability to narrow unions based on
164
+ # control flow, so cannot be a strict all() here.
165
+ return any (can_coerce_to (s , dest ) for s in src .items )
166
+ return False
167
+
168
+ return True
169
+
170
+
58
171
class OpChecker (OpVisitor [None ]):
59
172
def __init__ (self , parent_fn : FuncIR ) -> None :
60
173
self .parent_fn = parent_fn
@@ -66,7 +179,16 @@ def fail(self, source: Op, desc: str) -> None:
66
179
def check_control_op_targets (self , op : ControlOp ) -> None :
67
180
for target in op .targets ():
68
181
if target not in self .parent_fn .blocks :
69
- self .fail (source = op , desc = f"Invalid control operation target: { target .label } " )
182
+ self .fail (
183
+ source = op , desc = f"Invalid control operation target: { target .label } "
184
+ )
185
+
186
+ def check_type_coercion (self , op : Op , src : RType , dest : RType ) -> None :
187
+ if not can_coerce_to (src , dest ):
188
+ self .fail (
189
+ source = op ,
190
+ desc = f"Cannot coerce source type { src .name } to dest type { dest .name } " ,
191
+ )
70
192
71
193
def visit_goto (self , op : Goto ) -> None :
72
194
self .check_control_op_targets (op )
@@ -75,52 +197,118 @@ def visit_branch(self, op: Branch) -> None:
75
197
self .check_control_op_targets (op )
76
198
77
199
def visit_return (self , op : Return ) -> None :
78
- pass
200
+ self . check_type_coercion ( op , op . value . type , self . parent_fn . decl . sig . ret_type )
79
201
80
202
def visit_unreachable (self , op : Unreachable ) -> None :
203
+ # Unreachables are checked at a higher level since validation
204
+ # requires access to the entire basic block.
81
205
pass
82
206
83
207
def visit_assign (self , op : Assign ) -> None :
84
- pass
208
+ self . check_type_coercion ( op , op . src . type , op . dest . type )
85
209
86
210
def visit_assign_multi (self , op : AssignMulti ) -> None :
87
- pass
211
+ for src in op .src :
212
+ assert isinstance (op .dest .type , RArray )
213
+ self .check_type_coercion (op , src .type , op .dest .type .item_type )
88
214
89
215
def visit_load_error_value (self , op : LoadErrorValue ) -> None :
216
+ # Currently it is assumed that all types have an error value.
217
+ # Once this is fixed we can validate that the rtype here actually
218
+ # has an error value.
90
219
pass
91
220
221
+ def check_tuple_items_valid_literals (
222
+ self , op : LoadLiteral , t : Tuple [object , ...]
223
+ ) -> None :
224
+ for x in t :
225
+ if x is not None and not isinstance (
226
+ x , (str , bytes , bool , int , float , complex , tuple )
227
+ ):
228
+ self .fail (op , f"Invalid type for item of tuple literal: { type (x )} )" )
229
+ if isinstance (x , tuple ):
230
+ self .check_tuple_items_valid_literals (op , x )
231
+
92
232
def visit_load_literal (self , op : LoadLiteral ) -> None :
93
- pass
233
+ expected_type = None
234
+ if op .value is None :
235
+ expected_type = "builtins.object"
236
+ elif isinstance (op .value , int ):
237
+ expected_type = "builtins.int"
238
+ elif isinstance (op .value , str ):
239
+ expected_type = "builtins.str"
240
+ elif isinstance (op .value , bytes ):
241
+ expected_type = "builtins.bytes"
242
+ elif isinstance (op .value , bool ):
243
+ expected_type = "builtins.object"
244
+ elif isinstance (op .value , float ):
245
+ expected_type = "builtins.float"
246
+ elif isinstance (op .value , complex ):
247
+ expected_type = "builtins.object"
248
+ elif isinstance (op .value , tuple ):
249
+ expected_type = "builtins.tuple"
250
+ self .check_tuple_items_valid_literals (op , op .value )
251
+
252
+ assert expected_type is not None , "Missed a case for LoadLiteral check"
253
+
254
+ if op .type .name not in [expected_type , "builtins.object" ]:
255
+ self .fail (
256
+ op ,
257
+ f"Invalid literal value for type: value has "
258
+ f"type { expected_type } , but op has type { op .type .name } " ,
259
+ )
94
260
95
261
def visit_get_attr (self , op : GetAttr ) -> None :
262
+ # Nothing to do.
96
263
pass
97
264
98
265
def visit_set_attr (self , op : SetAttr ) -> None :
266
+ # Nothing to do.
99
267
pass
100
268
269
+ # Static operations cannot be checked at the function level.
101
270
def visit_load_static (self , op : LoadStatic ) -> None :
102
271
pass
103
272
104
273
def visit_init_static (self , op : InitStatic ) -> None :
105
274
pass
106
275
107
276
def visit_tuple_get (self , op : TupleGet ) -> None :
277
+ # Nothing to do.
108
278
pass
109
279
110
280
def visit_tuple_set (self , op : TupleSet ) -> None :
281
+ # Nothing to do.
111
282
pass
112
283
113
284
def visit_inc_ref (self , op : IncRef ) -> None :
285
+ # Nothing to do.
114
286
pass
115
287
116
288
def visit_dec_ref (self , op : DecRef ) -> None :
289
+ # Nothing to do.
117
290
pass
118
291
119
292
def visit_call (self , op : Call ) -> None :
120
- pass
293
+ # Length is checked in constructor, and return type is set
294
+ # in a way that can't be incorrect
295
+ for arg_value , arg_runtime in zip (op .args , op .fn .sig .args ):
296
+ self .check_type_coercion (op , arg_value .type , arg_runtime .type )
121
297
122
298
def visit_method_call (self , op : MethodCall ) -> None :
123
- pass
299
+ # Similar to above, but we must look up method first.
300
+ method_decl = op .receiver_type .class_ir .method_decl (op .method )
301
+ if method_decl .kind == FUNC_STATICMETHOD :
302
+ decl_index = 0
303
+ else :
304
+ decl_index = 1
305
+
306
+ if len (op .args ) + decl_index != len (method_decl .sig .args ):
307
+ self .fail (op , "Incorrect number of args for method call." )
308
+
309
+ # Skip the receiver argument (self)
310
+ for arg_value , arg_runtime in zip (op .args , method_decl .sig .args [decl_index :]):
311
+ self .check_type_coercion (op , arg_value .type , arg_runtime .type )
124
312
125
313
def visit_cast (self , op : Cast ) -> None :
126
314
pass
0 commit comments