Skip to content

Commit 348c3c3

Browse files
committed
Add tests
1 parent 133cfd1 commit 348c3c3

File tree

1 file changed

+47
-1
lines changed

1 file changed

+47
-1
lines changed

dpnp/tests/test_ndarray.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import dpctl.tensor as dpt
22
import numpy
33
import pytest
4-
from numpy.testing import assert_allclose, assert_array_equal
4+
from numpy.testing import (
5+
assert_allclose,
6+
assert_array_equal,
7+
assert_raises_regex,
8+
)
59

610
import dpnp
711

@@ -104,6 +108,48 @@ def test_flags_writable():
104108
assert not a.imag.flags.writable
105109

106110

111+
class TestArrayNamespace:
112+
def test_basic(self):
113+
a = dpnp.arange(2)
114+
xp = a.__array_namespace__()
115+
assert xp is dpnp
116+
117+
@pytest.mark.parametrize("api_version", [None, "2023.12"])
118+
def test_api_version(self, api_version):
119+
a = dpnp.arange(2)
120+
xp = a.__array_namespace__(api_version=api_version)
121+
assert xp is dpnp
122+
123+
@pytest.mark.parametrize("api_version", ["2021.12", "2022.12", "2024.12"])
124+
def test_unsupported_api_version(self, api_version):
125+
a = dpnp.arange(2)
126+
assert_raises_regex(
127+
ValueError,
128+
"Only 2023.12 is supported",
129+
a.__array_namespace__,
130+
api_version=api_version,
131+
)
132+
133+
@pytest.mark.parametrize(
134+
"api_version",
135+
[
136+
2023,
137+
(2022,),
138+
[
139+
2021,
140+
],
141+
],
142+
)
143+
def test_wrong_api_version(self, api_version):
144+
a = dpnp.arange(2)
145+
assert_raises_regex(
146+
TypeError,
147+
"Expected type str",
148+
a.__array_namespace__,
149+
api_version=api_version,
150+
)
151+
152+
107153
class TestItem:
108154
@pytest.mark.parametrize("args", [2, 7, (1, 2), (2, 0)])
109155
def test_basic(self, args):

0 commit comments

Comments
 (0)