@@ -11,7 +11,7 @@ using namespace sycl::ext::oneapi;
11
11
using namespace sycl ::ext::oneapi::experimental::matrix;
12
12
constexpr float bf16_eps = 0.00390625 ;
13
13
14
- // Example usage of Nvidia matrix multiply.
14
+ // Example usage of joint_matrix matrix multiply.
15
15
// Optimizations such as memory paddings for avoiding bank conflicts are not
16
16
// included in this test which aids clarity for what is going on. This example
17
17
// forms a "Big matrix" corresponding to a single "TILE" using cuda example
@@ -30,37 +30,47 @@ constexpr float bf16_eps = 0.00390625;
30
30
constexpr int N_THREADS_PER_MATRIX_OP = 32 ;
31
31
32
32
// number of submatrices per row of accumulator ("C", "D") matrices.
33
- constexpr int SUB_TILES_M = 3 ;
33
+ constexpr int SUB_TILES_M = 2 ;
34
34
// number of submatrices per col of accumulator matrices.
35
35
constexpr int SUB_TILES_N = 2 ;
36
36
// number of submatrices per col of "A"/per row of "B", matrices.
37
- constexpr int SUB_TILES_K = 1 ;
37
+ constexpr int SUB_TILES_K = 2 ;
38
38
39
- template <typename Tm, typename Tc, typename Td, size_t M, size_t K, size_t N>
39
+ template <typename Tm, typename Tc, typename Td, size_t M, size_t K, size_t N,
40
+ layout layout_A, layout layout_B, layout layout_C>
40
41
class TypeHelper ;
41
42
42
- template <typename Tm, typename Tc, typename Td, size_t M, size_t K, size_t N>
43
- using KernelName = class TypeHelper <Tm, Tc, Td, M, K, N>;
43
+ template <typename Tm, typename Tc, typename Td, size_t M, size_t K, size_t N,
44
+ layout layout_A, layout layout_B, layout layout_C>
45
+ using KernelName =
46
+ class TypeHelper <Tm, Tc, Td, M, K, N, layout_A, layout_B, layout_C>;
44
47
45
- template <size_t Big_N, size_t Big_K, typename Tm, typename Tc>
48
+ template <size_t Big_N, size_t Big_K, size_t Big_M, layout layout_A,
49
+ layout layout_B, typename Tm, typename Tc>
46
50
Tc matrix_ref_mn (const int &m, const int &n, Tm *A, Tm *B, Tc *C) {
47
51
Tc res = C[m * Big_N + n];
48
52
49
- if constexpr (std::is_same<Tm, bfloat16>::value) {
50
- for (int k = 0 ; k < Big_K; k++)
51
- res += A[m * Big_K + k] * B[k * Big_N + n];
52
- } else {
53
- for (int k = 0 ; k < Big_K; k++)
54
- res +=
55
- static_cast <Tc>(A[m * Big_K + k]) * static_cast <Tc>(B[k * Big_N + n]);
53
+ for (int k = 0 ; k < Big_K; k++) {
54
+ auto index_a =
55
+ layout_A == layout::row_major ? m * Big_K + k : m + k * Big_M;
56
+ auto index_b =
57
+ layout_B == layout::row_major ? k * Big_N + n : k + n * Big_K;
58
+
59
+ if constexpr (std::is_same<Tm, bfloat16>::value) {
60
+ res += A[index_a] * B[index_b];
61
+ } else {
62
+ res += static_cast <Tc>(A[index_a]) * static_cast <Tc>(B[index_b]);
63
+ }
56
64
}
57
65
58
66
return res;
59
67
}
60
68
61
- template <typename Tm, typename Tc, typename Td, size_t Sub_Tiles_M,
62
- size_t Sub_Tiles_K, size_t Sub_Tiles_N, size_t M, size_t K, size_t N,
63
- typename T3 = std::remove_const_t <Tm>>
69
+ template <
70
+ typename Tm, typename Tc, typename Td, size_t Sub_Tiles_M,
71
+ size_t Sub_Tiles_K, size_t Sub_Tiles_N, size_t M, size_t K, size_t N,
72
+ layout layout_A = layout::row_major, layout layout_B = layout::row_major,
73
+ layout layout_C = layout::row_major, typename T3 = std::remove_const_t <Tm>>
64
74
void test (queue &q) {
65
75
// total number of M dimension matrix elements for the "Big matrix".
66
76
constexpr auto Big_M = Sub_Tiles_M * M;
@@ -97,7 +107,8 @@ void test(queue &q) {
97
107
accessor<bfloat16, 1 , access::mode::write, target::device> accA (bufA,
98
108
cgh);
99
109
100
- cgh.parallel_for <KernelName<Tm, Tc, class copyA , M, K, N>>(
110
+ cgh.parallel_for <KernelName<Tm, Tc, class copyA , M, K, N, layout_A,
111
+ layout_B, layout_C>>(
101
112
range<1 >(Big_M * Big_K), [=](item<1 > item) {
102
113
auto i = item.get_linear_id ();
103
114
accA[i] = 0 .1f * (i % 10 );
@@ -107,7 +118,8 @@ void test(queue &q) {
107
118
accessor<bfloat16, 1 , access::mode::write, target::device> accB (bufB,
108
119
cgh);
109
120
110
- cgh.parallel_for <KernelName<Tm, Tc, class copyB , M, K, N>>(
121
+ cgh.parallel_for <KernelName<Tm, Tc, class copyB , M, K, N, layout_A,
122
+ layout_B, layout_C>>(
111
123
range<1 >(Big_K * Big_N), [=](item<1 > item) {
112
124
auto i = item.get_linear_id ();
113
125
accB[i] = 0 .1f * (i % 10 );
@@ -130,41 +142,55 @@ void test(queue &q) {
130
142
range<2 > GlobalRange = {Sub_Tiles_M,
131
143
Sub_Tiles_N * N_THREADS_PER_MATRIX_OP};
132
144
133
- cgh.parallel_for <KernelName<Tm, Tc, Td, M, K, N>>(
145
+ cgh.parallel_for <
146
+ KernelName<Tm, Tc, Td, M, K, N, layout_A, layout_B, layout_C>>(
134
147
nd_range<2 >(GlobalRange, LocalRange), [=](nd_item<2 > item) {
135
148
sycl::sub_group sg = item.get_sub_group ();
136
149
// row id of current submatrix of BIG C matrix
137
150
const auto m = item.get_group ().get_group_id ()[0 ];
138
151
// column id of current submatrix of BIG C matrix
139
152
const auto n = item.get_group ().get_group_id ()[1 ];
140
153
141
- joint_matrix<sycl::sub_group, T3, use::a, M, K, layout::row_major>
142
- sub_a;
143
- joint_matrix<sycl::sub_group, T3, use::b, K, N, layout::row_major>
144
- sub_b;
154
+ joint_matrix<sycl::sub_group, T3, use::a, M, K, layout_A> sub_a;
155
+ joint_matrix<sycl::sub_group, T3, use::b, K, N, layout_B> sub_b;
145
156
joint_matrix<sycl::sub_group, std::remove_const_t <Tc>,
146
157
use::accumulator, M, N>
147
158
sub_c;
148
159
joint_matrix<sycl::sub_group, Td, use::accumulator, M, N> sub_d;
160
+ auto stride_C = layout_C == layout::row_major ? Big_N : Big_M;
161
+ auto load_stride_C = layout_C == layout::row_major
162
+ ? (m * M) * Big_N + n * N
163
+ : (m * M) + n * N * Big_M;
149
164
150
165
joint_matrix_load (
151
166
sg, sub_c,
152
167
accC.template get_multi_ptr <access::decorated::no>() +
153
- (m * M) * Big_N + n * N,
154
- Big_N, layout::row_major);
168
+ load_stride_C,
169
+ stride_C, layout_C);
170
+
171
+ auto stride_A = layout_A == layout::row_major ? Big_K : Big_M;
172
+ auto stride_B = layout_B == layout::row_major ? Big_N : Big_K;
173
+
155
174
// k = row/col id of current submatrix of BIG A/B matrices
156
175
for (int k = 0 ; k < Sub_Tiles_K; k++) {
176
+ auto load_stride_A = layout_A == layout::row_major
177
+ ? (k * K) + (m * M * Big_K)
178
+ : (k * K * Big_M) + (m * M);
179
+ auto load_stride_B = layout_B == layout::row_major
180
+ ? (k * K * Big_N) + (n * N)
181
+ : (k * K) + (n * N * Big_K);
182
+
157
183
joint_matrix_load (
158
184
sg, sub_a,
159
185
accA.template get_multi_ptr <access::decorated::no>() +
160
- (k * K) + (m * M * Big_K) ,
161
- Big_K );
186
+ load_stride_A ,
187
+ stride_A );
162
188
163
189
joint_matrix_load (
164
190
sg, sub_b,
165
191
accB.template get_multi_ptr <access::decorated::no>() +
166
- (k * K * Big_N) + (n * N) ,
167
- Big_N );
192
+ load_stride_B ,
193
+ stride_B );
168
194
169
195
// round values to correct precision if using tf32
170
196
if constexpr (std::is_same<T3, precision::tf32>::value) {
@@ -174,27 +200,32 @@ void test(queue &q) {
174
200
}
175
201
176
202
joint_matrix_mad (sg, sub_d, sub_a, sub_b, sub_c);
203
+ joint_matrix_copy (sg, sub_d, sub_c);
177
204
}
178
205
joint_matrix_store (
179
206
sg, sub_d,
180
207
accD.template get_multi_ptr <access::decorated::no>() +
181
- (m * M) * Big_N + n * N ,
182
- Big_N, layout::row_major );
208
+ load_stride_C ,
209
+ stride_C, layout_C );
183
210
});
184
211
});
185
212
q.wait ();
186
213
}
187
214
188
215
for (int m = 0 ; m < Big_M; m++) {
189
216
for (int n = 0 ; n < Big_N; n++) {
217
+ auto index_D =
218
+ layout_C == layout::row_major ? m * Big_N + n : m + n * Big_M;
190
219
if constexpr (std::is_same<std::remove_const_t <Tm>, bfloat16>::value) {
191
- auto res_device = matrix_ref_mn<Big_N, Big_K>(m, n, A, B, C);
192
- assert (fabs (2 * (D[m * Big_N + n] - res_device)) /
193
- (D[m * Big_N + n] + res_device) <
220
+ auto res_device =
221
+ matrix_ref_mn<Big_N, Big_K, Big_M, layout_A, layout_B>(m, n, A, B,
222
+ C);
223
+ assert (fabs (2 * (D[index_D] - res_device)) / (D[index_D] + res_device) <
194
224
bf16_eps * 2 );
195
225
} else {
196
- assert (
197
- (D[m * Big_N + n] == matrix_ref_mn<Big_N, Big_K>(m, n, A, B, C)));
226
+ assert ((D[index_D] ==
227
+ matrix_ref_mn<Big_N, Big_K, Big_M, layout_A, layout_B>(m, n, A,
228
+ B, C)));
198
229
}
199
230
}
200
231
}
0 commit comments