@@ -232,6 +232,79 @@ void dpnp_modf_c(void* array1_in, void* result1_out, void* result2_out, size_t s
232
232
event.wait ();
233
233
}
234
234
235
+ template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
236
+ class dpnp_multiply_c_kernel ;
237
+
238
+ template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
239
+ void dpnp_multiply_c (void * result_out,
240
+ const void * input1_in,
241
+ const size_t input1_size,
242
+ const size_t * input1_shape,
243
+ const size_t input1_shape_ndim,
244
+ const void * input2_in,
245
+ const size_t input2_size,
246
+ const size_t * input2_shape,
247
+ const size_t input2_shape_ndim,
248
+ const size_t * where)
249
+ {
250
+ // avoid warning unused variable
251
+ (void )input1_shape;
252
+ (void )input1_shape_ndim;
253
+ (void )input2_shape;
254
+ (void )input2_shape_ndim;
255
+ (void )where;
256
+
257
+ if (!input1_size || !input2_size)
258
+ {
259
+ return ;
260
+ }
261
+
262
+ const size_t result_size = (input2_size > input1_size) ? input2_size : input1_size;
263
+
264
+ const _DataType_input1* input1_data = reinterpret_cast <const _DataType_input1*>(input1_in);
265
+ const _DataType_input2* input2_data = reinterpret_cast <const _DataType_input2*>(input2_in);
266
+ _DataType_output* result = reinterpret_cast <_DataType_output*>(result_out);
267
+
268
+ cl::sycl::range<1 > gws (result_size);
269
+ auto kernel_parallel_for_func = [=](cl::sycl::id<1 > global_id) {
270
+ size_t i = global_id[0 ]; /* for (size_t i = 0; i < result_size; ++i)*/
271
+ {
272
+ const _DataType_input1 input1_elem = (input1_size == 1 ) ? input1_data[0 ] : input1_data[i];
273
+ const _DataType_input2 input2_elem = (input2_size == 1 ) ? input2_data[0 ] : input2_data[i];
274
+ result[i] = input1_elem * input2_elem;
275
+ }
276
+ };
277
+ auto kernel_func = [&](cl::sycl::handler& cgh) {
278
+ cgh.parallel_for <class dpnp_multiply_c_kernel <_DataType_output, _DataType_input1,
279
+ _DataType_input2>>(gws, kernel_parallel_for_func);
280
+ };
281
+
282
+ cl::sycl::event event;
283
+
284
+ if (input1_size == input2_size)
285
+ {
286
+ if constexpr ((std::is_same<_DataType_input1, double >::value ||
287
+ std::is_same<_DataType_input1, float >::value) &&
288
+ std::is_same<_DataType_input2, _DataType_input1>::value)
289
+ {
290
+ _DataType_input1* input1 = const_cast <_DataType_input1*>(input1_data);
291
+ _DataType_input2* input2 = const_cast <_DataType_input2*>(input2_data);
292
+ // https://docs.oneapi.com/versions/latest/onemkl/mul.html
293
+ event = oneapi::mkl::vm::mul (DPNP_QUEUE, result_size, input1, input2, result);
294
+ }
295
+ else
296
+ {
297
+ event = DPNP_QUEUE.submit (kernel_func);
298
+ }
299
+ }
300
+ else
301
+ {
302
+ event = DPNP_QUEUE.submit (kernel_func);
303
+ }
304
+
305
+ event.wait ();
306
+ }
307
+
235
308
template <typename _KernelNameSpecialization1, typename _KernelNameSpecialization2, typename _KernelNameSpecialization3>
236
309
class dpnp_remainder_c_kernel ;
237
310
@@ -411,6 +484,34 @@ void func_map_init_mathematical(func_map_t& fmap)
411
484
fmap[DPNPFuncName::DPNP_FN_MODF][eft_FLT][eft_FLT] = {eft_FLT, (void *)dpnp_modf_c<float , float >};
412
485
fmap[DPNPFuncName::DPNP_FN_MODF][eft_DBL][eft_DBL] = {eft_DBL, (void *)dpnp_modf_c<double , double >};
413
486
487
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_BLN][eft_BLN] = {eft_BLN, (void *)dpnp_multiply_c<bool , bool , bool >};
488
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_BLN][eft_INT] = {eft_INT, (void *)dpnp_multiply_c<int , bool , int >};
489
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_BLN][eft_LNG] = {eft_LNG, (void *)dpnp_multiply_c<long , bool , long >};
490
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_BLN][eft_FLT] = {eft_FLT, (void *)dpnp_multiply_c<float , bool , float >};
491
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_BLN][eft_DBL] = {eft_DBL, (void *)dpnp_multiply_c<double , bool , double >};
492
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_INT][eft_BLN] = {eft_INT, (void *)dpnp_multiply_c<int , int , bool >};
493
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_INT][eft_INT] = {eft_INT, (void *)dpnp_multiply_c<int , int , int >};
494
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_INT][eft_LNG] = {eft_LNG, (void *)dpnp_multiply_c<long , int , long >};
495
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_INT][eft_FLT] = {eft_DBL, (void *)dpnp_multiply_c<double , int , float >};
496
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_INT][eft_DBL] = {eft_DBL, (void *)dpnp_multiply_c<double , int , double >};
497
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_LNG][eft_BLN] = {eft_LNG, (void *)dpnp_multiply_c<long , long , bool >};
498
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_LNG][eft_INT] = {eft_LNG, (void *)dpnp_multiply_c<long , long , int >};
499
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_LNG][eft_LNG] = {eft_LNG, (void *)dpnp_multiply_c<long , long , long >};
500
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_LNG][eft_FLT] = {eft_DBL, (void *)dpnp_multiply_c<double , long , float >};
501
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_LNG][eft_DBL] = {eft_DBL, (void *)dpnp_multiply_c<double , long , double >};
502
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_FLT][eft_BLN] = {eft_FLT, (void *)dpnp_multiply_c<float , float , bool >};
503
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_FLT][eft_INT] = {eft_DBL, (void *)dpnp_multiply_c<double , float , int >};
504
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_FLT][eft_LNG] = {eft_DBL, (void *)dpnp_multiply_c<double , float , long >};
505
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_FLT][eft_FLT] = {eft_FLT, (void *)dpnp_multiply_c<float , float , float >};
506
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_FLT][eft_DBL] = {eft_DBL, (void *)dpnp_multiply_c<double , float , double >};
507
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_DBL][eft_BLN] = {eft_DBL, (void *)dpnp_multiply_c<double , double , bool >};
508
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_DBL][eft_INT] = {eft_DBL, (void *)dpnp_multiply_c<double , double , int >};
509
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_DBL][eft_LNG] = {eft_DBL, (void *)dpnp_multiply_c<double , double , long >};
510
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_DBL][eft_FLT] = {eft_DBL, (void *)dpnp_multiply_c<double , double , float >};
511
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_DBL][eft_DBL] = {eft_DBL, (void *)dpnp_multiply_c<double , double , double >};
512
+ fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_C128][eft_C128] = {
513
+ eft_C128, (void *)dpnp_multiply_c<std::complex<double >, std::complex<double >, std::complex<double >>};
514
+
414
515
fmap[DPNPFuncName::DPNP_FN_REMAINDER][eft_INT][eft_INT] = {eft_INT, (void *)dpnp_remainder_c<int , int , int >};
415
516
fmap[DPNPFuncName::DPNP_FN_REMAINDER][eft_INT][eft_LNG] = {eft_LNG, (void *)dpnp_remainder_c<int , long , long >};
416
517
fmap[DPNPFuncName::DPNP_FN_REMAINDER][eft_INT][eft_FLT] = {eft_DBL, (void *)dpnp_remainder_c<int , float , double >};
0 commit comments