Skip to content

Rename get_namespace to array_namespace #29

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# 1.1.1 (2023-03-08)

## Major Changes

- Rename `get_namespace()` to `array_namespace()` (`get_namespace()` is
maintained as a backwards compatible alias).

## Minor Changes

- The minimum supported NumPy version is now 1.21. Fixed a few issues with
Expand All @@ -8,11 +13,14 @@

- Add `api_version` to `get_namespace()`.

- `get_namespace()` now works correctly with `torch` tensors.
- `array_namespace()` (*née* `get_namespace()`) now works correctly with
`torch` tensors.

- `get_namespace()` now works correctly with `numpy.array_api` arrays.
- `array_namespace()` (*née* `get_namespace()`) now works correctly with
`numpy.array_api` arrays.

- `get_namespace()` now raises `TypeError` instead of `ValueError`.
- `array_namespace()` (*née* `get_namespace()`) now raises `TypeError` instead
of `ValueError`.

- Fix the `torch.std` wrapper.

Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ later this year.
## Usage

The typical usage of this library will be to get the corresponding array API
compliant namespace from the input arrays using `get_namespace()`, like
compliant namespace from the input arrays using `array_namespace()`, like

```py
def your_function(x, y):
xp = array_api_compat.get_namespace(x, y)
xp = array_api_compat.array_namespace(x, y)
# Now use xp as the array library namespace
return xp.mean(x, axis=0) + 2*xp.std(y, axis=0)
```
Expand Down Expand Up @@ -88,7 +88,7 @@ part of the specification but which are useful for using the array API:
- `is_array_api_obj(x)`: Return `True` if `x` is an array API compatible array
object.

- `get_namespace(*xs)`: Get the corresponding array API namespace for the
- `array_namespace(*xs)`: Get the corresponding array API namespace for the
arrays `xs`. For example, if the arrays are NumPy arrays, the returned
namespace will be `array_api_compat.numpy`. Note that this function will
also work for namespaces that aren't supported by this compat library but
Expand Down Expand Up @@ -133,7 +133,7 @@ specification:
don't want to monkeypatch or wrap it. The helper functions `device()` and
`to_device()` are provided to work around these missing methods (see above).
`x.mT` can be replaced with `xp.linalg.matrix_transpose(x)`.
`get_namespace(x)` should be used instead of `x.__array_namespace__`.
`array_namespace(x)` should be used instead of `x.__array_namespace__`.

- Value-based casting for scalars will be in effect unless explicitly disabled
with the environment variable `NPY_PROMOTION_STATE=weak` or
Expand Down Expand Up @@ -168,7 +168,7 @@ version.

- Like NumPy/CuPy, we do not wrap the `torch.Tensor` object. It is missing the
`__array_namespace__` and `to_device` methods, so the corresponding helper
functions `get_namespace()` and `to_device()` in this library should be
functions `array_namespace()` and `to_device()` in this library should be
used instead (see above).

- The `x.size` attribute on `torch.Tensor` is a function that behaves
Expand Down
4 changes: 2 additions & 2 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from types import ModuleType
import inspect

from ._helpers import _check_device, _is_numpy_array, get_namespace
from ._helpers import _check_device, _is_numpy_array, array_namespace

# These functions are modified from the NumPy versions.

Expand Down Expand Up @@ -293,7 +293,7 @@ def _asarray(
"""
if namespace is None:
try:
xp = get_namespace(obj, _use_compat=False)
xp = array_namespace(obj, _use_compat=False)
except ValueError:
# TODO: What about lists of arrays?
raise ValueError("A namespace must be specified for asarray() with non-array input")
Expand Down
10 changes: 6 additions & 4 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _check_api_version(api_version):
if api_version is not None and api_version != '2021.12':
raise ValueError("Only the 2021.12 version of the array API specification is currently supported")

def get_namespace(*xs, api_version=None, _use_compat=True):
def array_namespace(*xs, api_version=None, _use_compat=True):
"""
Get the array API compatible namespace for the arrays `xs`.

Expand All @@ -62,7 +62,7 @@ def get_namespace(*xs, api_version=None, _use_compat=True):
Typical usage is

def your_function(x, y):
xp = array_api_compat.get_namespace(x, y)
xp = array_api_compat.array_namespace(x, y)
# Now use xp as the array library namespace
return xp.mean(x, axis=0) + 2*xp.std(y, axis=0)

Expand All @@ -72,7 +72,7 @@ def your_function(x, y):
namespaces = set()
for x in xs:
if isinstance(x, (tuple, list)):
namespaces.add(get_namespace(*x, _use_compat=_use_compat))
namespaces.add(array_namespace(*x, _use_compat=_use_compat))
elif hasattr(x, '__array_namespace__'):
namespaces.add(x.__array_namespace__(api_version=api_version))
elif _is_numpy_array(x):
Expand Down Expand Up @@ -113,6 +113,8 @@ def your_function(x, y):

return xp

# backwards compatibility alias
get_namespace = array_namespace

def _check_device(xp, device):
if xp == sys.modules.get('numpy'):
Expand Down Expand Up @@ -224,4 +226,4 @@ def size(x):
return None
return math.prod(x.shape)

__all__ = ['is_array_api_obj', 'get_namespace', 'device', 'to_device', 'size']
__all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device', 'to_device', 'size']
41 changes: 41 additions & 0 deletions tests/test_array_namespace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import array_api_compat
from array_api_compat import array_namespace
import pytest


@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
@pytest.mark.parametrize("api_version", [None, '2021.12'])
def test_array_namespace(library, api_version):
lib = pytest.importorskip(library)

array = lib.asarray([1.0, 2.0, 3.0])
namespace = array_api_compat.array_namespace(array, api_version=api_version)

if 'array_api' in library:
assert namespace == lib
else:
assert namespace == getattr(array_api_compat, library)

def test_array_namespace_multiple():
import numpy as np

x = np.asarray([1, 2])
assert array_namespace(x, x) == array_namespace((x, x)) == \
array_namespace((x, x), x) == array_api_compat.numpy

def test_array_namespace_errors():
pytest.raises(TypeError, lambda: array_namespace([1]))
pytest.raises(TypeError, lambda: array_namespace())

import numpy as np
import torch
x = np.asarray([1, 2])
y = torch.asarray([1, 2])

pytest.raises(TypeError, lambda: array_namespace(x, y))

pytest.raises(ValueError, lambda: array_namespace(x, api_version='2022.12'))

def test_get_namespace():
# Backwards compatible wrapper
assert array_api_compat.get_namespace is array_api_compat.array_namespace
37 changes: 0 additions & 37 deletions tests/test_get_namespace.py

This file was deleted.