Skip to content

Commit 0075894

Browse files
committed
use gloabl list for usm_type
1 parent a654957 commit 0075894

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

dpnp/tests/test_random_state.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
_def_device = dpctl.SyclQueue().sycl_device
2222
_def_dev_has_fp64 = _def_device.has_aspect_fp64
2323

24+
list_of_usm_types = ["host", "device", "shared"]
25+
2426

2527
def assert_cfd(data, exp_sycl_queue, exp_usm_type=None):
2628
assert exp_sycl_queue == data.sycl_queue
@@ -36,7 +38,7 @@ def get_default_floating():
3638

3739
class TestNormal:
3840
@pytest.mark.parametrize("dtype", [dpnp.float32, dpnp.float64, None])
39-
@pytest.mark.parametrize("usm_type", ["host", "device", "shared"])
41+
@pytest.mark.parametrize("usm_type", list_of_usm_types)
4042
def test_distr(self, dtype, usm_type):
4143
seed = 1234567
4244
sycl_queue = dpctl.SyclQueue()
@@ -91,9 +93,9 @@ def test_distr(self, dtype, usm_type):
9193
assert_cfd(dpnp_data, sycl_queue, usm_type)
9294

9395
@pytest.mark.parametrize("dtype", [dpnp.float32, dpnp.float64, None])
94-
@pytest.mark.parametrize("usm_type", ["host", "device", "shared"])
96+
@pytest.mark.parametrize("usm_type", list_of_usm_types)
9597
def test_scale(self, dtype, usm_type):
96-
mean = 7
98+
mean = 7.0
9799
rs = RandomState(39567)
98100
func = lambda scale: rs.normal(
99101
loc=mean, scale=scale, dtype=dtype, usm_type=usm_type
@@ -127,10 +129,8 @@ def test_scale(self, dtype, usm_type):
127129
],
128130
)
129131
def test_inf_loc(self, loc):
130-
assert_equal(
131-
RandomState(6531).normal(loc=loc, scale=1, size=1000),
132-
get_default_floating()(loc),
133-
)
132+
a = RandomState(6531).normal(loc=loc, scale=1, size=1000)
133+
assert_equal(a, get_default_floating()(loc))
134134

135135
def test_inf_scale(self):
136136
a = RandomState().normal(0, numpy.inf, size=1000)
@@ -142,7 +142,7 @@ def test_inf_scale(self):
142142
@pytest.mark.parametrize("loc", [numpy.inf, -numpy.inf])
143143
def test_inf_loc_scale(self, loc):
144144
a = RandomState().normal(loc=loc, scale=numpy.inf, size=1000)
145-
assert_equal(dpnp.isnan(a).all(), False)
145+
assert not dpnp.isnan(a).all()
146146
assert_equal(dpnp.nanmin(a), loc)
147147
assert_equal(dpnp.nanmax(a), loc)
148148

@@ -252,7 +252,7 @@ def test_invalid_usm_type(self, usm_type):
252252

253253

254254
class TestRand:
255-
@pytest.mark.parametrize("usm_type", ["host", "device", "shared"])
255+
@pytest.mark.parametrize("usm_type", list_of_usm_types)
256256
def test_distr(self, usm_type):
257257
seed = 28042
258258
sycl_queue = dpctl.SyclQueue()
@@ -337,7 +337,7 @@ class TestRandInt:
337337
[int, dpnp.int32, dpnp.int_],
338338
ids=["int", "dpnp.int32", "dpnp.int_"],
339339
)
340-
@pytest.mark.parametrize("usm_type", ["host", "device", "shared"])
340+
@pytest.mark.parametrize("usm_type", list_of_usm_types)
341341
def test_distr(self, dtype, usm_type):
342342
seed = 9864
343343
low = 1
@@ -419,7 +419,7 @@ def test_negative_bounds(self):
419419
def test_negative_interval(self):
420420
rs = RandomState(3567)
421421

422-
assert_equal(-5 <= rs.randint(-5, -1) < -1, True)
422+
assert -5 <= rs.randint(-5, -1) < -1
423423

424424
x = rs.randint(-7, -1, 5)
425425
assert_equal(-7 <= x, True)
@@ -486,8 +486,8 @@ def test_full_range(self):
486486
def test_in_bounds_fuzz(self):
487487
for high in [4, 8, 16]:
488488
vals = RandomState().randint(2, high, size=2**16)
489-
assert_equal(vals.max() < high, True)
490-
assert_equal(vals.min() >= 2, True)
489+
assert vals.max() < high
490+
assert vals.min() >= 2
491491

492492
@pytest.mark.parametrize(
493493
"zero_size",
@@ -567,7 +567,7 @@ def test_invalid_usm_type(self, usm_type):
567567

568568

569569
class TestRandN:
570-
@pytest.mark.parametrize("usm_type", ["host", "device", "shared"])
570+
@pytest.mark.parametrize("usm_type", list_of_usm_types)
571571
def test_distr(self, usm_type):
572572
seed = 3649
573573
sycl_queue = dpctl.SyclQueue()
@@ -796,7 +796,7 @@ def test_invalid_shape(self, seed):
796796

797797

798798
class TestStandardNormal:
799-
@pytest.mark.parametrize("usm_type", ["host", "device", "shared"])
799+
@pytest.mark.parametrize("usm_type", list_of_usm_types)
800800
def test_distr(self, usm_type):
801801
seed = 1234567
802802
sycl_queue = dpctl.SyclQueue()
@@ -870,7 +870,7 @@ def test_wrong_dims(self):
870870

871871

872872
class TestRandSample:
873-
@pytest.mark.parametrize("usm_type", ["host", "device", "shared"])
873+
@pytest.mark.parametrize("usm_type", list_of_usm_types)
874874
def test_distr(self, usm_type):
875875
seed = 12657
876876
sycl_queue = dpctl.SyclQueue()
@@ -944,7 +944,7 @@ class TestUniform:
944944
@pytest.mark.parametrize(
945945
"dtype", [dpnp.float32, dpnp.float64, dpnp.int32, None]
946946
)
947-
@pytest.mark.parametrize("usm_type", ["host", "device", "shared"])
947+
@pytest.mark.parametrize("usm_type", list_of_usm_types)
948948
def test_distr(self, bounds, dtype, usm_type):
949949
seed = 28041997
950950
low = bounds[0]
@@ -1000,7 +1000,7 @@ def test_distr(self, bounds, dtype, usm_type):
10001000
@pytest.mark.parametrize(
10011001
"dtype", [dpnp.float32, dpnp.float64, dpnp.int32, None]
10021002
)
1003-
@pytest.mark.parametrize("usm_type", ["host", "device", "shared"])
1003+
@pytest.mark.parametrize("usm_type", list_of_usm_types)
10041004
def test_low_high_equal(self, dtype, usm_type):
10051005
seed = 28045
10061006
low = high = 3.75

0 commit comments

Comments
 (0)