Skip to content

Commit c244872

Browse files
committed
Merge branch 'main' into typ_v4
2 parents 2954efd + 6ae28ee commit c244872

File tree

2 files changed

+161
-156
lines changed

2 files changed

+161
-156
lines changed

array_api_compat/common/_helpers.py

Lines changed: 106 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from __future__ import annotations
1010

11+
import enum
1112
import inspect
1213
import math
1314
import sys
@@ -481,6 +482,86 @@ def _check_api_version(api_version: str | None) -> None:
481482
)
482483

483484

485+
class _ClsToXPInfo(enum.Enum):
486+
SCALAR = 0
487+
MAYBE_JAX_ZERO_GRADIENT = 1
488+
489+
490+
@lru_cache(100)
491+
def _cls_to_namespace(
492+
cls: type,
493+
api_version: str | None,
494+
use_compat: bool | None,
495+
) -> tuple[Namespace | None, _ClsToXPInfo | None]:
496+
if use_compat not in (None, True, False):
497+
raise ValueError("use_compat must be None, True, or False")
498+
_use_compat = use_compat in (None, True)
499+
cls_ = cast(Hashable, cls) # Make mypy happy
500+
501+
if (
502+
_issubclass_fast(cls_, "numpy", "ndarray")
503+
or _issubclass_fast(cls_, "numpy", "generic")
504+
):
505+
if use_compat is True:
506+
_check_api_version(api_version)
507+
from .. import numpy as xp
508+
elif use_compat is False:
509+
import numpy as xp # type: ignore[no-redef]
510+
else:
511+
# NumPy 2.0+ have __array_namespace__; however they are not
512+
# yet fully array API compatible.
513+
from .. import numpy as xp # type: ignore[no-redef]
514+
return xp, _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT
515+
516+
# Note: this must happen _after_ the test for np.generic,
517+
# because np.float64 and np.complex128 are subclasses of float and complex.
518+
if issubclass(cls, int | float | complex | type(None)):
519+
return None, _ClsToXPInfo.SCALAR
520+
521+
if _issubclass_fast(cls_, "cupy", "ndarray"):
522+
if _use_compat:
523+
_check_api_version(api_version)
524+
from .. import cupy as xp # type: ignore[no-redef]
525+
else:
526+
import cupy as xp # type: ignore[no-redef]
527+
return xp, None
528+
529+
if _issubclass_fast(cls_, "torch", "Tensor"):
530+
if _use_compat:
531+
_check_api_version(api_version)
532+
from .. import torch as xp # type: ignore[no-redef]
533+
else:
534+
import torch as xp # type: ignore[no-redef]
535+
return xp, None
536+
537+
if _issubclass_fast(cls_, "dask.array", "Array"):
538+
if _use_compat:
539+
_check_api_version(api_version)
540+
from ..dask import array as xp # type: ignore[no-redef]
541+
else:
542+
import dask.array as xp # type: ignore[no-redef]
543+
return xp, None
544+
545+
# Backwards compatibility for jax<0.4.32
546+
if _issubclass_fast(cls_, "jax", "Array"):
547+
return _jax_namespace(api_version, use_compat), None
548+
549+
return None, None
550+
551+
552+
def _jax_namespace(api_version: str | None, use_compat: bool | None) -> Namespace:
553+
if use_compat:
554+
raise ValueError("JAX does not have an array-api-compat wrapper")
555+
import jax.numpy as jnp
556+
if not hasattr(jnp, "__array_namespace_info__"):
557+
# JAX v0.4.32 and newer implements the array API directly in jax.numpy.
558+
# For older JAX versions, it is available via jax.experimental.array_api.
559+
# jnp.Array objects gain the __array_namespace__ method.
560+
import jax.experimental.array_api # noqa: F401
561+
# Test api_version
562+
return jnp.empty(0).__array_namespace__(api_version=api_version)
563+
564+
484565
def array_namespace(
485566
*xs: Array | complex | None,
486567
api_version: str | None = None,
@@ -549,105 +630,40 @@ def your_function(x, y):
549630
is_pydata_sparse_array
550631
551632
"""
552-
if use_compat not in [None, True, False]:
553-
raise ValueError("use_compat must be None, True, or False")
554-
555-
_use_compat = use_compat in [None, True]
556-
557633
namespaces: set[Namespace] = set()
558634
for x in xs:
559-
if is_numpy_array(x):
560-
import numpy as np
561-
562-
from .. import numpy as numpy_namespace
563-
564-
if use_compat is True:
565-
_check_api_version(api_version)
566-
namespaces.add(numpy_namespace)
567-
elif use_compat is False:
568-
namespaces.add(np)
569-
else:
570-
# numpy 2.0+ have __array_namespace__, however, they are not yet fully array API
571-
# compatible.
572-
namespaces.add(numpy_namespace)
573-
elif is_cupy_array(x):
574-
if _use_compat:
575-
_check_api_version(api_version)
576-
from .. import cupy as cupy_namespace
577-
578-
namespaces.add(cupy_namespace)
579-
else:
580-
import cupy as cp
581-
582-
namespaces.add(cp)
583-
elif is_torch_array(x):
584-
if _use_compat:
585-
_check_api_version(api_version)
586-
from .. import torch as torch_namespace
587-
588-
namespaces.add(torch_namespace)
589-
else:
590-
import torch
591-
592-
namespaces.add(torch)
593-
elif is_dask_array(x):
594-
if _use_compat:
595-
_check_api_version(api_version)
596-
from ..dask import array as dask_namespace
597-
598-
namespaces.add(dask_namespace)
599-
else:
600-
import dask.array as da
601-
602-
namespaces.add(da)
603-
elif is_jax_array(x):
604-
if use_compat is True:
605-
_check_api_version(api_version)
606-
raise ValueError("JAX does not have an array-api-compat wrapper")
607-
elif use_compat is False:
608-
import jax.numpy as jnp
609-
else:
610-
# JAX v0.4.32 and newer implements the array API directly in jax.numpy.
611-
# For older JAX versions, it is available via jax.experimental.array_api.
612-
import jax.numpy
613-
614-
if hasattr(jax.numpy, "__array_api_version__"):
615-
jnp = jax.numpy
616-
else:
617-
import jax.experimental.array_api as jnp # type: ignore[no-redef]
618-
namespaces.add(jnp)
619-
elif is_pydata_sparse_array(x):
620-
if use_compat is True:
621-
_check_api_version(api_version)
622-
raise ValueError("`sparse` does not have an array-api-compat wrapper")
623-
else:
624-
import sparse
625-
# `sparse` is already an array namespace. We do not have a wrapper
626-
# submodule for it.
627-
namespaces.add(sparse)
628-
elif hasattr(x, "__array_namespace__"):
629-
if use_compat is True:
635+
xp, info = _cls_to_namespace(cast(Hashable, type(x)), api_version, use_compat)
636+
if info is _ClsToXPInfo.SCALAR:
637+
continue
638+
639+
if (
640+
info is _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT
641+
and _is_jax_zero_gradient_array(x)
642+
):
643+
xp = _jax_namespace(api_version, use_compat)
644+
645+
if xp is None:
646+
get_ns = getattr(x, "__array_namespace__", None)
647+
if get_ns is None:
648+
raise TypeError(f"{type(x).__name__} is not a supported array type")
649+
if use_compat:
630650
raise ValueError(
631651
"The given array does not have an array-api-compat wrapper"
632652
)
633-
x = cast("SupportsArrayNamespace[Any]", x)
634-
namespaces.add(x.__array_namespace__(api_version=api_version))
635-
elif isinstance(x, int | float | complex) or x is None:
636-
continue
637-
else:
638-
# TODO: Support Python scalars?
639-
raise TypeError(f"{type(x).__name__} is not a supported array type")
653+
xp = get_ns(api_version=api_version)
640654

641-
if not namespaces:
642-
raise TypeError("Unrecognized array input")
655+
namespaces.add(xp)
643656

644-
if len(namespaces) != 1:
657+
try:
658+
(xp,) = namespaces
659+
return xp
660+
except ValueError:
661+
if not namespaces:
662+
raise TypeError(
663+
"array_namespace requires at least one non-scalar array input"
664+
)
645665
raise TypeError(f"Multiple namespaces for array inputs: {namespaces}")
646666

647-
(xp,) = namespaces
648-
649-
return xp
650-
651667

652668
# backwards compatibility alias
653669
get_namespace = array_namespace

0 commit comments

Comments
 (0)