Skip to content

Commit 18c1868

Browse files
committed
Secure support of int32 for RandomState.uniform in tests.
1 parent a789726 commit 18c1868

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

dpnp/random/dpnp_iface_random.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def uniform(self, low=0.0, high=1.0, size=None, dtype=numpy.float64, usm_type="d
125125
-----------
126126
Parameters ``low`` and ``high`` are supported as scalar.
127127
Otherwise, :obj:`numpy.random.uniform(low, high, size)` samples are drawn.
128-
Parameter ``dtype`` is supported only for :obj:`dpnp.float32` or :obj:`dpnp.float64`.
128+
Parameter ``dtype`` is supported only for :obj:`dpnp.int32`, :obj:`dpnp.float32` or :obj:`dpnp.float64`.
129129
Output array data type is the same as ``dtype``.
130130
"""
131131

@@ -137,7 +137,7 @@ def uniform(self, low=0.0, high=1.0, size=None, dtype=numpy.float64, usm_type="d
137137
else:
138138
if low > high:
139139
low, high = high, low
140-
if not (dpnp.is_type_supported(dtype) and dtype in {dpnp.float32, dpnp.float64}):
140+
if not (dpnp.is_type_supported(dtype) and dtype in {dpnp.int32, dpnp.float32, dpnp.float64}):
141141
raise TypeError(f"{dtype} is unsupported.")
142142
return self.random_state.uniform(low, high, size, dtype, usm_type).get_pyobj()
143143

tests/test_randomstate.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy
66

77
from dpnp.random import RandomState
8-
from numpy.testing import (assert_allclose, assert_raises, assert_array_almost_equal)
8+
from numpy.testing import (assert_allclose, assert_raises, assert_array_equal, assert_array_almost_equal)
99

1010

1111
class TestSeed:
@@ -87,14 +87,28 @@ class TestUniform:
8787
@pytest.mark.parametrize("usm_type",
8888
["host", "device", "shared"],
8989
ids=['host', 'device', 'shared'])
90-
def test_uniform(self, dtype, usm_type):
90+
def test_uniform_float(self, dtype, usm_type):
9191
seed = 28041997
9292
actual = dpnp.asnumpy(RandomState(seed).uniform(low=1.23, high=10.54, size=(3, 2), dtype=dtype, usm_type=usm_type))
9393
desired = numpy.array([[3.700744485249743, 8.390019132522866],
9494
[2.60340195777826, 4.473366308724508],
9595
[1.773701806552708, 4.193498786306009]])
9696
assert_array_almost_equal(actual, desired, decimal=6)
9797

98+
@pytest.mark.parametrize("dtype",
99+
[dpnp.int32, numpy.int32, numpy.intc],
100+
ids=['dpnp.int32', 'numpy.int32', 'numpy.intc'])
101+
@pytest.mark.parametrize("usm_type",
102+
["host", "device", "shared"],
103+
ids=['host', 'device', 'shared'])
104+
def test_uniform_int(self, dtype, usm_type):
105+
seed = 28041997
106+
actual = dpnp.asnumpy(RandomState(seed).uniform(low=1.23, high=10.54, size=(3, 2), dtype=dtype, usm_type=usm_type))
107+
desired = numpy.array([[3, 8],
108+
[2, 4],
109+
[1, 4]])
110+
assert_array_equal(actual, desired)
111+
98112
@pytest.mark.parametrize("high",
99113
[dpnp.array([3]), numpy.array([3])],
100114
ids=['dpnp.array([3])', 'numpy.array([3])'])
@@ -109,8 +123,8 @@ def test_fallback(self, low, high):
109123
assert_array_almost_equal(actual, desired, decimal=15)
110124

111125
@pytest.mark.parametrize("dtype",
112-
[dpnp.float16, numpy.integer, dpnp.int, dpnp.bool, numpy.int64, dpnp.int32],
113-
ids=['dpnp.float16', 'numpy.integer', 'dpnp.int', 'dpnp.bool', 'numpy.int64', 'dpnp.int32'])
126+
[dpnp.float16, numpy.integer, dpnp.int, dpnp.bool, numpy.int64],
127+
ids=['dpnp.float16', 'numpy.integer', 'dpnp.int', 'dpnp.bool', 'numpy.int64'])
114128
def test_invalid_dtype(self, dtype):
115129
# dtype must be float32 or float64
116130
assert_raises(TypeError, RandomState().uniform, dtype=dtype)

0 commit comments

Comments
 (0)