16
16
#include < CL/sycl.hpp>
17
17
#include < limits>
18
18
#include < numeric>
19
- template <typename T> class sycl_subgr ;
19
+
20
+ template <typename T, bool UseNewSyntax> class sycl_subgr ;
20
21
using namespace cl ::sycl;
21
- template <typename T> void check (queue &Queue, size_t G = 240 , size_t L = 60 ) {
22
+ template <typename T, bool UseNewSyntax = false >
23
+ void check (queue &Queue, size_t G = 240 , size_t L = 60 ) {
22
24
try {
23
25
nd_range<1 > NdRange (G, L);
24
26
std::vector<T> data (G);
@@ -29,21 +31,26 @@ template <typename T> void check(queue &Queue, size_t G = 240, size_t L = 60) {
29
31
auto addacc = addbuf.template get_access <access::mode::read_write>(cgh);
30
32
auto sgsizeacc = sgsizebuf.get_access <access::mode::read_write>(cgh);
31
33
32
- cgh.parallel_for <sycl_subgr<T>>(NdRange, [=](nd_item<1 > NdItem) {
33
- ONEAPI::sub_group SG = NdItem.get_sub_group ();
34
- size_t lid = SG.get_local_id ().get (0 );
35
- size_t gid = NdItem.get_global_id (0 );
36
- size_t SGoff = gid - lid;
34
+ cgh.parallel_for <sycl_subgr<T, UseNewSyntax>>(
35
+ NdRange, [=](nd_item<1 > NdItem) {
36
+ ONEAPI::sub_group SG = NdItem.get_sub_group ();
37
+ size_t lid = SG.get_local_id ().get (0 );
38
+ size_t gid = NdItem.get_global_id (0 );
39
+ size_t SGoff = gid - lid;
37
40
38
- T res = 0 ;
39
- for (size_t i = 0 ; i <= lid; i++) {
40
- res += addacc[SGoff + i];
41
- }
42
- SG.barrier (access::fence_space::global_space);
43
- addacc[gid] = res;
44
- if (NdItem.get_global_id (0 ) == 0 )
45
- sgsizeacc[0 ] = SG.get_max_local_range ()[0 ];
46
- });
41
+ T res = 0 ;
42
+ for (size_t i = 0 ; i <= lid; i++) {
43
+ res += addacc[SGoff + i];
44
+ }
45
+ if constexpr (UseNewSyntax) {
46
+ group_barrier (SG);
47
+ } else {
48
+ SG.barrier (access::fence_space::global_space);
49
+ }
50
+ addacc[gid] = res;
51
+ if (NdItem.get_global_id (0 ) == 0 )
52
+ sgsizeacc[0 ] = SG.get_max_local_range ()[0 ];
53
+ });
47
54
});
48
55
auto addacc = addbuf.template get_access <access::mode::read_write>();
49
56
auto sgsizeacc = sgsizebuf.get_access <access::mode::read_write>();
@@ -79,8 +86,14 @@ int main() {
79
86
check<long >(Queue);
80
87
check<unsigned long >(Queue);
81
88
check<float >(Queue);
89
+ check<int , true >(Queue);
90
+ check<unsigned int , true >(Queue);
91
+ check<long , true >(Queue);
92
+ check<unsigned long , true >(Queue);
93
+ check<float , true >(Queue);
82
94
if (Queue.get_device ().has_extension (" cl_khr_fp64" )) {
83
95
check<double >(Queue);
96
+ check<double , true >(Queue);
84
97
}
85
98
std::cout << " Test passed." << std::endl;
86
99
return 0 ;
0 commit comments