Skip to content

PyTorch compatibility layer #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 95 commits into from
Feb 25, 2023
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
95 commits
Select commit Hold shift + click to select a range
f49f42b
Start pytorch compatibility layer
asmeurer Jan 7, 2023
4b95748
Add vendor tests for torch
asmeurer Jan 7, 2023
1ecb7ca
Replace torch expand_dims wrapper to a wrapper around unsqueeze
asmeurer Jan 9, 2023
52b6054
Add torch support to the helper functions
asmeurer Jan 9, 2023
c484dbf
Add max and min wrappers for torch
asmeurer Jan 9, 2023
3023a89
Add a wrapper for torch.prod
asmeurer Jan 9, 2023
a1bbd9b
Add the torch prod wrapper to __all__
asmeurer Jan 9, 2023
dba7fa7
Return a copy from max and min with axis=()
asmeurer Jan 10, 2023
44d91e1
Add a size() helper function
asmeurer Jan 10, 2023
1faea7b
Add any and all torch wrappers and fix some issues with prod
asmeurer Jan 10, 2023
ad4484d
Add astype torch wrapper
asmeurer Jan 11, 2023
db3241d
Cast the input to prod/all/any to tensor
asmeurer Jan 11, 2023
3f0d913
More logical order for some functions
asmeurer Jan 11, 2023
2d25dd2
Add wrappers for two-argument elementwise functions
asmeurer Jan 24, 2023
c4c0cfa
Add bitwise_invert to torch
asmeurer Jan 24, 2023
a1917f8
Add torch wrappers for broadcast_to and can_cast
asmeurer Jan 24, 2023
ae28ce0
Add torch arange wrapper
asmeurer Jan 24, 2023
c3d9334
Add a wrapper for torch.eye
asmeurer Jan 24, 2023
f4d6df1
Add pytorch linspace wrapper
asmeurer Jan 24, 2023
14b6519
Add torch squeeze wrapper
asmeurer Jan 24, 2023
db3f579
Add torch flip and roll wrappers
asmeurer Jan 24, 2023
27b7e8c
Add a torch wrapper for nonzero
asmeurer Jan 24, 2023
e886644
Fix pyflakes warning
asmeurer Jan 24, 2023
fd5d179
Add torch wrapper for where
asmeurer Jan 24, 2023
a5f3253
Add sort wrapper to torch
asmeurer Feb 2, 2023
4a71e63
Pass kwargs through some torch wrappers
asmeurer Feb 2, 2023
157fc1e
Add torch mean(), std(), and var() wrappers
asmeurer Feb 2, 2023
c3efe6a
Add torch sum() and prod() wrappers
asmeurer Feb 3, 2023
ed46247
Add unique_* wrappers to torch
asmeurer Feb 4, 2023
eaf5358
Just raise NotImplementedError in pytorch unique_all()
asmeurer Feb 4, 2023
cd25f47
Fix to_device for pytorch tensors
asmeurer Feb 10, 2023
98ed0b2
Restrict the names imported from torch into the compat submodule
asmeurer Feb 10, 2023
be4b534
Allow torch sum and prod to upcast uint8 to int64
asmeurer Feb 10, 2023
ecda017
Don't unnecessarily flip the axes in flip()
asmeurer Feb 10, 2023
3ccee1b
Comment out dead code in the torch unique_all() wrapper
asmeurer Feb 10, 2023
7699755
Use flatten instead of ravel
asmeurer Feb 10, 2023
d38bad5
Improve some error messages
asmeurer Feb 10, 2023
24c0ea3
Use a better function name and use unsqueeze instead of None indexing
asmeurer Feb 13, 2023
48d1ae1
Add pytorch-xfails.txt (still need to validate)
asmeurer Feb 18, 2023
866647d
Move main namespace linear algebra helpers to _aliases.py
asmeurer Feb 18, 2023
85b71de
Merge branch 'master' into pytorch
asmeurer Feb 18, 2023
c441f33
Fix main namespace linalg functions in numpy and cupy
asmeurer Feb 18, 2023
04eef18
Add main namespace linalg functions to the torch wrapper
asmeurer Feb 18, 2023
453ecb8
Add torch wrapper for matmul
asmeurer Feb 18, 2023
5a3bbbe
Finish torch wrappers for matmul, vecdot, and tensordot
asmeurer Feb 20, 2023
b8bbdc8
Clean up pytorch-xfails file
asmeurer Feb 21, 2023
7d176b9
Update pytorch-xfails.txt
asmeurer Feb 21, 2023
17c0a91
Merge branch 'main' into pytorch
asmeurer Feb 21, 2023
1ffcb15
Install pytorch in on CI
asmeurer Feb 21, 2023
0db034d
Make the GitHub Actions workflow reusable so that we can test pytorch
asmeurer Feb 21, 2023
39cbfd4
Fix workflow path
asmeurer Feb 21, 2023
c3d0d8e
Fix variable interpolation syntax
asmeurer Feb 21, 2023
cf21cea
Allow specifying extra pytest args in the test yamls
asmeurer Feb 21, 2023
a95eeb6
Enable verbose output for the torch tests
asmeurer Feb 21, 2023
4737dc0
Revert "Enable verbose output for the torch tests"
asmeurer Feb 21, 2023
ea9c1e2
Skip the torch test that crashes the CI
asmeurer Feb 21, 2023
4904411
Skip another test that crashes on CI
asmeurer Feb 22, 2023
039af59
Disable linalg in the torch CI tests
asmeurer Feb 22, 2023
367c4b6
Add missing torch xfails
asmeurer Feb 22, 2023
f3ee38c
Do a verbose CI run for the pytorch array API tests
asmeurer Feb 22, 2023
c631fa3
Revert "Do a verbose CI run for the pytorch array API tests"
asmeurer Feb 22, 2023
33dacf9
Add some missing torch xfails
asmeurer Feb 22, 2023
82b8def
Do a verbose output run of the torch array API tests (with the correc…
asmeurer Feb 22, 2023
b014c1b
Revert "Do a verbose output run of the torch array API tests (with th…
asmeurer Feb 22, 2023
124d6c3
Add a missing torch xfail
asmeurer Feb 22, 2023
847a9e5
Add a missing torch xfail
asmeurer Feb 22, 2023
0857b86
Skip test_floor_divide, which core dumps on CI
asmeurer Feb 23, 2023
6354cd9
Add a missing torch xfail
asmeurer Feb 23, 2023
0565dee
Update the README
asmeurer Feb 23, 2023
e9b447c
Fix some formatting in the README
asmeurer Feb 23, 2023
bb4d3af
Typo fix
asmeurer Feb 23, 2023
5545635
Update torch reduction functions that don't support multiple axes
asmeurer Feb 23, 2023
a78f733
Add a test skip that crashes on CI
asmeurer Feb 23, 2023
0cefa7b
Add a CHANGELOG for the upcoming 1.1 release
asmeurer Feb 23, 2023
1aceff5
Add more skips for tests that crash on CI
asmeurer Feb 23, 2023
b016d4c
Skip a torch test that crashes CI
asmeurer Feb 23, 2023
1cd43f2
Add a torch xfail
asmeurer Feb 23, 2023
c8a5a70
Add a script to manually run the cupy tests
asmeurer Feb 23, 2023
7261dae
Add a torch skip
asmeurer Feb 23, 2023
795dbea
Merge branch 'pytorch' of github.com:asmeurer/array-api-compat into p…
asmeurer Feb 23, 2023
0368c8f
Use cupy specific skips and xfails
asmeurer Feb 24, 2023
2b456c2
Allow passing pytest args through in test_cupy.sh
asmeurer Feb 24, 2023
046ffd0
Add a shebang to test_cupy.sh
asmeurer Feb 24, 2023
c3eb0d5
Make the hypothesis examples database persistent in test_cupy.sh
asmeurer Feb 24, 2023
111a122
Fix sort() and argsort() with cupy
asmeurer Feb 24, 2023
09b5a6f
Add comments for the rest of the cupy xfails
asmeurer Feb 24, 2023
b780b9a
Merge branch 'pytorch' of github.com:asmeurer/array-api-compat into p…
asmeurer Feb 24, 2023
1908a00
Fix argument quoting in test_cupy.sh
asmeurer Feb 25, 2023
5cbd1c0
Update cupy skips and xfails
asmeurer Feb 25, 2023
6a63d5c
Update cupy xfails
asmeurer Feb 25, 2023
175f195
Update test_cupy.sh to run the vendoring tests
asmeurer Feb 25, 2023
d1c7999
Add a minor CHANGELOG entry
asmeurer Feb 25, 2023
0e923b6
Merge branch 'pytorch' of github.com:asmeurer/array-api-compat into p…
asmeurer Feb 25, 2023
2fb0a0a
Bump the version to 1.1
asmeurer Feb 25, 2023
3470b36
Add a missing torch skip
asmeurer Feb 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ the array API:
[Stream](https://docs.cupy.dev/en/stable/reference/generated/cupy.cuda.Stream.html)
objects.

- `size(x)`: Equivalent to
[`x.size`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.size.html#array_api.array.size),
i.e., the number of elements in the array. Included because PyTorch's
`Tensor` defines `size` as a method which returns the shape, and this cannot
be wrapped because this compat library doesn't wrap or extend the array
objects.

## Known Differences from the Array API Specification

There are some known differences between this library and the array API
Expand Down
34 changes: 31 additions & 3 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from __future__ import annotations

import sys
import math

def _is_numpy_array(x):
# Avoid importing NumPy if it isn't already
Expand All @@ -29,11 +30,24 @@ def _is_cupy_array(x):
# TODO: Should we reject ndarray subclasses?
return isinstance(x, (cp.ndarray, cp.generic))

def _is_torch_array(x):
# Avoid importing torch if it isn't already
if 'torch' not in sys.modules:
return False

import torch

# TODO: Should we reject ndarray subclasses?
return isinstance(x, torch.Tensor)

def is_array_api_obj(x):
"""
Check if x is an array API compatible array object.
"""
return _is_numpy_array(x) or _is_cupy_array(x) or hasattr(x, '__array_namespace__')
return _is_numpy_array(x) \
or _is_cupy_array(x) \
or _is_torch_array(x) \
or hasattr(x, '__array_namespace__')

def get_namespace(*xs, _use_compat=True):
"""
Expand Down Expand Up @@ -139,6 +153,11 @@ def _cupy_to_device(x, device, /, stream=None):
prev_stream.use()
return arr

def _torch_to_device(x, device, /, stream=None):
if stream is not None:
raise NotImplementedError
return x.to(device)

def to_device(x: "Array", device: "Device", /, *, stream: Optional[Union[int, Any]] = None) -> "Array":
"""
Copy the array from the device on which it currently resides to the specified ``device``.
Expand Down Expand Up @@ -169,7 +188,16 @@ def to_device(x: "Array", device: "Device", /, *, stream: Optional[Union[int, An
elif _is_cupy_array(x):
# cupy does not yet have to_device
return _cupy_to_device(x, device, stream=stream)

elif _is_torch_array(x):
return _torch_to_device(x)
return x.to_device(device, stream=stream)

__all__ = ['is_array_api_obj', 'get_namespace', 'device', 'to_device']
def size(x):
"""
Return the total number of elements of x
"""
if None in x.shape:
return None
return math.prod(x.shape)

__all__ = ['is_array_api_obj', 'get_namespace', 'device', 'to_device', 'size']
14 changes: 14 additions & 0 deletions array_api_compat/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from torch import *

# Several names are not included in the above import *
import torch
for n in dir(torch):
if not n.startswith('_'):
exec(n + ' = torch.' + n)

# These imports may overwrite names from the import * above.
from ._aliases import *

from ..common._helpers import *

__array_api_version__ = '2021.12'
Loading