Skip to content

Commit 52f34fd

Browse files
authored
[SYCL] Change the query API to match the new changes made to the matrix API (#6981)
1 parent 10d50ed commit 52f34fd

File tree

2 files changed

+218
-82
lines changed

2 files changed

+218
-82
lines changed

sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp

Lines changed: 75 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -56,25 +56,25 @@ enum class matrix_type {
5656
enum class scope_t { sub_group, work_group };
5757

5858
template <tpu u, typename Ta = void, typename Tb = void, typename Tc = void,
59-
int M = 0, int N = 0, int K = 0, typename Enabled = void>
59+
int sM = 0, int sN = 0, int sK = 0, typename Enabled = void>
6060
struct tpu_params;
6161

6262
#if __cplusplus >= 201703L
6363
template <typename Ta, typename Tb, typename Tc>
64-
constexpr bool is_combination_valid_amx(int M, int N, int K) {
64+
constexpr bool is_combination_valid_amx(int sM, int sN, int sK) {
6565
// is_same_v is a C++17 feature
6666
if ((std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, int8_t> &&
67-
std::is_same_v<Tc, int> && M <= 16 && N <= 16 && K <= 64) ||
67+
std::is_same_v<Tc, int> && sM <= 16 && sN <= 16 && sK <= 64) ||
6868
(std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, uint8_t> &&
69-
std::is_same_v<Tc, int> && M <= 16 && N <= 16 && K <= 64) ||
69+
std::is_same_v<Tc, int> && sM <= 16 && sN <= 16 && sK <= 64) ||
7070
(std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, uint8_t> &&
71-
std::is_same_v<Tc, int> && M <= 16 && N <= 16 && K <= 64) ||
71+
std::is_same_v<Tc, int> && sM <= 16 && sN <= 16 && sK <= 64) ||
7272
(std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, int8_t> &&
73-
std::is_same_v<Tc, int> && M <= 16 && N <= 16 && K <= 64) ||
73+
std::is_same_v<Tc, int> && sM <= 16 && sN <= 16 && sK <= 64) ||
7474
// bf16
7575
(std::is_same_v<Ta, unsigned short> &&
7676
std::is_same_v<Tb, unsigned short> && std::is_same_v<Tc, float> &&
77-
M <= 16 && N <= 16 && K <= 32))
77+
sM <= 16 && sN <= 16 && sK <= 32))
7878
return true;
7979
else
8080
return false;
@@ -100,11 +100,11 @@ constexpr bool are_types_valid_amx() {
100100

101101
// General query:
102102
// types are not given, no default sizes and no implicit matrix construction
103-
template <int M, int N, int K>
104-
struct tpu_params<tpu::amx, void, void, void, M, N, K> {
105-
static constexpr std::size_t defaultM = -1; // depends on the type
106-
static constexpr std::size_t defaultN = -1;
107-
static constexpr std::size_t defaultK = -1;
103+
template <int sM, int sN, int sK>
104+
struct tpu_params<tpu::amx, void, void, void, sM, sN, sK> {
105+
static constexpr std::size_t M = -1; // depends on the type
106+
static constexpr std::size_t N = -1;
107+
static constexpr std::size_t K = -1;
108108

109109
bool dynamic_p = false; // should be true in future implementations because
110110
// AMX hardware supports dynamic sizes
@@ -116,7 +116,7 @@ struct tpu_params<tpu::amx, void, void, void, M, N, K> {
116116
uint32_t max_ksize;
117117
matrix_type atype;
118118
matrix_type btype;
119-
matrix_type ctype;
119+
matrix_type accumulatortype;
120120
uint32_t msize;
121121
uint32_t nsize;
122122
uint32_t ksize;
@@ -146,19 +146,17 @@ struct tpu_params<tpu::amx, Ta, Tb, Tc, 0, 0, 0,
146146
"DPC++ code to implement bf16) ");
147147

148148
// construct the matrices using the default sizes
149-
static constexpr std::size_t defaultM = 16;
150-
static constexpr std::size_t defaultN = 16;
151-
static constexpr std::size_t defaultK = ((sizeof(Ta) == 1) ? 64 : 32);
149+
static constexpr std::size_t M = 16;
150+
static constexpr std::size_t N = 16;
151+
static constexpr std::size_t K = ((sizeof(Ta) == 1) ? 64 : 32);
152152

153153
template <typename Group>
154-
using joint_matrix_a =
155-
joint_matrix<Ta, defaultM, defaultK, use::a, layout::row_major, Group>;
154+
using joint_matrix_a = joint_matrix<Ta, M, K, use::a, layout::unused, Group>;
156155
template <typename Group>
157-
using joint_matrix_b =
158-
joint_matrix<Tb, defaultK, defaultN, use::b, layout::packed_b, Group>;
156+
using joint_matrix_b = joint_matrix<Tb, K, N, use::b, layout::unused, Group>;
159157
template <typename Group>
160-
using joint_matrix_c = joint_matrix<Tc, defaultM, defaultN, use::accumulator,
161-
layout::row_major, Group>;
158+
using joint_matrix_accumulator =
159+
joint_matrix<Tc, M, N, use::accumulator, layout::unused, Group>;
162160

163161
bool dynamic_p = false; // should be true in future implementations because
164162
// AMX hardware supports dynamic sizes
@@ -170,7 +168,7 @@ struct tpu_params<tpu::amx, Ta, Tb, Tc, 0, 0, 0,
170168
uint32_t max_ksize;
171169
matrix_type atype;
172170
matrix_type btype;
173-
matrix_type ctype;
171+
matrix_type accumulatortype;
174172
uint32_t msize;
175173
uint32_t nsize;
176174
uint32_t ksize;
@@ -183,36 +181,34 @@ struct tpu_params<tpu::amx, Ta, Tb, Tc, 0, 0, 0,
183181

184182
// Valid or not:
185183
// Specialization when both types and sizes are given
186-
template <typename Ta, typename Tb, typename Tc, int M, int N, int K>
184+
template <typename Ta, typename Tb, typename Tc, int sM, int sN, int sK>
187185
struct tpu_params<
188-
tpu::amx, Ta, Tb, Tc, M, N, K,
186+
tpu::amx, Ta, Tb, Tc, sM, sN, sK,
189187
typename std::enable_if<(
190188
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
191-
!std::is_same_v<Tc, void> && M != 0 && N != 0 && K != 0)>::type> {
189+
!std::is_same_v<Tc, void> && sM != 0 && sN != 0 && sK != 0)>::type> {
192190
// Validate that parameters are supported
193191
static_assert(
194-
(M == 0 && N == 0 && K == 0) ||
195-
(is_combination_valid_amx<Ta, Tb, Tc>(M, N, K)),
192+
(sM == 0 && sN == 0 && sK == 0) ||
193+
(is_combination_valid_amx<Ta, Tb, Tc>(sM, sN, sK)),
196194
"Invalid parameters for AMX, query valid types and maximum sizes "
197195
"using: tpu_params<tpu::amx> myparams; and then check out "
198196
"myparams.combinations array");
199197

200198
// if combination is valid, construct the matrices
201199

202-
static constexpr std::size_t defaultM = (M != 0) ? M : 16;
203-
static constexpr std::size_t defaultN = (N != 0) ? N : 16;
204-
static constexpr std::size_t defaultK =
205-
(K != 0) ? K : ((sizeof(Ta) == 1) ? 64 : 32);
200+
static constexpr std::size_t M = (sM != 0) ? sM : 16;
201+
static constexpr std::size_t N = (sN != 0) ? sN : 16;
202+
static constexpr std::size_t K =
203+
(sK != 0) ? sK : ((sizeof(Ta) == 1) ? 64 : 32);
206204

207205
template <typename Group>
208-
using joint_matrix_a =
209-
joint_matrix<Ta, defaultM, defaultK, use::a, layout::row_major, Group>;
206+
using joint_matrix_a = joint_matrix<Ta, M, K, use::a, layout::unused, Group>;
210207
template <typename Group>
211-
using joint_matrix_b =
212-
joint_matrix<Tb, defaultK, defaultN, use::b, layout::packed_b, Group>;
208+
using joint_matrix_b = joint_matrix<Tb, K, N, use::b, layout::unused, Group>;
213209
template <typename Group>
214-
using joint_matrix_c = joint_matrix<Tc, defaultM, defaultN, use::accumulator,
215-
layout::row_major, Group>;
210+
using joint_matrix_accumulator =
211+
joint_matrix<Tc, M, N, use::accumulator, layout::unused, Group>;
216212

217213
bool dynamic_p = false; // should be true in future implementations
218214
// because AMX hardware supports dynamic sizes
@@ -226,25 +222,25 @@ struct tpu_params<
226222
// capabilities of the DPAS hardware.
227223

228224
template <typename Ta, typename Tb, typename Tc>
229-
constexpr bool is_combination_valid_dpas(int M, int N, int K) {
225+
constexpr bool is_combination_valid_dpas(int sM, int sN, int sK) {
230226
if ((std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, int8_t> &&
231-
std::is_same_v<Tc, int> && (M == 1 || M == 2 || M == 4 || M == 8) &&
232-
N == 8 && K == 32) ||
227+
std::is_same_v<Tc, int> && (sM == 1 || sM == 2 || sM == 4 || sM == 8) &&
228+
sN == 8 && sK == 32) ||
233229
(std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, uint8_t> &&
234-
std::is_same_v<Tc, int> && (M == 1 || M == 2 || M == 4 || M == 8) &&
235-
N == 8 && K == 32) ||
230+
std::is_same_v<Tc, int> && (sM == 1 || sM == 2 || sM == 4 || sM == 8) &&
231+
sN == 8 && sK == 32) ||
236232
(std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, int8_t> &&
237-
std::is_same_v<Tc, int> && (M == 1 || M == 2 || M == 4 || M == 8) &&
238-
N == 8 && K == 32) ||
233+
std::is_same_v<Tc, int> && (sM == 1 || sM == 2 || sM == 4 || sM == 8) &&
234+
sN == 8 && sK == 32) ||
239235
(std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, uint8_t> &&
240-
std::is_same_v<Tc, int> && (M == 1 || M == 2 || M == 4 || M == 8) &&
241-
N == 8 && K == 32) ||
236+
std::is_same_v<Tc, int> && (sM == 1 || sM == 2 || sM == 4 || sM == 8) &&
237+
sN == 8 && sK == 32) ||
242238
(std::is_same_v<Ta, half> && std::is_same_v<Tb, half> &&
243-
std::is_same_v<Tc, float> && (M == 1 || M == 2 || M == 4 || M == 8) &&
244-
N == 8 && K == 16) ||
239+
std::is_same_v<Tc, float> &&
240+
(sM == 1 || sM == 2 || sM == 4 || sM == 8) && sN == 8 && sK == 16) ||
245241
(std::is_same_v<Ta, unsigned short> &&
246242
std::is_same_v<Tb, unsigned short> && std::is_same_v<Tc, float> &&
247-
(M == 1 || M == 2 || M == 4 || M == 8) && N == 8 && K == 16))
243+
(sM == 1 || sM == 2 || sM == 4 || sM == 8) && sN == 8 && sK == 16))
248244
return true;
249245
else
250246
return false;
@@ -272,11 +268,11 @@ constexpr bool are_types_valid_dpas() {
272268

273269
// General Query
274270
// specialization for when types are not given --> no default values
275-
template <int M, int N, int K>
276-
struct tpu_params<tpu::dpas, void, void, void, M, N, K> {
277-
static constexpr std::size_t defaultM = -1; // depends on the type
278-
static constexpr std::size_t defaultN = -1;
279-
static constexpr std::size_t defaultK = -1;
271+
template <int sM, int sN, int sK>
272+
struct tpu_params<tpu::dpas, void, void, void, sM, sN, sK> {
273+
static constexpr std::size_t M = -1; // depends on the type
274+
static constexpr std::size_t N = -1;
275+
static constexpr std::size_t K = -1;
280276

281277
bool dynamic_p = false; // no dynamic allocation on the GPU
282278
uint32_t numtiles = -1; // does not apply for DPAS
@@ -288,7 +284,7 @@ struct tpu_params<tpu::dpas, void, void, void, M, N, K> {
288284
uint32_t max_ksize;
289285
matrix_type atype;
290286
matrix_type btype;
291-
matrix_type ctype;
287+
matrix_type accumulatortype;
292288
uint32_t msize;
293289
uint32_t nsize;
294290
uint32_t ksize;
@@ -340,19 +336,17 @@ struct tpu_params<tpu::dpas, Ta, Tb, Tc, 0, 0, 0,
340336

341337
// construct the matrices using the default sizes
342338

343-
static constexpr std::size_t defaultM = 8;
344-
static constexpr std::size_t defaultN = 8;
345-
static constexpr std::size_t defaultK = ((sizeof(Ta) == 1) ? 32 : 16);
339+
static constexpr std::size_t M = 8;
340+
static constexpr std::size_t N = 8;
341+
static constexpr std::size_t K = ((sizeof(Ta) == 1) ? 32 : 16);
346342

347343
template <typename Group>
348-
using joint_matrix_a =
349-
joint_matrix<Ta, defaultM, defaultK, use::a, layout::row_major, Group>;
344+
using joint_matrix_a = joint_matrix<Ta, M, K, use::a, layout::unused, Group>;
350345
template <typename Group>
351-
using joint_matrix_b =
352-
joint_matrix<Tb, defaultK, defaultN, use::b, layout::packed_b, Group>;
346+
using joint_matrix_b = joint_matrix<Tb, K, N, use::b, layout::unused, Group>;
353347
template <typename Group>
354-
using joint_matrix_c = joint_matrix<Tc, defaultM, defaultN, use::accumulator,
355-
layout::row_major, Group>;
348+
using joint_matrix_accumulator =
349+
joint_matrix<Tc, M, N, use::accumulator, layout::unused, Group>;
356350

357351
bool dynamic_p = false; // no dynamic allocation on the GPU
358352
uint32_t numtiles = -1; // does not apply for DPAS
@@ -363,15 +357,16 @@ struct tpu_params<tpu::dpas, Ta, Tb, Tc, 0, 0, 0,
363357
uint32_t max_ksize;
364358
matrix_type atype;
365359
matrix_type btype;
366-
matrix_type ctype;
360+
matrix_type accumulatortype;
367361
uint32_t msize;
368362
uint32_t nsize;
369363
uint32_t ksize;
370364
};
371365
using mt = matrix_type;
372366
static constexpr combination combinations[] = {
373367
// The types used in the initialization below are fake and not used. In
374-
// this case, users already chose the types, they are only looking for the
368+
// this case, users already chose the types, they are only looking for
369+
// the
375370
// sizes
376371
{0, 0, 0, mt::bf8, mt::bf8, mt::bf8, 1, 8, (sizeof(Ta) == 1) ? 32 : 16},
377372
{0, 0, 0, mt::bf8, mt::bf8, mt::bf8, 2, 8, (sizeof(Ta) == 1) ? 32 : 16},
@@ -384,32 +379,30 @@ struct tpu_params<tpu::dpas, Ta, Tb, Tc, 0, 0, 0,
384379

385380
// Valid or not:
386381
// Specialization when both types and sizes are given
387-
template <typename Ta, typename Tb, typename Tc, int M, int N, int K>
382+
template <typename Ta, typename Tb, typename Tc, int sM, int sN, int sK>
388383
struct tpu_params<
389-
tpu::dpas, Ta, Tb, Tc, M, N, K,
390-
typename std::enable_if<((!std::is_same_v<Ta, void> && M != 0))>::type> {
384+
tpu::dpas, Ta, Tb, Tc, sM, sN, sK,
385+
typename std::enable_if<((!std::is_same_v<Ta, void> && sM != 0))>::type> {
391386
// Validate that parameters are supported
392-
static_assert((M == 0 && N == 0 && K == 0) ||
393-
(is_combination_valid_dpas<Ta, Tb, Tc>(M, N, K)),
387+
static_assert((sM == 0 && sN == 0 && sK == 0) ||
388+
(is_combination_valid_dpas<Ta, Tb, Tc>(sM, sN, sK)),
394389
"Invalid parameters for DPAS, query valid combinations "
395390
"using: tpu_params<tpu::dpas> myparams; and then check out "
396391
"myparams.combinations array");
397392

398393
// if combination is valid, construct the matrices
399-
static constexpr std::size_t defaultM = (M != 0) ? M : 8;
400-
static constexpr std::size_t defaultN = (N != 0) ? N : 8;
401-
static constexpr std::size_t defaultK =
402-
(K != 0) ? K : ((sizeof(Ta) == 1) ? 32 : 16);
394+
static constexpr std::size_t M = (sM != 0) ? sM : 8;
395+
static constexpr std::size_t N = (sN != 0) ? sN : 8;
396+
static constexpr std::size_t K =
397+
(sK != 0) ? sK : ((sizeof(Ta) == 1) ? 32 : 16);
403398

404399
template <typename Group>
405-
using joint_matrix_a =
406-
joint_matrix<Ta, defaultM, defaultK, use::a, layout::row_major, Group>;
400+
using joint_matrix_a = joint_matrix<Ta, M, K, use::a, layout::unused, Group>;
407401
template <typename Group>
408-
using joint_matrix_b =
409-
joint_matrix<Tb, defaultK, defaultN, use::b, layout::packed_b, Group>;
402+
using joint_matrix_b = joint_matrix<Tb, K, N, use::b, layout::unused, Group>;
410403
template <typename Group>
411-
using joint_matrix_c = joint_matrix<Tc, defaultM, defaultN, use::accumulator,
412-
layout::row_major, Group>;
404+
using joint_matrix_accumulator =
405+
joint_matrix<Tc, M, N, use::accumulator, layout::unused, Group>;
413406

414407
bool dynamic_p = false; // no dynamic allocation on the GPU
415408
uint32_t numtiles = -1; // does not apply for DPAS

0 commit comments

Comments
 (0)