Skip to content

Commit 2d65f39

Browse files
committed
Merge branch 'main' into 2022.12
2 parents d26a847 + 5eea1c7 commit 2d65f39

File tree

5 files changed

+45
-9
lines changed

5 files changed

+45
-9
lines changed

.github/workflows/array-api-tests.yml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,9 @@ jobs:
4141
- name: Checkout array-api-tests
4242
uses: actions/checkout@v3
4343
with:
44-
# repository: data-apis/array-api-tests
44+
repository: data-apis/array-api-tests
4545
submodules: 'true'
4646
path: array-api-tests
47-
48-
repository: data-apis/array-api-tests
49-
ref: master
5047
- name: Set up Python ${{ matrix.python-version }}
5148
uses: actions/setup-python@v1
5249
with:

array_api_compat/common/_aliases.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def _asarray(
331331

332332
return xp.asarray(obj, dtype=dtype, **kwargs)
333333

334-
# xp.reshape calls the keyword argument 'newshape' instead of 'shape'
334+
# np.reshape calls the keyword argument 'newshape' instead of 'shape'
335335
def reshape(x: ndarray,
336336
/,
337337
shape: Tuple[int, ...],
@@ -340,8 +340,9 @@ def reshape(x: ndarray,
340340
if copy is True:
341341
x = x.copy()
342342
elif copy is False:
343-
x.shape = shape
344-
return x
343+
y = x.view()
344+
y.shape = shape
345+
return y
345346
return xp.reshape(x, shape, **kwargs)
346347

347348
# The descending keyword is new in sort and argsort, and 'kind' replaced with

array_api_compat/common/_helpers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,11 @@ def _cupy_to_device(x, device, /, stream=None):
154154

155155
if device == x.device:
156156
return x
157+
elif device == "cpu":
158+
# allowing us to use `to_device(x, "cpu")`
159+
# is useful for portable test swapping between
160+
# host and device backends
161+
return x.get()
157162
elif not isinstance(device, _Device):
158163
raise ValueError(f"Unsupported device {device!r}")
159164
else:

array_api_compat/torch/_aliases.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,16 @@ def where(condition: array, x1: array, x2: array, /) -> array:
478478
x1, x2 = _fix_promotion(x1, x2)
479479
return torch.where(condition, x1, x2)
480480

481+
# torch.reshape doesn't have the copy keyword
482+
def reshape(x: array,
483+
/,
484+
shape: Tuple[int, ...],
485+
copy: Optional[bool] = None,
486+
**kwargs) -> array:
487+
if copy is not None:
488+
raise NotImplementedError("torch.reshape doesn't yet support the copy keyword")
489+
return torch.reshape(x, shape, **kwargs)
490+
481491
# torch.arange doesn't support returning empty arrays
482492
# (https://github.com/pytorch/pytorch/issues/70915), and doesn't support some
483493
# keyword argument combinations
@@ -680,8 +690,8 @@ def take(x: array, indices: array, /, *, axis: int, **kwargs) -> array:
680690
'floor_divide', 'greater', 'greater_equal', 'less', 'less_equal',
681691
'logaddexp', 'multiply', 'not_equal', 'pow', 'remainder',
682692
'subtract', 'max', 'min', 'sort', 'prod', 'sum', 'any', 'all',
683-
'mean', 'std', 'var', 'concat', 'squeeze', 'broadcast_to', 'flip',
684-
'roll', 'nonzero', 'where', 'arange', 'eye', 'linspace', 'full',
693+
'mean', 'std', 'var', 'concat', 'squeeze', 'broadcast_to', ''flip', 'roll',
694+
'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full',
685695
'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype',
686696
'broadcast_arrays', 'unique_all', 'unique_counts',
687697
'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose',

tests/test_common.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from ._helpers import import_
2+
from array_api_compat import to_device, device
3+
4+
import pytest
5+
import numpy as np
6+
from numpy.testing import assert_allclose
7+
8+
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
9+
def test_to_device_host(library):
10+
# different libraries have different semantics
11+
# for DtoH transfers; ensure that we support a portable
12+
# shim for common array libs
13+
# see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919
14+
xp = import_('array_api_compat.' + library)
15+
expected = np.array([1, 2, 3])
16+
x = xp.asarray([1, 2, 3])
17+
x = to_device(x, "cpu")
18+
# torch will return a genuine Device object, but
19+
# the other libs will do something different with
20+
# a `device(x)` query; however, what's really important
21+
# here is that we can test portably after calling
22+
# to_device(x, "cpu") to return to host
23+
assert_allclose(x, expected)

0 commit comments

Comments
 (0)