Skip to content

Commit 885e361

Browse files
authored
Handle prefix/suffix in typevartuple *args support (#14112)
This requires handling more cases in the various places that we previously modified to support *args in general. We also need to refresh the formals-to-actuals twice in checkexpr as now it can happen in the infer_function_type_arguments_using_context call. The handling here is kind of asymmetric, because we can convert prefices into positional arguments, but there is no equivalent for suffices, so we represent that as a Tuple[Unpack[...], <suffix>] and handle that case separately in some spots. We also support various edge cases like passing in a tuple without any typevartuples involved.
1 parent 48c4a47 commit 885e361

File tree

7 files changed

+284
-62
lines changed

7 files changed

+284
-62
lines changed

mypy/applytype.py

Lines changed: 50 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44

55
import mypy.subtypes
66
from mypy.expandtype import expand_type, expand_unpack_with_variables
7-
from mypy.nodes import ARG_POS, ARG_STAR, Context
7+
from mypy.nodes import ARG_STAR, Context
88
from mypy.types import (
99
AnyType,
1010
CallableType,
1111
Parameters,
1212
ParamSpecType,
1313
PartialType,
14+
TupleType,
1415
Type,
1516
TypeVarId,
1617
TypeVarLikeType,
@@ -19,6 +20,7 @@
1920
UnpackType,
2021
get_proper_type,
2122
)
23+
from mypy.typevartuples import find_unpack_in_list, replace_starargs
2224

2325

2426
def get_target_type(
@@ -114,39 +116,57 @@ def apply_generic_arguments(
114116
# Apply arguments to argument types.
115117
var_arg = callable.var_arg()
116118
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
117-
expanded = expand_unpack_with_variables(var_arg.typ, id_to_type)
118-
assert isinstance(expanded, list)
119-
# Handle other cases later.
120-
for t in expanded:
121-
assert not isinstance(t, UnpackType)
122119
star_index = callable.arg_kinds.index(ARG_STAR)
123-
arg_kinds = (
124-
callable.arg_kinds[:star_index]
125-
+ [ARG_POS] * len(expanded)
126-
+ callable.arg_kinds[star_index + 1 :]
120+
callable = callable.copy_modified(
121+
arg_types=(
122+
[
123+
expand_type(at, id_to_type, allow_erased_callables)
124+
for at in callable.arg_types[:star_index]
125+
]
126+
+ [callable.arg_types[star_index]]
127+
+ [
128+
expand_type(at, id_to_type, allow_erased_callables)
129+
for at in callable.arg_types[star_index + 1 :]
130+
]
131+
)
127132
)
128-
arg_names = (
129-
callable.arg_names[:star_index]
130-
+ [None] * len(expanded)
131-
+ callable.arg_names[star_index + 1 :]
132-
)
133-
arg_types = (
134-
[
135-
expand_type(at, id_to_type, allow_erased_callables)
136-
for at in callable.arg_types[:star_index]
137-
]
138-
+ expanded
139-
+ [
140-
expand_type(at, id_to_type, allow_erased_callables)
141-
for at in callable.arg_types[star_index + 1 :]
133+
134+
unpacked_type = get_proper_type(var_arg.typ.type)
135+
if isinstance(unpacked_type, TupleType):
136+
# Assuming for now that because we convert prefixes to positional arguments,
137+
# the first argument is always an unpack.
138+
expanded_tuple = expand_type(unpacked_type, id_to_type)
139+
if isinstance(expanded_tuple, TupleType):
140+
# TODO: handle the case where the tuple has an unpack. This will
141+
# hit an assert below.
142+
expanded_unpack = find_unpack_in_list(expanded_tuple.items)
143+
if expanded_unpack is not None:
144+
callable = callable.copy_modified(
145+
arg_types=(
146+
callable.arg_types[:star_index]
147+
+ [expanded_tuple]
148+
+ callable.arg_types[star_index + 1 :]
149+
)
150+
)
151+
else:
152+
callable = replace_starargs(callable, expanded_tuple.items)
153+
else:
154+
# TODO: handle the case for if we get a variable length tuple.
155+
assert False, f"mypy bug: unimplemented case, {expanded_tuple}"
156+
elif isinstance(unpacked_type, TypeVarTupleType):
157+
expanded_tvt = expand_unpack_with_variables(var_arg.typ, id_to_type)
158+
assert isinstance(expanded_tvt, list)
159+
for t in expanded_tvt:
160+
assert not isinstance(t, UnpackType)
161+
callable = replace_starargs(callable, expanded_tvt)
162+
else:
163+
assert False, "mypy bug: unhandled case applying unpack"
164+
else:
165+
callable = callable.copy_modified(
166+
arg_types=[
167+
expand_type(at, id_to_type, allow_erased_callables) for at in callable.arg_types
142168
]
143169
)
144-
else:
145-
arg_types = [
146-
expand_type(at, id_to_type, allow_erased_callables) for at in callable.arg_types
147-
]
148-
arg_kinds = callable.arg_kinds
149-
arg_names = callable.arg_names
150170

151171
# Apply arguments to TypeGuard if any.
152172
if callable.type_guard is not None:
@@ -158,10 +178,7 @@ def apply_generic_arguments(
158178
remaining_tvars = [tv for tv in tvars if tv.id not in id_to_type]
159179

160180
return callable.copy_modified(
161-
arg_types=arg_types,
162181
ret_type=expand_type(callable.ret_type, id_to_type, allow_erased_callables),
163182
variables=remaining_tvars,
164183
type_guard=type_guard,
165-
arg_kinds=arg_kinds,
166-
arg_names=arg_names,
167184
)

mypy/checker.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,12 +1178,17 @@ def check_func_def(
11781178
if isinstance(arg_type, ParamSpecType):
11791179
pass
11801180
elif isinstance(arg_type, UnpackType):
1181-
arg_type = TupleType(
1182-
[arg_type],
1183-
fallback=self.named_generic_type(
1184-
"builtins.tuple", [self.named_type("builtins.object")]
1185-
),
1186-
)
1181+
if isinstance(get_proper_type(arg_type.type), TupleType):
1182+
# Instead of using Tuple[Unpack[Tuple[...]]], just use
1183+
# Tuple[...]
1184+
arg_type = arg_type.type
1185+
else:
1186+
arg_type = TupleType(
1187+
[arg_type],
1188+
fallback=self.named_generic_type(
1189+
"builtins.tuple", [self.named_type("builtins.object")]
1190+
),
1191+
)
11871192
else:
11881193
# builtins.tuple[T] is typing.Tuple[T, ...]
11891194
arg_type = self.named_generic_type("builtins.tuple", [arg_type])

mypy/checkexpr.py

Lines changed: 74 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@
150150
TypeVarType,
151151
UninhabitedType,
152152
UnionType,
153+
UnpackType,
153154
flatten_nested_unions,
154155
get_proper_type,
155156
get_proper_types,
@@ -1404,13 +1405,21 @@ def check_callable_call(
14041405
)
14051406
callee = freshen_function_type_vars(callee)
14061407
callee = self.infer_function_type_arguments_using_context(callee, context)
1408+
if need_refresh:
1409+
# Argument kinds etc. may have changed due to
1410+
# ParamSpec or TypeVarTuple variables being replaced with an arbitrary
1411+
# number of arguments; recalculate actual-to-formal map
1412+
formal_to_actual = map_actuals_to_formals(
1413+
arg_kinds,
1414+
arg_names,
1415+
callee.arg_kinds,
1416+
callee.arg_names,
1417+
lambda i: self.accept(args[i]),
1418+
)
14071419
callee = self.infer_function_type_arguments(
14081420
callee, args, arg_kinds, formal_to_actual, context
14091421
)
14101422
if need_refresh:
1411-
# Argument kinds etc. may have changed due to
1412-
# ParamSpec variables being replaced with an arbitrary
1413-
# number of arguments; recalculate actual-to-formal map
14141423
formal_to_actual = map_actuals_to_formals(
14151424
arg_kinds,
14161425
arg_names,
@@ -1999,11 +2008,66 @@ def check_argument_types(
19992008
# Keep track of consumed tuple *arg items.
20002009
mapper = ArgTypeExpander(self.argument_infer_context())
20012010
for i, actuals in enumerate(formal_to_actual):
2002-
for actual in actuals:
2003-
actual_type = arg_types[actual]
2011+
orig_callee_arg_type = get_proper_type(callee.arg_types[i])
2012+
2013+
# Checking the case that we have more than one item but the first argument
2014+
# is an unpack, so this would be something like:
2015+
# [Tuple[Unpack[Ts]], int]
2016+
#
2017+
# In this case we have to check everything together, we do this by re-unifying
2018+
# the suffices to the tuple, e.g. a single actual like
2019+
# Tuple[Unpack[Ts], int]
2020+
expanded_tuple = False
2021+
if len(actuals) > 1:
2022+
first_actual_arg_type = get_proper_type(arg_types[actuals[0]])
2023+
if (
2024+
isinstance(first_actual_arg_type, TupleType)
2025+
and len(first_actual_arg_type.items) == 1
2026+
and isinstance(get_proper_type(first_actual_arg_type.items[0]), UnpackType)
2027+
):
2028+
# TODO: use walrus operator
2029+
actual_types = [first_actual_arg_type.items[0]] + [
2030+
arg_types[a] for a in actuals[1:]
2031+
]
2032+
actual_kinds = [nodes.ARG_STAR] + [nodes.ARG_POS] * (len(actuals) - 1)
2033+
2034+
assert isinstance(orig_callee_arg_type, TupleType)
2035+
assert orig_callee_arg_type.items
2036+
callee_arg_types = orig_callee_arg_type.items
2037+
callee_arg_kinds = [nodes.ARG_STAR] + [nodes.ARG_POS] * (
2038+
len(orig_callee_arg_type.items) - 1
2039+
)
2040+
expanded_tuple = True
2041+
2042+
if not expanded_tuple:
2043+
actual_types = [arg_types[a] for a in actuals]
2044+
actual_kinds = [arg_kinds[a] for a in actuals]
2045+
if isinstance(orig_callee_arg_type, UnpackType):
2046+
unpacked_type = get_proper_type(orig_callee_arg_type.type)
2047+
# Only case we know of thus far.
2048+
assert isinstance(unpacked_type, TupleType)
2049+
actual_types = [arg_types[a] for a in actuals]
2050+
actual_kinds = [arg_kinds[a] for a in actuals]
2051+
callee_arg_types = unpacked_type.items
2052+
callee_arg_kinds = [ARG_POS] * len(actuals)
2053+
else:
2054+
callee_arg_types = [orig_callee_arg_type] * len(actuals)
2055+
callee_arg_kinds = [callee.arg_kinds[i]] * len(actuals)
2056+
2057+
assert len(actual_types) == len(actuals) == len(actual_kinds)
2058+
2059+
if len(callee_arg_types) != len(actual_types):
2060+
# TODO: Improve error message
2061+
self.chk.fail("Invalid number of arguments", context)
2062+
continue
2063+
2064+
assert len(callee_arg_types) == len(actual_types)
2065+
assert len(callee_arg_types) == len(callee_arg_kinds)
2066+
for actual, actual_type, actual_kind, callee_arg_type, callee_arg_kind in zip(
2067+
actuals, actual_types, actual_kinds, callee_arg_types, callee_arg_kinds
2068+
):
20042069
if actual_type is None:
20052070
continue # Some kind of error was already reported.
2006-
actual_kind = arg_kinds[actual]
20072071
# Check that a *arg is valid as varargs.
20082072
if actual_kind == nodes.ARG_STAR and not self.is_valid_var_arg(actual_type):
20092073
self.msg.invalid_var_arg(actual_type, context)
@@ -2013,13 +2077,13 @@ def check_argument_types(
20132077
is_mapping = is_subtype(actual_type, self.chk.named_type("typing.Mapping"))
20142078
self.msg.invalid_keyword_var_arg(actual_type, is_mapping, context)
20152079
expanded_actual = mapper.expand_actual_type(
2016-
actual_type, actual_kind, callee.arg_names[i], callee.arg_kinds[i]
2080+
actual_type, actual_kind, callee.arg_names[i], callee_arg_kind
20172081
)
20182082
check_arg(
20192083
expanded_actual,
20202084
actual_type,
2021-
arg_kinds[actual],
2022-
callee.arg_types[i],
2085+
actual_kind,
2086+
callee_arg_type,
20232087
actual + 1,
20242088
i + 1,
20252089
callee,
@@ -4719,6 +4783,7 @@ def is_valid_var_arg(self, typ: Type) -> bool:
47194783
)
47204784
or isinstance(typ, AnyType)
47214785
or isinstance(typ, ParamSpecType)
4786+
or isinstance(typ, UnpackType)
47224787
)
47234788

47244789
def is_valid_keyword_var_arg(self, typ: Type) -> bool:

mypy/constraints.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,26 @@ def infer_constraints_for_callable(
133133
)
134134
)
135135

136-
assert isinstance(unpack_type.type, TypeVarTupleType)
137-
constraints.append(Constraint(unpack_type.type, SUPERTYPE_OF, TypeList(actual_types)))
136+
unpacked_type = get_proper_type(unpack_type.type)
137+
if isinstance(unpacked_type, TypeVarTupleType):
138+
constraints.append(Constraint(unpacked_type, SUPERTYPE_OF, TypeList(actual_types)))
139+
elif isinstance(unpacked_type, TupleType):
140+
# Prefixes get converted to positional args, so technically the only case we
141+
# should have here is like Tuple[Unpack[Ts], Y1, Y2, Y3]. If this turns out
142+
# not to hold we can always handle the prefixes too.
143+
inner_unpack = unpacked_type.items[0]
144+
assert isinstance(inner_unpack, UnpackType)
145+
inner_unpacked_type = get_proper_type(inner_unpack.type)
146+
assert isinstance(inner_unpacked_type, TypeVarTupleType)
147+
suffix_len = len(unpacked_type.items) - 1
148+
constraints.append(
149+
Constraint(
150+
inner_unpacked_type, SUPERTYPE_OF, TypeList(actual_types[:-suffix_len])
151+
)
152+
)
153+
else:
154+
assert False, "mypy bug: unhandled constraint inference case"
155+
138156
else:
139157
for actual in actuals:
140158
actual_arg_type = arg_types[actual]

0 commit comments

Comments
 (0)