Skip to content

Commit 06aefc6

Browse files
authored
Tweak constraints handling for splitting typevartuples (#13716)
The existing logic for splitting mapped & template into prefix, middle, and suffix does not handle the case where the template middle itself is not a singleton unpack but rather itself has a prefix & suffix. In this case we need to pull out the prefix & suffix by doing a second round of splitting on the middle. Originally we weren't sure if the PEP required implementing this double split, but one of the PEP646 test cases requires it.
1 parent a677f49 commit 06aefc6

File tree

3 files changed

+94
-4
lines changed

3 files changed

+94
-4
lines changed

mypy/constraints.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
extract_unpack,
5555
find_unpack_in_list,
5656
split_with_instance,
57+
split_with_mapped_and_template,
5758
split_with_prefix_and_suffix,
5859
)
5960

@@ -677,10 +678,14 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
677678
mapped = map_instance_to_supertype(instance, template.type)
678679
tvars = template.type.defn.type_vars
679680
if template.type.has_type_var_tuple_type:
680-
mapped_prefix, mapped_middle, mapped_suffix = split_with_instance(mapped)
681-
template_prefix, template_middle, template_suffix = split_with_instance(
682-
template
683-
)
681+
(
682+
mapped_prefix,
683+
mapped_middle,
684+
mapped_suffix,
685+
template_prefix,
686+
template_middle,
687+
template_suffix,
688+
) = split_with_mapped_and_template(mapped, template)
684689

685690
# Add a constraint for the type var tuple, and then
686691
# remove it for the case below.

mypy/typevartuples.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,44 @@ def split_with_instance(
4444
)
4545

4646

47+
def split_with_mapped_and_template(
48+
mapped: Instance, template: Instance
49+
) -> tuple[
50+
tuple[Type, ...],
51+
tuple[Type, ...],
52+
tuple[Type, ...],
53+
tuple[Type, ...],
54+
tuple[Type, ...],
55+
tuple[Type, ...],
56+
]:
57+
mapped_prefix, mapped_middle, mapped_suffix = split_with_instance(mapped)
58+
template_prefix, template_middle, template_suffix = split_with_instance(template)
59+
60+
unpack_prefix = find_unpack_in_list(template_middle)
61+
assert unpack_prefix is not None
62+
unpack_suffix = len(template_middle) - unpack_prefix - 1
63+
64+
(
65+
mapped_middle_prefix,
66+
mapped_middle_middle,
67+
mapped_middle_suffix,
68+
) = split_with_prefix_and_suffix(mapped_middle, unpack_prefix, unpack_suffix)
69+
(
70+
template_middle_prefix,
71+
template_middle_middle,
72+
template_middle_suffix,
73+
) = split_with_prefix_and_suffix(template_middle, unpack_prefix, unpack_suffix)
74+
75+
return (
76+
mapped_prefix + mapped_middle_prefix,
77+
mapped_middle_middle,
78+
mapped_middle_suffix + mapped_suffix,
79+
template_prefix + template_middle_prefix,
80+
template_middle_middle,
81+
template_middle_suffix + template_suffix,
82+
)
83+
84+
4785
def extract_unpack(types: Sequence[Type]) -> ProperType | None:
4886
"""Given a list of types, extracts either a single type from an unpack, or returns None."""
4987
if len(types) == 1:

test-data/unit/check-typevar-tuple.test

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,50 @@ class Array(Generic[Unpack[Shape]]):
257257

258258
x: Array[float, Height, Width] = Array()
259259
[builtins fixtures/tuple.pyi]
260+
261+
[case testPep646TypeConcatenation]
262+
from typing import Generic, TypeVar, NewType
263+
from typing_extensions import TypeVarTuple, Unpack
264+
265+
Shape = TypeVarTuple('Shape')
266+
267+
Channels = NewType("Channels", int)
268+
Batch = NewType("Batch", int)
269+
Height = NewType('Height', int)
270+
Width = NewType('Width', int)
271+
272+
class Array(Generic[Unpack[Shape]]):
273+
pass
274+
275+
276+
def add_batch_axis(x: Array[Unpack[Shape]]) -> Array[Batch, Unpack[Shape]]: ...
277+
def del_batch_axis(x: Array[Batch, Unpack[Shape]]) -> Array[Unpack[Shape]]: ...
278+
def add_batch_channels(
279+
x: Array[Unpack[Shape]]
280+
) -> Array[Batch, Unpack[Shape], Channels]: ...
281+
282+
a: Array[Height, Width]
283+
b = add_batch_axis(a)
284+
reveal_type(b) # N: Revealed type is "__main__.Array[__main__.Batch, __main__.Height, __main__.Width]"
285+
c = del_batch_axis(b)
286+
reveal_type(c) # N: Revealed type is "__main__.Array[__main__.Height, __main__.Width]"
287+
d = add_batch_channels(a)
288+
reveal_type(d) # N: Revealed type is "__main__.Array[__main__.Batch, __main__.Height, __main__.Width, __main__.Channels]"
289+
290+
[builtins fixtures/tuple.pyi]
291+
[case testPep646TypeVarConcatenation]
292+
from typing import Generic, TypeVar, NewType, Tuple
293+
from typing_extensions import TypeVarTuple, Unpack
294+
295+
T = TypeVar('T')
296+
Ts = TypeVarTuple('Ts')
297+
298+
def prefix_tuple(
299+
x: T,
300+
y: Tuple[Unpack[Ts]],
301+
) -> Tuple[T, Unpack[Ts]]:
302+
...
303+
304+
z = prefix_tuple(x=0, y=(True, 'a'))
305+
reveal_type(z) # N: Revealed type is "Tuple[builtins.int, builtins.bool, builtins.str]"
306+
[builtins fixtures/tuple.pyi]

0 commit comments

Comments
 (0)