@@ -23,10 +23,10 @@ void assert_ops_ref(host_accessor<T, 2, access::mode::read> C,
23
23
std::numeric_limits<float >::epsilon ());
24
24
}
25
25
}
26
- template <typename T, typename Ts, size_t M, size_t N >
27
- void matrix_verify_add (queue q, big_matrix<Ts, M, N > &A, nd_range<2 > &r,
26
+ template <typename T, typename Ts, size_t M, size_t K >
27
+ void matrix_verify_add (queue q, big_matrix<Ts, M, K > &A, nd_range<2 > &r,
28
28
const float ref) {
29
- buffer<Ts, 2 > bufA (A.get_data (), range<2 >(M, N ));
29
+ buffer<Ts, 2 > bufA (A.get_data (), range<2 >(M, K ));
30
30
31
31
q.submit ([&](handler &cgh) {
32
32
sycl::accessor accA{bufA, cgh, sycl::read_write};
@@ -52,17 +52,17 @@ void matrix_verify_add(queue q, big_matrix<Ts, M, N> &A, nd_range<2> &r,
52
52
ext::intel::experimental::matrix::joint_matrix_store (
53
53
sg, sub_a,
54
54
accA.template get_multi_ptr <access::decorated::no>() +
55
- (sg_startx * TM) * N + sg_starty / SG_SZ * TK,
56
- N );
55
+ (sg_startx * TM) * K + sg_starty / SG_SZ * TK,
56
+ K );
57
57
}); // parallel for
58
58
}).wait ();
59
- assert_ops_ref<Ts, M, N >(bufA.get_host_access (sycl::read_only), ref);
59
+ assert_ops_ref<Ts, M, K >(bufA.get_host_access (sycl::read_only), ref);
60
60
}
61
61
62
- template <typename T, typename Ts, size_t M, size_t N >
63
- void matrix_verify_sub (queue q, big_matrix<Ts, M, N > &A, nd_range<2 > &r,
62
+ template <typename T, typename Ts, size_t M, size_t K >
63
+ void matrix_verify_sub (queue q, big_matrix<Ts, M, K > &A, nd_range<2 > &r,
64
64
const float ref) {
65
- buffer<Ts, 2 > bufA (A.get_data (), range<2 >(M, N ));
65
+ buffer<Ts, 2 > bufA (A.get_data (), range<2 >(M, K ));
66
66
67
67
q.submit ([&](handler &cgh) {
68
68
sycl::accessor accA{bufA, cgh, sycl::read_write};
@@ -87,17 +87,17 @@ void matrix_verify_sub(queue q, big_matrix<Ts, M, N> &A, nd_range<2> &r,
87
87
ext::intel::experimental::matrix::joint_matrix_store (
88
88
sg, sub_a,
89
89
accA.template get_multi_ptr <access::decorated::no>() +
90
- (sg_startx * TM) * N + sg_starty / SG_SZ * TK,
91
- N );
90
+ (sg_startx * TM) * K + sg_starty / SG_SZ * TK,
91
+ K );
92
92
}); // parallel for
93
93
}).wait ();
94
- assert_ops_ref<Ts, M, N >(bufA.get_host_access (sycl::read_only), ref);
94
+ assert_ops_ref<Ts, M, K >(bufA.get_host_access (sycl::read_only), ref);
95
95
}
96
96
97
- template <typename T, typename Ts, size_t M, size_t N >
98
- void matrix_verify_mul (queue q, big_matrix<Ts, M, N > &A, nd_range<2 > &r,
97
+ template <typename T, typename Ts, size_t M, size_t K >
98
+ void matrix_verify_mul (queue q, big_matrix<Ts, M, K > &A, nd_range<2 > &r,
99
99
const float ref) {
100
- buffer<Ts, 2 > bufA (A.get_data (), range<2 >(M, N ));
100
+ buffer<Ts, 2 > bufA (A.get_data (), range<2 >(M, K ));
101
101
102
102
q.submit ([&](handler &cgh) {
103
103
sycl::accessor accA{bufA, cgh, sycl::read_write};
@@ -121,17 +121,17 @@ void matrix_verify_mul(queue q, big_matrix<Ts, M, N> &A, nd_range<2> &r,
121
121
ext::intel::experimental::matrix::joint_matrix_store (
122
122
sg, sub_a,
123
123
accA.template get_multi_ptr <access::decorated::no>() +
124
- (sg_startx * TM) * N + sg_starty / SG_SZ * TK,
125
- N );
124
+ (sg_startx * TM) * K + sg_starty / SG_SZ * TK,
125
+ K );
126
126
}); // parallel for
127
127
}).wait ();
128
- assert_ops_ref<Ts, M, N >(bufA.get_host_access (sycl::read_only), ref);
128
+ assert_ops_ref<Ts, M, K >(bufA.get_host_access (sycl::read_only), ref);
129
129
}
130
130
131
- template <typename T, typename Ts, size_t M, size_t N >
132
- void matrix_verify_div (queue q, big_matrix<Ts, M, N > &A, nd_range<2 > &r,
131
+ template <typename T, typename Ts, size_t M, size_t K >
132
+ void matrix_verify_div (queue q, big_matrix<Ts, M, K > &A, nd_range<2 > &r,
133
133
const float ref) {
134
- buffer<Ts, 2 > bufA (A.get_data (), range<2 >(M, N ));
134
+ buffer<Ts, 2 > bufA (A.get_data (), range<2 >(M, K ));
135
135
136
136
q.submit ([&](handler &cgh) {
137
137
sycl::accessor accA{bufA, cgh, sycl::read_write};
@@ -156,17 +156,17 @@ void matrix_verify_div(queue q, big_matrix<Ts, M, N> &A, nd_range<2> &r,
156
156
ext::intel::experimental::matrix::joint_matrix_store (
157
157
sg, sub_a,
158
158
accA.template get_multi_ptr <access::decorated::no>() +
159
- (sg_startx * TM) * N + sg_starty / SG_SZ * TK,
160
- N );
159
+ (sg_startx * TM) * K + sg_starty / SG_SZ * TK,
160
+ K );
161
161
}); // parallel for
162
162
}).wait ();
163
- assert_ops_ref<Ts, M, N >(bufA.get_host_access (sycl::read_only), ref);
163
+ assert_ops_ref<Ts, M, K >(bufA.get_host_access (sycl::read_only), ref);
164
164
}
165
165
166
- template <typename T, typename Ts, size_t M, size_t N >
167
- void matrix_verify_logic (queue q, big_matrix<Ts, M, N > &A, nd_range<2 > &r,
166
+ template <typename T, typename Ts, size_t M, size_t K >
167
+ void matrix_verify_logic (queue q, big_matrix<Ts, M, K > &A, nd_range<2 > &r,
168
168
const float ref) {
169
- buffer<Ts, 2 > bufA (A.get_data (), range<2 >(M, N ));
169
+ buffer<Ts, 2 > bufA (A.get_data (), range<2 >(M, K ));
170
170
171
171
q.submit ([&](handler &cgh) {
172
172
sycl::accessor accA{bufA, cgh, sycl::read_write};
@@ -206,33 +206,33 @@ void matrix_verify_logic(queue q, big_matrix<Ts, M, N> &A, nd_range<2> &r,
206
206
ext::intel::experimental::matrix::joint_matrix_store (
207
207
sg, sub_a,
208
208
accA.template get_multi_ptr <access::decorated::no>() +
209
- (sg_startx * TM) * N + sg_starty / SG_SZ * TK,
210
- N );
209
+ (sg_startx * TM) * K + sg_starty / SG_SZ * TK,
210
+ K );
211
211
}); // parallel for
212
212
}).wait ();
213
- assert_ops_ref<Ts, M, N >(bufA.get_host_access (sycl::read_only), ref);
213
+ assert_ops_ref<Ts, M, K >(bufA.get_host_access (sycl::read_only), ref);
214
214
}
215
215
216
216
static constexpr size_t MATRIX_M = TM * 2 ;
217
- static constexpr size_t MATRIX_N = TN * 2 ;
218
- float A[MATRIX_M][MATRIX_N ];
219
- float D[MATRIX_M][MATRIX_N ];
217
+ static constexpr size_t MATRIX_K = TK * 2 ;
218
+ float A[MATRIX_M][MATRIX_K ];
219
+ float D[MATRIX_M][MATRIX_K ];
220
220
221
221
int main () {
222
222
223
- big_matrix<float , MATRIX_M, MATRIX_N > MD ((float *)&D);
224
- big_matrix<float , MATRIX_M, MATRIX_N > MA ((float *)&A);
223
+ big_matrix<float , MATRIX_M, MATRIX_K > MD ((float *)&D);
224
+ big_matrix<float , MATRIX_M, MATRIX_K > MA ((float *)&A);
225
225
226
226
size_t NDRangeM = MATRIX_M / TM;
227
- size_t NDRangeN = MATRIX_N / TK;
227
+ size_t NDRangeK = MATRIX_K / TK;
228
228
queue q;
229
- nd_range<2 > r ({NDRangeM, NDRangeN * SG_SZ}, {1 , 1 * SG_SZ});
229
+ nd_range<2 > r ({NDRangeM, NDRangeK * SG_SZ}, {1 , 1 * SG_SZ});
230
230
231
- matrix_verify_add<precision::tf32, float , MATRIX_M, MATRIX_N >(q, MA, r, 7.0 );
232
- matrix_verify_sub<precision::tf32, float , MATRIX_M, MATRIX_N >(q, MA, r, 3.0 );
233
- matrix_verify_mul<precision::tf32, float , MATRIX_M, MATRIX_N >(q, MA, r, 15.0 );
234
- matrix_verify_div<precision::tf32, float , MATRIX_M, MATRIX_N >(q, MA, r, 2.0 );
235
- matrix_verify_logic<precision::tf32, float , MATRIX_M, MATRIX_N >(q, MA, r,
231
+ matrix_verify_add<precision::tf32, float , MATRIX_M, MATRIX_K >(q, MA, r, 7.0 );
232
+ matrix_verify_sub<precision::tf32, float , MATRIX_M, MATRIX_K >(q, MA, r, 3.0 );
233
+ matrix_verify_mul<precision::tf32, float , MATRIX_M, MATRIX_K >(q, MA, r, 15.0 );
234
+ matrix_verify_div<precision::tf32, float , MATRIX_M, MATRIX_K >(q, MA, r, 2.0 );
235
+ matrix_verify_logic<precision::tf32, float , MATRIX_M, MATRIX_K >(q, MA, r,
236
236
7.0 );
237
237
238
238
return 0 ;
0 commit comments