Skip to content

TYP: Type annotations, part 4 #313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions array_api_compat/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def wrapped_f(*args: object, **kwargs: object) -> object:
specification for more details.

"""
wrapped_f.__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue]
return wrapped_f # pyright: ignore[reportReturnType]
wrapped_f.__signature__ = new_sig # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
return wrapped_f # type: ignore[return-value] # pyright: ignore[reportReturnType]

return inner

Expand Down
14 changes: 8 additions & 6 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from __future__ import annotations

import inspect
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast
from collections.abc import Sequence
from types import NoneType
from typing import TYPE_CHECKING, Any, NamedTuple, cast

from ._helpers import _check_device, array_namespace
from ._helpers import device as _get_device
from ._helpers import is_cupy_namespace as _is_cupy_namespace
from ._helpers import is_cupy_namespace
from ._typing import Array, Device, DType, Namespace

if TYPE_CHECKING:
Expand Down Expand Up @@ -381,8 +383,8 @@ def clip(
# TODO: np.clip has other ufunc kwargs
out: Array | None = None,
) -> Array:
def _isscalar(a: object) -> TypeIs[int | float | None]:
return isinstance(a, (int, float, type(None)))
def _isscalar(a: object) -> TypeIs[float | None]:
return isinstance(a, int | float | NoneType)

min_shape = () if _isscalar(min) else min.shape
max_shape = () if _isscalar(max) else max.shape
Expand Down Expand Up @@ -450,7 +452,7 @@ def reshape(
shape: tuple[int, ...],
xp: Namespace,
*,
copy: Optional[bool] = None,
copy: bool | None = None,
**kwargs: object,
) -> Array:
if copy is True:
Expand Down Expand Up @@ -657,7 +659,7 @@ def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
out = xp.sign(x, **kwargs)
# CuPy sign() does not propagate nans. See
# https://github.com/data-apis/array-api-compat/issues/136
if _is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp):
if is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp):
out[xp.isnan(x)] = xp.nan
return out[()]

Expand Down
62 changes: 29 additions & 33 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,50 +12,46 @@
import math
import sys
import warnings
from collections.abc import Collection, Hashable
from collections.abc import Hashable
from functools import lru_cache
from types import NoneType
from typing import (
TYPE_CHECKING,
Any,
Final,
Literal,
SupportsIndex,
TypeAlias,
TypeGuard,
TypeVar,
cast,
overload,
)

from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace

if TYPE_CHECKING:

import cupy as cp
import dask.array as da
import jax
import ndonnx as ndx
import numpy as np
import numpy.typing as npt
import sparse # pyright: ignore[reportMissingTypeStubs]
import sparse
import torch

# TODO: import from typing (requires Python >=3.13)
from typing_extensions import TypeIs, TypeVar

_SizeT = TypeVar("_SizeT", bound = int | None)
from typing_extensions import TypeIs

_ZeroGradientArray: TypeAlias = npt.NDArray[np.void]
_CupyArray: TypeAlias = Any # cupy has no py.typed

_ArrayApiObj: TypeAlias = (
npt.NDArray[Any]
| cp.ndarray
| da.Array
| jax.Array
| ndx.Array
| sparse.SparseArray
| torch.Tensor
| SupportsArrayNamespace[Any]
| _CupyArray
| SupportsArrayNamespace
)

_API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"})
Expand Down Expand Up @@ -95,7 +91,7 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
return dtype == jax.float0


def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]:
def is_numpy_array(x: object) -> TypeIs[npt.NDArray[Any]]:
"""
Return True if `x` is a NumPy array.

Expand Down Expand Up @@ -266,7 +262,7 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
return _issubclass_fast(cls, "sparse", "SparseArray")


def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType]
def is_array_api_obj(x: object) -> TypeGuard[_ArrayApiObj]:
"""
Return True if `x` is an array API compatible array object.

Expand Down Expand Up @@ -581,7 +577,7 @@ def your_function(x, y):

namespaces.add(cupy_namespace)
else:
import cupy as cp # pyright: ignore[reportMissingTypeStubs]
import cupy as cp

namespaces.add(cp)
elif is_torch_array(x):
Expand Down Expand Up @@ -618,14 +614,14 @@ def your_function(x, y):
if hasattr(jax.numpy, "__array_api_version__"):
jnp = jax.numpy
else:
import jax.experimental.array_api as jnp # pyright: ignore[reportMissingImports]
import jax.experimental.array_api as jnp # type: ignore[no-redef]
namespaces.add(jnp)
elif is_pydata_sparse_array(x):
if use_compat is True:
_check_api_version(api_version)
raise ValueError("`sparse` does not have an array-api-compat wrapper")
else:
import sparse # pyright: ignore[reportMissingTypeStubs]
import sparse
# `sparse` is already an array namespace. We do not have a wrapper
# submodule for it.
namespaces.add(sparse)
Expand All @@ -634,9 +630,9 @@ def your_function(x, y):
raise ValueError(
"The given array does not have an array-api-compat wrapper"
)
x = cast("SupportsArrayNamespace[Any]", x)
x = cast(SupportsArrayNamespace, x)
namespaces.add(x.__array_namespace__(api_version=api_version))
elif isinstance(x, (bool, int, float, complex, type(None))):
elif isinstance(x, int | float | complex | NoneType):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif isinstance(x, int | float | complex | NoneType):
elif x is None or isinstance(x, int | float | complex):

(I'll spare you the pseudo-philosophical rant this time)

continue
else:
# TODO: Support Python scalars?
Expand Down Expand Up @@ -732,7 +728,7 @@ def device(x: _ArrayApiObj, /) -> Device:
return "cpu"
elif is_dask_array(x):
# Peek at the metadata of the Dask array to determine type
if is_numpy_array(x._meta): # pyright: ignore
if is_numpy_array(x._meta):
# Must be on CPU since backed by numpy
return "cpu"
return _DASK_DEVICE
Expand Down Expand Up @@ -761,7 +757,7 @@ def device(x: _ArrayApiObj, /) -> Device:
return "cpu"
# Return the device of the constituent array
return device(inner) # pyright: ignore
return x.device # pyright: ignore
return x.device # type: ignore # pyright: ignore


# Prevent shadowing, used below
Expand All @@ -770,11 +766,11 @@ def device(x: _ArrayApiObj, /) -> Device:

# Based on cupy.array_api.Array.to_device
def _cupy_to_device(
x: _CupyArray,
x: cp.ndarray,
device: Device,
/,
stream: int | Any | None = None,
) -> _CupyArray:
) -> cp.ndarray:
import cupy as cp

if device == "cpu":
Expand Down Expand Up @@ -803,7 +799,7 @@ def _torch_to_device(
x: torch.Tensor,
device: torch.device | str | int,
/,
stream: None = None,
stream: int | Any | None = None,
) -> torch.Tensor:
if stream is not None:
raise NotImplementedError
Expand Down Expand Up @@ -869,7 +865,7 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -
# cupy does not yet have to_device
return _cupy_to_device(x, device, stream=stream)
elif is_torch_array(x):
return _torch_to_device(x, device, stream=stream) # pyright: ignore[reportArgumentType]
return _torch_to_device(x, device, stream=stream)
elif is_dask_array(x):
if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
Expand All @@ -894,12 +890,12 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -


@overload
def size(x: HasShape[Collection[SupportsIndex]]) -> int: ...
def size(x: HasShape[int]) -> int: ...
@overload
def size(x: HasShape[Collection[None]]) -> None: ...
def size(x: HasShape[int | None]) -> int | None: ...
@overload
def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ...
def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
def size(x: HasShape[float]) -> int | None: ... # Dask special case
def size(x: HasShape[float | None]) -> int | None:
"""
Return the total number of elements of x.

Expand All @@ -914,9 +910,9 @@ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
# Lazy API compliant arrays, such as ndonnx, can contain None in their shape
if None in x.shape:
return None
out = math.prod(cast("Collection[SupportsIndex]", x.shape))
out = math.prod(cast(tuple[float, ...], x.shape))
# dask.array.Array.shape can contain NaN
return None if math.isnan(out) else out
return None if math.isnan(out) else cast(int, out)


@lru_cache(100)
Expand All @@ -932,7 +928,7 @@ def _is_writeable_cls(cls: type) -> bool | None:
return None


def is_writeable_array(x: object) -> bool:
def is_writeable_array(x: object) -> TypeGuard[_ArrayApiObj]:
"""
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
Return False if `x` is not an array API compatible object.
Expand Down Expand Up @@ -970,7 +966,7 @@ def _is_lazy_cls(cls: type) -> bool | None:
return None


def is_lazy_array(x: object) -> bool:
def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]:
"""Return True if x is potentially a future or it may be otherwise impossible or
expensive to eagerly read its contents, regardless of their size, e.g. by
calling ``bool(x)`` or ``float(x)``.
Expand Down Expand Up @@ -1007,7 +1003,7 @@ def is_lazy_array(x: object) -> bool:
# on __bool__ (dask is one such example, which however is special-cased above).

# Select a single point of the array
s = size(cast("HasShape[Collection[SupportsIndex | None]]", x))
s = size(cast(HasShape, x))
if s is None:
return True
xp = array_namespace(x)
Expand Down
2 changes: 1 addition & 1 deletion array_api_compat/common/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
if np.__version__[0] == "2":
from numpy.lib.array_utils import normalize_axis_tuple
else:
from numpy.core.numeric import normalize_axis_tuple
from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef]

from .._internal import get_xp
from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot
Expand Down
Loading
Loading