Skip to content

Commit 85fc99c

Browse files
authored
stubtest: error if a function is async at runtime but not in the stub (and vice versa) (#12212)
1 parent c7365ef commit 85fc99c

File tree

2 files changed

+90
-4
lines changed

2 files changed

+90
-4
lines changed

mypy/stubtest.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,18 @@ def _verify_signature(
676676
yield 'runtime does not have **kwargs argument "{}"'.format(stub.varkw.variable.name)
677677

678678

679+
def _verify_coroutine(
680+
stub: nodes.FuncItem, runtime: Any, *, runtime_is_coroutine: bool
681+
) -> Optional[str]:
682+
if stub.is_coroutine:
683+
if not runtime_is_coroutine:
684+
return 'is an "async def" function in the stub, but not at runtime'
685+
else:
686+
if runtime_is_coroutine:
687+
return 'is an "async def" function at runtime, but not in the stub'
688+
return None
689+
690+
679691
@verify.register(nodes.FuncItem)
680692
def verify_funcitem(
681693
stub: nodes.FuncItem, runtime: MaybeMissing[Any], object_path: List[str]
@@ -693,19 +705,40 @@ def verify_funcitem(
693705
yield Error(object_path, "is inconsistent, " + message, stub, runtime)
694706

695707
signature = safe_inspect_signature(runtime)
708+
runtime_is_coroutine = inspect.iscoroutinefunction(runtime)
709+
710+
if signature:
711+
stub_sig = Signature.from_funcitem(stub)
712+
runtime_sig = Signature.from_inspect_signature(signature)
713+
runtime_sig_desc = f'{"async " if runtime_is_coroutine else ""}def {signature}'
714+
else:
715+
runtime_sig_desc = None
716+
717+
coroutine_mismatch_error = _verify_coroutine(
718+
stub,
719+
runtime,
720+
runtime_is_coroutine=runtime_is_coroutine
721+
)
722+
723+
if coroutine_mismatch_error is not None:
724+
yield Error(
725+
object_path,
726+
coroutine_mismatch_error,
727+
stub,
728+
runtime,
729+
runtime_desc=runtime_sig_desc
730+
)
731+
696732
if not signature:
697733
return
698734

699-
stub_sig = Signature.from_funcitem(stub)
700-
runtime_sig = Signature.from_inspect_signature(signature)
701-
702735
for message in _verify_signature(stub_sig, runtime_sig, function_name=stub.name):
703736
yield Error(
704737
object_path,
705738
"is inconsistent, " + message,
706739
stub,
707740
runtime,
708-
runtime_desc="def " + str(signature),
741+
runtime_desc=runtime_sig_desc,
709742
)
710743

711744

mypy/test/teststubtest.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,34 @@ def use_tmp_dir() -> Iterator[None]:
2828

2929
TEST_MODULE_NAME = "test_module"
3030

31+
32+
stubtest_typing_stub = """
33+
Any = object()
34+
35+
class _SpecialForm:
36+
def __getitem__(self, typeargs: Any) -> object: ...
37+
38+
Callable: _SpecialForm = ...
39+
Generic: _SpecialForm = ...
40+
41+
class TypeVar:
42+
def __init__(self, name, covariant: bool = ..., contravariant: bool = ...) -> None: ...
43+
44+
_T = TypeVar("_T")
45+
_T_co = TypeVar("_T_co", covariant=True)
46+
_K = TypeVar("_K")
47+
_V = TypeVar("_V")
48+
_S = TypeVar("_S", contravariant=True)
49+
_R = TypeVar("_R", covariant=True)
50+
51+
class Coroutine(Generic[_T_co, _S, _R]): ...
52+
class Iterable(Generic[_T_co]): ...
53+
class Mapping(Generic[_K, _V]): ...
54+
class Sequence(Iterable[_T_co]): ...
55+
class Tuple(Sequence[_T_co]): ...
56+
def overload(func: _T) -> _T: ...
57+
"""
58+
3159
stubtest_builtins_stub = """
3260
from typing import Generic, Mapping, Sequence, TypeVar, overload
3361
@@ -66,6 +94,8 @@ def run_stubtest(
6694
with use_tmp_dir():
6795
with open("builtins.pyi", "w") as f:
6896
f.write(stubtest_builtins_stub)
97+
with open("typing.pyi", "w") as f:
98+
f.write(stubtest_typing_stub)
6999
with open("{}.pyi".format(TEST_MODULE_NAME), "w") as f:
70100
f.write(stub)
71101
with open("{}.py".format(TEST_MODULE_NAME), "w") as f:
@@ -172,6 +202,29 @@ class X:
172202
error="X.mistyped_var",
173203
)
174204

205+
@collect_cases
206+
def test_coroutines(self) -> Iterator[Case]:
207+
yield Case(
208+
stub="async def foo() -> int: ...",
209+
runtime="def foo(): return 5",
210+
error="foo",
211+
)
212+
yield Case(
213+
stub="def bar() -> int: ...",
214+
runtime="async def bar(): return 5",
215+
error="bar",
216+
)
217+
yield Case(
218+
stub="def baz() -> int: ...",
219+
runtime="def baz(): return 5",
220+
error=None,
221+
)
222+
yield Case(
223+
stub="async def bingo() -> int: ...",
224+
runtime="async def bingo(): return 5",
225+
error=None,
226+
)
227+
175228
@collect_cases
176229
def test_arg_name(self) -> Iterator[Case]:
177230
yield Case(

0 commit comments

Comments
 (0)