Skip to content

[mypyc] Support various number-related dunders #10679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Jun 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
eb23625
Generate slots for __neg__, __invert__, _int__ and __float__
JukkaL May 2, 2021
3993f34
Specialize __neg__ and __invert__ for native classes
JukkaL May 2, 2021
8899420
Add tests cases for unary operators
JukkaL May 2, 2021
acd61dd
Basic support for __add__ and __radd__
JukkaL May 2, 2021
01dea76
Add NotImplemented test case
JukkaL May 2, 2021
9dab63e
Small refactoring
JukkaL May 2, 2021
d2d3e60
Revert change to NotImplemented fixture
JukkaL May 2, 2021
7cc4d2a
[WIP] Start refactoring dunder wrapper generation
JukkaL May 2, 2021
fb77328
Improvements to reverse operator methods
JukkaL May 2, 2021
4be4daf
Add partial NotImplemented support
JukkaL May 3, 2021
bd66107
Optimize reverse operator methods and refactor unbox/cast emit
JukkaL May 3, 2021
9cc93cd
Minor tweaks
JukkaL May 3, 2021
6566954
Improve error handling to be more compatible with CPython
JukkaL May 3, 2021
f01d8c8
Refactor: move operation info from mypy.nodes to mypy.operators
JukkaL May 3, 2021
b3e92d0
Support __sub__ and __rsub__
JukkaL May 3, 2021
97981b7
Support more binary ops (__mul__ etc.)
JukkaL May 3, 2021
7265020
Support matrix multiply (@)
JukkaL May 3, 2021
a305ba5
Support true division and floor division
JukkaL May 3, 2021
ed19b25
Support __iadd__
JukkaL May 3, 2021
d91da59
Support additional in-place operator methods such as __isub__
JukkaL May 3, 2021
639a8d4
Test runtime type error
JukkaL May 3, 2021
783c10b
Fix crash in test driver
JukkaL Jun 6, 2021
c560a22
Add test case
JukkaL Jun 6, 2021
cc625d0
Improve comments and docstrings
JukkaL Jun 6, 2021
74fb47c
Fixes to reverse methods
JukkaL Jun 19, 2021
be84c84
Update and add tests
JukkaL Jun 20, 2021
2f48293
Fix lint
JukkaL Jun 20, 2021
16e7c2c
Fix test on Python 3.6
JukkaL Jun 21, 2021
db092de
Minor refactoring
JukkaL Jun 21, 2021
0cc18b9
Oops, commit missing change
JukkaL Jun 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
is_final_node,
ARG_NAMED)
from mypy import nodes
from mypy import operators
from mypy.literals import literal, literal_hash, Key
from mypy.typeanal import has_any_from_unimported_type, check_for_explicit_any
from mypy.types import (
Expand Down Expand Up @@ -1026,13 +1027,13 @@ def is_forward_op_method(self, method_name: str) -> bool:
if self.options.python_version[0] == 2 and method_name == '__div__':
return True
else:
return method_name in nodes.reverse_op_methods
return method_name in operators.reverse_op_methods

def is_reverse_op_method(self, method_name: str) -> bool:
if self.options.python_version[0] == 2 and method_name == '__rdiv__':
return True
else:
return method_name in nodes.reverse_op_method_set
return method_name in operators.reverse_op_method_set

def check_for_missing_annotations(self, fdef: FuncItem) -> None:
# Check for functions with unspecified/not fully specified types.
Expand Down Expand Up @@ -1188,7 +1189,7 @@ def check_reverse_op_method(self, defn: FuncItem,
if self.options.python_version[0] == 2 and reverse_name == '__rdiv__':
forward_name = '__div__'
else:
forward_name = nodes.normal_from_reverse_op[reverse_name]
forward_name = operators.normal_from_reverse_op[reverse_name]
forward_inst = get_proper_type(reverse_type.arg_types[1])
if isinstance(forward_inst, TypeVarType):
forward_inst = get_proper_type(forward_inst.upper_bound)
Expand Down Expand Up @@ -1327,7 +1328,7 @@ def check_inplace_operator_method(self, defn: FuncBase) -> None:
They cannot arbitrarily overlap with __add__.
"""
method = defn.name
if method not in nodes.inplace_operator_methods:
if method not in operators.inplace_operator_methods:
return
typ = bind_self(self.function_type(defn))
cls = defn.info
Expand Down Expand Up @@ -1447,7 +1448,7 @@ def check_method_or_accessor_override_for_base(self, defn: Union[FuncDef,
# (__init__, __new__, __init_subclass__ are special).
if self.check_method_override_for_base_with_name(defn, name, base):
return True
if name in nodes.inplace_operator_methods:
if name in operators.inplace_operator_methods:
# Figure out the name of the corresponding operator method.
method = '__' + name[3:]
# An inplace operator method such as __iadd__ might not be
Expand Down Expand Up @@ -5522,9 +5523,9 @@ def infer_operator_assignment_method(typ: Type, operator: str) -> Tuple[bool, st
depending on which method is supported by the type.
"""
typ = get_proper_type(typ)
method = nodes.op_methods[operator]
method = operators.op_methods[operator]
if isinstance(typ, Instance):
if operator in nodes.ops_with_inplace_method:
if operator in operators.ops_with_inplace_method:
inplace_method = '__i' + method[2:]
if typ.type.has_readable_member(inplace_method):
return True, inplace_method
Expand Down
23 changes: 12 additions & 11 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from mypy.literals import literal
from mypy import nodes
from mypy import operators
import mypy.checker
from mypy import types
from mypy.sametypes import is_same_type
Expand Down Expand Up @@ -2169,7 +2170,7 @@ def visit_op_expr(self, e: OpExpr) -> Type:
if right_radd_method is None:
return self.concat_tuples(proper_left_type, proper_right_type)

if e.op in nodes.op_methods:
if e.op in operators.op_methods:
method = self.get_operator_method(e.op)
result, method_type = self.check_op(method, left_type, e.right, e,
allow_reverse=True)
Expand Down Expand Up @@ -2234,7 +2235,7 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
self.msg.dangerous_comparison(left_type, cont_type, 'container', e)
else:
self.msg.add_errors(local_errors)
elif operator in nodes.op_methods:
elif operator in operators.op_methods:
method = self.get_operator_method(operator)
err_count = self.msg.errors.total_errors()
sub_result, method_type = self.check_op(method, left_type, right, e,
Expand Down Expand Up @@ -2362,7 +2363,7 @@ def get_operator_method(self, op: str) -> str:
# TODO also check for "from __future__ import division"
return '__div__'
else:
return nodes.op_methods[op]
return operators.op_methods[op]

def check_method_call_by_name(self,
method: str,
Expand Down Expand Up @@ -2537,7 +2538,7 @@ def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]:
# which records tuples containing the method, base type, and the argument.

bias_right = is_proper_subtype(right_type, left_type)
if op_name in nodes.op_methods_that_shortcut and is_same_type(left_type, right_type):
if op_name in operators.op_methods_that_shortcut and is_same_type(left_type, right_type):
# When we do "A() + A()", for example, Python will only call the __add__ method,
# never the __radd__ method.
#
Expand Down Expand Up @@ -2575,8 +2576,8 @@ def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]:
# When running Python 2, we might also try calling the __cmp__ method.

is_python_2 = self.chk.options.python_version[0] == 2
if is_python_2 and op_name in nodes.ops_falling_back_to_cmp:
cmp_method = nodes.comparison_fallback_method
if is_python_2 and op_name in operators.ops_falling_back_to_cmp:
cmp_method = operators.comparison_fallback_method
left_cmp_op = lookup_operator(cmp_method, left_type)
right_cmp_op = lookup_operator(cmp_method, right_type)

Expand Down Expand Up @@ -2760,7 +2761,7 @@ def get_reverse_op_method(self, method: str) -> str:
if method == '__div__' and self.chk.options.python_version[0] == 2:
return '__rdiv__'
else:
return nodes.reverse_op_methods[method]
return operators.reverse_op_methods[method]

def check_boolean_op(self, e: OpExpr, context: Context) -> Type:
"""Type check a boolean operation ('and' or 'or')."""
Expand Down Expand Up @@ -2867,7 +2868,7 @@ def visit_unary_expr(self, e: UnaryExpr) -> Type:
if op == 'not':
result = self.bool_type() # type: Type
else:
method = nodes.unary_op_methods[op]
method = operators.unary_op_methods[op]
result, method_type = self.check_method_call_by_name(method, operand_type, [], [], e)
e.method_type = method_type
return result
Expand Down Expand Up @@ -4533,9 +4534,9 @@ def is_operator_method(fullname: Optional[str]) -> bool:
return False
short_name = fullname.split('.')[-1]
return (
short_name in nodes.op_methods.values() or
short_name in nodes.reverse_op_methods.values() or
short_name in nodes.unary_op_methods.values())
short_name in operators.op_methods.values() or
short_name in operators.reverse_op_methods.values() or
short_name in operators.unary_op_methods.values())


def get_partial_instance_type(t: Optional[Type]) -> Optional[PartialType]:
Expand Down
4 changes: 2 additions & 2 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
)
from mypy.typetraverser import TypeTraverserVisitor
from mypy.nodes import (
TypeInfo, Context, MypyFile, op_methods, op_methods_to_symbols,
FuncDef, reverse_builtin_aliases,
TypeInfo, Context, MypyFile, FuncDef, reverse_builtin_aliases,
ARG_POS, ARG_OPT, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2,
ReturnStmt, NameExpr, Var, CONTRAVARIANT, COVARIANT, SymbolNode,
CallExpr, IndexExpr, StrExpr, SymbolTable, TempNode
)
from mypy.operators import op_methods, op_methods_to_symbols
from mypy.subtypes import (
is_subtype, find_member, get_member_flags,
IS_SETTABLE, IS_CLASSVAR, IS_CLASS_OR_STATIC,
Expand Down
94 changes: 0 additions & 94 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,100 +1634,6 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_assignment_expr(self)


# Map from binary operator id to related method name (in Python 3).
op_methods = {
'+': '__add__',
'-': '__sub__',
'*': '__mul__',
'/': '__truediv__',
'%': '__mod__',
'divmod': '__divmod__',
'//': '__floordiv__',
'**': '__pow__',
'@': '__matmul__',
'&': '__and__',
'|': '__or__',
'^': '__xor__',
'<<': '__lshift__',
'>>': '__rshift__',
'==': '__eq__',
'!=': '__ne__',
'<': '__lt__',
'>=': '__ge__',
'>': '__gt__',
'<=': '__le__',
'in': '__contains__',
} # type: Final

op_methods_to_symbols = {v: k for (k, v) in op_methods.items()} # type: Final
op_methods_to_symbols['__div__'] = '/'

comparison_fallback_method = '__cmp__' # type: Final
ops_falling_back_to_cmp = {'__ne__', '__eq__',
'__lt__', '__le__',
'__gt__', '__ge__'} # type: Final


ops_with_inplace_method = {
'+', '-', '*', '/', '%', '//', '**', '@', '&', '|', '^', '<<', '>>'} # type: Final

inplace_operator_methods = set(
'__i' + op_methods[op][2:] for op in ops_with_inplace_method) # type: Final

reverse_op_methods = {
'__add__': '__radd__',
'__sub__': '__rsub__',
'__mul__': '__rmul__',
'__truediv__': '__rtruediv__',
'__mod__': '__rmod__',
'__divmod__': '__rdivmod__',
'__floordiv__': '__rfloordiv__',
'__pow__': '__rpow__',
'__matmul__': '__rmatmul__',
'__and__': '__rand__',
'__or__': '__ror__',
'__xor__': '__rxor__',
'__lshift__': '__rlshift__',
'__rshift__': '__rrshift__',
'__eq__': '__eq__',
'__ne__': '__ne__',
'__lt__': '__gt__',
'__ge__': '__le__',
'__gt__': '__lt__',
'__le__': '__ge__',
} # type: Final

# Suppose we have some class A. When we do A() + A(), Python will only check
# the output of A().__add__(A()) and skip calling the __radd__ method entirely.
# This shortcut is used only for the following methods:
op_methods_that_shortcut = {
'__add__',
'__sub__',
'__mul__',
'__div__',
'__truediv__',
'__mod__',
'__divmod__',
'__floordiv__',
'__pow__',
'__matmul__',
'__and__',
'__or__',
'__xor__',
'__lshift__',
'__rshift__',
} # type: Final

normal_from_reverse_op = dict((m, n) for n, m in reverse_op_methods.items()) # type: Final
reverse_op_method_set = set(reverse_op_methods.values()) # type: Final

unary_op_methods = {
'-': '__neg__',
'+': '__pos__',
'~': '__invert__',
} # type: Final


class OpExpr(Expression):
"""Binary operation (other than . or [] or comparison operators,
which have specific nodes)."""
Expand Down
99 changes: 99 additions & 0 deletions mypy/operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""Information about Python operators"""

from typing_extensions import Final


# Map from binary operator id to related method name (in Python 3).
op_methods = {
'+': '__add__',
'-': '__sub__',
'*': '__mul__',
'/': '__truediv__',
'%': '__mod__',
'divmod': '__divmod__',
'//': '__floordiv__',
'**': '__pow__',
'@': '__matmul__',
'&': '__and__',
'|': '__or__',
'^': '__xor__',
'<<': '__lshift__',
'>>': '__rshift__',
'==': '__eq__',
'!=': '__ne__',
'<': '__lt__',
'>=': '__ge__',
'>': '__gt__',
'<=': '__le__',
'in': '__contains__',
} # type: Final

op_methods_to_symbols = {v: k for (k, v) in op_methods.items()} # type: Final
op_methods_to_symbols['__div__'] = '/'

comparison_fallback_method = '__cmp__' # type: Final
ops_falling_back_to_cmp = {'__ne__', '__eq__',
'__lt__', '__le__',
'__gt__', '__ge__'} # type: Final


ops_with_inplace_method = {
'+', '-', '*', '/', '%', '//', '**', '@', '&', '|', '^', '<<', '>>'} # type: Final

inplace_operator_methods = set(
'__i' + op_methods[op][2:] for op in ops_with_inplace_method) # type: Final

reverse_op_methods = {
'__add__': '__radd__',
'__sub__': '__rsub__',
'__mul__': '__rmul__',
'__truediv__': '__rtruediv__',
'__mod__': '__rmod__',
'__divmod__': '__rdivmod__',
'__floordiv__': '__rfloordiv__',
'__pow__': '__rpow__',
'__matmul__': '__rmatmul__',
'__and__': '__rand__',
'__or__': '__ror__',
'__xor__': '__rxor__',
'__lshift__': '__rlshift__',
'__rshift__': '__rrshift__',
'__eq__': '__eq__',
'__ne__': '__ne__',
'__lt__': '__gt__',
'__ge__': '__le__',
'__gt__': '__lt__',
'__le__': '__ge__',
} # type: Final

reverse_op_method_names = set(reverse_op_methods.values()) # type: Final

# Suppose we have some class A. When we do A() + A(), Python will only check
# the output of A().__add__(A()) and skip calling the __radd__ method entirely.
# This shortcut is used only for the following methods:
op_methods_that_shortcut = {
'__add__',
'__sub__',
'__mul__',
'__div__',
'__truediv__',
'__mod__',
'__divmod__',
'__floordiv__',
'__pow__',
'__matmul__',
'__and__',
'__or__',
'__xor__',
'__lshift__',
'__rshift__',
} # type: Final

normal_from_reverse_op = dict((m, n) for n, m in reverse_op_methods.items()) # type: Final
reverse_op_method_set = set(reverse_op_methods.values()) # type: Final

unary_op_methods = {
'-': '__neg__',
'+': '__pos__',
'~': '__invert__',
} # type: Final
4 changes: 3 additions & 1 deletion mypy/server/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ class 'mod.Cls'. This can also refer to an attribute inherited from a
ComparisonExpr, GeneratorExpr, DictionaryComprehension, StarExpr, PrintStmt, ForStmt, WithStmt,
TupleExpr, OperatorAssignmentStmt, DelStmt, YieldFromExpr, Decorator, Block,
TypeInfo, FuncBase, OverloadedFuncDef, RefExpr, SuperExpr, Var, NamedTupleExpr, TypedDictExpr,
LDEF, MDEF, GDEF, TypeAliasExpr, NewTypeExpr, ImportAll, EnumCallExpr, AwaitExpr,
LDEF, MDEF, GDEF, TypeAliasExpr, NewTypeExpr, ImportAll, EnumCallExpr, AwaitExpr
)
from mypy.operators import (
op_methods, reverse_op_methods, ops_with_inplace_method, unary_op_methods
)
from mypy.traverser import TraverserVisitor
Expand Down
Loading