@@ -782,12 +782,6 @@ def test_tensordot_axes_errors():
782
782
with pytest .raises (ValueError ):
783
783
dpt .tensordot (m1 , m2 , axes = - 1 )
784
784
785
- with pytest .raises (ValueError ):
786
- dpt .tensordot (m1 , m2 , axes = ((- 1 ,), (1 ,)))
787
-
788
- with pytest .raises (ValueError ):
789
- dpt .tensordot (m1 , m2 , axes = ((1 ,), (- 1 ,)))
790
-
791
785
792
786
@pytest .mark .parametrize ("dtype" , _numeric_types )
793
787
def test_vecdot_1d (dtype ):
@@ -834,7 +828,7 @@ def test_vecdot_axis(dtype):
834
828
835
829
v2 = dpt .ones ((m1 , n , m2 ), dtype = dtype )
836
830
837
- r = dpt .vecdot (v1 , v2 , axis = 1 )
831
+ r = dpt .vecdot (v1 , v2 , axis = - 2 )
838
832
839
833
assert r .shape == (
840
834
m1 ,
@@ -864,7 +858,7 @@ def test_vecdot_strided(dtype):
864
858
:, :n , ::- 1
865
859
]
866
860
867
- r = dpt .vecdot (v1 , v2 , axis = 1 )
861
+ r = dpt .vecdot (v1 , v2 , axis = - 2 )
868
862
869
863
ref = sum (
870
864
el1 * el2
@@ -903,6 +897,9 @@ def test_vector_arg_validation():
903
897
with pytest .raises (ValueError ):
904
898
dpt .vecdot (v1 , v2 , axis = 2 )
905
899
900
+ with pytest .raises (ValueError ):
901
+ dpt .vecdot (v1 , v2 , axis = - 2 )
902
+
906
903
q = dpctl .SyclQueue (
907
904
v2 .sycl_context , v2 .sycl_device , property = "enable_profiling"
908
905
)
0 commit comments