Skip to content

Commit f86f3e5

Browse files
ev-brbetatim
andcommitted
ENH: rm __array__, add __buffer__
On python 3.12 and above, delegate to numpy's __buffer__ On earlier python's raise an error and ask the user to updgrade. Otherwise, on python 3.11 and below, np.array(array_api_strict_array) becomes a 0D object array. Co-authored-by: Tim Head <[email protected]>
1 parent 9eebf2c commit f86f3e5

File tree

4 files changed

+45
-73
lines changed

4 files changed

+45
-73
lines changed

.github/workflows/array-api-tests.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ jobs:
1111
runs-on: ubuntu-latest
1212
strategy:
1313
matrix:
14-
python-version: ['3.10', '3.11', '3.12', '3.13']
15-
numpy-version: ['1.26', 'dev']
14+
python-version: ['3.12', '3.13']
15+
numpy-version: ['1.26', '2.2', 'dev']
1616
exclude:
1717
- python-version: '3.13'
1818
numpy-version: '1.26'
@@ -38,7 +38,7 @@ jobs:
3838
if [[ "${{ matrix.numpy-version }}" == "dev" ]]; then
3939
python -m pip install --pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple numpy;
4040
else
41-
python -m pip install 'numpy>=1.26,<2.0';
41+
python -m pip install 'numpy=='${{ matrix.numpy-version }};
4242
fi
4343
python -m pip install ${GITHUB_WORKSPACE}/array-api-strict
4444
python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt

array_api_strict/_array_object.py

Lines changed: 21 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,7 @@ def __hash__(self) -> int:
6767
CPU_DEVICE = Device()
6868
ALL_DEVICES = (CPU_DEVICE, Device("device1"), Device("device2"))
6969

70-
# See https://github.com/data-apis/array-api-strict/issues/67 and the comment
71-
# on __array__ below.
72-
_allow_array = True
70+
_default = object()
7371

7472

7573
class Array:
@@ -151,40 +149,28 @@ def __repr__(self) -> str:
151149

152150
__str__ = __repr__
153151

154-
# In the future, _allow_array will be set to False, which will disallow
155-
# __array__. This means calling `np.func()` on an array_api_strict array
156-
# will give an error. If we don't explicitly disallow it, NumPy defaults
157-
# to creating an object dtype array, which would lead to confusing error
158-
# messages at best and surprising bugs at worst. The reason for doing this
159-
# is that __array__ is not actually supported by the standard, so it can
160-
# lead to code assuming np.asarray(other_array) would always work in the
161-
# standard.
162-
#
163-
# This was implemented historically for compatibility, and removing it has
152+
# `__array__` was implemented historically for compatibility, and removing it has
164153
# caused issues for some libraries (see
165154
# https://github.com/data-apis/array-api-strict/issues/67).
166-
def __array__(
167-
self, dtype: None | np.dtype[Any] = None, copy: None | bool = None
168-
) -> npt.NDArray[Any]:
169-
# We have to allow this to be internally enabled as there's no other
170-
# easy way to parse a list of Array objects in asarray().
171-
if _allow_array:
172-
if self._device != CPU_DEVICE:
173-
raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.")
174-
# copy keyword is new in 2.0.0; for older versions don't use it
175-
# retry without that keyword.
176-
if np.__version__[0] < '2':
177-
return np.asarray(self._array, dtype=dtype)
178-
elif np.__version__.startswith('2.0.0-dev0'):
179-
# Handle dev version for which we can't know based on version
180-
# number whether or not the copy keyword is supported.
181-
try:
182-
return np.asarray(self._array, dtype=dtype, copy=copy)
183-
except TypeError:
184-
return np.asarray(self._array, dtype=dtype)
185-
else:
186-
return np.asarray(self._array, dtype=dtype, copy=copy)
187-
raise ValueError("Conversion from an array_api_strict array to a NumPy ndarray is not supported")
155+
156+
# Instead of `__array__` we now implement the buffer protocol.
157+
# Note that it makes array-apis-strict requiring python>=3.12
158+
def __buffer__(self, flags):
159+
if self._device != CPU_DEVICE:
160+
raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.")
161+
return self._array.__buffer__(flags)
162+
163+
# We do not define __release_buffer__, per the discussion at
164+
# https://github.com/data-apis/array-api-strict/pull/115#pullrequestreview-2917178729
165+
166+
def __array__(self, *args, **kwds):
167+
# a stub for python < 3.12; otherwise numpy silently produces object arrays
168+
import sys
169+
minor, major = sys.version_info.minor, sys.version_info.major
170+
if major < 3 or minor < 12:
171+
raise TypeError(
172+
"Interoperation with NumPy requires python >= 3.12. Please upgrade."
173+
)
188174

189175
# These are various helper functions to make the array behavior match the
190176
# spec in places where it either deviates from or is more strict than

array_api_strict/_creation_functions.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from __future__ import annotations
22

3-
from collections.abc import Generator
4-
from contextlib import contextmanager
53
from enum import Enum
64
from typing import TYPE_CHECKING, Literal
75

@@ -26,21 +24,6 @@ class Undef(Enum):
2624
_undef = Undef.UNDEF
2725

2826

29-
@contextmanager
30-
def allow_array() -> Generator[None]:
31-
"""
32-
Temporarily enable Array.__array__. This is needed for np.array to parse
33-
list of lists of Array objects.
34-
"""
35-
from . import _array_object
36-
original_value = _array_object._allow_array
37-
try:
38-
_array_object._allow_array = True
39-
yield
40-
finally:
41-
_array_object._allow_array = original_value
42-
43-
4427
def _check_valid_dtype(dtype: DType | None) -> None:
4528
# Note: Only spelling dtypes as the dtype objects is supported.
4629
if dtype not in (None,) + _all_dtypes:
@@ -123,8 +106,8 @@ def asarray(
123106
# Give a better error message in this case. NumPy would convert this
124107
# to an object array. TODO: This won't handle large integers in lists.
125108
raise OverflowError("Integer out of bounds for array dtypes")
126-
with allow_array():
127-
res = np.array(obj, dtype=_np_dtype, copy=copy)
109+
110+
res = np.array(obj, dtype=_np_dtype, copy=copy)
128111
return Array._new(res, device=device)
129112

130113

array_api_strict/tests/test_array_object.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import operator
23
from builtins import all as all_
34

@@ -526,6 +527,10 @@ def test_array_properties():
526527
assert b.mT.shape == (3, 2)
527528

528529

530+
@pytest.mark.xfail(sys.version_info.major*100 + sys.version_info.minor < 312,
531+
reason="array conversion relies on buffer protocol, and "
532+
"requires python >= 3.12"
533+
)
529534
def test_array_conversion():
530535
# Check that arrays on the CPU device can be converted to NumPy
531536
# but arrays on other devices can't. Note this is testing the logic in
@@ -536,25 +541,23 @@ def test_array_conversion():
536541

537542
for device in ("device1", "device2"):
538543
a = ones((2, 3), device=array_api_strict.Device(device))
539-
with pytest.raises(RuntimeError, match="Can not convert array"):
544+
with pytest.raises((RuntimeError, ValueError)):
540545
np.asarray(a)
541546

542-
def test__array__():
543-
# __array__ should work for now
547+
# __buffer__ should work for now for conversion to numpy
544548
a = ones((2, 3))
545-
np.array(a)
546-
547-
# Test the _allow_array private global flag for disabling it in the
548-
# future.
549-
from .. import _array_object
550-
original_value = _array_object._allow_array
551-
try:
552-
_array_object._allow_array = False
553-
a = ones((2, 3))
554-
with pytest.raises(ValueError, match="Conversion from an array_api_strict array to a NumPy ndarray is not supported"):
555-
np.array(a)
556-
finally:
557-
_array_object._allow_array = original_value
549+
na = np.array(a)
550+
assert na.shape == (2, 3)
551+
assert na.dtype == np.float64
552+
553+
@pytest.mark.skipif(not sys.version_info.major*100 + sys.version_info.minor < 312,
554+
reason="conversion to numpy errors out unless python >= 3.12"
555+
)
556+
def test_array_conversion_2():
557+
a = ones((2, 3))
558+
with pytest.raises(TypeError):
559+
np.array(a)
560+
558561

559562
def test_allow_newaxis():
560563
a = ones(5)

0 commit comments

Comments
 (0)