Skip to content

Commit 0d07c90

Browse files
authored
Extend support for tuple subtyping with typevar tuples (#13718)
This adds a new Pep646 test case which demonstrates that like with the constraints from PR #13716 we need to split twice for subtyping when handling Unpacks. It also demonstrates a weakness of the previous PR which is that the middle-prefix and the prefix may need to be handled differently so we introduce another splitting function that returns a 10-tuple instead of a 6-tuple and reimplement the 6-tuple version on top of the 10-tuple version. Complicating things further, the test case reveals that there are error cases where split_with_mapped_and_template cannot actually unpack the middle a second time because the mapped middle is too short to do the unpack. We also now have to deal with the case where there was no unpack in the template in which case we only do a single split. In addition we fix a behavioral issue where according to PEP646 we should assume that Tuple[Unpack[Tuple[Any, ...]]] is equivalent to Tuple[Any, Any] even if we don't actually know the lengths match. As such test_type_var_tuple_unpacked_variable_length_tuple changes from asserting a strict subtype to asserting equivalence. One of the messages was bad as well so we add a branch for UnpackType in message pretty-printing.
1 parent 9227bce commit 0d07c90

File tree

6 files changed

+173
-20
lines changed

6 files changed

+173
-20
lines changed

mypy/constraints.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,14 +678,20 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
678678
mapped = map_instance_to_supertype(instance, template.type)
679679
tvars = template.type.defn.type_vars
680680
if template.type.has_type_var_tuple_type:
681+
mapped_prefix, mapped_middle, mapped_suffix = split_with_instance(mapped)
682+
template_prefix, template_middle, template_suffix = split_with_instance(
683+
template
684+
)
685+
split_result = split_with_mapped_and_template(mapped, template)
686+
assert split_result is not None
681687
(
682688
mapped_prefix,
683689
mapped_middle,
684690
mapped_suffix,
685691
template_prefix,
686692
template_middle,
687693
template_suffix,
688-
) = split_with_mapped_and_template(mapped, template)
694+
) = split_result
689695

690696
# Add a constraint for the type var tuple, and then
691697
# remove it for the case below.

mypy/messages.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
UnboundType,
8585
UninhabitedType,
8686
UnionType,
87+
UnpackType,
8788
get_proper_type,
8889
get_proper_types,
8990
)
@@ -2257,6 +2258,8 @@ def format_literal_value(typ: LiteralType) -> str:
22572258
else:
22582259
# There are type arguments. Convert the arguments to strings.
22592260
return f"{base_str}[{format_list(itype.args)}]"
2261+
elif isinstance(typ, UnpackType):
2262+
return f"Unpack[{format(typ.type)}]"
22602263
elif isinstance(typ, TypeVarType):
22612264
# This is similar to non-generic instance types.
22622265
return typ.name

mypy/subtypes.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
)
6363
from mypy.typestate import SubtypeKind, TypeState
6464
from mypy.typevars import fill_typevars_with_any
65-
from mypy.typevartuples import extract_unpack, split_with_instance
65+
from mypy.typevartuples import extract_unpack, fully_split_with_mapped_and_template
6666

6767
# Flags for detected protocol members
6868
IS_SETTABLE: Final = 1
@@ -485,8 +485,22 @@ def visit_instance(self, left: Instance) -> bool:
485485
t = erased
486486
nominal = True
487487
if right.type.has_type_var_tuple_type:
488-
left_prefix, left_middle, left_suffix = split_with_instance(left)
489-
right_prefix, right_middle, right_suffix = split_with_instance(right)
488+
split_result = fully_split_with_mapped_and_template(left, right)
489+
if split_result is None:
490+
return False
491+
492+
(
493+
left_prefix,
494+
left_mprefix,
495+
left_middle,
496+
left_msuffix,
497+
left_suffix,
498+
right_prefix,
499+
right_mprefix,
500+
right_middle,
501+
right_msuffix,
502+
right_suffix,
503+
) = split_result
490504

491505
left_unpacked = extract_unpack(left_middle)
492506
right_unpacked = extract_unpack(right_middle)
@@ -495,6 +509,15 @@ def visit_instance(self, left: Instance) -> bool:
495509
def check_mixed(
496510
unpacked_type: ProperType, compare_to: tuple[Type, ...]
497511
) -> bool:
512+
if (
513+
isinstance(unpacked_type, Instance)
514+
and unpacked_type.type.fullname == "builtins.tuple"
515+
):
516+
if not all(
517+
is_equivalent(l, unpacked_type.args[0]) for l in compare_to
518+
):
519+
return False
520+
return True
498521
if isinstance(unpacked_type, TypeVarTupleType):
499522
return False
500523
if isinstance(unpacked_type, AnyType):
@@ -521,13 +544,6 @@ def check_mixed(
521544
if not check_mixed(left_unpacked, right_middle):
522545
return False
523546
elif left_unpacked is None and right_unpacked is not None:
524-
if (
525-
isinstance(right_unpacked, Instance)
526-
and right_unpacked.type.fullname == "builtins.tuple"
527-
):
528-
return all(
529-
is_equivalent(l, right_unpacked.args[0]) for l in left_middle
530-
)
531547
if not check_mixed(right_unpacked, left_middle):
532548
return False
533549

@@ -540,16 +556,24 @@ def check_mixed(
540556
if not is_equivalent(left_t, right_t):
541557
return False
542558

559+
assert len(left_mprefix) == len(right_mprefix)
560+
assert len(left_msuffix) == len(right_msuffix)
561+
562+
for left_item, right_item in zip(
563+
left_mprefix + left_msuffix, right_mprefix + right_msuffix
564+
):
565+
if not is_equivalent(left_item, right_item):
566+
return False
567+
543568
left_items = t.args[: right.type.type_var_tuple_prefix]
544569
right_items = right.args[: right.type.type_var_tuple_prefix]
545570
if right.type.type_var_tuple_suffix:
546571
left_items += t.args[-right.type.type_var_tuple_suffix :]
547572
right_items += right.args[-right.type.type_var_tuple_suffix :]
548-
549573
unpack_index = right.type.type_var_tuple_prefix
550574
assert unpack_index is not None
551575
type_params = zip(
552-
left_prefix + right_suffix,
576+
left_prefix + left_suffix,
553577
right_prefix + right_suffix,
554578
right.type.defn.type_vars[:unpack_index]
555579
+ right.type.defn.type_vars[unpack_index + 1 :],

mypy/test/testsubtypes.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,22 @@ def test_type_var_tuple_with_prefix_suffix(self) -> None:
273273
Instance(self.fx.gvi, [self.fx.a, UnpackType(self.fx.ss), self.fx.b, self.fx.c]),
274274
)
275275

276+
def test_type_var_tuple_unpacked_varlength_tuple(self) -> None:
277+
self.assert_subtype(
278+
Instance(
279+
self.fx.gvi,
280+
[
281+
UnpackType(
282+
TupleType(
283+
[self.fx.a, self.fx.b],
284+
fallback=Instance(self.fx.std_tuplei, [self.fx.o]),
285+
)
286+
)
287+
],
288+
),
289+
Instance(self.fx.gvi, [self.fx.a, self.fx.b]),
290+
)
291+
276292
def test_type_var_tuple_unpacked_tuple(self) -> None:
277293
self.assert_subtype(
278294
Instance(
@@ -333,7 +349,7 @@ def test_type_var_tuple_unpacked_tuple(self) -> None:
333349
)
334350

335351
def test_type_var_tuple_unpacked_variable_length_tuple(self) -> None:
336-
self.assert_strict_subtype(
352+
self.assert_equivalent(
337353
Instance(self.fx.gvi, [self.fx.a, self.fx.a]),
338354
Instance(self.fx.gvi, [UnpackType(Instance(self.fx.std_tuplei, [self.fx.a]))]),
339355
)

mypy/typevartuples.py

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,70 @@ def split_with_mapped_and_template(
5353
tuple[Type, ...],
5454
tuple[Type, ...],
5555
tuple[Type, ...],
56-
]:
56+
] | None:
57+
split_result = fully_split_with_mapped_and_template(mapped, template)
58+
if split_result is None:
59+
return None
60+
61+
(
62+
mapped_prefix,
63+
mapped_middle_prefix,
64+
mapped_middle_middle,
65+
mapped_middle_suffix,
66+
mapped_suffix,
67+
template_prefix,
68+
template_middle_prefix,
69+
template_middle_middle,
70+
template_middle_suffix,
71+
template_suffix,
72+
) = split_result
73+
74+
return (
75+
mapped_prefix + mapped_middle_prefix,
76+
mapped_middle_middle,
77+
mapped_middle_suffix + mapped_suffix,
78+
template_prefix + template_middle_prefix,
79+
template_middle_middle,
80+
template_middle_suffix + template_suffix,
81+
)
82+
83+
84+
def fully_split_with_mapped_and_template(
85+
mapped: Instance, template: Instance
86+
) -> tuple[
87+
tuple[Type, ...],
88+
tuple[Type, ...],
89+
tuple[Type, ...],
90+
tuple[Type, ...],
91+
tuple[Type, ...],
92+
tuple[Type, ...],
93+
tuple[Type, ...],
94+
tuple[Type, ...],
95+
tuple[Type, ...],
96+
tuple[Type, ...],
97+
] | None:
5798
mapped_prefix, mapped_middle, mapped_suffix = split_with_instance(mapped)
5899
template_prefix, template_middle, template_suffix = split_with_instance(template)
59100

60101
unpack_prefix = find_unpack_in_list(template_middle)
61-
assert unpack_prefix is not None
102+
if unpack_prefix is None:
103+
return (
104+
mapped_prefix,
105+
(),
106+
mapped_middle,
107+
(),
108+
mapped_suffix,
109+
template_prefix,
110+
(),
111+
template_middle,
112+
(),
113+
template_suffix,
114+
)
115+
62116
unpack_suffix = len(template_middle) - unpack_prefix - 1
117+
# mapped_middle is too short to do the unpack
118+
if unpack_prefix + unpack_suffix > len(mapped_middle):
119+
return None
63120

64121
(
65122
mapped_middle_prefix,
@@ -73,12 +130,16 @@ def split_with_mapped_and_template(
73130
) = split_with_prefix_and_suffix(template_middle, unpack_prefix, unpack_suffix)
74131

75132
return (
76-
mapped_prefix + mapped_middle_prefix,
133+
mapped_prefix,
134+
mapped_middle_prefix,
77135
mapped_middle_middle,
78-
mapped_middle_suffix + mapped_suffix,
79-
template_prefix + template_middle_prefix,
136+
mapped_middle_suffix,
137+
mapped_suffix,
138+
template_prefix,
139+
template_middle_prefix,
80140
template_middle_middle,
81-
template_middle_suffix + template_suffix,
141+
template_middle_suffix,
142+
template_suffix,
82143
)
83144

84145

test-data/unit/check-typevar-tuple.test

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,3 +304,46 @@ def prefix_tuple(
304304
z = prefix_tuple(x=0, y=(True, 'a'))
305305
reveal_type(z) # N: Revealed type is "Tuple[builtins.int, builtins.bool, builtins.str]"
306306
[builtins fixtures/tuple.pyi]
307+
[case testPep646TypeVarTupleUnpacking]
308+
from typing import Generic, TypeVar, NewType, Any, Tuple
309+
from typing_extensions import TypeVarTuple, Unpack
310+
311+
Shape = TypeVarTuple('Shape')
312+
313+
Channels = NewType("Channels", int)
314+
Batch = NewType("Batch", int)
315+
Height = NewType('Height', int)
316+
Width = NewType('Width', int)
317+
318+
class Array(Generic[Unpack[Shape]]):
319+
pass
320+
321+
def process_batch_channels(
322+
x: Array[Batch, Unpack[Tuple[Any, ...]], Channels]
323+
) -> None:
324+
...
325+
326+
x: Array[Batch, Height, Width, Channels]
327+
process_batch_channels(x)
328+
y: Array[Batch, Channels]
329+
process_batch_channels(y)
330+
z: Array[Batch]
331+
process_batch_channels(z) # E: Argument 1 to "process_batch_channels" has incompatible type "Array[Batch]"; expected "Array[Batch, Unpack[Tuple[Any, ...]], Channels]"
332+
333+
u: Array[Unpack[Tuple[Any, ...]]]
334+
335+
def expect_variadic_array(
336+
x: Array[Batch, Unpack[Shape]]
337+
) -> None:
338+
...
339+
340+
def expect_variadic_array_2(
341+
x: Array[Batch, Height, Width, Channels]
342+
) -> None:
343+
...
344+
345+
expect_variadic_array(u)
346+
expect_variadic_array_2(u)
347+
348+
349+
[builtins fixtures/tuple.pyi]

0 commit comments

Comments
 (0)