Skip to content

[mypyc] Optimize calls to final classes #17886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mypyc/codegen/emitclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,7 @@ def generate_setup_for_class(
emitter.emit_line("}")
else:
emitter.emit_line(f"self->vtable = {vtable_name};")

for i in range(0, len(cl.bitmap_attrs), BITMAP_BITS):
field = emitter.bitmap_field(i)
emitter.emit_line(f"self->{field} = 0;")
Expand Down
42 changes: 24 additions & 18 deletions mypyc/codegen/emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from mypyc.ir.pprint import generate_names_for_ir
from mypyc.ir.rtypes import (
RArray,
RInstance,
RStruct,
RTuple,
RType,
Expand Down Expand Up @@ -362,20 +363,23 @@ def visit_get_attr(self, op: GetAttr) -> None:
prefer_method = cl.is_trait and attr_rtype.error_overlap
if cl.get_method(op.attr, prefer_method=prefer_method):
# Properties are essentially methods, so use vtable access for them.
version = "_TRAIT" if cl.is_trait else ""
self.emit_line(
"%s = CPY_GET_ATTR%s(%s, %s, %d, %s, %s); /* %s */"
% (
dest,
version,
obj,
self.emitter.type_struct_name(rtype.class_ir),
rtype.getter_index(op.attr),
rtype.struct_name(self.names),
self.ctype(rtype.attr_type(op.attr)),
op.attr,
if cl.is_method_final(op.attr):
self.emit_method_call(f"{dest} = ", op.obj, op.attr, [])
else:
version = "_TRAIT" if cl.is_trait else ""
self.emit_line(
"%s = CPY_GET_ATTR%s(%s, %s, %d, %s, %s); /* %s */"
% (
dest,
version,
obj,
self.emitter.type_struct_name(rtype.class_ir),
rtype.getter_index(op.attr),
rtype.struct_name(self.names),
self.ctype(rtype.attr_type(op.attr)),
op.attr,
)
)
)
else:
# Otherwise, use direct or offset struct access.
attr_expr = self.get_attr_expr(obj, op, decl_cl)
Expand Down Expand Up @@ -529,11 +533,13 @@ def visit_call(self, op: Call) -> None:
def visit_method_call(self, op: MethodCall) -> None:
"""Call native method."""
dest = self.get_dest_assign(op)
obj = self.reg(op.obj)
self.emit_method_call(dest, op.obj, op.method, op.args)

rtype = op.receiver_type
def emit_method_call(self, dest: str, op_obj: Value, name: str, op_args: list[Value]) -> None:
obj = self.reg(op_obj)
rtype = op_obj.type
assert isinstance(rtype, RInstance)
class_ir = rtype.class_ir
name = op.method
method = rtype.class_ir.get_method(name)
assert method is not None

Expand All @@ -547,7 +553,7 @@ def visit_method_call(self, op: MethodCall) -> None:
if method.decl.kind == FUNC_STATICMETHOD
else [f"(PyObject *)Py_TYPE({obj})"] if method.decl.kind == FUNC_CLASSMETHOD else [obj]
)
args = ", ".join(obj_args + [self.reg(arg) for arg in op.args])
args = ", ".join(obj_args + [self.reg(arg) for arg in op_args])
mtype = native_function_type(method, self.emitter)
version = "_TRAIT" if rtype.class_ir.is_trait else ""
if is_direct:
Expand All @@ -567,7 +573,7 @@ def visit_method_call(self, op: MethodCall) -> None:
rtype.struct_name(self.names),
mtype,
args,
op.method,
name,
)
)

Expand Down
10 changes: 7 additions & 3 deletions mypyc/ir/class_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,15 @@ def __init__(
is_generated: bool = False,
is_abstract: bool = False,
is_ext_class: bool = True,
is_final_class: bool = False,
) -> None:
self.name = name
self.module_name = module_name
self.is_trait = is_trait
self.is_generated = is_generated
self.is_abstract = is_abstract
self.is_ext_class = is_ext_class
self.is_final_class = is_final_class
# An augmented class has additional methods separate from what mypyc generates.
# Right now the only one is dataclasses.
self.is_augmented = False
Expand Down Expand Up @@ -199,7 +201,8 @@ def __repr__(self) -> str:
"ClassIR("
"name={self.name}, module_name={self.module_name}, "
"is_trait={self.is_trait}, is_generated={self.is_generated}, "
"is_abstract={self.is_abstract}, is_ext_class={self.is_ext_class}"
"is_abstract={self.is_abstract}, is_ext_class={self.is_ext_class}, "
"is_final_class={self.is_final_class}"
")".format(self=self)
)

Expand Down Expand Up @@ -248,8 +251,7 @@ def has_method(self, name: str) -> bool:
def is_method_final(self, name: str) -> bool:
subs = self.subclasses()
if subs is None:
# TODO: Look at the final attribute!
return False
return self.is_final_class

if self.has_method(name):
method_decl = self.method_decl(name)
Expand Down Expand Up @@ -349,6 +351,7 @@ def serialize(self) -> JsonDict:
"is_abstract": self.is_abstract,
"is_generated": self.is_generated,
"is_augmented": self.is_augmented,
"is_final_class": self.is_final_class,
"inherits_python": self.inherits_python,
"has_dict": self.has_dict,
"allow_interpreted_subclasses": self.allow_interpreted_subclasses,
Expand Down Expand Up @@ -404,6 +407,7 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> ClassIR:
ir.is_abstract = data["is_abstract"]
ir.is_ext_class = data["is_ext_class"]
ir.is_augmented = data["is_augmented"]
ir.is_final_class = data["is_final_class"]
ir.inherits_python = data["inherits_python"]
ir.has_dict = data["has_dict"]
ir.allow_interpreted_subclasses = data["allow_interpreted_subclasses"]
Expand Down
2 changes: 1 addition & 1 deletion mypyc/ir/rtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class RType:

@abstractmethod
def accept(self, visitor: RTypeVisitor[T]) -> T:
raise NotImplementedError
raise NotImplementedError()

def short_name(self) -> str:
return short_name(self.name)
Expand Down
4 changes: 2 additions & 2 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1889,7 +1889,7 @@ def primitive_op(
# Does this primitive map into calling a Python C API
# or an internal mypyc C API function?
if desc.c_function_name:
# TODO: Generate PrimitiOps here and transform them into CallC
# TODO: Generate PrimitiveOps here and transform them into CallC
# ops only later in the lowering pass
c_desc = CFunctionDescription(
desc.name,
Expand All @@ -1908,7 +1908,7 @@ def primitive_op(
)
return self.call_c(c_desc, args, line, result_type)

# This primitve gets transformed in a lowering pass to
# This primitive gets transformed in a lowering pass to
# lower-level IR ops using a custom transform function.

coerced = []
Expand Down
6 changes: 5 additions & 1 deletion mypyc/irbuild/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,11 @@ def build_type_map(
# references even if there are import cycles.
for module, cdef in classes:
class_ir = ClassIR(
cdef.name, module.fullname, is_trait(cdef), is_abstract=cdef.info.is_abstract
cdef.name,
module.fullname,
is_trait(cdef),
is_abstract=cdef.info.is_abstract,
is_final_class=cdef.info.is_final,
)
class_ir.is_ext_class = is_extension_class(cdef)
if class_ir.is_ext_class:
Expand Down
11 changes: 10 additions & 1 deletion mypyc/irbuild/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,16 @@
UnaryExpr,
Var,
)
from mypy.semanal import refers_to_fullname
from mypy.types import FINAL_DECORATOR_NAMES

DATACLASS_DECORATORS = {"dataclasses.dataclass", "attr.s", "attr.attrs"}


def is_final_decorator(d: Expression) -> bool:
return refers_to_fullname(d, FINAL_DECORATOR_NAMES)


def is_trait_decorator(d: Expression) -> bool:
return isinstance(d, RefExpr) and d.fullname == "mypy_extensions.trait"

Expand Down Expand Up @@ -119,7 +125,10 @@ def get_mypyc_attrs(stmt: ClassDef | Decorator) -> dict[str, Any]:

def is_extension_class(cdef: ClassDef) -> bool:
if any(
not is_trait_decorator(d) and not is_dataclass_decorator(d) and not get_mypyc_attr_call(d)
not is_trait_decorator(d)
and not is_dataclass_decorator(d)
and not get_mypyc_attr_call(d)
and not is_final_decorator(d)
for d in cdef.decorators
):
return False
Expand Down
75 changes: 75 additions & 0 deletions mypyc/test-data/run-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -2503,3 +2503,78 @@ class C:
def test_final_attribute() -> None:
assert C.A == -1
assert C.a == [-1]

[case testClassWithFinalDecorator]
from typing import final

@final
class C:
def a(self) -> int:
return 1

def test_class_final_attribute() -> None:
assert C().a() == 1


[case testClassWithFinalDecoratorCtor]
from typing import final

@final
class C:
def __init__(self) -> None:
self.a = 1

def b(self) -> int:
return 2

@property
def c(self) -> int:
return 3

def test_class_final_attribute() -> None:
assert C().a == 1
assert C().b() == 2
assert C().c == 3

[case testClassWithFinalDecoratorInheritedWithProperties]
from typing import final

class B:
def a(self) -> int:
return 2

@property
def b(self) -> int:
return self.a() + 2

@property
def c(self) -> int:
return 3

def test_class_final_attribute_basic() -> None:
assert B().a() == 2
assert B().b == 4
assert B().c == 3

@final
class C(B):
def a(self) -> int:
return 1

@property
def b(self) -> int:
return self.a() + 1

def fn(cl: B) -> int:
return cl.a()

def test_class_final_attribute_inherited() -> None:
assert C().a() == 1
assert fn(C()) == 1
assert B().a() == 2
assert fn(B()) == 2

assert B().b == 4
assert C().b == 2
assert B().c == 3
assert C().c == 3
Loading