@@ -508,6 +508,10 @@ void dpnp_rng_negative_binomial_c(void* result, const double a, const double p,
508
508
event_out.wait ();
509
509
}
510
510
511
+ template <typename _KernelNameSpecialization>
512
+ class dpnp_rng_noncentral_chisquare_c_kernel1 ;
513
+ template <typename _KernelNameSpecialization>
514
+ class dpnp_rng_noncentral_chisquare_c_kernel2 ;
511
515
template <typename _DataType>
512
516
void dpnp_rng_noncentral_chisquare_c (void * result, const _DataType df, const _DataType nonc, const size_t size)
513
517
{
@@ -562,14 +566,24 @@ void dpnp_rng_noncentral_chisquare_c(void* result, const _DataType df, const _Da
562
566
event_out.wait ();
563
567
564
568
shape = 0.5 * df;
565
-
569
+
566
570
if (0.125 * size > sqrt (lambda))
567
571
{
568
572
size_t * idx = nullptr ;
569
573
_DataType* tmp = nullptr ;
570
574
idx = reinterpret_cast <size_t *>(dpnp_memory_alloc_c (size * sizeof (size_t )));
571
- for (i = 0 ; i < size; i++)
575
+
576
+ cl::sycl::range<1 > gws1 (size);
577
+ auto kernel_parallel_for_func1 = [=](cl::sycl::id<1 > global_id) {
578
+ size_t i = global_id[0 ];
572
579
idx[i] = i;
580
+ };
581
+ auto kernel_func1 = [&](cl::sycl::handler& cgh) {
582
+ cgh.parallel_for <class dpnp_rng_noncentral_chisquare_c_kernel1 <_DataType>>(gws1,
583
+ kernel_parallel_for_func1);
584
+ };
585
+ event_out = DPNP_QUEUE.submit (kernel_func1);
586
+ event_out.wait ();
573
587
574
588
std::sort (idx, idx + size, [pvec](size_t i1, size_t i2) { return pvec[i1] < pvec[i2]; });
575
589
/* idx now contains original indexes of ordered Poisson outputs */
@@ -578,14 +592,13 @@ void dpnp_rng_noncentral_chisquare_c(void* result, const _DataType df, const _Da
578
592
tmp = reinterpret_cast <_DataType*>(dpnp_memory_alloc_c (size * sizeof (_DataType)));
579
593
for (i = 0 ; i < size;)
580
594
{
581
- size_t k, j;
595
+ size_t j;
582
596
int cv = pvec[idx[i]];
583
-
584
597
// TODO vectorize
585
598
for (j = i + 1 ; (j < size) && (pvec[idx[j]] == cv); j++)
586
599
{
587
600
}
588
- // assert(j > i);
601
+
589
602
if (j <= i)
590
603
{
591
604
throw std::runtime_error (" DPNP RNG Error: dpnp_rng_noncentral_chisquare_c() failed." );
@@ -594,13 +607,20 @@ void dpnp_rng_noncentral_chisquare_c(void* result, const _DataType df, const _Da
594
607
event_out = mkl_rng::generate (gamma_distribution, DPNP_RNG_ENGINE, j - i, tmp);
595
608
event_out.wait ();
596
609
597
- // TODO vectorize
598
- for (k = i; k < j; k++)
599
- result1[idx[k]] = tmp[k - i];
610
+ cl::sycl::range<1 > gws2 (j - i);
611
+ auto kernel_parallel_for_func2 = [=](cl::sycl::id<1 > global_id) {
612
+ size_t index = global_id[0 ];
613
+ result1[idx[index + i]] = tmp[index];
614
+ };
615
+ auto kernel_func2 = [&](cl::sycl::handler& cgh) {
616
+ cgh.parallel_for <class dpnp_rng_noncentral_chisquare_c_kernel2 <_DataType>>(gws2,
617
+ kernel_parallel_for_func2);
618
+ };
619
+ event_out = DPNP_QUEUE.submit (kernel_func2);
620
+ event_out.wait ();
600
621
601
622
i = j;
602
623
}
603
-
604
624
dpnp_memory_free_c (tmp);
605
625
dpnp_memory_free_c (idx);
606
626
}
0 commit comments