Skip to content

[attrs] Support attr.s(eq=..., order=...) #7619

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 71 additions & 6 deletions mypy/plugins/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,62 @@ def deserialize(cls, info: TypeInfo, data: JsonDict) -> 'Attribute':
)


def _determine_eq_order(ctx: 'mypy.plugin.ClassDefContext') -> Tuple[bool, bool]:
"""
Validate the combination of *cmp*, *eq*, and *order*. Derive the effective
values of eq and order.
"""
cmp = _get_decorator_optional_bool_argument(ctx, 'cmp')
eq = _get_decorator_optional_bool_argument(ctx, 'eq')
order = _get_decorator_optional_bool_argument(ctx, 'order')

if cmp is not None and any((eq is not None, order is not None)):
ctx.api.fail("Don't mix `cmp` with `eq' and `order`", ctx.reason)

# cmp takes precedence due to bw-compatibility.
if cmp is not None:
ctx.api.fail("cmp is deprecated, use eq and order", ctx.reason)
return cmp, cmp

# If left None, equality is on and ordering mirrors equality.
if eq is None:
eq = True

if order is None:
order = eq

if eq is False and order is True:
ctx.api.fail("eq must be True if order is True", ctx.reason)

return eq, order


def _get_decorator_optional_bool_argument(
ctx: 'mypy.plugin.ClassDefContext',
name: str,
default: Optional[bool] = None,
) -> Optional[bool]:
"""Return the Optional[bool] argument for the decorator.

This handles both @decorator(...) and @decorator.
"""
if isinstance(ctx.reason, CallExpr):
attr_value = _get_argument(ctx.reason, name)
if attr_value:
if isinstance(attr_value, NameExpr):
if attr_value.fullname == 'builtins.True':
return True
if attr_value.fullname == 'builtins.False':
return False
if attr_value.fullname == 'builtins.None':
return None
ctx.api.fail('"{}" argument must be True or False.'.format(name), ctx.reason)
return default
return default
else:
return default


def attr_class_maker_callback(ctx: 'mypy.plugin.ClassDefContext',
auto_attribs_default: bool = False) -> None:
"""Add necessary dunder methods to classes decorated with attr.s.
Expand All @@ -193,7 +249,8 @@ def attr_class_maker_callback(ctx: 'mypy.plugin.ClassDefContext',

init = _get_decorator_bool_argument(ctx, 'init', True)
frozen = _get_frozen(ctx)
cmp = _get_decorator_bool_argument(ctx, 'cmp', True)
eq, order = _determine_eq_order(ctx)

auto_attribs = _get_decorator_bool_argument(ctx, 'auto_attribs', auto_attribs_default)
kw_only = _get_decorator_bool_argument(ctx, 'kw_only', False)

Expand Down Expand Up @@ -231,8 +288,10 @@ def attr_class_maker_callback(ctx: 'mypy.plugin.ClassDefContext',
adder = MethodAdder(ctx)
if init:
_add_init(ctx, attributes, adder)
if cmp:
_add_cmp(ctx, adder)
if eq:
_add_eq(ctx, adder)
if order:
_add_order(ctx, adder)
if frozen:
_make_frozen(ctx, attributes)

Expand Down Expand Up @@ -529,16 +588,22 @@ def _parse_assignments(
return lvalues, rvalues


def _add_cmp(ctx: 'mypy.plugin.ClassDefContext', adder: 'MethodAdder') -> None:
"""Generate all the cmp methods for this class."""
def _add_eq(ctx: 'mypy.plugin.ClassDefContext', adder: 'MethodAdder') -> None:
"""Generate __eq__ and __ne__ for this class."""
# For __ne__ and __eq__ the type is:
# def __ne__(self, other: object) -> bool
bool_type = ctx.api.named_type('__builtins__.bool')
object_type = ctx.api.named_type('__builtins__.object')
args = [Argument(Var('other', object_type), object_type, None, ARG_POS)]
for method in ['__ne__', '__eq__']:
adder.add_method(method, args, bool_type)
# For the rest we use:


def _add_order(ctx: 'mypy.plugin.ClassDefContext', adder: 'MethodAdder') -> None:
"""Generate all the ordering methods for this class."""
bool_type = ctx.api.named_type('__builtins__.bool')
object_type = ctx.api.named_type('__builtins__.object')
# Make the types be:
# AT = TypeVar('AT')
# def __lt__(self: AT, other: AT) -> bool
# This way comparisons with subclasses will work correctly.
Expand Down
55 changes: 53 additions & 2 deletions test-data/unit/check-attr.test
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,9 @@ A(1) != 1
1 != A(1)
[builtins fixtures/attr.pyi]

[case testAttrsCmpFalse]
[case testAttrsEqFalse]
from attr import attrib, attrs
@attrs(auto_attribs=True, cmp=False)
@attrs(auto_attribs=True, eq=False)
class A:
a: int
reveal_type(A) # N: Revealed type is 'def (a: builtins.int) -> __main__.A'
Expand Down Expand Up @@ -245,6 +245,57 @@ A(1) != 1
1 != A(1)
[builtins fixtures/attr.pyi]

[case testAttrsOrderFalse]
from attr import attrib, attrs
@attrs(auto_attribs=True, order=False)
class A:
a: int
reveal_type(A) # N: Revealed type is 'def (a: builtins.int) -> __main__.A'
reveal_type(A.__eq__) # N: Revealed type is 'def (self: __main__.A, other: builtins.object) -> builtins.bool'
reveal_type(A.__ne__) # N: Revealed type is 'def (self: __main__.A, other: builtins.object) -> builtins.bool'

A(1) < A(2) # E: Unsupported left operand type for < ("A")
A(1) <= A(2) # E: Unsupported left operand type for <= ("A")
A(1) > A(2) # E: Unsupported left operand type for > ("A")
A(1) >= A(2) # E: Unsupported left operand type for >= ("A")
A(1) == A(2)
A(1) != A(2)

A(1) < 1 # E: Unsupported left operand type for < ("A")
A(1) <= 1 # E: Unsupported left operand type for <= ("A")
A(1) > 1 # E: Unsupported left operand type for > ("A")
A(1) >= 1 # E: Unsupported left operand type for >= ("A")
A(1) == 1
A(1) != 1

1 < A(1) # E: Unsupported left operand type for < ("int")
1 <= A(1) # E: Unsupported left operand type for <= ("int")
1 > A(1) # E: Unsupported left operand type for > ("int")
1 >= A(1) # E: Unsupported left operand type for >= ("int")
1 == A(1)
1 != A(1)
[builtins fixtures/attr.pyi]

[case testAttrsCmpEqOrderValues]
from attr import attrib, attrs
@attrs(cmp=True) # E: cmp is deprecated, use eq and order
class DeprecatedTrue:
...

@attrs(cmp=False) # E: cmp is deprecated, use eq and order
class DeprecatedFalse:
...

@attrs(cmp=False, eq=True) # E: Don't mix `cmp` with `eq' and `order` # E: cmp is deprecated, use eq and order
class Mixed:
...

@attrs(order=True, eq=False) # E: eq must be True if order is True
class Confused:
...
[builtins fixtures/attr.pyi]


[case testAttrsInheritance]
import attr
@attr.s
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-incremental.test
Original file line number Diff line number Diff line change
Expand Up @@ -2937,7 +2937,7 @@ class Frozen:
@attr.s(init=False)
class NoInit:
x: int = attr.ib()
@attr.s(cmp=False)
@attr.s(eq=False)
class NoCmp:
x: int = attr.ib()

Expand Down
4 changes: 2 additions & 2 deletions test-data/unit/fine-grained.test
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,7 @@ B(1, 2) < B(1, 2)
[file b.py]
from a import A
import attr
@attr.s(cmp=False)
@attr.s(eq=False)
class B(A):
b = attr.ib() # type: int
[file a.py]
Expand All @@ -1016,7 +1016,7 @@ class A:

[file a.py.2]
import attr
@attr.s(cmp=False, init=False)
@attr.s(eq=False, init=False)
class A:
a = attr.ib() # type: int
[builtins fixtures/list.pyi]
Expand Down
24 changes: 18 additions & 6 deletions test-data/unit/lib-stub/attr.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ _ValidatorArgType = Union[_ValidatorType[_T], Sequence[_ValidatorType[_T]]]
def attrib(default: None = ...,
validator: None = ...,
repr: bool = ...,
cmp: bool = ...,
cmp: Optional[bool] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
convert: None = ...,
Expand All @@ -22,13 +22,15 @@ def attrib(default: None = ...,
converter: None = ...,
factory: None = ...,
kw_only: bool = ...,
eq: Optional[bool] = ...,
order: Optional[bool] = ...,
) -> Any: ...
# This form catches an explicit None or no default and infers the type from the other arguments.
@overload
def attrib(default: None = ...,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: bool = ...,
cmp: bool = ...,
cmp: Optional[bool] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
convert: Optional[_ConverterType[_T]] = ...,
Expand All @@ -37,13 +39,15 @@ def attrib(default: None = ...,
converter: Optional[_ConverterType[_T]] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[bool] = ...,
order: Optional[bool] = ...,
) -> _T: ...
# This form catches an explicit default argument.
@overload
def attrib(default: _T,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: bool = ...,
cmp: bool = ...,
cmp: Optional[bool] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
convert: Optional[_ConverterType[_T]] = ...,
Expand All @@ -52,13 +56,15 @@ def attrib(default: _T,
converter: Optional[_ConverterType[_T]] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[bool] = ...,
order: Optional[bool] = ...,
) -> _T: ...
# This form covers type=non-Type: e.g. forward references (str), Any
@overload
def attrib(default: Optional[_T] = ...,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: bool = ...,
cmp: bool = ...,
cmp: Optional[bool] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
convert: Optional[_ConverterType[_T]] = ...,
Expand All @@ -67,14 +73,16 @@ def attrib(default: Optional[_T] = ...,
converter: Optional[_ConverterType[_T]] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[bool] = ...,
order: Optional[bool] = ...,
) -> Any: ...

@overload
def attrs(maybe_cls: _C,
these: Optional[Mapping[str, Any]] = ...,
repr_ns: Optional[str] = ...,
repr: bool = ...,
cmp: bool = ...,
cmp: Optional[bool] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
slots: bool = ...,
Expand All @@ -84,13 +92,15 @@ def attrs(maybe_cls: _C,
auto_attribs: bool = ...,
kw_only: bool = ...,
cache_hash: bool = ...,
eq: Optional[bool] = ...,
order: Optional[bool] = ...,
) -> _C: ...
@overload
def attrs(maybe_cls: None = ...,
these: Optional[Mapping[str, Any]] = ...,
repr_ns: Optional[str] = ...,
repr: bool = ...,
cmp: bool = ...,
cmp: Optional[bool] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
slots: bool = ...,
Expand All @@ -100,6 +110,8 @@ def attrs(maybe_cls: None = ...,
auto_attribs: bool = ...,
kw_only: bool = ...,
cache_hash: bool = ...,
eq: Optional[bool] = ...,
order: Optional[bool] = ...,
) -> Callable[[_C], _C]: ...


Expand Down