Skip to content

Commit 0a720ed

Browse files
authored
Fix typevar tuple handling to expect unpack in class def (#13630)
Originally this PR was intended to add some test cases from PEP646. However it became immediately apparent that there was a major bug in the implementation where we expected the definition to look like: ``` class Foo(Generic[Ts]) ``` When it is supposed to be ``` class Foo(Generic[Unpack[Ts]]) ``` This fixes that. Also improve constraints solving involving typevar tuples.
1 parent 6a50192 commit 0a720ed

File tree

5 files changed

+176
-13
lines changed

5 files changed

+176
-13
lines changed

mypy/constraints.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -583,9 +583,60 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
583583
if self.direction == SUBTYPE_OF and template.type.has_base(instance.type.fullname):
584584
mapped = map_instance_to_supertype(template, instance.type)
585585
tvars = mapped.type.defn.type_vars
586+
587+
if instance.type.has_type_var_tuple_type:
588+
mapped_prefix, mapped_middle, mapped_suffix = split_with_instance(mapped)
589+
instance_prefix, instance_middle, instance_suffix = split_with_instance(
590+
instance
591+
)
592+
593+
# Add a constraint for the type var tuple, and then
594+
# remove it for the case below.
595+
instance_unpack = extract_unpack(instance_middle)
596+
if instance_unpack is not None:
597+
if isinstance(instance_unpack, TypeVarTupleType):
598+
res.append(
599+
Constraint(
600+
instance_unpack, SUBTYPE_OF, TypeList(list(mapped_middle))
601+
)
602+
)
603+
elif (
604+
isinstance(instance_unpack, Instance)
605+
and instance_unpack.type.fullname == "builtins.tuple"
606+
):
607+
for item in mapped_middle:
608+
res.extend(
609+
infer_constraints(
610+
instance_unpack.args[0], item, self.direction
611+
)
612+
)
613+
elif isinstance(instance_unpack, TupleType):
614+
if len(instance_unpack.items) == len(mapped_middle):
615+
for instance_arg, item in zip(
616+
instance_unpack.items, mapped_middle
617+
):
618+
res.extend(
619+
infer_constraints(instance_arg, item, self.direction)
620+
)
621+
622+
mapped_args = mapped_prefix + mapped_suffix
623+
instance_args = instance_prefix + instance_suffix
624+
625+
assert instance.type.type_var_tuple_prefix is not None
626+
assert instance.type.type_var_tuple_suffix is not None
627+
tvars_prefix, _, tvars_suffix = split_with_prefix_and_suffix(
628+
tuple(tvars),
629+
instance.type.type_var_tuple_prefix,
630+
instance.type.type_var_tuple_suffix,
631+
)
632+
tvars = list(tvars_prefix + tvars_suffix)
633+
else:
634+
mapped_args = mapped.args
635+
instance_args = instance.args
636+
586637
# N.B: We use zip instead of indexing because the lengths might have
587638
# mismatches during daemon reprocessing.
588-
for tvar, mapped_arg, instance_arg in zip(tvars, mapped.args, instance.args):
639+
for tvar, mapped_arg, instance_arg in zip(tvars, mapped_args, instance_args):
589640
# TODO(PEP612): More ParamSpec work (or is Parameters the only thing accepted)
590641
if isinstance(tvar, TypeVarType):
591642
# The constraints for generic type parameters depend on variance.
@@ -617,8 +668,9 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
617668
res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix))
618669
elif isinstance(suffix, ParamSpecType):
619670
res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix))
620-
elif isinstance(tvar, TypeVarTupleType):
621-
raise NotImplementedError
671+
else:
672+
# This case should have been handled above.
673+
assert not isinstance(tvar, TypeVarTupleType)
622674

623675
return res
624676
elif self.direction == SUPERTYPE_OF and instance.type.has_base(template.type.fullname):
@@ -710,6 +762,9 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
710762
res.append(Constraint(template_arg, SUPERTYPE_OF, suffix))
711763
elif isinstance(suffix, ParamSpecType):
712764
res.append(Constraint(template_arg, SUPERTYPE_OF, suffix))
765+
else:
766+
# This case should have been handled above.
767+
assert not isinstance(tvar, TypeVarTupleType)
713768
return res
714769
if (
715770
template.type.is_protocol

mypy/nodes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2858,6 +2858,7 @@ def __init__(self, names: SymbolTable, defn: ClassDef, module_name: str) -> None
28582858
self.metadata = {}
28592859

28602860
def add_type_vars(self) -> None:
2861+
self.has_type_var_tuple_type = False
28612862
if self.defn.type_vars:
28622863
for i, vd in enumerate(self.defn.type_vars):
28632864
if isinstance(vd, mypy.types.ParamSpecType):

mypy/semanal.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1684,10 +1684,16 @@ def analyze_class_typevar_declaration(self, base: Type) -> tuple[TypeVarLikeList
16841684
):
16851685
is_proto = sym.node.fullname != "typing.Generic"
16861686
tvars: TypeVarLikeList = []
1687+
have_type_var_tuple = False
16871688
for arg in unbound.args:
16881689
tag = self.track_incomplete_refs()
16891690
tvar = self.analyze_unbound_tvar(arg)
16901691
if tvar:
1692+
if isinstance(tvar[1], TypeVarTupleExpr):
1693+
if have_type_var_tuple:
1694+
self.fail("Can only use one type var tuple in a class def", base)
1695+
continue
1696+
have_type_var_tuple = True
16911697
tvars.append(tvar)
16921698
elif not self.found_incomplete_ref(tag):
16931699
self.fail("Free type variable expected in %s[...]" % sym.node.name, base)
@@ -1706,11 +1712,19 @@ def analyze_unbound_tvar(self, t: Type) -> tuple[str, TypeVarLikeExpr] | None:
17061712
# It's bound by our type variable scope
17071713
return None
17081714
return unbound.name, sym.node
1709-
if sym and isinstance(sym.node, TypeVarTupleExpr):
1710-
if sym.fullname and not self.tvar_scope.allow_binding(sym.fullname):
1711-
# It's bound by our type variable scope
1715+
if sym and sym.fullname == "typing_extensions.Unpack":
1716+
inner_t = unbound.args[0]
1717+
if not isinstance(inner_t, UnboundType):
17121718
return None
1713-
return unbound.name, sym.node
1719+
inner_unbound = inner_t
1720+
inner_sym = self.lookup_qualified(inner_unbound.name, inner_unbound)
1721+
if inner_sym and isinstance(inner_sym.node, PlaceholderNode):
1722+
self.record_incomplete_ref()
1723+
if inner_sym and isinstance(inner_sym.node, TypeVarTupleExpr):
1724+
if inner_sym.fullname and not self.tvar_scope.allow_binding(inner_sym.fullname):
1725+
# It's bound by our type variable scope
1726+
return None
1727+
return inner_unbound.name, inner_sym.node
17141728
if sym is None or not isinstance(sym.node, TypeVarExpr):
17151729
return None
17161730
elif sym.fullname and not self.tvar_scope.allow_binding(sym.fullname):

test-data/unit/check-typevar-tuple.test

Lines changed: 93 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,18 +96,18 @@ reveal_type(h(args)) # N: Revealed type is "Tuple[builtins.str, builtins.str, b
9696

9797
[case testTypeVarTupleGenericClassDefn]
9898
from typing import Generic, TypeVar, Tuple
99-
from typing_extensions import TypeVarTuple
99+
from typing_extensions import TypeVarTuple, Unpack
100100

101101
T = TypeVar("T")
102102
Ts = TypeVarTuple("Ts")
103103

104-
class Variadic(Generic[Ts]):
104+
class Variadic(Generic[Unpack[Ts]]):
105105
pass
106106

107-
class Mixed1(Generic[T, Ts]):
107+
class Mixed1(Generic[T, Unpack[Ts]]):
108108
pass
109109

110-
class Mixed2(Generic[Ts, T]):
110+
class Mixed2(Generic[Unpack[Ts], T]):
111111
pass
112112

113113
variadic: Variadic[int, str]
@@ -133,7 +133,7 @@ Ts = TypeVarTuple("Ts")
133133
T = TypeVar("T")
134134
S = TypeVar("S")
135135

136-
class Variadic(Generic[T, Ts, S]):
136+
class Variadic(Generic[T, Unpack[Ts], S]):
137137
pass
138138

139139
def foo(t: Variadic[int, Unpack[Ts], object]) -> Tuple[int, Unpack[Ts]]:
@@ -152,7 +152,7 @@ Ts = TypeVarTuple("Ts")
152152
T = TypeVar("T")
153153
S = TypeVar("S")
154154

155-
class Variadic(Generic[T, Ts, S]):
155+
class Variadic(Generic[T, Unpack[Ts], S]):
156156
def __init__(self, t: Tuple[Unpack[Ts]]) -> None:
157157
...
158158

@@ -170,3 +170,90 @@ from typing_extensions import TypeVarTuple
170170
Ts = TypeVarTuple("Ts")
171171
B = Ts # E: Type variable "__main__.Ts" is invalid as target for type alias
172172
[builtins fixtures/tuple.pyi]
173+
174+
[case testPep646ArrayExample]
175+
from typing import Generic, Tuple, TypeVar, Protocol, NewType
176+
from typing_extensions import TypeVarTuple, Unpack
177+
178+
Shape = TypeVarTuple('Shape')
179+
180+
Height = NewType('Height', int)
181+
Width = NewType('Width', int)
182+
183+
T_co = TypeVar("T_co", covariant=True)
184+
T = TypeVar("T")
185+
186+
class SupportsAbs(Protocol[T_co]):
187+
def __abs__(self) -> T_co: pass
188+
189+
def abs(a: SupportsAbs[T]) -> T:
190+
...
191+
192+
class Array(Generic[Unpack[Shape]]):
193+
def __init__(self, shape: Tuple[Unpack[Shape]]):
194+
self._shape: Tuple[Unpack[Shape]] = shape
195+
196+
def get_shape(self) -> Tuple[Unpack[Shape]]:
197+
return self._shape
198+
199+
def __abs__(self) -> Array[Unpack[Shape]]: ...
200+
201+
def __add__(self, other: Array[Unpack[Shape]]) -> Array[Unpack[Shape]]: ...
202+
203+
shape = (Height(480), Width(640))
204+
x: Array[Height, Width] = Array(shape)
205+
reveal_type(abs(x)) # N: Revealed type is "__main__.Array[__main__.Height, __main__.Width]"
206+
reveal_type(x + x) # N: Revealed type is "__main__.Array[__main__.Height, __main__.Width]"
207+
208+
[builtins fixtures/tuple.pyi]
209+
[case testPep646ArrayExampleWithDType]
210+
from typing import Generic, Tuple, TypeVar, Protocol, NewType
211+
from typing_extensions import TypeVarTuple, Unpack
212+
213+
DType = TypeVar("DType")
214+
Shape = TypeVarTuple('Shape')
215+
216+
Height = NewType('Height', int)
217+
Width = NewType('Width', int)
218+
219+
T_co = TypeVar("T_co", covariant=True)
220+
T = TypeVar("T")
221+
222+
class SupportsAbs(Protocol[T_co]):
223+
def __abs__(self) -> T_co: pass
224+
225+
def abs(a: SupportsAbs[T]) -> T:
226+
...
227+
228+
class Array(Generic[DType, Unpack[Shape]]):
229+
def __init__(self, shape: Tuple[Unpack[Shape]]):
230+
self._shape: Tuple[Unpack[Shape]] = shape
231+
232+
def get_shape(self) -> Tuple[Unpack[Shape]]:
233+
return self._shape
234+
235+
def __abs__(self) -> Array[DType, Unpack[Shape]]: ...
236+
237+
def __add__(self, other: Array[DType, Unpack[Shape]]) -> Array[DType, Unpack[Shape]]: ...
238+
239+
shape = (Height(480), Width(640))
240+
x: Array[float, Height, Width] = Array(shape)
241+
reveal_type(abs(x)) # N: Revealed type is "__main__.Array[builtins.float, __main__.Height, __main__.Width]"
242+
reveal_type(x + x) # N: Revealed type is "__main__.Array[builtins.float, __main__.Height, __main__.Width]"
243+
244+
[builtins fixtures/tuple.pyi]
245+
246+
[case testPep646ArrayExampleInfer]
247+
from typing import Generic, Tuple, TypeVar, NewType
248+
from typing_extensions import TypeVarTuple, Unpack
249+
250+
Shape = TypeVarTuple('Shape')
251+
252+
Height = NewType('Height', int)
253+
Width = NewType('Width', int)
254+
255+
class Array(Generic[Unpack[Shape]]):
256+
pass
257+
258+
x: Array[float, Height, Width] = Array()
259+
[builtins fixtures/tuple.pyi]

test-data/unit/semanal-errors.test

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,9 +1456,11 @@ bad: Tuple[Unpack[int]] # E: builtins.int cannot be unpacked (must be tuple or
14561456
[builtins fixtures/tuple.pyi]
14571457

14581458
[case testTypeVarTuple]
1459+
from typing import Generic
14591460
from typing_extensions import TypeVarTuple, Unpack
14601461

14611462
TVariadic = TypeVarTuple('TVariadic')
1463+
TVariadic2 = TypeVarTuple('TVariadic2')
14621464
TP = TypeVarTuple('?') # E: String argument 1 "?" to TypeVarTuple(...) does not match variable name "TP"
14631465
TP2: int = TypeVarTuple('TP2') # E: Cannot declare the type of a TypeVar or similar construct
14641466
TP3 = TypeVarTuple() # E: Too few arguments for TypeVarTuple()
@@ -1467,3 +1469,7 @@ TP5 = TypeVarTuple(t='TP5') # E: TypeVarTuple() expects a string literal as fir
14671469

14681470
x: TVariadic # E: TypeVarTuple "TVariadic" is unbound
14691471
y: Unpack[TVariadic] # E: TypeVarTuple "TVariadic" is unbound
1472+
1473+
1474+
class Variadic(Generic[Unpack[TVariadic], Unpack[TVariadic2]]): # E: Can only use one type var tuple in a class def
1475+
pass

0 commit comments

Comments
 (0)