Skip to content

Commit 0ff7a29

Browse files
stubgen: include __all__ in output (#16356)
Fixes #10314
1 parent 128176a commit 0ff7a29

File tree

3 files changed

+79
-19
lines changed

3 files changed

+79
-19
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
## Unreleased
44

5-
...
5+
Stubgen will now include `__all__` in its output if it is in the input file (PR [16356](https://github.com/python/mypy/pull/16356)).
66

77
#### Other Notable Changes and Fixes
88
...

mypy/stubutil.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -614,10 +614,24 @@ def get_imports(self) -> str:
614614

615615
def output(self) -> str:
616616
"""Return the text for the stub."""
617-
imports = self.get_imports()
618-
if imports and self._output:
619-
imports += "\n"
620-
return imports + "".join(self._output)
617+
pieces: list[str] = []
618+
if imports := self.get_imports():
619+
pieces.append(imports)
620+
if dunder_all := self.get_dunder_all():
621+
pieces.append(dunder_all)
622+
if self._output:
623+
pieces.append("".join(self._output))
624+
return "\n".join(pieces)
625+
626+
def get_dunder_all(self) -> str:
627+
"""Return the __all__ list for the stub."""
628+
if self._all_:
629+
# Note we emit all names in the runtime __all__ here, even if they
630+
# don't actually exist. If that happens, the runtime has a bug, and
631+
# it's not obvious what the correct behavior should be. We choose
632+
# to reflect the runtime __all__ as closely as possible.
633+
return f"__all__ = {self._all_!r}\n"
634+
return ""
621635

622636
def add(self, string: str) -> None:
623637
"""Add text to generated stub."""
@@ -651,8 +665,7 @@ def set_defined_names(self, defined_names: set[str]) -> None:
651665
self.defined_names = defined_names
652666
# Names in __all__ are required
653667
for name in self._all_ or ():
654-
if name not in self.IGNORED_DUNDERS:
655-
self.import_tracker.reexport(name)
668+
self.import_tracker.reexport(name)
656669

657670
# These are "soft" imports for objects which might appear in annotations but not have
658671
# a corresponding import statement.
@@ -751,7 +764,13 @@ def is_private_name(self, name: str, fullname: str | None = None) -> bool:
751764
return False
752765
if name == "_":
753766
return False
754-
return name.startswith("_") and (not name.endswith("__") or name in self.IGNORED_DUNDERS)
767+
if not name.startswith("_"):
768+
return False
769+
if self._all_ and name in self._all_:
770+
return False
771+
if name.startswith("__") and name.endswith("__"):
772+
return name in self.IGNORED_DUNDERS
773+
return True
755774

756775
def should_reexport(self, name: str, full_module: str, name_is_alias: bool) -> bool:
757776
if (
@@ -761,18 +780,21 @@ def should_reexport(self, name: str, full_module: str, name_is_alias: bool) -> b
761780
):
762781
# Special case certain names that should be exported, against our general rules.
763782
return True
783+
if name_is_alias:
784+
return False
785+
if self.export_less:
786+
return False
787+
if not self.module_name:
788+
return False
764789
is_private = self.is_private_name(name, full_module + "." + name)
790+
if is_private:
791+
return False
765792
top_level = full_module.split(".")[0]
766793
self_top_level = self.module_name.split(".", 1)[0]
767-
if (
768-
not name_is_alias
769-
and not self.export_less
770-
and (not self._all_ or name in self.IGNORED_DUNDERS)
771-
and self.module_name
772-
and not is_private
773-
and top_level in (self_top_level, "_" + self_top_level)
774-
):
794+
if top_level not in (self_top_level, "_" + self_top_level):
775795
# Export imports from the same package, since we can't reliably tell whether they
776796
# are part of the public API.
777-
return True
778-
return False
797+
return False
798+
if self._all_:
799+
return name in self._all_
800+
return True

test-data/unit/stubgen.test

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -587,20 +587,26 @@ __all__ = [] + ['f']
587587
def f(): ...
588588
def g(): ...
589589
[out]
590+
__all__ = ['f']
591+
590592
def f() -> None: ...
591593

592594
[case testOmitDefsNotInAll_semanal]
593595
__all__ = ['f']
594596
def f(): ...
595597
def g(): ...
596598
[out]
599+
__all__ = ['f']
600+
597601
def f() -> None: ...
598602

599603
[case testOmitDefsNotInAll_inspect]
600604
__all__ = [] + ['f']
601605
def f(): ...
602606
def g(): ...
603607
[out]
608+
__all__ = ['f']
609+
604610
def f(): ...
605611

606612
[case testVarDefsNotInAll_import]
@@ -610,6 +616,8 @@ x = 1
610616
y = 1
611617
def g(): ...
612618
[out]
619+
__all__ = ['f', 'g']
620+
613621
def f() -> None: ...
614622
def g() -> None: ...
615623

@@ -620,6 +628,8 @@ x = 1
620628
y = 1
621629
def g(): ...
622630
[out]
631+
__all__ = ['f', 'g']
632+
623633
def f(): ...
624634
def g(): ...
625635

@@ -628,6 +638,8 @@ __all__ = [] + ['f']
628638
def f(): ...
629639
class A: ...
630640
[out]
641+
__all__ = ['f']
642+
631643
def f() -> None: ...
632644

633645
class A: ...
@@ -637,6 +649,8 @@ __all__ = [] + ['f']
637649
def f(): ...
638650
class A: ...
639651
[out]
652+
__all__ = ['f']
653+
640654
def f(): ...
641655

642656
class A: ...
@@ -647,6 +661,8 @@ class A:
647661
x = 1
648662
def f(self): ...
649663
[out]
664+
__all__ = ['A']
665+
650666
class A:
651667
x: int
652668
def f(self) -> None: ...
@@ -684,6 +700,8 @@ x = 1
684700
[out]
685701
from re import match as match, sub as sub
686702

703+
__all__ = ['match', 'sub', 'x']
704+
687705
x: int
688706

689707
[case testExportModule_import]
@@ -694,6 +712,8 @@ y = 2
694712
[out]
695713
import re as re
696714

715+
__all__ = ['re', 'x']
716+
697717
x: int
698718

699719
[case testExportModule2_import]
@@ -704,6 +724,8 @@ y = 2
704724
[out]
705725
import re as re
706726

727+
__all__ = ['re', 'x']
728+
707729
x: int
708730

709731
[case testExportModuleAs_import]
@@ -714,6 +736,8 @@ y = 2
714736
[out]
715737
import re as rex
716738

739+
__all__ = ['rex', 'x']
740+
717741
x: int
718742

719743
[case testExportModuleInPackage_import]
@@ -722,13 +746,17 @@ __all__ = ['p']
722746
[out]
723747
import urllib.parse as p
724748

749+
__all__ = ['p']
750+
725751
[case testExportPackageOfAModule_import]
726752
import urllib.parse
727753
__all__ = ['urllib']
728754

729755
[out]
730756
import urllib as urllib
731757

758+
__all__ = ['urllib']
759+
732760
[case testRelativeImportAll]
733761
from .x import *
734762
[out]
@@ -741,6 +769,8 @@ x = 1
741769
class C:
742770
def g(self): ...
743771
[out]
772+
__all__ = ['f', 'x', 'C', 'g']
773+
744774
def f() -> None: ...
745775

746776
x: int
@@ -758,6 +788,8 @@ x = 1
758788
class C:
759789
def g(self): ...
760790
[out]
791+
__all__ = ['f', 'x', 'C', 'g']
792+
761793
def f(): ...
762794

763795
x: int
@@ -2343,6 +2375,8 @@ else:
23432375
[out]
23442376
import cookielib as cookielib
23452377

2378+
__all__ = ['cookielib']
2379+
23462380
[case testCannotCalculateMRO_semanal]
23472381
class X: pass
23482382

@@ -2788,6 +2822,8 @@ class A: pass
27882822
# p/__init__.pyi
27892823
from p.a import A
27902824

2825+
__all__ = ['a']
2826+
27912827
a: A
27922828
# p/a.pyi
27932829
class A: ...
@@ -2961,7 +2997,9 @@ __uri__ = ''
29612997
__version__ = ''
29622998

29632999
[out]
2964-
from m import __version__ as __version__
3000+
from m import __about__ as __about__, __author__ as __author__, __version__ as __version__
3001+
3002+
__all__ = ['__about__', '__author__', '__version__']
29653003

29663004
[case testAttrsClass_semanal]
29673005
import attrs

0 commit comments

Comments
 (0)