Skip to content

Commit c9d4c61

Browse files
authored
[mypyc] Optimize dunder methods (#17934)
This change gives mypyc the ability to optionally optimize dunder methods that can guarantee strict adherence to its signature typing. The optimization allows to bypass vtable for dunder methods in certain cases that are applicable. Currently, mypy has adopted the convention of accept dunder methods that return `NotImplemented` value even when its signature do not reflect this possibility. With this change and by enabling an special flag, mypyc will expect strict typing be honored and will unleash more optimizations like native call without vtable lookup for some cases on dunder method calls. For example it could avoid calls to RichCompare Python API making the code can be fully optimized in by the C compiler when some comparison with dunders are required. Example: ```python @Final class A: def __init__(self, x: i32) -> None: self.x: Final = x def __lt__(self, other: "A") -> bool: return self.x < other.x A(1) < A(2) ``` would produce: ```c char CPyDef_A_____lt__(PyObject *cpy_r_self, PyObject *cpy_r_other) { int32_t cpy_r_r0; int32_t cpy_r_r1; char cpy_r_r2; cpy_r_r0 = ((AObject *)cpy_r_self)->_x; cpy_r_r1 = ((AObject *)cpy_r_other)->_x; cpy_r_r2 = cpy_r_r0 < cpy_r_r1; return cpy_r_r2; } ... cpy_r_r29 = CPyDef_A_____lt__(cpy_r_r27, cpy_r_r28); ... ``` Instead of: ```c PyObject *CPyDef_A_____lt__(PyObject *cpy_r_self, PyObject *cpy_r_other) { int32_t cpy_r_r0; int32_t cpy_r_r1; char cpy_r_r2; PyObject *cpy_r_r3; cpy_r_r0 = ((AObject *)cpy_r_self)->_x; cpy_r_r1 = ((AObject *)cpy_r_other)->_x; cpy_r_r2 = cpy_r_r0 < cpy_r_r1; cpy_r_r3 = cpy_r_r2 ? Py_True : Py_False; CPy_INCREF(cpy_r_r3); return cpy_r_r3; } ... cpy_r_r29 = PyObject_RichCompare(cpy_r_r27, cpy_r_r28, 0); ... ``` Default behavior is kept. Tests run with both of strict typing enabled and disabled.
1 parent bd2aafc commit c9d4c61

File tree

11 files changed

+210
-72
lines changed

11 files changed

+210
-72
lines changed

mypyc/__main__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from mypyc.build import mypycify
2525
2626
setup(name='mypyc_output',
27-
ext_modules=mypycify({}, opt_level="{}", debug_level="{}"),
27+
ext_modules=mypycify({}, opt_level="{}", debug_level="{}", strict_dunder_typing={}),
2828
)
2929
"""
3030

@@ -38,10 +38,11 @@ def main() -> None:
3838

3939
opt_level = os.getenv("MYPYC_OPT_LEVEL", "3")
4040
debug_level = os.getenv("MYPYC_DEBUG_LEVEL", "1")
41+
strict_dunder_typing = bool(int(os.getenv("MYPYC_STRICT_DUNDER_TYPING", "0")))
4142

4243
setup_file = os.path.join(build_dir, "setup.py")
4344
with open(setup_file, "w") as f:
44-
f.write(setup_format.format(sys.argv[1:], opt_level, debug_level))
45+
f.write(setup_format.format(sys.argv[1:], opt_level, debug_level, strict_dunder_typing))
4546

4647
# We don't use run_setup (like we do in the test suite) because it throws
4748
# away the error code from distutils, and we don't care about the slight

mypyc/build.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ def mypycify(
470470
skip_cgen_input: Any | None = None,
471471
target_dir: str | None = None,
472472
include_runtime_files: bool | None = None,
473+
strict_dunder_typing: bool = False,
473474
) -> list[Extension]:
474475
"""Main entry point to building using mypyc.
475476
@@ -509,6 +510,9 @@ def mypycify(
509510
should be directly #include'd instead of linked
510511
separately in order to reduce compiler invocations.
511512
Defaults to False in multi_file mode, True otherwise.
513+
strict_dunder_typing: If True, force dunder methods to have the return type
514+
of the method strictly, which can lead to more
515+
optimization opportunities. Defaults to False.
512516
"""
513517

514518
# Figure out our configuration
@@ -519,6 +523,7 @@ def mypycify(
519523
separate=separate is not False,
520524
target_dir=target_dir,
521525
include_runtime_files=include_runtime_files,
526+
strict_dunder_typing=strict_dunder_typing,
522527
)
523528

524529
# Generate all the actual important C code

mypyc/irbuild/classdef.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
pytype_from_template_op,
8282
type_object_op,
8383
)
84+
from mypyc.subtype import is_subtype
8485

8586

8687
def transform_class_def(builder: IRBuilder, cdef: ClassDef) -> None:
@@ -801,30 +802,42 @@ def create_ne_from_eq(builder: IRBuilder, cdef: ClassDef) -> None:
801802

802803
def gen_glue_ne_method(builder: IRBuilder, cls: ClassIR, line: int) -> None:
803804
"""Generate a "__ne__" method from a "__eq__" method."""
804-
with builder.enter_method(cls, "__ne__", object_rprimitive):
805-
rhs_arg = builder.add_argument("rhs", object_rprimitive)
806-
807-
# If __eq__ returns NotImplemented, then __ne__ should also
808-
not_implemented_block, regular_block = BasicBlock(), BasicBlock()
805+
func_ir = cls.get_method("__eq__")
806+
assert func_ir
807+
eq_sig = func_ir.decl.sig
808+
strict_typing = builder.options.strict_dunders_typing
809+
with builder.enter_method(cls, "__ne__", eq_sig.ret_type):
810+
rhs_type = eq_sig.args[0].type if strict_typing else object_rprimitive
811+
rhs_arg = builder.add_argument("rhs", rhs_type)
809812
eqval = builder.add(MethodCall(builder.self(), "__eq__", [rhs_arg], line))
810-
not_implemented = builder.add(
811-
LoadAddress(not_implemented_op.type, not_implemented_op.src, line)
812-
)
813-
builder.add(
814-
Branch(
815-
builder.translate_is_op(eqval, not_implemented, "is", line),
816-
not_implemented_block,
817-
regular_block,
818-
Branch.BOOL,
819-
)
820-
)
821813

822-
builder.activate_block(regular_block)
823-
retval = builder.coerce(builder.unary_op(eqval, "not", line), object_rprimitive, line)
824-
builder.add(Return(retval))
814+
can_return_not_implemented = is_subtype(not_implemented_op.type, eq_sig.ret_type)
815+
return_bool = is_subtype(eq_sig.ret_type, bool_rprimitive)
825816

826-
builder.activate_block(not_implemented_block)
827-
builder.add(Return(not_implemented))
817+
if not strict_typing or can_return_not_implemented:
818+
# If __eq__ returns NotImplemented, then __ne__ should also
819+
not_implemented_block, regular_block = BasicBlock(), BasicBlock()
820+
not_implemented = builder.add(
821+
LoadAddress(not_implemented_op.type, not_implemented_op.src, line)
822+
)
823+
builder.add(
824+
Branch(
825+
builder.translate_is_op(eqval, not_implemented, "is", line),
826+
not_implemented_block,
827+
regular_block,
828+
Branch.BOOL,
829+
)
830+
)
831+
builder.activate_block(regular_block)
832+
rettype = bool_rprimitive if return_bool and strict_typing else object_rprimitive
833+
retval = builder.coerce(builder.unary_op(eqval, "not", line), rettype, line)
834+
builder.add(Return(retval))
835+
builder.activate_block(not_implemented_block)
836+
builder.add(Return(not_implemented))
837+
else:
838+
rettype = bool_rprimitive if return_bool and strict_typing else object_rprimitive
839+
retval = builder.coerce(builder.unary_op(eqval, "not", line), rettype, line)
840+
builder.add(Return(retval))
828841

829842

830843
def load_non_ext_class(

mypyc/irbuild/function.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@
9696

9797

9898
def transform_func_def(builder: IRBuilder, fdef: FuncDef) -> None:
99-
func_ir, func_reg = gen_func_item(builder, fdef, fdef.name, builder.mapper.fdef_to_sig(fdef))
99+
sig = builder.mapper.fdef_to_sig(fdef, builder.options.strict_dunders_typing)
100+
func_ir, func_reg = gen_func_item(builder, fdef, fdef.name, sig)
100101

101102
# If the function that was visited was a nested function, then either look it up in our
102103
# current environment or define it if it was not already defined.
@@ -113,9 +114,8 @@ def transform_overloaded_func_def(builder: IRBuilder, o: OverloadedFuncDef) -> N
113114

114115

115116
def transform_decorator(builder: IRBuilder, dec: Decorator) -> None:
116-
func_ir, func_reg = gen_func_item(
117-
builder, dec.func, dec.func.name, builder.mapper.fdef_to_sig(dec.func)
118-
)
117+
sig = builder.mapper.fdef_to_sig(dec.func, builder.options.strict_dunders_typing)
118+
func_ir, func_reg = gen_func_item(builder, dec.func, dec.func.name, sig)
119119
decorated_func: Value | None = None
120120
if func_reg:
121121
decorated_func = load_decorated_func(builder, dec.func, func_reg)
@@ -416,7 +416,8 @@ def handle_ext_method(builder: IRBuilder, cdef: ClassDef, fdef: FuncDef) -> None
416416
# Perform the function of visit_method for methods inside extension classes.
417417
name = fdef.name
418418
class_ir = builder.mapper.type_to_ir[cdef.info]
419-
func_ir, func_reg = gen_func_item(builder, fdef, name, builder.mapper.fdef_to_sig(fdef), cdef)
419+
sig = builder.mapper.fdef_to_sig(fdef, builder.options.strict_dunders_typing)
420+
func_ir, func_reg = gen_func_item(builder, fdef, name, sig, cdef)
420421
builder.functions.append(func_ir)
421422

422423
if is_decorated(builder, fdef):
@@ -481,7 +482,8 @@ def handle_non_ext_method(
481482
) -> None:
482483
# Perform the function of visit_method for methods inside non-extension classes.
483484
name = fdef.name
484-
func_ir, func_reg = gen_func_item(builder, fdef, name, builder.mapper.fdef_to_sig(fdef), cdef)
485+
sig = builder.mapper.fdef_to_sig(fdef, builder.options.strict_dunders_typing)
486+
func_ir, func_reg = gen_func_item(builder, fdef, name, sig, cdef)
485487
assert func_reg is not None
486488
builder.functions.append(func_ir)
487489

mypyc/irbuild/ll_builder.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from mypy.argmap import map_actuals_to_formals
1616
from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2, ArgKind
17-
from mypy.operators import op_methods
17+
from mypy.operators import op_methods, unary_op_methods
1818
from mypy.types import AnyType, TypeOfAny
1919
from mypyc.common import (
2020
BITMAP_BITS,
@@ -167,6 +167,7 @@
167167
buf_init_item,
168168
fast_isinstance_op,
169169
none_object_op,
170+
not_implemented_op,
170171
var_object_size,
171172
)
172173
from mypyc.primitives.registry import (
@@ -1398,11 +1399,48 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
13981399
if base_op in float_op_to_id:
13991400
return self.float_op(lreg, rreg, base_op, line)
14001401

1402+
dunder_op = self.dunder_op(lreg, rreg, op, line)
1403+
if dunder_op:
1404+
return dunder_op
1405+
14011406
primitive_ops_candidates = binary_ops.get(op, [])
14021407
target = self.matching_primitive_op(primitive_ops_candidates, [lreg, rreg], line)
14031408
assert target, "Unsupported binary operation: %s" % op
14041409
return target
14051410

1411+
def dunder_op(self, lreg: Value, rreg: Value | None, op: str, line: int) -> Value | None:
1412+
"""
1413+
Dispatch a dunder method if applicable.
1414+
For example for `a + b` it will use `a.__add__(b)` which can lead to higher performance
1415+
due to the fact that the method could be already compiled and optimized instead of going
1416+
all the way through `PyNumber_Add(a, b)` python api (making a jump into the python DL).
1417+
"""
1418+
ltype = lreg.type
1419+
if not isinstance(ltype, RInstance):
1420+
return None
1421+
1422+
method_name = op_methods.get(op) if rreg else unary_op_methods.get(op)
1423+
if method_name is None:
1424+
return None
1425+
1426+
if not ltype.class_ir.has_method(method_name):
1427+
return None
1428+
1429+
decl = ltype.class_ir.method_decl(method_name)
1430+
if not rreg and len(decl.sig.args) != 1:
1431+
return None
1432+
1433+
if rreg and (len(decl.sig.args) != 2 or not is_subtype(rreg.type, decl.sig.args[1].type)):
1434+
return None
1435+
1436+
if rreg and is_subtype(not_implemented_op.type, decl.sig.ret_type):
1437+
# If the method is able to return NotImplemented, we should not optimize it.
1438+
# We can just let go so it will be handled through the python api.
1439+
return None
1440+
1441+
args = [rreg] if rreg else []
1442+
return self.gen_method_call(lreg, method_name, args, decl.sig.ret_type, line)
1443+
14061444
def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) -> Value:
14071445
"""Check if a tagged integer is a short integer.
14081446
@@ -1558,16 +1596,9 @@ def unary_op(self, value: Value, expr_op: str, line: int) -> Value:
15581596
if isinstance(value, Float):
15591597
return Float(-value.value, value.line)
15601598
if isinstance(typ, RInstance):
1561-
if expr_op == "-":
1562-
method = "__neg__"
1563-
elif expr_op == "+":
1564-
method = "__pos__"
1565-
elif expr_op == "~":
1566-
method = "__invert__"
1567-
else:
1568-
method = ""
1569-
if method and typ.class_ir.has_method(method):
1570-
return self.gen_method_call(value, method, [], None, line)
1599+
result = self.dunder_op(value, None, expr_op, line)
1600+
if result is not None:
1601+
return result
15711602
call_c_ops_candidates = unary_ops.get(expr_op, [])
15721603
target = self.matching_call_c(call_c_ops_candidates, [value], line)
15731604
assert target, "Unsupported unary operation: %s" % expr_op

mypyc/irbuild/mapper.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def get_arg_rtype(self, typ: Type, kind: ArgKind) -> RType:
160160
else:
161161
return self.type_to_rtype(typ)
162162

163-
def fdef_to_sig(self, fdef: FuncDef) -> FuncSignature:
163+
def fdef_to_sig(self, fdef: FuncDef, strict_dunders_typing: bool) -> FuncSignature:
164164
if isinstance(fdef.type, CallableType):
165165
arg_types = [
166166
self.get_arg_rtype(typ, kind)
@@ -199,11 +199,14 @@ def fdef_to_sig(self, fdef: FuncDef) -> FuncSignature:
199199
)
200200
]
201201

202-
# We force certain dunder methods to return objects to support letting them
203-
# return NotImplemented. It also avoids some pointless boxing and unboxing,
204-
# since tp_richcompare needs an object anyways.
205-
if fdef.name in ("__eq__", "__ne__", "__lt__", "__gt__", "__le__", "__ge__"):
206-
ret = object_rprimitive
202+
if not strict_dunders_typing:
203+
# We force certain dunder methods to return objects to support letting them
204+
# return NotImplemented. It also avoids some pointless boxing and unboxing,
205+
# since tp_richcompare needs an object anyways.
206+
# However, it also prevents some optimizations.
207+
if fdef.name in ("__eq__", "__ne__", "__lt__", "__gt__", "__le__", "__ge__"):
208+
ret = object_rprimitive
209+
207210
return FuncSignature(args, ret)
208211

209212
def is_native_module(self, module: str) -> bool:

0 commit comments

Comments
 (0)