Skip to content

Commit aaf5290

Browse files
authored
Add __array_namespace__ method (#2252)
The PR proposes to implement `__array_namespace__` method of dpnp ndarray, which is required to be compliant with python array API. The method will return dpnp as an array namespace, member functions of which implement data API. The array namespace is assumed to be stored inside dpctl tensor. So dpnp ndarray constructor is updated to explicitly pass `array_namespace=dpnp` into `dpt.usm_ndarray` call. And also to set the namespace through `_set_namespace(dpnp)` every time dpnp ndarray is created from usm_ndarray.
1 parent 303a203 commit aaf5290

File tree

3 files changed

+80
-22
lines changed

3 files changed

+80
-22
lines changed

dpnp/dpnp_array.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def __init__(
9494
offset=offset,
9595
order=order,
9696
buffer_ctor_kwargs={"queue": sycl_queue_normalized},
97+
array_namespace=dpnp,
9798
)
9899

99100
@property
@@ -201,6 +202,31 @@ def __and__(self, other):
201202
# '__array_ufunc__',
202203
# '__array_wrap__',
203204

205+
def __array_namespace__(self, /, *, api_version=None):
206+
"""
207+
Returns array namespace, member functions of which implement data API.
208+
209+
Parameters
210+
----------
211+
api_version : str, optional
212+
Request namespace compliant with given version of array API. If
213+
``None``, namespace for the most recent supported version is
214+
returned.
215+
Default: ``None``.
216+
217+
Returns
218+
-------
219+
out : any
220+
An object representing the array API namespace. It should have
221+
every top-level function defined in the specification as
222+
an attribute. It may contain other public names as well, but it is
223+
recommended to only include those names that are part of the
224+
specification.
225+
226+
"""
227+
228+
return self._array_obj.__array_namespace__(api_version=api_version)
229+
204230
def __bool__(self):
205231
"""``True`` if self else ``False``."""
206232
return self._array_obj.__bool__()
@@ -327,15 +353,7 @@ def __getitem__(self, key):
327353
key = _get_unwrapped_index_key(key)
328354

329355
item = self._array_obj.__getitem__(key)
330-
if not isinstance(item, dpt.usm_ndarray):
331-
raise RuntimeError(
332-
"Expected dpctl.tensor.usm_ndarray, got {}"
333-
"".format(type(item))
334-
)
335-
336-
res = self.__new__(dpnp_array)
337-
res._array_obj = item
338-
return res
356+
return dpnp_array._create_from_usm_ndarray(item)
339357

340358
# '__getstate__',
341359

@@ -606,6 +624,7 @@ def _create_from_usm_ndarray(usm_ary: dpt.usm_ndarray):
606624
)
607625
res = dpnp_array.__new__(dpnp_array)
608626
res._array_obj = usm_ary
627+
res._array_obj._set_namespace(dpnp)
609628
return res
610629

611630
def all(self, axis=None, out=None, keepdims=False, *, where=True):
@@ -1749,17 +1768,16 @@ def transpose(self, *axes):
17491768
if axes_len == 1 and isinstance(axes[0], (tuple, list)):
17501769
axes = axes[0]
17511770

1752-
res = self.__new__(dpnp_array)
17531771
if ndim == 2 and axes_len == 0:
1754-
res._array_obj = self._array_obj.T
1772+
usm_res = self._array_obj.T
17551773
else:
17561774
if len(axes) == 0 or axes[0] is None:
17571775
# self.transpose().shape == self.shape[::-1]
17581776
# self.transpose(None).shape == self.shape[::-1]
17591777
axes = tuple((ndim - x - 1) for x in range(ndim))
17601778

1761-
res._array_obj = dpt.permute_dims(self._array_obj, axes)
1762-
return res
1779+
usm_res = dpt.permute_dims(self._array_obj, axes)
1780+
return dpnp_array._create_from_usm_ndarray(usm_res)
17631781

17641782
def var(
17651783
self,

dpnp/dpnp_iface_indexing.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -622,14 +622,8 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
622622
out_strides = a_straides[:-2] + (1,)
623623
out_offset = a_element_offset
624624

625-
return dpnp_array._create_from_usm_ndarray(
626-
dpt.usm_ndarray(
627-
out_shape,
628-
dtype=a.dtype,
629-
buffer=a.get_array(),
630-
strides=out_strides,
631-
offset=out_offset,
632-
)
625+
return dpnp_array(
626+
out_shape, buffer=a, strides=out_strides, offset=out_offset
633627
)
634628

635629

dpnp/tests/test_ndarray.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import dpctl.tensor as dpt
22
import numpy
33
import pytest
4-
from numpy.testing import assert_allclose, assert_array_equal
4+
from numpy.testing import (
5+
assert_allclose,
6+
assert_array_equal,
7+
assert_raises_regex,
8+
)
59

610
import dpnp
711

@@ -104,6 +108,48 @@ def test_flags_writable():
104108
assert not a.imag.flags.writable
105109

106110

111+
class TestArrayNamespace:
112+
def test_basic(self):
113+
a = dpnp.arange(2)
114+
xp = a.__array_namespace__()
115+
assert xp is dpnp
116+
117+
@pytest.mark.parametrize("api_version", [None, "2023.12"])
118+
def test_api_version(self, api_version):
119+
a = dpnp.arange(2)
120+
xp = a.__array_namespace__(api_version=api_version)
121+
assert xp is dpnp
122+
123+
@pytest.mark.parametrize("api_version", ["2021.12", "2022.12", "2024.12"])
124+
def test_unsupported_api_version(self, api_version):
125+
a = dpnp.arange(2)
126+
assert_raises_regex(
127+
ValueError,
128+
"Only 2023.12 is supported",
129+
a.__array_namespace__,
130+
api_version=api_version,
131+
)
132+
133+
@pytest.mark.parametrize(
134+
"api_version",
135+
[
136+
2023,
137+
(2022,),
138+
[
139+
2021,
140+
],
141+
],
142+
)
143+
def test_wrong_api_version(self, api_version):
144+
a = dpnp.arange(2)
145+
assert_raises_regex(
146+
TypeError,
147+
"Expected type str",
148+
a.__array_namespace__,
149+
api_version=api_version,
150+
)
151+
152+
107153
class TestItem:
108154
@pytest.mark.parametrize("args", [2, 7, (1, 2), (2, 0)])
109155
def test_basic(self, args):

0 commit comments

Comments
 (0)