Skip to content

Commit fab0312

Browse files
authored
DOT combine MKL and SYCL into one kernel (#111)
1 parent 4c99634 commit fab0312

File tree

5 files changed

+47
-101
lines changed

5 files changed

+47
-101
lines changed

dpnp/backend/backend_iface.hpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -171,22 +171,6 @@ INP_DLLEXPORT void custom_sum_c(void* array, void* result, size_t size);
171171
template <typename _DataType>
172172
INP_DLLEXPORT void custom_prod_c(void* array, void* result, size_t size);
173173

174-
/**
175-
* @ingroup BACKEND_API
176-
* @brief MKL implementation of dot function
177-
*
178-
* @param [in] array1 Input array.
179-
*
180-
* @param [in] array2 Input array.
181-
*
182-
* @param [out] result1 Output array.
183-
*
184-
* @param [in] size Number of elements in input arrays.
185-
*
186-
*/
187-
template <typename _DataType>
188-
INP_DLLEXPORT void mkl_blas_dot_c(void* array1, void* array2, void* result1, size_t size);
189-
190174
/**
191175
* @ingroup BACKEND_API
192176
* @brief MKL implementation of eig function

dpnp/backend/backend_iface_fptr.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,11 @@ void* get_backend_function_name(const char* func_name, const char* type_name)
9191
{
9292
if (!strncmp(type_name, supported_type1_name, strlen(supported_type1_name)))
9393
{
94-
return reinterpret_cast<void*>(mkl_blas_dot_c<double>);
94+
return reinterpret_cast<void*>(custom_blas_dot_c<double>);
9595
}
9696
else if (!strncmp(type_name, supported_type2_name, strlen(supported_type2_name)))
9797
{
98-
return reinterpret_cast<void*>(mkl_blas_dot_c<float>);
98+
return reinterpret_cast<void*>(custom_blas_dot_c<float>);
9999
}
100100
else if (!strncmp(type_name, supported_type3_name, strlen(supported_type3_name)))
101101
{
@@ -312,8 +312,8 @@ static func_map_t func_map_init()
312312

313313
fmap[DPNPFuncName::DPNP_FN_DOT][eft_INT][eft_INT] = {eft_INT, (void*)custom_blas_dot_c<int>};
314314
fmap[DPNPFuncName::DPNP_FN_DOT][eft_LNG][eft_LNG] = {eft_LNG, (void*)custom_blas_dot_c<long>};
315-
fmap[DPNPFuncName::DPNP_FN_DOT][eft_FLT][eft_FLT] = {eft_FLT, (void*)mkl_blas_dot_c<float>};
316-
fmap[DPNPFuncName::DPNP_FN_DOT][eft_DBL][eft_DBL] = {eft_DBL, (void*)mkl_blas_dot_c<double>};
315+
fmap[DPNPFuncName::DPNP_FN_DOT][eft_FLT][eft_FLT] = {eft_FLT, (void*)custom_blas_dot_c<float>};
316+
fmap[DPNPFuncName::DPNP_FN_DOT][eft_DBL][eft_DBL] = {eft_DBL, (void*)custom_blas_dot_c<double>};
317317

318318
fmap[DPNPFuncName::DPNP_FN_EIG][eft_INT][eft_INT] = {eft_DBL, (void*)mkl_lapack_syevd_c<double>};
319319
fmap[DPNPFuncName::DPNP_FN_EIG][eft_LNG][eft_LNG] = {eft_DBL, (void*)mkl_lapack_syevd_c<double>};

dpnp/backend/custom_kernels.cpp

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -131,35 +131,58 @@ void custom_blas_dot_c(void* array1_in, void* array2_in, void* result1, size_t s
131131
_DataType* array_2 = reinterpret_cast<_DataType*>(array2_in);
132132
_DataType* result = reinterpret_cast<_DataType*>(result1);
133133

134-
_DataType* local_mem = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(size * sizeof(_DataType)));
134+
if (!size)
135+
{
136+
return;
137+
}
135138

136-
// what about reduction??
137-
cl::sycl::range<1> gws(size);
138-
event = DPNP_QUEUE.submit([&](cl::sycl::handler& cgh) {
139-
cgh.parallel_for<class custom_blas_dot_c_kernel<_DataType> >(gws, [=](cl::sycl::id<1> global_id)
140-
{
141-
const size_t index = global_id[0];
142-
local_mem[index] = array_1[index] * array_2[index];
143-
} // kernel lambda
144-
); // parallel_for
145-
} // task lambda
146-
); // queue.submit
139+
if constexpr (std::is_same<_DataType, double>::value || std::is_same<_DataType, float>::value)
140+
{
141+
event = mkl_blas::dot(DPNP_QUEUE,
142+
size,
143+
array_1,
144+
1, // array_1 stride
145+
array_2,
146+
1, // array_2 stride
147+
result);
148+
event.wait();
149+
}
150+
else
151+
{
152+
_DataType* local_mem = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(size * sizeof(_DataType)));
147153

148-
event.wait();
154+
// what about reduction??
155+
cl::sycl::range<1> gws(size);
149156

150-
auto policy = oneapi::dpl::execution::make_device_policy<class custom_blas_dot_c_kernel<_DataType>>(DPNP_QUEUE);
157+
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
158+
const size_t index = global_id[0];
159+
local_mem[index] = array_1[index] * array_2[index];
160+
};
151161

152-
_DataType accumulator = 0;
153-
accumulator = std::reduce(policy, local_mem, local_mem + size, _DataType(0), std::plus<_DataType>());
154-
policy.queue().wait();
162+
auto kernel_func = [&](cl::sycl::handler& cgh) {
163+
cgh.parallel_for<class custom_blas_dot_c_kernel<_DataType> >(gws, kernel_parallel_for_func);
164+
};
165+
166+
event = DPNP_QUEUE.submit(kernel_func);
167+
168+
event.wait();
155169

156-
result[0] = accumulator;
170+
auto policy = oneapi::dpl::execution::make_device_policy<class custom_blas_dot_c_kernel<_DataType>>(DPNP_QUEUE);
157171

158-
free(local_mem, DPNP_QUEUE);
172+
_DataType accumulator = 0;
173+
accumulator = std::reduce(policy, local_mem, local_mem + size, _DataType(0), std::plus<_DataType>());
174+
policy.queue().wait();
175+
176+
result[0] = accumulator;
177+
178+
free(local_mem, DPNP_QUEUE);
179+
}
159180
}
160181

161-
template void custom_blas_dot_c<long>(void* array1_in, void* array2_in, void* result1, size_t size);
162182
template void custom_blas_dot_c<int>(void* array1_in, void* array2_in, void* result1, size_t size);
183+
template void custom_blas_dot_c<long>(void* array1_in, void* array2_in, void* result1, size_t size);
184+
template void custom_blas_dot_c<float>(void* array1_in, void* array2_in, void* result1, size_t size);
185+
template void custom_blas_dot_c<double>(void* array1_in, void* array2_in, void* result1, size_t size);
163186

164187
#if 0 // Example for OpenCL kernel
165188
#include <map>

dpnp/backend/mkl_wrap_blas1.cpp

Lines changed: 0 additions & 60 deletions
This file was deleted.

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,6 @@
272272
"dpnp/backend/custom_kernels_sorting.cpp",
273273
"dpnp/backend/custom_kernels_statistics.cpp",
274274
"dpnp/backend/memory_sycl.cpp",
275-
"dpnp/backend/mkl_wrap_blas1.cpp",
276275
"dpnp/backend/mkl_wrap_lapack.cpp",
277276
"dpnp/backend/mkl_wrap_rng.cpp",
278277
"dpnp/backend/queue_sycl.cpp"

0 commit comments

Comments
 (0)