Skip to content

Commit 686194f

Browse files
authored
Merge e44be1e into 807179a
2 parents 807179a + e44be1e commit 686194f

File tree

6 files changed

+52
-11
lines changed

6 files changed

+52
-11
lines changed

dpnp/dpnp_array.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,9 @@ def __and__(self, other):
192192
# '__array_prepare__',
193193
# '__array_priority__',
194194
# '__array_struct__',
195-
# '__array_ufunc__',
195+
196+
__array_ufunc__ = None
197+
196198
# '__array_wrap__',
197199

198200
def __array_namespace__(self, /, *, api_version=None):

dpnp/tests/helper.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,20 @@ def get_all_dtypes(
161161
return dtypes
162162

163163

164+
def get_array(xp, a):
165+
"""
166+
Cast input array `a` to a type supported by `xp` initerface.
167+
168+
Implicit conversion of either DPNP or DPCTL array to a NumPy array is not
169+
allowed. Input array has to be explicitly casted with `asnumpy` function.
170+
171+
"""
172+
173+
if xp is numpy and dpnp.is_supported_array_type(a):
174+
return dpnp.asnumpy(a)
175+
return a
176+
177+
164178
def generate_random_numpy_array(
165179
shape,
166180
dtype=None,

dpnp/tests/test_arraycreation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .helper import (
1818
assert_dtype_allclose,
1919
get_all_dtypes,
20+
get_array,
2021
)
2122
from .third_party.cupy import testing
2223

@@ -768,7 +769,7 @@ def test_space_numpy_dtype(func, start_dtype, stop_dtype):
768769
],
769770
)
770771
def test_linspace_arrays(start, stop):
771-
func = lambda xp: xp.linspace(start, stop, 10)
772+
func = lambda xp: xp.linspace(get_array(xp, start), get_array(xp, stop), 10)
772773
assert func(numpy).shape == func(dpnp).shape
773774

774775

dpnp/tests/test_linalg.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1935,7 +1935,7 @@ def test_matrix_rank(self, data, dtype):
19351935

19361936
np_rank = numpy.linalg.matrix_rank(a)
19371937
dp_rank = dpnp.linalg.matrix_rank(a_dp)
1938-
assert np_rank == dp_rank
1938+
assert dp_rank.asnumpy() == np_rank
19391939

19401940
@pytest.mark.parametrize("dtype", get_all_dtypes())
19411941
@pytest.mark.parametrize(
@@ -1953,7 +1953,7 @@ def test_matrix_rank_hermitian(self, data, dtype):
19531953

19541954
np_rank = numpy.linalg.matrix_rank(a, hermitian=True)
19551955
dp_rank = dpnp.linalg.matrix_rank(a_dp, hermitian=True)
1956-
assert np_rank == dp_rank
1956+
assert dp_rank.asnumpy() == np_rank
19571957

19581958
@pytest.mark.parametrize(
19591959
"high_tol, low_tol",
@@ -1986,15 +1986,15 @@ def test_matrix_rank_tolerance(self, high_tol, low_tol):
19861986
dp_rank_high_tol = dpnp.linalg.matrix_rank(
19871987
a_dp, hermitian=True, tol=dp_high_tol
19881988
)
1989-
assert np_rank_high_tol == dp_rank_high_tol
1989+
assert dp_rank_high_tol.asnumpy() == np_rank_high_tol
19901990

19911991
np_rank_low_tol = numpy.linalg.matrix_rank(
19921992
a, hermitian=True, tol=low_tol
19931993
)
19941994
dp_rank_low_tol = dpnp.linalg.matrix_rank(
19951995
a_dp, hermitian=True, tol=dp_low_tol
19961996
)
1997-
assert np_rank_low_tol == dp_rank_low_tol
1997+
assert dp_rank_low_tol.asnumpy() == np_rank_low_tol
19981998

19991999
# rtol kwarg was added in numpy 2.0
20002000
@testing.with_requires("numpy>=2.0")
@@ -2807,15 +2807,14 @@ def check_decomposition(
28072807
for i in range(min(dp_a.shape[-2], dp_a.shape[-1])):
28082808
dpnp_diag_s[..., i, i] = dp_s[..., i]
28092809
reconstructed = dpnp.dot(dp_u, dpnp.dot(dpnp_diag_s, dp_vt))
2810-
# TODO: use assert dpnp.allclose() inside check_decomposition()
2811-
# when it will support complex dtypes
2812-
assert_allclose(dp_a, reconstructed, rtol=tol, atol=1e-4)
2810+
2811+
assert dpnp.allclose(dp_a, reconstructed, rtol=tol, atol=1e-4)
28132812

28142813
assert_allclose(dp_s, np_s, rtol=tol, atol=1e-03)
28152814

28162815
if compute_vt:
28172816
for i in range(min(dp_a.shape[-2], dp_a.shape[-1])):
2818-
if np_u[..., 0, i] * dp_u[..., 0, i] < 0:
2817+
if np_u[..., 0, i] * dpnp.asnumpy(dp_u[..., 0, i]) < 0:
28192818
np_u[..., :, i] = -np_u[..., :, i]
28202819
np_vt[..., i, :] = -np_vt[..., i, :]
28212820
for i in range(numpy.count_nonzero(np_s > tol)):

dpnp/tests/test_manipulation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .helper import (
1616
assert_dtype_allclose,
1717
get_all_dtypes,
18+
get_array,
1819
get_complex_dtypes,
1920
get_float_complex_dtypes,
2021
get_float_dtypes,
@@ -1232,7 +1233,10 @@ def test_axes(self):
12321233
def test_axes_type(self, axes):
12331234
a = numpy.ones((50, 40, 3))
12341235
ia = dpnp.array(a)
1235-
assert_equal(dpnp.rot90(ia, axes=axes), numpy.rot90(a, axes=axes))
1236+
assert_equal(
1237+
dpnp.rot90(ia, axes=axes),
1238+
numpy.rot90(a, axes=get_array(numpy, axes)),
1239+
)
12361240

12371241
def test_rotation_axes(self):
12381242
a = numpy.arange(8).reshape((2, 2, 2))

dpnp/tests/test_ndarray.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,27 @@ def test_wrong_api_version(self, api_version):
150150
)
151151

152152

153+
class TestArrayUfunc:
154+
def test_add(self):
155+
a = numpy.ones(10)
156+
b = dpnp.ones(10)
157+
msg = "An array must be any of supported type"
158+
159+
with assert_raises_regex(TypeError, msg):
160+
a + b
161+
162+
with assert_raises_regex(TypeError, msg):
163+
b + a
164+
165+
def test_add_inplace(self):
166+
a = numpy.ones(10)
167+
b = dpnp.ones(10)
168+
with assert_raises_regex(
169+
TypeError, "operand 'dpnp_array' does not support ufuncs"
170+
):
171+
a += b
172+
173+
153174
class TestItem:
154175
@pytest.mark.parametrize("args", [2, 7, (1, 2), (2, 0)])
155176
def test_basic(self, args):

0 commit comments

Comments
 (0)