Skip to content

Commit 0c7de14

Browse files
LukichevaPolinadensmirnAlexander-Makaryev
authored
Implement SYCL kernels in noncentral_chisquare (#1054)
* Implement SYCL kernels in noncentral_chisquare Co-authored-by: densmirn <[email protected]> Co-authored-by: Alexander-Makaryev <[email protected]>
1 parent 838f0cf commit 0c7de14

File tree

1 file changed

+29
-10
lines changed

1 file changed

+29
-10
lines changed

dpnp/backend/kernels/dpnp_krnl_random.cpp

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -486,14 +486,18 @@ void dpnp_rng_negative_binomial_c(void* result, const double a, const double p,
486486
event_out.wait();
487487
}
488488

489+
template <typename _KernelNameSpecialization>
490+
class dpnp_rng_noncentral_chisquare_c_kernel1;
491+
template <typename _KernelNameSpecialization>
492+
class dpnp_rng_noncentral_chisquare_c_kernel2;
489493
template <typename _DataType>
490494
void dpnp_rng_noncentral_chisquare_c(void* result, const _DataType df, const _DataType nonc, const size_t size)
491495
{
492496
if (!size || !result)
493497
{
494498
return;
495499
}
496-
DPNPC_ptr_adapter<_DataType> result1_ptr(result, size, true, true);
500+
DPNPC_ptr_adapter<_DataType> result1_ptr(result, size, false, true);
497501
_DataType* result1 = result1_ptr.get_ptr();
498502

499503
const _DataType d_zero = _DataType(0.0);
@@ -540,14 +544,23 @@ void dpnp_rng_noncentral_chisquare_c(void* result, const _DataType df, const _Da
540544
event_out.wait();
541545

542546
shape = 0.5 * df;
543-
544547
if (0.125 * size > sqrt(lambda))
545548
{
546549
size_t* idx = nullptr;
547550
_DataType* tmp = nullptr;
548551
idx = reinterpret_cast<size_t*>(dpnp_memory_alloc_c(size * sizeof(size_t)));
549-
for (i = 0; i < size; i++)
552+
553+
cl::sycl::range<1> gws1(size);
554+
auto kernel_parallel_for_func1 = [=](cl::sycl::id<1> global_id) {
555+
size_t i = global_id[0];
550556
idx[i] = i;
557+
};
558+
auto kernel_func1 = [&](cl::sycl::handler& cgh) {
559+
cgh.parallel_for<class dpnp_rng_noncentral_chisquare_c_kernel1<_DataType>>(gws1,
560+
kernel_parallel_for_func1);
561+
};
562+
event_out = DPNP_QUEUE.submit(kernel_func1);
563+
event_out.wait();
551564

552565
std::sort(idx, idx + size, [pvec](size_t i1, size_t i2) { return pvec[i1] < pvec[i2]; });
553566
/* idx now contains original indexes of ordered Poisson outputs */
@@ -556,14 +569,13 @@ void dpnp_rng_noncentral_chisquare_c(void* result, const _DataType df, const _Da
556569
tmp = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(size * sizeof(_DataType)));
557570
for (i = 0; i < size;)
558571
{
559-
size_t k, j;
572+
size_t j;
560573
int cv = pvec[idx[i]];
561-
562574
// TODO vectorize
563575
for (j = i + 1; (j < size) && (pvec[idx[j]] == cv); j++)
564576
{
565577
}
566-
// assert(j > i);
578+
567579
if (j <= i)
568580
{
569581
throw std::runtime_error("DPNP RNG Error: dpnp_rng_noncentral_chisquare_c() failed.");
@@ -572,13 +584,20 @@ void dpnp_rng_noncentral_chisquare_c(void* result, const _DataType df, const _Da
572584
event_out = mkl_rng::generate(gamma_distribution, DPNP_RNG_ENGINE, j - i, tmp);
573585
event_out.wait();
574586

575-
// TODO vectorize
576-
for (k = i; k < j; k++)
577-
result1[idx[k]] = tmp[k - i];
587+
cl::sycl::range<1> gws2(j - i);
588+
auto kernel_parallel_for_func2 = [=](cl::sycl::id<1> global_id) {
589+
size_t index = global_id[0];
590+
result1[idx[index + i]] = tmp[index];
591+
};
592+
auto kernel_func2 = [&](cl::sycl::handler& cgh) {
593+
cgh.parallel_for<class dpnp_rng_noncentral_chisquare_c_kernel2<_DataType>>(gws2,
594+
kernel_parallel_for_func2);
595+
};
596+
event_out = DPNP_QUEUE.submit(kernel_func2);
597+
event_out.wait();
578598

579599
i = j;
580600
}
581-
582601
dpnp_memory_free_c(tmp);
583602
dpnp_memory_free_c(idx);
584603
}

0 commit comments

Comments
 (0)