1
1
"""Code generation for native function bodies."""
2
2
3
- from typing import Union , Optional
3
+ from typing import List , Union , Optional
4
4
from typing_extensions import Final
5
5
6
6
from mypyc .common import (
7
7
REG_PREFIX , NATIVE_PREFIX , STATIC_PREFIX , TYPE_PREFIX , MODULE_PREFIX ,
8
8
)
9
9
from mypyc .codegen .emit import Emitter
10
10
from mypyc .ir .ops import (
11
- OpVisitor , Goto , Branch , Return , Assign , Integer , LoadErrorValue , GetAttr , SetAttr ,
11
+ Op , OpVisitor , Goto , Branch , Return , Assign , Integer , LoadErrorValue , GetAttr , SetAttr ,
12
12
LoadStatic , InitStatic , TupleGet , TupleSet , Call , IncRef , DecRef , Box , Cast , Unbox ,
13
13
BasicBlock , Value , MethodCall , Unreachable , NAMESPACE_STATIC , NAMESPACE_TYPE , NAMESPACE_MODULE ,
14
14
RaiseStandardError , CallC , LoadGlobal , Truncate , IntOp , LoadMem , GetElementPtr ,
@@ -88,8 +88,13 @@ def generate_native_function(fn: FuncIR,
88
88
next_block = blocks [i + 1 ]
89
89
body .emit_label (block )
90
90
visitor .next_block = next_block
91
- for op in block .ops :
92
- op .accept (visitor )
91
+
92
+ ops = block .ops
93
+ visitor .ops = ops
94
+ visitor .op_index = 0
95
+ while visitor .op_index < len (ops ):
96
+ ops [visitor .op_index ].accept (visitor )
97
+ visitor .op_index += 1
93
98
94
99
body .emit_line ('}' )
95
100
@@ -110,7 +115,12 @@ def __init__(self,
110
115
self .module_name = module_name
111
116
self .literals = emitter .context .literals
112
117
self .rare = False
118
+ # Next basic block to be processed after the current one (if any), set by caller
113
119
self .next_block : Optional [BasicBlock ] = None
120
+ # Ops in the basic block currently being processed, set by caller
121
+ self .ops : List [Op ] = []
122
+ # Current index within ops; visit methods can increment this to skip/merge ops
123
+ self .op_index = 0
114
124
115
125
def temp_name (self ) -> str :
116
126
return self .emitter .temp_name ()
@@ -293,16 +303,44 @@ def visit_get_attr(self, op: GetAttr) -> None:
293
303
attr_expr = self .get_attr_expr (obj , op , decl_cl )
294
304
self .emitter .emit_line ('{} = {};' .format (dest , attr_expr ))
295
305
self .emitter .emit_undefined_attr_check (
296
- attr_rtype , attr_expr , '==' , unlikely = True
306
+ attr_rtype , dest , '==' , unlikely = True
297
307
)
298
308
exc_class = 'PyExc_AttributeError'
299
- self .emitter .emit_line (
300
- 'PyErr_SetString({}, "attribute {} of {} undefined");' .format (
301
- exc_class , repr (op .attr ), repr (cl .name )))
309
+ merged_branch = None
310
+ branch = self .next_branch ()
311
+ if branch is not None :
312
+ if (branch .value is op
313
+ and branch .op == Branch .IS_ERROR
314
+ and branch .traceback_entry is not None
315
+ and not branch .negated ):
316
+ # Generate code for the following branch here to avoid
317
+ # redundant branches in the generate code.
318
+ self .emit_attribute_error (branch , cl .name , op .attr )
319
+ self .emit_line ('goto %s;' % self .label (branch .true ))
320
+ merged_branch = branch
321
+ self .emitter .emit_line ('}' )
322
+ if not merged_branch :
323
+ self .emitter .emit_line (
324
+ 'PyErr_SetString({}, "attribute {} of {} undefined");' .format (
325
+ exc_class , repr (op .attr ), repr (cl .name )))
326
+
302
327
if attr_rtype .is_refcounted :
303
- self .emitter .emit_line ('} else {' )
304
- self .emitter .emit_inc_ref (attr_expr , attr_rtype )
305
- self .emitter .emit_line ('}' )
328
+ if not merged_branch :
329
+ self .emitter .emit_line ('} else {' )
330
+ self .emitter .emit_inc_ref (dest , attr_rtype )
331
+ if merged_branch :
332
+ if merged_branch .false is not self .next_block :
333
+ self .emit_line ('goto %s;' % self .label (merged_branch .false ))
334
+ self .op_index += 1
335
+ else :
336
+ self .emitter .emit_line ('}' )
337
+
338
+ def next_branch (self ) -> Optional [Branch ]:
339
+ if self .op_index + 1 < len (self .ops ):
340
+ next_op = self .ops [self .op_index + 1 ]
341
+ if isinstance (next_op , Branch ):
342
+ return next_op
343
+ return None
306
344
307
345
def visit_set_attr (self , op : SetAttr ) -> None :
308
346
dest = self .reg (op )
@@ -603,6 +641,19 @@ def emit_traceback(self, op: Branch) -> None:
603
641
if DEBUG_ERRORS :
604
642
self .emit_line ('assert(PyErr_Occurred() != NULL && "failure w/o err!");' )
605
643
644
+ def emit_attribute_error (self , op : Branch , class_name : str , attr : str ) -> None :
645
+ assert op .traceback_entry is not None
646
+ globals_static = self .emitter .static_name ('globals' , self .module_name )
647
+ self .emit_line ('CPy_AttributeError("%s", "%s", "%s", "%s", %d, %s);' % (
648
+ self .source_path .replace ("\\ " , "\\ \\ " ),
649
+ op .traceback_entry [0 ],
650
+ class_name ,
651
+ attr ,
652
+ op .traceback_entry [1 ],
653
+ globals_static ))
654
+ if DEBUG_ERRORS :
655
+ self .emit_line ('assert(PyErr_Occurred() != NULL && "failure w/o err!");' )
656
+
606
657
def emit_signed_int_cast (self , type : RType ) -> str :
607
658
if is_tagged (type ):
608
659
return '(Py_ssize_t)'
0 commit comments