Skip to content

Commit 9db5e4f

Browse files
committed
Formatting and minor fixes
1 parent e4c3187 commit 9db5e4f

File tree

4 files changed

+78
-30
lines changed

4 files changed

+78
-30
lines changed

dpnp/backend/src/dpnp_random_state.hpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,26 +32,40 @@
3232

3333
namespace mkl_rng = oneapi::mkl::rng;
3434

35+
// Structure storing MKL engine for MT199374x32x10 algorithm
3536
struct mt19937_struct
3637
{
3738
mkl_rng::mt19937* engine;
3839
};
3940

40-
void MT19937_InitScalarSeed(mt19937_struct *mt19937, DPCTLSyclQueueRef QRef, uint32_t seed=1)
41+
/**
42+
* @brief Create a MKL engine from scalar seed
43+
*
44+
* Invoke a common seed initialization of the engine for MT199374x32x10 algorithm.
45+
*/
46+
void MT19937_InitScalarSeed(mt19937_struct *mt19937, DPCTLSyclQueueRef q_ref, uint32_t seed = 1)
4147
{
42-
sycl::queue q = *(reinterpret_cast<sycl::queue *>(QRef));
48+
sycl::queue q = *(reinterpret_cast<sycl::queue *>(q_ref));
4349
mt19937->engine = new mkl_rng::mt19937(q, seed);
4450
return;
4551
}
4652

47-
void MT19937_InitVectorSeed(mt19937_struct *mt19937, DPCTLSyclQueueRef QRef, uint32_t * seed, unsigned int n) {
48-
sycl::queue q = *(reinterpret_cast<sycl::queue *>(QRef));
53+
/**
54+
* @brief Create a MKL engine from seed vector
55+
*
56+
* Invoke an extended seed initialization of the engine for MT199374x32x10 algorithm.
57+
*
58+
* @note the vector size is limited by length=3
59+
*/
60+
void MT19937_InitVectorSeed(mt19937_struct *mt19937, DPCTLSyclQueueRef q_ref, uint32_t * seed, unsigned int n) {
61+
sycl::queue q = *(reinterpret_cast<sycl::queue *>(q_ref));
4962

5063
switch (n) {
5164
case 1: mt19937->engine = new mkl_rng::mt19937(q, {seed[0]}); break;
5265
case 2: mt19937->engine = new mkl_rng::mt19937(q, {seed[0], seed[1]}); break;
5366
case 3: mt19937->engine = new mkl_rng::mt19937(q, {seed[0], seed[1], seed[2]}); break;
5467
default:
68+
// TODO need to get rid of the limitation for seed vector length
5569
throw std::runtime_error("Too long seed vector");
5670
}
5771
return;

dpnp/random/dpnp_algo_random.pyx

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# cython: language_level=3
22
# -*- coding: utf-8 -*-
33
# *****************************************************************************
4-
# Copyright (c) 2016-2020, Intel Corporation
4+
# Copyright (c) 2016-2022, Intel Corporation
55
# All rights reserved.
66
#
77
# Redistribution and use in source and binary forms, with or without
@@ -287,59 +287,73 @@ cdef extern from "dpnp_random_state.hpp":
287287

288288

289289
cdef class MT19937:
290-
"""Class storing MKL engine for MT199374x32x10 algorithm
291290
"""
291+
Class storing MKL engine for MT199374x32x10 algorithm.
292+
"""
293+
292294
cdef mt19937_struct mt19937
293-
cdef c_dpctl.DPCTLSyclQueueRef QRef
294-
cdef c_dpctl.SyclQueue Queue
295+
cdef c_dpctl.DPCTLSyclQueueRef q_ref
296+
cdef c_dpctl.SyclQueue q
295297

296298
def __cinit__(self, seed, sycl_queue=None):
297299
cdef bint is_vector_seed = False
298300
cdef uint32_t scalar_seed = 0
299301
cdef unsigned int vector_seed_len = 0
300302
cdef unsigned int *vector_seed = NULL
301303

302-
self.QRef = NULL
304+
self.q_ref = NULL
303305
if sycl_queue is None:
304306
sycl_queue = dpctl.SyclQueue()
305-
if not isinstance(sycl_queue, dpctl.SyclQueue):
306-
raise TypeError
307307

308-
if isinstance(seed, int):
308+
# keep a refference on SYCL queue
309+
self.q = <c_dpctl.SyclQueue> sycl_queue
310+
self.q_ref = c_dpctl.DPCTLQueue_Copy((self.q).get_queue_ref())
311+
if (self.q_ref is NULL):
312+
raise ValueError("SyclQueue copy failed")
313+
314+
# get a scalar seed value or a vector of seeds
315+
if isinstance(seed, int) and seed >= 0:
309316
scalar_seed = <uint32_t> seed
310317
elif isinstance(seed, (list, tuple)):
311318
is_vector_seed = True
312319
vector_seed_len = len(seed)
320+
if vector_seed_len > 3:
321+
raise ValueError(
322+
f"{vector_seed_len} length of seed vector isn't supported, "
323+
"the length is limited by 3")
324+
313325
vector_seed = <uint32_t *> malloc(vector_seed_len * sizeof(uint32_t))
314326
if (not vector_seed):
315-
raise MemoryError
327+
raise MemoryError(f"Could not allocate memory for seed vector of length {vector_seed_len}")
316328

329+
# convert input seed's type to uint32_t one (expected in MKL function)
317330
try:
318331
for i in range(vector_seed_len):
319332
vector_seed[i] = <uint32_t> seed[i]
320333
except (ValueError, TypeError) as e:
321334
free(vector_seed)
322335
raise e
323336
else:
324-
raise TypeError("Seed must be an uint32_t, or a sequence of uint32_t elements")
337+
raise TypeError("Seed must be an unsigned int, or a sequence of unsigned int elements")
325338

326-
self.Queue = <c_dpctl.SyclQueue> sycl_queue
327-
self.QRef = c_dpctl.DPCTLQueue_Copy((self.Queue).get_queue_ref())
328339
if is_vector_seed:
329-
MT19937_InitVectorSeed(&self.mt19937, self.QRef, vector_seed, vector_seed_len)
340+
MT19937_InitVectorSeed(&self.mt19937, self.q_ref, vector_seed, vector_seed_len)
330341
free(vector_seed)
331342
else:
332-
MT19937_InitScalarSeed(&self.mt19937, self.QRef, scalar_seed)
343+
MT19937_InitScalarSeed(&self.mt19937, self.q_ref, scalar_seed)
333344

334345
def __dealloc__(self):
335346
MT19937_Delete(&self.mt19937)
336-
c_dpctl.DPCTLQueue_Delete(self.QRef)
347+
c_dpctl.DPCTLQueue_Delete(self.q_ref)
337348

338-
cdef mt19937_struct * mt19937(self):
349+
cdef mt19937_struct * get_mt19937(self):
339350
return &self.mt19937
340351

352+
cdef c_dpctl.SyclQueue get_queue(self):
353+
return self.q
354+
341355
cdef c_dpctl.DPCTLSyclQueueRef get_queue_ref(self):
342-
return self.QRef
356+
return self.q_ref
343357

344358
cpdef utils.dpnp_descriptor uniform(self, low, high, size, dtype, usm_type):
345359
cdef shape_type_c result_shape
@@ -365,11 +379,11 @@ cdef class MT19937:
365379
None,
366380
device=None,
367381
usm_type=usm_type,
368-
sycl_queue=self.Queue)
382+
sycl_queue=self.get_queue())
369383

370384
func = <fptr_dpnp_rng_uniform_c_1out_t > kernel_data.ptr
371385
# call FPTR function
372-
event_ref = func(self.QRef, result.get_data(), low, high, result.size, &self.mt19937, NULL)
386+
event_ref = func(self.get_queue_ref(), result.get_data(), low, high, result.size, self.get_mt19937(), NULL)
373387

374388
with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
375389
c_dpctl.DPCTLEvent_Delete(event_ref)

dpnp/random/dpnp_iface_random.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,28 @@
9999

100100

101101
class RandomState:
102-
def __init__(self, seed, sycl_queue=None):
102+
"""
103+
A container for the Mersenne Twister pseudo-random number generator.
104+
105+
For full documentation refer to :obj:`numpy.random.RandomState`.
106+
"""
107+
108+
def __init__(self, seed=1, sycl_queue=None):
109+
if seed is None:
110+
seed = 1
111+
103112
self.random_state = MT19937(seed, sycl_queue)
104113

105114
def uniform(self, low=0.0, high=1.0, size=None, dtype=numpy.float64, usm_type="device"):
115+
"""
116+
Draw samples from a uniform distribution.
117+
118+
Samples are uniformly distributed over the half-open interval [low, high) (includes low, but excludes high).
119+
In other words, any value within the given interval is equally likely to be drawn by uniform.
120+
121+
For full documentation refer to :obj:`numpy.random.RandomState.uniform`.
122+
"""
123+
106124
if not use_origin_backend(low):
107125
if not dpnp.isscalar(low):
108126
pass

tests/test_sycl_queue.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -220,22 +220,24 @@ def test_broadcasting(func, data1, data2, device):
220220

221221

222222
@pytest.mark.parametrize("usm_type",
223-
["host", "device", "shared"])
224-
def test_uniform(usm_type):
223+
["host", "device", "shared"])
224+
@pytest.mark.parametrize("size",
225+
[None, (), 3, (2, 1), (4, 2, 5)],
226+
ids=['None', '()', '3', '(2,1)', '(4,2,5)'])
227+
def test_uniform(usm_type, size):
225228
low = 1.0
226229
high = 2.0
227-
size = 3
228230
res = dpnp.random.uniform(low, high, size=size, usm_type=usm_type)
229231

230232
res_usm_type = res.get_array().usm_type
231233
assert usm_type == res_usm_type
232234

233235

234236
@pytest.mark.parametrize("usm_type",
235-
["host", "device", "shared"])
237+
["host", "device", "shared"])
236238
@pytest.mark.parametrize("seed",
237-
[123, (12, 58), [134, 99], (147, 56, 896), [1, 654, 78]],
238-
ids=['123', '(12, 58)', '[134, 99]', '(147, 56, 896)', '[1, 654, 78]'])
239+
[None, (), 123, (12, 58), (147, 56, 896), [1, 654, 78]],
240+
ids=['None', '()', '123', '(12,58)', '(147,56,896)', '[1,654,78]'])
239241
def test_rs_uniform(usm_type, seed):
240242
seed = 123
241243
sycl_queue = dpctl.SyclQueue()

0 commit comments

Comments
 (0)