Skip to content

Commit f91fb1a

Browse files
authored
[mypyc] Support various number-related dunders (#10679)
This adds support for these unary dunders: * `__neg__` * `__invert__` * `__int__` * `__float__` Also add support for binary, reversible dunders, such as `__add__` and `__radd__`. Finally, add support for in-place operator dunders such as `__iadd__`. The semantics of the binary dunders don't always match Python semantics, but many common use cases should work. There is one significant difference from Python that is not easy to remove: if a forward dunder method is called with an incompatible argument, it's treated the same as if it returned `NotImplemented`. This is necessary since the body of the method is never reached on incompatible argument type and there is no way to explicitly return `NotImplemented`. However, it's still recommended that the body returns `NotImplemented` as expected for Python compatibility. If a dunder returns `NotImplemented` and has a type annotation, the return type should be annotated as `Union[T, Any]`, where `T` is the return value when `NotImplemented` is not returned. Work on mypyc/mypyc#839.
1 parent 002722e commit f91fb1a

File tree

19 files changed

+1123
-198
lines changed

19 files changed

+1123
-198
lines changed

mypy/checker.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
is_final_node,
2828
ARG_NAMED)
2929
from mypy import nodes
30+
from mypy import operators
3031
from mypy.literals import literal, literal_hash, Key
3132
from mypy.typeanal import has_any_from_unimported_type, check_for_explicit_any
3233
from mypy.types import (
@@ -1026,13 +1027,13 @@ def is_forward_op_method(self, method_name: str) -> bool:
10261027
if self.options.python_version[0] == 2 and method_name == '__div__':
10271028
return True
10281029
else:
1029-
return method_name in nodes.reverse_op_methods
1030+
return method_name in operators.reverse_op_methods
10301031

10311032
def is_reverse_op_method(self, method_name: str) -> bool:
10321033
if self.options.python_version[0] == 2 and method_name == '__rdiv__':
10331034
return True
10341035
else:
1035-
return method_name in nodes.reverse_op_method_set
1036+
return method_name in operators.reverse_op_method_set
10361037

10371038
def check_for_missing_annotations(self, fdef: FuncItem) -> None:
10381039
# Check for functions with unspecified/not fully specified types.
@@ -1188,7 +1189,7 @@ def check_reverse_op_method(self, defn: FuncItem,
11881189
if self.options.python_version[0] == 2 and reverse_name == '__rdiv__':
11891190
forward_name = '__div__'
11901191
else:
1191-
forward_name = nodes.normal_from_reverse_op[reverse_name]
1192+
forward_name = operators.normal_from_reverse_op[reverse_name]
11921193
forward_inst = get_proper_type(reverse_type.arg_types[1])
11931194
if isinstance(forward_inst, TypeVarType):
11941195
forward_inst = get_proper_type(forward_inst.upper_bound)
@@ -1327,7 +1328,7 @@ def check_inplace_operator_method(self, defn: FuncBase) -> None:
13271328
They cannot arbitrarily overlap with __add__.
13281329
"""
13291330
method = defn.name
1330-
if method not in nodes.inplace_operator_methods:
1331+
if method not in operators.inplace_operator_methods:
13311332
return
13321333
typ = bind_self(self.function_type(defn))
13331334
cls = defn.info
@@ -1447,7 +1448,7 @@ def check_method_or_accessor_override_for_base(self, defn: Union[FuncDef,
14471448
# (__init__, __new__, __init_subclass__ are special).
14481449
if self.check_method_override_for_base_with_name(defn, name, base):
14491450
return True
1450-
if name in nodes.inplace_operator_methods:
1451+
if name in operators.inplace_operator_methods:
14511452
# Figure out the name of the corresponding operator method.
14521453
method = '__' + name[3:]
14531454
# An inplace operator method such as __iadd__ might not be
@@ -5529,9 +5530,9 @@ def infer_operator_assignment_method(typ: Type, operator: str) -> Tuple[bool, st
55295530
depending on which method is supported by the type.
55305531
"""
55315532
typ = get_proper_type(typ)
5532-
method = nodes.op_methods[operator]
5533+
method = operators.op_methods[operator]
55335534
if isinstance(typ, Instance):
5534-
if operator in nodes.ops_with_inplace_method:
5535+
if operator in operators.ops_with_inplace_method:
55355536
inplace_method = '__i' + method[2:]
55365537
if typ.type.has_readable_member(inplace_method):
55375538
return True, inplace_method

mypy/checkexpr.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
)
3737
from mypy.literals import literal
3838
from mypy import nodes
39+
from mypy import operators
3940
import mypy.checker
4041
from mypy import types
4142
from mypy.sametypes import is_same_type
@@ -2169,7 +2170,7 @@ def visit_op_expr(self, e: OpExpr) -> Type:
21692170
if right_radd_method is None:
21702171
return self.concat_tuples(proper_left_type, proper_right_type)
21712172

2172-
if e.op in nodes.op_methods:
2173+
if e.op in operators.op_methods:
21732174
method = self.get_operator_method(e.op)
21742175
result, method_type = self.check_op(method, left_type, e.right, e,
21752176
allow_reverse=True)
@@ -2234,7 +2235,7 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
22342235
self.msg.dangerous_comparison(left_type, cont_type, 'container', e)
22352236
else:
22362237
self.msg.add_errors(local_errors)
2237-
elif operator in nodes.op_methods:
2238+
elif operator in operators.op_methods:
22382239
method = self.get_operator_method(operator)
22392240
err_count = self.msg.errors.total_errors()
22402241
sub_result, method_type = self.check_op(method, left_type, right, e,
@@ -2362,7 +2363,7 @@ def get_operator_method(self, op: str) -> str:
23622363
# TODO also check for "from __future__ import division"
23632364
return '__div__'
23642365
else:
2365-
return nodes.op_methods[op]
2366+
return operators.op_methods[op]
23662367

23672368
def check_method_call_by_name(self,
23682369
method: str,
@@ -2537,7 +2538,7 @@ def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]:
25372538
# which records tuples containing the method, base type, and the argument.
25382539

25392540
bias_right = is_proper_subtype(right_type, left_type)
2540-
if op_name in nodes.op_methods_that_shortcut and is_same_type(left_type, right_type):
2541+
if op_name in operators.op_methods_that_shortcut and is_same_type(left_type, right_type):
25412542
# When we do "A() + A()", for example, Python will only call the __add__ method,
25422543
# never the __radd__ method.
25432544
#
@@ -2575,8 +2576,8 @@ def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]:
25752576
# When running Python 2, we might also try calling the __cmp__ method.
25762577

25772578
is_python_2 = self.chk.options.python_version[0] == 2
2578-
if is_python_2 and op_name in nodes.ops_falling_back_to_cmp:
2579-
cmp_method = nodes.comparison_fallback_method
2579+
if is_python_2 and op_name in operators.ops_falling_back_to_cmp:
2580+
cmp_method = operators.comparison_fallback_method
25802581
left_cmp_op = lookup_operator(cmp_method, left_type)
25812582
right_cmp_op = lookup_operator(cmp_method, right_type)
25822583

@@ -2760,7 +2761,7 @@ def get_reverse_op_method(self, method: str) -> str:
27602761
if method == '__div__' and self.chk.options.python_version[0] == 2:
27612762
return '__rdiv__'
27622763
else:
2763-
return nodes.reverse_op_methods[method]
2764+
return operators.reverse_op_methods[method]
27642765

27652766
def check_boolean_op(self, e: OpExpr, context: Context) -> Type:
27662767
"""Type check a boolean operation ('and' or 'or')."""
@@ -2867,7 +2868,7 @@ def visit_unary_expr(self, e: UnaryExpr) -> Type:
28672868
if op == 'not':
28682869
result = self.bool_type() # type: Type
28692870
else:
2870-
method = nodes.unary_op_methods[op]
2871+
method = operators.unary_op_methods[op]
28712872
result, method_type = self.check_method_call_by_name(method, operand_type, [], [], e)
28722873
e.method_type = method_type
28732874
return result
@@ -4533,9 +4534,9 @@ def is_operator_method(fullname: Optional[str]) -> bool:
45334534
return False
45344535
short_name = fullname.split('.')[-1]
45354536
return (
4536-
short_name in nodes.op_methods.values() or
4537-
short_name in nodes.reverse_op_methods.values() or
4538-
short_name in nodes.unary_op_methods.values())
4537+
short_name in operators.op_methods.values() or
4538+
short_name in operators.reverse_op_methods.values() or
4539+
short_name in operators.unary_op_methods.values())
45394540

45404541

45414542
def get_partial_instance_type(t: Optional[Type]) -> Optional[PartialType]:

mypy/messages.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@
2828
)
2929
from mypy.typetraverser import TypeTraverserVisitor
3030
from mypy.nodes import (
31-
TypeInfo, Context, MypyFile, op_methods, op_methods_to_symbols,
32-
FuncDef, reverse_builtin_aliases,
31+
TypeInfo, Context, MypyFile, FuncDef, reverse_builtin_aliases,
3332
ARG_POS, ARG_OPT, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2,
3433
ReturnStmt, NameExpr, Var, CONTRAVARIANT, COVARIANT, SymbolNode,
3534
CallExpr, IndexExpr, StrExpr, SymbolTable, TempNode
3635
)
36+
from mypy.operators import op_methods, op_methods_to_symbols
3737
from mypy.subtypes import (
3838
is_subtype, find_member, get_member_flags,
3939
IS_SETTABLE, IS_CLASSVAR, IS_CLASS_OR_STATIC,

mypy/nodes.py

Lines changed: 0 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1634,100 +1634,6 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
16341634
return visitor.visit_assignment_expr(self)
16351635

16361636

1637-
# Map from binary operator id to related method name (in Python 3).
1638-
op_methods = {
1639-
'+': '__add__',
1640-
'-': '__sub__',
1641-
'*': '__mul__',
1642-
'/': '__truediv__',
1643-
'%': '__mod__',
1644-
'divmod': '__divmod__',
1645-
'//': '__floordiv__',
1646-
'**': '__pow__',
1647-
'@': '__matmul__',
1648-
'&': '__and__',
1649-
'|': '__or__',
1650-
'^': '__xor__',
1651-
'<<': '__lshift__',
1652-
'>>': '__rshift__',
1653-
'==': '__eq__',
1654-
'!=': '__ne__',
1655-
'<': '__lt__',
1656-
'>=': '__ge__',
1657-
'>': '__gt__',
1658-
'<=': '__le__',
1659-
'in': '__contains__',
1660-
} # type: Final
1661-
1662-
op_methods_to_symbols = {v: k for (k, v) in op_methods.items()} # type: Final
1663-
op_methods_to_symbols['__div__'] = '/'
1664-
1665-
comparison_fallback_method = '__cmp__' # type: Final
1666-
ops_falling_back_to_cmp = {'__ne__', '__eq__',
1667-
'__lt__', '__le__',
1668-
'__gt__', '__ge__'} # type: Final
1669-
1670-
1671-
ops_with_inplace_method = {
1672-
'+', '-', '*', '/', '%', '//', '**', '@', '&', '|', '^', '<<', '>>'} # type: Final
1673-
1674-
inplace_operator_methods = set(
1675-
'__i' + op_methods[op][2:] for op in ops_with_inplace_method) # type: Final
1676-
1677-
reverse_op_methods = {
1678-
'__add__': '__radd__',
1679-
'__sub__': '__rsub__',
1680-
'__mul__': '__rmul__',
1681-
'__truediv__': '__rtruediv__',
1682-
'__mod__': '__rmod__',
1683-
'__divmod__': '__rdivmod__',
1684-
'__floordiv__': '__rfloordiv__',
1685-
'__pow__': '__rpow__',
1686-
'__matmul__': '__rmatmul__',
1687-
'__and__': '__rand__',
1688-
'__or__': '__ror__',
1689-
'__xor__': '__rxor__',
1690-
'__lshift__': '__rlshift__',
1691-
'__rshift__': '__rrshift__',
1692-
'__eq__': '__eq__',
1693-
'__ne__': '__ne__',
1694-
'__lt__': '__gt__',
1695-
'__ge__': '__le__',
1696-
'__gt__': '__lt__',
1697-
'__le__': '__ge__',
1698-
} # type: Final
1699-
1700-
# Suppose we have some class A. When we do A() + A(), Python will only check
1701-
# the output of A().__add__(A()) and skip calling the __radd__ method entirely.
1702-
# This shortcut is used only for the following methods:
1703-
op_methods_that_shortcut = {
1704-
'__add__',
1705-
'__sub__',
1706-
'__mul__',
1707-
'__div__',
1708-
'__truediv__',
1709-
'__mod__',
1710-
'__divmod__',
1711-
'__floordiv__',
1712-
'__pow__',
1713-
'__matmul__',
1714-
'__and__',
1715-
'__or__',
1716-
'__xor__',
1717-
'__lshift__',
1718-
'__rshift__',
1719-
} # type: Final
1720-
1721-
normal_from_reverse_op = dict((m, n) for n, m in reverse_op_methods.items()) # type: Final
1722-
reverse_op_method_set = set(reverse_op_methods.values()) # type: Final
1723-
1724-
unary_op_methods = {
1725-
'-': '__neg__',
1726-
'+': '__pos__',
1727-
'~': '__invert__',
1728-
} # type: Final
1729-
1730-
17311637
class OpExpr(Expression):
17321638
"""Binary operation (other than . or [] or comparison operators,
17331639
which have specific nodes)."""

mypy/operators.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""Information about Python operators"""
2+
3+
from typing_extensions import Final
4+
5+
6+
# Map from binary operator id to related method name (in Python 3).
7+
op_methods = {
8+
'+': '__add__',
9+
'-': '__sub__',
10+
'*': '__mul__',
11+
'/': '__truediv__',
12+
'%': '__mod__',
13+
'divmod': '__divmod__',
14+
'//': '__floordiv__',
15+
'**': '__pow__',
16+
'@': '__matmul__',
17+
'&': '__and__',
18+
'|': '__or__',
19+
'^': '__xor__',
20+
'<<': '__lshift__',
21+
'>>': '__rshift__',
22+
'==': '__eq__',
23+
'!=': '__ne__',
24+
'<': '__lt__',
25+
'>=': '__ge__',
26+
'>': '__gt__',
27+
'<=': '__le__',
28+
'in': '__contains__',
29+
} # type: Final
30+
31+
op_methods_to_symbols = {v: k for (k, v) in op_methods.items()} # type: Final
32+
op_methods_to_symbols['__div__'] = '/'
33+
34+
comparison_fallback_method = '__cmp__' # type: Final
35+
ops_falling_back_to_cmp = {'__ne__', '__eq__',
36+
'__lt__', '__le__',
37+
'__gt__', '__ge__'} # type: Final
38+
39+
40+
ops_with_inplace_method = {
41+
'+', '-', '*', '/', '%', '//', '**', '@', '&', '|', '^', '<<', '>>'} # type: Final
42+
43+
inplace_operator_methods = set(
44+
'__i' + op_methods[op][2:] for op in ops_with_inplace_method) # type: Final
45+
46+
reverse_op_methods = {
47+
'__add__': '__radd__',
48+
'__sub__': '__rsub__',
49+
'__mul__': '__rmul__',
50+
'__truediv__': '__rtruediv__',
51+
'__mod__': '__rmod__',
52+
'__divmod__': '__rdivmod__',
53+
'__floordiv__': '__rfloordiv__',
54+
'__pow__': '__rpow__',
55+
'__matmul__': '__rmatmul__',
56+
'__and__': '__rand__',
57+
'__or__': '__ror__',
58+
'__xor__': '__rxor__',
59+
'__lshift__': '__rlshift__',
60+
'__rshift__': '__rrshift__',
61+
'__eq__': '__eq__',
62+
'__ne__': '__ne__',
63+
'__lt__': '__gt__',
64+
'__ge__': '__le__',
65+
'__gt__': '__lt__',
66+
'__le__': '__ge__',
67+
} # type: Final
68+
69+
reverse_op_method_names = set(reverse_op_methods.values()) # type: Final
70+
71+
# Suppose we have some class A. When we do A() + A(), Python will only check
72+
# the output of A().__add__(A()) and skip calling the __radd__ method entirely.
73+
# This shortcut is used only for the following methods:
74+
op_methods_that_shortcut = {
75+
'__add__',
76+
'__sub__',
77+
'__mul__',
78+
'__div__',
79+
'__truediv__',
80+
'__mod__',
81+
'__divmod__',
82+
'__floordiv__',
83+
'__pow__',
84+
'__matmul__',
85+
'__and__',
86+
'__or__',
87+
'__xor__',
88+
'__lshift__',
89+
'__rshift__',
90+
} # type: Final
91+
92+
normal_from_reverse_op = dict((m, n) for n, m in reverse_op_methods.items()) # type: Final
93+
reverse_op_method_set = set(reverse_op_methods.values()) # type: Final
94+
95+
unary_op_methods = {
96+
'-': '__neg__',
97+
'+': '__pos__',
98+
'~': '__invert__',
99+
} # type: Final

mypy/server/deps.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ class 'mod.Cls'. This can also refer to an attribute inherited from a
8989
ComparisonExpr, GeneratorExpr, DictionaryComprehension, StarExpr, PrintStmt, ForStmt, WithStmt,
9090
TupleExpr, OperatorAssignmentStmt, DelStmt, YieldFromExpr, Decorator, Block,
9191
TypeInfo, FuncBase, OverloadedFuncDef, RefExpr, SuperExpr, Var, NamedTupleExpr, TypedDictExpr,
92-
LDEF, MDEF, GDEF, TypeAliasExpr, NewTypeExpr, ImportAll, EnumCallExpr, AwaitExpr,
92+
LDEF, MDEF, GDEF, TypeAliasExpr, NewTypeExpr, ImportAll, EnumCallExpr, AwaitExpr
93+
)
94+
from mypy.operators import (
9395
op_methods, reverse_op_methods, ops_with_inplace_method, unary_op_methods
9496
)
9597
from mypy.traverser import TraverserVisitor

0 commit comments

Comments
 (0)