Skip to content

Commit 77d1c87

Browse files
committed
Default dtype to be dependent on fp64 support
1 parent efd9f84 commit 77d1c87

File tree

4 files changed

+106
-65
lines changed

4 files changed

+106
-65
lines changed

dpnp/random/dpnp_iface_random.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,8 @@ def normal(loc=0.0, scale=1.0, size=None, usm_type='device'):
786786
-----------
787787
Parameters ``loc`` and ``scale`` are supported as scalar.
788788
Otherwise, :obj:`numpy.random.normal(loc, scale, size)` samples are drawn.
789-
Output array data type is :obj:`dpnp.float64`.
789+
Output array data type is :obj:`dpnp.float64` if device supports it
790+
or :obj:`dpnp.float32` otherwise.
790791
791792
Examples
792793
--------
@@ -798,7 +799,7 @@ def normal(loc=0.0, scale=1.0, size=None, usm_type='device'):
798799
return _get_random_state().normal(loc=loc,
799800
scale=scale,
800801
size=size,
801-
dtype=dpnp.float64,
802+
dtype=None,
802803
usm_type=usm_type)
803804

804805

@@ -996,7 +997,8 @@ def rand(d0, *dn, usm_type="device"):
996997
997998
Limitations
998999
-----------
999-
Output array data type is :obj:`dpnp.float64`.
1000+
Output array data type is :obj:`dpnp.float64` if device supports it
1001+
or :obj:`dpnp.float32` otherwise.
10001002
10011003
Examples
10021004
--------
@@ -1054,7 +1056,8 @@ def randn(d0, *dn, usm_type="device"):
10541056
10551057
Limitations
10561058
-----------
1057-
Output array data type is :obj:`dpnp.float64`.
1059+
Output array data type is :obj:`dpnp.float64` if device supports it
1060+
or :obj:`dpnp.float32` otherwise.
10581061
10591062
Examples
10601063
--------
@@ -1084,7 +1087,8 @@ def random(size=None, usm_type="device"):
10841087
10851088
Limitations
10861089
-----------
1087-
Output array data type is :obj:`dpnp.float64`.
1090+
Output array data type is :obj:`dpnp.float64` if device supports it
1091+
or :obj:`dpnp.float32` otherwise.
10881092
10891093
Examples
10901094
--------
@@ -1145,7 +1149,8 @@ def random_sample(size=None, usm_type="device"):
11451149
11461150
Limitations
11471151
-----------
1148-
Output array data type is :obj:`dpnp.float64`.
1152+
Output array data type is :obj:`dpnp.float64` if device supports it
1153+
or :obj:`dpnp.float32` otherwise.
11491154
11501155
Examples
11511156
--------
@@ -1172,7 +1177,8 @@ def ranf(size=None, usm_type="device"):
11721177
11731178
Limitations
11741179
-----------
1175-
Output array data type is :obj:`dpnp.float64`.
1180+
Output array data type is :obj:`dpnp.float64` if device supports it
1181+
or :obj:`dpnp.float32` otherwise.
11761182
11771183
Examples
11781184
--------
@@ -1233,7 +1239,8 @@ def sample(size=None, usm_type="device"):
12331239
12341240
Limitations
12351241
-----------
1236-
Output array data type is :obj:`dpnp.float64`.
1242+
Output array data type is :obj:`dpnp.float64` if device supports it
1243+
or :obj:`dpnp.float32` otherwise.
12371244
12381245
Examples
12391246
--------
@@ -1407,7 +1414,8 @@ def standard_normal(size=None, usm_type="device"):
14071414
14081415
Limitations
14091416
-----------
1410-
Output array data type is :obj:`dpnp.float64`.
1417+
Output array data type is :obj:`dpnp.float64` if device supports it
1418+
or :obj:`dpnp.float32` otherwise.
14111419
14121420
Examples
14131421
--------
@@ -1508,7 +1516,8 @@ def uniform(low=0.0, high=1.0, size=None, usm_type='device'):
15081516
-----------
15091517
Parameters ``low`` and ``high`` are supported as scalar.
15101518
Otherwise, :obj:`numpy.random.uniform(low, high, size)` samples are drawn.
1511-
Output array data type is :obj:`dpnp.float64`.
1519+
Output array data type is :obj:`dpnp.float64` if device supports it
1520+
or :obj:`dpnp.float32` otherwise.
15121521
15131522
Examples
15141523
--------
@@ -1524,7 +1533,7 @@ def uniform(low=0.0, high=1.0, size=None, usm_type='device'):
15241533
return _get_random_state().uniform(low=low,
15251534
high=high,
15261535
size=size,
1527-
dtype=dpnp.float64,
1536+
dtype=None,
15281537
usm_type=usm_type)
15291538

15301539

dpnp/random/dpnp_random_state.py

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,15 @@ class RandomState:
5959
"""
6060

6161
def __init__(self, seed=None, device=None, sycl_queue=None):
62-
self.seed = 1 if seed is None else seed
63-
self.sycl_queue = dpnp.get_normalized_queue_device(device=device, sycl_queue=sycl_queue)
64-
self.random_state = MT19937(self.seed, self.sycl_queue)
65-
self.fallback_random_state = call_origin(numpy.random.RandomState, seed)
62+
self._seed = 1 if seed is None else seed
63+
self._sycl_queue = dpnp.get_normalized_queue_device(device=device, sycl_queue=sycl_queue)
64+
65+
self._def_float_type = dpnp.float32
66+
if self._sycl_queue.get_sycl_device().has_aspect_fp64:
67+
self._def_float_type = dpnp.float64
68+
69+
self._random_state = MT19937(self._seed, self._sycl_queue)
70+
self._fallback_random_state = call_origin(numpy.random.RandomState, seed)
6671

6772

6873
def get_state(self):
@@ -71,17 +76,17 @@ def get_state(self):
7176
7277
For full documentation refer to :obj:`numpy.random.RandomState.get_state`.
7378
"""
74-
return self.random_state
79+
return self._random_state
7580

7681

7782
def get_sycl_queue(self):
7883
"""
7984
Return a sycl queue used from the container.
8085
"""
81-
return self.sycl_queue
86+
return self._sycl_queue
8287

8388

84-
def normal(self, loc=0.0, scale=1.0, size=None, dtype=dpnp.float64, usm_type="device"):
89+
def normal(self, loc=0.0, scale=1.0, size=None, dtype=None, usm_type="device"):
8590
"""
8691
Draw random samples from a normal (Gaussian) distribution.
8792
@@ -92,7 +97,9 @@ def normal(self, loc=0.0, scale=1.0, size=None, dtype=dpnp.float64, usm_type="de
9297
Parameters ``loc`` and ``scale`` are supported as scalar.
9398
Otherwise, :obj:`numpy.random.RandomState.normal(loc, scale, size)` samples are drawn.
9499
95-
Parameter ``dtype`` is supported only for :obj:`dpnp.float32` or :obj:`dpnp.float64`.
100+
Parameter ``dtype`` is supported only for :obj:`dpnp.float32`, :obj:`dpnp.float64` or `None`.
101+
If ``dtype`` is None (default), :obj:`dpnp.float64` type will be used if device supports it
102+
or :obj:`dpnp.float32` otherwise.
96103
Output array data type is the same as ``dtype``.
97104
98105
Examples
@@ -127,17 +134,19 @@ def normal(self, loc=0.0, scale=1.0, size=None, dtype=dpnp.float64, usm_type="de
127134
elif scale < 0 or scale == 0 and numpy.signbit(scale):
128135
raise ValueError(f"scale={scale}, but must be non-negative.")
129136

130-
if not dtype in (dpnp.float32, dpnp.float64):
137+
if dtype is None:
138+
dtype = self._def_float_type
139+
elif not dtype in (dpnp.float32, dpnp.float64):
131140
raise TypeError(f"dtype={dtype} is unsupported.")
132141

133142
dpu.validate_usm_type(usm_type=usm_type, allow_none=False)
134-
return self.random_state.normal(loc=loc,
135-
scale=scale,
136-
size=size,
137-
dtype=dtype,
138-
usm_type=usm_type).get_pyobj()
143+
return self._random_state.normal(loc=loc,
144+
scale=scale,
145+
size=size,
146+
dtype=dtype,
147+
usm_type=usm_type).get_pyobj()
139148

140-
return call_origin(self.fallback_random_state.normal, loc=loc, scale=scale, size=size)
149+
return call_origin(self._fallback_random_state.normal, loc=loc, scale=scale, size=size)
141150

142151

143152
def rand(self, *args, usm_type="device"):
@@ -236,7 +245,7 @@ def randint(self, low, high=None, size=None, dtype=int, usm_type="device"):
236245
dtype=_dtype,
237246
usm_type=usm_type)
238247

239-
return call_origin(self.fallback_random_state.randint,
248+
return call_origin(self._fallback_random_state.randint,
240249
low=low, high=high, size=size, dtype=dtype)
241250

242251

@@ -248,7 +257,8 @@ def randn(self, *args, usm_type="device"):
248257
249258
Limitations
250259
-----------
251-
Output array data type is :obj:`dpnp.float64`.
260+
Output array data type is :obj:`dpnp.float64` if device supports it
261+
or :obj:`dpnp.float32` otherwise.
252262
253263
Examples
254264
--------
@@ -287,7 +297,8 @@ def random_sample(self, size=None, usm_type="device"):
287297
288298
Limitations
289299
-----------
290-
Output array data type is :obj:`dpnp.float64`.
300+
Output array data type is :obj:`dpnp.float64` if device supports it
301+
or :obj:`dpnp.float32` otherwise.
291302
292303
Examples
293304
--------
@@ -305,7 +316,7 @@ def random_sample(self, size=None, usm_type="device"):
305316
return self.uniform(low=0.0,
306317
high=1.0,
307318
size=size,
308-
dtype=dpnp.float64,
319+
dtype=None,
309320
usm_type=usm_type)
310321

311322

@@ -317,8 +328,8 @@ def standard_normal(self, size=None, usm_type="device"):
317328
318329
Limitations
319330
-----------
320-
Parameter ``dtype`` is supported only for :obj:`dpnp.float32` or :obj:`dpnp.float64`.
321-
Output array data type is the same as ``dtype``.
331+
Output array data type is :obj:`dpnp.float64` if device supports it
332+
or :obj:`dpnp.float32` otherwise.
322333
323334
Examples
324335
--------
@@ -339,11 +350,11 @@ def standard_normal(self, size=None, usm_type="device"):
339350
return self.normal(loc=0.0,
340351
scale=1.0,
341352
size=size,
342-
dtype=dpnp.float64,
353+
dtype=None,
343354
usm_type=usm_type)
344355

345356

346-
def uniform(self, low=0.0, high=1.0, size=None, dtype=dpnp.float64, usm_type="device"):
357+
def uniform(self, low=0.0, high=1.0, size=None, dtype=None, usm_type="device"):
347358
"""
348359
Draw samples from a uniform distribution.
349360
@@ -356,7 +367,9 @@ def uniform(self, low=0.0, high=1.0, size=None, dtype=dpnp.float64, usm_type="de
356367
-----------
357368
Parameters ``low`` and ``high`` are supported as scalar.
358369
Otherwise, :obj:`numpy.random.uniform(low, high, size)` samples are drawn.
359-
Parameter ``dtype`` is supported only for :obj:`dpnp.int32`, :obj:`dpnp.float32` or :obj:`dpnp.float64`.
370+
Parameter ``dtype`` is supported only for :obj:`dpnp.int32`, :obj:`dpnp.float32`, :obj:`dpnp.float64` or `None`.
371+
If ``dtype`` is None (default), :obj:`dpnp.float64` type will be used if device supports it
372+
or :obj:`dpnp.float32` otherwise.
360373
Output array data type is the same as ``dtype``.
361374
362375
Examples
@@ -394,14 +407,16 @@ def uniform(self, low=0.0, high=1.0, size=None, dtype=dpnp.float64, usm_type="de
394407
if low > high:
395408
low, high = high, low
396409

397-
if not dtype in (dpnp.int32, dpnp.float32, dpnp.float64):
410+
if dtype is None:
411+
dtype = self._def_float_type
412+
elif not dtype in (dpnp.int32, dpnp.float32, dpnp.float64):
398413
raise TypeError(f"dtype={dtype} is unsupported.")
399414

400415
dpu.validate_usm_type(usm_type, allow_none=False)
401-
return self.random_state.uniform(low=low,
402-
high=high,
403-
size=size,
404-
dtype=dtype,
405-
usm_type=usm_type).get_pyobj()
416+
return self._random_state.uniform(low=low,
417+
high=high,
418+
size=size,
419+
dtype=dtype,
420+
usm_type=usm_type).get_pyobj()
406421

407-
return call_origin(self.fallback_random_state.uniform, low=low, high=high, size=size)
422+
return call_origin(self._fallback_random_state.uniform, low=low, high=high, size=size)

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,6 @@ tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_2_{extern
335335
tests/third_party/cupy/statistics_tests/test_correlation.py::TestCov::test_cov_empty
336336
tests/third_party/cupy/statistics_tests/test_meanvar.py::TestMeanVar::test_external_mean_axis
337337

338-
tests/test_random.py::test_randn_normal_distribution
339338
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_multidim_outer
340339

341340
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_3_{external=True, length=20000}::test_partition_axis

0 commit comments

Comments
 (0)