|
8 | 8 |
|
9 | 9 | from __future__ import annotations
|
10 | 10 |
|
| 11 | +import enum |
11 | 12 | import inspect
|
12 | 13 | import math
|
13 | 14 | import sys
|
@@ -481,6 +482,86 @@ def _check_api_version(api_version: str | None) -> None:
|
481 | 482 | )
|
482 | 483 |
|
483 | 484 |
|
| 485 | +class _ClsToXPInfo(enum.Enum): |
| 486 | + SCALAR = 0 |
| 487 | + MAYBE_JAX_ZERO_GRADIENT = 1 |
| 488 | + |
| 489 | + |
| 490 | +@lru_cache(100) |
| 491 | +def _cls_to_namespace( |
| 492 | + cls: type, |
| 493 | + api_version: str | None, |
| 494 | + use_compat: bool | None, |
| 495 | +) -> tuple[Namespace | None, _ClsToXPInfo | None]: |
| 496 | + if use_compat not in (None, True, False): |
| 497 | + raise ValueError("use_compat must be None, True, or False") |
| 498 | + _use_compat = use_compat in (None, True) |
| 499 | + cls_ = cast(Hashable, cls) # Make mypy happy |
| 500 | + |
| 501 | + if ( |
| 502 | + _issubclass_fast(cls_, "numpy", "ndarray") |
| 503 | + or _issubclass_fast(cls_, "numpy", "generic") |
| 504 | + ): |
| 505 | + if use_compat is True: |
| 506 | + _check_api_version(api_version) |
| 507 | + from .. import numpy as xp |
| 508 | + elif use_compat is False: |
| 509 | + import numpy as xp # type: ignore[no-redef] |
| 510 | + else: |
| 511 | + # NumPy 2.0+ have __array_namespace__; however they are not |
| 512 | + # yet fully array API compatible. |
| 513 | + from .. import numpy as xp # type: ignore[no-redef] |
| 514 | + return xp, _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT |
| 515 | + |
| 516 | + # Note: this must happen _after_ the test for np.generic, |
| 517 | + # because np.float64 and np.complex128 are subclasses of float and complex. |
| 518 | + if issubclass(cls, int | float | complex | type(None)): |
| 519 | + return None, _ClsToXPInfo.SCALAR |
| 520 | + |
| 521 | + if _issubclass_fast(cls_, "cupy", "ndarray"): |
| 522 | + if _use_compat: |
| 523 | + _check_api_version(api_version) |
| 524 | + from .. import cupy as xp # type: ignore[no-redef] |
| 525 | + else: |
| 526 | + import cupy as xp # type: ignore[no-redef] |
| 527 | + return xp, None |
| 528 | + |
| 529 | + if _issubclass_fast(cls_, "torch", "Tensor"): |
| 530 | + if _use_compat: |
| 531 | + _check_api_version(api_version) |
| 532 | + from .. import torch as xp # type: ignore[no-redef] |
| 533 | + else: |
| 534 | + import torch as xp # type: ignore[no-redef] |
| 535 | + return xp, None |
| 536 | + |
| 537 | + if _issubclass_fast(cls_, "dask.array", "Array"): |
| 538 | + if _use_compat: |
| 539 | + _check_api_version(api_version) |
| 540 | + from ..dask import array as xp # type: ignore[no-redef] |
| 541 | + else: |
| 542 | + import dask.array as xp # type: ignore[no-redef] |
| 543 | + return xp, None |
| 544 | + |
| 545 | + # Backwards compatibility for jax<0.4.32 |
| 546 | + if _issubclass_fast(cls_, "jax", "Array"): |
| 547 | + return _jax_namespace(api_version, use_compat), None |
| 548 | + |
| 549 | + return None, None |
| 550 | + |
| 551 | + |
| 552 | +def _jax_namespace(api_version: str | None, use_compat: bool | None) -> Namespace: |
| 553 | + if use_compat: |
| 554 | + raise ValueError("JAX does not have an array-api-compat wrapper") |
| 555 | + import jax.numpy as jnp |
| 556 | + if not hasattr(jnp, "__array_namespace_info__"): |
| 557 | + # JAX v0.4.32 and newer implements the array API directly in jax.numpy. |
| 558 | + # For older JAX versions, it is available via jax.experimental.array_api. |
| 559 | + # jnp.Array objects gain the __array_namespace__ method. |
| 560 | + import jax.experimental.array_api # noqa: F401 |
| 561 | + # Test api_version |
| 562 | + return jnp.empty(0).__array_namespace__(api_version=api_version) |
| 563 | + |
| 564 | + |
484 | 565 | def array_namespace(
|
485 | 566 | *xs: Array | complex | None,
|
486 | 567 | api_version: str | None = None,
|
@@ -549,105 +630,40 @@ def your_function(x, y):
|
549 | 630 | is_pydata_sparse_array
|
550 | 631 |
|
551 | 632 | """
|
552 |
| - if use_compat not in [None, True, False]: |
553 |
| - raise ValueError("use_compat must be None, True, or False") |
554 |
| - |
555 |
| - _use_compat = use_compat in [None, True] |
556 |
| - |
557 | 633 | namespaces: set[Namespace] = set()
|
558 | 634 | for x in xs:
|
559 |
| - if is_numpy_array(x): |
560 |
| - import numpy as np |
561 |
| - |
562 |
| - from .. import numpy as numpy_namespace |
563 |
| - |
564 |
| - if use_compat is True: |
565 |
| - _check_api_version(api_version) |
566 |
| - namespaces.add(numpy_namespace) |
567 |
| - elif use_compat is False: |
568 |
| - namespaces.add(np) |
569 |
| - else: |
570 |
| - # numpy 2.0+ have __array_namespace__, however, they are not yet fully array API |
571 |
| - # compatible. |
572 |
| - namespaces.add(numpy_namespace) |
573 |
| - elif is_cupy_array(x): |
574 |
| - if _use_compat: |
575 |
| - _check_api_version(api_version) |
576 |
| - from .. import cupy as cupy_namespace |
577 |
| - |
578 |
| - namespaces.add(cupy_namespace) |
579 |
| - else: |
580 |
| - import cupy as cp |
581 |
| - |
582 |
| - namespaces.add(cp) |
583 |
| - elif is_torch_array(x): |
584 |
| - if _use_compat: |
585 |
| - _check_api_version(api_version) |
586 |
| - from .. import torch as torch_namespace |
587 |
| - |
588 |
| - namespaces.add(torch_namespace) |
589 |
| - else: |
590 |
| - import torch |
591 |
| - |
592 |
| - namespaces.add(torch) |
593 |
| - elif is_dask_array(x): |
594 |
| - if _use_compat: |
595 |
| - _check_api_version(api_version) |
596 |
| - from ..dask import array as dask_namespace |
597 |
| - |
598 |
| - namespaces.add(dask_namespace) |
599 |
| - else: |
600 |
| - import dask.array as da |
601 |
| - |
602 |
| - namespaces.add(da) |
603 |
| - elif is_jax_array(x): |
604 |
| - if use_compat is True: |
605 |
| - _check_api_version(api_version) |
606 |
| - raise ValueError("JAX does not have an array-api-compat wrapper") |
607 |
| - elif use_compat is False: |
608 |
| - import jax.numpy as jnp |
609 |
| - else: |
610 |
| - # JAX v0.4.32 and newer implements the array API directly in jax.numpy. |
611 |
| - # For older JAX versions, it is available via jax.experimental.array_api. |
612 |
| - import jax.numpy |
613 |
| - |
614 |
| - if hasattr(jax.numpy, "__array_api_version__"): |
615 |
| - jnp = jax.numpy |
616 |
| - else: |
617 |
| - import jax.experimental.array_api as jnp # type: ignore[no-redef] |
618 |
| - namespaces.add(jnp) |
619 |
| - elif is_pydata_sparse_array(x): |
620 |
| - if use_compat is True: |
621 |
| - _check_api_version(api_version) |
622 |
| - raise ValueError("`sparse` does not have an array-api-compat wrapper") |
623 |
| - else: |
624 |
| - import sparse |
625 |
| - # `sparse` is already an array namespace. We do not have a wrapper |
626 |
| - # submodule for it. |
627 |
| - namespaces.add(sparse) |
628 |
| - elif hasattr(x, "__array_namespace__"): |
629 |
| - if use_compat is True: |
| 635 | + xp, info = _cls_to_namespace(cast(Hashable, type(x)), api_version, use_compat) |
| 636 | + if info is _ClsToXPInfo.SCALAR: |
| 637 | + continue |
| 638 | + |
| 639 | + if ( |
| 640 | + info is _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT |
| 641 | + and _is_jax_zero_gradient_array(x) |
| 642 | + ): |
| 643 | + xp = _jax_namespace(api_version, use_compat) |
| 644 | + |
| 645 | + if xp is None: |
| 646 | + get_ns = getattr(x, "__array_namespace__", None) |
| 647 | + if get_ns is None: |
| 648 | + raise TypeError(f"{type(x).__name__} is not a supported array type") |
| 649 | + if use_compat: |
630 | 650 | raise ValueError(
|
631 | 651 | "The given array does not have an array-api-compat wrapper"
|
632 | 652 | )
|
633 |
| - x = cast("SupportsArrayNamespace[Any]", x) |
634 |
| - namespaces.add(x.__array_namespace__(api_version=api_version)) |
635 |
| - elif isinstance(x, int | float | complex) or x is None: |
636 |
| - continue |
637 |
| - else: |
638 |
| - # TODO: Support Python scalars? |
639 |
| - raise TypeError(f"{type(x).__name__} is not a supported array type") |
| 653 | + xp = get_ns(api_version=api_version) |
640 | 654 |
|
641 |
| - if not namespaces: |
642 |
| - raise TypeError("Unrecognized array input") |
| 655 | + namespaces.add(xp) |
643 | 656 |
|
644 |
| - if len(namespaces) != 1: |
| 657 | + try: |
| 658 | + (xp,) = namespaces |
| 659 | + return xp |
| 660 | + except ValueError: |
| 661 | + if not namespaces: |
| 662 | + raise TypeError( |
| 663 | + "array_namespace requires at least one non-scalar array input" |
| 664 | + ) |
645 | 665 | raise TypeError(f"Multiple namespaces for array inputs: {namespaces}")
|
646 | 666 |
|
647 |
| - (xp,) = namespaces |
648 |
| - |
649 |
| - return xp |
650 |
| - |
651 | 667 |
|
652 | 668 | # backwards compatibility alias
|
653 | 669 | get_namespace = array_namespace
|
|
0 commit comments