Skip to content

Commit c08ea5e

Browse files
Use setup_method to rudece code duplication in test_linalg
1 parent 16c292c commit c08ea5e

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

tests/test_linalg.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from .helper import (
1515
assert_dtype_allclose,
1616
get_all_dtypes,
17-
get_complex_dtypes,
1817
get_float_complex_dtypes,
1918
has_support_aspect64,
2019
is_cpu_device,
@@ -679,6 +678,11 @@ def test_norm3(array, ord, axis):
679678

680679

681680
class TestQr:
681+
# Set numpy.random.seed for test methods to prevent
682+
# random generation of the input singular matrix
683+
def setup_method(self):
684+
numpy.random.seed(76)
685+
682686
# TODO: New packages that fix issue CMPLRLLVM-53771 are only available in internal CI.
683687
# Skip the tests on cpu until these packages are available for the external CI.
684688
# Specifically dpcpp_linux-64>=2024.1.0
@@ -703,7 +707,6 @@ class TestQr:
703707
ids=["r", "raw", "complete", "reduced"],
704708
)
705709
def test_qr(self, dtype, shape, mode):
706-
numpy.random.seed(76)
707710
a = numpy.random.randn(*shape).astype(dtype)
708711
if numpy.issubdtype(dtype, numpy.complexfloating):
709712
a += 1j * numpy.random.randn(*shape)
@@ -776,7 +779,6 @@ def test_qr_empty(self, dtype, shape, mode):
776779
ids=["r", "raw", "complete", "reduced"],
777780
)
778781
def test_qr_strides(self, mode):
779-
numpy.random.seed(76)
780782
a = numpy.random.randn(5, 5)
781783
ia = inp.array(a)
782784

@@ -1037,6 +1039,11 @@ def test_slogdet_errors(self):
10371039

10381040

10391041
class TestSvd:
1042+
# Set numpy.random.seed for test methods to prevent
1043+
# random generation of the input singular matrix
1044+
def setup_method(self):
1045+
numpy.random.seed(76)
1046+
10401047
def get_tol(self, dtype):
10411048
tol = 1e-06
10421049
if dtype in (inp.float32, inp.complex64):
@@ -1134,7 +1141,6 @@ def test_svd(self, dtype, shape):
11341141
ids=["(2, 2)", "(16, 16)"],
11351142
)
11361143
def test_svd_hermitian(self, dtype, compute_vt, shape):
1137-
numpy.random.seed(76)
11381144
a = numpy.random.randn(*shape).astype(dtype)
11391145
if numpy.issubdtype(dtype, numpy.complexfloating):
11401146
a += 1j * numpy.random.randn(*shape)
@@ -1177,6 +1183,11 @@ def test_svd_errors(self):
11771183

11781184

11791185
class TestPinv:
1186+
# Set numpy.random.seed for test methods to prevent
1187+
# random generation of the input singular matrix
1188+
def setup_method(self):
1189+
numpy.random.seed(76)
1190+
11801191
def get_tol(self, dtype):
11811192
tol = 1e-06
11821193
if dtype in (inp.float32, inp.complex64):
@@ -1212,7 +1223,6 @@ def check_types_shapes(self, dp_B, np_B):
12121223
],
12131224
)
12141225
def test_pinv(self, dtype, shape):
1215-
numpy.random.seed(76)
12161226
a = numpy.random.randn(*shape).astype(dtype)
12171227
if numpy.issubdtype(dtype, numpy.complexfloating):
12181228
a += 1j * numpy.random.randn(*shape)
@@ -1240,7 +1250,6 @@ def test_pinv(self, dtype, shape):
12401250
ids=["(2, 2)", "(16, 16)"],
12411251
)
12421252
def test_pinv_hermitian(self, dtype, shape):
1243-
numpy.random.seed(76)
12441253
a = numpy.random.randn(*shape).astype(dtype)
12451254
if numpy.issubdtype(dtype, numpy.complexfloating):
12461255
a += 1j * numpy.random.randn(*shape)
@@ -1281,7 +1290,6 @@ def test_pinv_empty(self, dtype, shape):
12811290
assert_dtype_allclose(B_dp, B)
12821291

12831292
def test_pinv_strides(self):
1284-
numpy.random.seed(76)
12851293
a = numpy.random.randn(5, 5)
12861294
a_dp = inp.array(a)
12871295

0 commit comments

Comments
 (0)