Skip to content

Commit 7d69ce2

Browse files
authored
[mypyc] Improve support for compiling singledispatch (#10795)
This makes several improvements to the support for compiling singledispatch that was introduced in #10753 by: * Making sure registered implementations defined later in a file take precedence when multiple overlap * Using non-native calls to registered implementations to allow for adding other decorators to registered functions (099b047) * Creating a separate function that dispatches to the correct implementation instead of adding code to dispatch to one of the registered implementations directly into the main singledispatch function, allowing the main singledispatch function to be a generator (59555e4) * Avoiding a compilation error when trying to dispatch on an ABC (2d40421)
1 parent 66cae4b commit 7d69ce2

File tree

6 files changed

+183
-64
lines changed

6 files changed

+183
-64
lines changed

mypyc/irbuild/context.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ def __init__(self,
2222
is_nested: bool = False,
2323
contains_nested: bool = False,
2424
is_decorated: bool = False,
25-
in_non_ext: bool = False,
26-
is_singledispatch: bool = False) -> None:
25+
in_non_ext: bool = False) -> None:
2726
self.fitem = fitem
2827
self.name = name if not is_decorated else decorator_helper_name(name)
2928
self.class_name = class_name
@@ -48,7 +47,6 @@ def __init__(self,
4847
self.contains_nested = contains_nested
4948
self.is_decorated = is_decorated
5049
self.in_non_ext = in_non_ext
51-
self.is_singledispatch = is_singledispatch
5250

5351
# TODO: add field for ret_type: RType = none_rprimitive
5452

mypyc/irbuild/function.py

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
setup_func_for_recursive_call
5454
)
5555

56+
from mypyc.primitives.registry import builtin_names
57+
5658

5759
# Top-level transform functions
5860

@@ -219,9 +221,12 @@ def c() -> None:
219221
in_non_ext = not ir.is_ext_class
220222
class_name = cdef.name
221223

222-
builder.enter(FuncInfo(fitem, name, class_name, gen_func_ns(builder),
223-
is_nested, contains_nested, is_decorated, in_non_ext,
224-
is_singledispatch))
224+
if is_singledispatch:
225+
func_name = '__mypyc_singledispatch_main_function_{}__'.format(name)
226+
else:
227+
func_name = name
228+
builder.enter(FuncInfo(fitem, func_name, class_name, gen_func_ns(builder),
229+
is_nested, contains_nested, is_decorated, in_non_ext))
225230

226231
# Functions that contain nested functions need an environment class to store variables that
227232
# are free in their nested functions. Generator functions need an environment class to
@@ -254,9 +259,6 @@ def c() -> None:
254259
if builder.fn_info.contains_nested and not builder.fn_info.is_generator:
255260
finalize_env_class(builder)
256261

257-
if builder.fn_info.is_singledispatch:
258-
add_singledispatch_registered_impls(builder)
259-
260262
builder.ret_types[-1] = sig.ret_type
261263

262264
# Add all variables and functions that are declared/defined within this
@@ -313,6 +315,15 @@ def c() -> None:
313315
# calculate them *once* when the function definition is evaluated.
314316
calculate_arg_defaults(builder, fn_info, func_reg, symtable)
315317

318+
if is_singledispatch:
319+
# add the generated main singledispatch function
320+
builder.functions.append(func_ir)
321+
# create the dispatch function
322+
assert isinstance(fitem, FuncDef)
323+
dispatch_name = decorator_helper_name(name) if is_decorated else name
324+
dispatch_func_ir = gen_dispatch_func_ir(builder, fitem, fn_info.name, dispatch_name, sig)
325+
return dispatch_func_ir, None
326+
316327
return (func_ir, func_reg)
317328

318329

@@ -768,28 +779,62 @@ def check_if_isinstance(builder: IRBuilder, obj: Value, typ: TypeInfo, line: int
768779
class_ir = builder.mapper.type_to_ir[typ]
769780
return builder.builder.isinstance_native(obj, class_ir, line)
770781
else:
771-
class_obj = builder.load_module_attr_by_fullname(typ.fullname, line)
782+
if typ.fullname in builtin_names:
783+
builtin_addr_type, src = builtin_names[typ.fullname]
784+
class_obj = builder.add(LoadAddress(builtin_addr_type, src, line))
785+
else:
786+
class_obj = builder.load_global_str(typ.name, line)
772787
return builder.call_c(slow_isinstance_op, [obj, class_obj], line)
773788

774789

775-
def add_singledispatch_registered_impls(builder: IRBuilder) -> None:
776-
fitem = builder.fn_info.fitem
777-
assert isinstance(fitem, FuncDef)
790+
def generate_singledispatch_dispatch_function(
791+
builder: IRBuilder,
792+
main_singledispatch_function_name: str,
793+
fitem: FuncDef,
794+
) -> None:
778795
impls = builder.singledispatch_impls[fitem]
779796
line = fitem.line
780797
current_func_decl = builder.mapper.func_to_decl[fitem]
781798
arg_info = get_args(builder, current_func_decl.sig.args, line)
782-
for dispatch_type, impl in impls:
783-
func_decl = builder.mapper.func_to_decl[impl]
799+
800+
def gen_func_call_and_return(func_name: str) -> None:
801+
func = builder.load_global_str(func_name, line)
802+
# TODO: don't pass optional arguments if they weren't passed to this function
803+
ret_val = builder.builder.py_call(
804+
func, arg_info.args, line, arg_info.arg_kinds, arg_info.arg_names
805+
)
806+
coerced = builder.coerce(ret_val, current_func_decl.sig.ret_type, line)
807+
builder.nonlocal_control[-1].gen_return(builder, coerced, line)
808+
809+
# Reverse the list of registered implementations so we use the implementations defined later
810+
# if there are multiple overlapping implementations
811+
for dispatch_type, impl in reversed(impls):
784812
call_impl, next_impl = BasicBlock(), BasicBlock()
785813
should_call_impl = check_if_isinstance(builder, arg_info.args[0], dispatch_type, line)
786814
builder.add_bool_branch(should_call_impl, call_impl, next_impl)
787815

788816
# Call the registered implementation
789817
builder.activate_block(call_impl)
790818

791-
ret_val = builder.builder.call(
792-
func_decl, arg_info.args, arg_info.arg_kinds, arg_info.arg_names, line
793-
)
794-
builder.nonlocal_control[-1].gen_return(builder, ret_val, line)
819+
gen_func_call_and_return(impl.name)
795820
builder.activate_block(next_impl)
821+
822+
gen_func_call_and_return(main_singledispatch_function_name)
823+
824+
825+
def gen_dispatch_func_ir(
826+
builder: IRBuilder,
827+
fitem: FuncDef,
828+
main_func_name: str,
829+
dispatch_name: str,
830+
sig: FuncSignature,
831+
) -> FuncIR:
832+
"""Create a dispatch function (a function that checks the first argument type and dispatches
833+
to the correct implementation)
834+
"""
835+
builder.enter()
836+
generate_singledispatch_dispatch_function(builder, main_func_name, fitem)
837+
args, _, blocks, _, fn_info = builder.leave()
838+
func_decl = FuncDecl(dispatch_name, None, builder.module_name, sig)
839+
dispatch_func_ir = FuncIR(func_decl, args, blocks)
840+
return dispatch_func_ir

mypyc/irbuild/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def build_ir(modules: List[MypyFile],
6666

6767
for module in modules:
6868
# First pass to determine free symbols.
69-
pbv = PreBuildVisitor()
69+
pbv = PreBuildVisitor(errors, module)
7070
module.accept(pbv)
7171

7272
# Construct and configure builder objects (cyclic runtime dependency).

mypyc/irbuild/prebuildvisitor.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
from mypyc.errors import Errors
12
from mypy.types import Instance, get_proper_type
23
from typing import DefaultDict, Dict, List, NamedTuple, Set, Optional, Tuple
34
from collections import defaultdict
45

56
from mypy.nodes import (
67
Decorator, Expression, FuncDef, FuncItem, LambdaExpr, NameExpr, SymbolNode, Var, MemberExpr,
7-
CallExpr, RefExpr, TypeInfo
8+
CallExpr, RefExpr, TypeInfo, MypyFile
89
)
910
from mypy.traverser import TraverserVisitor
1011

@@ -23,7 +24,7 @@ class PreBuildVisitor(TraverserVisitor):
2324
The main IR build pass uses this information.
2425
"""
2526

26-
def __init__(self) -> None:
27+
def __init__(self, errors: Errors, current_file: MypyFile) -> None:
2728
super().__init__()
2829
# Dict from a function to symbols defined directly in the
2930
# function that are used as non-local (free) variables within a
@@ -57,6 +58,10 @@ def __init__(self) -> None:
5758
self.singledispatch_impls: DefaultDict[
5859
FuncDef, List[Tuple[TypeInfo, FuncDef]]] = defaultdict(list)
5960

61+
self.errors: Errors = errors
62+
63+
self.current_file: MypyFile = current_file
64+
6065
def visit_decorator(self, dec: Decorator) -> None:
6166
if dec.decorators:
6267
# Only add the function being decorated if there exist
@@ -72,12 +77,27 @@ def visit_decorator(self, dec: Decorator) -> None:
7277
else:
7378
decorators_to_store = dec.decorators.copy()
7479
removed: List[int] = []
80+
# the index of the last non-register decorator before finding a register decorator
81+
# when going through decorators from top to bottom
82+
last_non_register: Optional[int] = None
7583
for i, d in enumerate(decorators_to_store):
7684
impl = get_singledispatch_register_call_info(d, dec.func)
7785
if impl is not None:
7886
self.singledispatch_impls[impl.singledispatch_func].append(
7987
(impl.dispatch_type, dec.func))
8088
removed.append(i)
89+
if last_non_register is not None:
90+
# found a register decorator after a non-register decorator, which we
91+
# don't support because we'd have to make a copy of the function before
92+
# calling the decorator so that we can call it later, which complicates
93+
# the implementation for something that is probably not commonly used
94+
self.errors.error(
95+
"Calling decorator after registering function not supported",
96+
self.current_file.path,
97+
decorators_to_store[last_non_register].line,
98+
)
99+
else:
100+
last_non_register = i
81101
# calling register on a function that tries to dispatch based on type annotations
82102
# raises a TypeError because compiled functions don't have an __annotations__
83103
# attribute

mypyc/test-data/commandline.test

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def f(x: int) -> int:
108108
from typing import List, Any, AsyncIterable
109109
from typing_extensions import Final
110110
from mypy_extensions import trait, mypyc_attr
111+
from functools import singledispatch
111112

112113
def busted(b: bool) -> None:
113114
for i in range(1, 10, 0): # E: range() step can't be zero
@@ -219,3 +220,23 @@ async def async_with() -> None:
219220

220221
async def async_generators() -> AsyncIterable[int]:
221222
yield 1 # E: async generators are unimplemented
223+
224+
@singledispatch
225+
def a(arg) -> None:
226+
pass
227+
228+
@decorator # E: Calling decorator after registering function not supported
229+
@a.register
230+
def g(arg: int) -> None:
231+
pass
232+
233+
@a.register
234+
@decorator
235+
def h(arg: str) -> None:
236+
pass
237+
238+
@decorator
239+
@decorator # E: Calling decorator after registering function not supported
240+
@a.register
241+
def i(arg: Foo) -> None:
242+
pass

0 commit comments

Comments
 (0)