Skip to content

Commit 8316a9c

Browse files
authored
[SYCL] Add DPAS of PVC query and other improvements: (#7562)
- Remove dynamic_p from the query since dynamic extent of the sizes is not supported - Add xmx16 to the list of TPUs. This is the PVC TPU - Change float19 to tf32 - Extend scope to return more than one value
1 parent ffda344 commit 8316a9c

File tree

2 files changed

+262
-61
lines changed

2 files changed

+262
-61
lines changed

sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp

Lines changed: 233 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,16 @@
66
//
77
// ===--------------------------------------------------------------------=== //
88
// This file implements the static query interface for the joint_matrix
9-
// experimental extension. AMX, DPAS and different other TPUs support different
10-
// logical sizes and types. The query interface is used to validate user code
11-
// and inform them about supported types, sizes, scope, and layouts by the
12-
// current implementation. Note that this query interface is a compile-time
13-
// query, so there will be no runtime errors. The query interface provides
14-
// three functionalities:
15-
// 1- At compile time, inform the user whether a specific
16-
// combination is valid or not.
17-
// 2- Construct the matrices using a default shape
18-
// if user does not provide a combination
19-
// 3- General query interface for sizes, types,
20-
// static/dynamic, scope. This is needed to void padding by the user,
21-
// for tuning, and efficient code generation if used by a library.
9+
// experimental extension. Intel AMX, Intel XMX, and Nvidia Tensor Cores support
10+
// different logical sizes and types. The query interface is used to validate
11+
// user code and inform them about supported types, sizes, scopes, and layouts
12+
// by the current implementation. Note that this query interface is a
13+
// compile-time query, so there will be no runtime errors. The query interface
14+
// provides three functionalities: 1- At compile time, inform the user whether a
15+
// specific combination is valid or not. 2- Construct the matrices using a
16+
// default shape if user does not provide a combination 3- General query
17+
// interface for sizes, types, scopes. This is needed to void padding by the
18+
// user, for tuning, and efficient code generation if used by a library.
2219

2320
#pragma once
2421

@@ -29,14 +26,15 @@ namespace oneapi {
2926
namespace experimental::matrix {
3027

3128
enum class tpu {
32-
dpas,
29+
xmx8,
30+
xmx16,
3331
amx,
3432
};
3533
enum class matrix_type {
3634
bf8,
3735
bf16,
3836
fp16,
39-
fp19, // tfloat32
37+
tf32,
4038
fp32,
4139
fp64,
4240
sint2,
@@ -104,10 +102,9 @@ struct tpu_params<tpu::amx, void, void, void, sM, sN, sK> {
104102
static constexpr std::size_t N = -1;
105103
static constexpr std::size_t K = -1;
106104

107-
bool dynamic_p = false; // should be true in future implementations because
108-
// AMX hardware supports dynamic sizes
109105
uint32_t numtiles = 8;
110-
scope_t scope = scope_t::sub_group;
106+
static constexpr scope_t scopes[] = {scope_t::sub_group};
107+
static constexpr int num_scopes = sizeof(scopes) / sizeof(scope_t);
111108
struct combination {
112109
uint32_t max_msize;
113110
uint32_t max_nsize;
@@ -155,10 +152,9 @@ struct tpu_params<tpu::amx, Ta, Tb, Tc, 0, 0, 0,
155152
using joint_matrix_accumulator =
156153
joint_matrix<Group, Tc, use::accumulator, M, N>;
157154

158-
bool dynamic_p = false; // should be true in future implementations because
159-
// AMX hardware supports dynamic sizes
160155
uint32_t numtiles = 8;
161-
scope_t scope = scope_t::sub_group;
156+
static constexpr scope_t scopes[] = {scope_t::sub_group};
157+
static constexpr int num_scopes = sizeof(scopes) / sizeof(scope_t);
162158
struct combination {
163159
uint32_t max_msize;
164160
uint32_t max_nsize;
@@ -207,19 +203,18 @@ struct tpu_params<
207203
using joint_matrix_accumulator =
208204
joint_matrix<Group, Tc, use::accumulator, M, N>;
209205

210-
bool dynamic_p = false; // should be true in future implementations
211-
// because AMX hardware supports dynamic sizes
212206
uint32_t numtiles = 8;
213-
scope_t scope = scope_t::sub_group;
207+
static constexpr scope_t scopes[] = {scope_t::sub_group};
208+
static constexpr int num_scopes = sizeof(scopes) / sizeof(scope_t);
214209
};
215210

216-
// DPAS case
217-
// The DPAS implementation supports the logical capability support of the HW
218-
// So in this case, M, N, K sizes returned by the query represent the logical
219-
// capabilities of the DPAS hardware.
211+
// Intel XMX with SIMD8 capability
212+
// The Intel XMX implementation supports the logical capability support of the
213+
// HW So in this case, M, N, K sizes returned by the query represent the logical
214+
// capabilities of the Intel XMX hardware.
220215

221216
template <typename Ta, typename Tb, typename Tc>
222-
constexpr bool is_combination_valid_dpas(int sM, int sN, int sK) {
217+
constexpr bool is_combination_valid_xmx8(int sM, int sN, int sK) {
223218
if ((std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, int8_t> &&
224219
std::is_same_v<Tc, int> && (sM == 1 || sM == 2 || sM == 4 || sM == 8) &&
225220
sN == 8 && sK == 32) ||
@@ -244,7 +239,7 @@ constexpr bool is_combination_valid_dpas(int sM, int sN, int sK) {
244239
}
245240

246241
template <typename Ta, typename Tb, typename Tc>
247-
constexpr bool are_types_valid_dpas() {
242+
constexpr bool are_types_valid_xmx8() {
248243
if ((std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, int8_t> &&
249244
std::is_same_v<Tc, int>) ||
250245
(std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, int8_t> &&
@@ -265,14 +260,14 @@ constexpr bool are_types_valid_dpas() {
265260
// General Query
266261
// specialization for when types are not given --> no default values
267262
template <int sM, int sN, int sK>
268-
struct tpu_params<tpu::dpas, void, void, void, sM, sN, sK> {
263+
struct tpu_params<tpu::xmx8, void, void, void, sM, sN, sK> {
269264
static constexpr std::size_t M = -1; // depends on the type
270265
static constexpr std::size_t N = -1;
271266
static constexpr std::size_t K = -1;
272267

273-
bool dynamic_p = false; // no dynamic allocation on the GPU
274-
uint32_t numtiles = -1; // does not apply for DPAS
275-
scope_t scope = scope_t::sub_group;
268+
uint32_t numtiles = -1; // does not apply for XMX8
269+
static constexpr scope_t scopes[] = {scope_t::sub_group};
270+
static constexpr int num_scopes = sizeof(scopes) / sizeof(scope_t);
276271

277272
struct combination {
278273
uint32_t max_msize;
@@ -320,12 +315,12 @@ struct tpu_params<tpu::dpas, void, void, void, sM, sN, sK> {
320315
// Specialization for when only types are given, need to query only sizes
321316

322317
template <typename Ta, typename Tb, typename Tc>
323-
struct tpu_params<tpu::dpas, Ta, Tb, Tc, 0, 0, 0,
318+
struct tpu_params<tpu::xmx8, Ta, Tb, Tc, 0, 0, 0,
324319
typename std::enable_if<(!std::is_same_v<Ta, void> &&
325320
!std::is_same_v<Tb, void> &&
326321
!std::is_same_v<Tc, void>)>::type> {
327-
static_assert((are_types_valid_dpas<Ta, Tb, Tc>()),
328-
"Invalid types for DPAS, supported types are int8_t, uint8_t, "
322+
static_assert((are_types_valid_xmx8<Ta, Tb, Tc>()),
323+
"Invalid types for XMX8, supported types are int8_t, uint8_t, "
329324
"half, and bf16 (Note that unsigned short should be used in the"
330325
"DPC++ code to implement bf16)");
331326

@@ -343,9 +338,9 @@ struct tpu_params<tpu::dpas, Ta, Tb, Tc, 0, 0, 0,
343338
using joint_matrix_accumulator =
344339
joint_matrix<Group, Tc, use::accumulator, M, N>;
345340

346-
bool dynamic_p = false; // no dynamic allocation on the GPU
347-
uint32_t numtiles = -1; // does not apply for DPAS
348-
scope_t scope = scope_t::sub_group;
341+
uint32_t numtiles = -1; // does not apply for XMX8
342+
static constexpr scope_t scopes[] = {scope_t::sub_group};
343+
static constexpr int num_scopes = sizeof(scopes) / sizeof(scope_t);
349344
struct combination {
350345
uint32_t max_msize;
351346
uint32_t max_nsize;
@@ -376,13 +371,13 @@ struct tpu_params<tpu::dpas, Ta, Tb, Tc, 0, 0, 0,
376371
// Specialization when both types and sizes are given
377372
template <typename Ta, typename Tb, typename Tc, int sM, int sN, int sK>
378373
struct tpu_params<
379-
tpu::dpas, Ta, Tb, Tc, sM, sN, sK,
374+
tpu::xmx8, Ta, Tb, Tc, sM, sN, sK,
380375
typename std::enable_if<((!std::is_same_v<Ta, void> && sM != 0))>::type> {
381376
// Validate that parameters are supported
382377
static_assert((sM == 0 && sN == 0 && sK == 0) ||
383-
(is_combination_valid_dpas<Ta, Tb, Tc>(sM, sN, sK)),
384-
"Invalid parameters for DPAS, query valid combinations "
385-
"using: tpu_params<tpu::dpas> myparams; and then check out "
378+
(is_combination_valid_xmx8<Ta, Tb, Tc>(sM, sN, sK)),
379+
"Invalid parameters for XMX8, query valid combinations "
380+
"using: tpu_params<tpu::xmx8> myparams; and then check out "
386381
"myparams.combinations array");
387382

388383
// if combination is valid, construct the matrices
@@ -399,9 +394,200 @@ struct tpu_params<
399394
using joint_matrix_accumulator =
400395
joint_matrix<Group, Tc, use::accumulator, M, N>;
401396

402-
bool dynamic_p = false; // no dynamic allocation on the GPU
403-
uint32_t numtiles = -1; // does not apply for DPAS
404-
scope_t scope = scope_t::sub_group;
397+
uint32_t numtiles = -1; // does not apply for XMX8
398+
static constexpr scope_t scopes[] = {scope_t::sub_group};
399+
static constexpr int num_scopes = sizeof(scopes) / sizeof(scope_t);
400+
};
401+
402+
// Intel XMX with SIMD16 capability
403+
// The Intel XMX implementation supports the logical capability support of the
404+
// HW So in this case, M, N, K sizes returned by the query represent the logical
405+
// capabilities of the Intel XMX hardware.
406+
407+
template <typename Ta, typename Tb, typename Tc>
408+
constexpr bool is_combination_valid_xmx16(int sM, int sN, int sK) {
409+
if ((std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, int8_t> &&
410+
std::is_same_v<Tc, int> && (sM == 1 || sM == 2 || sM == 4 || sM == 8) &&
411+
sN == 16 && sK == 32) ||
412+
(std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, uint8_t> &&
413+
std::is_same_v<Tc, int> && (sM == 1 || sM == 2 || sM == 4 || sM == 8) &&
414+
sN == 16 && sK == 32) ||
415+
(std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, int8_t> &&
416+
std::is_same_v<Tc, int> && (sM == 1 || sM == 2 || sM == 4 || sM == 8) &&
417+
sN == 16 && sK == 32) ||
418+
(std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, uint8_t> &&
419+
std::is_same_v<Tc, int> && (sM == 1 || sM == 2 || sM == 4 || sM == 8) &&
420+
sN == 16 && sK == 32) ||
421+
(std::is_same_v<Ta, half> && std::is_same_v<Tb, half> &&
422+
std::is_same_v<Tc, float> &&
423+
(sM == 1 || sM == 2 || sM == 4 || sM == 8) && sN == 16 && sK == 16) ||
424+
(std::is_same_v<Ta, unsigned short> &&
425+
std::is_same_v<Tb, unsigned short> && std::is_same_v<Tc, float> &&
426+
(sM == 1 || sM == 2 || sM == 4 || sM == 8) && sN == 16 && sK == 16))
427+
return true;
428+
else
429+
return false;
430+
}
431+
432+
template <typename Ta, typename Tb, typename Tc>
433+
constexpr bool are_types_valid_xmx16() {
434+
if ((std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, int8_t> &&
435+
std::is_same_v<Tc, int>) ||
436+
(std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, int8_t> &&
437+
std::is_same_v<Tc, int>) ||
438+
(std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, uint8_t> &&
439+
std::is_same_v<Tc, int>) ||
440+
(std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, uint8_t> &&
441+
std::is_same_v<Tc, int>) ||
442+
(std::is_same_v<Ta, half> && std::is_same_v<Tb, half> &&
443+
std::is_same_v<Tc, float>) ||
444+
(std::is_same_v<Ta, unsigned short> &&
445+
std::is_same_v<Tb, unsigned short> && std::is_same_v<Tc, float>))
446+
return true;
447+
else
448+
return false;
449+
}
450+
451+
// General Query
452+
// specialization for when types are not given --> no default values
453+
template <int sM, int sN, int sK>
454+
struct tpu_params<tpu::xmx16, void, void, void, sM, sN, sK> {
455+
static constexpr std::size_t M = -1; // depends on the type
456+
static constexpr std::size_t N = -1;
457+
static constexpr std::size_t K = -1;
458+
459+
uint32_t numtiles = -1; // does not apply for XMX
460+
static constexpr scope_t scopes[] = {scope_t::sub_group};
461+
static constexpr int num_scopes = sizeof(scopes) / sizeof(scope_t);
462+
463+
struct combination {
464+
uint32_t max_msize;
465+
uint32_t max_nsize;
466+
uint32_t max_ksize;
467+
matrix_type atype;
468+
matrix_type btype;
469+
matrix_type accumulatortype;
470+
uint32_t msize;
471+
uint32_t nsize;
472+
uint32_t ksize;
473+
};
474+
using mt = matrix_type;
475+
static constexpr combination combinations[] = {
476+
{0, 0, 0, mt::sint8, mt::sint8, mt::sint32, 1, 16, 32},
477+
{0, 0, 0, mt::sint8, mt::sint8, mt::sint32, 2, 16, 32},
478+
{0, 0, 0, mt::sint8, mt::sint8, mt::sint32, 4, 16, 32},
479+
{0, 0, 0, mt::sint8, mt::sint8, mt::sint32, 8, 16, 32},
480+
{0, 0, 0, mt::sint8, mt::uint8, mt::sint32, 1, 16, 32},
481+
{0, 0, 0, mt::sint8, mt::uint8, mt::sint32, 2, 16, 32},
482+
{0, 0, 0, mt::sint8, mt::uint8, mt::sint32, 4, 16, 32},
483+
{0, 0, 0, mt::sint8, mt::uint8, mt::sint32, 8, 16, 32},
484+
{0, 0, 0, mt::uint8, mt::sint8, mt::sint32, 1, 16, 32},
485+
{0, 0, 0, mt::uint8, mt::sint8, mt::sint32, 2, 16, 32},
486+
{0, 0, 0, mt::uint8, mt::sint8, mt::sint32, 4, 16, 32},
487+
{0, 0, 0, mt::uint8, mt::sint8, mt::sint32, 8, 16, 32},
488+
{0, 0, 0, mt::uint8, mt::uint8, mt::sint32, 1, 16, 32},
489+
{0, 0, 0, mt::uint8, mt::uint8, mt::sint32, 2, 16, 32},
490+
{0, 0, 0, mt::uint8, mt::uint8, mt::sint32, 4, 16, 32},
491+
{0, 0, 0, mt::uint8, mt::uint8, mt::sint32, 8, 16, 32},
492+
{0, 0, 0, mt::fp16, mt::fp16, mt::fp32, 1, 16, 16},
493+
{0, 0, 0, mt::fp16, mt::fp16, mt::fp32, 2, 16, 16},
494+
{0, 0, 0, mt::fp16, mt::fp16, mt::fp32, 4, 16, 16},
495+
{0, 0, 0, mt::fp16, mt::fp16, mt::fp32, 8, 16, 16},
496+
{0, 0, 0, mt::bf16, mt::bf16, mt::fp32, 1, 16, 16},
497+
{0, 0, 0, mt::bf16, mt::bf16, mt::fp32, 2, 16, 16},
498+
{0, 0, 0, mt::bf16, mt::bf16, mt::fp32, 4, 16, 16},
499+
{0, 0, 0, mt::bf16, mt::bf16, mt::fp32, 8, 16, 16},
500+
};
501+
static constexpr int num_combinations =
502+
sizeof(combinations) / sizeof(combination);
503+
};
504+
505+
// Sizes-only query:
506+
// Specialization for when only types are given, need to query only sizes
507+
508+
template <typename Ta, typename Tb, typename Tc>
509+
struct tpu_params<tpu::xmx16, Ta, Tb, Tc, 0, 0, 0,
510+
typename std::enable_if<(!std::is_same_v<Ta, void> &&
511+
!std::is_same_v<Tb, void> &&
512+
!std::is_same_v<Tc, void>)>::type> {
513+
static_assert((are_types_valid_xmx16<Ta, Tb, Tc>()),
514+
"Invalid types for XMX16, supported types are int8_t, uint8_t, "
515+
"half, and bf16 (Note that unsigned short should be used in the"
516+
"DPC++ code to implement bf16)");
517+
518+
// construct the matrices using the default sizes
519+
520+
static constexpr std::size_t M = 8;
521+
static constexpr std::size_t N = 16;
522+
static constexpr std::size_t K = ((sizeof(Ta) == 1) ? 32 : 16);
523+
524+
template <typename Group, layout Layout>
525+
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
526+
template <typename Group, layout Layout>
527+
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
528+
template <typename Group>
529+
using joint_matrix_accumulator =
530+
joint_matrix<Group, Tc, use::accumulator, M, N>;
531+
532+
uint32_t numtiles = -1; // does not apply for XMX
533+
static constexpr scope_t scopes[] = {scope_t::sub_group};
534+
static constexpr int num_scopes = sizeof(scopes) / sizeof(scope_t);
535+
struct combination {
536+
uint32_t max_msize;
537+
uint32_t max_nsize;
538+
uint32_t max_ksize;
539+
matrix_type atype;
540+
matrix_type btype;
541+
matrix_type accumulatortype;
542+
uint32_t msize;
543+
uint32_t nsize;
544+
uint32_t ksize;
545+
};
546+
using mt = matrix_type;
547+
static constexpr combination combinations[] = {
548+
// The types used in the initialization below are fake and not used. In
549+
// this case, users already chose the types, they are only looking for
550+
// the
551+
// sizes
552+
{0, 0, 0, mt::bf8, mt::bf8, mt::bf8, 1, 16, (sizeof(Ta) == 1) ? 32 : 16},
553+
{0, 0, 0, mt::bf8, mt::bf8, mt::bf8, 2, 16, (sizeof(Ta) == 1) ? 32 : 16},
554+
{0, 0, 0, mt::bf8, mt::bf8, mt::bf8, 4, 16, (sizeof(Ta) == 1) ? 32 : 16},
555+
{0, 0, 0, mt::bf8, mt::bf8, mt::bf8, 8, 16, (sizeof(Ta) == 1) ? 32 : 16},
556+
};
557+
static constexpr int num_combinations =
558+
sizeof(combinations) / sizeof(combination);
559+
};
560+
561+
// Valid or not:
562+
// Specialization when both types and sizes are given
563+
template <typename Ta, typename Tb, typename Tc, int sM, int sN, int sK>
564+
struct tpu_params<
565+
tpu::xmx16, Ta, Tb, Tc, sM, sN, sK,
566+
typename std::enable_if<((!std::is_same_v<Ta, void> && sM != 0))>::type> {
567+
// Validate that parameters are supported
568+
static_assert((sM == 0 && sN == 0 && sK == 0) ||
569+
(is_combination_valid_xmx16<Ta, Tb, Tc>(sM, sN, sK)),
570+
"Invalid parameters for XMX16, query valid combinations "
571+
"using: tpu_params<tpu::xmx16> myparams; and then check out "
572+
"myparams.combinations array");
573+
574+
// if combination is valid, construct the matrices
575+
static constexpr std::size_t M = (sM != 0) ? sM : 8;
576+
static constexpr std::size_t N = (sN != 0) ? sN : 8;
577+
static constexpr std::size_t K =
578+
(sK != 0) ? sK : ((sizeof(Ta) == 1) ? 32 : 16);
579+
580+
template <typename Group, layout Layout>
581+
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
582+
template <typename Group, layout Layout>
583+
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
584+
template <typename Group>
585+
using joint_matrix_accumulator =
586+
joint_matrix<Group, Tc, use::accumulator, M, N>;
587+
588+
uint32_t numtiles = -1; // does not apply for XMX16
589+
static constexpr scope_t scopes[] = {scope_t::sub_group};
590+
static constexpr int num_scopes = sizeof(scopes) / sizeof(scope_t);
405591
};
406592
} // namespace experimental::matrix
407593
} // namespace oneapi

0 commit comments

Comments
 (0)