Skip to content

Commit a5a9e15

Browse files
authored
[mypyc] Add initial support for compiling singledispatch functions (#10753)
This PR adds initial support for compiling functions marked with singledispatch by generating IR that checks the type of the first argument and calls the correct implementation, falling back to the main singledispatch function if none of the registered implementations have a dispatch type that matches the argument. Currently, this only supports both one-argument versions of register (passing a type as an argument to register or using type annotations), and only works if register is used as a decorator.
1 parent e07ad3b commit a5a9e15

File tree

6 files changed

+191
-34
lines changed

6 files changed

+191
-34
lines changed

mypyc/irbuild/builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(self,
115115
self.encapsulating_funcs = pbv.encapsulating_funcs
116116
self.nested_fitems = pbv.nested_funcs.keys()
117117
self.fdefs_to_decorators = pbv.funcs_to_decorators
118+
self.singledispatch_impls = pbv.singledispatch_impls
118119

119120
self.visitor = visitor
120121

mypyc/irbuild/context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ def __init__(self,
2222
is_nested: bool = False,
2323
contains_nested: bool = False,
2424
is_decorated: bool = False,
25-
in_non_ext: bool = False) -> None:
25+
in_non_ext: bool = False,
26+
is_singledispatch: bool = False) -> None:
2627
self.fitem = fitem
2728
self.name = name if not is_decorated else decorator_helper_name(name)
2829
self.class_name = class_name
@@ -47,6 +48,7 @@ def __init__(self,
4748
self.contains_nested = contains_nested
4849
self.is_decorated = is_decorated
4950
self.in_non_ext = in_non_ext
51+
self.is_singledispatch = is_singledispatch
5052

5153
# TODO: add field for ret_type: RType = none_rprimitive
5254

mypyc/irbuild/function.py

Lines changed: 66 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
instance of the callable class.
1111
"""
1212

13-
from typing import Optional, List, Tuple, Union, Dict
13+
from typing import NamedTuple, Optional, List, Sequence, Tuple, Union, Dict
1414

1515
from mypy.nodes import (
1616
ClassDef, FuncDef, OverloadedFuncDef, Decorator, Var, YieldFromExpr, AwaitExpr, YieldExpr,
17-
FuncItem, LambdaExpr, SymbolNode, ARG_NAMED, ARG_NAMED_OPT
17+
FuncItem, LambdaExpr, SymbolNode, ARG_NAMED, ARG_NAMED_OPT, TypeInfo
1818
)
1919
from mypy.types import CallableType, get_proper_type
2020

@@ -28,7 +28,9 @@
2828
)
2929
from mypyc.ir.class_ir import ClassIR, NonExtClassInfo
3030
from mypyc.primitives.generic_ops import py_setattr_op, next_raw_op, iter_op
31-
from mypyc.primitives.misc_ops import check_stop_op, yield_from_except_op, coro_op, send_op
31+
from mypyc.primitives.misc_ops import (
32+
check_stop_op, yield_from_except_op, coro_op, send_op, slow_isinstance_op
33+
)
3234
from mypyc.primitives.dict_ops import dict_set_item_op
3335
from mypyc.common import SELF_NAME, LAMBDA_NAME, decorator_helper_name
3436
from mypyc.sametype import is_same_method_signature
@@ -84,7 +86,10 @@ def transform_decorator(builder: IRBuilder, dec: Decorator) -> None:
8486
decorated_func = load_decorated_func(builder, dec.func, func_reg)
8587
builder.assign(get_func_target(builder, dec.func), decorated_func, dec.func.line)
8688
func_reg = decorated_func
87-
else:
89+
# If the prebuild pass didn't put this function in the function to decorators map (for example
90+
# if this is a registered singledispatch implementation with no other decorators), we should
91+
# treat this function as a regular function, not a decorated function
92+
elif dec.func in builder.fdefs_to_decorators:
8893
# Obtain the the function name in order to construct the name of the helper function.
8994
name = dec.func.fullname.split('.')[-1]
9095
helper_name = decorator_helper_name(name)
@@ -206,6 +211,7 @@ def c() -> None:
206211
is_nested = fitem in builder.nested_fitems or isinstance(fitem, LambdaExpr)
207212
contains_nested = fitem in builder.encapsulating_funcs.keys()
208213
is_decorated = fitem in builder.fdefs_to_decorators
214+
is_singledispatch = fitem in builder.singledispatch_impls
209215
in_non_ext = False
210216
class_name = None
211217
if cdef:
@@ -214,7 +220,8 @@ def c() -> None:
214220
class_name = cdef.name
215221

216222
builder.enter(FuncInfo(fitem, name, class_name, gen_func_ns(builder),
217-
is_nested, contains_nested, is_decorated, in_non_ext))
223+
is_nested, contains_nested, is_decorated, in_non_ext,
224+
is_singledispatch))
218225

219226
# Functions that contain nested functions need an environment class to store variables that
220227
# are free in their nested functions. Generator functions need an environment class to
@@ -247,6 +254,9 @@ def c() -> None:
247254
if builder.fn_info.contains_nested and not builder.fn_info.is_generator:
248255
finalize_env_class(builder)
249256

257+
if builder.fn_info.is_singledispatch:
258+
add_singledispatch_registered_impls(builder)
259+
250260
builder.ret_types[-1] = sig.ret_type
251261

252262
# Add all variables and functions that are declared/defined within this
@@ -628,6 +638,23 @@ def gen_glue(builder: IRBuilder, sig: FuncSignature, target: FuncIR,
628638
return gen_glue_method(builder, sig, target, cls, base, fdef.line, do_py_ops)
629639

630640

641+
class ArgInfo(NamedTuple):
642+
args: List[Value]
643+
arg_names: List[Optional[str]]
644+
arg_kinds: List[int]
645+
646+
647+
def get_args(builder: IRBuilder, rt_args: Sequence[RuntimeArg], line: int) -> ArgInfo:
648+
# The environment operates on Vars, so we make some up
649+
fake_vars = [(Var(arg.name), arg.type) for arg in rt_args]
650+
args = [builder.read(builder.add_local_reg(var, type, is_arg=True), line)
651+
for var, type in fake_vars]
652+
arg_names = [arg.name if arg.kind in (ARG_NAMED, ARG_NAMED_OPT) else None
653+
for arg in rt_args]
654+
arg_kinds = [concrete_arg_kind(arg.kind) for arg in rt_args]
655+
return ArgInfo(args, arg_names, arg_kinds)
656+
657+
631658
def gen_glue_method(builder: IRBuilder, sig: FuncSignature, target: FuncIR,
632659
cls: ClassIR, base: ClassIR, line: int,
633660
do_pycall: bool,
@@ -664,13 +691,8 @@ def f(builder: IRBuilder, x: object) -> int: ...
664691
if target.decl.kind == FUNC_NORMAL:
665692
rt_args[0] = RuntimeArg(sig.args[0].name, RInstance(cls))
666693

667-
# The environment operates on Vars, so we make some up
668-
fake_vars = [(Var(arg.name), arg.type) for arg in rt_args]
669-
args = [builder.read(builder.add_local_reg(var, type, is_arg=True), line)
670-
for var, type in fake_vars]
671-
arg_names = [arg.name if arg.kind in (ARG_NAMED, ARG_NAMED_OPT) else None
672-
for arg in rt_args]
673-
arg_kinds = [concrete_arg_kind(arg.kind) for arg in rt_args]
694+
arg_info = get_args(builder, rt_args, line)
695+
args, arg_kinds, arg_names = arg_info.args, arg_info.arg_kinds, arg_info.arg_names
674696

675697
if do_pycall:
676698
retval = builder.builder.py_method_call(
@@ -739,3 +761,35 @@ def get_func_target(builder: IRBuilder, fdef: FuncDef) -> AssignmentTarget:
739761
return builder.lookup(fdef)
740762

741763
return builder.add_local_reg(fdef, object_rprimitive)
764+
765+
766+
def check_if_isinstance(builder: IRBuilder, obj: Value, typ: TypeInfo, line: int) -> Value:
767+
if typ in builder.mapper.type_to_ir:
768+
class_ir = builder.mapper.type_to_ir[typ]
769+
return builder.builder.isinstance_native(obj, class_ir, line)
770+
else:
771+
class_obj = builder.load_module_attr_by_fullname(typ.fullname, line)
772+
return builder.call_c(slow_isinstance_op, [obj, class_obj], line)
773+
774+
775+
def add_singledispatch_registered_impls(builder: IRBuilder) -> None:
776+
fitem = builder.fn_info.fitem
777+
assert isinstance(fitem, FuncDef)
778+
impls = builder.singledispatch_impls[fitem]
779+
line = fitem.line
780+
current_func_decl = builder.mapper.func_to_decl[fitem]
781+
arg_info = get_args(builder, current_func_decl.sig.args, line)
782+
for dispatch_type, impl in impls:
783+
func_decl = builder.mapper.func_to_decl[impl]
784+
call_impl, next_impl = BasicBlock(), BasicBlock()
785+
should_call_impl = check_if_isinstance(builder, arg_info.args[0], dispatch_type, line)
786+
builder.add_bool_branch(should_call_impl, call_impl, next_impl)
787+
788+
# Call the registered implementation
789+
builder.activate_block(call_impl)
790+
791+
ret_val = builder.builder.call(
792+
func_decl, arg_info.args, arg_info.arg_kinds, arg_info.arg_names, line
793+
)
794+
builder.nonlocal_control[-1].gen_return(builder, ret_val, line)
795+
builder.activate_block(next_impl)

mypyc/irbuild/prebuildvisitor.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
from typing import Dict, List, Set
1+
from mypy.types import Instance, get_proper_type
2+
from typing import DefaultDict, Dict, List, NamedTuple, Set, Optional, Tuple
3+
from collections import defaultdict
24

35
from mypy.nodes import (
4-
Decorator, Expression, FuncDef, FuncItem, LambdaExpr, NameExpr, SymbolNode, Var, MemberExpr
6+
Decorator, Expression, FuncDef, FuncItem, LambdaExpr, NameExpr, SymbolNode, Var, MemberExpr,
7+
CallExpr, RefExpr, TypeInfo
58
)
69
from mypy.traverser import TraverserVisitor
710

@@ -50,6 +53,10 @@ def __init__(self) -> None:
5053
# Map function to its non-special decorators.
5154
self.funcs_to_decorators: Dict[FuncDef, List[Expression]] = {}
5255

56+
# Map of main singledispatch function to list of registered implementations
57+
self.singledispatch_impls: DefaultDict[
58+
FuncDef, List[Tuple[TypeInfo, FuncDef]]] = defaultdict(list)
59+
5360
def visit_decorator(self, dec: Decorator) -> None:
5461
if dec.decorators:
5562
# Only add the function being decorated if there exist
@@ -63,6 +70,20 @@ def visit_decorator(self, dec: Decorator) -> None:
6370
# Property setters are not treated as decorated methods.
6471
self.prop_setters.add(dec.func)
6572
else:
73+
removed: List[int] = []
74+
for i, d in enumerate(dec.decorators):
75+
impl = get_singledispatch_register_call_info(d, dec.func)
76+
if impl is not None:
77+
self.singledispatch_impls[impl.singledispatch_func].append(
78+
(impl.dispatch_type, dec.func))
79+
removed.append(i)
80+
for i in reversed(removed):
81+
del dec.decorators[i]
82+
# if the only decorators are register calls, we shouldn't treat this
83+
# as a decorated function because there aren't any decorators to apply
84+
if not dec.decorators:
85+
return
86+
6687
self.funcs_to_decorators[dec.func] = dec.decorators
6788
super().visit_decorator(dec)
6889

@@ -141,3 +162,45 @@ def add_free_variable(self, symbol: SymbolNode) -> None:
141162
# and mark is as a non-local symbol within that function.
142163
func = self.symbols_to_funcs[symbol]
143164
self.free_variables.setdefault(func, set()).add(symbol)
165+
166+
167+
class RegisteredImpl(NamedTuple):
168+
singledispatch_func: FuncDef
169+
dispatch_type: TypeInfo
170+
171+
172+
def get_singledispatch_register_call_info(decorator: Expression, func: FuncDef
173+
) -> Optional[RegisteredImpl]:
174+
# @fun.register(complex)
175+
# def g(arg): ...
176+
if (isinstance(decorator, CallExpr) and len(decorator.args) == 1
177+
and isinstance(decorator.args[0], RefExpr)):
178+
callee = decorator.callee
179+
dispatch_type = decorator.args[0].node
180+
if not isinstance(dispatch_type, TypeInfo):
181+
return None
182+
183+
if isinstance(callee, MemberExpr):
184+
return registered_impl_from_possible_register_call(callee, dispatch_type)
185+
# @fun.register
186+
# def g(arg: int): ...
187+
elif isinstance(decorator, MemberExpr):
188+
# we don't know if this is a register call yet, so we can't be sure that the function
189+
# actually has arguments
190+
if not func.arguments:
191+
return None
192+
arg_type = get_proper_type(func.arguments[0].variable.type)
193+
if not isinstance(arg_type, Instance):
194+
return None
195+
info = arg_type.type
196+
return registered_impl_from_possible_register_call(decorator, info)
197+
return None
198+
199+
200+
def registered_impl_from_possible_register_call(expr: MemberExpr, dispatch_type: TypeInfo
201+
) -> Optional[RegisteredImpl]:
202+
if expr.name == 'register' and isinstance(expr.expr, NameExpr):
203+
node = expr.expr.node
204+
if isinstance(node, Decorator):
205+
return RegisteredImpl(node.func, dispatch_type)
206+
return None

mypyc/primitives/misc_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@
141141
is_borrowed=True)
142142

143143
# isinstance(obj, cls)
144-
function_op(
144+
slow_isinstance_op = function_op(
145145
name='builtins.isinstance',
146146
arg_types=[object_rprimitive, object_rprimitive],
147147
return_type=c_int_rprimitive,

mypyc/test-data/run-singledispatch.test

Lines changed: 55 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Most of these tests are marked as xfails because mypyc doesn't support singledispatch yet
33
# (These tests will be re-enabled when mypyc supports singledispatch)
44

5-
[case testSpecializedImplementationUsed-xfail]
5+
[case testSpecializedImplementationUsed]
66
from functools import singledispatch
77

88
@singledispatch
@@ -17,7 +17,7 @@ def test_specialize() -> None:
1717
assert fun('a')
1818
assert not fun(3)
1919

20-
[case testSubclassesOfExpectedTypeUseSpecialized-xfail]
20+
[case testSubclassesOfExpectedTypeUseSpecialized]
2121
from functools import singledispatch
2222
class A: pass
2323
class B(A): pass
@@ -76,7 +76,7 @@ def test_singledispatch() -> None:
7676
assert fun('a') == 'str'
7777
assert fun({'a': 'b'}) == 'default'
7878

79-
[case testCanRegisterCompiledClasses-xfail]
79+
[case testCanRegisterCompiledClasses]
8080
from functools import singledispatch
8181
class A: pass
8282

@@ -136,21 +136,6 @@ def fun_specialized(arg: int) -> bool:
136136
def test_singledispatch() -> None:
137137
assert fun_specialized('a')
138138

139-
[case testTypeAnnotationsDisagreeWithRegisterArgument-xfail]
140-
from functools import singledispatch
141-
142-
@singledispatch
143-
def fun(arg) -> bool:
144-
return False
145-
146-
@fun.register(int)
147-
def fun_specialized(arg: str) -> bool:
148-
return True
149-
150-
def test_singledispatch() -> None:
151-
assert fun(3) # type: ignore
152-
assert not fun('a')
153-
154139
[case testNoneIsntATypeWhenUsedAsArgumentToRegister-xfail]
155140
from functools import singledispatch
156141

@@ -385,3 +370,55 @@ def test_verify() -> None:
385370
assert verify_list(MypyFile(), 5, ['a', 'b']) == ['in TypeInfo', 'hello']
386371
assert verify_list(TypeInfo(), str, ['a', 'b']) == ['in TypeInfo', 'hello']
387372
assert verify_list(TypeVarExpr(), 'a', ['x', 'y']) == ['x', 'y']
373+
374+
[case testArgsInRegisteredImplNamedDifferentlyFromMainFunction]
375+
from functools import singledispatch
376+
377+
@singledispatch
378+
def f(a) -> bool:
379+
return False
380+
381+
@f.register
382+
def g(b: int) -> bool:
383+
return True
384+
385+
def test_singledispatch():
386+
assert f(5)
387+
assert not f('a')
388+
389+
[case testKeywordArguments-xfail]
390+
from functools import singledispatch
391+
392+
@singledispatch
393+
def f(arg, *, kwarg: bool = False) -> bool:
394+
return not kwarg
395+
396+
@f.register
397+
def g(arg: int, *, kwarg: bool = True) -> bool:
398+
return kwarg
399+
400+
def test_keywords():
401+
assert f('a')
402+
assert f('a', kwarg=False)
403+
assert not f('a', kwarg=True)
404+
405+
assert f(1)
406+
assert f(1, kwarg=True)
407+
assert not f(1, kwarg=False)
408+
409+
[case testGeneratorAndMultipleTypesOfIterable-xfail]
410+
from functools import singledispatch
411+
from typing import *
412+
413+
@singledispatch
414+
def f(arg: Any) -> Iterable[int]:
415+
yield 1
416+
417+
@f.register
418+
def g(arg: str) -> Iterable[int]:
419+
return [0]
420+
421+
def test_iterables():
422+
assert f(1) != [1]
423+
assert list(f(1)) == [1]
424+
assert f('a') == [0]

0 commit comments

Comments
 (0)