Skip to content

Commit d559a76

Browse files
committed
add tests
1 parent 279c69a commit d559a76

File tree

4 files changed

+9
-2
lines changed

4 files changed

+9
-2
lines changed

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,11 @@ def _call_func(src, dst, sycl_queue, depends=None):
122122
if vmi._is_available() and mkl_fn_to_call is not None:
123123
if getattr(vmi, mkl_fn_to_call)(sycl_queue, src, dst):
124124
# call pybind11 extension for unary function from OneMKL VM
125+
print("call vm backend")
125126
return getattr(vmi, mkl_impl_fn)(
126127
sycl_queue, src, dst, depends
127128
)
129+
print("call dpctl backend")
128130
return unary_dp_impl_fn(src, dst, sycl_queue, depends)
129131

130132
super().__init__(
@@ -185,6 +187,7 @@ def __call__(
185187
if dtype is not None:
186188
x_usm = dpt.astype(x_usm, dtype, copy=False)
187189

190+
print("VAHID")
188191
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
189192
res_usm = super().__call__(x_usm, out=out_usm, order=order)
190193

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,9 +332,9 @@ def _get_result_shape_vecdot(x1, x2, x1_ndim, x2_ndim):
332332

333333
if x1_ndim == 1 and x2_ndim == 1:
334334
result_shape = ()
335-
elif x1_is_1D and not x2_is_1D:
335+
elif x1_ndim == 1:
336336
result_shape = x2_shape[:-1]
337-
elif x2_is_1D and not x1_is_1D:
337+
elif x2_ndim == 1:
338338
result_shape = x1_shape[:-1]
339339
else:
340340
if x1_ndim != x2_ndim:

dpnp/tests/helper.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def assert_dtype_allclose(
3636
3737
"""
3838

39+
assert dpnp_arr.shape == numpy_arr.shape
40+
3941
list_64bit_types = [numpy.float64, numpy.complex128]
4042
is_inexact = lambda x: hasattr(x, "dtype") and dpnp.issubdtype(
4143
x.dtype, dpnp.inexact

dpnp/tests/test_product.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,8 @@ def setup_method(self):
10001000
((1, 4, 5), (3, 1, 5)),
10011001
((1, 1, 4, 5), (3, 1, 5)),
10021002
((1, 4, 5), (1, 3, 1, 5)),
1003+
((2, 1), (1, 1, 1)),
1004+
((1, 1, 3), (3,)),
10031005
],
10041006
)
10051007
def test_basic(self, dtype, shape1, shape2):

0 commit comments

Comments
 (0)