Skip to content

Commit 8b5b29b

Browse files
committed
address_comments
1 parent 086dadc commit 8b5b29b

File tree

2 files changed

+20
-75
lines changed

2 files changed

+20
-75
lines changed

dpnp/tests/test_random.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,9 +1067,7 @@ def test_seed(self):
10671067

10681068
class TestPermutationsTestShuffle:
10691069
@pytest.mark.parametrize(
1070-
"dtype",
1071-
[dpnp.float32, dpnp.float64, dpnp.int32, dpnp.int64],
1072-
ids=["float32", "float64", "int32", "int64"],
1070+
"dtype", [dpnp.float32, dpnp.float64, dpnp.int32, dpnp.int64]
10731071
)
10741072
def test_shuffle(self, dtype):
10751073
seed = 28041990
@@ -1086,9 +1084,7 @@ def test_shuffle(self, dtype):
10861084
assert_array_equal(actual_x, desired_x)
10871085

10881086
@pytest.mark.parametrize(
1089-
"dtype",
1090-
[dpnp.float32, dpnp.float64, dpnp.int32, dpnp.int64],
1091-
ids=["float32", "float64", "int32", "int64"],
1087+
"dtype", [dpnp.float32, dpnp.float64, dpnp.int32, dpnp.int64]
10921088
)
10931089
def test_no_miss_numbers(self, dtype):
10941090
seed = 28041990

dpnp/tests/test_random_state.py

Lines changed: 18 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,8 @@ def get_default_floating():
3535

3636

3737
class TestNormal:
38-
@pytest.mark.parametrize(
39-
"dtype",
40-
[dpnp.float32, dpnp.float64, None],
41-
ids=["float32", "float64", "None"],
42-
)
43-
@pytest.mark.parametrize(
44-
"usm_type",
45-
["host", "device", "shared"],
46-
ids=["host", "device", "shared"],
47-
)
38+
@pytest.mark.parametrize("dtype", [dpnp.float32, dpnp.float64, None])
39+
@pytest.mark.parametrize("usm_type", ["host", "device", "shared"])
4840
def test_distr(self, dtype, usm_type):
4941
seed = 1234567
5042
sycl_queue = dpctl.SyclQueue()
@@ -98,16 +90,8 @@ def test_distr(self, dtype, usm_type):
9890
# check if compute follows data isn't broken
9991
assert_cfd(dpnp_data, sycl_queue, usm_type)
10092

101-
@pytest.mark.parametrize(
102-
"dtype",
103-
[dpnp.float32, dpnp.float64, None],
104-
ids=["float32", "float64", "None"],
105-
)
106-
@pytest.mark.parametrize(
107-
"usm_type",
108-
["host", "device", "shared"],
109-
ids=["host", "device", "shared"],
110-
)
93+
@pytest.mark.parametrize("dtype", [dpnp.float32, dpnp.float64, None])
94+
@pytest.mark.parametrize("usm_type", ["host", "device", "shared"])
11195
def test_scale(self, dtype, usm_type):
11296
mean = 7
11397
rs = RandomState(39567)
@@ -150,14 +134,12 @@ def test_inf_loc(self, loc):
150134

151135
def test_inf_scale(self):
152136
a = RandomState().normal(0, numpy.inf, size=1000)
153-
assert_equal(dpnp.isnan(a).any(), False)
154-
assert_equal(dpnp.isinf(a).all(), True)
137+
assert not dpnp.isnan(a).any()
138+
assert dpnp.isinf(a).all()
155139
assert_equal(a.max(), numpy.inf)
156140
assert_equal(a.min(), -numpy.inf)
157141

158-
@pytest.mark.parametrize(
159-
"loc", [numpy.inf, -numpy.inf], ids=["numpy.inf", "-numpy.inf"]
160-
)
142+
@pytest.mark.parametrize("loc", [numpy.inf, -numpy.inf])
161143
def test_inf_loc_scale(self, loc):
162144
a = RandomState().normal(loc=loc, scale=numpy.inf, size=1000)
163145
assert_equal(dpnp.isnan(a).all(), False)
@@ -270,11 +252,7 @@ def test_invalid_usm_type(self, usm_type):
270252

271253

272254
class TestRand:
273-
@pytest.mark.parametrize(
274-
"usm_type",
275-
["host", "device", "shared"],
276-
ids=["host", "device", "shared"],
277-
)
255+
@pytest.mark.parametrize("usm_type", ["host", "device", "shared"])
278256
def test_distr(self, usm_type):
279257
seed = 28042
280258
sycl_queue = dpctl.SyclQueue()
@@ -359,11 +337,7 @@ class TestRandInt:
359337
[int, dpnp.int32, dpnp.int_],
360338
ids=["int", "dpnp.int32", "dpnp.int_"],
361339
)
362-
@pytest.mark.parametrize(
363-
"usm_type",
364-
["host", "device", "shared"],
365-
ids=["host", "device", "shared"],
366-
)
340+
@pytest.mark.parametrize("usm_type", ["host", "device", "shared"])
367341
def test_distr(self, dtype, usm_type):
368342
seed = 9864
369343
low = 1
@@ -593,11 +567,7 @@ def test_invalid_usm_type(self, usm_type):
593567

594568

595569
class TestRandN:
596-
@pytest.mark.parametrize(
597-
"usm_type",
598-
["host", "device", "shared"],
599-
ids=["host", "device", "shared"],
600-
)
570+
@pytest.mark.parametrize("usm_type", ["host", "device", "shared"])
601571
def test_distr(self, usm_type):
602572
seed = 3649
603573
sycl_queue = dpctl.SyclQueue()
@@ -627,6 +597,7 @@ def test_distr(self, usm_type):
627597

628598
# TODO: discuss with oneMKL: there is a difference between CPU and GPU
629599
# generated samples since 9 digit while precision=15 for float64
600+
# precision = dpnp.finfo(numpy.float64).precision
630601
precision = dpnp.finfo(numpy.float32).precision
631602
assert_array_almost_equal(data, expected, decimal=precision)
632603

@@ -677,9 +648,7 @@ def test_wrong_dims(self):
677648

678649
class TestSeed:
679650
@pytest.mark.parametrize(
680-
"func",
681-
["normal", "standard_normal", "random_sample", "uniform"],
682-
ids=["normal", "standard_normal", "random_sample", "uniform"],
651+
"func", ["normal", "standard_normal", "random_sample", "uniform"]
683652
)
684653
def test_scalar(self, func):
685654
seed = 28041997
@@ -827,11 +796,7 @@ def test_invalid_shape(self, seed):
827796

828797

829798
class TestStandardNormal:
830-
@pytest.mark.parametrize(
831-
"usm_type",
832-
["host", "device", "shared"],
833-
ids=["host", "device", "shared"],
834-
)
799+
@pytest.mark.parametrize("usm_type", ["host", "device", "shared"])
835800
def test_distr(self, usm_type):
836801
seed = 1234567
837802
sycl_queue = dpctl.SyclQueue()
@@ -905,11 +870,7 @@ def test_wrong_dims(self):
905870

906871

907872
class TestRandSample:
908-
@pytest.mark.parametrize(
909-
"usm_type",
910-
["host", "device", "shared"],
911-
ids=["host", "device", "shared"],
912-
)
873+
@pytest.mark.parametrize("usm_type", ["host", "device", "shared"])
913874
def test_distr(self, usm_type):
914875
seed = 12657
915876
sycl_queue = dpctl.SyclQueue()
@@ -981,15 +942,9 @@ class TestUniform:
981942
ids=["(low, high)=[1.23, 10.54]", "(low, high)=[10.54, 1.23]"],
982943
)
983944
@pytest.mark.parametrize(
984-
"dtype",
985-
[dpnp.float32, dpnp.float64, dpnp.int32, None],
986-
ids=["float32", "float64", "int32", "None"],
987-
)
988-
@pytest.mark.parametrize(
989-
"usm_type",
990-
["host", "device", "shared"],
991-
ids=["host", "device", "shared"],
945+
"dtype", [dpnp.float32, dpnp.float64, dpnp.int32, None]
992946
)
947+
@pytest.mark.parametrize("usm_type", ["host", "device", "shared"])
993948
def test_distr(self, bounds, dtype, usm_type):
994949
seed = 28041997
995950
low = bounds[0]
@@ -1043,15 +998,9 @@ def test_distr(self, bounds, dtype, usm_type):
1043998
assert_cfd(actual, sycl_queue, usm_type)
1044999

10451000
@pytest.mark.parametrize(
1046-
"dtype",
1047-
[dpnp.float32, dpnp.float64, dpnp.int32, None],
1048-
ids=["float32", "float64", "int32", "None"],
1049-
)
1050-
@pytest.mark.parametrize(
1051-
"usm_type",
1052-
["host", "device", "shared"],
1053-
ids=["host", "device", "shared"],
1001+
"dtype", [dpnp.float32, dpnp.float64, dpnp.int32, None]
10541002
)
1003+
@pytest.mark.parametrize("usm_type", ["host", "device", "shared"])
10551004
def test_low_high_equal(self, dtype, usm_type):
10561005
seed = 28045
10571006
low = high = 3.75

0 commit comments

Comments
 (0)