Skip to content

Commit ef501e5

Browse files
committed
EIG combine existing MKL and non-existing SYCL to one kernel
1 parent 4c99634 commit ef501e5

File tree

6 files changed

+63
-86
lines changed

6 files changed

+63
-86
lines changed

dpnp/backend/backend_iface.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ INP_DLLEXPORT void mkl_blas_dot_c(void* array1, void* array2, void* result1, siz
189189

190190
/**
191191
* @ingroup BACKEND_API
192-
* @brief MKL implementation of eig function
192+
* @brief Custom implementation of eig function
193193
*
194194
* @param [in] array1 Input array.
195195
*
@@ -199,7 +199,7 @@ INP_DLLEXPORT void mkl_blas_dot_c(void* array1, void* array2, void* result1, siz
199199
*
200200
*/
201201
template <typename _DataType>
202-
INP_DLLEXPORT void mkl_lapack_syevd_c(void* array1, void* result1, size_t size);
202+
INP_DLLEXPORT void custom_lapack_syevd_c(void* array1, void* result1, size_t size);
203203

204204
/**
205205
* @ingroup BACKEND_API

dpnp/backend/backend_iface_fptr.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -315,10 +315,10 @@ static func_map_t func_map_init()
315315
fmap[DPNPFuncName::DPNP_FN_DOT][eft_FLT][eft_FLT] = {eft_FLT, (void*)mkl_blas_dot_c<float>};
316316
fmap[DPNPFuncName::DPNP_FN_DOT][eft_DBL][eft_DBL] = {eft_DBL, (void*)mkl_blas_dot_c<double>};
317317

318-
fmap[DPNPFuncName::DPNP_FN_EIG][eft_INT][eft_INT] = {eft_DBL, (void*)mkl_lapack_syevd_c<double>};
319-
fmap[DPNPFuncName::DPNP_FN_EIG][eft_LNG][eft_LNG] = {eft_DBL, (void*)mkl_lapack_syevd_c<double>};
320-
fmap[DPNPFuncName::DPNP_FN_EIG][eft_FLT][eft_FLT] = {eft_FLT, (void*)mkl_lapack_syevd_c<float>};
321-
fmap[DPNPFuncName::DPNP_FN_EIG][eft_DBL][eft_DBL] = {eft_DBL, (void*)mkl_lapack_syevd_c<double>};
318+
fmap[DPNPFuncName::DPNP_FN_EIG][eft_INT][eft_INT] = {eft_DBL, (void*)custom_lapack_syevd_c<double>};
319+
fmap[DPNPFuncName::DPNP_FN_EIG][eft_LNG][eft_LNG] = {eft_DBL, (void*)custom_lapack_syevd_c<double>};
320+
fmap[DPNPFuncName::DPNP_FN_EIG][eft_FLT][eft_FLT] = {eft_FLT, (void*)custom_lapack_syevd_c<float>};
321+
fmap[DPNPFuncName::DPNP_FN_EIG][eft_DBL][eft_DBL] = {eft_DBL, (void*)custom_lapack_syevd_c<double>};
322322

323323
fmap[DPNPFuncName::DPNP_FN_EXP][eft_INT][eft_INT] = {eft_DBL, (void*)custom_elemwise_exp_c<int, double>};
324324
fmap[DPNPFuncName::DPNP_FN_EXP][eft_LNG][eft_LNG] = {eft_DBL, (void*)custom_elemwise_exp_c<long, double>};

dpnp/backend/custom_kernels.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <cmath>
2727
#include <iostream>
2828
#include <mkl_blas_sycl.hpp>
29+
#include <mkl_lapack_sycl.hpp>
2930
#include <type_traits>
3031

3132
#include <backend_iface.hpp>
@@ -34,6 +35,7 @@
3435
#include "queue_sycl.hpp"
3536

3637
namespace mkl_blas = oneapi::mkl::blas;
38+
namespace mkl_lapack = oneapi::mkl::lapack;
3739

3840
template <typename _KernelNameSpecialization>
3941
class custom_blas_gemm_c_kernel;
@@ -161,6 +163,60 @@ void custom_blas_dot_c(void* array1_in, void* array2_in, void* result1, size_t s
161163
template void custom_blas_dot_c<long>(void* array1_in, void* array2_in, void* result1, size_t size);
162164
template void custom_blas_dot_c<int>(void* array1_in, void* array2_in, void* result1, size_t size);
163165

166+
template <typename _DataType>
167+
void custom_lapack_syevd_c(void* array_in, void* result1, size_t size)
168+
{
169+
if (!size)
170+
{
171+
return;
172+
}
173+
174+
_DataType* array = reinterpret_cast<_DataType*>(array_in);
175+
_DataType* result = reinterpret_cast<_DataType*>(result1);
176+
177+
if constexpr (std::is_same<_DataType, double>::value || std::is_same<_DataType, float>::value)
178+
{
179+
cl::sycl::event event;
180+
181+
_DataType* syevd_array = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(size * size * sizeof(_DataType)));
182+
dpnp_memory_memcpy_c(syevd_array, array, size * size * sizeof(_DataType));
183+
184+
const std::int64_t lda = std::max<size_t>(1UL, size);
185+
186+
const std::int64_t scratchpad_size = mkl_lapack::syevd_scratchpad_size<_DataType>(
187+
DPNP_QUEUE, oneapi::mkl::job::vec, oneapi::mkl::uplo::upper, size, lda);
188+
189+
_DataType* scratchpad = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(scratchpad_size * sizeof(_DataType)));
190+
191+
event = mkl_lapack::syevd(DPNP_QUEUE, // queue
192+
oneapi::mkl::job::vec, // jobz
193+
oneapi::mkl::uplo::upper, // uplo
194+
size, // The order of the matrix A (0≤n)
195+
syevd_array, // will be overwritten with eigenvectors
196+
lda,
197+
result,
198+
scratchpad,
199+
scratchpad_size);
200+
event.wait();
201+
202+
dpnp_memory_free_c(scratchpad);
203+
204+
custom_elemwise_transpose_c<_DataType>(
205+
syevd_array, {(long)size, (long)size}, {(long)size, (long)size}, {1, 0}, array, size * size);
206+
207+
dpnp_memory_free_c(syevd_array);
208+
}
209+
else
210+
{
211+
// TODO: implement SYCL kernel for int/long input
212+
}
213+
}
214+
215+
template void custom_lapack_syevd_c<int>(void* array1_in, void* result1, size_t size);
216+
template void custom_lapack_syevd_c<long>(void* array1_in, void* result1, size_t size);
217+
template void custom_lapack_syevd_c<float>(void* array1_in, void* result1, size_t size);
218+
template void custom_lapack_syevd_c<double>(void* array1_in, void* result1, size_t size);
219+
164220
#if 0 // Example for OpenCL kernel
165221
#include <map>
166222
#include <typeindex>

dpnp/backend/mkl_wrap_lapack.cpp

Lines changed: 0 additions & 78 deletions
This file was deleted.

examples/example7.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ int main(int, char**)
2828
array[size * i + i] = i + 1;
2929
}
3030

31-
mkl_lapack_syevd_c<double>(array, result, size);
31+
custom_lapack_syevd_c<double>(array, result, size);
3232

3333
std::cout << "eigen values" << std::endl;
3434
for (size_t i = 0; i < size; ++i)

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,6 @@
273273
"dpnp/backend/custom_kernels_statistics.cpp",
274274
"dpnp/backend/memory_sycl.cpp",
275275
"dpnp/backend/mkl_wrap_blas1.cpp",
276-
"dpnp/backend/mkl_wrap_lapack.cpp",
277276
"dpnp/backend/mkl_wrap_rng.cpp",
278277
"dpnp/backend/queue_sycl.cpp"
279278
],

0 commit comments

Comments
 (0)