Skip to content

Commit 040f3ab

Browse files
authored
[mypyc] Generate smaller code for casts (#12839)
Merge a cast op followed by a branch that does an error check and adds a traceback entry. Since casts are very common, this reduces the size of the generated code a fair amount. Old code generated for a cast: ``` if (likely(PyUnicode_Check(cpy_r_x))) cpy_r_r0 = cpy_r_x; else { CPy_TypeError("str", cpy_r_x); cpy_r_r0 = NULL; } if (unlikely(cpy_r_r0 == NULL)) { CPy_AddTraceback("t/t.py", "foo", 2, CPyStatic_globals); goto CPyL2; } ``` New code: ``` if (likely(PyUnicode_Check(cpy_r_x))) cpy_r_r0 = cpy_r_x; else { CPy_TypeErrorTraceback("t/t.py", "foo", 2, CPyStatic_globals, "str", cpy_r_x); goto CPyL2; } ```
1 parent c8efeed commit 040f3ab

File tree

7 files changed

+270
-57
lines changed

7 files changed

+270
-57
lines changed

mypyc/codegen/emit.py

Lines changed: 118 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from mypy.backports import OrderedDict
44
from typing import List, Set, Dict, Optional, Callable, Union, Tuple
5+
from typing_extensions import Final
6+
57
import sys
68

79
from mypyc.common import (
@@ -23,6 +25,10 @@
2325
from mypyc.sametype import is_same_type
2426
from mypyc.codegen.literals import Literals
2527

28+
# Whether to insert debug asserts for all error handling, to quickly
29+
# catch errors propagating without exceptions set.
30+
DEBUG_ERRORS: Final = False
31+
2632

2733
class HeaderDeclaration:
2834
"""A representation of a declaration in C.
@@ -104,6 +110,20 @@ def __init__(self, label: str) -> None:
104110
self.label = label
105111

106112

113+
class TracebackAndGotoHandler(ErrorHandler):
114+
"""Add traceback item and goto label on error."""
115+
116+
def __init__(self,
117+
label: str,
118+
source_path: str,
119+
module_name: str,
120+
traceback_entry: Tuple[str, int]) -> None:
121+
self.label = label
122+
self.source_path = source_path
123+
self.module_name = module_name
124+
self.traceback_entry = traceback_entry
125+
126+
107127
class ReturnHandler(ErrorHandler):
108128
"""Return a constant value on error."""
109129

@@ -439,18 +459,6 @@ def emit_cast(self,
439459
likely: If the cast is likely to succeed (can be False for unions)
440460
"""
441461
error = error or AssignHandler()
442-
if isinstance(error, AssignHandler):
443-
handle_error = '%s = NULL;' % dest
444-
elif isinstance(error, GotoHandler):
445-
handle_error = 'goto %s;' % error.label
446-
else:
447-
assert isinstance(error, ReturnHandler)
448-
handle_error = 'return %s;' % error.value
449-
if raise_exception:
450-
raise_exc = f'CPy_TypeError("{self.pretty_name(typ)}", {src}); '
451-
err = raise_exc + handle_error
452-
else:
453-
err = handle_error
454462

455463
# Special case casting *from* optional
456464
if src_type and is_optional_type(src_type) and not is_object_rprimitive(typ):
@@ -465,9 +473,9 @@ def emit_cast(self,
465473
self.emit_arg_check(src, dest, typ, check.format(src), optional)
466474
self.emit_lines(
467475
f' {dest} = {src};',
468-
'else {',
469-
err,
470-
'}')
476+
'else {')
477+
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
478+
self.emit_line('}')
471479
return
472480

473481
# TODO: Verify refcount handling.
@@ -500,9 +508,9 @@ def emit_cast(self,
500508
self.emit_arg_check(src, dest, typ, check.format(prefix, src), optional)
501509
self.emit_lines(
502510
f' {dest} = {src};',
503-
'else {',
504-
err,
505-
'}')
511+
'else {')
512+
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
513+
self.emit_line('}')
506514
elif is_bytes_rprimitive(typ):
507515
if declare_dest:
508516
self.emit_line(f'PyObject *{dest};')
@@ -512,9 +520,9 @@ def emit_cast(self,
512520
self.emit_arg_check(src, dest, typ, check.format(src, src), optional)
513521
self.emit_lines(
514522
f' {dest} = {src};',
515-
'else {',
516-
err,
517-
'}')
523+
'else {')
524+
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
525+
self.emit_line('}')
518526
elif is_tuple_rprimitive(typ):
519527
if declare_dest:
520528
self.emit_line(f'{self.ctype(typ)} {dest};')
@@ -525,9 +533,9 @@ def emit_cast(self,
525533
check.format(src), optional)
526534
self.emit_lines(
527535
f' {dest} = {src};',
528-
'else {',
529-
err,
530-
'}')
536+
'else {')
537+
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
538+
self.emit_line('}')
531539
elif isinstance(typ, RInstance):
532540
if declare_dest:
533541
self.emit_line(f'PyObject *{dest};')
@@ -551,10 +559,10 @@ def emit_cast(self,
551559
check = f'(likely{check})'
552560
self.emit_arg_check(src, dest, typ, check, optional)
553561
self.emit_lines(
554-
f' {dest} = {src};',
555-
'else {',
556-
err,
557-
'}')
562+
f' {dest} = {src};'.format(dest, src),
563+
'else {')
564+
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
565+
self.emit_line('}')
558566
elif is_none_rprimitive(typ):
559567
if declare_dest:
560568
self.emit_line(f'PyObject *{dest};')
@@ -565,9 +573,9 @@ def emit_cast(self,
565573
check.format(src), optional)
566574
self.emit_lines(
567575
f' {dest} = {src};',
568-
'else {',
569-
err,
570-
'}')
576+
'else {')
577+
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
578+
self.emit_line('}')
571579
elif is_object_rprimitive(typ):
572580
if declare_dest:
573581
self.emit_line(f'PyObject *{dest};')
@@ -576,21 +584,51 @@ def emit_cast(self,
576584
if optional:
577585
self.emit_line('}')
578586
elif isinstance(typ, RUnion):
579-
self.emit_union_cast(src, dest, typ, declare_dest, err, optional, src_type)
587+
self.emit_union_cast(src, dest, typ, declare_dest, error, optional, src_type,
588+
raise_exception)
580589
elif isinstance(typ, RTuple):
581590
assert not optional
582-
self.emit_tuple_cast(src, dest, typ, declare_dest, err, src_type)
591+
self.emit_tuple_cast(src, dest, typ, declare_dest, error, src_type)
583592
else:
584593
assert False, 'Cast not implemented: %s' % typ
585594

595+
def emit_cast_error_handler(self,
596+
error: ErrorHandler,
597+
src: str,
598+
dest: str,
599+
typ: RType,
600+
raise_exception: bool) -> None:
601+
if raise_exception:
602+
if isinstance(error, TracebackAndGotoHandler):
603+
# Merge raising and emitting traceback entry into a single call.
604+
self.emit_type_error_traceback(
605+
error.source_path, error.module_name, error.traceback_entry,
606+
typ=typ,
607+
src=src)
608+
self.emit_line('goto %s;' % error.label)
609+
return
610+
self.emit_line('CPy_TypeError("{}", {}); '.format(self.pretty_name(typ), src))
611+
if isinstance(error, AssignHandler):
612+
self.emit_line('%s = NULL;' % dest)
613+
elif isinstance(error, GotoHandler):
614+
self.emit_line('goto %s;' % error.label)
615+
elif isinstance(error, TracebackAndGotoHandler):
616+
self.emit_line('%s = NULL;' % dest)
617+
self.emit_traceback(error.source_path, error.module_name, error.traceback_entry)
618+
self.emit_line('goto %s;' % error.label)
619+
else:
620+
assert isinstance(error, ReturnHandler)
621+
self.emit_line('return %s;' % error.value)
622+
586623
def emit_union_cast(self,
587624
src: str,
588625
dest: str,
589626
typ: RUnion,
590627
declare_dest: bool,
591-
err: str,
628+
error: ErrorHandler,
592629
optional: bool,
593-
src_type: Optional[RType]) -> None:
630+
src_type: Optional[RType],
631+
raise_exception: bool) -> None:
594632
"""Emit cast to a union type.
595633
596634
The arguments are similar to emit_cast.
@@ -613,11 +651,11 @@ def emit_union_cast(self,
613651
likely=False)
614652
self.emit_line(f'if ({dest} != NULL) goto {good_label};')
615653
# Handle cast failure.
616-
self.emit_line(err)
654+
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
617655
self.emit_label(good_label)
618656

619657
def emit_tuple_cast(self, src: str, dest: str, typ: RTuple, declare_dest: bool,
620-
err: str, src_type: Optional[RType]) -> None:
658+
error: ErrorHandler, src_type: Optional[RType]) -> None:
621659
"""Emit cast to a tuple type.
622660
623661
The arguments are similar to emit_cast.
@@ -740,7 +778,8 @@ def emit_unbox(self,
740778
self.emit_line('} else {')
741779

742780
cast_temp = self.temp_name()
743-
self.emit_tuple_cast(src, cast_temp, typ, declare_dest=True, err='', src_type=None)
781+
self.emit_tuple_cast(src, cast_temp, typ, declare_dest=True, error=error,
782+
src_type=None)
744783
self.emit_line(f'if (unlikely({cast_temp} == NULL)) {{')
745784

746785
# self.emit_arg_check(src, dest, typ,
@@ -886,3 +925,44 @@ def emit_gc_clear(self, target: str, rtype: RType) -> None:
886925
self.emit_line(f'Py_CLEAR({target});')
887926
else:
888927
assert False, 'emit_gc_clear() not implemented for %s' % repr(rtype)
928+
929+
def emit_traceback(self,
930+
source_path: str,
931+
module_name: str,
932+
traceback_entry: Tuple[str, int]) -> None:
933+
return self._emit_traceback('CPy_AddTraceback', source_path, module_name, traceback_entry)
934+
935+
def emit_type_error_traceback(
936+
self,
937+
source_path: str,
938+
module_name: str,
939+
traceback_entry: Tuple[str, int],
940+
*,
941+
typ: RType,
942+
src: str) -> None:
943+
func = 'CPy_TypeErrorTraceback'
944+
type_str = f'"{self.pretty_name(typ)}"'
945+
return self._emit_traceback(
946+
func, source_path, module_name, traceback_entry, type_str=type_str, src=src)
947+
948+
def _emit_traceback(self,
949+
func: str,
950+
source_path: str,
951+
module_name: str,
952+
traceback_entry: Tuple[str, int],
953+
type_str: str = '',
954+
src: str = '') -> None:
955+
globals_static = self.static_name('globals', module_name)
956+
line = '%s("%s", "%s", %d, %s' % (
957+
func,
958+
source_path.replace("\\", "\\\\"),
959+
traceback_entry[0],
960+
traceback_entry[1],
961+
globals_static)
962+
if type_str:
963+
assert src
964+
line += f', {type_str}, {src}'
965+
line += ');'
966+
self.emit_line(line)
967+
if DEBUG_ERRORS:
968+
self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");')

mypyc/codegen/emitfunc.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from mypyc.common import (
77
REG_PREFIX, NATIVE_PREFIX, STATIC_PREFIX, TYPE_PREFIX, MODULE_PREFIX,
88
)
9-
from mypyc.codegen.emit import Emitter
9+
from mypyc.codegen.emit import Emitter, TracebackAndGotoHandler, DEBUG_ERRORS
1010
from mypyc.ir.ops import (
1111
Op, OpVisitor, Goto, Branch, Return, Assign, Integer, LoadErrorValue, GetAttr, SetAttr,
1212
LoadStatic, InitStatic, TupleGet, TupleSet, Call, IncRef, DecRef, Box, Cast, Unbox,
@@ -23,10 +23,6 @@
2323
from mypyc.ir.pprint import generate_names_for_ir
2424
from mypyc.analysis.blockfreq import frequently_executed_blocks
2525

26-
# Whether to insert debug asserts for all error handling, to quickly
27-
# catch errors propagating without exceptions set.
28-
DEBUG_ERRORS = False
29-
3026

3127
def native_function_type(fn: FuncIR, emitter: Emitter) -> str:
3228
args = ', '.join(emitter.ctype(arg.type) for arg in fn.args) or 'void'
@@ -322,7 +318,7 @@ def visit_get_attr(self, op: GetAttr) -> None:
322318
and branch.traceback_entry is not None
323319
and not branch.negated):
324320
# Generate code for the following branch here to avoid
325-
# redundant branches in the generate code.
321+
# redundant branches in the generated code.
326322
self.emit_attribute_error(branch, cl.name, op.attr)
327323
self.emit_line('goto %s;' % self.label(branch.true))
328324
merged_branch = branch
@@ -485,8 +481,24 @@ def visit_box(self, op: Box) -> None:
485481
self.emitter.emit_box(self.reg(op.src), self.reg(op), op.src.type, can_borrow=True)
486482

487483
def visit_cast(self, op: Cast) -> None:
484+
branch = self.next_branch()
485+
handler = None
486+
if branch is not None:
487+
if (branch.value is op
488+
and branch.op == Branch.IS_ERROR
489+
and branch.traceback_entry is not None
490+
and not branch.negated
491+
and branch.false is self.next_block):
492+
# Generate code also for the following branch here to avoid
493+
# redundant branches in the generated code.
494+
handler = TracebackAndGotoHandler(self.label(branch.true),
495+
self.source_path,
496+
self.module_name,
497+
branch.traceback_entry)
498+
self.op_index += 1
499+
488500
self.emitter.emit_cast(self.reg(op.src), self.reg(op), op.type,
489-
src_type=op.src.type)
501+
src_type=op.src.type, error=handler)
490502

491503
def visit_unbox(self, op: Unbox) -> None:
492504
self.emitter.emit_unbox(self.reg(op.src), self.reg(op), op.type)
@@ -647,14 +659,7 @@ def emit_declaration(self, line: str) -> None:
647659

648660
def emit_traceback(self, op: Branch) -> None:
649661
if op.traceback_entry is not None:
650-
globals_static = self.emitter.static_name('globals', self.module_name)
651-
self.emit_line('CPy_AddTraceback("%s", "%s", %d, %s);' % (
652-
self.source_path.replace("\\", "\\\\"),
653-
op.traceback_entry[0],
654-
op.traceback_entry[1],
655-
globals_static))
656-
if DEBUG_ERRORS:
657-
self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");')
662+
self.emitter.emit_traceback(self.source_path, self.module_name, op.traceback_entry)
658663

659664
def emit_attribute_error(self, op: Branch, class_name: str, attr: str) -> None:
660665
assert op.traceback_entry is not None

mypyc/lib-rt/CPy.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,8 @@ void _CPy_GetExcInfo(PyObject **p_type, PyObject **p_value, PyObject **p_traceba
500500
void CPyError_OutOfMemory(void);
501501
void CPy_TypeError(const char *expected, PyObject *value);
502502
void CPy_AddTraceback(const char *filename, const char *funcname, int line, PyObject *globals);
503+
void CPy_TypeErrorTraceback(const char *filename, const char *funcname, int line,
504+
PyObject *globals, const char *expected, PyObject *value);
503505
void CPy_AttributeError(const char *filename, const char *funcname, const char *classname,
504506
const char *attrname, int line, PyObject *globals);
505507

mypyc/lib-rt/exc_ops.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,13 @@ void CPy_AddTraceback(const char *filename, const char *funcname, int line, PyOb
233233
_PyErr_ChainExceptions(exc, val, tb);
234234
}
235235

236+
CPy_NOINLINE
237+
void CPy_TypeErrorTraceback(const char *filename, const char *funcname, int line,
238+
PyObject *globals, const char *expected, PyObject *value) {
239+
CPy_TypeError(expected, value);
240+
CPy_AddTraceback(filename, funcname, line, globals);
241+
}
242+
236243
void CPy_AttributeError(const char *filename, const char *funcname, const char *classname,
237244
const char *attrname, int line, PyObject *globals) {
238245
char buf[500];

mypyc/test-data/run-functions.test

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,3 +1220,18 @@ def sub(s: str, f: Callable[[str], str]) -> str: ...
12201220
def sub(s: bytes, f: Callable[[bytes], bytes]) -> bytes: ...
12211221
def sub(s, f):
12221222
return f(s)
1223+
1224+
[case testContextManagerSpecialCase]
1225+
from typing import Generator, Callable, Iterator
1226+
from contextlib import contextmanager
1227+
1228+
@contextmanager
1229+
def f() -> Iterator[None]:
1230+
yield
1231+
1232+
def g() -> None:
1233+
a = ['']
1234+
with f():
1235+
a.pop()
1236+
1237+
g()

0 commit comments

Comments
 (0)