21
21
_def_device = dpctl .SyclQueue ().sycl_device
22
22
_def_dev_has_fp64 = _def_device .has_aspect_fp64
23
23
24
+ list_of_usm_types = ["host" , "device" , "shared" ]
25
+
24
26
25
27
def assert_cfd (data , exp_sycl_queue , exp_usm_type = None ):
26
28
assert exp_sycl_queue == data .sycl_queue
@@ -36,7 +38,7 @@ def get_default_floating():
36
38
37
39
class TestNormal :
38
40
@pytest .mark .parametrize ("dtype" , [dpnp .float32 , dpnp .float64 , None ])
39
- @pytest .mark .parametrize ("usm_type" , [ "host" , "device" , "shared" ] )
41
+ @pytest .mark .parametrize ("usm_type" , list_of_usm_types )
40
42
def test_distr (self , dtype , usm_type ):
41
43
seed = 1234567
42
44
sycl_queue = dpctl .SyclQueue ()
@@ -91,9 +93,9 @@ def test_distr(self, dtype, usm_type):
91
93
assert_cfd (dpnp_data , sycl_queue , usm_type )
92
94
93
95
@pytest .mark .parametrize ("dtype" , [dpnp .float32 , dpnp .float64 , None ])
94
- @pytest .mark .parametrize ("usm_type" , [ "host" , "device" , "shared" ] )
96
+ @pytest .mark .parametrize ("usm_type" , list_of_usm_types )
95
97
def test_scale (self , dtype , usm_type ):
96
- mean = 7
98
+ mean = 7.0
97
99
rs = RandomState (39567 )
98
100
func = lambda scale : rs .normal (
99
101
loc = mean , scale = scale , dtype = dtype , usm_type = usm_type
@@ -127,10 +129,8 @@ def test_scale(self, dtype, usm_type):
127
129
],
128
130
)
129
131
def test_inf_loc (self , loc ):
130
- assert_equal (
131
- RandomState (6531 ).normal (loc = loc , scale = 1 , size = 1000 ),
132
- get_default_floating ()(loc ),
133
- )
132
+ a = RandomState (6531 ).normal (loc = loc , scale = 1 , size = 1000 )
133
+ assert_equal (a , get_default_floating ()(loc ))
134
134
135
135
def test_inf_scale (self ):
136
136
a = RandomState ().normal (0 , numpy .inf , size = 1000 )
@@ -142,7 +142,7 @@ def test_inf_scale(self):
142
142
@pytest .mark .parametrize ("loc" , [numpy .inf , - numpy .inf ])
143
143
def test_inf_loc_scale (self , loc ):
144
144
a = RandomState ().normal (loc = loc , scale = numpy .inf , size = 1000 )
145
- assert_equal ( dpnp .isnan (a ).all (), False )
145
+ assert not dpnp .isnan (a ).all ()
146
146
assert_equal (dpnp .nanmin (a ), loc )
147
147
assert_equal (dpnp .nanmax (a ), loc )
148
148
@@ -252,7 +252,7 @@ def test_invalid_usm_type(self, usm_type):
252
252
253
253
254
254
class TestRand :
255
- @pytest .mark .parametrize ("usm_type" , [ "host" , "device" , "shared" ] )
255
+ @pytest .mark .parametrize ("usm_type" , list_of_usm_types )
256
256
def test_distr (self , usm_type ):
257
257
seed = 28042
258
258
sycl_queue = dpctl .SyclQueue ()
@@ -337,7 +337,7 @@ class TestRandInt:
337
337
[int , dpnp .int32 , dpnp .int_ ],
338
338
ids = ["int" , "dpnp.int32" , "dpnp.int_" ],
339
339
)
340
- @pytest .mark .parametrize ("usm_type" , [ "host" , "device" , "shared" ] )
340
+ @pytest .mark .parametrize ("usm_type" , list_of_usm_types )
341
341
def test_distr (self , dtype , usm_type ):
342
342
seed = 9864
343
343
low = 1
@@ -419,7 +419,7 @@ def test_negative_bounds(self):
419
419
def test_negative_interval (self ):
420
420
rs = RandomState (3567 )
421
421
422
- assert_equal ( - 5 <= rs .randint (- 5 , - 1 ) < - 1 , True )
422
+ assert - 5 <= rs .randint (- 5 , - 1 ) < - 1
423
423
424
424
x = rs .randint (- 7 , - 1 , 5 )
425
425
assert_equal (- 7 <= x , True )
@@ -486,8 +486,8 @@ def test_full_range(self):
486
486
def test_in_bounds_fuzz (self ):
487
487
for high in [4 , 8 , 16 ]:
488
488
vals = RandomState ().randint (2 , high , size = 2 ** 16 )
489
- assert_equal ( vals .max () < high , True )
490
- assert_equal ( vals .min () >= 2 , True )
489
+ assert vals .max () < high
490
+ assert vals .min () >= 2
491
491
492
492
@pytest .mark .parametrize (
493
493
"zero_size" ,
@@ -567,7 +567,7 @@ def test_invalid_usm_type(self, usm_type):
567
567
568
568
569
569
class TestRandN :
570
- @pytest .mark .parametrize ("usm_type" , [ "host" , "device" , "shared" ] )
570
+ @pytest .mark .parametrize ("usm_type" , list_of_usm_types )
571
571
def test_distr (self , usm_type ):
572
572
seed = 3649
573
573
sycl_queue = dpctl .SyclQueue ()
@@ -796,7 +796,7 @@ def test_invalid_shape(self, seed):
796
796
797
797
798
798
class TestStandardNormal :
799
- @pytest .mark .parametrize ("usm_type" , [ "host" , "device" , "shared" ] )
799
+ @pytest .mark .parametrize ("usm_type" , list_of_usm_types )
800
800
def test_distr (self , usm_type ):
801
801
seed = 1234567
802
802
sycl_queue = dpctl .SyclQueue ()
@@ -870,7 +870,7 @@ def test_wrong_dims(self):
870
870
871
871
872
872
class TestRandSample :
873
- @pytest .mark .parametrize ("usm_type" , [ "host" , "device" , "shared" ] )
873
+ @pytest .mark .parametrize ("usm_type" , list_of_usm_types )
874
874
def test_distr (self , usm_type ):
875
875
seed = 12657
876
876
sycl_queue = dpctl .SyclQueue ()
@@ -944,7 +944,7 @@ class TestUniform:
944
944
@pytest .mark .parametrize (
945
945
"dtype" , [dpnp .float32 , dpnp .float64 , dpnp .int32 , None ]
946
946
)
947
- @pytest .mark .parametrize ("usm_type" , [ "host" , "device" , "shared" ] )
947
+ @pytest .mark .parametrize ("usm_type" , list_of_usm_types )
948
948
def test_distr (self , bounds , dtype , usm_type ):
949
949
seed = 28041997
950
950
low = bounds [0 ]
@@ -1000,7 +1000,7 @@ def test_distr(self, bounds, dtype, usm_type):
1000
1000
@pytest .mark .parametrize (
1001
1001
"dtype" , [dpnp .float32 , dpnp .float64 , dpnp .int32 , None ]
1002
1002
)
1003
- @pytest .mark .parametrize ("usm_type" , [ "host" , "device" , "shared" ] )
1003
+ @pytest .mark .parametrize ("usm_type" , list_of_usm_types )
1004
1004
def test_low_high_equal (self , dtype , usm_type ):
1005
1005
seed = 28045
1006
1006
low = high = 3.75
0 commit comments