Skip to content

Commit de00bde

Browse files
committed
appease linter
1 parent aa0d364 commit de00bde

File tree

6 files changed

+48
-43
lines changed

6 files changed

+48
-43
lines changed

pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ reportAny = false
195195
reportExplicitAny = false
196196
# data-apis/array-api-strict#6
197197
reportUnknownMemberType = false
198+
# no array-api-compat type stubs
199+
reportUnknownVariableType = false
198200

199201

200202
# Ruff
@@ -230,9 +232,6 @@ ignore = [
230232
"PLR09", # Too many <...>
231233
"PLR2004", # Magic value used in comparison
232234
"ISC001", # Conflicts with formatter
233-
# "N802", # Function name should be lowercase
234-
# "N806", # Variable in function should be lowercase
235-
# "PD008", # pandas-use-of-dot-at
236235
]
237236

238237
[tool.ruff.lint.per-file-ignores]

src/array_api_extra/_funcs.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
import operator
44
import warnings
5-
from collections.abc import Callable
6-
from typing import Any
5+
6+
# https://github.com/pylint-dev/pylint/issues/10112
7+
from collections.abc import Callable # pylint: disable=import-error
8+
from typing import ClassVar
79

810
from ._lib import _utils
911
from ._lib._compat import (
@@ -12,7 +14,7 @@
1214
is_dask_array,
1315
is_writeable_array,
1416
)
15-
from ._lib._typing import Array, ModuleType
17+
from ._lib._typing import Array, Index, ModuleType, Untyped
1618

1719
__all__ = [
1820
"at",
@@ -559,7 +561,7 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
559561
_undef = object()
560562

561563

562-
class at:
564+
class at: # pylint: disable=invalid-name
563565
"""
564566
Update operations for read-only arrays.
565567
@@ -651,14 +653,14 @@ class at:
651653
"""
652654

653655
x: Array
654-
idx: Any
655-
__slots__ = ("idx", "x")
656+
idx: Index
657+
__slots__: ClassVar[tuple[str, str]] = ("idx", "x")
656658

657-
def __init__(self, x: Array, idx: Any = _undef, /):
659+
def __init__(self, x: Array, idx: Index = _undef, /):
658660
self.x = x
659661
self.idx = idx
660662

661-
def __getitem__(self, idx: Any) -> Any:
663+
def __getitem__(self, idx: Index) -> at:
662664
"""Allow for the alternate syntax ``at(x)[start:stop:step]``,
663665
which looks prettier than ``at(x, slice(start, stop, step))``
664666
and feels more intuitive coming from the JAX documentation.
@@ -677,8 +679,8 @@ def _common(
677679
copy: bool | None = True,
678680
xp: ModuleType | None = None,
679681
_is_update: bool = True,
680-
**kwargs: Any,
681-
) -> tuple[Any, None] | tuple[None, Array]:
682+
**kwargs: Untyped,
683+
) -> tuple[Untyped, None] | tuple[None, Array]:
682684
"""Perform common prepocessing.
683685
684686
Returns
@@ -706,11 +708,11 @@ def _common(
706708
if not writeable:
707709
msg = "Cannot modify parameter in place"
708710
raise ValueError(msg)
709-
elif copy is None:
711+
elif copy is None: # type: ignore[redundant-expr]
710712
writeable = is_writeable_array(x)
711713
copy = _is_update and not writeable
712714
else:
713-
msg = f"Invalid value for copy: {copy!r}" # type: ignore[unreachable]
715+
msg = f"Invalid value for copy: {copy!r}" # type: ignore[unreachable] # pyright: ignore[reportUnreachable]
714716
raise ValueError(msg)
715717

716718
if copy:
@@ -741,7 +743,7 @@ def _common(
741743

742744
return None, x
743745

744-
def get(self, **kwargs: Any) -> Any:
746+
def get(self, **kwargs: Untyped) -> Untyped:
745747
"""Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
746748
that the output is either a copy or a view; it also allows passing
747749
keyword arguments to the backend.
@@ -766,7 +768,7 @@ def get(self, **kwargs: Any) -> Any:
766768
assert x is not None
767769
return x[self.idx]
768770

769-
def set(self, y: Array, /, **kwargs: Any) -> Array:
771+
def set(self, y: Array, /, **kwargs: Untyped) -> Array:
770772
"""Apply ``x[idx] = y`` and return the update array"""
771773
res, x = self._common("set", y, **kwargs)
772774
if res is not None:
@@ -781,7 +783,7 @@ def _iop(
781783
elwise_op: Callable[[Array, Array], Array],
782784
y: Array,
783785
/,
784-
**kwargs: Any,
786+
**kwargs: Untyped,
785787
) -> Array:
786788
"""x[idx] += y or equivalent in-place operation on a subset of x
787789
@@ -799,33 +801,33 @@ def _iop(
799801
x[self.idx] = elwise_op(x[self.idx], y)
800802
return x
801803

802-
def add(self, y: Array, /, **kwargs: Any) -> Array:
804+
def add(self, y: Array, /, **kwargs: Untyped) -> Array:
803805
"""Apply ``x[idx] += y`` and return the updated array"""
804806
return self._iop("add", operator.add, y, **kwargs)
805807

806-
def subtract(self, y: Array, /, **kwargs: Any) -> Array:
808+
def subtract(self, y: Array, /, **kwargs: Untyped) -> Array:
807809
"""Apply ``x[idx] -= y`` and return the updated array"""
808810
return self._iop("subtract", operator.sub, y, **kwargs)
809811

810-
def multiply(self, y: Array, /, **kwargs: Any) -> Array:
812+
def multiply(self, y: Array, /, **kwargs: Untyped) -> Array:
811813
"""Apply ``x[idx] *= y`` and return the updated array"""
812814
return self._iop("multiply", operator.mul, y, **kwargs)
813815

814-
def divide(self, y: Array, /, **kwargs: Any) -> Array:
816+
def divide(self, y: Array, /, **kwargs: Untyped) -> Array:
815817
"""Apply ``x[idx] /= y`` and return the updated array"""
816818
return self._iop("divide", operator.truediv, y, **kwargs)
817819

818-
def power(self, y: Array, /, **kwargs: Any) -> Array:
820+
def power(self, y: Array, /, **kwargs: Untyped) -> Array:
819821
"""Apply ``x[idx] **= y`` and return the updated array"""
820822
return self._iop("power", operator.pow, y, **kwargs)
821823

822-
def min(self, y: Array, /, **kwargs: Any) -> Array:
824+
def min(self, y: Array, /, **kwargs: Untyped) -> Array:
823825
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
824826
xp = array_namespace(self.x)
825827
y = xp.asarray(y)
826828
return self._iop("min", xp.minimum, y, **kwargs)
827829

828-
def max(self, y: Array, /, **kwargs: Any) -> Array:
830+
def max(self, y: Array, /, **kwargs: Untyped) -> Array:
829831
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
830832
xp = array_namespace(self.x)
831833
y = xp.asarray(y)

src/array_api_extra/_lib/_compat.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,19 @@
44

55
try:
66
from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
7-
array_namespace, # pyright: ignore[reportUnknownVariableType]
8-
device, # pyright: ignore[reportUnknownVariableType]
9-
is_array_api_obj, # pyright: ignore[reportUnknownVariableType]
10-
is_dask_array, # pyright: ignore[reportUnknownVariableType]
11-
is_writeable_array, # pyright: ignore[reportUnknownVariableType]
7+
array_namespace,
8+
device,
9+
is_array_api_obj,
10+
is_dask_array,
11+
is_writeable_array,
1212
)
1313
except ImportError:
1414
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
15-
array_namespace, # pyright: ignore[reportUnknownVariableType]
15+
array_namespace,
1616
device,
17-
is_array_api_obj, # pyright: ignore[reportUnknownVariableType]
18-
is_dask_array, # pyright: ignore[reportUnknownVariableType]
19-
is_writeable_array, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue]
17+
is_array_api_obj,
18+
is_dask_array,
19+
is_writeable_array,
2020
)
2121

2222
__all__ = (

src/array_api_extra/_lib/_typing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
# To be changed to a Protocol later (see data-apis/array-api#589)
1111
Array = Any # type: ignore[no-any-explicit]
1212
Device = Any # type: ignore[no-any-explicit]
13+
Index = Any # type: ignore[no-any-explicit]
14+
Untyped = Any # type: ignore[no-any-explicit]
1315
else:
1416

1517
def no_op_decorator(f): # pyright: ignore[reportUnreachable]
@@ -19,4 +21,4 @@ def no_op_decorator(f): # pyright: ignore[reportUnreachable]
1921

2022
__all__ = ["ModuleType", "override"]
2123
if typing.TYPE_CHECKING:
22-
__all__ += ["Array", "Device"]
24+
__all__ += ["Array", "Device", "Index", "Untyped"]

tests/test_at.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import numpy as np
88
import pytest
9-
from array_api_compat import (
9+
from array_api_compat import ( # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs]
1010
array_namespace,
1111
is_dask_array,
1212
is_pydata_sparse_array,
@@ -16,7 +16,7 @@
1616
from array_api_extra import at
1717

1818
if TYPE_CHECKING:
19-
from array_api_extra._lib._typing import Array
19+
from array_api_extra._lib._typing import Array, Untyped
2020

2121
all_libraries = (
2222
"array_api_strict",
@@ -31,7 +31,7 @@
3131

3232

3333
@pytest.fixture(params=all_libraries)
34-
def array(request):
34+
def array(request: pytest.FixtureRequest) -> Array:
3535
library = request.param
3636
if library == "numpy_readonly":
3737
x = np.asarray([10.0, 20.0, 30.0])
@@ -55,7 +55,7 @@ def assert_array_equal(a: Array, b: Array) -> None:
5555

5656

5757
@contextmanager
58-
def assert_copy(array, copy: bool | None):
58+
def assert_copy(array: Array, copy: bool | None) -> Untyped: # type: ignore[no-any-decorated]
5959
# dask arrays are writeable, but writing to them will hot-swap the
6060
# dask graph inside the collection so that anything that references
6161
# the original graph, i.e. the input collection, won't be mutated.
@@ -86,7 +86,9 @@ def assert_copy(array, copy: bool | None):
8686
("max", 25.0, [10.0, 25.0, 30.0]),
8787
],
8888
)
89-
def test_update_ops(array, copy, op, arg, expect):
89+
def test_update_ops(
90+
array: Array, copy: bool | None, op: str, arg: float, expect: list[float]
91+
):
9092
if is_pydata_sparse_array(array):
9193
pytest.skip("at() does not support updates on sparse arrays")
9294

@@ -97,7 +99,7 @@ def test_update_ops(array, copy, op, arg, expect):
9799

98100

99101
@pytest.mark.parametrize("copy", [True, False, None])
100-
def test_get(array, copy):
102+
def test_get(array: Array, copy: bool | None):
101103
expect_copy = copy
102104

103105
# dask is mutable, but __getitem__ never returns a view
@@ -117,7 +119,7 @@ def test_get(array, copy):
117119
y[:] = 40
118120

119121

120-
def test_get_bool_indices(array):
122+
def test_get_bool_indices(array: Array):
121123
"""get() with a boolean array index always returns a copy"""
122124
# sparse violates the array API as it doesn't support
123125
# a boolean index that is another sparse array.

vendor_tests/test_vendor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def test_vendor_compat():
1414
)
1515

1616
x = xp.asarray([1, 2, 3])
17-
assert array_namespace(x) is xp # type: ignore[no-untyped-call]
17+
assert array_namespace(x) is xp
1818
device(x)
1919
assert is_array_api_obj(x)
2020
assert not is_array_api_obj(123)

0 commit comments

Comments
 (0)