Skip to content

MAINT: Array API 2024.12 typing nits #156

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ increase performance.
In particular, the following kinds of function are also in-scope:

- Functions which implement
[array API standard extension](https://data-apis.org/array-api/2023.12/extensions/index.html)
[array API standard extension](https://data-apis.org/array-api/latest/extensions/index.html)
functions in terms of functions from the base standard.
- Functions which add functionality (e.g. extra parameters) to functions from
the standard.
Expand Down
4 changes: 2 additions & 2 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def _delegate(xp: ModuleType, *backends: Backend) -> bool:


def isclose(
a: Array,
b: Array,
a: Array | complex,
b: Array | complex,
*,
rtol: float = 1e-05,
atol: float = 1e-08,
Expand Down
32 changes: 25 additions & 7 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings
from collections.abc import Sequence
from types import ModuleType
from typing import cast
from typing import TYPE_CHECKING, cast

from ._at import at
from ._utils import _compat, _helpers
Expand Down Expand Up @@ -375,8 +375,8 @@ def expand_dims(


def isclose(
a: Array,
b: Array,
a: Array | complex,
b: Array | complex,
*,
rtol: float = 1e-05,
atol: float = 1e-08,
Expand All @@ -385,6 +385,9 @@ def isclose(
) -> Array: # numpydoc ignore=PR01,RT01
"""See docstring in array_api_extra._delegation."""
a, b = asarrays(a, b, xp=xp)
if TYPE_CHECKING: # Hack around pyright bug # pragma: no cover
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the pyright bug?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given

def f(x: Array | complex) -> Array: ...

x: Array | complex
x = f(x)

when you reach the end of the snippet, mypy correctly understands that x can only be Array, whereas pyright still believes it could be complex.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, so is it about reusing variable names? Is there an upstream issue we can link to?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given the upstream answer, what would you like to do here? Use different variable names? Keep the current workarounds but remove the FIXME?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've replaced Array with a fully fledged Protocol. As it would vastly enlarge the scope of this PR, I instead opened a follow-up that reverts these three hacks: #159

assert _compat.is_array_api_obj(a)
assert _compat.is_array_api_obj(b)

a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating"))
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
Expand Down Expand Up @@ -419,7 +422,13 @@ def isclose(
return xp.abs(a - b) <= (atol + xp.abs(b) // nrtol)


def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
def kron(
a: Array | complex,
b: Array | complex,
/,
*,
xp: ModuleType | None = None,
) -> Array:
"""
Kronecker product of two arrays.

Expand Down Expand Up @@ -495,9 +504,14 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
if xp is None:
xp = array_namespace(a, b)
a, b = asarrays(a, b, xp=xp)
if TYPE_CHECKING: # Hack around pyright bug # pragma: no cover
assert _compat.is_array_api_obj(a)
assert _compat.is_array_api_obj(b)

singletons = (1,) * (b.ndim - a.ndim)
a = xp.broadcast_to(a, singletons + a.shape)
if TYPE_CHECKING: # Hack around pyright bug # pragma: no cover
assert _compat.is_array_api_obj(a)

nd_b, nd_a = b.ndim, a.ndim
nd_max = max(nd_b, nd_a)
Expand Down Expand Up @@ -614,8 +628,8 @@ def pad(


def setdiff1d(
x1: Array,
x2: Array,
x1: Array | complex,
x2: Array | complex,
/,
*,
assume_unique: bool = False,
Expand All @@ -628,7 +642,7 @@ def setdiff1d(

Parameters
----------
x1 : array
x1 : array | int | float | complex | bool
Input array.
x2 : array
Input comparison array.
Expand Down Expand Up @@ -665,6 +679,10 @@ def setdiff1d(
else:
x1 = xp.unique_values(x1)
x2 = xp.unique_values(x2)

if TYPE_CHECKING: # Hack around pyright bug # pragma: no cover
assert _compat.is_array_api_obj(x1)

return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]


Expand Down
25 changes: 14 additions & 11 deletions src/array_api_extra/_lib/_utils/_compat.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,30 @@ from __future__ import annotations

from types import ModuleType

# TODO import from typing (requires Python >=3.13)
from typing_extensions import TypeIs

from ._typing import Array, Device

# pylint: disable=missing-class-docstring,unused-argument

class ArrayModule(ModuleType):
class Namespace(ModuleType):
def device(self, x: Array, /) -> Device: ...

def array_namespace(
*xs: Array,
*xs: Array | complex | None,
api_version: str | None = None,
use_compat: bool | None = None,
) -> ArrayModule: ...
) -> Namespace: ...
def device(x: Array, /) -> Device: ...
def is_array_api_obj(x: object, /) -> bool: ...
def is_array_api_strict_namespace(xp: ModuleType, /) -> bool: ...
def is_cupy_namespace(xp: ModuleType, /) -> bool: ...
def is_dask_namespace(xp: ModuleType, /) -> bool: ...
def is_jax_namespace(xp: ModuleType, /) -> bool: ...
def is_numpy_namespace(xp: ModuleType, /) -> bool: ...
def is_pydata_sparse_namespace(xp: ModuleType, /) -> bool: ...
def is_torch_namespace(xp: ModuleType, /) -> bool: ...
def is_array_api_obj(x: object, /) -> TypeIs[Array]: ...
def is_array_api_strict_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
def is_cupy_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
def is_dask_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
def is_jax_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
def is_numpy_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
def is_pydata_sparse_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
def is_torch_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
def is_cupy_array(x: object, /) -> bool: ...
def is_dask_array(x: object, /) -> bool: ...
def is_jax_array(x: object, /) -> bool: ...
Expand Down
23 changes: 13 additions & 10 deletions src/array_api_extra/_lib/_utils/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@

from collections.abc import Generator
from types import ModuleType
from typing import cast
from typing import TYPE_CHECKING

from . import _compat
from ._compat import array_namespace, is_array_api_obj, is_numpy_array
from ._typing import Array

if TYPE_CHECKING: # pragma: no cover
# TODO import from typing (requires Python >=3.13)
from typing_extensions import TypeIs


__all__ = ["asarrays", "in1d", "is_python_scalar", "mean"]


Expand Down Expand Up @@ -96,16 +101,16 @@ def mean(
return xp.mean(x, axis=axis, keepdims=keepdims)


def is_python_scalar(x: object) -> bool: # numpydoc ignore=PR01,RT01
def is_python_scalar(x: object) -> TypeIs[complex]: # numpydoc ignore=PR01,RT01
"""Return True if `x` is a Python scalar, False otherwise."""
# isinstance(x, float) returns True for np.float64
# isinstance(x, complex) returns True for np.complex128
return isinstance(x, int | float | complex | bool) and not is_numpy_array(x)
return isinstance(x, int | float | complex) and not is_numpy_array(x)


def asarrays(
a: Array | int | float | complex | bool,
b: Array | int | float | complex | bool,
a: Array | complex,
b: Array | complex,
xp: ModuleType,
) -> tuple[Array, Array]:
"""
Expand Down Expand Up @@ -150,9 +155,7 @@ def asarrays(
if is_array_api_obj(a):
# a is an Array API object
# b is a int | float | complex | bool

# pyright doesn't like it if you reuse the same variable name
xa = cast(Array, a)
xa = a

# https://data-apis.org/array-api/draft/API_specification/type_promotion.html#mixing-arrays-with-python-scalars
same_dtype = {
Expand All @@ -162,8 +165,8 @@ def asarrays(
complex: "complex floating",
}
kind = same_dtype[type(b)] # type: ignore[index]
if xp.isdtype(xa.dtype, kind):
xb = xp.asarray(b, dtype=xa.dtype)
if xp.isdtype(a.dtype, kind):
xb = xp.asarray(b, dtype=a.dtype)
else:
# Undefined behaviour. Let the function deal with it, if it can.
xb = xp.asarray(b)
Expand Down