Skip to content

Commit a9f3b5e

Browse files
authored
Use helper methods for a lot of ArgKind checks (#10793)
Part of the motivation here is that I want to make positional-only arguments be properly reflected in the argument kinds, and having most of the logic done through helpers will make that easier.
1 parent 3552971 commit a9f3b5e

17 files changed

+83
-75
lines changed

mypy/argmap.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ def map_actuals_to_formals(actual_kinds: List[nodes.ArgKind],
2929
for ai, actual_kind in enumerate(actual_kinds):
3030
if actual_kind == nodes.ARG_POS:
3131
if fi < nformals:
32-
if formal_kinds[fi] in [nodes.ARG_POS, nodes.ARG_OPT,
33-
nodes.ARG_NAMED, nodes.ARG_NAMED_OPT]:
32+
if not formal_kinds[fi].is_star():
3433
formal_to_actual[fi].append(ai)
3534
fi += 1
3635
elif formal_kinds[fi] == nodes.ARG_STAR:
@@ -52,14 +51,14 @@ def map_actuals_to_formals(actual_kinds: List[nodes.ArgKind],
5251
# Assume that it is an iterable (if it isn't, there will be
5352
# an error later).
5453
while fi < nformals:
55-
if formal_kinds[fi] in (nodes.ARG_NAMED, nodes.ARG_NAMED_OPT, nodes.ARG_STAR2):
54+
if formal_kinds[fi].is_named(star=True):
5655
break
5756
else:
5857
formal_to_actual[fi].append(ai)
5958
if formal_kinds[fi] == nodes.ARG_STAR:
6059
break
6160
fi += 1
62-
elif actual_kind in (nodes.ARG_NAMED, nodes.ARG_NAMED_OPT):
61+
elif actual_kind.is_named():
6362
assert actual_names is not None, "Internal error: named kinds without names given"
6463
name = actual_names[ai]
6564
if name in formal_names:

mypy/checkexpr.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr,
3333
TypeAliasExpr, BackquoteExpr, EnumCallExpr, TypeAlias, SymbolNode, PlaceholderNode,
3434
ParamSpecExpr,
35-
ArgKind, ARG_POS, ARG_OPT, ARG_NAMED, ARG_STAR, ARG_STAR2, LITERAL_TYPE, REVEAL_TYPE,
35+
ArgKind, ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, LITERAL_TYPE, REVEAL_TYPE,
3636
)
3737
from mypy.literals import literal
3838
from mypy import nodes
@@ -1111,7 +1111,7 @@ def infer_arg_types_in_context(
11111111

11121112
for i, actuals in enumerate(formal_to_actual):
11131113
for ai in actuals:
1114-
if arg_kinds[ai] not in (nodes.ARG_STAR, nodes.ARG_STAR2):
1114+
if not arg_kinds[ai].is_star():
11151115
res[ai] = self.accept(args[ai], callee.arg_types[i])
11161116

11171117
# Fill in the rest of the argument types.
@@ -1364,18 +1364,14 @@ def check_argument_count(self,
13641364

13651365
# Check for too many or few values for formals.
13661366
for i, kind in enumerate(callee.arg_kinds):
1367-
if kind == nodes.ARG_POS and (not formal_to_actual[i] and
1368-
not is_unexpected_arg_error):
1369-
# No actual for a mandatory positional formal.
1367+
if kind.is_required() and not formal_to_actual[i] and not is_unexpected_arg_error:
1368+
# No actual for a mandatory formal
13701369
if messages:
1371-
messages.too_few_arguments(callee, context, actual_names)
1372-
ok = False
1373-
elif kind == nodes.ARG_NAMED and (not formal_to_actual[i] and
1374-
not is_unexpected_arg_error):
1375-
# No actual for a mandatory named formal
1376-
if messages:
1377-
argname = callee.arg_names[i] or "?"
1378-
messages.missing_named_argument(callee, context, argname)
1370+
if kind.is_positional():
1371+
messages.too_few_arguments(callee, context, actual_names)
1372+
else:
1373+
argname = callee.arg_names[i] or "?"
1374+
messages.missing_named_argument(callee, context, argname)
13791375
ok = False
13801376
elif not kind.is_star() and is_duplicate_mapping(
13811377
formal_to_actual[i], actual_types, actual_kinds):
@@ -1385,7 +1381,7 @@ def check_argument_count(self,
13851381
if messages:
13861382
messages.duplicate_argument_value(callee, i, context)
13871383
ok = False
1388-
elif (kind in (nodes.ARG_NAMED, nodes.ARG_NAMED_OPT) and formal_to_actual[i] and
1384+
elif (kind.is_named() and formal_to_actual[i] and
13891385
actual_kinds[formal_to_actual[i][0]] not in [nodes.ARG_NAMED, nodes.ARG_STAR2]):
13901386
# Positional argument when expecting a keyword argument.
13911387
if messages:
@@ -1925,7 +1921,7 @@ def combine_function_signatures(self, types: Sequence[Type]) -> Union[AnyType, C
19251921
for i, (new_kind, target_kind) in enumerate(zip(new_kinds, target.arg_kinds)):
19261922
if new_kind == target_kind:
19271923
continue
1928-
elif new_kind in (ARG_POS, ARG_OPT) and target_kind in (ARG_POS, ARG_OPT):
1924+
elif new_kind.is_positional() and target_kind.is_positional():
19291925
new_kinds[i] = ARG_POS
19301926
else:
19311927
too_complex = True

mypy/join.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
is_subtype, is_equivalent, is_subtype_ignoring_tvars, is_proper_subtype,
1515
is_protocol_implementation, find_member
1616
)
17-
from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, INVARIANT, COVARIANT, CONTRAVARIANT
17+
from mypy.nodes import INVARIANT, COVARIANT, CONTRAVARIANT
1818
import mypy.typeops
1919
from mypy import state
2020
from mypy import meet
@@ -536,11 +536,10 @@ def combine_arg_names(t: CallableType, s: CallableType) -> List[Optional[str]]:
536536
"""
537537
num_args = len(t.arg_types)
538538
new_names = []
539-
named = (ARG_NAMED, ARG_NAMED_OPT)
540539
for i in range(num_args):
541540
t_name = t.arg_names[i]
542541
s_name = s.arg_names[i]
543-
if t_name == s_name or t.arg_kinds[i] in named or s.arg_kinds[i] in named:
542+
if t_name == s_name or t.arg_kinds[i].is_named() or s.arg_kinds[i].is_named():
544543
new_names.append(t_name)
545544
else:
546545
new_names.append(None)

mypy/messages.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1711,12 +1711,12 @@ def format(typ: Type) -> str:
17111711
for arg_name, arg_type, arg_kind in zip(
17121712
func.arg_names, func.arg_types, func.arg_kinds):
17131713
if (arg_kind == ARG_POS and arg_name is None
1714-
or verbosity == 0 and arg_kind in (ARG_POS, ARG_OPT)):
1714+
or verbosity == 0 and arg_kind.is_positional()):
17151715

17161716
arg_strings.append(format(arg_type))
17171717
else:
17181718
constructor = ARG_CONSTRUCTOR_NAMES[arg_kind]
1719-
if arg_kind in (ARG_STAR, ARG_STAR2) or arg_name is None:
1719+
if arg_kind.is_star() or arg_name is None:
17201720
arg_strings.append("{}({})".format(
17211721
constructor,
17221722
format(arg_type)))
@@ -1849,7 +1849,7 @@ def [T <: int] f(self, x: int, y: T) -> None
18491849
for i in range(len(tp.arg_types)):
18501850
if s:
18511851
s += ', '
1852-
if tp.arg_kinds[i] in (ARG_NAMED, ARG_NAMED_OPT) and not asterisk:
1852+
if tp.arg_kinds[i].is_named() and not asterisk:
18531853
s += '*, '
18541854
asterisk = True
18551855
if tp.arg_kinds[i] == ARG_STAR:
@@ -1861,7 +1861,7 @@ def [T <: int] f(self, x: int, y: T) -> None
18611861
if name:
18621862
s += name + ': '
18631863
s += format_type_bare(tp.arg_types[i])
1864-
if tp.arg_kinds[i] in (ARG_OPT, ARG_NAMED_OPT):
1864+
if tp.arg_kinds[i].is_optional():
18651865
s += ' = ...'
18661866

18671867
# If we got a "special arg" (i.e: self, cls, etc...), prepend it to the arg list

mypy/nodes.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1534,6 +1534,26 @@ class ArgKind(Enum):
15341534
# In an argument list, keyword-only and also optional
15351535
ARG_NAMED_OPT = 5
15361536

1537+
def is_positional(self, star: bool = False) -> bool:
1538+
return (
1539+
self == ARG_POS
1540+
or self == ARG_OPT
1541+
or (star and self == ARG_STAR)
1542+
)
1543+
1544+
def is_named(self, star: bool = False) -> bool:
1545+
return (
1546+
self == ARG_NAMED
1547+
or self == ARG_NAMED_OPT
1548+
or (star and self == ARG_STAR2)
1549+
)
1550+
1551+
def is_required(self) -> bool:
1552+
return self == ARG_POS or self == ARG_NAMED
1553+
1554+
def is_optional(self) -> bool:
1555+
return self == ARG_OPT or self == ARG_NAMED_OPT
1556+
15371557
def is_star(self) -> bool:
15381558
return self == ARG_STAR or self == ARG_STAR2
15391559

mypy/plugins/functools.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Dict, NamedTuple, Optional
33

44
import mypy.plugin
5-
from mypy.nodes import ARG_OPT, ARG_POS, ARG_STAR2, Argument, FuncItem, Var
5+
from mypy.nodes import ARG_POS, ARG_STAR2, Argument, FuncItem, Var
66
from mypy.plugins.common import add_method_to_class
77
from mypy.types import AnyType, CallableType, get_proper_type, Type, TypeOfAny, UnboundType
88

@@ -65,7 +65,7 @@ def _find_other_type(method: _MethodInfo) -> Type:
6565
cur_pos_arg = 0
6666
other_arg = None
6767
for arg_kind, arg_type in zip(method.type.arg_kinds, method.type.arg_types):
68-
if arg_kind in (ARG_POS, ARG_OPT):
68+
if arg_kind.is_positional():
6969
if cur_pos_arg == first_arg_pos:
7070
other_arg = arg_type
7171
break

mypy/plugins/singledispatch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from mypy.messages import format_type
22
from mypy.plugins.common import add_method_to_class
33
from mypy.nodes import (
4-
ARG_POS, Argument, Block, ClassDef, SymbolTable, TypeInfo, Var, ARG_STAR, ARG_OPT, Context
4+
ARG_POS, Argument, Block, ClassDef, SymbolTable, TypeInfo, Var, Context
55
)
66
from mypy.subtypes import is_subtype
77
from mypy.types import (
@@ -98,7 +98,7 @@ def create_singledispatch_function_callback(ctx: FunctionContext) -> Type:
9898
)
9999
return ctx.default_return_type
100100

101-
elif func_type.arg_kinds[0] not in (ARG_POS, ARG_OPT, ARG_STAR):
101+
elif not func_type.arg_kinds[0].is_positional(star=True):
102102
fail(
103103
ctx,
104104
'First argument to singledispatch function must be a positional argument',

mypy/strconv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ def func_helper(self, o: 'mypy.nodes.FuncItem') -> List[object]:
6262
extra: List[Tuple[str, List[mypy.nodes.Var]]] = []
6363
for arg in o.arguments:
6464
kind: mypy.nodes.ArgKind = arg.kind
65-
if kind in (mypy.nodes.ARG_POS, mypy.nodes.ARG_NAMED):
65+
if kind.is_required():
6666
args.append(arg.variable)
67-
elif kind in (mypy.nodes.ARG_OPT, mypy.nodes.ARG_NAMED_OPT):
67+
elif kind.is_optional():
6868
assert arg.initializer is not None
6969
args.append(('default', [arg.variable, arg.initializer]))
7070
elif kind == mypy.nodes.ARG_STAR:

mypy/stubgen.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
TupleExpr, ListExpr, ComparisonExpr, CallExpr, IndexExpr, EllipsisExpr,
7474
ClassDef, MypyFile, Decorator, AssignmentStmt, TypeInfo,
7575
IfStmt, ImportAll, ImportFrom, Import, FuncDef, FuncBase, Block,
76-
Statement, OverloadedFuncDef, ARG_POS, ARG_STAR, ARG_STAR2, ARG_NAMED, ARG_NAMED_OPT
76+
Statement, OverloadedFuncDef, ARG_POS, ARG_STAR, ARG_STAR2, ARG_NAMED,
7777
)
7878
from mypy.stubgenc import generate_stub_for_c_module
7979
from mypy.stubutil import (
@@ -631,8 +631,7 @@ def visit_func_def(self, o: FuncDef, is_abstract: bool = False,
631631
if not isinstance(get_proper_type(annotated_type), AnyType):
632632
annotation = ": {}".format(self.print_annotation(annotated_type))
633633
if arg_.initializer:
634-
if kind in (ARG_NAMED, ARG_NAMED_OPT) and not any(arg.startswith('*')
635-
for arg in args):
634+
if kind.is_named() and not any(arg.startswith('*') for arg in args):
636635
args.append('*')
637636
if not annotation:
638637
typename = self.get_str_type_of_node(arg_.initializer, True, False)

mypy/stubtest.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def _verify_arg_default_value(
334334
) -> Iterator[str]:
335335
"""Checks whether argument default values are compatible."""
336336
if runtime_arg.default != inspect.Parameter.empty:
337-
if stub_arg.kind not in (nodes.ARG_OPT, nodes.ARG_NAMED_OPT):
337+
if stub_arg.kind.is_required():
338338
yield (
339339
'runtime argument "{}" has a default value but stub argument does not'.format(
340340
runtime_arg.name
@@ -363,7 +363,7 @@ def _verify_arg_default_value(
363363
)
364364
)
365365
else:
366-
if stub_arg.kind in (nodes.ARG_OPT, nodes.ARG_NAMED_OPT):
366+
if stub_arg.kind.is_optional():
367367
yield (
368368
'stub argument "{}" has a default value but runtime argument does not'.format(
369369
stub_arg.variable.name
@@ -406,7 +406,7 @@ def has_default(arg: Any) -> bool:
406406
if isinstance(arg, inspect.Parameter):
407407
return arg.default != inspect.Parameter.empty
408408
if isinstance(arg, nodes.Argument):
409-
return arg.kind in (nodes.ARG_OPT, nodes.ARG_NAMED_OPT)
409+
return arg.kind.is_optional()
410410
raise AssertionError
411411

412412
def get_desc(arg: Any) -> str:
@@ -433,9 +433,9 @@ def from_funcitem(stub: nodes.FuncItem) -> "Signature[nodes.Argument]":
433433
stub_sig: Signature[nodes.Argument] = Signature()
434434
stub_args = maybe_strip_cls(stub.name, stub.arguments)
435435
for stub_arg in stub_args:
436-
if stub_arg.kind in (nodes.ARG_POS, nodes.ARG_OPT):
436+
if stub_arg.kind.is_positional():
437437
stub_sig.pos.append(stub_arg)
438-
elif stub_arg.kind in (nodes.ARG_NAMED, nodes.ARG_NAMED_OPT):
438+
elif stub_arg.kind.is_named():
439439
stub_sig.kwonly[stub_arg.variable.name] = stub_arg
440440
elif stub_arg.kind == nodes.ARG_STAR:
441441
stub_sig.varpos = stub_arg
@@ -531,9 +531,9 @@ def get_kind(arg_name: str) -> nodes.ArgKind:
531531
initializer=None,
532532
kind=get_kind(arg_name),
533533
)
534-
if arg.kind in (nodes.ARG_POS, nodes.ARG_OPT):
534+
if arg.kind.is_positional():
535535
sig.pos.append(arg)
536-
elif arg.kind in (nodes.ARG_NAMED, nodes.ARG_NAMED_OPT):
536+
elif arg.kind.is_named():
537537
sig.kwonly[arg.variable.name] = arg
538538
elif arg.kind == nodes.ARG_STAR:
539539
sig.varpos = arg

mypy/subtypes.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
# import mypy.solve
1919
from mypy.nodes import (
2020
FuncBase, Var, Decorator, OverloadedFuncDef, TypeInfo, CONTRAVARIANT, COVARIANT,
21-
ARG_POS, ARG_OPT, ARG_STAR, ARG_STAR2
21+
2222
)
2323
from mypy.maptype import map_instance_to_supertype
2424
from mypy.expandtype import expand_type_by_instance
@@ -950,8 +950,8 @@ def _incompatible(left_arg: Optional[FormalArgument],
950950

951951
i = right_star.pos
952952
assert i is not None
953-
while i < len(left.arg_kinds) and left.arg_kinds[i] in (ARG_POS, ARG_OPT):
954-
if allow_partial_overlap and left.arg_kinds[i] == ARG_OPT:
953+
while i < len(left.arg_kinds) and left.arg_kinds[i].is_positional():
954+
if allow_partial_overlap and left.arg_kinds[i].is_optional():
955955
break
956956

957957
left_by_position = left.argument_by_position(i)
@@ -970,7 +970,7 @@ def _incompatible(left_arg: Optional[FormalArgument],
970970
right_names = {name for name in right.arg_names if name is not None}
971971
left_only_names = set()
972972
for name, kind in zip(left.arg_names, left.arg_kinds):
973-
if name is None or kind in (ARG_STAR, ARG_STAR2) or name in right_names:
973+
if name is None or kind.is_star() or name in right_names:
974974
continue
975975
left_only_names.add(name)
976976

mypy/suggestions.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
)
3838
from mypy.build import State, Graph
3939
from mypy.nodes import (
40-
ArgKind, ARG_STAR, ARG_NAMED, ARG_STAR2, ARG_NAMED_OPT, FuncDef, MypyFile, SymbolTable,
40+
ArgKind, ARG_STAR, ARG_STAR2, FuncDef, MypyFile, SymbolTable,
4141
Decorator, RefExpr,
4242
SymbolNode, TypeInfo, Expression, ReturnStmt, CallExpr,
4343
reverse_builtin_aliases,
@@ -479,7 +479,7 @@ def format_args(self,
479479
arg = '*' + arg
480480
elif kind == ARG_STAR2:
481481
arg = '**' + arg
482-
elif kind in (ARG_NAMED, ARG_NAMED_OPT):
482+
elif kind.is_named():
483483
if name:
484484
arg = "%s=%s" % (name, arg)
485485
args.append(arg)
@@ -763,8 +763,7 @@ def any_score_callable(t: CallableType, is_method: bool, ignore_return: bool) ->
763763

764764
def is_tricky_callable(t: CallableType) -> bool:
765765
"""Is t a callable that we need to put a ... in for syntax reasons?"""
766-
return t.is_ellipsis_args or any(
767-
k in (ARG_STAR, ARG_STAR2, ARG_NAMED, ARG_NAMED_OPT) for k in t.arg_kinds)
766+
return t.is_ellipsis_args or any(k.is_star() or k.is_named() for k in t.arg_kinds)
768767

769768

770769
class TypeFormatter(TypeStrVisitor):

mypy/typeanal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,7 @@ def analyze_callable_args(self, arglist: TypeList) -> Optional[Tuple[List[Type],
778778
assert found.fullname is not None
779779
kind = ARG_KINDS_BY_CONSTRUCTOR[found.fullname]
780780
kinds.append(kind)
781-
if arg.name is not None and kind in {ARG_STAR, ARG_STAR2}:
781+
if arg.name is not None and kind.is_star():
782782
self.fail("{} arguments should not have names".format(
783783
arg.constructor), arg)
784784
return None

0 commit comments

Comments
 (0)