5
5
import numpy
6
6
7
7
from dpnp .random import RandomState
8
- from numpy .testing import (assert_allclose , assert_raises , assert_array_almost_equal )
8
+ from numpy .testing import (assert_allclose , assert_raises , assert_array_equal , assert_array_almost_equal )
9
9
10
10
11
11
class TestSeed :
@@ -87,14 +87,28 @@ class TestUniform:
87
87
@pytest .mark .parametrize ("usm_type" ,
88
88
["host" , "device" , "shared" ],
89
89
ids = ['host' , 'device' , 'shared' ])
90
- def test_uniform (self , dtype , usm_type ):
90
+ def test_uniform_float (self , dtype , usm_type ):
91
91
seed = 28041997
92
92
actual = dpnp .asnumpy (RandomState (seed ).uniform (low = 1.23 , high = 10.54 , size = (3 , 2 ), dtype = dtype , usm_type = usm_type ))
93
93
desired = numpy .array ([[3.700744485249743 , 8.390019132522866 ],
94
94
[2.60340195777826 , 4.473366308724508 ],
95
95
[1.773701806552708 , 4.193498786306009 ]])
96
96
assert_array_almost_equal (actual , desired , decimal = 6 )
97
97
98
+ @pytest .mark .parametrize ("dtype" ,
99
+ [dpnp .int32 , numpy .int32 , numpy .intc ],
100
+ ids = ['dpnp.int32' , 'numpy.int32' , 'numpy.intc' ])
101
+ @pytest .mark .parametrize ("usm_type" ,
102
+ ["host" , "device" , "shared" ],
103
+ ids = ['host' , 'device' , 'shared' ])
104
+ def test_uniform_int (self , dtype , usm_type ):
105
+ seed = 28041997
106
+ actual = dpnp .asnumpy (RandomState (seed ).uniform (low = 1.23 , high = 10.54 , size = (3 , 2 ), dtype = dtype , usm_type = usm_type ))
107
+ desired = numpy .array ([[3 , 8 ],
108
+ [2 , 4 ],
109
+ [1 , 4 ]])
110
+ assert_array_equal (actual , desired )
111
+
98
112
@pytest .mark .parametrize ("high" ,
99
113
[dpnp .array ([3 ]), numpy .array ([3 ])],
100
114
ids = ['dpnp.array([3])' , 'numpy.array([3])' ])
@@ -109,8 +123,8 @@ def test_fallback(self, low, high):
109
123
assert_array_almost_equal (actual , desired , decimal = 15 )
110
124
111
125
@pytest .mark .parametrize ("dtype" ,
112
- [dpnp .float16 , numpy .integer , dpnp .int , dpnp .bool , numpy .int64 , dpnp . int32 ],
113
- ids = ['dpnp.float16' , 'numpy.integer' , 'dpnp.int' , 'dpnp.bool' , 'numpy.int64' , 'dpnp.int32' ])
126
+ [dpnp .float16 , numpy .integer , dpnp .int , dpnp .bool , numpy .int64 ],
127
+ ids = ['dpnp.float16' , 'numpy.integer' , 'dpnp.int' , 'dpnp.bool' , 'numpy.int64' ])
114
128
def test_invalid_dtype (self , dtype ):
115
129
# dtype must be float32 or float64
116
130
assert_raises (TypeError , RandomState ().uniform , dtype = dtype )
0 commit comments