Skip to content

Commit 12ca5d7

Browse files
committed
Update random_tests/test_sample.py
1 parent cf4ee15 commit 12ca5d7

File tree

1 file changed

+73
-61
lines changed

1 file changed

+73
-61
lines changed

tests/third_party/cupy/random_tests/test_sample.py

Lines changed: 73 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -11,63 +11,61 @@
1111

1212

1313
class TestRandint(unittest.TestCase):
14+
1415
def test_lo_hi_reversed(self):
1516
with self.assertRaises(ValueError):
1617
random.randint(100, 1)
1718

1819
def test_lo_hi_equal(self):
1920
with self.assertRaises(ValueError):
20-
random.randint(3, 3, size=3)
21+
random.randint(3, 3, size=0)
2122

2223
with self.assertRaises(ValueError):
2324
# int(-0.2) is not less than int(0.3)
2425
random.randint(-0.2, 0.3)
2526

2627
def test_lo_hi_nonrandom(self):
2728
a = random.randint(-0.9, 1.1, size=3)
28-
numpy.testing.assert_array_equal(a, cupy.full((3,), 0))
29+
testing.assert_array_equal(a, cupy.full((3,), 0))
2930

3031
a = random.randint(-1.1, -0.9, size=(2, 2))
31-
numpy.testing.assert_array_equal(a, cupy.full((2, 2), -1))
32+
testing.assert_array_equal(a, cupy.full((2, 2), -1))
3233

3334
def test_zero_sizes(self):
3435
a = random.randint(10, size=(0,))
35-
numpy.testing.assert_array_equal(a, cupy.array(()))
36+
testing.assert_array_equal(a, cupy.array(()))
3637

3738
a = random.randint(10, size=0)
38-
numpy.testing.assert_array_equal(a, cupy.array(()))
39+
testing.assert_array_equal(a, cupy.array(()))
3940

4041

4142
@testing.fix_random()
4243
class TestRandint2(unittest.TestCase):
43-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
44+
4445
@_condition.repeat(3, 10)
4546
def test_bound_1(self):
46-
vals = [random.randint(0, 10, (2, 3)) for _ in range(10)]
47+
vals = [random.randint(0, 10, (2, 3)) for _ in range(20)]
4748
for val in vals:
48-
self.assertEqual(val.shape, (2, 3))
49-
self.assertEqual(min(_.min() for _ in vals), 0)
50-
self.assertEqual(max(_.max() for _ in vals), 9)
49+
assert val.shape == (2, 3)
50+
assert min(_.min() for _ in vals) == 0
51+
assert max(_.max() for _ in vals) == 9
5152

52-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
5353
@_condition.repeat(3, 10)
5454
def test_bound_2(self):
5555
vals = [random.randint(0, 2) for _ in range(20)]
5656
for val in vals:
57-
self.assertEqual(val.shape, ())
58-
self.assertEqual(min(_.min() for _ in vals), 0)
59-
self.assertEqual(max(_.max() for _ in vals), 1)
57+
assert val.shape == ()
58+
assert min(vals) == 0
59+
assert max(vals) == 1
6060

61-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
6261
@_condition.repeat(3, 10)
6362
def test_bound_overflow(self):
6463
# 100 - (-100) exceeds the range of int8
6564
val = random.randint(numpy.int8(-100), numpy.int8(100), size=20)
66-
self.assertEqual(val.shape, (20,))
67-
self.assertGreaterEqual(val.min(), -100)
68-
self.assertLess(val.max(), 100)
65+
assert val.shape == (20,)
66+
assert val.min() >= -100
67+
assert val.max() < 100
6968

70-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
7169
@_condition.repeat(3, 10)
7270
def test_bound_float1(self):
7371
# generate floats s.t. int(low) < int(high)
@@ -76,70 +74,78 @@ def test_bound_float1(self):
7674
high += 1
7775
vals = [random.randint(low, high, (2, 3)) for _ in range(10)]
7876
for val in vals:
79-
self.assertEqual(val.shape, (2, 3))
80-
self.assertEqual(min(_.min() for _ in vals), int(low))
81-
self.assertEqual(max(_.max() for _ in vals), int(high) - 1)
77+
assert val.shape == (2, 3)
78+
assert min(_.min() for _ in vals) == int(low)
79+
assert max(_.max() for _ in vals) == int(high) - 1
8280

83-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
8481
def test_bound_float2(self):
8582
vals = [random.randint(-1.0, 1.0, (2, 3)) for _ in range(10)]
8683
for val in vals:
87-
self.assertEqual(val.shape, (2, 3))
88-
self.assertEqual(min(_.min() for _ in vals), -1)
89-
self.assertEqual(max(_.max() for _ in vals), 0)
84+
assert val.shape == (2, 3)
85+
assert min(_.min() for _ in vals) == -1
86+
assert max(_.max() for _ in vals) == 0
9087

9188
@_condition.repeat(3, 10)
9289
def test_goodness_of_fit(self):
9390
mx = 5
9491
trial = 100
95-
vals = [numpy.random.randint(mx) for _ in range(trial)]
92+
vals = [random.randint(mx) for _ in range(trial)]
9693
counts = numpy.histogram(vals, bins=numpy.arange(mx + 1))[0]
9794
expected = numpy.array([float(trial) / mx] * mx)
98-
self.assertTrue(_hypothesis.chi_square_test(counts, expected))
95+
assert _hypothesis.chi_square_test(counts, expected)
9996

10097
@_condition.repeat(3, 10)
10198
def test_goodness_of_fit_2(self):
10299
mx = 5
103100
vals = random.randint(mx, size=(5, 20))
104101
counts = numpy.histogram(vals, bins=numpy.arange(mx + 1))[0]
105102
expected = numpy.array([float(vals.size) / mx] * mx)
106-
self.assertTrue(_hypothesis.chi_square_test(counts, expected))
103+
assert _hypothesis.chi_square_test(counts, expected)
107104

108105

109106
class TestRandintDtype(unittest.TestCase):
110-
# numpy.int8, numpy.uint8, numpy.int16, numpy.uint16, numpy.int32])
111-
@testing.for_dtypes([numpy.int32])
107+
108+
@testing.with_requires("numpy>=2.0")
109+
@testing.for_dtypes(
110+
[numpy.int8, numpy.uint8, numpy.int16, numpy.uint16, numpy.int32]
111+
)
112112
def test_dtype(self, dtype):
113113
size = (1000,)
114114
low = numpy.iinfo(dtype).min
115-
high = numpy.iinfo(dtype).max
116-
x = random.randint(low, high, size, dtype)
117-
self.assertLessEqual(low, min(x))
118-
self.assertLessEqual(max(x), high)
115+
high = numpy.iinfo(dtype).max + 1
116+
x = random.randint(low, high, size, dtype).get()
117+
assert low <= min(x)
118+
assert max(x) <= high
119119

120-
# @testing.for_int_dtypes(no_bool=True)
120+
@pytest.mark.skip("high=(max+1) is not supported")
121+
@testing.for_int_dtypes(no_bool=True)
121122
@testing.for_dtypes([numpy.int32])
122123
def test_dtype2(self, dtype):
123124
dtype = numpy.dtype(dtype)
124125

126+
# randint does not support 64 bit integers
127+
if dtype in (numpy.int64, numpy.uint64):
128+
return
129+
125130
iinfo = numpy.iinfo(dtype)
126131
size = (10000,)
127132

128-
x = random.randint(iinfo.min, iinfo.max, size, dtype)
129-
self.assertEqual(x.dtype, dtype)
130-
self.assertLessEqual(iinfo.min, min(x))
131-
self.assertLessEqual(max(x), iinfo.max)
133+
x = random.randint(iinfo.min, iinfo.max + 1, size, dtype).get()
134+
assert x.dtype == dtype
135+
assert iinfo.min <= min(x)
136+
assert max(x) <= iinfo.max
132137

133138
# Lower bound check
134-
with self.assertRaises(OverflowError):
139+
with self.assertRaises(ValueError):
135140
random.randint(iinfo.min - 1, iinfo.min + 10, size, dtype)
136141

137142
# Upper bound check
138-
with self.assertRaises(OverflowError):
143+
with self.assertRaises(ValueError):
139144
random.randint(iinfo.max - 10, iinfo.max + 2, size, dtype)
140145

141146

142147
class TestRandomIntegers(unittest.TestCase):
148+
143149
def test_normal(self):
144150
with mock.patch("dpnp.random.RandomState.randint") as m:
145151
random.random_integers(3, 5)
@@ -164,50 +170,53 @@ def test_size_is_not_none(self):
164170

165171
@testing.fix_random()
166172
class TestRandomIntegers2(unittest.TestCase):
173+
167174
@_condition.repeat(3, 10)
168175
def test_bound_1(self):
169-
vals = [random.random_integers(0, 10, (2, 3)).get() for _ in range(10)]
176+
vals = [random.random_integers(0, 10, (2, 3)) for _ in range(10)]
170177
for val in vals:
171-
self.assertEqual(val.shape, (2, 3))
172-
self.assertEqual(min(_.min() for _ in vals), 0)
173-
self.assertEqual(max(_.max() for _ in vals), 10)
178+
assert val.shape == (2, 3)
179+
assert min(_.min() for _ in vals) == 0
180+
assert max(_.max() for _ in vals) == 10
174181

175182
@_condition.repeat(3, 10)
176183
def test_bound_2(self):
177-
vals = [random.random_integers(0, 2).get() for _ in range(20)]
184+
vals = [random.random_integers(0, 2) for _ in range(20)]
178185
for val in vals:
179-
self.assertEqual(val.shape, ())
180-
self.assertEqual(min(vals), 0)
181-
self.assertEqual(max(vals), 2)
186+
assert val.shape == ()
187+
assert min(vals) == 0
188+
assert max(vals) == 2
182189

183190
@_condition.repeat(3, 10)
184191
def test_goodness_of_fit(self):
185192
mx = 5
186193
trial = 100
187-
vals = [random.randint(0, mx).get() for _ in range(trial)]
194+
vals = [random.randint(0, mx) for _ in range(trial)]
188195
counts = numpy.histogram(vals, bins=numpy.arange(mx + 1))[0]
189196
expected = numpy.array([float(trial) / mx] * mx)
190-
self.assertTrue(_hypothesis.chi_square_test(counts, expected))
197+
assert _hypothesis.chi_square_test(counts, expected)
191198

192199
@_condition.repeat(3, 10)
193200
def test_goodness_of_fit_2(self):
194201
mx = 5
195-
vals = random.randint(0, mx, (5, 20)).get()
202+
vals = random.randint(0, mx, (5, 20))
196203
counts = numpy.histogram(vals, bins=numpy.arange(mx + 1))[0]
197204
expected = numpy.array([float(vals.size) / mx] * mx)
198-
self.assertTrue(_hypothesis.chi_square_test(counts, expected))
205+
assert _hypothesis.chi_square_test(counts, expected)
199206

200207

208+
@pytest.mark.skip("random.choice() is not supported yet")
201209
class TestChoice(unittest.TestCase):
210+
202211
def setUp(self):
203-
self.rs_tmp = random.generator._random_states
212+
self.rs_tmp = random._generator._random_states
204213
device_id = cuda.Device().id
205214
self.m = mock.Mock()
206215
self.m.choice.return_value = 0
207-
random.generator._random_states = {device_id: self.m}
216+
random._generator._random_states = {device_id: self.m}
208217

209218
def tearDown(self):
210-
random.generator._random_states = self.rs_tmp
219+
random._generator._random_states = self.rs_tmp
211220

212221
def test_size_and_replace_and_p_are_none(self):
213222
random.choice(3)
@@ -243,10 +252,11 @@ def test_no_none(self):
243252

244253

245254
class TestRandomSample(unittest.TestCase):
255+
246256
def test_rand(self):
247-
# no keyword argument 'dtype' in dpnp
248-
with self.assertRaises(TypeError):
249-
random.rand(1, 2, 3, dtype=numpy.float32)
257+
with mock.patch("dpnp.random.RandomState.random_sample") as m:
258+
random.rand(1, 2, 3)
259+
m.assert_called_once_with(size=(1, 2, 3), usm_type="device")
250260

251261
def test_rand_default_dtype(self):
252262
with mock.patch("dpnp.random.RandomState.random_sample") as m:
@@ -280,12 +290,14 @@ def test_randn_invalid_argument(self):
280290
{"size": (1, 0)},
281291
)
282292
@testing.fix_random()
293+
@pytest.mark.skip("random.multinomial() is not fully supported")
283294
class TestMultinomial(unittest.TestCase):
295+
284296
@_condition.repeat(3, 10)
285297
@testing.for_float_dtypes()
286298
@testing.numpy_cupy_allclose(rtol=0.05)
287299
def test_multinomial(self, xp, dtype):
288300
pvals = xp.array([0.2, 0.3, 0.5], dtype)
289301
x = xp.random.multinomial(100000, pvals, self.size)
290-
self.assertEqual(x.dtype, "l")
302+
assert x.dtype.kind == "l"
291303
return x / 100000

0 commit comments

Comments
 (0)