Skip to content

Commit 84d224d

Browse files
committed
Default dtype to be dependent on fp64 support
1 parent efd9f84 commit 84d224d

File tree

7 files changed

+147
-97
lines changed

7 files changed

+147
-97
lines changed

dpnp/backend/kernels/dpnp_krnl_random.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,11 @@ static inline DPCTLSyclEventRef dpnp_rng_generate(const _DistrType& distr,
7979
const int64_t size,
8080
_DataType* result) {
8181
DPCTLSyclEventRef event_ref = nullptr;
82+
sycl::event event;
8283

8384
// perform rng generation
8485
try {
85-
auto event = mkl_rng::generate<_DistrType, _EngineType>(distr, engine, size, result);
86+
event = mkl_rng::generate<_DistrType, _EngineType>(distr, engine, size, result);
8687
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);
8788
} catch (const std::exception &e) {
8889
// TODO: add error reporting
@@ -1377,6 +1378,7 @@ DPCTLSyclEventRef dpnp_rng_normal_c(DPCTLSyclQueueRef q_ref,
13771378
{
13781379
// avoid warning unused variable
13791380
(void)dep_event_vec_ref;
1381+
(void)q_ref;
13801382

13811383
DPCTLSyclEventRef event_ref = nullptr;
13821384

@@ -1385,8 +1387,6 @@ DPCTLSyclEventRef dpnp_rng_normal_c(DPCTLSyclQueueRef q_ref,
13851387
return event_ref;
13861388
}
13871389

1388-
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
1389-
13901390
mt19937_struct* random_state = static_cast<mt19937_struct *>(random_state_in);
13911391
_DataType* result = static_cast<_DataType *>(result_out);
13921392

@@ -2135,7 +2135,7 @@ DPCTLSyclEventRef dpnp_rng_uniform_c(DPCTLSyclQueueRef q_ref,
21352135
return event_ref;
21362136
}
21372137

2138-
sycl::queue q = *(reinterpret_cast<sycl::queue *>(q_ref));
2138+
sycl::queue *q = reinterpret_cast<sycl::queue *>(q_ref);
21392139

21402140
mt19937_struct* random_state = static_cast<mt19937_struct *>(random_state_in);
21412141
_DataType* result = static_cast<_DataType *>(result_out);
@@ -2148,7 +2148,7 @@ DPCTLSyclEventRef dpnp_rng_uniform_c(DPCTLSyclQueueRef q_ref,
21482148
mkl_rng::mt19937 *engine = static_cast<mkl_rng::mt19937 *>(random_state->engine);
21492149

21502150
if constexpr (std::is_same<_DataType, int32_t>::value) {
2151-
if (q.get_device().has(sycl::aspect::fp64)) {
2151+
if (q->get_device().has(sycl::aspect::fp64)) {
21522152
/**
21532153
* A note from oneMKL for oneapi::mkl::rng::uniform (Discrete):
21542154
* The oneapi::mkl::rng::uniform_method::standard uses the s BRNG type on GPU devices.
@@ -2161,7 +2161,8 @@ DPCTLSyclEventRef dpnp_rng_uniform_c(DPCTLSyclQueueRef q_ref,
21612161

21622162
// perform generation
21632163
try {
2164-
auto event = mkl_rng::generate(distribution, *engine, size, result);
2164+
auto event = mkl_rng::generate<mkl_rng::uniform<_DataType, method_type>, mkl_rng::mt19937>(
2165+
distribution, *engine, size, result);
21652166
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);
21662167
return DPCTLEvent_Copy(event_ref);
21672168
} catch (const oneapi::mkl::unsupported_device&) {

dpnp/backend/src/dpnp_random_state.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,28 +30,25 @@ namespace mkl_rng = oneapi::mkl::rng;
3030

3131
void MT19937_InitScalarSeed(mt19937_struct *mt19937, DPCTLSyclQueueRef q_ref, uint32_t seed)
3232
{
33-
sycl::queue q = *(reinterpret_cast<sycl::queue *>(q_ref));
34-
mt19937->engine = new mkl_rng::mt19937(q, seed);
35-
return;
33+
sycl::queue *q = reinterpret_cast<sycl::queue *>(q_ref);
34+
mt19937->engine = new mkl_rng::mt19937(*q, seed);
3635
}
3736

3837
void MT19937_InitVectorSeed(mt19937_struct *mt19937, DPCTLSyclQueueRef q_ref, uint32_t *seed, unsigned int n) {
39-
sycl::queue q = *(reinterpret_cast<sycl::queue *>(q_ref));
38+
sycl::queue *q = reinterpret_cast<sycl::queue *>(q_ref);
4039

4140
switch (n) {
42-
case 1: mt19937->engine = new mkl_rng::mt19937(q, {seed[0]}); break;
43-
case 2: mt19937->engine = new mkl_rng::mt19937(q, {seed[0], seed[1]}); break;
44-
case 3: mt19937->engine = new mkl_rng::mt19937(q, {seed[0], seed[1], seed[2]}); break;
41+
case 1: mt19937->engine = new mkl_rng::mt19937(*q, {seed[0]}); break;
42+
case 2: mt19937->engine = new mkl_rng::mt19937(*q, {seed[0], seed[1]}); break;
43+
case 3: mt19937->engine = new mkl_rng::mt19937(*q, {seed[0], seed[1], seed[2]}); break;
4544
default:
4645
// TODO need to get rid of the limitation for seed vector length
4746
throw std::runtime_error("Too long seed vector");
4847
}
49-
return;
5048
}
5149

5250
void MT19937_Delete(mt19937_struct *mt19937) {
53-
mkl_rng::mt19937 *engine = reinterpret_cast<mkl_rng::mt19937 *>(mt19937->engine);
51+
mkl_rng::mt19937 *engine = static_cast<mkl_rng::mt19937 *>(mt19937->engine);
5452
mt19937->engine = nullptr;
5553
delete engine;
56-
return;
5754
}

dpnp/random/dpnp_algo_random.pyx

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -404,9 +404,9 @@ cdef class MT19937:
404404
# call FPTR function
405405
event_ref = func(self.get_queue_ref(), result.get_data(), loc, scale, result.size, self.get_mt19937(), NULL)
406406

407-
with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
408-
c_dpctl.DPCTLEvent_Delete(event_ref)
409-
407+
if event_ref != NULL:
408+
with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
409+
c_dpctl.DPCTLEvent_Delete(event_ref)
410410
return result
411411

412412

@@ -440,9 +440,9 @@ cdef class MT19937:
440440
# call FPTR function
441441
event_ref = func(self.get_queue_ref(), result.get_data(), low, high, result.size, self.get_mt19937(), NULL)
442442

443-
with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
444-
c_dpctl.DPCTLEvent_Delete(event_ref)
445-
443+
if event_ref != NULL:
444+
with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
445+
c_dpctl.DPCTLEvent_Delete(event_ref)
446446
return result
447447

448448

dpnp/random/dpnp_iface_random.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,8 @@ def normal(loc=0.0, scale=1.0, size=None, usm_type='device'):
786786
-----------
787787
Parameters ``loc`` and ``scale`` are supported as scalar.
788788
Otherwise, :obj:`numpy.random.normal(loc, scale, size)` samples are drawn.
789-
Output array data type is :obj:`dpnp.float64`.
789+
Output array data type is :obj:`dpnp.float64` if device supports it
790+
or :obj:`dpnp.float32` otherwise.
790791
791792
Examples
792793
--------
@@ -798,7 +799,7 @@ def normal(loc=0.0, scale=1.0, size=None, usm_type='device'):
798799
return _get_random_state().normal(loc=loc,
799800
scale=scale,
800801
size=size,
801-
dtype=dpnp.float64,
802+
dtype=None,
802803
usm_type=usm_type)
803804

804805

@@ -996,7 +997,8 @@ def rand(d0, *dn, usm_type="device"):
996997
997998
Limitations
998999
-----------
999-
Output array data type is :obj:`dpnp.float64`.
1000+
Output array data type is :obj:`dpnp.float64` if device supports it
1001+
or :obj:`dpnp.float32` otherwise.
10001002
10011003
Examples
10021004
--------
@@ -1054,7 +1056,8 @@ def randn(d0, *dn, usm_type="device"):
10541056
10551057
Limitations
10561058
-----------
1057-
Output array data type is :obj:`dpnp.float64`.
1059+
Output array data type is :obj:`dpnp.float64` if device supports it
1060+
or :obj:`dpnp.float32` otherwise.
10581061
10591062
Examples
10601063
--------
@@ -1084,7 +1087,8 @@ def random(size=None, usm_type="device"):
10841087
10851088
Limitations
10861089
-----------
1087-
Output array data type is :obj:`dpnp.float64`.
1090+
Output array data type is :obj:`dpnp.float64` if device supports it
1091+
or :obj:`dpnp.float32` otherwise.
10881092
10891093
Examples
10901094
--------
@@ -1145,7 +1149,8 @@ def random_sample(size=None, usm_type="device"):
11451149
11461150
Limitations
11471151
-----------
1148-
Output array data type is :obj:`dpnp.float64`.
1152+
Output array data type is :obj:`dpnp.float64` if device supports it
1153+
or :obj:`dpnp.float32` otherwise.
11491154
11501155
Examples
11511156
--------
@@ -1172,7 +1177,8 @@ def ranf(size=None, usm_type="device"):
11721177
11731178
Limitations
11741179
-----------
1175-
Output array data type is :obj:`dpnp.float64`.
1180+
Output array data type is :obj:`dpnp.float64` if device supports it
1181+
or :obj:`dpnp.float32` otherwise.
11761182
11771183
Examples
11781184
--------
@@ -1233,7 +1239,8 @@ def sample(size=None, usm_type="device"):
12331239
12341240
Limitations
12351241
-----------
1236-
Output array data type is :obj:`dpnp.float64`.
1242+
Output array data type is :obj:`dpnp.float64` if device supports it
1243+
or :obj:`dpnp.float32` otherwise.
12371244
12381245
Examples
12391246
--------
@@ -1407,7 +1414,8 @@ def standard_normal(size=None, usm_type="device"):
14071414
14081415
Limitations
14091416
-----------
1410-
Output array data type is :obj:`dpnp.float64`.
1417+
Output array data type is :obj:`dpnp.float64` if device supports it
1418+
or :obj:`dpnp.float32` otherwise.
14111419
14121420
Examples
14131421
--------
@@ -1508,7 +1516,8 @@ def uniform(low=0.0, high=1.0, size=None, usm_type='device'):
15081516
-----------
15091517
Parameters ``low`` and ``high`` are supported as scalar.
15101518
Otherwise, :obj:`numpy.random.uniform(low, high, size)` samples are drawn.
1511-
Output array data type is :obj:`dpnp.float64`.
1519+
Output array data type is :obj:`dpnp.float64` if device supports it
1520+
or :obj:`dpnp.float32` otherwise.
15121521
15131522
Examples
15141523
--------
@@ -1524,7 +1533,7 @@ def uniform(low=0.0, high=1.0, size=None, usm_type='device'):
15241533
return _get_random_state().uniform(low=low,
15251534
high=high,
15261535
size=size,
1527-
dtype=dpnp.float64,
1536+
dtype=None,
15281537
usm_type=usm_type)
15291538

15301539

0 commit comments

Comments
 (0)