Skip to content

Commit a04bdbf

Browse files
authored
[mypyc] Support ERR_ALWAYS (#9073)
Related to mypyc/mypyc#734, with a focus on exceptions related ops. This PR adds a new error kind: ERR_ALWAYS, which indicates the op always fails. It adds temporary false value to ensure such behavior in the exception handling transform and makes the raise op void.
1 parent eae1860 commit a04bdbf

File tree

8 files changed

+79
-62
lines changed

8 files changed

+79
-62
lines changed

mypyc/codegen/emitfunc.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,7 @@ def visit_branch(self, op: Branch) -> None:
130130

131131
self.emit_line('if ({}) {{'.format(cond))
132132

133-
if op.traceback_entry is not None:
134-
globals_static = self.emitter.static_name('globals', self.module_name)
135-
self.emit_line('CPy_AddTraceback("%s", "%s", %d, %s);' % (
136-
self.source_path.replace("\\", "\\\\"),
137-
op.traceback_entry[0],
138-
op.traceback_entry[1],
139-
globals_static))
140-
if DEBUG_ERRORS:
141-
self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");')
133+
self.emit_traceback(op)
142134

143135
self.emit_lines(
144136
'goto %s;' % self.label(op.true),
@@ -422,7 +414,10 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> None:
422414
self.emitter.emit_line('{} = 0;'.format(self.reg(op)))
423415

424416
def visit_call_c(self, op: CallC) -> None:
425-
dest = self.get_dest_assign(op)
417+
if op.is_void:
418+
dest = ''
419+
else:
420+
dest = self.get_dest_assign(op)
426421
args = ', '.join(self.reg(arg) for arg in op.args)
427422
self.emitter.emit_line("{}{}({});".format(dest, op.function_name, args))
428423

@@ -472,3 +467,14 @@ def emit_dec_ref(self, dest: str, rtype: RType, is_xdec: bool) -> None:
472467

473468
def emit_declaration(self, line: str) -> None:
474469
self.declarations.emit_line(line)
470+
471+
def emit_traceback(self, op: Branch) -> None:
472+
if op.traceback_entry is not None:
473+
globals_static = self.emitter.static_name('globals', self.module_name)
474+
self.emit_line('CPy_AddTraceback("%s", "%s", %d, %s);' % (
475+
self.source_path.replace("\\", "\\\\"),
476+
op.traceback_entry[0],
477+
op.traceback_entry[1],
478+
globals_static))
479+
if DEBUG_ERRORS:
480+
self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");')

mypyc/ir/ops.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,8 @@ def terminated(self) -> bool:
295295
ERR_FALSE = 2 # type: Final
296296
# Generates negative integer on exception
297297
ERR_NEG_INT = 3 # type: Final
298+
# Always fails
299+
ERR_ALWAYS = 4 # type: Final
298300

299301
# Hack: using this line number for an op will suppress it in tracebacks
300302
NO_TRACEBACK_LINE_NO = -10000
@@ -1167,7 +1169,10 @@ def __init__(self,
11671169

11681170
def to_str(self, env: Environment) -> str:
11691171
args_str = ', '.join(env.format('%r', arg) for arg in self.args)
1170-
return env.format('%r = %s(%s)', self, self.function_name, args_str)
1172+
if self.is_void:
1173+
return env.format('%s(%s)', self.function_name, args_str)
1174+
else:
1175+
return env.format('%r = %s(%s)', self, self.function_name, args_str)
11711176

11721177
def sources(self) -> List[Value]:
11731178
return self.args

mypyc/irbuild/statement.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def transform_raise_stmt(builder: IRBuilder, s: RaiseStmt) -> None:
243243
return
244244

245245
exc = builder.accept(s.expr)
246-
builder.primitive_op(raise_exception_op, [exc], s.line)
246+
builder.call_c(raise_exception_op, [exc], s.line)
247247
builder.add(Unreachable())
248248

249249

@@ -614,7 +614,7 @@ def transform_assert_stmt(builder: IRBuilder, a: AssertStmt) -> None:
614614
message = builder.accept(a.msg)
615615
exc_type = builder.load_module_attr_by_fullname('builtins.AssertionError', a.line)
616616
exc = builder.py_call(exc_type, [message], a.line)
617-
builder.primitive_op(raise_exception_op, [exc], a.line)
617+
builder.call_c(raise_exception_op, [exc], a.line)
618618
builder.add(Unreachable())
619619
builder.activate_block(ok_block)
620620

mypyc/primitives/exc_ops.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,18 @@
11
"""Exception-related primitive ops."""
22

3-
from mypyc.ir.ops import ERR_NEVER, ERR_FALSE
3+
from mypyc.ir.ops import ERR_NEVER, ERR_FALSE, ERR_ALWAYS
44
from mypyc.ir.rtypes import bool_rprimitive, object_rprimitive, void_rtype, exc_rtuple
55
from mypyc.primitives.registry import (
6-
simple_emit, call_emit, call_void_emit, call_and_fail_emit, custom_op,
6+
simple_emit, call_emit, call_void_emit, call_and_fail_emit, custom_op, c_custom_op
77
)
88

99
# If the argument is a class, raise an instance of the class. Otherwise, assume
1010
# that the argument is an exception object, and raise it.
11-
#
12-
# TODO: Making this raise conditionally is kind of hokey.
13-
raise_exception_op = custom_op(
11+
raise_exception_op = c_custom_op(
1412
arg_types=[object_rprimitive],
15-
result_type=bool_rprimitive,
16-
error_kind=ERR_FALSE,
17-
format_str='raise_exception({args[0]}); {dest} = 0',
18-
emit=call_and_fail_emit('CPy_Raise'))
13+
return_type=void_rtype,
14+
c_function_name='CPy_Raise',
15+
error_kind=ERR_ALWAYS)
1916

2017
# Raise StopIteration exception with the specified value (which can be NULL).
2118
set_stop_iteration_value = custom_op(

mypyc/test-data/irbuild-basic.test

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,24 +1322,22 @@ def foo():
13221322
r0 :: object
13231323
r1 :: str
13241324
r2, r3 :: object
1325-
r4 :: bool
13261325
L0:
13271326
r0 = builtins :: module
13281327
r1 = unicode_1 :: static ('Exception')
13291328
r2 = getattr r0, r1
13301329
r3 = py_call(r2)
1331-
raise_exception(r3); r4 = 0
1330+
CPy_Raise(r3)
13321331
unreachable
13331332
def bar():
13341333
r0 :: object
13351334
r1 :: str
13361335
r2 :: object
1337-
r3 :: bool
13381336
L0:
13391337
r0 = builtins :: module
13401338
r1 = unicode_1 :: static ('Exception')
13411339
r2 = getattr r0, r1
1342-
raise_exception(r2); r3 = 0
1340+
CPy_Raise(r2)
13431341
unreachable
13441342

13451343
[case testModuleTopLevel_toplevel]

mypyc/test-data/irbuild-statements.test

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -614,8 +614,7 @@ def complex_msg(x, s):
614614
r4 :: object
615615
r5 :: str
616616
r6, r7 :: object
617-
r8 :: bool
618-
r9 :: None
617+
r8 :: None
619618
L0:
620619
r0 = builtins.None :: object
621620
r1 = x is not r0
@@ -629,11 +628,11 @@ L2:
629628
r5 = unicode_3 :: static ('AssertionError')
630629
r6 = getattr r4, r5
631630
r7 = py_call(r6, s)
632-
raise_exception(r7); r8 = 0
631+
CPy_Raise(r7)
633632
unreachable
634633
L3:
635-
r9 = None
636-
return r9
634+
r8 = None
635+
return r8
637636

638637
[case testDelList]
639638
def delList() -> None:

mypyc/test-data/irbuild-try.test

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -277,14 +277,13 @@ def a(b):
277277
r1 :: object
278278
r2 :: str
279279
r3, r4 :: object
280-
r5 :: bool
281-
r6, r7, r8 :: tuple[object, object, object]
282-
r9 :: str
283-
r10 :: object
284-
r11 :: str
285-
r12, r13 :: object
286-
r14, r15 :: bool
287-
r16 :: None
280+
r5, r6, r7 :: tuple[object, object, object]
281+
r8 :: str
282+
r9 :: object
283+
r10 :: str
284+
r11, r12 :: object
285+
r13, r14 :: bool
286+
r15 :: None
288287
L0:
289288
L1:
290289
if b goto L2 else goto L3 :: bool
@@ -294,39 +293,39 @@ L2:
294293
r2 = unicode_2 :: static ('Exception')
295294
r3 = getattr r1, r2
296295
r4 = py_call(r3, r0)
297-
raise_exception(r4); r5 = 0
296+
CPy_Raise(r4)
298297
unreachable
299298
L3:
300299
L4:
301300
L5:
302-
r7 = <error> :: tuple[object, object, object]
303-
r6 = r7
301+
r6 = <error> :: tuple[object, object, object]
302+
r5 = r6
304303
goto L7
305304
L6: (handler for L1, L2, L3)
306-
r8 = error_catch
307-
r6 = r8
305+
r7 = error_catch
306+
r5 = r7
308307
L7:
309-
r9 = unicode_3 :: static ('finally')
310-
r10 = builtins :: module
311-
r11 = unicode_4 :: static ('print')
312-
r12 = getattr r10, r11
313-
r13 = py_call(r12, r9)
314-
if is_error(r6) goto L9 else goto L8
308+
r8 = unicode_3 :: static ('finally')
309+
r9 = builtins :: module
310+
r10 = unicode_4 :: static ('print')
311+
r11 = getattr r9, r10
312+
r12 = py_call(r11, r8)
313+
if is_error(r5) goto L9 else goto L8
315314
L8:
316-
reraise_exc; r14 = 0
315+
reraise_exc; r13 = 0
317316
unreachable
318317
L9:
319318
goto L13
320319
L10: (handler for L7, L8)
321-
if is_error(r6) goto L12 else goto L11
320+
if is_error(r5) goto L12 else goto L11
322321
L11:
323-
restore_exc_info r6
322+
restore_exc_info r5
324323
L12:
325-
r15 = keep_propagating
324+
r14 = keep_propagating
326325
unreachable
327326
L13:
328-
r16 = None
329-
return r16
327+
r15 = None
328+
return r15
330329

331330
[case testWith]
332331
from typing import Any

mypyc/transform/exceptions.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
from typing import List, Optional
1313

1414
from mypyc.ir.ops import (
15-
BasicBlock, LoadErrorValue, Return, Branch, RegisterOp, ERR_NEVER, ERR_MAGIC,
16-
ERR_FALSE, ERR_NEG_INT, NO_TRACEBACK_LINE_NO,
15+
BasicBlock, LoadErrorValue, Return, Branch, RegisterOp, LoadInt, ERR_NEVER, ERR_MAGIC,
16+
ERR_FALSE, ERR_NEG_INT, ERR_ALWAYS, NO_TRACEBACK_LINE_NO, Environment
1717
)
1818
from mypyc.ir.func_ir import FuncIR
19+
from mypyc.ir.rtypes import bool_rprimitive
1920

2021

2122
def insert_exception_handling(ir: FuncIR) -> None:
@@ -29,7 +30,7 @@ def insert_exception_handling(ir: FuncIR) -> None:
2930
error_label = add_handler_block(ir)
3031
break
3132
if error_label:
32-
ir.blocks = split_blocks_at_errors(ir.blocks, error_label, ir.traceback_name)
33+
ir.blocks = split_blocks_at_errors(ir.blocks, error_label, ir.traceback_name, ir.env)
3334

3435

3536
def add_handler_block(ir: FuncIR) -> BasicBlock:
@@ -44,7 +45,8 @@ def add_handler_block(ir: FuncIR) -> BasicBlock:
4445

4546
def split_blocks_at_errors(blocks: List[BasicBlock],
4647
default_error_handler: BasicBlock,
47-
func_name: Optional[str]) -> List[BasicBlock]:
48+
func_name: Optional[str],
49+
env: Environment) -> List[BasicBlock]:
4850
new_blocks = [] # type: List[BasicBlock]
4951

5052
# First split blocks on ops that may raise.
@@ -60,6 +62,7 @@ def split_blocks_at_errors(blocks: List[BasicBlock],
6062
block.error_handler = None
6163

6264
for op in ops:
65+
target = op
6366
cur_block.ops.append(op)
6467
if isinstance(op, RegisterOp) and op.error_kind != ERR_NEVER:
6568
# Split
@@ -77,14 +80,24 @@ def split_blocks_at_errors(blocks: List[BasicBlock],
7780
elif op.error_kind == ERR_NEG_INT:
7881
variant = Branch.NEG_INT_EXPR
7982
negated = False
83+
elif op.error_kind == ERR_ALWAYS:
84+
variant = Branch.BOOL_EXPR
85+
negated = True
86+
# this is a hack to represent the always fail
87+
# semantics, using a temporary bool with value false
88+
tmp = LoadInt(0, rtype=bool_rprimitive)
89+
cur_block.ops.append(tmp)
90+
env.add_op(tmp)
91+
target = tmp
8092
else:
8193
assert False, 'unknown error kind %d' % op.error_kind
8294

8395
# Void ops can't generate errors since error is always
8496
# indicated by a special value stored in a register.
85-
assert not op.is_void, "void op generating errors?"
97+
if op.error_kind != ERR_ALWAYS:
98+
assert not op.is_void, "void op generating errors?"
8699

87-
branch = Branch(op,
100+
branch = Branch(target,
88101
true_label=error_label,
89102
false_label=new_block,
90103
op=variant,

0 commit comments

Comments
 (0)