41
41
using dpctl::tensor::kernels::alignment_utils::is_aligned;
42
42
using dpctl::tensor::kernels::alignment_utils::required_alignment;
43
43
44
- using sycl::ext::oneapi::experimental::group_load;
45
- using sycl::ext::oneapi::experimental::group_store;
44
+ namespace syclex = sycl::ext::oneapi::experimental;
45
+ using syclex::group_load;
46
+ using syclex::group_store;
47
+
48
+ constexpr auto striped = syclex::properties{syclex::data_placement_striped};
46
49
47
50
template <typename T>
48
51
constexpr T dispatch_erf_op (T elem)
@@ -529,8 +532,8 @@ static void func_map_init_elemwise_1arg_1type(func_map_t &fmap)
529
532
sycl::vec<_DataType_input1, vec_sz> x1{}; \
530
533
sycl::vec<_DataType_input2, vec_sz> x2{}; \
531
534
\
532
- group_load (sg, input1_multi_ptr, x1); \
533
- group_load (sg, input2_multi_ptr, x2); \
535
+ group_load (sg, input1_multi_ptr, x1, striped); \
536
+ group_load (sg, input2_multi_ptr, x2, striped); \
534
537
\
535
538
res_vec = __vec_operation__; \
536
539
} \
@@ -540,8 +543,10 @@ static void func_map_init_elemwise_1arg_1type(func_map_t &fmap)
540
543
sycl::vec<_DataType_input1, vec_sz> tmp_x1{}; \
541
544
sycl::vec<_DataType_input2, vec_sz> tmp_x2{}; \
542
545
\
543
- group_load (sg, input1_multi_ptr, tmp_x1); \
544
- group_load (sg, input2_multi_ptr, tmp_x2); \
546
+ group_load (sg, input1_multi_ptr, tmp_x1, \
547
+ striped); \
548
+ group_load (sg, input2_multi_ptr, tmp_x2, \
549
+ striped); \
545
550
\
546
551
sycl::vec<_DataType_output, vec_sz> x1 = \
547
552
dpnp_vec_cast<_DataType_output, \
@@ -559,16 +564,16 @@ static void func_map_init_elemwise_1arg_1type(func_map_t &fmap)
559
564
sycl::vec<_DataType_input1, vec_sz> x1{}; \
560
565
sycl::vec<_DataType_input2, vec_sz> x2{}; \
561
566
\
562
- group_load (sg, input1_multi_ptr, x1); \
563
- group_load (sg, input2_multi_ptr, x2); \
567
+ group_load (sg, input1_multi_ptr, x1, striped); \
568
+ group_load (sg, input2_multi_ptr, x2, striped); \
564
569
\
565
570
for (size_t k = 0 ; k < vec_sz; ++k) { \
566
571
const _DataType_output input1_elem = x1[k]; \
567
572
const _DataType_output input2_elem = x2[k]; \
568
573
res_vec[k] = __operation__; \
569
574
} \
570
575
} \
571
- group_store (sg, res_vec, result_multi_ptr); \
576
+ group_store (sg, res_vec, result_multi_ptr, striped); \
572
577
} \
573
578
else { \
574
579
for (size_t k = start + sg.get_local_id ()[0 ]; \
0 commit comments