1
1
# cython: language_level=3
2
2
# -*- coding: utf-8 -*-
3
3
# *****************************************************************************
4
- # Copyright (c) 2016-2020 , Intel Corporation
4
+ # Copyright (c) 2016-2022 , Intel Corporation
5
5
# All rights reserved.
6
6
#
7
7
# Redistribution and use in source and binary forms, with or without
@@ -287,59 +287,73 @@ cdef extern from "dpnp_random_state.hpp":
287
287
288
288
289
289
cdef class MT19937:
290
- """ Class storing MKL engine for MT199374x32x10 algorithm
291
290
"""
291
+ Class storing MKL engine for MT199374x32x10 algorithm.
292
+ """
293
+
292
294
cdef mt19937_struct mt19937
293
- cdef c_dpctl.DPCTLSyclQueueRef QRef
294
- cdef c_dpctl.SyclQueue Queue
295
+ cdef c_dpctl.DPCTLSyclQueueRef q_ref
296
+ cdef c_dpctl.SyclQueue q
295
297
296
298
def __cinit__ (self , seed , sycl_queue = None ):
297
299
cdef bint is_vector_seed = False
298
300
cdef uint32_t scalar_seed = 0
299
301
cdef unsigned int vector_seed_len = 0
300
302
cdef unsigned int * vector_seed = NULL
301
303
302
- self .QRef = NULL
304
+ self .q_ref = NULL
303
305
if sycl_queue is None :
304
306
sycl_queue = dpctl.SyclQueue()
305
- if not isinstance (sycl_queue, dpctl.SyclQueue):
306
- raise TypeError
307
307
308
- if isinstance (seed, int ):
308
+ # keep a refference on SYCL queue
309
+ self .q = < c_dpctl.SyclQueue> sycl_queue
310
+ self .q_ref = c_dpctl.DPCTLQueue_Copy((self .q).get_queue_ref())
311
+ if (self .q_ref is NULL ):
312
+ raise ValueError (" SyclQueue copy failed" )
313
+
314
+ # get a scalar seed value or a vector of seeds
315
+ if isinstance (seed, int ) and seed >= 0 :
309
316
scalar_seed = < uint32_t> seed
310
317
elif isinstance (seed, (list , tuple )):
311
318
is_vector_seed = True
312
319
vector_seed_len = len (seed)
320
+ if vector_seed_len > 3 :
321
+ raise ValueError (
322
+ f" {vector_seed_len} length of seed vector isn't supported, "
323
+ " the length is limited by 3" )
324
+
313
325
vector_seed = < uint32_t * > malloc(vector_seed_len * sizeof(uint32_t))
314
326
if (not vector_seed):
315
- raise MemoryError
327
+ raise MemoryError (f " Could not allocate memory for seed vector of length {vector_seed_len} " )
316
328
329
+ # convert input seed's type to uint32_t one (expected in MKL function)
317
330
try :
318
331
for i in range (vector_seed_len):
319
332
vector_seed[i] = < uint32_t> seed[i]
320
333
except (ValueError , TypeError ) as e:
321
334
free(vector_seed)
322
335
raise e
323
336
else :
324
- raise TypeError (" Seed must be an uint32_t , or a sequence of uint32_t elements" )
337
+ raise TypeError (" Seed must be an unsigned int , or a sequence of unsigned int elements" )
325
338
326
- self .Queue = < c_dpctl.SyclQueue> sycl_queue
327
- self .QRef = c_dpctl.DPCTLQueue_Copy((self .Queue).get_queue_ref())
328
339
if is_vector_seed:
329
- MT19937_InitVectorSeed(& self .mt19937, self .QRef , vector_seed, vector_seed_len)
340
+ MT19937_InitVectorSeed(& self .mt19937, self .q_ref , vector_seed, vector_seed_len)
330
341
free(vector_seed)
331
342
else :
332
- MT19937_InitScalarSeed(& self .mt19937, self .QRef , scalar_seed)
343
+ MT19937_InitScalarSeed(& self .mt19937, self .q_ref , scalar_seed)
333
344
334
345
def __dealloc__ (self ):
335
346
MT19937_Delete(& self .mt19937)
336
- c_dpctl.DPCTLQueue_Delete(self .QRef )
347
+ c_dpctl.DPCTLQueue_Delete(self .q_ref )
337
348
338
- cdef mt19937_struct * mt19937 (self ):
349
+ cdef mt19937_struct * get_mt19937 (self ):
339
350
return & self .mt19937
340
351
352
+ cdef c_dpctl.SyclQueue get_queue(self ):
353
+ return self .q
354
+
341
355
cdef c_dpctl.DPCTLSyclQueueRef get_queue_ref(self ):
342
- return self .QRef
356
+ return self .q_ref
343
357
344
358
cpdef utils.dpnp_descriptor uniform(self , low, high, size, dtype, usm_type):
345
359
cdef shape_type_c result_shape
@@ -365,11 +379,11 @@ cdef class MT19937:
365
379
None ,
366
380
device = None ,
367
381
usm_type = usm_type,
368
- sycl_queue = self .Queue )
382
+ sycl_queue = self .get_queue() )
369
383
370
384
func = < fptr_dpnp_rng_uniform_c_1out_t > kernel_data.ptr
371
385
# call FPTR function
372
- event_ref = func(self .QRef , result.get_data(), low, high, result.size, & self .mt19937 , NULL )
386
+ event_ref = func(self .get_queue_ref() , result.get_data(), low, high, result.size, self .get_mt19937() , NULL )
373
387
374
388
with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
375
389
c_dpctl.DPCTLEvent_Delete(event_ref)
0 commit comments