Skip to content

Commit 815436c

Browse files
authored
Update test_mixins.py
1 parent bbeb3ce commit 815436c

File tree

1 file changed

+27
-61
lines changed

1 file changed

+27
-61
lines changed

tests/test_mixins.py

Lines changed: 27 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,44 @@
11
import unittest
22

3-
import numpy
4-
5-
import dpnp as inp
6-
7-
from .helper import get_float_dtypes
3+
from tests.third_party.cupy import testing
84

95

106
class TestMatMul(unittest.TestCase):
11-
def test_matmul(self):
12-
array_data = [1.0, 2.0, 3.0, 4.0]
13-
size = 2
14-
15-
for dtype in get_float_dtypes():
16-
# DPNP
17-
array1 = inp.reshape(
18-
inp.array(array_data, dtype=dtype), (size, size)
19-
)
20-
array2 = inp.reshape(
21-
inp.array(array_data, dtype=dtype), (size, size)
22-
)
23-
result = inp.matmul(array1, array2)
24-
# print(result)
25-
26-
# original
27-
array_1 = numpy.array(array_data, dtype=dtype).reshape((size, size))
28-
array_2 = numpy.array(array_data, dtype=dtype).reshape((size, size))
29-
expected = numpy.matmul(array_1, array_2)
30-
# print(expected)
31-
32-
# passed
33-
numpy.testing.assert_array_equal(expected, result)
34-
# still failed
35-
# self.assertEqual(expected, result)
36-
37-
def test_matmul2(self):
38-
array_data1 = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
39-
array_data2 = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
7+
@testing.for_float_dtypes()
8+
@testing.numpy_cupy_allclose()
9+
def test_matmul(self, xp, dtype):
10+
data = [1.0, 2.0, 3.0, 4.0]
11+
shape = (2, 2)
4012

41-
for dtype in get_float_dtypes():
42-
# DPNP
43-
array1 = inp.reshape(inp.array(array_data1, dtype=dtype), (3, 2))
44-
array2 = inp.reshape(inp.array(array_data2, dtype=dtype), (2, 4))
45-
result = inp.matmul(array1, array2)
46-
# print(result)
13+
a = xp.array(data, dtype=dtype).reshape(shape)
14+
b = xp.array(data, dtype=dtype).reshape(shape)
4715

48-
# original
49-
array_1 = numpy.array(array_data1, dtype=dtype).reshape((3, 2))
50-
array_2 = numpy.array(array_data2, dtype=dtype).reshape((2, 4))
51-
expected = numpy.matmul(array_1, array_2)
52-
# print(expected)
16+
return xp.matmul(a, b)
5317

54-
numpy.testing.assert_array_equal(expected, result)
18+
@testing.for_float_dtypes()
19+
@testing.numpy_cupy_allclose()
20+
def test_matmul2(self, xp, dtype):
21+
data1 = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
22+
data2 = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
5523

56-
def test_matmul3(self):
57-
array_data1 = numpy.full((513, 513), 5)
58-
array_data2 = numpy.full((513, 513), 2)
24+
a = xp.array(data1, dtype=dtype).reshape(3, 2)
25+
b = xp.array(data2, dtype=dtype).reshape(2, 4)
5926

60-
for dtype in get_float_dtypes():
61-
out = numpy.empty((513, 513), dtype=dtype)
27+
return xp.matmul(a, b)
6228

63-
# DPNP
64-
array1 = inp.array(array_data1, dtype=dtype)
65-
array2 = inp.array(array_data2, dtype=dtype)
66-
out1 = inp.array(out, dtype=dtype)
67-
result = inp.matmul(array1, array2, out=out1)
29+
@testing.for_float_dtypes()
30+
@testing.numpy_cupy_allclose()
31+
def test_matmul3(self, xp, dtype):
32+
data1 = xp.full((513, 513), 5)
33+
data2 = xp.full((513, 513), 2)
34+
out = xp.empty((513, 513), dtype=dtype)
6835

69-
# original
70-
array_1 = numpy.array(array_data1, dtype=dtype)
71-
array_2 = numpy.array(array_data2, dtype=dtype)
72-
expected = numpy.matmul(array_1, array_2, out=out)
36+
a = xp.array(data1, dtype=dtype)
37+
b = xp.array(data2, dtype=dtype)
7338

74-
numpy.testing.assert_array_equal(expected, result)
39+
xp.matmul(a, b, out=out)
7540

41+
return out
7642

7743
if __name__ == "__main__":
7844
unittest.main()

0 commit comments

Comments
 (0)