@@ -28,7 +28,7 @@ bool apply_verify(Tc *C, Tc *D, Ta *A, Ta *Ar) {
28
28
return true ;
29
29
}
30
30
template <typename Tc, typename Ta, size_t TM, size_t TN, size_t TK, size_t M,
31
- size_t N, class kernel_name >
31
+ size_t N, size_t K, class kernel_name >
32
32
bool apply_two_matrices (Tc *C, Tc *D, Ta *A, Ta *Ar, queue q) {
33
33
size_t NDRangeM = M / TM;
34
34
size_t NDRangeN = N / TN;
@@ -76,13 +76,13 @@ bool apply_two_matrices(Tc *C, Tc *D, Ta *A, Ta *Ar, queue q) {
76
76
sg, sub_d, pD + (sg_startx * TM) * N + sg_starty / sg_size * TN,
77
77
N, layout::row_major);
78
78
joint_matrix_load (
79
- sg, sub_a, pA + (sg_startx * TM) * N + sg_starty / sg_size * TK,
80
- N );
79
+ sg, sub_a, pA + (sg_startx * TM) * K + sg_starty / sg_size * TK,
80
+ K );
81
81
joint_matrix_apply (sg, sub_a, sub_ar,
82
82
[](const Ta &x, Ta &y) { y = x + 42 ; });
83
83
ext::intel::experimental::matrix::joint_matrix_store (
84
84
sg, sub_ar,
85
- pAr + (sg_startx * TM) * N + sg_starty / sg_size * TK, N );
85
+ pAr + (sg_startx * TM) * K + sg_starty / sg_size * TK, K );
86
86
}); // parallel for
87
87
}).wait ();
88
88
return apply_verify<Tc, Ta, M, N>(C, D, A, Ar);
@@ -91,27 +91,27 @@ bool apply_two_matrices(Tc *C, Tc *D, Ta *A, Ta *Ar, queue q) {
91
91
template <typename Ta, typename Tc, size_t TM, size_t TN, size_t TK,
92
92
class kernel_name >
93
93
bool test () {
94
-
95
94
static constexpr size_t M = TM * 2 ;
96
95
static constexpr size_t N = TN * 2 ;
96
+ static constexpr size_t K = TK * 2 ;
97
97
queue q;
98
98
99
99
Tc *C = malloc_shared<Tc>(M * N, q);
100
100
Tc *D = malloc_shared<Tc>(M * N, q);
101
- Ta *A = malloc_shared<Ta>(M * N , q);
102
- Ta *Ar = malloc_shared<Ta>(M * N , q);
101
+ Ta *A = malloc_shared<Ta>(M * K , q);
102
+ Ta *Ar = malloc_shared<Ta>(M * K , q);
103
103
104
104
matrix_rand (M, N, (Tc *)C, (Tc)100 );
105
- matrix_rand (M, N , (Ta *)A, (Ta)100 );
105
+ matrix_rand (M, K , (Ta *)A, (Ta)100 );
106
106
107
- bool res =
108
- apply_two_matrices<Tc, Ta, TM, TN, TK, M, N, kernel_name>( C, D, A, Ar, q);
107
+ bool res = apply_two_matrices<Tc, Ta, TM, TN, TK, M, N, K, kernel_name>(
108
+ C, D, A, Ar, q);
109
109
110
110
if constexpr (std::is_same_v<Ta, bfloat16>)
111
- std::cout << " bfloat16 " << TM << " x" << TN << " : "
111
+ std::cout << " bfloat16 " << TM << " x" << TN << " x " << TK << " : "
112
112
<< (res ? " passed" : " failed" ) << std::endl;
113
113
else if constexpr (std::is_same_v<Ta, int8_t >)
114
- std::cout << " int8_t " << TM << " x" << TN << " : "
114
+ std::cout << " int8_t " << TM << " x" << TN << " x " << TK << " : "
115
115
<< (res ? " passed" : " failed" ) << std::endl;
116
116
return res;
117
117
}
@@ -126,8 +126,8 @@ int main() {
126
126
bool passed = true ;
127
127
for (unsigned int i = 0 ; i < combinations.size (); i++) {
128
128
if (combinations[i].nsize == 0 ) { // Intel AMX
129
- passed &= test<int8_t , int32_t , 8 , 16 , 32 , class amx_int_8x16x32 >();
130
- passed &= test<bfloat16, float , 8 , 16 , 32 , class amx_bf16_8x16x32 >();
129
+ passed &= test<int8_t , int32_t , 16 , 16 , 64 , class amx_int_16x16x64 >();
130
+ passed &= test<bfloat16, float , 16 , 16 , 32 , class amx_bf16_16x16x32 >();
131
131
break ;
132
132
}
133
133
0 commit comments