Skip to content

Commit c755928

Browse files
authored
[mypyc] Simplify generated code for native attribute get (#11978)
The implementation merges consecutive GetAttr and Branch ops. The main benefit is that this makes the generated code smaller and easier to read, making it easier to spot possible improvements in the code.
1 parent 82bc8df commit c755928

File tree

4 files changed

+99
-15
lines changed

4 files changed

+99
-15
lines changed

mypyc/codegen/emitfunc.py

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
"""Code generation for native function bodies."""
22

3-
from typing import Union, Optional
3+
from typing import List, Union, Optional
44
from typing_extensions import Final
55

66
from mypyc.common import (
77
REG_PREFIX, NATIVE_PREFIX, STATIC_PREFIX, TYPE_PREFIX, MODULE_PREFIX,
88
)
99
from mypyc.codegen.emit import Emitter
1010
from mypyc.ir.ops import (
11-
OpVisitor, Goto, Branch, Return, Assign, Integer, LoadErrorValue, GetAttr, SetAttr,
11+
Op, OpVisitor, Goto, Branch, Return, Assign, Integer, LoadErrorValue, GetAttr, SetAttr,
1212
LoadStatic, InitStatic, TupleGet, TupleSet, Call, IncRef, DecRef, Box, Cast, Unbox,
1313
BasicBlock, Value, MethodCall, Unreachable, NAMESPACE_STATIC, NAMESPACE_TYPE, NAMESPACE_MODULE,
1414
RaiseStandardError, CallC, LoadGlobal, Truncate, IntOp, LoadMem, GetElementPtr,
@@ -88,8 +88,13 @@ def generate_native_function(fn: FuncIR,
8888
next_block = blocks[i + 1]
8989
body.emit_label(block)
9090
visitor.next_block = next_block
91-
for op in block.ops:
92-
op.accept(visitor)
91+
92+
ops = block.ops
93+
visitor.ops = ops
94+
visitor.op_index = 0
95+
while visitor.op_index < len(ops):
96+
ops[visitor.op_index].accept(visitor)
97+
visitor.op_index += 1
9398

9499
body.emit_line('}')
95100

@@ -110,7 +115,12 @@ def __init__(self,
110115
self.module_name = module_name
111116
self.literals = emitter.context.literals
112117
self.rare = False
118+
# Next basic block to be processed after the current one (if any), set by caller
113119
self.next_block: Optional[BasicBlock] = None
120+
# Ops in the basic block currently being processed, set by caller
121+
self.ops: List[Op] = []
122+
# Current index within ops; visit methods can increment this to skip/merge ops
123+
self.op_index = 0
114124

115125
def temp_name(self) -> str:
116126
return self.emitter.temp_name()
@@ -293,16 +303,44 @@ def visit_get_attr(self, op: GetAttr) -> None:
293303
attr_expr = self.get_attr_expr(obj, op, decl_cl)
294304
self.emitter.emit_line('{} = {};'.format(dest, attr_expr))
295305
self.emitter.emit_undefined_attr_check(
296-
attr_rtype, attr_expr, '==', unlikely=True
306+
attr_rtype, dest, '==', unlikely=True
297307
)
298308
exc_class = 'PyExc_AttributeError'
299-
self.emitter.emit_line(
300-
'PyErr_SetString({}, "attribute {} of {} undefined");'.format(
301-
exc_class, repr(op.attr), repr(cl.name)))
309+
merged_branch = None
310+
branch = self.next_branch()
311+
if branch is not None:
312+
if (branch.value is op
313+
and branch.op == Branch.IS_ERROR
314+
and branch.traceback_entry is not None
315+
and not branch.negated):
316+
# Generate code for the following branch here to avoid
317+
# redundant branches in the generate code.
318+
self.emit_attribute_error(branch, cl.name, op.attr)
319+
self.emit_line('goto %s;' % self.label(branch.true))
320+
merged_branch = branch
321+
self.emitter.emit_line('}')
322+
if not merged_branch:
323+
self.emitter.emit_line(
324+
'PyErr_SetString({}, "attribute {} of {} undefined");'.format(
325+
exc_class, repr(op.attr), repr(cl.name)))
326+
302327
if attr_rtype.is_refcounted:
303-
self.emitter.emit_line('} else {')
304-
self.emitter.emit_inc_ref(attr_expr, attr_rtype)
305-
self.emitter.emit_line('}')
328+
if not merged_branch:
329+
self.emitter.emit_line('} else {')
330+
self.emitter.emit_inc_ref(dest, attr_rtype)
331+
if merged_branch:
332+
if merged_branch.false is not self.next_block:
333+
self.emit_line('goto %s;' % self.label(merged_branch.false))
334+
self.op_index += 1
335+
else:
336+
self.emitter.emit_line('}')
337+
338+
def next_branch(self) -> Optional[Branch]:
339+
if self.op_index + 1 < len(self.ops):
340+
next_op = self.ops[self.op_index + 1]
341+
if isinstance(next_op, Branch):
342+
return next_op
343+
return None
306344

307345
def visit_set_attr(self, op: SetAttr) -> None:
308346
dest = self.reg(op)
@@ -603,6 +641,19 @@ def emit_traceback(self, op: Branch) -> None:
603641
if DEBUG_ERRORS:
604642
self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");')
605643

644+
def emit_attribute_error(self, op: Branch, class_name: str, attr: str) -> None:
645+
assert op.traceback_entry is not None
646+
globals_static = self.emitter.static_name('globals', self.module_name)
647+
self.emit_line('CPy_AttributeError("%s", "%s", "%s", "%s", %d, %s);' % (
648+
self.source_path.replace("\\", "\\\\"),
649+
op.traceback_entry[0],
650+
class_name,
651+
attr,
652+
op.traceback_entry[1],
653+
globals_static))
654+
if DEBUG_ERRORS:
655+
self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");')
656+
606657
def emit_signed_int_cast(self, type: RType) -> str:
607658
if is_tagged(type):
608659
return '(Py_ssize_t)'

mypyc/lib-rt/CPy.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,8 @@ void _CPy_GetExcInfo(PyObject **p_type, PyObject **p_value, PyObject **p_traceba
498498
void CPyError_OutOfMemory(void);
499499
void CPy_TypeError(const char *expected, PyObject *value);
500500
void CPy_AddTraceback(const char *filename, const char *funcname, int line, PyObject *globals);
501+
void CPy_AttributeError(const char *filename, const char *funcname, const char *classname,
502+
const char *attrname, int line, PyObject *globals);
501503

502504

503505
// Misc operations

mypyc/lib-rt/exc_ops.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,3 +225,11 @@ void CPy_AddTraceback(const char *filename, const char *funcname, int line, PyOb
225225
error:
226226
_PyErr_ChainExceptions(exc, val, tb);
227227
}
228+
229+
void CPy_AttributeError(const char *filename, const char *funcname, const char *classname,
230+
const char *attrname, int line, PyObject *globals) {
231+
char buf[500];
232+
snprintf(buf, sizeof(buf), "attribute '%.200s' of '%.200s' undefined", classname, attrname);
233+
PyErr_SetString(PyExc_AttributeError, buf);
234+
CPy_AddTraceback(filename, funcname, line, globals);
235+
}

mypyc/test/test_emitfunc.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -281,22 +281,39 @@ def test_get_attr(self) -> None:
281281
self.assert_emit(
282282
GetAttr(self.r, 'y', 1),
283283
"""cpy_r_r0 = ((mod___AObject *)cpy_r_r)->_y;
284-
if (unlikely(((mod___AObject *)cpy_r_r)->_y == CPY_INT_TAG)) {
284+
if (unlikely(cpy_r_r0 == CPY_INT_TAG)) {
285285
PyErr_SetString(PyExc_AttributeError, "attribute 'y' of 'A' undefined");
286286
} else {
287-
CPyTagged_INCREF(((mod___AObject *)cpy_r_r)->_y);
287+
CPyTagged_INCREF(cpy_r_r0);
288288
}
289289
""")
290290

291291
def test_get_attr_non_refcounted(self) -> None:
292292
self.assert_emit(
293293
GetAttr(self.r, 'x', 1),
294294
"""cpy_r_r0 = ((mod___AObject *)cpy_r_r)->_x;
295-
if (unlikely(((mod___AObject *)cpy_r_r)->_x == 2)) {
295+
if (unlikely(cpy_r_r0 == 2)) {
296296
PyErr_SetString(PyExc_AttributeError, "attribute 'x' of 'A' undefined");
297297
}
298298
""")
299299

300+
def test_get_attr_merged(self) -> None:
301+
op = GetAttr(self.r, 'y', 1)
302+
branch = Branch(op, BasicBlock(8), BasicBlock(9), Branch.IS_ERROR)
303+
branch.traceback_entry = ('foobar', 123)
304+
self.assert_emit(
305+
op,
306+
"""\
307+
cpy_r_r0 = ((mod___AObject *)cpy_r_r)->_y;
308+
if (unlikely(cpy_r_r0 == CPY_INT_TAG)) {
309+
CPy_AttributeError("prog.py", "foobar", "A", "y", 123, CPyStatic_prog___globals);
310+
goto CPyL8;
311+
}
312+
CPyTagged_INCREF(cpy_r_r0);
313+
goto CPyL9;
314+
""",
315+
next_branch=branch)
316+
300317
def test_set_attr(self) -> None:
301318
self.assert_emit(
302319
SetAttr(self.r, 'y', self.m, 1),
@@ -428,7 +445,8 @@ def assert_emit(self,
428445
expected: str,
429446
next_block: Optional[BasicBlock] = None,
430447
*,
431-
rare: bool = False) -> None:
448+
rare: bool = False,
449+
next_branch: Optional[Branch] = None) -> None:
432450
block = BasicBlock(0)
433451
block.ops.append(op)
434452
value_names = generate_names_for_ir(self.registers, [block])
@@ -440,6 +458,11 @@ def assert_emit(self,
440458
visitor = FunctionEmitterVisitor(emitter, declarations, 'prog.py', 'prog')
441459
visitor.next_block = next_block
442460
visitor.rare = rare
461+
if next_branch:
462+
visitor.ops = [op, next_branch]
463+
else:
464+
visitor.ops = [op]
465+
visitor.op_index = 0
443466

444467
op.accept(visitor)
445468
frags = declarations.fragments + emitter.fragments

0 commit comments

Comments
 (0)