Skip to content

Commit e07ad3b

Browse files
authored
Preserve types of functions registered with singledispatch (#10756)
When calling a function that has been registered as an implementation for a singledispatch function, currently, mypy doesn't type check arguments to those functions when those functions are called directly (instead of through the main singledispatch function). This fixes that by preserving the type of the registered function even after the register decorator has been used.
1 parent e82e8ec commit e07ad3b

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

mypy/plugins/singledispatch.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,13 +139,19 @@ def singledispatch_register_callback(ctx: MethodContext) -> Type:
139139
elif isinstance(first_arg_type, CallableType):
140140
# TODO: do more checking for registered functions
141141
register_function(ctx, ctx.type, first_arg_type)
142+
# The typeshed stubs for register say that the function returned is Callable[..., T], even
143+
# though the function returned is the same as the one passed in. We return the type of the
144+
# function so that mypy can properly type check cases where the registered function is used
145+
# directly (instead of through singledispatch)
146+
return first_arg_type
142147

143-
# register doesn't modify the function it's used on
148+
# fallback in case we don't recognize the arguments
144149
return ctx.default_return_type
145150

146151

147152
def register_function(ctx: PluginContext, singledispatch_obj: Instance, func: Type,
148153
register_arg: Optional[Type] = None) -> None:
154+
"""Register a function"""
149155

150156
func = get_proper_type(func)
151157
if not isinstance(func, CallableType):
@@ -165,6 +171,7 @@ def register_function(ctx: PluginContext, singledispatch_obj: Instance, func: Ty
165171
format_type(dispatch_type), format_type(fallback_dispatch_type)
166172
), func.definition)
167173
return
174+
return
168175

169176

170177
def get_dispatch_type(func: CallableType, register_arg: Optional[Type]) -> Optional[Type]:
@@ -183,6 +190,8 @@ def call_singledispatch_function_after_register_argument(ctx: MethodContext) ->
183190
func = get_first_arg(ctx.arg_types)
184191
if func is not None:
185192
register_function(ctx, type_args.singledispatch_obj, func, type_args.register_type)
193+
# see call to register_function in the callback for register
194+
return func
186195
return ctx.default_return_type
187196

188197

test-data/unit/check-singledispatch.test

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,3 +289,28 @@ def h(a: Missing) -> None:
289289
pass
290290

291291
[builtins fixtures/args.pyi]
292+
293+
[case testIncorrectArgumentTypeWhenCallingRegisteredImplDirectly]
294+
from functools import singledispatch
295+
296+
@singledispatch
297+
def f(arg, arg2: str) -> bool:
298+
return False
299+
300+
@f.register
301+
def g(arg: int, arg2: str) -> bool:
302+
pass
303+
304+
@f.register(str)
305+
def h(arg, arg2: str) -> bool:
306+
pass
307+
308+
g('a', 'a') # E: Argument 1 to "g" has incompatible type "str"; expected "int"
309+
g(1, 1) # E: Argument 2 to "g" has incompatible type "int"; expected "str"
310+
311+
# don't show errors for incorrect first argument here, because there's no type annotation for the
312+
# first argument
313+
h(1, 'a')
314+
h('a', 1) # E: Argument 2 to "h" has incompatible type "int"; expected "str"
315+
316+
[builtins fixtures/args.pyi]

0 commit comments

Comments
 (0)