Skip to content

Commit 82a8bb2

Browse files
authored
Merge pull request #83 from python/master
[mypyc] Improve support for compiling singledispatch (python#10795)
2 parents 7637cf2 + 7d69ce2 commit 82a8bb2

File tree

6 files changed

+183
-64
lines changed

6 files changed

+183
-64
lines changed

mypyc/irbuild/context.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ def __init__(self,
2222
is_nested: bool = False,
2323
contains_nested: bool = False,
2424
is_decorated: bool = False,
25-
in_non_ext: bool = False,
26-
is_singledispatch: bool = False) -> None:
25+
in_non_ext: bool = False) -> None:
2726
self.fitem = fitem
2827
self.name = name if not is_decorated else decorator_helper_name(name)
2928
self.class_name = class_name
@@ -48,7 +47,6 @@ def __init__(self,
4847
self.contains_nested = contains_nested
4948
self.is_decorated = is_decorated
5049
self.in_non_ext = in_non_ext
51-
self.is_singledispatch = is_singledispatch
5250

5351
# TODO: add field for ret_type: RType = none_rprimitive
5452

mypyc/irbuild/function.py

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
setup_func_for_recursive_call
5454
)
5555

56+
from mypyc.primitives.registry import builtin_names
57+
5658

5759
# Top-level transform functions
5860

@@ -219,9 +221,12 @@ def c() -> None:
219221
in_non_ext = not ir.is_ext_class
220222
class_name = cdef.name
221223

222-
builder.enter(FuncInfo(fitem, name, class_name, gen_func_ns(builder),
223-
is_nested, contains_nested, is_decorated, in_non_ext,
224-
is_singledispatch))
224+
if is_singledispatch:
225+
func_name = '__mypyc_singledispatch_main_function_{}__'.format(name)
226+
else:
227+
func_name = name
228+
builder.enter(FuncInfo(fitem, func_name, class_name, gen_func_ns(builder),
229+
is_nested, contains_nested, is_decorated, in_non_ext))
225230

226231
# Functions that contain nested functions need an environment class to store variables that
227232
# are free in their nested functions. Generator functions need an environment class to
@@ -254,9 +259,6 @@ def c() -> None:
254259
if builder.fn_info.contains_nested and not builder.fn_info.is_generator:
255260
finalize_env_class(builder)
256261

257-
if builder.fn_info.is_singledispatch:
258-
add_singledispatch_registered_impls(builder)
259-
260262
builder.ret_types[-1] = sig.ret_type
261263

262264
# Add all variables and functions that are declared/defined within this
@@ -313,6 +315,15 @@ def c() -> None:
313315
# calculate them *once* when the function definition is evaluated.
314316
calculate_arg_defaults(builder, fn_info, func_reg, symtable)
315317

318+
if is_singledispatch:
319+
# add the generated main singledispatch function
320+
builder.functions.append(func_ir)
321+
# create the dispatch function
322+
assert isinstance(fitem, FuncDef)
323+
dispatch_name = decorator_helper_name(name) if is_decorated else name
324+
dispatch_func_ir = gen_dispatch_func_ir(builder, fitem, fn_info.name, dispatch_name, sig)
325+
return dispatch_func_ir, None
326+
316327
return (func_ir, func_reg)
317328

318329

@@ -768,28 +779,62 @@ def check_if_isinstance(builder: IRBuilder, obj: Value, typ: TypeInfo, line: int
768779
class_ir = builder.mapper.type_to_ir[typ]
769780
return builder.builder.isinstance_native(obj, class_ir, line)
770781
else:
771-
class_obj = builder.load_module_attr_by_fullname(typ.fullname, line)
782+
if typ.fullname in builtin_names:
783+
builtin_addr_type, src = builtin_names[typ.fullname]
784+
class_obj = builder.add(LoadAddress(builtin_addr_type, src, line))
785+
else:
786+
class_obj = builder.load_global_str(typ.name, line)
772787
return builder.call_c(slow_isinstance_op, [obj, class_obj], line)
773788

774789

775-
def add_singledispatch_registered_impls(builder: IRBuilder) -> None:
776-
fitem = builder.fn_info.fitem
777-
assert isinstance(fitem, FuncDef)
790+
def generate_singledispatch_dispatch_function(
791+
builder: IRBuilder,
792+
main_singledispatch_function_name: str,
793+
fitem: FuncDef,
794+
) -> None:
778795
impls = builder.singledispatch_impls[fitem]
779796
line = fitem.line
780797
current_func_decl = builder.mapper.func_to_decl[fitem]
781798
arg_info = get_args(builder, current_func_decl.sig.args, line)
782-
for dispatch_type, impl in impls:
783-
func_decl = builder.mapper.func_to_decl[impl]
799+
800+
def gen_func_call_and_return(func_name: str) -> None:
801+
func = builder.load_global_str(func_name, line)
802+
# TODO: don't pass optional arguments if they weren't passed to this function
803+
ret_val = builder.builder.py_call(
804+
func, arg_info.args, line, arg_info.arg_kinds, arg_info.arg_names
805+
)
806+
coerced = builder.coerce(ret_val, current_func_decl.sig.ret_type, line)
807+
builder.nonlocal_control[-1].gen_return(builder, coerced, line)
808+
809+
# Reverse the list of registered implementations so we use the implementations defined later
810+
# if there are multiple overlapping implementations
811+
for dispatch_type, impl in reversed(impls):
784812
call_impl, next_impl = BasicBlock(), BasicBlock()
785813
should_call_impl = check_if_isinstance(builder, arg_info.args[0], dispatch_type, line)
786814
builder.add_bool_branch(should_call_impl, call_impl, next_impl)
787815

788816
# Call the registered implementation
789817
builder.activate_block(call_impl)
790818

791-
ret_val = builder.builder.call(
792-
func_decl, arg_info.args, arg_info.arg_kinds, arg_info.arg_names, line
793-
)
794-
builder.nonlocal_control[-1].gen_return(builder, ret_val, line)
819+
gen_func_call_and_return(impl.name)
795820
builder.activate_block(next_impl)
821+
822+
gen_func_call_and_return(main_singledispatch_function_name)
823+
824+
825+
def gen_dispatch_func_ir(
826+
builder: IRBuilder,
827+
fitem: FuncDef,
828+
main_func_name: str,
829+
dispatch_name: str,
830+
sig: FuncSignature,
831+
) -> FuncIR:
832+
"""Create a dispatch function (a function that checks the first argument type and dispatches
833+
to the correct implementation)
834+
"""
835+
builder.enter()
836+
generate_singledispatch_dispatch_function(builder, main_func_name, fitem)
837+
args, _, blocks, _, fn_info = builder.leave()
838+
func_decl = FuncDecl(dispatch_name, None, builder.module_name, sig)
839+
dispatch_func_ir = FuncIR(func_decl, args, blocks)
840+
return dispatch_func_ir

mypyc/irbuild/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def build_ir(modules: List[MypyFile],
6666

6767
for module in modules:
6868
# First pass to determine free symbols.
69-
pbv = PreBuildVisitor()
69+
pbv = PreBuildVisitor(errors, module)
7070
module.accept(pbv)
7171

7272
# Construct and configure builder objects (cyclic runtime dependency).

mypyc/irbuild/prebuildvisitor.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
from mypyc.errors import Errors
12
from mypy.types import Instance, get_proper_type
23
from typing import DefaultDict, Dict, List, NamedTuple, Set, Optional, Tuple
34
from collections import defaultdict
45

56
from mypy.nodes import (
67
Decorator, Expression, FuncDef, FuncItem, LambdaExpr, NameExpr, SymbolNode, Var, MemberExpr,
7-
CallExpr, RefExpr, TypeInfo
8+
CallExpr, RefExpr, TypeInfo, MypyFile
89
)
910
from mypy.traverser import TraverserVisitor
1011

@@ -23,7 +24,7 @@ class PreBuildVisitor(TraverserVisitor):
2324
The main IR build pass uses this information.
2425
"""
2526

26-
def __init__(self) -> None:
27+
def __init__(self, errors: Errors, current_file: MypyFile) -> None:
2728
super().__init__()
2829
# Dict from a function to symbols defined directly in the
2930
# function that are used as non-local (free) variables within a
@@ -57,6 +58,10 @@ def __init__(self) -> None:
5758
self.singledispatch_impls: DefaultDict[
5859
FuncDef, List[Tuple[TypeInfo, FuncDef]]] = defaultdict(list)
5960

61+
self.errors: Errors = errors
62+
63+
self.current_file: MypyFile = current_file
64+
6065
def visit_decorator(self, dec: Decorator) -> None:
6166
if dec.decorators:
6267
# Only add the function being decorated if there exist
@@ -72,12 +77,27 @@ def visit_decorator(self, dec: Decorator) -> None:
7277
else:
7378
decorators_to_store = dec.decorators.copy()
7479
removed: List[int] = []
80+
# the index of the last non-register decorator before finding a register decorator
81+
# when going through decorators from top to bottom
82+
last_non_register: Optional[int] = None
7583
for i, d in enumerate(decorators_to_store):
7684
impl = get_singledispatch_register_call_info(d, dec.func)
7785
if impl is not None:
7886
self.singledispatch_impls[impl.singledispatch_func].append(
7987
(impl.dispatch_type, dec.func))
8088
removed.append(i)
89+
if last_non_register is not None:
90+
# found a register decorator after a non-register decorator, which we
91+
# don't support because we'd have to make a copy of the function before
92+
# calling the decorator so that we can call it later, which complicates
93+
# the implementation for something that is probably not commonly used
94+
self.errors.error(
95+
"Calling decorator after registering function not supported",
96+
self.current_file.path,
97+
decorators_to_store[last_non_register].line,
98+
)
99+
else:
100+
last_non_register = i
81101
# calling register on a function that tries to dispatch based on type annotations
82102
# raises a TypeError because compiled functions don't have an __annotations__
83103
# attribute

mypyc/test-data/commandline.test

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def f(x: int) -> int:
108108
from typing import List, Any, AsyncIterable
109109
from typing_extensions import Final
110110
from mypy_extensions import trait, mypyc_attr
111+
from functools import singledispatch
111112

112113
def busted(b: bool) -> None:
113114
for i in range(1, 10, 0): # E: range() step can't be zero
@@ -219,3 +220,23 @@ async def async_with() -> None:
219220

220221
async def async_generators() -> AsyncIterable[int]:
221222
yield 1 # E: async generators are unimplemented
223+
224+
@singledispatch
225+
def a(arg) -> None:
226+
pass
227+
228+
@decorator # E: Calling decorator after registering function not supported
229+
@a.register
230+
def g(arg: int) -> None:
231+
pass
232+
233+
@a.register
234+
@decorator
235+
def h(arg: str) -> None:
236+
pass
237+
238+
@decorator
239+
@decorator # E: Calling decorator after registering function not supported
240+
@a.register
241+
def i(arg: Foo) -> None:
242+
pass

0 commit comments

Comments
 (0)