Skip to content

Commit a73b113

Browse files
jairov4MINI\jairo
authored andcommitted
Implement dunder method optimization
1 parent 706680f commit a73b113

File tree

11 files changed

+210
-72
lines changed

11 files changed

+210
-72
lines changed

mypyc/__main__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from mypyc.build import mypycify
2525
2626
setup(name='mypyc_output',
27-
ext_modules=mypycify({}, opt_level="{}", debug_level="{}"),
27+
ext_modules=mypycify({}, opt_level="{}", debug_level="{}", strict_dunder_typing={}),
2828
)
2929
"""
3030

@@ -38,10 +38,11 @@ def main() -> None:
3838

3939
opt_level = os.getenv("MYPYC_OPT_LEVEL", "3")
4040
debug_level = os.getenv("MYPYC_DEBUG_LEVEL", "1")
41+
strict_dunder_typing = bool(int(os.getenv("MYPYC_STRICT_DUNDER_TYPING", "0")))
4142

4243
setup_file = os.path.join(build_dir, "setup.py")
4344
with open(setup_file, "w") as f:
44-
f.write(setup_format.format(sys.argv[1:], opt_level, debug_level))
45+
f.write(setup_format.format(sys.argv[1:], opt_level, debug_level, strict_dunder_typing))
4546

4647
# We don't use run_setup (like we do in the test suite) because it throws
4748
# away the error code from distutils, and we don't care about the slight

mypyc/build.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ def mypycify(
470470
skip_cgen_input: Any | None = None,
471471
target_dir: str | None = None,
472472
include_runtime_files: bool | None = None,
473+
strict_dunder_typing: bool = False,
473474
) -> list[Extension]:
474475
"""Main entry point to building using mypyc.
475476
@@ -509,6 +510,9 @@ def mypycify(
509510
should be directly #include'd instead of linked
510511
separately in order to reduce compiler invocations.
511512
Defaults to False in multi_file mode, True otherwise.
513+
strict_dunder_typing: If True, force dunder methods to have the return type
514+
of the method strictly, which can lead to more
515+
optimization opportunities. Defaults to False.
512516
"""
513517

514518
# Figure out our configuration
@@ -519,6 +523,7 @@ def mypycify(
519523
separate=separate is not False,
520524
target_dir=target_dir,
521525
include_runtime_files=include_runtime_files,
526+
strict_dunder_typing=strict_dunder_typing,
522527
)
523528

524529
# Generate all the actual important C code

mypyc/irbuild/classdef.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
pytype_from_template_op,
8282
type_object_op,
8383
)
84+
from mypyc.subtype import is_subtype
8485

8586

8687
def transform_class_def(builder: IRBuilder, cdef: ClassDef) -> None:
@@ -801,30 +802,42 @@ def create_ne_from_eq(builder: IRBuilder, cdef: ClassDef) -> None:
801802

802803
def gen_glue_ne_method(builder: IRBuilder, cls: ClassIR, line: int) -> None:
803804
"""Generate a "__ne__" method from a "__eq__" method."""
804-
with builder.enter_method(cls, "__ne__", object_rprimitive):
805-
rhs_arg = builder.add_argument("rhs", object_rprimitive)
806-
807-
# If __eq__ returns NotImplemented, then __ne__ should also
808-
not_implemented_block, regular_block = BasicBlock(), BasicBlock()
805+
func_ir = cls.get_method("__eq__")
806+
assert func_ir
807+
eq_sig = func_ir.decl.sig
808+
strict_typing = builder.options.strict_dunders_typing
809+
with builder.enter_method(cls, "__ne__", eq_sig.ret_type):
810+
rhs_type = eq_sig.args[0].type if strict_typing else object_rprimitive
811+
rhs_arg = builder.add_argument("rhs", rhs_type)
809812
eqval = builder.add(MethodCall(builder.self(), "__eq__", [rhs_arg], line))
810-
not_implemented = builder.add(
811-
LoadAddress(not_implemented_op.type, not_implemented_op.src, line)
812-
)
813-
builder.add(
814-
Branch(
815-
builder.translate_is_op(eqval, not_implemented, "is", line),
816-
not_implemented_block,
817-
regular_block,
818-
Branch.BOOL,
819-
)
820-
)
821813

822-
builder.activate_block(regular_block)
823-
retval = builder.coerce(builder.unary_op(eqval, "not", line), object_rprimitive, line)
824-
builder.add(Return(retval))
814+
can_return_not_implemented = is_subtype(not_implemented_op.type, eq_sig.ret_type)
815+
return_bool = is_subtype(eq_sig.ret_type, bool_rprimitive)
825816

826-
builder.activate_block(not_implemented_block)
827-
builder.add(Return(not_implemented))
817+
if not strict_typing or can_return_not_implemented:
818+
# If __eq__ returns NotImplemented, then __ne__ should also
819+
not_implemented_block, regular_block = BasicBlock(), BasicBlock()
820+
not_implemented = builder.add(
821+
LoadAddress(not_implemented_op.type, not_implemented_op.src, line)
822+
)
823+
builder.add(
824+
Branch(
825+
builder.translate_is_op(eqval, not_implemented, "is", line),
826+
not_implemented_block,
827+
regular_block,
828+
Branch.BOOL,
829+
)
830+
)
831+
builder.activate_block(regular_block)
832+
rettype = bool_rprimitive if return_bool and strict_typing else object_rprimitive
833+
retval = builder.coerce(builder.unary_op(eqval, "not", line), rettype, line)
834+
builder.add(Return(retval))
835+
builder.activate_block(not_implemented_block)
836+
builder.add(Return(not_implemented))
837+
else:
838+
rettype = bool_rprimitive if return_bool and strict_typing else object_rprimitive
839+
retval = builder.coerce(builder.unary_op(eqval, "not", line), rettype, line)
840+
builder.add(Return(retval))
828841

829842

830843
def load_non_ext_class(

mypyc/irbuild/function.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@
9696

9797

9898
def transform_func_def(builder: IRBuilder, fdef: FuncDef) -> None:
99-
func_ir, func_reg = gen_func_item(builder, fdef, fdef.name, builder.mapper.fdef_to_sig(fdef))
99+
sig = builder.mapper.fdef_to_sig(fdef, builder.options.strict_dunders_typing)
100+
func_ir, func_reg = gen_func_item(builder, fdef, fdef.name, sig)
100101

101102
# If the function that was visited was a nested function, then either look it up in our
102103
# current environment or define it if it was not already defined.
@@ -113,9 +114,8 @@ def transform_overloaded_func_def(builder: IRBuilder, o: OverloadedFuncDef) -> N
113114

114115

115116
def transform_decorator(builder: IRBuilder, dec: Decorator) -> None:
116-
func_ir, func_reg = gen_func_item(
117-
builder, dec.func, dec.func.name, builder.mapper.fdef_to_sig(dec.func)
118-
)
117+
sig = builder.mapper.fdef_to_sig(dec.func, builder.options.strict_dunders_typing)
118+
func_ir, func_reg = gen_func_item(builder, dec.func, dec.func.name, sig)
119119
decorated_func: Value | None = None
120120
if func_reg:
121121
decorated_func = load_decorated_func(builder, dec.func, func_reg)
@@ -416,7 +416,8 @@ def handle_ext_method(builder: IRBuilder, cdef: ClassDef, fdef: FuncDef) -> None
416416
# Perform the function of visit_method for methods inside extension classes.
417417
name = fdef.name
418418
class_ir = builder.mapper.type_to_ir[cdef.info]
419-
func_ir, func_reg = gen_func_item(builder, fdef, name, builder.mapper.fdef_to_sig(fdef), cdef)
419+
sig = builder.mapper.fdef_to_sig(fdef, builder.options.strict_dunders_typing)
420+
func_ir, func_reg = gen_func_item(builder, fdef, name, sig, cdef)
420421
builder.functions.append(func_ir)
421422

422423
if is_decorated(builder, fdef):
@@ -481,7 +482,8 @@ def handle_non_ext_method(
481482
) -> None:
482483
# Perform the function of visit_method for methods inside non-extension classes.
483484
name = fdef.name
484-
func_ir, func_reg = gen_func_item(builder, fdef, name, builder.mapper.fdef_to_sig(fdef), cdef)
485+
sig = builder.mapper.fdef_to_sig(fdef, builder.options.strict_dunders_typing)
486+
func_ir, func_reg = gen_func_item(builder, fdef, name, sig, cdef)
485487
assert func_reg is not None
486488
builder.functions.append(func_ir)
487489

mypyc/irbuild/ll_builder.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from mypy.argmap import map_actuals_to_formals
1616
from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2, ArgKind
17-
from mypy.operators import op_methods
17+
from mypy.operators import op_methods, unary_op_methods
1818
from mypy.types import AnyType, TypeOfAny
1919
from mypyc.common import (
2020
BITMAP_BITS,
@@ -167,6 +167,7 @@
167167
buf_init_item,
168168
fast_isinstance_op,
169169
none_object_op,
170+
not_implemented_op,
170171
var_object_size,
171172
)
172173
from mypyc.primitives.registry import (
@@ -1398,11 +1399,48 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
13981399
if base_op in float_op_to_id:
13991400
return self.float_op(lreg, rreg, base_op, line)
14001401

1402+
dunder_op = self.dunder_op(lreg, rreg, op, line)
1403+
if dunder_op:
1404+
return dunder_op
1405+
14011406
primitive_ops_candidates = binary_ops.get(op, [])
14021407
target = self.matching_primitive_op(primitive_ops_candidates, [lreg, rreg], line)
14031408
assert target, "Unsupported binary operation: %s" % op
14041409
return target
14051410

1411+
def dunder_op(self, lreg: Value, rreg: Value | None, op: str, line: int) -> Value | None:
1412+
"""
1413+
Dispatch a dunder method if applicable.
1414+
For example for `a + b` it will use `a.__add__(b)` which can lead to higher performance
1415+
due to the fact that the method could be already compiled and optimized instead of going
1416+
all the way through `PyNumber_Add(a, b)` python api (making a jump into the python DL).
1417+
"""
1418+
ltype = lreg.type
1419+
if not isinstance(ltype, RInstance):
1420+
return None
1421+
1422+
method_name = op_methods.get(op) if rreg else unary_op_methods.get(op)
1423+
if method_name is None:
1424+
return None
1425+
1426+
if not ltype.class_ir.has_method(method_name):
1427+
return None
1428+
1429+
decl = ltype.class_ir.method_decl(method_name)
1430+
if not rreg and len(decl.sig.args) != 1:
1431+
return None
1432+
1433+
if rreg and (len(decl.sig.args) != 2 or not is_subtype(rreg.type, decl.sig.args[1].type)):
1434+
return None
1435+
1436+
if rreg and is_subtype(not_implemented_op.type, decl.sig.ret_type):
1437+
# If the method is able to return NotImplemented, we should not optimize it.
1438+
# We can just let go so it will be handled through the python api.
1439+
return None
1440+
1441+
args = [rreg] if rreg else []
1442+
return self.gen_method_call(lreg, method_name, args, decl.sig.ret_type, line)
1443+
14061444
def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) -> Value:
14071445
"""Check if a tagged integer is a short integer.
14081446
@@ -1558,16 +1596,9 @@ def unary_op(self, value: Value, expr_op: str, line: int) -> Value:
15581596
if isinstance(value, Float):
15591597
return Float(-value.value, value.line)
15601598
if isinstance(typ, RInstance):
1561-
if expr_op == "-":
1562-
method = "__neg__"
1563-
elif expr_op == "+":
1564-
method = "__pos__"
1565-
elif expr_op == "~":
1566-
method = "__invert__"
1567-
else:
1568-
method = ""
1569-
if method and typ.class_ir.has_method(method):
1570-
return self.gen_method_call(value, method, [], None, line)
1599+
result = self.dunder_op(value, None, expr_op, line)
1600+
if result is not None:
1601+
return result
15711602
call_c_ops_candidates = unary_ops.get(expr_op, [])
15721603
target = self.matching_call_c(call_c_ops_candidates, [value], line)
15731604
assert target, "Unsupported unary operation: %s" % expr_op

mypyc/irbuild/mapper.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def get_arg_rtype(self, typ: Type, kind: ArgKind) -> RType:
160160
else:
161161
return self.type_to_rtype(typ)
162162

163-
def fdef_to_sig(self, fdef: FuncDef) -> FuncSignature:
163+
def fdef_to_sig(self, fdef: FuncDef, strict_dunders_typing: bool) -> FuncSignature:
164164
if isinstance(fdef.type, CallableType):
165165
arg_types = [
166166
self.get_arg_rtype(typ, kind)
@@ -199,11 +199,14 @@ def fdef_to_sig(self, fdef: FuncDef) -> FuncSignature:
199199
)
200200
]
201201

202-
# We force certain dunder methods to return objects to support letting them
203-
# return NotImplemented. It also avoids some pointless boxing and unboxing,
204-
# since tp_richcompare needs an object anyways.
205-
if fdef.name in ("__eq__", "__ne__", "__lt__", "__gt__", "__le__", "__ge__"):
206-
ret = object_rprimitive
202+
if not strict_dunders_typing:
203+
# We force certain dunder methods to return objects to support letting them
204+
# return NotImplemented. It also avoids some pointless boxing and unboxing,
205+
# since tp_richcompare needs an object anyways.
206+
# However, it also prevents some optimizations.
207+
if fdef.name in ("__eq__", "__ne__", "__lt__", "__gt__", "__le__", "__ge__"):
208+
ret = object_rprimitive
209+
207210
return FuncSignature(args, ret)
208211

209212
def is_native_module(self, module: str) -> bool:

0 commit comments

Comments
 (0)