Skip to content

Commit a8467c4

Browse files
authored
[stubgen] Add required ... rhs to NamedTuple fields with default values (#15680)
Closes #15638
1 parent b901d21 commit a8467c4

File tree

2 files changed

+74
-1
lines changed

2 files changed

+74
-1
lines changed

mypy/stubgen.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
OverloadedFuncDef,
103103
Statement,
104104
StrExpr,
105+
TempNode,
105106
TupleExpr,
106107
TypeInfo,
107108
UnaryExpr,
@@ -637,6 +638,7 @@ def __init__(
637638
self._state = EMPTY
638639
self._toplevel_names: list[str] = []
639640
self._include_private = include_private
641+
self._current_class: ClassDef | None = None
640642
self.import_tracker = ImportTracker()
641643
# Was the tree semantically analysed before?
642644
self.analyzed = analyzed
@@ -886,6 +888,7 @@ def get_fullname(self, expr: Expression) -> str:
886888
return resolved_name
887889

888890
def visit_class_def(self, o: ClassDef) -> None:
891+
self._current_class = o
889892
self.method_names = find_method_names(o.defs.body)
890893
sep: int | None = None
891894
if not self._indent and self._state != EMPTY:
@@ -922,6 +925,7 @@ def visit_class_def(self, o: ClassDef) -> None:
922925
else:
923926
self._state = CLASS
924927
self.method_names = set()
928+
self._current_class = None
925929

926930
def get_base_types(self, cdef: ClassDef) -> list[str]:
927931
"""Get list of base classes for a class."""
@@ -1330,7 +1334,20 @@ def get_init(
13301334
typename += f"[{final_arg}]"
13311335
else:
13321336
typename = self.get_str_type_of_node(rvalue)
1333-
return f"{self._indent}{lvalue}: {typename}\n"
1337+
initializer = self.get_assign_initializer(rvalue)
1338+
return f"{self._indent}{lvalue}: {typename}{initializer}\n"
1339+
1340+
def get_assign_initializer(self, rvalue: Expression) -> str:
1341+
"""Does this rvalue need some special initializer value?"""
1342+
if self._current_class and self._current_class.info:
1343+
# Current rules
1344+
# 1. Return `...` if we are dealing with `NamedTuple` and it has an existing default value
1345+
if self._current_class.info.is_named_tuple and not isinstance(rvalue, TempNode):
1346+
return " = ..."
1347+
# TODO: support other possible cases, where initializer is important
1348+
1349+
# By default, no initializer is required:
1350+
return ""
13341351

13351352
def add(self, string: str) -> None:
13361353
"""Add text to generated stub."""

test-data/unit/stubgen.test

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,62 @@ class Y(NamedTuple):
698698
a: int
699699
b: str
700700

701+
[case testNamedTupleClassSyntax_semanal]
702+
from typing import NamedTuple
703+
704+
class A(NamedTuple):
705+
x: int
706+
y: str = 'a'
707+
708+
class B(A):
709+
z1: str
710+
z2 = 1
711+
z3: str = 'b'
712+
713+
class RegularClass:
714+
x: int
715+
y: str = 'a'
716+
class NestedNamedTuple(NamedTuple):
717+
x: int
718+
y: str = 'a'
719+
z: str = 'b'
720+
[out]
721+
from typing import NamedTuple
722+
723+
class A(NamedTuple):
724+
x: int
725+
y: str = ...
726+
727+
class B(A):
728+
z1: str
729+
z2: int
730+
z3: str
731+
732+
class RegularClass:
733+
x: int
734+
y: str
735+
class NestedNamedTuple(NamedTuple):
736+
x: int
737+
y: str = ...
738+
z: str
739+
740+
741+
[case testNestedClassInNamedTuple_semanal-xfail]
742+
from typing import NamedTuple
743+
744+
# TODO: make sure that nested classes in `NamedTuple` are supported:
745+
class NamedTupleWithNestedClass(NamedTuple):
746+
class Nested:
747+
x: int
748+
y: str = 'a'
749+
[out]
750+
from typing import NamedTuple
751+
752+
class NamedTupleWithNestedClass(NamedTuple):
753+
class Nested:
754+
x: int
755+
y: str
756+
701757
[case testEmptyNamedtuple]
702758
import collections, typing
703759
X = collections.namedtuple('X', [])

0 commit comments

Comments
 (0)