Skip to content

Commit f22d94c

Browse files
authored
[SYCL][Matrix] Update element_wise_all_ops_tf32_impl.hpp to use consistent dimension names (e.g. M, K, TM, TK). (#10598)
1 parent 5d220bc commit f22d94c

File tree

1 file changed

+42
-42
lines changed

1 file changed

+42
-42
lines changed

sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ void assert_ops_ref(host_accessor<T, 2, access::mode::read> C,
2323
std::numeric_limits<float>::epsilon());
2424
}
2525
}
26-
template <typename T, typename Ts, size_t M, size_t N>
27-
void matrix_verify_add(queue q, big_matrix<Ts, M, N> &A, nd_range<2> &r,
26+
template <typename T, typename Ts, size_t M, size_t K>
27+
void matrix_verify_add(queue q, big_matrix<Ts, M, K> &A, nd_range<2> &r,
2828
const float ref) {
29-
buffer<Ts, 2> bufA(A.get_data(), range<2>(M, N));
29+
buffer<Ts, 2> bufA(A.get_data(), range<2>(M, K));
3030

3131
q.submit([&](handler &cgh) {
3232
sycl::accessor accA{bufA, cgh, sycl::read_write};
@@ -52,17 +52,17 @@ void matrix_verify_add(queue q, big_matrix<Ts, M, N> &A, nd_range<2> &r,
5252
ext::intel::experimental::matrix::joint_matrix_store(
5353
sg, sub_a,
5454
accA.template get_multi_ptr<access::decorated::no>() +
55-
(sg_startx * TM) * N + sg_starty / SG_SZ * TK,
56-
N);
55+
(sg_startx * TM) * K + sg_starty / SG_SZ * TK,
56+
K);
5757
}); // parallel for
5858
}).wait();
59-
assert_ops_ref<Ts, M, N>(bufA.get_host_access(sycl::read_only), ref);
59+
assert_ops_ref<Ts, M, K>(bufA.get_host_access(sycl::read_only), ref);
6060
}
6161

62-
template <typename T, typename Ts, size_t M, size_t N>
63-
void matrix_verify_sub(queue q, big_matrix<Ts, M, N> &A, nd_range<2> &r,
62+
template <typename T, typename Ts, size_t M, size_t K>
63+
void matrix_verify_sub(queue q, big_matrix<Ts, M, K> &A, nd_range<2> &r,
6464
const float ref) {
65-
buffer<Ts, 2> bufA(A.get_data(), range<2>(M, N));
65+
buffer<Ts, 2> bufA(A.get_data(), range<2>(M, K));
6666

6767
q.submit([&](handler &cgh) {
6868
sycl::accessor accA{bufA, cgh, sycl::read_write};
@@ -87,17 +87,17 @@ void matrix_verify_sub(queue q, big_matrix<Ts, M, N> &A, nd_range<2> &r,
8787
ext::intel::experimental::matrix::joint_matrix_store(
8888
sg, sub_a,
8989
accA.template get_multi_ptr<access::decorated::no>() +
90-
(sg_startx * TM) * N + sg_starty / SG_SZ * TK,
91-
N);
90+
(sg_startx * TM) * K + sg_starty / SG_SZ * TK,
91+
K);
9292
}); // parallel for
9393
}).wait();
94-
assert_ops_ref<Ts, M, N>(bufA.get_host_access(sycl::read_only), ref);
94+
assert_ops_ref<Ts, M, K>(bufA.get_host_access(sycl::read_only), ref);
9595
}
9696

97-
template <typename T, typename Ts, size_t M, size_t N>
98-
void matrix_verify_mul(queue q, big_matrix<Ts, M, N> &A, nd_range<2> &r,
97+
template <typename T, typename Ts, size_t M, size_t K>
98+
void matrix_verify_mul(queue q, big_matrix<Ts, M, K> &A, nd_range<2> &r,
9999
const float ref) {
100-
buffer<Ts, 2> bufA(A.get_data(), range<2>(M, N));
100+
buffer<Ts, 2> bufA(A.get_data(), range<2>(M, K));
101101

102102
q.submit([&](handler &cgh) {
103103
sycl::accessor accA{bufA, cgh, sycl::read_write};
@@ -121,17 +121,17 @@ void matrix_verify_mul(queue q, big_matrix<Ts, M, N> &A, nd_range<2> &r,
121121
ext::intel::experimental::matrix::joint_matrix_store(
122122
sg, sub_a,
123123
accA.template get_multi_ptr<access::decorated::no>() +
124-
(sg_startx * TM) * N + sg_starty / SG_SZ * TK,
125-
N);
124+
(sg_startx * TM) * K + sg_starty / SG_SZ * TK,
125+
K);
126126
}); // parallel for
127127
}).wait();
128-
assert_ops_ref<Ts, M, N>(bufA.get_host_access(sycl::read_only), ref);
128+
assert_ops_ref<Ts, M, K>(bufA.get_host_access(sycl::read_only), ref);
129129
}
130130

131-
template <typename T, typename Ts, size_t M, size_t N>
132-
void matrix_verify_div(queue q, big_matrix<Ts, M, N> &A, nd_range<2> &r,
131+
template <typename T, typename Ts, size_t M, size_t K>
132+
void matrix_verify_div(queue q, big_matrix<Ts, M, K> &A, nd_range<2> &r,
133133
const float ref) {
134-
buffer<Ts, 2> bufA(A.get_data(), range<2>(M, N));
134+
buffer<Ts, 2> bufA(A.get_data(), range<2>(M, K));
135135

136136
q.submit([&](handler &cgh) {
137137
sycl::accessor accA{bufA, cgh, sycl::read_write};
@@ -156,17 +156,17 @@ void matrix_verify_div(queue q, big_matrix<Ts, M, N> &A, nd_range<2> &r,
156156
ext::intel::experimental::matrix::joint_matrix_store(
157157
sg, sub_a,
158158
accA.template get_multi_ptr<access::decorated::no>() +
159-
(sg_startx * TM) * N + sg_starty / SG_SZ * TK,
160-
N);
159+
(sg_startx * TM) * K + sg_starty / SG_SZ * TK,
160+
K);
161161
}); // parallel for
162162
}).wait();
163-
assert_ops_ref<Ts, M, N>(bufA.get_host_access(sycl::read_only), ref);
163+
assert_ops_ref<Ts, M, K>(bufA.get_host_access(sycl::read_only), ref);
164164
}
165165

166-
template <typename T, typename Ts, size_t M, size_t N>
167-
void matrix_verify_logic(queue q, big_matrix<Ts, M, N> &A, nd_range<2> &r,
166+
template <typename T, typename Ts, size_t M, size_t K>
167+
void matrix_verify_logic(queue q, big_matrix<Ts, M, K> &A, nd_range<2> &r,
168168
const float ref) {
169-
buffer<Ts, 2> bufA(A.get_data(), range<2>(M, N));
169+
buffer<Ts, 2> bufA(A.get_data(), range<2>(M, K));
170170

171171
q.submit([&](handler &cgh) {
172172
sycl::accessor accA{bufA, cgh, sycl::read_write};
@@ -206,33 +206,33 @@ void matrix_verify_logic(queue q, big_matrix<Ts, M, N> &A, nd_range<2> &r,
206206
ext::intel::experimental::matrix::joint_matrix_store(
207207
sg, sub_a,
208208
accA.template get_multi_ptr<access::decorated::no>() +
209-
(sg_startx * TM) * N + sg_starty / SG_SZ * TK,
210-
N);
209+
(sg_startx * TM) * K + sg_starty / SG_SZ * TK,
210+
K);
211211
}); // parallel for
212212
}).wait();
213-
assert_ops_ref<Ts, M, N>(bufA.get_host_access(sycl::read_only), ref);
213+
assert_ops_ref<Ts, M, K>(bufA.get_host_access(sycl::read_only), ref);
214214
}
215215

216216
static constexpr size_t MATRIX_M = TM * 2;
217-
static constexpr size_t MATRIX_N = TN * 2;
218-
float A[MATRIX_M][MATRIX_N];
219-
float D[MATRIX_M][MATRIX_N];
217+
static constexpr size_t MATRIX_K = TK * 2;
218+
float A[MATRIX_M][MATRIX_K];
219+
float D[MATRIX_M][MATRIX_K];
220220

221221
int main() {
222222

223-
big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
224-
big_matrix<float, MATRIX_M, MATRIX_N> MA((float *)&A);
223+
big_matrix<float, MATRIX_M, MATRIX_K> MD((float *)&D);
224+
big_matrix<float, MATRIX_M, MATRIX_K> MA((float *)&A);
225225

226226
size_t NDRangeM = MATRIX_M / TM;
227-
size_t NDRangeN = MATRIX_N / TK;
227+
size_t NDRangeK = MATRIX_K / TK;
228228
queue q;
229-
nd_range<2> r({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ});
229+
nd_range<2> r({NDRangeM, NDRangeK * SG_SZ}, {1, 1 * SG_SZ});
230230

231-
matrix_verify_add<precision::tf32, float, MATRIX_M, MATRIX_N>(q, MA, r, 7.0);
232-
matrix_verify_sub<precision::tf32, float, MATRIX_M, MATRIX_N>(q, MA, r, 3.0);
233-
matrix_verify_mul<precision::tf32, float, MATRIX_M, MATRIX_N>(q, MA, r, 15.0);
234-
matrix_verify_div<precision::tf32, float, MATRIX_M, MATRIX_N>(q, MA, r, 2.0);
235-
matrix_verify_logic<precision::tf32, float, MATRIX_M, MATRIX_N>(q, MA, r,
231+
matrix_verify_add<precision::tf32, float, MATRIX_M, MATRIX_K>(q, MA, r, 7.0);
232+
matrix_verify_sub<precision::tf32, float, MATRIX_M, MATRIX_K>(q, MA, r, 3.0);
233+
matrix_verify_mul<precision::tf32, float, MATRIX_M, MATRIX_K>(q, MA, r, 15.0);
234+
matrix_verify_div<precision::tf32, float, MATRIX_M, MATRIX_K>(q, MA, r, 2.0);
235+
matrix_verify_logic<precision::tf32, float, MATRIX_M, MATRIX_K>(q, MA, r,
236236
7.0);
237237

238238
return 0;

0 commit comments

Comments
 (0)