@@ -131,35 +131,58 @@ void custom_blas_dot_c(void* array1_in, void* array2_in, void* result1, size_t s
131
131
_DataType* array_2 = reinterpret_cast <_DataType*>(array2_in);
132
132
_DataType* result = reinterpret_cast <_DataType*>(result1);
133
133
134
- _DataType* local_mem = reinterpret_cast <_DataType*>(dpnp_memory_alloc_c (size * sizeof (_DataType)));
134
+ if (!size)
135
+ {
136
+ return ;
137
+ }
135
138
136
- // what about reduction??
137
- cl::sycl::range<1 > gws (size);
138
- event = DPNP_QUEUE.submit ([&](cl::sycl::handler& cgh) {
139
- cgh.parallel_for <class custom_blas_dot_c_kernel <_DataType> >(gws, [=](cl::sycl::id<1 > global_id)
140
- {
141
- const size_t index = global_id[0 ];
142
- local_mem[index] = array_1[index] * array_2[index];
143
- } // kernel lambda
144
- ); // parallel_for
145
- } // task lambda
146
- ); // queue.submit
139
+ if constexpr (std::is_same<_DataType, double >::value || std::is_same<_DataType, float >::value)
140
+ {
141
+ event = mkl_blas::dot (DPNP_QUEUE,
142
+ size,
143
+ array_1,
144
+ 1 , // array_1 stride
145
+ array_2,
146
+ 1 , // array_2 stride
147
+ result);
148
+ event.wait ();
149
+ }
150
+ else
151
+ {
152
+ _DataType* local_mem = reinterpret_cast <_DataType*>(dpnp_memory_alloc_c (size * sizeof (_DataType)));
147
153
148
- event.wait ();
154
+ // what about reduction??
155
+ cl::sycl::range<1 > gws (size);
149
156
150
- auto policy = oneapi::dpl::execution::make_device_policy<class custom_blas_dot_c_kernel <_DataType>>(DPNP_QUEUE);
157
+ auto kernel_parallel_for_func = [=](cl::sycl::id<1 > global_id) {
158
+ const size_t index = global_id[0 ];
159
+ local_mem[index] = array_1[index] * array_2[index];
160
+ };
151
161
152
- _DataType accumulator = 0 ;
153
- accumulator = std::reduce (policy, local_mem, local_mem + size, _DataType (0 ), std::plus<_DataType>());
154
- policy.queue ().wait ();
162
+ auto kernel_func = [&](cl::sycl::handler& cgh) {
163
+ cgh.parallel_for <class custom_blas_dot_c_kernel <_DataType> >(gws, kernel_parallel_for_func);
164
+ };
165
+
166
+ event = DPNP_QUEUE.submit (kernel_func);
167
+
168
+ event.wait ();
155
169
156
- result[ 0 ] = accumulator ;
170
+ auto policy = oneapi::dpl::execution::make_device_policy< class custom_blas_dot_c_kernel <_DataType>>(DPNP_QUEUE) ;
157
171
158
- free (local_mem, DPNP_QUEUE);
172
+ _DataType accumulator = 0 ;
173
+ accumulator = std::reduce (policy, local_mem, local_mem + size, _DataType (0 ), std::plus<_DataType>());
174
+ policy.queue ().wait ();
175
+
176
+ result[0 ] = accumulator;
177
+
178
+ free (local_mem, DPNP_QUEUE);
179
+ }
159
180
}
160
181
161
- template void custom_blas_dot_c<long >(void * array1_in, void * array2_in, void * result1, size_t size);
162
182
template void custom_blas_dot_c<int >(void * array1_in, void * array2_in, void * result1, size_t size);
183
+ template void custom_blas_dot_c<long >(void * array1_in, void * array2_in, void * result1, size_t size);
184
+ template void custom_blas_dot_c<float >(void * array1_in, void * array2_in, void * result1, size_t size);
185
+ template void custom_blas_dot_c<double >(void * array1_in, void * array2_in, void * result1, size_t size);
163
186
164
187
#if 0 // Example for OpenCL kernel
165
188
#include <map>
0 commit comments