Skip to content

Commit cbd96f9

Browse files
authored
[mypyc] Optimize calls to final classes (#17886)
Fixes #9612 This change allows to gain more efficiency where classes are annotated with `@final` bypassing entirely the vtable for method calls and property accessors. For example: In ```python @Final class Vector: __slots__ = ("_x", "_y") def __init__(self, x: i32, y: i32) -> None: self._x = x self._y = y @Property def y(self) -> i32: return self._y def test_vector() -> None: v3 = Vector(1, 2) assert v3.y == 2 ``` The call will produce: ```c ... cpy_r_r6 = CPyDef_Vector___y(cpy_r_r0); ... ``` Instead of: ```c ... cpy_r_r1 = CPY_GET_ATTR(cpy_r_r0, CPyType_Vector, 2, farm_rush___engine___vectors2___VectorObject, int32_t); /* y */ ... ``` (which uses vtable)
1 parent 395108d commit cbd96f9

File tree

8 files changed

+125
-26
lines changed

8 files changed

+125
-26
lines changed

mypyc/codegen/emitclass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,7 @@ def generate_setup_for_class(
571571
emitter.emit_line("}")
572572
else:
573573
emitter.emit_line(f"self->vtable = {vtable_name};")
574+
574575
for i in range(0, len(cl.bitmap_attrs), BITMAP_BITS):
575576
field = emitter.bitmap_field(i)
576577
emitter.emit_line(f"self->{field} = 0;")

mypyc/codegen/emitfunc.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
from mypyc.ir.pprint import generate_names_for_ir
7373
from mypyc.ir.rtypes import (
7474
RArray,
75+
RInstance,
7576
RStruct,
7677
RTuple,
7778
RType,
@@ -362,20 +363,23 @@ def visit_get_attr(self, op: GetAttr) -> None:
362363
prefer_method = cl.is_trait and attr_rtype.error_overlap
363364
if cl.get_method(op.attr, prefer_method=prefer_method):
364365
# Properties are essentially methods, so use vtable access for them.
365-
version = "_TRAIT" if cl.is_trait else ""
366-
self.emit_line(
367-
"%s = CPY_GET_ATTR%s(%s, %s, %d, %s, %s); /* %s */"
368-
% (
369-
dest,
370-
version,
371-
obj,
372-
self.emitter.type_struct_name(rtype.class_ir),
373-
rtype.getter_index(op.attr),
374-
rtype.struct_name(self.names),
375-
self.ctype(rtype.attr_type(op.attr)),
376-
op.attr,
366+
if cl.is_method_final(op.attr):
367+
self.emit_method_call(f"{dest} = ", op.obj, op.attr, [])
368+
else:
369+
version = "_TRAIT" if cl.is_trait else ""
370+
self.emit_line(
371+
"%s = CPY_GET_ATTR%s(%s, %s, %d, %s, %s); /* %s */"
372+
% (
373+
dest,
374+
version,
375+
obj,
376+
self.emitter.type_struct_name(rtype.class_ir),
377+
rtype.getter_index(op.attr),
378+
rtype.struct_name(self.names),
379+
self.ctype(rtype.attr_type(op.attr)),
380+
op.attr,
381+
)
377382
)
378-
)
379383
else:
380384
# Otherwise, use direct or offset struct access.
381385
attr_expr = self.get_attr_expr(obj, op, decl_cl)
@@ -529,11 +533,13 @@ def visit_call(self, op: Call) -> None:
529533
def visit_method_call(self, op: MethodCall) -> None:
530534
"""Call native method."""
531535
dest = self.get_dest_assign(op)
532-
obj = self.reg(op.obj)
536+
self.emit_method_call(dest, op.obj, op.method, op.args)
533537

534-
rtype = op.receiver_type
538+
def emit_method_call(self, dest: str, op_obj: Value, name: str, op_args: list[Value]) -> None:
539+
obj = self.reg(op_obj)
540+
rtype = op_obj.type
541+
assert isinstance(rtype, RInstance)
535542
class_ir = rtype.class_ir
536-
name = op.method
537543
method = rtype.class_ir.get_method(name)
538544
assert method is not None
539545

@@ -547,7 +553,7 @@ def visit_method_call(self, op: MethodCall) -> None:
547553
if method.decl.kind == FUNC_STATICMETHOD
548554
else [f"(PyObject *)Py_TYPE({obj})"] if method.decl.kind == FUNC_CLASSMETHOD else [obj]
549555
)
550-
args = ", ".join(obj_args + [self.reg(arg) for arg in op.args])
556+
args = ", ".join(obj_args + [self.reg(arg) for arg in op_args])
551557
mtype = native_function_type(method, self.emitter)
552558
version = "_TRAIT" if rtype.class_ir.is_trait else ""
553559
if is_direct:
@@ -567,7 +573,7 @@ def visit_method_call(self, op: MethodCall) -> None:
567573
rtype.struct_name(self.names),
568574
mtype,
569575
args,
570-
op.method,
576+
name,
571577
)
572578
)
573579

mypyc/ir/class_ir.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,15 @@ def __init__(
9393
is_generated: bool = False,
9494
is_abstract: bool = False,
9595
is_ext_class: bool = True,
96+
is_final_class: bool = False,
9697
) -> None:
9798
self.name = name
9899
self.module_name = module_name
99100
self.is_trait = is_trait
100101
self.is_generated = is_generated
101102
self.is_abstract = is_abstract
102103
self.is_ext_class = is_ext_class
104+
self.is_final_class = is_final_class
103105
# An augmented class has additional methods separate from what mypyc generates.
104106
# Right now the only one is dataclasses.
105107
self.is_augmented = False
@@ -199,7 +201,8 @@ def __repr__(self) -> str:
199201
"ClassIR("
200202
"name={self.name}, module_name={self.module_name}, "
201203
"is_trait={self.is_trait}, is_generated={self.is_generated}, "
202-
"is_abstract={self.is_abstract}, is_ext_class={self.is_ext_class}"
204+
"is_abstract={self.is_abstract}, is_ext_class={self.is_ext_class}, "
205+
"is_final_class={self.is_final_class}"
203206
")".format(self=self)
204207
)
205208

@@ -248,8 +251,7 @@ def has_method(self, name: str) -> bool:
248251
def is_method_final(self, name: str) -> bool:
249252
subs = self.subclasses()
250253
if subs is None:
251-
# TODO: Look at the final attribute!
252-
return False
254+
return self.is_final_class
253255

254256
if self.has_method(name):
255257
method_decl = self.method_decl(name)
@@ -349,6 +351,7 @@ def serialize(self) -> JsonDict:
349351
"is_abstract": self.is_abstract,
350352
"is_generated": self.is_generated,
351353
"is_augmented": self.is_augmented,
354+
"is_final_class": self.is_final_class,
352355
"inherits_python": self.inherits_python,
353356
"has_dict": self.has_dict,
354357
"allow_interpreted_subclasses": self.allow_interpreted_subclasses,
@@ -404,6 +407,7 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> ClassIR:
404407
ir.is_abstract = data["is_abstract"]
405408
ir.is_ext_class = data["is_ext_class"]
406409
ir.is_augmented = data["is_augmented"]
410+
ir.is_final_class = data["is_final_class"]
407411
ir.inherits_python = data["inherits_python"]
408412
ir.has_dict = data["has_dict"]
409413
ir.allow_interpreted_subclasses = data["allow_interpreted_subclasses"]

mypyc/ir/rtypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class RType:
6464

6565
@abstractmethod
6666
def accept(self, visitor: RTypeVisitor[T]) -> T:
67-
raise NotImplementedError
67+
raise NotImplementedError()
6868

6969
def short_name(self) -> str:
7070
return short_name(self.name)

mypyc/irbuild/ll_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1889,7 +1889,7 @@ def primitive_op(
18891889
# Does this primitive map into calling a Python C API
18901890
# or an internal mypyc C API function?
18911891
if desc.c_function_name:
1892-
# TODO: Generate PrimitiOps here and transform them into CallC
1892+
# TODO: Generate PrimitiveOps here and transform them into CallC
18931893
# ops only later in the lowering pass
18941894
c_desc = CFunctionDescription(
18951895
desc.name,
@@ -1908,7 +1908,7 @@ def primitive_op(
19081908
)
19091909
return self.call_c(c_desc, args, line, result_type)
19101910

1911-
# This primitve gets transformed in a lowering pass to
1911+
# This primitive gets transformed in a lowering pass to
19121912
# lower-level IR ops using a custom transform function.
19131913

19141914
coerced = []

mypyc/irbuild/prepare.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,11 @@ def build_type_map(
8181
# references even if there are import cycles.
8282
for module, cdef in classes:
8383
class_ir = ClassIR(
84-
cdef.name, module.fullname, is_trait(cdef), is_abstract=cdef.info.is_abstract
84+
cdef.name,
85+
module.fullname,
86+
is_trait(cdef),
87+
is_abstract=cdef.info.is_abstract,
88+
is_final_class=cdef.info.is_final,
8589
)
8690
class_ir.is_ext_class = is_extension_class(cdef)
8791
if class_ir.is_ext_class:

mypyc/irbuild/util.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,16 @@
2727
UnaryExpr,
2828
Var,
2929
)
30+
from mypy.semanal import refers_to_fullname
31+
from mypy.types import FINAL_DECORATOR_NAMES
3032

3133
DATACLASS_DECORATORS = {"dataclasses.dataclass", "attr.s", "attr.attrs"}
3234

3335

36+
def is_final_decorator(d: Expression) -> bool:
37+
return refers_to_fullname(d, FINAL_DECORATOR_NAMES)
38+
39+
3440
def is_trait_decorator(d: Expression) -> bool:
3541
return isinstance(d, RefExpr) and d.fullname == "mypy_extensions.trait"
3642

@@ -119,7 +125,10 @@ def get_mypyc_attrs(stmt: ClassDef | Decorator) -> dict[str, Any]:
119125

120126
def is_extension_class(cdef: ClassDef) -> bool:
121127
if any(
122-
not is_trait_decorator(d) and not is_dataclass_decorator(d) and not get_mypyc_attr_call(d)
128+
not is_trait_decorator(d)
129+
and not is_dataclass_decorator(d)
130+
and not get_mypyc_attr_call(d)
131+
and not is_final_decorator(d)
123132
for d in cdef.decorators
124133
):
125134
return False

mypyc/test-data/run-classes.test

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2519,3 +2519,78 @@ class C:
25192519
def test_final_attribute() -> None:
25202520
assert C.A == -1
25212521
assert C.a == [-1]
2522+
2523+
[case testClassWithFinalDecorator]
2524+
from typing import final
2525+
2526+
@final
2527+
class C:
2528+
def a(self) -> int:
2529+
return 1
2530+
2531+
def test_class_final_attribute() -> None:
2532+
assert C().a() == 1
2533+
2534+
2535+
[case testClassWithFinalDecoratorCtor]
2536+
from typing import final
2537+
2538+
@final
2539+
class C:
2540+
def __init__(self) -> None:
2541+
self.a = 1
2542+
2543+
def b(self) -> int:
2544+
return 2
2545+
2546+
@property
2547+
def c(self) -> int:
2548+
return 3
2549+
2550+
def test_class_final_attribute() -> None:
2551+
assert C().a == 1
2552+
assert C().b() == 2
2553+
assert C().c == 3
2554+
2555+
[case testClassWithFinalDecoratorInheritedWithProperties]
2556+
from typing import final
2557+
2558+
class B:
2559+
def a(self) -> int:
2560+
return 2
2561+
2562+
@property
2563+
def b(self) -> int:
2564+
return self.a() + 2
2565+
2566+
@property
2567+
def c(self) -> int:
2568+
return 3
2569+
2570+
def test_class_final_attribute_basic() -> None:
2571+
assert B().a() == 2
2572+
assert B().b == 4
2573+
assert B().c == 3
2574+
2575+
@final
2576+
class C(B):
2577+
def a(self) -> int:
2578+
return 1
2579+
2580+
@property
2581+
def b(self) -> int:
2582+
return self.a() + 1
2583+
2584+
def fn(cl: B) -> int:
2585+
return cl.a()
2586+
2587+
def test_class_final_attribute_inherited() -> None:
2588+
assert C().a() == 1
2589+
assert fn(C()) == 1
2590+
assert B().a() == 2
2591+
assert fn(B()) == 2
2592+
2593+
assert B().b == 4
2594+
assert C().b == 2
2595+
assert B().c == 3
2596+
assert C().c == 3

0 commit comments

Comments
 (0)