Skip to content

Commit cf8a7b5

Browse files
authored
Merge pull request #95 from python/master
[mypyc] Fix order of dispatch type checking in singledispatch functio…
2 parents 0942c53 + 97b3b90 commit cf8a7b5

File tree

3 files changed

+78
-12
lines changed

3 files changed

+78
-12
lines changed

mypy/build.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import types
2323

2424
from typing import (AbstractSet, Any, Dict, Iterable, Iterator, List, Sequence,
25-
Mapping, NamedTuple, Optional, Set, Tuple, Union, Callable, TextIO)
25+
Mapping, NamedTuple, Optional, Set, Tuple, TypeVar, Union, Callable, TextIO)
2626
from typing_extensions import ClassVar, Final, TYPE_CHECKING
2727
from mypy_extensions import TypedDict
2828

@@ -3234,21 +3234,22 @@ def dfs(v: str) -> Iterator[Set[str]]:
32343234
yield from dfs(v)
32353235

32363236

3237-
def topsort(data: Dict[AbstractSet[str],
3238-
Set[AbstractSet[str]]]) -> Iterable[Set[AbstractSet[str]]]:
3237+
T = TypeVar("T")
3238+
3239+
3240+
def topsort(data: Dict[T, Set[T]]) -> Iterable[Set[T]]:
32393241
"""Topological sort.
32403242
32413243
Args:
3242-
data: A map from SCCs (represented as frozen sets of strings) to
3243-
sets of SCCs, its dependencies. NOTE: This data structure
3244+
data: A map from vertices to all vertices that it has an edge
3245+
connecting it to. NOTE: This data structure
32443246
is modified in place -- for normalization purposes,
32453247
self-dependencies are removed and entries representing
32463248
orphans are added.
32473249
32483250
Returns:
3249-
An iterator yielding sets of SCCs that have an equivalent
3250-
ordering. NOTE: The algorithm doesn't care about the internal
3251-
structure of SCCs.
3251+
An iterator yielding sets of vertices that have an equivalent
3252+
ordering.
32523253
32533254
Example:
32543255
Suppose the input has the following structure:

mypyc/irbuild/function.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
instance of the callable class.
1111
"""
1212

13-
from typing import NamedTuple, Optional, List, Sequence, Tuple, Union, Dict
13+
from mypy.build import topsort
14+
from typing import (
15+
NamedTuple, Optional, List, Sequence, Tuple, Union, Dict, Iterator,
16+
)
1417

1518
from mypy.nodes import (
1619
ClassDef, FuncDef, OverloadedFuncDef, Decorator, Var, YieldFromExpr, AwaitExpr, YieldExpr,
@@ -822,9 +825,10 @@ def gen_func_call_and_return(func_name: str) -> None:
822825
coerced = builder.coerce(ret_val, current_func_decl.sig.ret_type, line)
823826
builder.nonlocal_control[-1].gen_return(builder, coerced, line)
824827

825-
# Reverse the list of registered implementations so we use the implementations defined later
826-
# if there are multiple overlapping implementations
827-
for dispatch_type, impl in reversed(impls):
828+
# Sort the list of implementations so that we check any subclasses before we check the classes
829+
# they inherit from, to better match singledispatch's behavior of going through the argument's
830+
# MRO, and using the first implementation it finds
831+
for dispatch_type, impl in sort_with_subclasses_first(impls):
828832
call_impl, next_impl = BasicBlock(), BasicBlock()
829833
should_call_impl = check_if_isinstance(builder, arg_info.args[0], dispatch_type, line)
830834
builder.add_bool_branch(should_call_impl, call_impl, next_impl)
@@ -857,3 +861,17 @@ def gen_dispatch_func_ir(
857861
func_decl = FuncDecl(dispatch_name, None, builder.module_name, sig)
858862
dispatch_func_ir = FuncIR(func_decl, args, blocks)
859863
return dispatch_func_ir
864+
865+
866+
def sort_with_subclasses_first(
867+
impls: List[Tuple[TypeInfo, FuncDef]]
868+
) -> Iterator[Tuple[TypeInfo, FuncDef]]:
869+
870+
# graph with edges pointing from every class to their subclasses
871+
graph = {typ: set(typ.mro[1:]) for typ, _ in impls}
872+
873+
dispatch_types = topsort(graph)
874+
impl_dict = {typ: func for typ, func in impls}
875+
876+
for group in reversed(list(dispatch_types)):
877+
yield from ((typ, impl_dict[typ]) for typ in group if typ in impl_dict)

mypyc/test-data/run-singledispatch.test

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,3 +465,50 @@ def h(arg: str) -> int:
465465
def test_singledispatch():
466466
assert f('a') == 35
467467
assert f(A()) == 10
468+
469+
[case testMoreSpecificTypeBeforeLessSpecificType]
470+
from functools import singledispatch
471+
class A: pass
472+
class B(A): pass
473+
474+
@singledispatch
475+
def f(arg) -> str:
476+
return 'default'
477+
478+
@f.register
479+
def g(arg: B) -> str:
480+
return 'b'
481+
482+
@f.register
483+
def h(arg: A) -> str:
484+
return 'a'
485+
486+
def test_singledispatch():
487+
assert f(B()) == 'b'
488+
assert f(A()) == 'a'
489+
assert f(5) == 'default'
490+
491+
[case testMultipleRelatedClassesBeingRegistered]
492+
from functools import singledispatch
493+
494+
class A: pass
495+
class B(A): pass
496+
class C(B): pass
497+
498+
@singledispatch
499+
def f(arg) -> str: return 'default'
500+
501+
@f.register
502+
def _(arg: A) -> str: return 'a'
503+
504+
@f.register
505+
def _(arg: C) -> str: return 'c'
506+
507+
@f.register
508+
def _(arg: B) -> str: return 'b'
509+
510+
def test_singledispatch():
511+
assert f(A()) == 'a'
512+
assert f(B()) == 'b'
513+
assert f(C()) == 'c'
514+
assert f(1) == 'default'

0 commit comments

Comments
 (0)