Skip to content

Commit 0d1b224

Browse files
Implement SYCL kernels in noncentral_chisquare
1 parent 457a598 commit 0d1b224

File tree

1 file changed

+29
-9
lines changed

1 file changed

+29
-9
lines changed

dpnp/backend/kernels/dpnp_krnl_random.cpp

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,10 @@ void dpnp_rng_negative_binomial_c(void* result, const double a, const double p,
508508
event_out.wait();
509509
}
510510

511+
template <typename _KernelNameSpecialization>
512+
class dpnp_rng_noncentral_chisquare_c_kernel1;
513+
template <typename _KernelNameSpecialization>
514+
class dpnp_rng_noncentral_chisquare_c_kernel2;
511515
template <typename _DataType>
512516
void dpnp_rng_noncentral_chisquare_c(void* result, const _DataType df, const _DataType nonc, const size_t size)
513517
{
@@ -562,14 +566,24 @@ void dpnp_rng_noncentral_chisquare_c(void* result, const _DataType df, const _Da
562566
event_out.wait();
563567

564568
shape = 0.5 * df;
565-
569+
566570
if (0.125 * size > sqrt(lambda))
567571
{
568572
size_t* idx = nullptr;
569573
_DataType* tmp = nullptr;
570574
idx = reinterpret_cast<size_t*>(dpnp_memory_alloc_c(size * sizeof(size_t)));
571-
for (i = 0; i < size; i++)
575+
576+
cl::sycl::range<1> gws1(size);
577+
auto kernel_parallel_for_func1 = [=](cl::sycl::id<1> global_id) {
578+
size_t i = global_id[0];
572579
idx[i] = i;
580+
};
581+
auto kernel_func1 = [&](cl::sycl::handler& cgh) {
582+
cgh.parallel_for<class dpnp_rng_noncentral_chisquare_c_kernel1<_DataType>>(gws1,
583+
kernel_parallel_for_func1);
584+
};
585+
event_out = DPNP_QUEUE.submit(kernel_func1);
586+
event_out.wait();
573587

574588
std::sort(idx, idx + size, [pvec](size_t i1, size_t i2) { return pvec[i1] < pvec[i2]; });
575589
/* idx now contains original indexes of ordered Poisson outputs */
@@ -578,14 +592,13 @@ void dpnp_rng_noncentral_chisquare_c(void* result, const _DataType df, const _Da
578592
tmp = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(size * sizeof(_DataType)));
579593
for (i = 0; i < size;)
580594
{
581-
size_t k, j;
595+
size_t j;
582596
int cv = pvec[idx[i]];
583-
584597
// TODO vectorize
585598
for (j = i + 1; (j < size) && (pvec[idx[j]] == cv); j++)
586599
{
587600
}
588-
// assert(j > i);
601+
589602
if (j <= i)
590603
{
591604
throw std::runtime_error("DPNP RNG Error: dpnp_rng_noncentral_chisquare_c() failed.");
@@ -594,13 +607,20 @@ void dpnp_rng_noncentral_chisquare_c(void* result, const _DataType df, const _Da
594607
event_out = mkl_rng::generate(gamma_distribution, DPNP_RNG_ENGINE, j - i, tmp);
595608
event_out.wait();
596609

597-
// TODO vectorize
598-
for (k = i; k < j; k++)
599-
result1[idx[k]] = tmp[k - i];
610+
cl::sycl::range<1> gws2(j - i);
611+
auto kernel_parallel_for_func2 = [=](cl::sycl::id<1> global_id) {
612+
size_t index = global_id[0];
613+
result1[idx[index + i]] = tmp[index];
614+
};
615+
auto kernel_func2 = [&](cl::sycl::handler& cgh) {
616+
cgh.parallel_for<class dpnp_rng_noncentral_chisquare_c_kernel2<_DataType>>(gws2,
617+
kernel_parallel_for_func2);
618+
};
619+
event_out = DPNP_QUEUE.submit(kernel_func2);
620+
event_out.wait();
600621

601622
i = j;
602623
}
603-
604624
dpnp_memory_free_c(tmp);
605625
dpnp_memory_free_c(idx);
606626
}

0 commit comments

Comments
 (0)