@@ -59,10 +59,15 @@ class RandomState:
59
59
"""
60
60
61
61
def __init__ (self , seed = None , device = None , sycl_queue = None ):
62
- self .seed = 1 if seed is None else seed
63
- self .sycl_queue = dpnp .get_normalized_queue_device (device = device , sycl_queue = sycl_queue )
64
- self .random_state = MT19937 (self .seed , self .sycl_queue )
65
- self .fallback_random_state = call_origin (numpy .random .RandomState , seed )
62
+ self ._seed = 1 if seed is None else seed
63
+ self ._sycl_queue = dpnp .get_normalized_queue_device (device = device , sycl_queue = sycl_queue )
64
+
65
+ self ._def_float_type = dpnp .float32
66
+ if self ._sycl_queue .get_sycl_device ().has_aspect_fp64 :
67
+ self ._def_float_type = dpnp .float64
68
+
69
+ self ._random_state = MT19937 (self ._seed , self ._sycl_queue )
70
+ self ._fallback_random_state = call_origin (numpy .random .RandomState , seed )
66
71
67
72
68
73
def get_state (self ):
@@ -71,17 +76,17 @@ def get_state(self):
71
76
72
77
For full documentation refer to :obj:`numpy.random.RandomState.get_state`.
73
78
"""
74
- return self .random_state
79
+ return self ._random_state
75
80
76
81
77
82
def get_sycl_queue (self ):
78
83
"""
79
84
Return a sycl queue used from the container.
80
85
"""
81
- return self .sycl_queue
86
+ return self ._sycl_queue
82
87
83
88
84
- def normal (self , loc = 0.0 , scale = 1.0 , size = None , dtype = dpnp . float64 , usm_type = "device" ):
89
+ def normal (self , loc = 0.0 , scale = 1.0 , size = None , dtype = None , usm_type = "device" ):
85
90
"""
86
91
Draw random samples from a normal (Gaussian) distribution.
87
92
@@ -92,7 +97,9 @@ def normal(self, loc=0.0, scale=1.0, size=None, dtype=dpnp.float64, usm_type="de
92
97
Parameters ``loc`` and ``scale`` are supported as scalar.
93
98
Otherwise, :obj:`numpy.random.RandomState.normal(loc, scale, size)` samples are drawn.
94
99
95
- Parameter ``dtype`` is supported only for :obj:`dpnp.float32` or :obj:`dpnp.float64`.
100
+ Parameter ``dtype`` is supported only for :obj:`dpnp.float32`, :obj:`dpnp.float64` or `None`.
101
+ If ``dtype`` is None (default), :obj:`dpnp.float64` type will be used if device supports it
102
+ or :obj:`dpnp.float32` otherwise.
96
103
Output array data type is the same as ``dtype``.
97
104
98
105
Examples
@@ -127,17 +134,19 @@ def normal(self, loc=0.0, scale=1.0, size=None, dtype=dpnp.float64, usm_type="de
127
134
elif scale < 0 or scale == 0 and numpy .signbit (scale ):
128
135
raise ValueError (f"scale={ scale } , but must be non-negative." )
129
136
130
- if not dtype in (dpnp .float32 , dpnp .float64 ):
137
+ if dtype is None :
138
+ dtype = self ._def_float_type
139
+ elif not dtype in (dpnp .float32 , dpnp .float64 ):
131
140
raise TypeError (f"dtype={ dtype } is unsupported." )
132
141
133
142
dpu .validate_usm_type (usm_type = usm_type , allow_none = False )
134
- return self .random_state .normal (loc = loc ,
135
- scale = scale ,
136
- size = size ,
137
- dtype = dtype ,
138
- usm_type = usm_type ).get_pyobj ()
143
+ return self ._random_state .normal (loc = loc ,
144
+ scale = scale ,
145
+ size = size ,
146
+ dtype = dtype ,
147
+ usm_type = usm_type ).get_pyobj ()
139
148
140
- return call_origin (self .fallback_random_state .normal , loc = loc , scale = scale , size = size )
149
+ return call_origin (self ._fallback_random_state .normal , loc = loc , scale = scale , size = size )
141
150
142
151
143
152
def rand (self , * args , usm_type = "device" ):
@@ -236,7 +245,7 @@ def randint(self, low, high=None, size=None, dtype=int, usm_type="device"):
236
245
dtype = _dtype ,
237
246
usm_type = usm_type )
238
247
239
- return call_origin (self .fallback_random_state .randint ,
248
+ return call_origin (self ._fallback_random_state .randint ,
240
249
low = low , high = high , size = size , dtype = dtype )
241
250
242
251
@@ -248,7 +257,8 @@ def randn(self, *args, usm_type="device"):
248
257
249
258
Limitations
250
259
-----------
251
- Output array data type is :obj:`dpnp.float64`.
260
+ Output array data type is :obj:`dpnp.float64` if device supports it
261
+ or :obj:`dpnp.float32` otherwise.
252
262
253
263
Examples
254
264
--------
@@ -287,7 +297,8 @@ def random_sample(self, size=None, usm_type="device"):
287
297
288
298
Limitations
289
299
-----------
290
- Output array data type is :obj:`dpnp.float64`.
300
+ Output array data type is :obj:`dpnp.float64` if device supports it
301
+ or :obj:`dpnp.float32` otherwise.
291
302
292
303
Examples
293
304
--------
@@ -305,7 +316,7 @@ def random_sample(self, size=None, usm_type="device"):
305
316
return self .uniform (low = 0.0 ,
306
317
high = 1.0 ,
307
318
size = size ,
308
- dtype = dpnp . float64 ,
319
+ dtype = None ,
309
320
usm_type = usm_type )
310
321
311
322
@@ -317,8 +328,8 @@ def standard_normal(self, size=None, usm_type="device"):
317
328
318
329
Limitations
319
330
-----------
320
- Parameter ``dtype`` is supported only for :obj:`dpnp.float32` or :obj:`dpnp.float64`.
321
- Output array data type is the same as ``dtype`` .
331
+ Output array data type is :obj:`dpnp.float64` if device supports it
332
+ or :obj:`dpnp.float32` otherwise .
322
333
323
334
Examples
324
335
--------
@@ -339,11 +350,11 @@ def standard_normal(self, size=None, usm_type="device"):
339
350
return self .normal (loc = 0.0 ,
340
351
scale = 1.0 ,
341
352
size = size ,
342
- dtype = dpnp . float64 ,
353
+ dtype = None ,
343
354
usm_type = usm_type )
344
355
345
356
346
- def uniform (self , low = 0.0 , high = 1.0 , size = None , dtype = dpnp . float64 , usm_type = "device" ):
357
+ def uniform (self , low = 0.0 , high = 1.0 , size = None , dtype = None , usm_type = "device" ):
347
358
"""
348
359
Draw samples from a uniform distribution.
349
360
@@ -356,7 +367,9 @@ def uniform(self, low=0.0, high=1.0, size=None, dtype=dpnp.float64, usm_type="de
356
367
-----------
357
368
Parameters ``low`` and ``high`` are supported as scalar.
358
369
Otherwise, :obj:`numpy.random.uniform(low, high, size)` samples are drawn.
359
- Parameter ``dtype`` is supported only for :obj:`dpnp.int32`, :obj:`dpnp.float32` or :obj:`dpnp.float64`.
370
+ Parameter ``dtype`` is supported only for :obj:`dpnp.int32`, :obj:`dpnp.float32`, :obj:`dpnp.float64` or `None`.
371
+ If ``dtype`` is None (default), :obj:`dpnp.float64` type will be used if device supports it
372
+ or :obj:`dpnp.float32` otherwise.
360
373
Output array data type is the same as ``dtype``.
361
374
362
375
Examples
@@ -394,14 +407,16 @@ def uniform(self, low=0.0, high=1.0, size=None, dtype=dpnp.float64, usm_type="de
394
407
if low > high :
395
408
low , high = high , low
396
409
397
- if not dtype in (dpnp .int32 , dpnp .float32 , dpnp .float64 ):
410
+ if dtype is None :
411
+ dtype = self ._def_float_type
412
+ elif not dtype in (dpnp .int32 , dpnp .float32 , dpnp .float64 ):
398
413
raise TypeError (f"dtype={ dtype } is unsupported." )
399
414
400
415
dpu .validate_usm_type (usm_type , allow_none = False )
401
- return self .random_state .uniform (low = low ,
402
- high = high ,
403
- size = size ,
404
- dtype = dtype ,
405
- usm_type = usm_type ).get_pyobj ()
416
+ return self ._random_state .uniform (low = low ,
417
+ high = high ,
418
+ size = size ,
419
+ dtype = dtype ,
420
+ usm_type = usm_type ).get_pyobj ()
406
421
407
- return call_origin (self .fallback_random_state .uniform , low = low , high = high , size = size )
422
+ return call_origin (self ._fallback_random_state .uniform , low = low , high = high , size = size )
0 commit comments