@@ -133,35 +133,58 @@ void custom_blas_dot_c(void* array1_in, void* array2_in, void* result1, size_t s
133
133
_DataType* array_2 = reinterpret_cast <_DataType*>(array2_in);
134
134
_DataType* result = reinterpret_cast <_DataType*>(result1);
135
135
136
- _DataType* local_mem = reinterpret_cast <_DataType*>(dpnp_memory_alloc_c (size * sizeof (_DataType)));
136
+ if (!size)
137
+ {
138
+ return ;
139
+ }
137
140
138
- // what about reduction??
139
- cl::sycl::range<1 > gws (size);
140
- event = DPNP_QUEUE.submit ([&](cl::sycl::handler& cgh) {
141
- cgh.parallel_for <class custom_blas_dot_c_kernel <_DataType> >(gws, [=](cl::sycl::id<1 > global_id)
142
- {
143
- const size_t index = global_id[0 ];
144
- local_mem[index] = array_1[index] * array_2[index];
145
- } // kernel lambda
146
- ); // parallel_for
147
- } // task lambda
148
- ); // queue.submit
141
+ if constexpr (std::is_same<_DataType, double >::value || std::is_same<_DataType, float >::value)
142
+ {
143
+ event = mkl_blas::dot (DPNP_QUEUE,
144
+ size,
145
+ array_1,
146
+ 1 , // array_1 stride
147
+ array_2,
148
+ 1 , // array_2 stride
149
+ result);
150
+ event.wait ();
151
+ }
152
+ else
153
+ {
154
+ _DataType* local_mem = reinterpret_cast <_DataType*>(dpnp_memory_alloc_c (size * sizeof (_DataType)));
149
155
150
- event.wait ();
156
+ // what about reduction??
157
+ cl::sycl::range<1 > gws (size);
151
158
152
- auto policy = oneapi::dpl::execution::make_device_policy<class custom_blas_dot_c_kernel <_DataType>>(DPNP_QUEUE);
159
+ auto kernel_parallel_for_func = [=](cl::sycl::id<1 > global_id) {
160
+ const size_t index = global_id[0 ];
161
+ local_mem[index] = array_1[index] * array_2[index];
162
+ };
163
+
164
+ auto kernel_func = [&](cl::sycl::handler& cgh) {
165
+ cgh.parallel_for <class custom_blas_dot_c_kernel <_DataType>>(gws, kernel_parallel_for_func);
166
+ };
167
+
168
+ event = DPNP_QUEUE.submit (kernel_func);
153
169
154
- _DataType accumulator = 0 ;
155
- accumulator = std::reduce (policy, local_mem, local_mem + size, _DataType (0 ), std::plus<_DataType>());
156
- policy.queue ().wait ();
170
+ event.wait ();
171
+
172
+ auto policy = oneapi::dpl::execution::make_device_policy<class custom_blas_dot_c_kernel <_DataType>>(DPNP_QUEUE);
173
+
174
+ _DataType accumulator = 0 ;
175
+ accumulator = std::reduce (policy, local_mem, local_mem + size, _DataType (0 ), std::plus<_DataType>());
176
+ policy.queue ().wait ();
157
177
158
- result[0 ] = accumulator;
178
+ result[0 ] = accumulator;
159
179
160
- free (local_mem, DPNP_QUEUE);
180
+ free (local_mem, DPNP_QUEUE);
181
+ }
161
182
}
162
183
163
- template void custom_blas_dot_c<long >(void * array1_in, void * array2_in, void * result1, size_t size);
164
184
template void custom_blas_dot_c<int >(void * array1_in, void * array2_in, void * result1, size_t size);
185
+ template void custom_blas_dot_c<long >(void * array1_in, void * array2_in, void * result1, size_t size);
186
+ template void custom_blas_dot_c<float >(void * array1_in, void * array2_in, void * result1, size_t size);
187
+ template void custom_blas_dot_c<double >(void * array1_in, void * array2_in, void * result1, size_t size);
165
188
166
189
template <typename _DataType>
167
190
void custom_lapack_syevd_c (void * array_in, void * result1, size_t size)
0 commit comments