Skip to content

Commit 8da099b

Browse files
authored
ABS add MKL kernel (#127)
1 parent c9f1c2a commit 8da099b

File tree

1 file changed

+24
-16
lines changed

1 file changed

+24
-16
lines changed

dpnp/backend/custom_kernels_mathematical.cpp

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -50,25 +50,33 @@ void custom_elemwise_absolute_c(void* array1_in, const std::vector<long>& input_
5050
size_t* input_offset_shape = reinterpret_cast<size_t*>(dpnp_memory_alloc_c(input_shape_size * sizeof(long)));
5151
size_t* result_offset_shape = reinterpret_cast<size_t*>(dpnp_memory_alloc_c(input_shape_size * sizeof(long)));
5252

53-
cl::sycl::range<1> gws(size);
54-
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
55-
const size_t idx = global_id[0];
53+
if constexpr (std::is_same<_DataType, double>::value || std::is_same<_DataType, float>::value)
54+
{
55+
// https://docs.oneapi.com/versions/latest/onemkl/abs.html
56+
event = oneapi::mkl::vm::abs(DPNP_QUEUE, size, array1, result);
57+
}
58+
else
59+
{
60+
cl::sycl::range<1> gws(size);
61+
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
62+
const size_t idx = global_id[0];
5663

57-
if (array1[idx] >= 0)
58-
{
59-
result[idx] = array1[idx];
60-
}
61-
else
62-
{
63-
result[idx] = -1 * array1[idx];
64-
}
65-
};
64+
if (array1[idx] >= 0)
65+
{
66+
result[idx] = array1[idx];
67+
}
68+
else
69+
{
70+
result[idx] = -1 * array1[idx];
71+
}
72+
};
6673

67-
auto kernel_func = [&](cl::sycl::handler& cgh) {
68-
cgh.parallel_for<class custom_elemwise_absolute_c_kernel<_DataType>>(gws, kernel_parallel_for_func);
69-
};
74+
auto kernel_func = [&](cl::sycl::handler& cgh) {
75+
cgh.parallel_for<class custom_elemwise_absolute_c_kernel<_DataType>>(gws, kernel_parallel_for_func);
76+
};
7077

71-
event = DPNP_QUEUE.submit(kernel_func);
78+
event = DPNP_QUEUE.submit(kernel_func);
79+
}
7280

7381
event.wait();
7482

0 commit comments

Comments
 (0)