41
41
]
42
42
43
43
44
+ def _map_int_to_type (n , dt ):
45
+ assert isinstance (n , int )
46
+ assert n > 0
47
+ if dt == dpt .int8 :
48
+ return ((n + 128 ) % 256 ) - 128
49
+ elif dt == dpt .uint8 :
50
+ return n % 256
51
+ elif dt == dpt .int16 :
52
+ return ((n + 32768 ) % 65536 ) - 32768
53
+ elif dt == dpt .uint16 :
54
+ return n % 65536
55
+ return n
56
+
57
+
44
58
def test_matrix_transpose ():
45
59
get_queue_or_skip ()
46
60
@@ -702,8 +716,8 @@ def test_vecdot_1d(dtype):
702
716
v2 = dpt .ones (n , dtype = dtype )
703
717
704
718
r = dpt .vecdot (v1 , v2 )
705
-
706
- assert r == n
719
+ expected_value = _map_int_to_type ( n , r . dtype )
720
+ assert r == expected_value
707
721
708
722
709
723
@pytest .mark .parametrize ("dtype" , _numeric_types )
@@ -722,7 +736,8 @@ def test_vecdot_3d(dtype):
722
736
m1 ,
723
737
m2 ,
724
738
)
725
- assert dpt .all (r == n )
739
+ expected_value = _map_int_to_type (n , r .dtype )
740
+ assert dpt .all (r == expected_value )
726
741
727
742
728
743
@pytest .mark .parametrize ("dtype" , _numeric_types )
@@ -741,7 +756,8 @@ def test_vecdot_axis(dtype):
741
756
m1 ,
742
757
m2 ,
743
758
)
744
- assert dpt .all (r == n )
759
+ expected_value = _map_int_to_type (n , r .dtype )
760
+ assert dpt .all (r == expected_value )
745
761
746
762
747
763
@pytest .mark .parametrize ("dtype" , _numeric_types )
@@ -775,6 +791,7 @@ def test_vecdot_strided(dtype):
775
791
m1 ,
776
792
m2 ,
777
793
)
794
+ ref = _map_int_to_type (ref , r .dtype )
778
795
assert dpt .all (r == ref )
779
796
780
797
0 commit comments