Skip to content

[mypyc] Optimize dunder methods #17934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mypyc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from mypyc.build import mypycify

setup(name='mypyc_output',
ext_modules=mypycify({}, opt_level="{}", debug_level="{}"),
ext_modules=mypycify({}, opt_level="{}", debug_level="{}", strict_dunder_typing={}),
)
"""

Expand All @@ -38,10 +38,11 @@ def main() -> None:

opt_level = os.getenv("MYPYC_OPT_LEVEL", "3")
debug_level = os.getenv("MYPYC_DEBUG_LEVEL", "1")
strict_dunder_typing = bool(int(os.getenv("MYPYC_STRICT_DUNDER_TYPING", "0")))

setup_file = os.path.join(build_dir, "setup.py")
with open(setup_file, "w") as f:
f.write(setup_format.format(sys.argv[1:], opt_level, debug_level))
f.write(setup_format.format(sys.argv[1:], opt_level, debug_level, strict_dunder_typing))

# We don't use run_setup (like we do in the test suite) because it throws
# away the error code from distutils, and we don't care about the slight
Expand Down
5 changes: 5 additions & 0 deletions mypyc/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ def mypycify(
skip_cgen_input: Any | None = None,
target_dir: str | None = None,
include_runtime_files: bool | None = None,
strict_dunder_typing: bool = False,
) -> list[Extension]:
"""Main entry point to building using mypyc.

Expand Down Expand Up @@ -509,6 +510,9 @@ def mypycify(
should be directly #include'd instead of linked
separately in order to reduce compiler invocations.
Defaults to False in multi_file mode, True otherwise.
strict_dunder_typing: If True, force dunder methods to have the return type
of the method strictly, which can lead to more
optimization opportunities. Defaults to False.
"""

# Figure out our configuration
Expand All @@ -519,6 +523,7 @@ def mypycify(
separate=separate is not False,
target_dir=target_dir,
include_runtime_files=include_runtime_files,
strict_dunder_typing=strict_dunder_typing,
)

# Generate all the actual important C code
Expand Down
55 changes: 34 additions & 21 deletions mypyc/irbuild/classdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
pytype_from_template_op,
type_object_op,
)
from mypyc.subtype import is_subtype


def transform_class_def(builder: IRBuilder, cdef: ClassDef) -> None:
Expand Down Expand Up @@ -801,30 +802,42 @@ def create_ne_from_eq(builder: IRBuilder, cdef: ClassDef) -> None:

def gen_glue_ne_method(builder: IRBuilder, cls: ClassIR, line: int) -> None:
"""Generate a "__ne__" method from a "__eq__" method."""
with builder.enter_method(cls, "__ne__", object_rprimitive):
rhs_arg = builder.add_argument("rhs", object_rprimitive)

# If __eq__ returns NotImplemented, then __ne__ should also
not_implemented_block, regular_block = BasicBlock(), BasicBlock()
func_ir = cls.get_method("__eq__")
assert func_ir
eq_sig = func_ir.decl.sig
strict_typing = builder.options.strict_dunders_typing
with builder.enter_method(cls, "__ne__", eq_sig.ret_type):
rhs_type = eq_sig.args[0].type if strict_typing else object_rprimitive
rhs_arg = builder.add_argument("rhs", rhs_type)
eqval = builder.add(MethodCall(builder.self(), "__eq__", [rhs_arg], line))
not_implemented = builder.add(
LoadAddress(not_implemented_op.type, not_implemented_op.src, line)
)
builder.add(
Branch(
builder.translate_is_op(eqval, not_implemented, "is", line),
not_implemented_block,
regular_block,
Branch.BOOL,
)
)

builder.activate_block(regular_block)
retval = builder.coerce(builder.unary_op(eqval, "not", line), object_rprimitive, line)
builder.add(Return(retval))
can_return_not_implemented = is_subtype(not_implemented_op.type, eq_sig.ret_type)
return_bool = is_subtype(eq_sig.ret_type, bool_rprimitive)

builder.activate_block(not_implemented_block)
builder.add(Return(not_implemented))
if not strict_typing or can_return_not_implemented:
# If __eq__ returns NotImplemented, then __ne__ should also
not_implemented_block, regular_block = BasicBlock(), BasicBlock()
not_implemented = builder.add(
LoadAddress(not_implemented_op.type, not_implemented_op.src, line)
)
builder.add(
Branch(
builder.translate_is_op(eqval, not_implemented, "is", line),
not_implemented_block,
regular_block,
Branch.BOOL,
)
)
builder.activate_block(regular_block)
rettype = bool_rprimitive if return_bool and strict_typing else object_rprimitive
retval = builder.coerce(builder.unary_op(eqval, "not", line), rettype, line)
builder.add(Return(retval))
builder.activate_block(not_implemented_block)
builder.add(Return(not_implemented))
else:
rettype = bool_rprimitive if return_bool and strict_typing else object_rprimitive
retval = builder.coerce(builder.unary_op(eqval, "not", line), rettype, line)
builder.add(Return(retval))


def load_non_ext_class(
Expand Down
14 changes: 8 additions & 6 deletions mypyc/irbuild/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@


def transform_func_def(builder: IRBuilder, fdef: FuncDef) -> None:
func_ir, func_reg = gen_func_item(builder, fdef, fdef.name, builder.mapper.fdef_to_sig(fdef))
sig = builder.mapper.fdef_to_sig(fdef, builder.options.strict_dunders_typing)
func_ir, func_reg = gen_func_item(builder, fdef, fdef.name, sig)

# If the function that was visited was a nested function, then either look it up in our
# current environment or define it if it was not already defined.
Expand All @@ -113,9 +114,8 @@ def transform_overloaded_func_def(builder: IRBuilder, o: OverloadedFuncDef) -> N


def transform_decorator(builder: IRBuilder, dec: Decorator) -> None:
func_ir, func_reg = gen_func_item(
builder, dec.func, dec.func.name, builder.mapper.fdef_to_sig(dec.func)
)
sig = builder.mapper.fdef_to_sig(dec.func, builder.options.strict_dunders_typing)
func_ir, func_reg = gen_func_item(builder, dec.func, dec.func.name, sig)
decorated_func: Value | None = None
if func_reg:
decorated_func = load_decorated_func(builder, dec.func, func_reg)
Expand Down Expand Up @@ -416,7 +416,8 @@ def handle_ext_method(builder: IRBuilder, cdef: ClassDef, fdef: FuncDef) -> None
# Perform the function of visit_method for methods inside extension classes.
name = fdef.name
class_ir = builder.mapper.type_to_ir[cdef.info]
func_ir, func_reg = gen_func_item(builder, fdef, name, builder.mapper.fdef_to_sig(fdef), cdef)
sig = builder.mapper.fdef_to_sig(fdef, builder.options.strict_dunders_typing)
func_ir, func_reg = gen_func_item(builder, fdef, name, sig, cdef)
builder.functions.append(func_ir)

if is_decorated(builder, fdef):
Expand Down Expand Up @@ -481,7 +482,8 @@ def handle_non_ext_method(
) -> None:
# Perform the function of visit_method for methods inside non-extension classes.
name = fdef.name
func_ir, func_reg = gen_func_item(builder, fdef, name, builder.mapper.fdef_to_sig(fdef), cdef)
sig = builder.mapper.fdef_to_sig(fdef, builder.options.strict_dunders_typing)
func_ir, func_reg = gen_func_item(builder, fdef, name, sig, cdef)
assert func_reg is not None
builder.functions.append(func_ir)

Expand Down
53 changes: 42 additions & 11 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from mypy.argmap import map_actuals_to_formals
from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2, ArgKind
from mypy.operators import op_methods
from mypy.operators import op_methods, unary_op_methods
from mypy.types import AnyType, TypeOfAny
from mypyc.common import (
BITMAP_BITS,
Expand Down Expand Up @@ -167,6 +167,7 @@
buf_init_item,
fast_isinstance_op,
none_object_op,
not_implemented_op,
var_object_size,
)
from mypyc.primitives.registry import (
Expand Down Expand Up @@ -1398,11 +1399,48 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
if base_op in float_op_to_id:
return self.float_op(lreg, rreg, base_op, line)

dunder_op = self.dunder_op(lreg, rreg, op, line)
if dunder_op:
return dunder_op

primitive_ops_candidates = binary_ops.get(op, [])
target = self.matching_primitive_op(primitive_ops_candidates, [lreg, rreg], line)
assert target, "Unsupported binary operation: %s" % op
return target

def dunder_op(self, lreg: Value, rreg: Value | None, op: str, line: int) -> Value | None:
"""
Dispatch a dunder method if applicable.
For example for `a + b` it will use `a.__add__(b)` which can lead to higher performance
due to the fact that the method could be already compiled and optimized instead of going
all the way through `PyNumber_Add(a, b)` python api (making a jump into the python DL).
"""
ltype = lreg.type
if not isinstance(ltype, RInstance):
return None

method_name = op_methods.get(op) if rreg else unary_op_methods.get(op)
if method_name is None:
return None

if not ltype.class_ir.has_method(method_name):
return None

decl = ltype.class_ir.method_decl(method_name)
if not rreg and len(decl.sig.args) != 1:
return None

if rreg and (len(decl.sig.args) != 2 or not is_subtype(rreg.type, decl.sig.args[1].type)):
return None

if rreg and is_subtype(not_implemented_op.type, decl.sig.ret_type):
# If the method is able to return NotImplemented, we should not optimize it.
# We can just let go so it will be handled through the python api.
return None

args = [rreg] if rreg else []
return self.gen_method_call(lreg, method_name, args, decl.sig.ret_type, line)

def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) -> Value:
"""Check if a tagged integer is a short integer.

Expand Down Expand Up @@ -1558,16 +1596,9 @@ def unary_op(self, value: Value, expr_op: str, line: int) -> Value:
if isinstance(value, Float):
return Float(-value.value, value.line)
if isinstance(typ, RInstance):
if expr_op == "-":
method = "__neg__"
elif expr_op == "+":
method = "__pos__"
elif expr_op == "~":
method = "__invert__"
else:
method = ""
if method and typ.class_ir.has_method(method):
return self.gen_method_call(value, method, [], None, line)
result = self.dunder_op(value, None, expr_op, line)
if result is not None:
return result
call_c_ops_candidates = unary_ops.get(expr_op, [])
target = self.matching_call_c(call_c_ops_candidates, [value], line)
assert target, "Unsupported unary operation: %s" % expr_op
Expand Down
15 changes: 9 additions & 6 deletions mypyc/irbuild/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def get_arg_rtype(self, typ: Type, kind: ArgKind) -> RType:
else:
return self.type_to_rtype(typ)

def fdef_to_sig(self, fdef: FuncDef) -> FuncSignature:
def fdef_to_sig(self, fdef: FuncDef, strict_dunders_typing: bool) -> FuncSignature:
if isinstance(fdef.type, CallableType):
arg_types = [
self.get_arg_rtype(typ, kind)
Expand Down Expand Up @@ -199,11 +199,14 @@ def fdef_to_sig(self, fdef: FuncDef) -> FuncSignature:
)
]

# We force certain dunder methods to return objects to support letting them
# return NotImplemented. It also avoids some pointless boxing and unboxing,
# since tp_richcompare needs an object anyways.
if fdef.name in ("__eq__", "__ne__", "__lt__", "__gt__", "__le__", "__ge__"):
ret = object_rprimitive
if not strict_dunders_typing:
# We force certain dunder methods to return objects to support letting them
# return NotImplemented. It also avoids some pointless boxing and unboxing,
# since tp_richcompare needs an object anyways.
# However, it also prevents some optimizations.
if fdef.name in ("__eq__", "__ne__", "__lt__", "__gt__", "__le__", "__ge__"):
ret = object_rprimitive

return FuncSignature(args, ret)

def is_native_module(self, module: str) -> bool:
Expand Down
Loading
Loading