|
41 | 41 | using dpctl::tensor::kernels::alignment_utils::is_aligned;
|
42 | 42 | using dpctl::tensor::kernels::alignment_utils::required_alignment;
|
43 | 43 |
|
| 44 | +using sycl::ext::oneapi::experimental::group_load; |
| 45 | +using sycl::ext::oneapi::experimental::group_store; |
| 46 | + |
44 | 47 | template <typename T>
|
45 | 48 | constexpr T dispatch_erf_op(T elem)
|
46 | 49 | {
|
@@ -523,41 +526,49 @@ static void func_map_init_elemwise_1arg_1type(func_map_t &fmap)
|
523 | 526 | _DataType_input2, \
|
524 | 527 | _DataType_output>) \
|
525 | 528 | { \
|
526 |
| - sycl::vec<_DataType_input1, vec_sz> x1 = \ |
527 |
| - sg.load<vec_sz>(input1_multi_ptr); \ |
528 |
| - sycl::vec<_DataType_input2, vec_sz> x2 = \ |
529 |
| - sg.load<vec_sz>(input2_multi_ptr); \ |
| 529 | + sycl::vec<_DataType_input1, vec_sz> x1{}; \ |
| 530 | + sycl::vec<_DataType_input2, vec_sz> x2{}; \ |
| 531 | + \ |
| 532 | + group_load(sg, input1_multi_ptr, x1); \ |
| 533 | + group_load(sg, input2_multi_ptr, x2); \ |
530 | 534 | \
|
531 | 535 | res_vec = __vec_operation__; \
|
532 | 536 | } \
|
533 | 537 | else /* input types don't match result type, so \
|
534 | 538 | explicit casting is required */ \
|
535 | 539 | { \
|
| 540 | + sycl::vec<_DataType_input1, vec_sz> tmp_x1{}; \ |
| 541 | + sycl::vec<_DataType_input2, vec_sz> tmp_x2{}; \ |
| 542 | + \ |
| 543 | + group_load(sg, input1_multi_ptr, tmp_x1); \ |
| 544 | + group_load(sg, input2_multi_ptr, tmp_x2); \ |
| 545 | + \ |
536 | 546 | sycl::vec<_DataType_output, vec_sz> x1 = \
|
537 | 547 | dpnp_vec_cast<_DataType_output, \
|
538 | 548 | _DataType_input1, vec_sz>( \
|
539 |
| - sg.load<vec_sz>(input1_multi_ptr)); \ |
| 549 | + tmp_x1); \ |
540 | 550 | sycl::vec<_DataType_output, vec_sz> x2 = \
|
541 | 551 | dpnp_vec_cast<_DataType_output, \
|
542 | 552 | _DataType_input2, vec_sz>( \
|
543 |
| - sg.load<vec_sz>(input2_multi_ptr)); \ |
| 553 | + tmp_x2); \ |
544 | 554 | \
|
545 | 555 | res_vec = __vec_operation__; \
|
546 | 556 | } \
|
547 | 557 | } \
|
548 | 558 | else { \
|
549 |
| - sycl::vec<_DataType_input1, vec_sz> x1 = \ |
550 |
| - sg.load<vec_sz>(input1_multi_ptr); \ |
551 |
| - sycl::vec<_DataType_input2, vec_sz> x2 = \ |
552 |
| - sg.load<vec_sz>(input2_multi_ptr); \ |
| 559 | + sycl::vec<_DataType_input1, vec_sz> x1{}; \ |
| 560 | + sycl::vec<_DataType_input2, vec_sz> x2{}; \ |
| 561 | + \ |
| 562 | + group_load(sg, input1_multi_ptr, x1); \ |
| 563 | + group_load(sg, input2_multi_ptr, x2); \ |
553 | 564 | \
|
554 | 565 | for (size_t k = 0; k < vec_sz; ++k) { \
|
555 | 566 | const _DataType_output input1_elem = x1[k]; \
|
556 | 567 | const _DataType_output input2_elem = x2[k]; \
|
557 | 568 | res_vec[k] = __operation__; \
|
558 | 569 | } \
|
559 | 570 | } \
|
560 |
| - sg.store<vec_sz>(result_multi_ptr, res_vec); \ |
| 571 | + group_store(sg, res_vec, result_multi_ptr); \ |
561 | 572 | } \
|
562 | 573 | else { \
|
563 | 574 | for (size_t k = start + sg.get_local_id()[0]; \
|
|
0 commit comments