@@ -59,26 +59,27 @@ void matrix_verify_add(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
59
59
q.submit ([&](handler &cgh) {
60
60
auto accA = bufA.get_access <access::mode::read_write>(cgh);
61
61
62
- cgh.parallel_for <class add_matrix >(r, [accA](nd_item<2 > spmd_item) {
63
- const auto global_idx = spmd_item.get_global_id (0 );
64
- const auto global_idy = spmd_item.get_global_id (1 );
65
- const auto sg_startx = global_idx - spmd_item.get_local_id (0 );
66
- const auto sg_starty = global_idy - spmd_item.get_local_id (1 );
67
-
68
- ext::oneapi::sub_group sg = spmd_item.get_sub_group ();
69
- joint_matrix<T, TM, TK> sub_a (sg);
70
-
71
- joint_matrix_fill (sg, sub_a, 5.0 );
72
-
73
- auto wi_slice_a = sub_a.get_wi_data ();
74
- for (int i = 0 ; i < wi_slice_a.length (); i++) {
75
- wi_slice_a[i] = wi_slice_a[i] + 2 ;
76
- }
77
- joint_matrix_store (sg, sub_a,
78
- accA.get_pointer () + (sg_startx * TM) * N +
79
- sg_starty / SG_SZ * TN,
80
- N, matrix_layout::row_major);
81
- }); // parallel for
62
+ cgh.parallel_for <class add_matrix >(
63
+ r, [accA](nd_item<2 > spmd_item) [[intel::reqd_sub_group_size (SG_SZ)]] {
64
+ const auto global_idx = spmd_item.get_global_id (0 );
65
+ const auto global_idy = spmd_item.get_global_id (1 );
66
+ const auto sg_startx = global_idx - spmd_item.get_local_id (0 );
67
+ const auto sg_starty = global_idy - spmd_item.get_local_id (1 );
68
+
69
+ ext::oneapi::sub_group sg = spmd_item.get_sub_group ();
70
+ joint_matrix<T, TM, TK> sub_a (sg);
71
+
72
+ joint_matrix_fill (sg, sub_a, 5.0 );
73
+
74
+ auto wi_slice_a = sub_a.get_wi_data ();
75
+ for (int i = 0 ; i < wi_slice_a.length (); i++) {
76
+ wi_slice_a[i] = wi_slice_a[i] + 2 ;
77
+ }
78
+ joint_matrix_store (sg, sub_a,
79
+ accA.get_pointer () + (sg_startx * TM) * N +
80
+ sg_starty / SG_SZ * TN,
81
+ N, matrix_layout::row_major);
82
+ }); // parallel for
82
83
}).wait ();
83
84
assert_ops_ref<T, M, N>(bufA.get_access <access::mode::read>(), ref);
84
85
}
@@ -91,26 +92,27 @@ void matrix_verify_sub(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
91
92
q.submit ([&](handler &cgh) {
92
93
auto accA = bufA.get_access <access::mode::read_write>(cgh);
93
94
94
- cgh.parallel_for <class sub_matrix >(r, [accA](nd_item<2 > spmd_item) {
95
- const auto global_idx = spmd_item.get_global_id (0 );
96
- const auto global_idy = spmd_item.get_global_id (1 );
97
- const auto sg_startx = global_idx - spmd_item.get_local_id (0 );
98
- const auto sg_starty = global_idy - spmd_item.get_local_id (1 );
99
-
100
- ext::oneapi::sub_group sg = spmd_item.get_sub_group ();
101
- joint_matrix<T, TM, TK> sub_a (sg);
102
-
103
- joint_matrix_fill (sg, sub_a, 5.0 );
104
-
105
- auto wi_slice_a = sub_a.get_wi_data ();
106
- for (int i = 0 ; i < wi_slice_a.length (); i++) {
107
- wi_slice_a[i] = wi_slice_a[i] - 2 ;
108
- }
109
- joint_matrix_store (sg, sub_a,
110
- accA.get_pointer () + (sg_startx * TM) * N +
111
- sg_starty / SG_SZ * TN,
112
- N, matrix_layout::row_major);
113
- }); // parallel for
95
+ cgh.parallel_for <class sub_matrix >(
96
+ r, [accA](nd_item<2 > spmd_item) [[intel::reqd_sub_group_size (SG_SZ)]] {
97
+ const auto global_idx = spmd_item.get_global_id (0 );
98
+ const auto global_idy = spmd_item.get_global_id (1 );
99
+ const auto sg_startx = global_idx - spmd_item.get_local_id (0 );
100
+ const auto sg_starty = global_idy - spmd_item.get_local_id (1 );
101
+
102
+ ext::oneapi::sub_group sg = spmd_item.get_sub_group ();
103
+ joint_matrix<T, TM, TK> sub_a (sg);
104
+
105
+ joint_matrix_fill (sg, sub_a, 5.0 );
106
+
107
+ auto wi_slice_a = sub_a.get_wi_data ();
108
+ for (int i = 0 ; i < wi_slice_a.length (); i++) {
109
+ wi_slice_a[i] = wi_slice_a[i] - 2 ;
110
+ }
111
+ joint_matrix_store (sg, sub_a,
112
+ accA.get_pointer () + (sg_startx * TM) * N +
113
+ sg_starty / SG_SZ * TN,
114
+ N, matrix_layout::row_major);
115
+ }); // parallel for
114
116
}).wait ();
115
117
assert_ops_ref<T, M, N>(bufA.get_access <access::mode::read>(), ref);
116
118
}
@@ -123,26 +125,27 @@ void matrix_verify_mul(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
123
125
q.submit ([&](handler &cgh) {
124
126
auto accA = bufA.get_access <access::mode::read_write>(cgh);
125
127
126
- cgh.parallel_for <class mul_matrix >(r, [accA](nd_item<2 > spmd_item) {
127
- const auto global_idx = spmd_item.get_global_id (0 );
128
- const auto global_idy = spmd_item.get_global_id (1 );
129
- const auto sg_startx = global_idx - spmd_item.get_local_id (0 );
130
- const auto sg_starty = global_idy - spmd_item.get_local_id (1 );
131
-
132
- ext::oneapi::sub_group sg = spmd_item.get_sub_group ();
133
- joint_matrix<T, TM, TK> sub_a (sg);
134
-
135
- joint_matrix_fill (sg, sub_a, 5.0 );
136
-
137
- auto wi_slice_a = sub_a.get_wi_data ();
138
- for (int i = 0 ; i < wi_slice_a.length (); i++) {
139
- wi_slice_a[i] = wi_slice_a[i] * 3.0 ;
140
- }
141
- joint_matrix_store (sg, sub_a,
142
- accA.get_pointer () + (sg_startx * TM) * N +
143
- sg_starty / SG_SZ * TN,
144
- N, matrix_layout::row_major);
145
- }); // parallel for
128
+ cgh.parallel_for <class mul_matrix >(
129
+ r, [accA](nd_item<2 > spmd_item) [[intel::reqd_sub_group_size (SG_SZ)]] {
130
+ const auto global_idx = spmd_item.get_global_id (0 );
131
+ const auto global_idy = spmd_item.get_global_id (1 );
132
+ const auto sg_startx = global_idx - spmd_item.get_local_id (0 );
133
+ const auto sg_starty = global_idy - spmd_item.get_local_id (1 );
134
+
135
+ ext::oneapi::sub_group sg = spmd_item.get_sub_group ();
136
+ joint_matrix<T, TM, TK> sub_a (sg);
137
+
138
+ joint_matrix_fill (sg, sub_a, 5.0 );
139
+
140
+ auto wi_slice_a = sub_a.get_wi_data ();
141
+ for (int i = 0 ; i < wi_slice_a.length (); i++) {
142
+ wi_slice_a[i] = wi_slice_a[i] * 3.0 ;
143
+ }
144
+ joint_matrix_store (sg, sub_a,
145
+ accA.get_pointer () + (sg_startx * TM) * N +
146
+ sg_starty / SG_SZ * TN,
147
+ N, matrix_layout::row_major);
148
+ }); // parallel for
146
149
}).wait ();
147
150
assert_ops_ref<T, M, N>(bufA.get_access <access::mode::read>(), ref);
148
151
}
@@ -155,26 +158,27 @@ void matrix_verify_div(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
155
158
q.submit ([&](handler &cgh) {
156
159
auto accA = bufA.get_access <access::mode::read_write>(cgh);
157
160
158
- cgh.parallel_for <class div_matrix >(r, [accA](nd_item<2 > spmd_item) {
159
- const auto global_idx = spmd_item.get_global_id (0 );
160
- const auto global_idy = spmd_item.get_global_id (1 );
161
- const auto sg_startx = global_idx - spmd_item.get_local_id (0 );
162
- const auto sg_starty = global_idy - spmd_item.get_local_id (1 );
163
-
164
- ext::oneapi::sub_group sg = spmd_item.get_sub_group ();
165
- joint_matrix<T, TM, TK> sub_a (sg);
166
-
167
- joint_matrix_fill (sg, sub_a, 4.0 );
168
-
169
- auto wi_slice_a = sub_a.get_wi_data ();
170
- for (int i = 0 ; i < wi_slice_a.length (); i++) {
171
- wi_slice_a[i] = wi_slice_a[i] / 2.0 ;
172
- }
173
- joint_matrix_store (sg, sub_a,
174
- accA.get_pointer () + (sg_startx * TM) * N +
175
- sg_starty / SG_SZ * TN,
176
- N, matrix_layout::row_major);
177
- }); // parallel for
161
+ cgh.parallel_for <class div_matrix >(
162
+ r, [accA](nd_item<2 > spmd_item) [[intel::reqd_sub_group_size (SG_SZ)]] {
163
+ const auto global_idx = spmd_item.get_global_id (0 );
164
+ const auto global_idy = spmd_item.get_global_id (1 );
165
+ const auto sg_startx = global_idx - spmd_item.get_local_id (0 );
166
+ const auto sg_starty = global_idy - spmd_item.get_local_id (1 );
167
+
168
+ ext::oneapi::sub_group sg = spmd_item.get_sub_group ();
169
+ joint_matrix<T, TM, TK> sub_a (sg);
170
+
171
+ joint_matrix_fill (sg, sub_a, 4.0 );
172
+
173
+ auto wi_slice_a = sub_a.get_wi_data ();
174
+ for (int i = 0 ; i < wi_slice_a.length (); i++) {
175
+ wi_slice_a[i] = wi_slice_a[i] / 2.0 ;
176
+ }
177
+ joint_matrix_store (sg, sub_a,
178
+ accA.get_pointer () + (sg_startx * TM) * N +
179
+ sg_starty / SG_SZ * TN,
180
+ N, matrix_layout::row_major);
181
+ }); // parallel for
178
182
}).wait ();
179
183
assert_ops_ref<T, M, N>(bufA.get_access <access::mode::read>(), ref);
180
184
}
@@ -187,42 +191,43 @@ void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
187
191
q.submit ([&](handler &cgh) {
188
192
auto accA = bufA.get_access <access::mode::read_write>(cgh);
189
193
190
- cgh.parallel_for <class logic_matrix >(r, [accA](nd_item<2 > spmd_item) {
191
- const auto global_idx = spmd_item.get_global_id (0 );
192
- const auto global_idy = spmd_item.get_global_id (1 );
193
- const auto sg_startx = global_idx - spmd_item.get_local_id (0 );
194
- const auto sg_starty = global_idy - spmd_item.get_local_id (1 );
195
-
196
- ext::oneapi::sub_group sg = spmd_item.get_sub_group ();
197
- joint_matrix<T, TM, TK> sub_a (sg);
198
-
199
- joint_matrix_fill (sg, sub_a, 5.0 );
200
-
201
- auto wi_slice_a = sub_a.get_wi_data ();
202
- for (int i = 0 ; i < wi_slice_a.length (); i++) {
203
- if (wi_slice_a[i]) {
204
- if (wi_slice_a[i] > 2.0 || wi_slice_a[i] >= 2.0 ||
205
- wi_slice_a[i] < 2.0 || wi_slice_a[i] <= 2.0 ) {
206
- T val = (wi_slice_a[i] != 2.0 ) ? wi_slice_a[i]
207
- : static_cast <half>(2.0 );
208
- val--;
209
- val++;
210
- if (wi_slice_a[i] == 2.0 ) {
211
- val -= 2 ;
212
- val *= 3.0 ;
213
- val /= 2.0 ;
214
- } else {
215
- val += 2 ;
194
+ cgh.parallel_for <class logic_matrix >(
195
+ r, [accA](nd_item<2 > spmd_item) [[intel::reqd_sub_group_size (SG_SZ)]] {
196
+ const auto global_idx = spmd_item.get_global_id (0 );
197
+ const auto global_idy = spmd_item.get_global_id (1 );
198
+ const auto sg_startx = global_idx - spmd_item.get_local_id (0 );
199
+ const auto sg_starty = global_idy - spmd_item.get_local_id (1 );
200
+
201
+ ext::oneapi::sub_group sg = spmd_item.get_sub_group ();
202
+ joint_matrix<T, TM, TK> sub_a (sg);
203
+
204
+ joint_matrix_fill (sg, sub_a, 5.0 );
205
+
206
+ auto wi_slice_a = sub_a.get_wi_data ();
207
+ for (int i = 0 ; i < wi_slice_a.length (); i++) {
208
+ if (wi_slice_a[i]) {
209
+ if (wi_slice_a[i] > 2.0 || wi_slice_a[i] >= 2.0 ||
210
+ wi_slice_a[i] < 2.0 || wi_slice_a[i] <= 2.0 ) {
211
+ T val = (wi_slice_a[i] != 2.0 ) ? wi_slice_a[i]
212
+ : static_cast <half>(2.0 );
213
+ val--;
214
+ val++;
215
+ if (wi_slice_a[i] == 2.0 ) {
216
+ val -= 2 ;
217
+ val *= 3.0 ;
218
+ val /= 2.0 ;
219
+ } else {
220
+ val += 2 ;
221
+ }
222
+ wi_slice_a[i] = val;
223
+ }
216
224
}
217
- wi_slice_a[i] = val;
218
225
}
219
- }
220
- }
221
- joint_matrix_store (sg, sub_a,
222
- accA.get_pointer () + (sg_startx * TM) * N +
223
- sg_starty / SG_SZ * TN,
224
- N, matrix_layout::row_major);
225
- }); // parallel for
226
+ joint_matrix_store (sg, sub_a,
227
+ accA.get_pointer () + (sg_startx * TM) * N +
228
+ sg_starty / SG_SZ * TN,
229
+ N, matrix_layout::row_major);
230
+ }); // parallel for
226
231
}).wait ();
227
232
assert_ops_ref<T, M, N>(bufA.get_access <access::mode::read>(), ref);
228
233
}
0 commit comments