Skip to content

Commit a482d79

Browse files
committed
test_usm_ndarray_linalg changed to reflect vecdot and tensordot changes
1 parent 8239b74 commit a482d79

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

dpctl/tests/test_usm_ndarray_linalg.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -782,12 +782,6 @@ def test_tensordot_axes_errors():
782782
with pytest.raises(ValueError):
783783
dpt.tensordot(m1, m2, axes=-1)
784784

785-
with pytest.raises(ValueError):
786-
dpt.tensordot(m1, m2, axes=((-1,), (1,)))
787-
788-
with pytest.raises(ValueError):
789-
dpt.tensordot(m1, m2, axes=((1,), (-1,)))
790-
791785

792786
@pytest.mark.parametrize("dtype", _numeric_types)
793787
def test_vecdot_1d(dtype):
@@ -834,7 +828,7 @@ def test_vecdot_axis(dtype):
834828

835829
v2 = dpt.ones((m1, n, m2), dtype=dtype)
836830

837-
r = dpt.vecdot(v1, v2, axis=1)
831+
r = dpt.vecdot(v1, v2, axis=-2)
838832

839833
assert r.shape == (
840834
m1,
@@ -864,7 +858,7 @@ def test_vecdot_strided(dtype):
864858
:, :n, ::-1
865859
]
866860

867-
r = dpt.vecdot(v1, v2, axis=1)
861+
r = dpt.vecdot(v1, v2, axis=-2)
868862

869863
ref = sum(
870864
el1 * el2
@@ -903,6 +897,9 @@ def test_vector_arg_validation():
903897
with pytest.raises(ValueError):
904898
dpt.vecdot(v1, v2, axis=2)
905899

900+
with pytest.raises(ValueError):
901+
dpt.vecdot(v1, v2, axis=-2)
902+
906903
q = dpctl.SyclQueue(
907904
v2.sycl_context, v2.sycl_device, property="enable_profiling"
908905
)

0 commit comments

Comments
 (0)