Skip to content

Commit f6bac71

Browse files
authored
stubtest: verify the contents of __all__ in a stub (#12214)
1 parent 6dc3a39 commit f6bac71

File tree

2 files changed

+79
-5
lines changed

2 files changed

+79
-5
lines changed

mypy/stubtest.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from contextlib import redirect_stderr, redirect_stdout
2222
from functools import singledispatch
2323
from pathlib import Path
24-
from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union, cast
24+
from typing import Any, Dict, Generic, Iterator, List, Optional, Set, Tuple, TypeVar, Union, cast
2525

2626
import typing_extensions
2727
from typing_extensions import Type, get_origin
@@ -243,6 +243,38 @@ def verify(
243243
yield Error(object_path, "is an unknown mypy node", stub, runtime)
244244

245245

246+
def _verify_exported_names(
247+
object_path: List[str], stub: nodes.MypyFile, runtime_all_as_set: Set[str]
248+
) -> Iterator[Error]:
249+
public_names_in_stub = {m for m, o in stub.names.items() if o.module_public}
250+
names_in_stub_not_runtime = sorted(public_names_in_stub - runtime_all_as_set)
251+
names_in_runtime_not_stub = sorted(runtime_all_as_set - public_names_in_stub)
252+
if not (names_in_runtime_not_stub or names_in_stub_not_runtime):
253+
return
254+
yield Error(
255+
object_path,
256+
(
257+
"module: names exported from the stub "
258+
"do not correspond to the names exported at runtime.\n"
259+
"(Note: This is probably either due to an inaccurate "
260+
"`__all__` in the stub, "
261+
"or due to a name being declared in `__all__` "
262+
"but not actually defined in the stub.)"
263+
),
264+
# pass in MISSING instead of the stub and runtime objects,
265+
# as the line numbers aren't very relevant here,
266+
# and it makes for a prettier error message.
267+
stub_object=MISSING,
268+
runtime_object=MISSING,
269+
stub_desc=(
270+
f"Names exported in the stub but not at runtime: " f"{names_in_stub_not_runtime}"
271+
),
272+
runtime_desc=(
273+
f"Names exported at runtime but not in the stub: " f"{names_in_runtime_not_stub}"
274+
),
275+
)
276+
277+
246278
@verify.register(nodes.MypyFile)
247279
def verify_mypyfile(
248280
stub: nodes.MypyFile, runtime: MaybeMissing[types.ModuleType], object_path: List[str]
@@ -254,6 +286,17 @@ def verify_mypyfile(
254286
yield Error(object_path, "is not a module", stub, runtime)
255287
return
256288

289+
runtime_all_as_set: Optional[Set[str]]
290+
291+
if hasattr(runtime, "__all__"):
292+
runtime_all_as_set = set(runtime.__all__)
293+
if "__all__" in stub.names:
294+
# Only verify the contents of the stub's __all__
295+
# if the stub actually defines __all__
296+
yield from _verify_exported_names(object_path, stub, runtime_all_as_set)
297+
else:
298+
runtime_all_as_set = None
299+
257300
# Check things in the stub
258301
to_check = {
259302
m
@@ -272,16 +315,16 @@ def _belongs_to_runtime(r: types.ModuleType, attr: str) -> bool:
272315
return not isinstance(obj, types.ModuleType)
273316

274317
runtime_public_contents = (
275-
runtime.__all__
276-
if hasattr(runtime, "__all__")
277-
else [
318+
runtime_all_as_set
319+
if runtime_all_as_set is not None
320+
else {
278321
m
279322
for m in dir(runtime)
280323
if not is_probably_private(m)
281324
# Ensure that the object's module is `runtime`, since in the absence of __all__ we
282325
# don't have a good way to detect re-exports at runtime.
283326
and _belongs_to_runtime(runtime, m)
284-
]
327+
}
285328
)
286329
# Check all things declared in module's __all__, falling back to our best guess
287330
to_check.update(runtime_public_contents)

mypy/test/teststubtest.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,37 @@ def f(): return 3
947947
error=None,
948948
)
949949

950+
@collect_cases
951+
def test_all_at_runtime_not_stub(self) -> Iterator[Case]:
952+
yield Case(
953+
stub="Z: int",
954+
runtime="""
955+
__all__ = []
956+
Z = 5""",
957+
error=None,
958+
)
959+
960+
@collect_cases
961+
def test_all_in_stub_not_at_runtime(self) -> Iterator[Case]:
962+
yield Case(stub="__all__ = ()", runtime="", error="__all__")
963+
964+
@collect_cases
965+
def test_all_in_stub_different_to_all_at_runtime(self) -> Iterator[Case]:
966+
# We *should* emit an error with the module name itself,
967+
# if the stub *does* define __all__,
968+
# but the stub's __all__ is inconsistent with the runtime's __all__
969+
yield Case(
970+
stub="""
971+
__all__ = ['foo']
972+
foo: str
973+
""",
974+
runtime="""
975+
__all__ = []
976+
foo = 'foo'
977+
""",
978+
error="",
979+
)
980+
950981
@collect_cases
951982
def test_missing(self) -> Iterator[Case]:
952983
yield Case(stub="x = 5", runtime="", error="x")

0 commit comments

Comments
 (0)