Skip to content

Commit e28925d

Browse files
authored
stubgen: use PEP 604 unions everywhere (#16519)
Fixes #12920
1 parent a8741d8 commit e28925d

File tree

3 files changed

+57
-20
lines changed

3 files changed

+57
-20
lines changed

mypy/stubgen.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
=> Generate out/urllib/parse.pyi.
2323
2424
$ stubgen -p urllib
25-
=> Generate stubs for whole urlib package (recursively).
25+
=> Generate stubs for whole urllib package (recursively).
2626
2727
For C modules, you can get more precise function signatures by parsing .rst (Sphinx)
2828
documentation for extra information. For this, use the --doc-dir option:
@@ -306,6 +306,13 @@ def visit_str_expr(self, node: StrExpr) -> str:
306306
return repr(node.value)
307307

308308
def visit_index_expr(self, node: IndexExpr) -> str:
309+
base_fullname = self.stubgen.get_fullname(node.base)
310+
if base_fullname == "typing.Union":
311+
if isinstance(node.index, TupleExpr):
312+
return " | ".join([item.accept(self) for item in node.index.items])
313+
return node.index.accept(self)
314+
if base_fullname == "typing.Optional":
315+
return f"{node.index.accept(self)} | None"
309316
base = node.base.accept(self)
310317
index = node.index.accept(self)
311318
if len(index) > 2 and index.startswith("(") and index.endswith(")"):
@@ -682,7 +689,7 @@ def process_decorator(self, o: Decorator) -> None:
682689
self.add_decorator(qualname, require_name=False)
683690

684691
def get_fullname(self, expr: Expression) -> str:
685-
"""Return the full name resolving imports and import aliases."""
692+
"""Return the expression's full name."""
686693
if (
687694
self.analyzed
688695
and isinstance(expr, (NameExpr, MemberExpr))
@@ -691,16 +698,7 @@ def get_fullname(self, expr: Expression) -> str:
691698
):
692699
return expr.fullname
693700
name = get_qualified_name(expr)
694-
if "." not in name:
695-
real_module = self.import_tracker.module_for.get(name)
696-
real_short = self.import_tracker.reverse_alias.get(name, name)
697-
if real_module is None and real_short not in self.defined_names:
698-
real_module = "builtins" # not imported and not defined, must be a builtin
699-
else:
700-
name_module, real_short = name.split(".", 1)
701-
real_module = self.import_tracker.reverse_alias.get(name_module, name_module)
702-
resolved_name = real_short if real_module is None else f"{real_module}.{real_short}"
703-
return resolved_name
701+
return self.resolve_name(name)
704702

705703
def visit_class_def(self, o: ClassDef) -> None:
706704
self._current_class = o

mypy/stubutil.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,11 @@ def visit_any(self, t: AnyType) -> str:
226226

227227
def visit_unbound_type(self, t: UnboundType) -> str:
228228
s = t.name
229+
fullname = self.stubgen.resolve_name(s)
230+
if fullname == "typing.Union":
231+
return " | ".join([item.accept(self) for item in t.args])
232+
if fullname == "typing.Optional":
233+
return f"{t.args[0].accept(self)} | None"
229234
if self.known_modules is not None and "." in s:
230235
# see if this object is from any of the modules that we're currently processing.
231236
# reverse sort so that subpackages come before parents: e.g. "foo.bar" before "foo".
@@ -588,14 +593,18 @@ def __init__(
588593
def get_sig_generators(self) -> list[SignatureGenerator]:
589594
return []
590595

591-
def refers_to_fullname(self, name: str, fullname: str | tuple[str, ...]) -> bool:
592-
"""Return True if the variable name identifies the same object as the given fullname(s)."""
593-
if isinstance(fullname, tuple):
594-
return any(self.refers_to_fullname(name, fname) for fname in fullname)
595-
module, short = fullname.rsplit(".", 1)
596-
return self.import_tracker.module_for.get(name) == module and (
597-
name == short or self.import_tracker.reverse_alias.get(name) == short
598-
)
596+
def resolve_name(self, name: str) -> str:
597+
"""Return the full name resolving imports and import aliases."""
598+
if "." not in name:
599+
real_module = self.import_tracker.module_for.get(name)
600+
real_short = self.import_tracker.reverse_alias.get(name, name)
601+
if real_module is None and real_short not in self.defined_names:
602+
real_module = "builtins" # not imported and not defined, must be a builtin
603+
else:
604+
name_module, real_short = name.split(".", 1)
605+
real_module = self.import_tracker.reverse_alias.get(name_module, name_module)
606+
resolved_name = real_short if real_module is None else f"{real_module}.{real_short}"
607+
return resolved_name
599608

600609
def add_name(self, fullname: str, require: bool = True) -> str:
601610
"""Add a name to be imported and return the name reference.

test-data/unit/stubgen.test

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4139,3 +4139,33 @@ from dataclasses import dataclass
41394139
class X(missing.Base):
41404140
a: int
41414141
def __init__(self, *selfa_, a, **selfa__) -> None: ...
4142+
4143+
[case testAlwaysUsePEP604Union]
4144+
import typing
4145+
import typing as t
4146+
from typing import Optional, Union, Optional as O, Union as U
4147+
import x
4148+
4149+
union = Union[int, str]
4150+
bad_union = Union[int]
4151+
nested_union = Optional[Union[int, str]]
4152+
not_union = x.Union[int, str]
4153+
u = U[int, str]
4154+
o = O[int]
4155+
4156+
def f1(a: Union["int", Optional[tuple[int, t.Optional[int]]]]) -> int: ...
4157+
def f2(a: typing.Union[int | x.Union[int, int], O[float]]) -> int: ...
4158+
4159+
[out]
4160+
import x
4161+
from _typeshed import Incomplete
4162+
4163+
union = int | str
4164+
bad_union = int
4165+
nested_union = int | str | None
4166+
not_union: Incomplete
4167+
u = int | str
4168+
o = int | None
4169+
4170+
def f1(a: int | tuple[int, int | None] | None) -> int: ...
4171+
def f2(a: int | x.Union[int, int] | float | None) -> int: ...

0 commit comments

Comments
 (0)