Skip to content

Commit c7365ef

Browse files
authored
[mypyc] Implement additional ircheck checks (#12191)
* Implement more checks in ircheck Check op/register validity and check type coercions for return and assign. * Add checks for call operations, control ops, and literal loads * Add check for duplicate ops In particular, this will catch the case where builder.add() is called twice, which causes very weird bogus IR.
1 parent d02db50 commit c7365ef

File tree

5 files changed

+360
-47
lines changed

5 files changed

+360
-47
lines changed

mypyc/analysis/ircheck.py

Lines changed: 207 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
"""Utilities for checking that internal ir is valid and consistent."""
2-
from typing import List, Union
2+
from typing import List, Union, Set, Tuple
33
from mypyc.ir.pprint import format_func
44
from mypyc.ir.ops import (
55
OpVisitor, BasicBlock, Op, ControlOp, Goto, Branch, Return, Unreachable,
66
Assign, AssignMulti, LoadErrorValue, LoadLiteral, GetAttr, SetAttr, LoadStatic,
77
InitStatic, TupleGet, TupleSet, IncRef, DecRef, Call, MethodCall, Cast,
88
Box, Unbox, RaiseStandardError, CallC, Truncate, LoadGlobal, IntOp, ComparisonOp,
9-
LoadMem, SetMem, GetElementPtr, LoadAddress, KeepAlive
9+
LoadMem, SetMem, GetElementPtr, LoadAddress, KeepAlive, Register, Integer,
10+
BaseAssign
1011
)
11-
from mypyc.ir.func_ir import FuncIR
12+
from mypyc.ir.rtypes import (
13+
RType, RPrimitive, RUnion, is_object_rprimitive, RInstance, RArray,
14+
int_rprimitive, list_rprimitive, dict_rprimitive, set_rprimitive,
15+
range_rprimitive, str_rprimitive, bytes_rprimitive, tuple_rprimitive
16+
)
17+
from mypyc.ir.func_ir import FuncIR, FUNC_STATICMETHOD
1218

1319

1420
class FnError(object):
@@ -17,8 +23,11 @@ def __init__(self, source: Union[Op, BasicBlock], desc: str) -> None:
1723
self.desc = desc
1824

1925
def __eq__(self, other: object) -> bool:
20-
return isinstance(other, FnError) and self.source == other.source and \
21-
self.desc == other.desc
26+
return (
27+
isinstance(other, FnError)
28+
and self.source == other.source
29+
and self.desc == other.desc
30+
)
2231

2332
def __repr__(self) -> str:
2433
return f"FnError(source={self.source}, desc={self.desc})"
@@ -28,19 +37,44 @@ def check_func_ir(fn: FuncIR) -> List[FnError]:
2837
"""Applies validations to a given function ir and returns a list of errors found."""
2938
errors = []
3039

40+
op_set = set()
41+
3142
for block in fn.blocks:
3243
if not block.terminated:
33-
errors.append(FnError(
34-
source=block.ops[-1] if block.ops else block,
35-
desc="Block not terminated",
36-
))
44+
errors.append(
45+
FnError(
46+
source=block.ops[-1] if block.ops else block,
47+
desc="Block not terminated",
48+
)
49+
)
50+
for op in block.ops[:-1]:
51+
if isinstance(op, ControlOp):
52+
errors.append(
53+
FnError(
54+
source=op,
55+
desc="Block has operations after control op",
56+
)
57+
)
58+
59+
if op in op_set:
60+
errors.append(
61+
FnError(
62+
source=op,
63+
desc="Func has a duplicate op",
64+
)
65+
)
66+
op_set.add(op)
67+
68+
errors.extend(check_op_sources_valid(fn))
69+
if errors:
70+
return errors
3771

3872
op_checker = OpChecker(fn)
3973
for block in fn.blocks:
4074
for op in block.ops:
4175
op.accept(op_checker)
4276

43-
return errors + op_checker.errors
77+
return op_checker.errors
4478

4579

4680
class IrCheckException(Exception):
@@ -50,11 +84,90 @@ class IrCheckException(Exception):
5084
def assert_func_ir_valid(fn: FuncIR) -> None:
5185
errors = check_func_ir(fn)
5286
if errors:
53-
raise IrCheckException("Internal error: Generated invalid IR: \n" + "\n".join(
54-
format_func(fn, [(e.source, e.desc) for e in errors])),
87+
raise IrCheckException(
88+
"Internal error: Generated invalid IR: \n"
89+
+ "\n".join(format_func(fn, [(e.source, e.desc) for e in errors])),
5590
)
5691

5792

93+
def check_op_sources_valid(fn: FuncIR) -> List[FnError]:
94+
errors = []
95+
valid_ops: Set[Op] = set()
96+
valid_registers: Set[Register] = set()
97+
98+
for block in fn.blocks:
99+
valid_ops.update(block.ops)
100+
101+
valid_registers.update(
102+
[op.dest for op in block.ops if isinstance(op, BaseAssign)]
103+
)
104+
105+
valid_registers.update(fn.arg_regs)
106+
107+
for block in fn.blocks:
108+
for op in block.ops:
109+
for source in op.sources():
110+
if isinstance(source, Integer):
111+
pass
112+
elif isinstance(source, Op):
113+
if source not in valid_ops:
114+
errors.append(
115+
FnError(
116+
source=op,
117+
desc=f"Invalid op reference to op of type {type(source).__name__}",
118+
)
119+
)
120+
elif isinstance(source, Register):
121+
if source not in valid_registers:
122+
errors.append(
123+
FnError(
124+
source=op,
125+
desc=f"Invalid op reference to register {source.name}",
126+
)
127+
)
128+
129+
return errors
130+
131+
132+
disjoint_types = set(
133+
[
134+
int_rprimitive.name,
135+
bytes_rprimitive.name,
136+
str_rprimitive.name,
137+
dict_rprimitive.name,
138+
list_rprimitive.name,
139+
set_rprimitive.name,
140+
tuple_rprimitive.name,
141+
range_rprimitive.name,
142+
]
143+
)
144+
145+
146+
def can_coerce_to(src: RType, dest: RType) -> bool:
147+
"""Check if src can be assigned to dest_rtype.
148+
149+
Currently okay to have false positives.
150+
"""
151+
if isinstance(dest, RUnion):
152+
return any(can_coerce_to(src, d) for d in dest.items)
153+
154+
if isinstance(dest, RPrimitive):
155+
if isinstance(src, RPrimitive):
156+
# If either src or dest is a disjoint type, then they must both be.
157+
if src.name in disjoint_types and dest.name in disjoint_types:
158+
return src.name == dest.name
159+
return src.size == dest.size
160+
if isinstance(src, RInstance):
161+
return is_object_rprimitive(dest)
162+
if isinstance(src, RUnion):
163+
# IR doesn't have the ability to narrow unions based on
164+
# control flow, so cannot be a strict all() here.
165+
return any(can_coerce_to(s, dest) for s in src.items)
166+
return False
167+
168+
return True
169+
170+
58171
class OpChecker(OpVisitor[None]):
59172
def __init__(self, parent_fn: FuncIR) -> None:
60173
self.parent_fn = parent_fn
@@ -66,7 +179,16 @@ def fail(self, source: Op, desc: str) -> None:
66179
def check_control_op_targets(self, op: ControlOp) -> None:
67180
for target in op.targets():
68181
if target not in self.parent_fn.blocks:
69-
self.fail(source=op, desc=f"Invalid control operation target: {target.label}")
182+
self.fail(
183+
source=op, desc=f"Invalid control operation target: {target.label}"
184+
)
185+
186+
def check_type_coercion(self, op: Op, src: RType, dest: RType) -> None:
187+
if not can_coerce_to(src, dest):
188+
self.fail(
189+
source=op,
190+
desc=f"Cannot coerce source type {src.name} to dest type {dest.name}",
191+
)
70192

71193
def visit_goto(self, op: Goto) -> None:
72194
self.check_control_op_targets(op)
@@ -75,52 +197,118 @@ def visit_branch(self, op: Branch) -> None:
75197
self.check_control_op_targets(op)
76198

77199
def visit_return(self, op: Return) -> None:
78-
pass
200+
self.check_type_coercion(op, op.value.type, self.parent_fn.decl.sig.ret_type)
79201

80202
def visit_unreachable(self, op: Unreachable) -> None:
203+
# Unreachables are checked at a higher level since validation
204+
# requires access to the entire basic block.
81205
pass
82206

83207
def visit_assign(self, op: Assign) -> None:
84-
pass
208+
self.check_type_coercion(op, op.src.type, op.dest.type)
85209

86210
def visit_assign_multi(self, op: AssignMulti) -> None:
87-
pass
211+
for src in op.src:
212+
assert isinstance(op.dest.type, RArray)
213+
self.check_type_coercion(op, src.type, op.dest.type.item_type)
88214

89215
def visit_load_error_value(self, op: LoadErrorValue) -> None:
216+
# Currently it is assumed that all types have an error value.
217+
# Once this is fixed we can validate that the rtype here actually
218+
# has an error value.
90219
pass
91220

221+
def check_tuple_items_valid_literals(
222+
self, op: LoadLiteral, t: Tuple[object, ...]
223+
) -> None:
224+
for x in t:
225+
if x is not None and not isinstance(
226+
x, (str, bytes, bool, int, float, complex, tuple)
227+
):
228+
self.fail(op, f"Invalid type for item of tuple literal: {type(x)})")
229+
if isinstance(x, tuple):
230+
self.check_tuple_items_valid_literals(op, x)
231+
92232
def visit_load_literal(self, op: LoadLiteral) -> None:
93-
pass
233+
expected_type = None
234+
if op.value is None:
235+
expected_type = "builtins.object"
236+
elif isinstance(op.value, int):
237+
expected_type = "builtins.int"
238+
elif isinstance(op.value, str):
239+
expected_type = "builtins.str"
240+
elif isinstance(op.value, bytes):
241+
expected_type = "builtins.bytes"
242+
elif isinstance(op.value, bool):
243+
expected_type = "builtins.object"
244+
elif isinstance(op.value, float):
245+
expected_type = "builtins.float"
246+
elif isinstance(op.value, complex):
247+
expected_type = "builtins.object"
248+
elif isinstance(op.value, tuple):
249+
expected_type = "builtins.tuple"
250+
self.check_tuple_items_valid_literals(op, op.value)
251+
252+
assert expected_type is not None, "Missed a case for LoadLiteral check"
253+
254+
if op.type.name not in [expected_type, "builtins.object"]:
255+
self.fail(
256+
op,
257+
f"Invalid literal value for type: value has "
258+
f"type {expected_type}, but op has type {op.type.name}",
259+
)
94260

95261
def visit_get_attr(self, op: GetAttr) -> None:
262+
# Nothing to do.
96263
pass
97264

98265
def visit_set_attr(self, op: SetAttr) -> None:
266+
# Nothing to do.
99267
pass
100268

269+
# Static operations cannot be checked at the function level.
101270
def visit_load_static(self, op: LoadStatic) -> None:
102271
pass
103272

104273
def visit_init_static(self, op: InitStatic) -> None:
105274
pass
106275

107276
def visit_tuple_get(self, op: TupleGet) -> None:
277+
# Nothing to do.
108278
pass
109279

110280
def visit_tuple_set(self, op: TupleSet) -> None:
281+
# Nothing to do.
111282
pass
112283

113284
def visit_inc_ref(self, op: IncRef) -> None:
285+
# Nothing to do.
114286
pass
115287

116288
def visit_dec_ref(self, op: DecRef) -> None:
289+
# Nothing to do.
117290
pass
118291

119292
def visit_call(self, op: Call) -> None:
120-
pass
293+
# Length is checked in constructor, and return type is set
294+
# in a way that can't be incorrect
295+
for arg_value, arg_runtime in zip(op.args, op.fn.sig.args):
296+
self.check_type_coercion(op, arg_value.type, arg_runtime.type)
121297

122298
def visit_method_call(self, op: MethodCall) -> None:
123-
pass
299+
# Similar to above, but we must look up method first.
300+
method_decl = op.receiver_type.class_ir.method_decl(op.method)
301+
if method_decl.kind == FUNC_STATICMETHOD:
302+
decl_index = 0
303+
else:
304+
decl_index = 1
305+
306+
if len(op.args) + decl_index != len(method_decl.sig.args):
307+
self.fail(op, "Incorrect number of args for method call.")
308+
309+
# Skip the receiver argument (self)
310+
for arg_value, arg_runtime in zip(op.args, method_decl.sig.args[decl_index:]):
311+
self.check_type_coercion(op, arg_value.type, arg_runtime.type)
124312

125313
def visit_cast(self, op: Cast) -> None:
126314
pass

mypyc/ir/ops.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -214,15 +214,21 @@ def accept(self, visitor: 'OpVisitor[T]') -> T:
214214
pass
215215

216216

217-
class Assign(Op):
217+
class BaseAssign(Op):
218+
"""Base class for ops that assign to a register."""
219+
def __init__(self, dest: Register, line: int = -1) -> None:
220+
super().__init__(line)
221+
self.dest = dest
222+
223+
224+
class Assign(BaseAssign):
218225
"""Assign a value to a Register (dest = src)."""
219226

220227
error_kind = ERR_NEVER
221228

222229
def __init__(self, dest: Register, src: Value, line: int = -1) -> None:
223-
super().__init__(line)
230+
super().__init__(dest, line)
224231
self.src = src
225-
self.dest = dest
226232

227233
def sources(self) -> List[Value]:
228234
return [self.src]
@@ -234,7 +240,7 @@ def accept(self, visitor: 'OpVisitor[T]') -> T:
234240
return visitor.visit_assign(self)
235241

236242

237-
class AssignMulti(Op):
243+
class AssignMulti(BaseAssign):
238244
"""Assign multiple values to a Register (dest = src1, src2, ...).
239245
240246
This is used to initialize RArray values. It's provided to avoid
@@ -248,12 +254,11 @@ class AssignMulti(Op):
248254
error_kind = ERR_NEVER
249255

250256
def __init__(self, dest: Register, src: List[Value], line: int = -1) -> None:
251-
super().__init__(line)
257+
super().__init__(dest, line)
252258
assert src
253259
assert isinstance(dest.type, RArray)
254260
assert dest.type.length == len(src)
255261
self.src = src
256-
self.dest = dest
257262

258263
def sources(self) -> List[Value]:
259264
return self.src[:]
@@ -490,6 +495,7 @@ def __init__(self, fn: 'FuncDecl', args: Sequence[Value], line: int) -> None:
490495
super().__init__(line)
491496
self.fn = fn
492497
self.args = list(args)
498+
assert len(self.args) == len(fn.sig.args)
493499
self.type = fn.sig.ret_type
494500

495501
def sources(self) -> List[Value]:

0 commit comments

Comments
 (0)