Skip to content

DOT combine MKL and SYCL into one kernel #111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions dpnp/backend/backend_iface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,22 +171,6 @@ INP_DLLEXPORT void custom_sum_c(void* array, void* result, size_t size);
template <typename _DataType>
INP_DLLEXPORT void custom_prod_c(void* array, void* result, size_t size);

/**
* @ingroup BACKEND_API
* @brief MKL implementation of dot function
*
* @param [in] array1 Input array.
*
* @param [in] array2 Input array.
*
* @param [out] result1 Output array.
*
* @param [in] size Number of elements in input arrays.
*
*/
template <typename _DataType>
INP_DLLEXPORT void mkl_blas_dot_c(void* array1, void* array2, void* result1, size_t size);

/**
* @ingroup BACKEND_API
* @brief MKL implementation of eig function
Expand Down
8 changes: 4 additions & 4 deletions dpnp/backend/backend_iface_fptr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,11 @@ void* get_backend_function_name(const char* func_name, const char* type_name)
{
if (!strncmp(type_name, supported_type1_name, strlen(supported_type1_name)))
{
return reinterpret_cast<void*>(mkl_blas_dot_c<double>);
return reinterpret_cast<void*>(custom_blas_dot_c<double>);
}
else if (!strncmp(type_name, supported_type2_name, strlen(supported_type2_name)))
{
return reinterpret_cast<void*>(mkl_blas_dot_c<float>);
return reinterpret_cast<void*>(custom_blas_dot_c<float>);
}
else if (!strncmp(type_name, supported_type3_name, strlen(supported_type3_name)))
{
Expand Down Expand Up @@ -312,8 +312,8 @@ static func_map_t func_map_init()

fmap[DPNPFuncName::DPNP_FN_DOT][eft_INT][eft_INT] = {eft_INT, (void*)custom_blas_dot_c<int>};
fmap[DPNPFuncName::DPNP_FN_DOT][eft_LNG][eft_LNG] = {eft_LNG, (void*)custom_blas_dot_c<long>};
fmap[DPNPFuncName::DPNP_FN_DOT][eft_FLT][eft_FLT] = {eft_FLT, (void*)mkl_blas_dot_c<float>};
fmap[DPNPFuncName::DPNP_FN_DOT][eft_DBL][eft_DBL] = {eft_DBL, (void*)mkl_blas_dot_c<double>};
fmap[DPNPFuncName::DPNP_FN_DOT][eft_FLT][eft_FLT] = {eft_FLT, (void*)custom_blas_dot_c<float>};
fmap[DPNPFuncName::DPNP_FN_DOT][eft_DBL][eft_DBL] = {eft_DBL, (void*)custom_blas_dot_c<double>};

fmap[DPNPFuncName::DPNP_FN_EIG][eft_INT][eft_INT] = {eft_DBL, (void*)mkl_lapack_syevd_c<double>};
fmap[DPNPFuncName::DPNP_FN_EIG][eft_LNG][eft_LNG] = {eft_DBL, (void*)mkl_lapack_syevd_c<double>};
Expand Down
63 changes: 43 additions & 20 deletions dpnp/backend/custom_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,35 +131,58 @@ void custom_blas_dot_c(void* array1_in, void* array2_in, void* result1, size_t s
_DataType* array_2 = reinterpret_cast<_DataType*>(array2_in);
_DataType* result = reinterpret_cast<_DataType*>(result1);

_DataType* local_mem = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(size * sizeof(_DataType)));
if (!size)
{
return;
}

// what about reduction??
cl::sycl::range<1> gws(size);
event = DPNP_QUEUE.submit([&](cl::sycl::handler& cgh) {
cgh.parallel_for<class custom_blas_dot_c_kernel<_DataType> >(gws, [=](cl::sycl::id<1> global_id)
{
const size_t index = global_id[0];
local_mem[index] = array_1[index] * array_2[index];
} // kernel lambda
); // parallel_for
} // task lambda
); // queue.submit
if constexpr (std::is_same<_DataType, double>::value || std::is_same<_DataType, float>::value)
{
event = mkl_blas::dot(DPNP_QUEUE,
size,
array_1,
1, // array_1 stride
array_2,
1, // array_2 stride
result);
event.wait();
}
else
{
_DataType* local_mem = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(size * sizeof(_DataType)));

event.wait();
// what about reduction??
cl::sycl::range<1> gws(size);

auto policy = oneapi::dpl::execution::make_device_policy<class custom_blas_dot_c_kernel<_DataType>>(DPNP_QUEUE);
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
const size_t index = global_id[0];
local_mem[index] = array_1[index] * array_2[index];
};

_DataType accumulator = 0;
accumulator = std::reduce(policy, local_mem, local_mem + size, _DataType(0), std::plus<_DataType>());
policy.queue().wait();
auto kernel_func = [&](cl::sycl::handler& cgh) {
cgh.parallel_for<class custom_blas_dot_c_kernel<_DataType> >(gws, kernel_parallel_for_func);
};

event = DPNP_QUEUE.submit(kernel_func);

event.wait();

result[0] = accumulator;
auto policy = oneapi::dpl::execution::make_device_policy<class custom_blas_dot_c_kernel<_DataType>>(DPNP_QUEUE);

free(local_mem, DPNP_QUEUE);
_DataType accumulator = 0;
accumulator = std::reduce(policy, local_mem, local_mem + size, _DataType(0), std::plus<_DataType>());
policy.queue().wait();

result[0] = accumulator;

free(local_mem, DPNP_QUEUE);
}
}

template void custom_blas_dot_c<long>(void* array1_in, void* array2_in, void* result1, size_t size);
template void custom_blas_dot_c<int>(void* array1_in, void* array2_in, void* result1, size_t size);
template void custom_blas_dot_c<long>(void* array1_in, void* array2_in, void* result1, size_t size);
template void custom_blas_dot_c<float>(void* array1_in, void* array2_in, void* result1, size_t size);
template void custom_blas_dot_c<double>(void* array1_in, void* array2_in, void* result1, size_t size);

#if 0 // Example for OpenCL kernel
#include <map>
Expand Down
60 changes: 0 additions & 60 deletions dpnp/backend/mkl_wrap_blas1.cpp

This file was deleted.

1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@
"dpnp/backend/custom_kernels_sorting.cpp",
"dpnp/backend/custom_kernels_statistics.cpp",
"dpnp/backend/memory_sycl.cpp",
"dpnp/backend/mkl_wrap_blas1.cpp",
"dpnp/backend/mkl_wrap_lapack.cpp",
"dpnp/backend/mkl_wrap_rng.cpp",
"dpnp/backend/queue_sycl.cpp"
Expand Down