Skip to content

Extend support for tuple subtyping with typevar tuples #13718

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,14 +678,20 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
mapped = map_instance_to_supertype(instance, template.type)
tvars = template.type.defn.type_vars
if template.type.has_type_var_tuple_type:
mapped_prefix, mapped_middle, mapped_suffix = split_with_instance(mapped)
template_prefix, template_middle, template_suffix = split_with_instance(
template
)
split_result = split_with_mapped_and_template(mapped, template)
assert split_result is not None
(
mapped_prefix,
mapped_middle,
mapped_suffix,
template_prefix,
template_middle,
template_suffix,
) = split_with_mapped_and_template(mapped, template)
) = split_result

# Add a constraint for the type var tuple, and then
# remove it for the case below.
Expand Down
3 changes: 3 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
UnboundType,
UninhabitedType,
UnionType,
UnpackType,
get_proper_type,
get_proper_types,
)
Expand Down Expand Up @@ -2229,6 +2230,8 @@ def format_literal_value(typ: LiteralType) -> str:
else:
# There are type arguments. Convert the arguments to strings.
return f"{base_str}[{format_list(itype.args)}]"
elif isinstance(typ, UnpackType):
return f"Unpack[{format(typ.type)}]"
elif isinstance(typ, TypeVarType):
# This is similar to non-generic instance types.
return typ.name
Expand Down
48 changes: 36 additions & 12 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
)
from mypy.typestate import SubtypeKind, TypeState
from mypy.typevars import fill_typevars_with_any
from mypy.typevartuples import extract_unpack, split_with_instance
from mypy.typevartuples import extract_unpack, fully_split_with_mapped_and_template

# Flags for detected protocol members
IS_SETTABLE: Final = 1
Expand Down Expand Up @@ -485,8 +485,22 @@ def visit_instance(self, left: Instance) -> bool:
t = erased
nominal = True
if right.type.has_type_var_tuple_type:
left_prefix, left_middle, left_suffix = split_with_instance(left)
right_prefix, right_middle, right_suffix = split_with_instance(right)
split_result = fully_split_with_mapped_and_template(left, right)
if split_result is None:
return False

(
left_prefix,
left_mprefix,
left_middle,
left_msuffix,
left_suffix,
right_prefix,
right_mprefix,
right_middle,
right_msuffix,
right_suffix,
) = split_result

left_unpacked = extract_unpack(left_middle)
right_unpacked = extract_unpack(right_middle)
Expand All @@ -495,6 +509,15 @@ def visit_instance(self, left: Instance) -> bool:
def check_mixed(
unpacked_type: ProperType, compare_to: tuple[Type, ...]
) -> bool:
if (
isinstance(unpacked_type, Instance)
and unpacked_type.type.fullname == "builtins.tuple"
):
if not all(
is_equivalent(l, unpacked_type.args[0]) for l in compare_to
):
return False
return True
if isinstance(unpacked_type, TypeVarTupleType):
return False
if isinstance(unpacked_type, AnyType):
Expand All @@ -521,13 +544,6 @@ def check_mixed(
if not check_mixed(left_unpacked, right_middle):
return False
elif left_unpacked is None and right_unpacked is not None:
if (
isinstance(right_unpacked, Instance)
and right_unpacked.type.fullname == "builtins.tuple"
):
return all(
is_equivalent(l, right_unpacked.args[0]) for l in left_middle
)
if not check_mixed(right_unpacked, left_middle):
return False

Expand All @@ -540,16 +556,24 @@ def check_mixed(
if not is_equivalent(left_t, right_t):
return False

assert len(left_mprefix) == len(right_mprefix)
assert len(left_msuffix) == len(right_msuffix)

for left_item, right_item in zip(
left_mprefix + left_msuffix, right_mprefix + right_msuffix
):
if not is_equivalent(left_item, right_item):
return False

left_items = t.args[: right.type.type_var_tuple_prefix]
right_items = right.args[: right.type.type_var_tuple_prefix]
if right.type.type_var_tuple_suffix:
left_items += t.args[-right.type.type_var_tuple_suffix :]
right_items += right.args[-right.type.type_var_tuple_suffix :]

unpack_index = right.type.type_var_tuple_prefix
assert unpack_index is not None
type_params = zip(
left_prefix + right_suffix,
left_prefix + left_suffix,
right_prefix + right_suffix,
right.type.defn.type_vars[:unpack_index]
+ right.type.defn.type_vars[unpack_index + 1 :],
Expand Down
18 changes: 17 additions & 1 deletion mypy/test/testsubtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,22 @@ def test_type_var_tuple_with_prefix_suffix(self) -> None:
Instance(self.fx.gvi, [self.fx.a, UnpackType(self.fx.ss), self.fx.b, self.fx.c]),
)

def test_type_var_tuple_unpacked_varlength_tuple(self) -> None:
self.assert_subtype(
Instance(
self.fx.gvi,
[
UnpackType(
TupleType(
[self.fx.a, self.fx.b],
fallback=Instance(self.fx.std_tuplei, [self.fx.o]),
)
)
],
),
Instance(self.fx.gvi, [self.fx.a, self.fx.b]),
)

def test_type_var_tuple_unpacked_tuple(self) -> None:
self.assert_subtype(
Instance(
Expand Down Expand Up @@ -333,7 +349,7 @@ def test_type_var_tuple_unpacked_tuple(self) -> None:
)

def test_type_var_tuple_unpacked_variable_length_tuple(self) -> None:
self.assert_strict_subtype(
self.assert_equivalent(
Instance(self.fx.gvi, [self.fx.a, self.fx.a]),
Instance(self.fx.gvi, [UnpackType(Instance(self.fx.std_tuplei, [self.fx.a]))]),
)
Expand Down
73 changes: 67 additions & 6 deletions mypy/typevartuples.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,70 @@ def split_with_mapped_and_template(
tuple[Type, ...],
tuple[Type, ...],
tuple[Type, ...],
]:
] | None:
split_result = fully_split_with_mapped_and_template(mapped, template)
if split_result is None:
return None

(
mapped_prefix,
mapped_middle_prefix,
mapped_middle_middle,
mapped_middle_suffix,
mapped_suffix,
template_prefix,
template_middle_prefix,
template_middle_middle,
template_middle_suffix,
template_suffix,
) = split_result

return (
mapped_prefix + mapped_middle_prefix,
mapped_middle_middle,
mapped_middle_suffix + mapped_suffix,
template_prefix + template_middle_prefix,
template_middle_middle,
template_middle_suffix + template_suffix,
)


def fully_split_with_mapped_and_template(
mapped: Instance, template: Instance
) -> tuple[
tuple[Type, ...],
tuple[Type, ...],
tuple[Type, ...],
tuple[Type, ...],
tuple[Type, ...],
tuple[Type, ...],
tuple[Type, ...],
tuple[Type, ...],
tuple[Type, ...],
tuple[Type, ...],
] | None:
mapped_prefix, mapped_middle, mapped_suffix = split_with_instance(mapped)
template_prefix, template_middle, template_suffix = split_with_instance(template)

unpack_prefix = find_unpack_in_list(template_middle)
assert unpack_prefix is not None
if unpack_prefix is None:
return (
mapped_prefix,
(),
mapped_middle,
(),
mapped_suffix,
template_prefix,
(),
template_middle,
(),
template_suffix,
)

unpack_suffix = len(template_middle) - unpack_prefix - 1
# mapped_middle is too short to do the unpack
if unpack_prefix + unpack_suffix > len(mapped_middle):
return None

(
mapped_middle_prefix,
Expand All @@ -73,12 +130,16 @@ def split_with_mapped_and_template(
) = split_with_prefix_and_suffix(template_middle, unpack_prefix, unpack_suffix)

return (
mapped_prefix + mapped_middle_prefix,
mapped_prefix,
mapped_middle_prefix,
mapped_middle_middle,
mapped_middle_suffix + mapped_suffix,
template_prefix + template_middle_prefix,
mapped_middle_suffix,
mapped_suffix,
template_prefix,
template_middle_prefix,
template_middle_middle,
template_middle_suffix + template_suffix,
template_middle_suffix,
template_suffix,
)


Expand Down
43 changes: 43 additions & 0 deletions test-data/unit/check-typevar-tuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -304,3 +304,46 @@ def prefix_tuple(
z = prefix_tuple(x=0, y=(True, 'a'))
reveal_type(z) # N: Revealed type is "Tuple[builtins.int, builtins.bool, builtins.str]"
[builtins fixtures/tuple.pyi]
[case testPep646TypeVarTupleUnpacking]
from typing import Generic, TypeVar, NewType, Any, Tuple
from typing_extensions import TypeVarTuple, Unpack

Shape = TypeVarTuple('Shape')

Channels = NewType("Channels", int)
Batch = NewType("Batch", int)
Height = NewType('Height', int)
Width = NewType('Width', int)

class Array(Generic[Unpack[Shape]]):
pass

def process_batch_channels(
x: Array[Batch, Unpack[Tuple[Any, ...]], Channels]
) -> None:
...

x: Array[Batch, Height, Width, Channels]
process_batch_channels(x)
y: Array[Batch, Channels]
process_batch_channels(y)
z: Array[Batch]
process_batch_channels(z) # E: Argument 1 to "process_batch_channels" has incompatible type "Array[Batch]"; expected "Array[Batch, Unpack[Tuple[Any, ...]], Channels]"

u: Array[Unpack[Tuple[Any, ...]]]

def expect_variadic_array(
x: Array[Batch, Unpack[Shape]]
) -> None:
...

def expect_variadic_array_2(
x: Array[Batch, Height, Width, Channels]
) -> None:
...

expect_variadic_array(u)
expect_variadic_array_2(u)


[builtins fixtures/tuple.pyi]