Skip to content

Various fixes for the torch wrapper #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Mar 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
19dc410
Add torch wrappers for ones(), zeros(), and empty()
asmeurer Mar 6, 2023
642bbd1
Allow the xfails and skips files to be specified as inputs
asmeurer Mar 6, 2023
6b6b487
Use package-version as an input instead
asmeurer Mar 6, 2023
80a20e8
Add GitHub Actions tests for NumPy 1.21
asmeurer Mar 6, 2023
c24abe4
Fix workflow names
asmeurer Mar 6, 2023
e414cf4
Try fixing the pytest invocation
asmeurer Mar 6, 2023
592eded
Use a different approach to get a different xfails file for numpy 1.21
asmeurer Mar 7, 2023
2490692
Fix typo
asmeurer Mar 7, 2023
ab3a1c9
Don't run NumPy 1.21 in Python 3.11
asmeurer Mar 7, 2023
68dd2ea
Use better syntax to skip 3.11
asmeurer Mar 7, 2023
912fd11
Set the Python version in an environment variable
asmeurer Mar 7, 2023
8358e1c
Move the Numpy 1.21 skip to the main config file
asmeurer Mar 7, 2023
eaf7b0a
Try to fix workflow syntax
asmeurer Mar 7, 2023
20c56be
Try to fix workflow syntax
asmeurer Mar 7, 2023
78dc3e5
Move the if down to the steps
asmeurer Mar 7, 2023
e7e1fb5
Fix torch std() with integral float correction
asmeurer Mar 7, 2023
30f4fac
Make get_namespace() raise TypeError instead of ValueError
asmeurer Mar 7, 2023
5ff7366
Add some more tests for get_namespace
asmeurer Mar 7, 2023
bb5da85
Add an example to the get_namespace docstring
asmeurer Mar 7, 2023
54f5f64
Fix some compatibility issues with NumPy 1.21
asmeurer Mar 7, 2023
5eba5fe
Update NumPy 1.21 XFAILs
asmeurer Mar 7, 2023
a1bb958
Add notes about minumum supported versions in the README
asmeurer Mar 7, 2023
e722706
Add the api_version keyword to get_namespace
asmeurer Mar 8, 2023
008e1b0
Update NumPy 1.12 XFAILs
asmeurer Mar 8, 2023
6ba689e
Update NumPy 1.21 xfails
asmeurer Mar 8, 2023
f652dd4
Add some missing NumPy 1.21 XFAILs
asmeurer Mar 8, 2023
902fefd
Add a CHANGELOG for a 1.1.1 release
asmeurer Mar 8, 2023
d2f9ba9
Add some missing NumPy 1.21 XFAILs
asmeurer Mar 8, 2023
c992c88
Fix the job name for NumPy 1.21
asmeurer Mar 8, 2023
14a7ca3
Add missing NumPy 1.21 XFAILs
asmeurer Mar 9, 2023
9fb5d98
Add missing NumPy 1.21 xfails
asmeurer Mar 10, 2023
8403874
Add a numpy 1.21 xfail
asmeurer Mar 10, 2023
f85f427
Add a missing NumPy 1.21 xfail
asmeurer Mar 10, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/workflows/array-api-tests-numpy-1-21.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
name: Array API Tests (NumPy 1.21)

on: [push, pull_request]

jobs:
array-api-tests-numpy:
uses: ./.github/workflows/array-api-tests.yml
with:
package-name: numpy
package-version: '== 1.21.*'
xfails-file-extra: '-1-21'
4 changes: 2 additions & 2 deletions .github/workflows/array-api-tests-numpy.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
name: Array API Tests (NumPy)
name: Array API Tests (NumPy Latest)

on: [push, pull_request]

jobs:
array-api-tests-numpy:
array-api-tests-numpy-1-21:
uses: ./.github/workflows/array-api-tests.yml
with:
package-name: numpy
2 changes: 1 addition & 1 deletion .github/workflows/array-api-tests-torch.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Array API Tests (PyTorch)
name: Array API Tests (PyTorch Latest)

on: [push, pull_request]

Expand Down
20 changes: 18 additions & 2 deletions .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,21 @@ on:
package-name:
required: true
type: string
package-version:
required: false
type: string
default: '>= 0'
pytest-extra-args:
required: false
type: string
# This is not how I would prefer to implement this but it's the only way
# that seems possible with GitHub Actions' limited expressions syntax
xfails-file-extra:
required: false
type: string
skips-file-extra:
required: false
type: string


env:
Expand Down Expand Up @@ -41,11 +53,15 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
# NumPy 1.21 doesn't support Python 3.11. There doesn't seem to be a way
# to put this in the numpy 1.21 config file.
if: "! (matrix.python-version == '3.11' && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))"
run: |
python -m pip install --upgrade pip
python -m pip install ${{ inputs.package-name }}
python -m pip install '${{ inputs.package-name }} ${{ inputs.package-version }}'
python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt
- name: Run the array API testsuite (${{ inputs.package-name }})
if: "! (matrix.python-version == '3.11' && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))"
env:
ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.package-name }}
# This enables the NEP 50 type promotion behavior (without it a lot of
Expand All @@ -54,4 +70,4 @@ jobs:
run: |
export PYTHONPATH="${GITHUB_WORKSPACE}/array-api-compat"
cd ${GITHUB_WORKSPACE}/array-api-tests
pytest ${PYTEST_ARGS} --xfails-file ${GITHUB_WORKSPACE}/array-api-compat/${{ inputs.package-name }}-xfails.txt --skips-file ${GITHUB_WORKSPACE}/array-api-compat/${{ inputs.package-name }}-skips.txt array_api_tests/
pytest array_api_tests/ --xfails-file ${GITHUB_WORKSPACE}/array-api-compat/${{ inputs.package-name }}${{ inputs.xfails-file-extra }}-xfails.txt --skips-file ${GITHUB_WORKSPACE}/array-api-compat/${{ inputs.package-name }}${{ inputs.skips-file-extra}}-skips.txt ${PYTEST_ARGS}
21 changes: 21 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,24 @@
# 1.1.1 (2023-03-08)

## Minor Changes

- The minimum supported NumPy version is now 1.21. Fixed a few issues with
NumPy 1.21 (with `unique_*` and `asarray`), although there are also a few
known issues with this version (see the README).

- Add `api_version` to `get_namespace()`.

- `get_namespace()` now works correctly with `torch` tensors.

- `get_namespace()` now works correctly with `numpy.array_api` arrays.

- `get_namespace()` now raises `TypeError` instead of `ValueError`.

- Fix the `torch.std` wrapper.

- Add `torch` wrappers for `ones`, `empty`, and `zeros` so that `shape` can be
passed as a keyword argument.

# 1.1 (2023-02-24)

## Major Changes
Expand Down
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,29 @@ specification:
50](https://numpy.org/neps/nep-0050-scalar-promotion.html) and
https://github.com/numpy/numpy/issues/22341)

- `asarray()` does not support `copy=False`.

- Functions which are not wrapped may not have the same type annotations
as the spec.

- Functions which are not wrapped may not use positional-only arguments.

The minimum supported NumPy version is 1.21. However, this older version of
NumPy has a few issues:

- `unique_*` will not compare nans as unequal.
- `finfo()` has no `smallest_normal`.
- No `from_dlpack` or `__dlpack__`.
- `argmax()` and `argmin()` do not have `keepdims`.
- `qr()` doesn't support matrix stacks.
- `asarray()` doesn't support `copy=True` (as noted above, `copy=False` is not
supported even in the latest NumPy).
- Type promotion behavior will be value based for 0-D arrays (and there is no
`NPY_PROMOTION_STATE=weak` to disable this).

If any of these are an issue, it is recommended to bump your minimum NumPy
version.

### PyTorch

- Like NumPy/CuPy, we do not wrap the `torch.Tensor` object. It is missing the
Expand Down Expand Up @@ -190,6 +208,8 @@ specification:
- As with NumPy, type annotations and positional-only arguments may not
exactly match the spec for functions that are not wrapped at all.

The minimum supported PyTorch version is 1.13.

## Vendoring

This library supports vendoring as an installation method. To vendor the
Expand Down
31 changes: 25 additions & 6 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from typing import NamedTuple
from types import ModuleType
import inspect

from ._helpers import _check_device, _is_numpy_array, get_namespace

Expand Down Expand Up @@ -161,13 +162,23 @@ class UniqueInverseResult(NamedTuple):
inverse_indices: ndarray


def _unique_kwargs(xp):
# Older versions of NumPy and CuPy do not have equal_nan. Rather than
# trying to parse version numbers, just check if equal_nan is in the
# signature.
s = inspect.signature(xp.unique)
if 'equal_nan' in s.parameters:
return {'equal_nan': False}
return {}

def unique_all(x: ndarray, /, xp) -> UniqueAllResult:
kwargs = _unique_kwargs(xp)
values, indices, inverse_indices, counts = xp.unique(
x,
return_counts=True,
return_index=True,
return_inverse=True,
equal_nan=False,
**kwargs,
)
# np.unique() flattens inverse indices, but they need to share x's shape
# See https://github.com/numpy/numpy/issues/20638
Expand All @@ -181,24 +192,26 @@ def unique_all(x: ndarray, /, xp) -> UniqueAllResult:


def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult:
kwargs = _unique_kwargs(xp)
res = xp.unique(
x,
return_counts=True,
return_index=False,
return_inverse=False,
equal_nan=False,
**kwargs
)

return UniqueCountsResult(*res)


def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult:
kwargs = _unique_kwargs(xp)
values, inverse_indices = xp.unique(
x,
return_counts=False,
return_index=False,
return_inverse=True,
equal_nan=False,
**kwargs,
)
# xp.unique() flattens inverse indices, but they need to share x's shape
# See https://github.com/numpy/numpy/issues/20638
Expand All @@ -207,12 +220,13 @@ def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult:


def unique_values(x: ndarray, /, xp) -> ndarray:
kwargs = _unique_kwargs(xp)
return xp.unique(
x,
return_counts=False,
return_index=False,
return_inverse=False,
equal_nan=False,
**kwargs,
)

def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray:
Expand Down Expand Up @@ -295,8 +309,13 @@ def _asarray(
_check_device(xp, device)
if _is_numpy_array(obj):
import numpy as np
COPY_FALSE = (False, np._CopyMode.IF_NEEDED)
COPY_TRUE = (True, np._CopyMode.ALWAYS)
if hasattr(np, '_CopyMode'):
# Not present in older NumPys
COPY_FALSE = (False, np._CopyMode.IF_NEEDED)
COPY_TRUE = (True, np._CopyMode.ALWAYS)
else:
COPY_FALSE = (False,)
COPY_TRUE = (True,)
else:
COPY_FALSE = (False,)
COPY_TRUE = (True,)
Expand Down
27 changes: 22 additions & 5 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,33 +49,50 @@ def is_array_api_obj(x):
or _is_torch_array(x) \
or hasattr(x, '__array_namespace__')

def get_namespace(*xs, _use_compat=True):
def _check_api_version(api_version):
if api_version is not None and api_version != '2021.12':
raise ValueError("Only the 2021.12 version of the array API specification is currently supported")

def get_namespace(*xs, api_version=None, _use_compat=True):
"""
Get the array API compatible namespace for the arrays `xs`.

`xs` should contain one or more arrays.

Typical usage is

def your_function(x, y):
xp = array_api_compat.get_namespace(x, y)
# Now use xp as the array library namespace
return xp.mean(x, axis=0) + 2*xp.std(y, axis=0)

api_version should be the newest version of the spec that you need support
for (currently the compat library wrapped APIs only support v2021.12).
"""
namespaces = set()
for x in xs:
if isinstance(x, (tuple, list)):
namespaces.add(get_namespace(*x, _use_compat=_use_compat))
elif hasattr(x, '__array_namespace__'):
namespaces.add(x.__array_namespace__())
namespaces.add(x.__array_namespace__(api_version=api_version))
elif _is_numpy_array(x):
_check_api_version(api_version)
if _use_compat:
from .. import numpy as numpy_namespace
namespaces.add(numpy_namespace)
else:
import numpy as np
namespaces.add(np)
elif _is_cupy_array(x):
_check_api_version(api_version)
if _use_compat:
from .. import cupy as cupy_namespace
namespaces.add(cupy_namespace)
else:
import cupy as cp
namespaces.add(cp)
elif _is_torch_array(x):
_check_api_version(api_version)
if _use_compat:
from .. import torch as torch_namespace
namespaces.add(torch_namespace)
Expand All @@ -84,13 +101,13 @@ def get_namespace(*xs, _use_compat=True):
namespaces.add(torch)
else:
# TODO: Support Python scalars?
raise ValueError("The input is not a supported array type")
raise TypeError("The input is not a supported array type")

if not namespaces:
raise ValueError("Unrecognized array input")
raise TypeError("Unrecognized array input")

if len(namespaces) != 1:
raise ValueError(f"Multiple namespaces for array inputs: {namespaces}")
raise TypeError(f"Multiple namespaces for array inputs: {namespaces}")

xp, = namespaces

Expand Down
38 changes: 31 additions & 7 deletions array_api_compat/torch/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,10 @@ def std(x: array,
# https://github.com/pytorch/pytorch/issues/61492. We don't try to
# implement it here for now.

# if isinstance(correction, float):
# correction = int(correction)
if isinstance(correction, float):
_correction = int(correction)
if correction != _correction:
raise NotImplementedError("float correction in torch std() is not yet supported")

# https://github.com/pytorch/pytorch/issues/29137
if axis == ():
Expand All @@ -372,10 +374,10 @@ def std(x: array,
if axis is None:
# torch doesn't support keepdims with axis=None
# (https://github.com/pytorch/pytorch/issues/71209)
res = torch.std(x, tuple(range(x.ndim)), correction=correction, **kwargs)
res = torch.std(x, tuple(range(x.ndim)), correction=_correction, **kwargs)
res = _axis_none_keepdims(res, x.ndim, keepdims)
return res
return torch.std(x, axis, correction=correction, keepdims=keepdims, **kwargs)
return torch.std(x, axis, correction=_correction, keepdims=keepdims, **kwargs)

def var(x: array,
/,
Expand Down Expand Up @@ -519,6 +521,28 @@ def full(shape: Union[int, Tuple[int, ...]],

return torch.full(shape, fill_value, dtype=dtype, device=device, **kwargs)

# ones, zeros, and empty do not accept shape as a keyword argument
def ones(shape: Union[int, Tuple[int, ...]],
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs) -> array:
return torch.ones(shape, dtype=dtype, device=device, **kwargs)

def zeros(shape: Union[int, Tuple[int, ...]],
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs) -> array:
return torch.zeros(shape, dtype=dtype, device=device, **kwargs)

def empty(shape: Union[int, Tuple[int, ...]],
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs) -> array:
return torch.empty(shape, dtype=dtype, device=device, **kwargs)

# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742
def expand_dims(x: array, /, *, axis: int = 0) -> array:
return torch.unsqueeze(x, axis)
Expand Down Expand Up @@ -585,7 +609,7 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
'logaddexp', 'multiply', 'not_equal', 'pow', 'remainder',
'subtract', 'max', 'min', 'sort', 'prod', 'sum', 'any', 'all',
'mean', 'std', 'var', 'concat', 'squeeze', 'flip', 'roll',
'nonzero', 'where', 'arange', 'eye', 'linspace', 'full',
'expand_dims', 'astype', 'broadcast_arrays', 'unique_all',
'unique_counts', 'unique_inverse', 'unique_values',
'nonzero', 'where', 'arange', 'eye', 'linspace', 'full', 'ones',
'zeros', 'empty', 'expand_dims', 'astype', 'broadcast_arrays',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'matmul', 'matrix_transpose', 'vecdot', 'tensordot']
Loading