Skip to content

Commit 1015fcf

Browse files
Merge pull request #1519 from IntelPython/resolve-numpy-warning-tripped-in-linalg-test
Explicitly project ref-values to range of short integral types
2 parents e50c303 + 56030dd commit 1015fcf

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

dpctl/tests/test_usm_ndarray_linalg.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,20 @@
4141
]
4242

4343

44+
def _map_int_to_type(n, dt):
45+
assert isinstance(n, int)
46+
assert n > 0
47+
if dt == dpt.int8:
48+
return ((n + 128) % 256) - 128
49+
elif dt == dpt.uint8:
50+
return n % 256
51+
elif dt == dpt.int16:
52+
return ((n + 32768) % 65536) - 32768
53+
elif dt == dpt.uint16:
54+
return n % 65536
55+
return n
56+
57+
4458
def test_matrix_transpose():
4559
get_queue_or_skip()
4660

@@ -702,8 +716,8 @@ def test_vecdot_1d(dtype):
702716
v2 = dpt.ones(n, dtype=dtype)
703717

704718
r = dpt.vecdot(v1, v2)
705-
706-
assert r == n
719+
expected_value = _map_int_to_type(n, r.dtype)
720+
assert r == expected_value
707721

708722

709723
@pytest.mark.parametrize("dtype", _numeric_types)
@@ -722,7 +736,8 @@ def test_vecdot_3d(dtype):
722736
m1,
723737
m2,
724738
)
725-
assert dpt.all(r == n)
739+
expected_value = _map_int_to_type(n, r.dtype)
740+
assert dpt.all(r == expected_value)
726741

727742

728743
@pytest.mark.parametrize("dtype", _numeric_types)
@@ -741,7 +756,8 @@ def test_vecdot_axis(dtype):
741756
m1,
742757
m2,
743758
)
744-
assert dpt.all(r == n)
759+
expected_value = _map_int_to_type(n, r.dtype)
760+
assert dpt.all(r == expected_value)
745761

746762

747763
@pytest.mark.parametrize("dtype", _numeric_types)
@@ -775,6 +791,7 @@ def test_vecdot_strided(dtype):
775791
m1,
776792
m2,
777793
)
794+
ref = _map_int_to_type(ref, r.dtype)
778795
assert dpt.all(r == ref)
779796

780797

0 commit comments

Comments
 (0)