Skip to content

Add sparse compatibility layer. #134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def is_numpy_array(x):
is_torch_array
is_dask_array
is_jax_array
is_sparse_array
"""
# Avoid importing NumPy if it isn't already
if 'numpy' not in sys.modules:
Expand Down Expand Up @@ -79,6 +80,7 @@ def is_cupy_array(x):
is_torch_array
is_dask_array
is_jax_array
is_sparse_array
"""
# Avoid importing NumPy if it isn't already
if 'cupy' not in sys.modules:
Expand All @@ -105,6 +107,7 @@ def is_torch_array(x):
is_cupy_array
is_dask_array
is_jax_array
is_sparse_array
"""
# Avoid importing torch if it isn't already
if 'torch' not in sys.modules:
Expand All @@ -131,6 +134,7 @@ def is_dask_array(x):
is_cupy_array
is_torch_array
is_jax_array
is_sparse_array
"""
# Avoid importing dask if it isn't already
if 'dask.array' not in sys.modules:
Expand All @@ -157,6 +161,7 @@ def is_jax_array(x):
is_cupy_array
is_torch_array
is_dask_array
is_sparse_array
"""
# Avoid importing jax if it isn't already
if 'jax' not in sys.modules:
Expand All @@ -166,6 +171,35 @@ def is_jax_array(x):

return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)


def is_sparse_array(x) -> bool:
"""
Return True if `x` is a `sparse` array.

This function does not import `sparse` if it has not already been imported
and is therefore cheap to use.


See Also
--------

array_namespace
is_array_api_obj
is_numpy_array
is_cupy_array
is_torch_array
is_dask_array
is_jax_array
"""
# Avoid importing jax if it isn't already
if 'sparse' not in sys.modules:
return False

import sparse

# TODO: Account for other backends.
return isinstance(x, sparse.SparseArray)

def is_array_api_obj(x):
"""
Return True if `x` is an array API compatible array object.
Expand All @@ -185,6 +219,7 @@ def is_array_api_obj(x):
or is_torch_array(x) \
or is_dask_array(x) \
or is_jax_array(x) \
or is_sparse_array(x) \
or hasattr(x, '__array_namespace__')

def _check_api_version(api_version):
Expand Down Expand Up @@ -253,6 +288,7 @@ def your_function(x, y):
is_torch_array
is_dask_array
is_jax_array
is_sparse_array

"""
if use_compat not in [None, True, False]:
Expand Down Expand Up @@ -312,6 +348,13 @@ def your_function(x, y):
# not have a wrapper submodule for it.
import jax.experimental.array_api as jnp
namespaces.add(jnp)
elif is_sparse_array(x):
if use_compat is True:
_check_api_version(api_version)
raise ValueError("`sparse` does not have an array-api-compat wrapper")
else:
import sparse
namespaces.add(sparse)
elif hasattr(x, '__array_namespace__'):
if use_compat is True:
raise ValueError("The given array does not have an array-api-compat wrapper")
Expand Down Expand Up @@ -406,8 +449,23 @@ def device(x: Array, /) -> Device:
return x.device()
else:
return x.device
elif is_sparse_array(x):
# `sparse` will gain `.device`, so check for this first.
x_device = getattr(x, 'device', None)
if x_device is not None:
return x_device
# Everything but DOK has this attr.
try:
inner = x.data
except AttributeError:
return "cpu"
# Return the device of the constituent array
return device(inner)
return x.device

# Prevent shadowing, used below
_device = device

# Based on cupy.array_api.Array.to_device
def _cupy_to_device(x, device, /, stream=None):
import cupy as cp
Expand Down Expand Up @@ -523,6 +581,10 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
# This import adds to_device to x
import jax.experimental.array_api # noqa: F401
return x.to_device(device, stream=stream)
elif is_sparse_array(x) and device == _device(x):
# Perform trivial check to return the same array if
# device is same instead of err-ing.
return x
return x.to_device(device, stream=stream)

def size(x):
Expand All @@ -549,6 +611,7 @@ def size(x):
"is_jax_array",
"is_numpy_array",
"is_torch_array",
"is_sparse_array",
"size",
"to_device",
]
Expand Down
5 changes: 5 additions & 0 deletions array_api_compat/sparse/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from sparse import * # noqa: F403
from ..common._aliases import * # noqa: F403
from ..common._helpers import * # noqa: F401,F403

__array_api_version__ = '2022.12'
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ jax[cpu]
numpy
pytest
torch
sparse >=0.15.1