Skip to content

Commit 5f14466

Browse files
authored
Merge branch 'master' into feature/eig_kernel
2 parents ef501e5 + d8e4d24 commit 5f14466

11 files changed

+290
-265
lines changed

dpnp/backend/backend_iface.hpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -171,22 +171,6 @@ INP_DLLEXPORT void custom_sum_c(void* array, void* result, size_t size);
171171
template <typename _DataType>
172172
INP_DLLEXPORT void custom_prod_c(void* array, void* result, size_t size);
173173

174-
/**
175-
* @ingroup BACKEND_API
176-
* @brief MKL implementation of dot function
177-
*
178-
* @param [in] array1 Input array.
179-
*
180-
* @param [in] array2 Input array.
181-
*
182-
* @param [out] result1 Output array.
183-
*
184-
* @param [in] size Number of elements in input arrays.
185-
*
186-
*/
187-
template <typename _DataType>
188-
INP_DLLEXPORT void mkl_blas_dot_c(void* array1, void* array2, void* result1, size_t size);
189-
190174
/**
191175
* @ingroup BACKEND_API
192176
* @brief Custom implementation of eig function

dpnp/backend/backend_iface_fptr.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,11 @@ void* get_backend_function_name(const char* func_name, const char* type_name)
9191
{
9292
if (!strncmp(type_name, supported_type1_name, strlen(supported_type1_name)))
9393
{
94-
return reinterpret_cast<void*>(mkl_blas_dot_c<double>);
94+
return reinterpret_cast<void*>(custom_blas_dot_c<double>);
9595
}
9696
else if (!strncmp(type_name, supported_type2_name, strlen(supported_type2_name)))
9797
{
98-
return reinterpret_cast<void*>(mkl_blas_dot_c<float>);
98+
return reinterpret_cast<void*>(custom_blas_dot_c<float>);
9999
}
100100
else if (!strncmp(type_name, supported_type3_name, strlen(supported_type3_name)))
101101
{
@@ -312,8 +312,8 @@ static func_map_t func_map_init()
312312

313313
fmap[DPNPFuncName::DPNP_FN_DOT][eft_INT][eft_INT] = {eft_INT, (void*)custom_blas_dot_c<int>};
314314
fmap[DPNPFuncName::DPNP_FN_DOT][eft_LNG][eft_LNG] = {eft_LNG, (void*)custom_blas_dot_c<long>};
315-
fmap[DPNPFuncName::DPNP_FN_DOT][eft_FLT][eft_FLT] = {eft_FLT, (void*)mkl_blas_dot_c<float>};
316-
fmap[DPNPFuncName::DPNP_FN_DOT][eft_DBL][eft_DBL] = {eft_DBL, (void*)mkl_blas_dot_c<double>};
315+
fmap[DPNPFuncName::DPNP_FN_DOT][eft_FLT][eft_FLT] = {eft_FLT, (void*)custom_blas_dot_c<float>};
316+
fmap[DPNPFuncName::DPNP_FN_DOT][eft_DBL][eft_DBL] = {eft_DBL, (void*)custom_blas_dot_c<double>};
317317

318318
fmap[DPNPFuncName::DPNP_FN_EIG][eft_INT][eft_INT] = {eft_DBL, (void*)custom_lapack_syevd_c<double>};
319319
fmap[DPNPFuncName::DPNP_FN_EIG][eft_LNG][eft_LNG] = {eft_DBL, (void*)custom_lapack_syevd_c<double>};

dpnp/backend/custom_kernels.cpp

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -133,35 +133,58 @@ void custom_blas_dot_c(void* array1_in, void* array2_in, void* result1, size_t s
133133
_DataType* array_2 = reinterpret_cast<_DataType*>(array2_in);
134134
_DataType* result = reinterpret_cast<_DataType*>(result1);
135135

136-
_DataType* local_mem = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(size * sizeof(_DataType)));
136+
if (!size)
137+
{
138+
return;
139+
}
137140

138-
// what about reduction??
139-
cl::sycl::range<1> gws(size);
140-
event = DPNP_QUEUE.submit([&](cl::sycl::handler& cgh) {
141-
cgh.parallel_for<class custom_blas_dot_c_kernel<_DataType> >(gws, [=](cl::sycl::id<1> global_id)
142-
{
143-
const size_t index = global_id[0];
144-
local_mem[index] = array_1[index] * array_2[index];
145-
} // kernel lambda
146-
); // parallel_for
147-
} // task lambda
148-
); // queue.submit
141+
if constexpr (std::is_same<_DataType, double>::value || std::is_same<_DataType, float>::value)
142+
{
143+
event = mkl_blas::dot(DPNP_QUEUE,
144+
size,
145+
array_1,
146+
1, // array_1 stride
147+
array_2,
148+
1, // array_2 stride
149+
result);
150+
event.wait();
151+
}
152+
else
153+
{
154+
_DataType* local_mem = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(size * sizeof(_DataType)));
149155

150-
event.wait();
156+
// what about reduction??
157+
cl::sycl::range<1> gws(size);
151158

152-
auto policy = oneapi::dpl::execution::make_device_policy<class custom_blas_dot_c_kernel<_DataType>>(DPNP_QUEUE);
159+
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
160+
const size_t index = global_id[0];
161+
local_mem[index] = array_1[index] * array_2[index];
162+
};
163+
164+
auto kernel_func = [&](cl::sycl::handler& cgh) {
165+
cgh.parallel_for<class custom_blas_dot_c_kernel<_DataType>>(gws, kernel_parallel_for_func);
166+
};
167+
168+
event = DPNP_QUEUE.submit(kernel_func);
153169

154-
_DataType accumulator = 0;
155-
accumulator = std::reduce(policy, local_mem, local_mem + size, _DataType(0), std::plus<_DataType>());
156-
policy.queue().wait();
170+
event.wait();
171+
172+
auto policy = oneapi::dpl::execution::make_device_policy<class custom_blas_dot_c_kernel<_DataType>>(DPNP_QUEUE);
173+
174+
_DataType accumulator = 0;
175+
accumulator = std::reduce(policy, local_mem, local_mem + size, _DataType(0), std::plus<_DataType>());
176+
policy.queue().wait();
157177

158-
result[0] = accumulator;
178+
result[0] = accumulator;
159179

160-
free(local_mem, DPNP_QUEUE);
180+
free(local_mem, DPNP_QUEUE);
181+
}
161182
}
162183

163-
template void custom_blas_dot_c<long>(void* array1_in, void* array2_in, void* result1, size_t size);
164184
template void custom_blas_dot_c<int>(void* array1_in, void* array2_in, void* result1, size_t size);
185+
template void custom_blas_dot_c<long>(void* array1_in, void* array2_in, void* result1, size_t size);
186+
template void custom_blas_dot_c<float>(void* array1_in, void* array2_in, void* result1, size_t size);
187+
template void custom_blas_dot_c<double>(void* array1_in, void* array2_in, void* result1, size_t size);
165188

166189
template <typename _DataType>
167190
void custom_lapack_syevd_c(void* array_in, void* result1, size_t size)

0 commit comments

Comments
 (0)