@@ -70,11 +70,9 @@ void matrix_multiply(big_matrix<TC, M, N> &C, big_matrix<TA, M, K> &A,
70
70
}).wait ();
71
71
}
72
72
73
- template <size_t TN, size_t TK, class kernel_name , typename TA, typename TB ,
74
- typename TC>
73
+ template <size_t TM, size_t TN, size_t TK, class kernel_name , typename TA,
74
+ typename TB, typename TC>
75
75
int gemm_row_major () {
76
- static constexpr size_t TM = 8 ;
77
-
78
76
static constexpr size_t MATRIX_M = TM * 2 ;
79
77
static constexpr size_t MATRIX_N = TN * 2 ;
80
78
static constexpr size_t MATRIX_K = TK * 2 ;
@@ -98,6 +96,7 @@ int gemm_row_major() {
98
96
matrix_multiply_ref ((TA *)A, (TB *)B, (TC *)D, MATRIX_M, MATRIX_N, MATRIX_K);
99
97
100
98
bool res = matrix_compare (MATRIX_M, MATRIX_N, (TC *)C, (TC *)D);
99
+ std::cout << TM << " x" << TN << " x" << TK << " : " ;
101
100
std::cout << (res ? " passed" : " failed" ) << std::endl;
102
101
return !res;
103
102
}
@@ -108,42 +107,43 @@ int main() {
108
107
q.get_device ()
109
108
.get_info <sycl::ext::oneapi::experimental::info::device::
110
109
matrix_combinations>();
111
- for (unsigned int i = 0 ; i < combinations.size (); i++) {
112
- if (combinations[i].atype == matrix_type::bf16 ) {
113
- if (combinations[i].nsize == 0 ||
114
- (combinations[i].nsize == 16 && combinations[i].max_msize == 8 &&
115
- combinations[i].ksize == 16 )) {
116
- gemm_row_major<16 , 16 , class gemm_bfloat16_16 , bfloat16, bfloat16,
117
- float >();
118
- }
119
- if (combinations[i].nsize == 8 && combinations[i].max_msize == 8 &&
120
- combinations[i].ksize == 16 ) {
121
- gemm_row_major<8 , 16 , class gemm_bfloat16_8 , bfloat16, bfloat16,
122
- float >();
110
+ int res = 0 ;
111
+ for (auto &combination : combinations) {
112
+ if (combination.nsize == 0 ||
113
+ combination.nsize == 16 ) { // Intel AMX or architecture::intel_gpu_pvc
114
+ res += gemm_row_major<8 , 16 , 16 , class bf16_8x16x16 , bfloat16, bfloat16,
115
+ float >();
116
+ res += gemm_row_major<8 , 16 , 32 , class ss_8x16x32 , int8_t , int8_t ,
117
+ int32_t >();
118
+ res += gemm_row_major<8 , 16 , 32 , class us_8x16x32 , uint8_t , int8_t ,
119
+ int32_t >();
120
+ res += gemm_row_major<8 , 16 , 32 , class su_8x16x32 , int8_t , uint8_t ,
121
+ int32_t >();
122
+ res += gemm_row_major<8 , 16 , 32 , class uu_8x16x32 , uint8_t , uint8_t ,
123
+ int32_t >();
124
+
125
+ if (combination.nsize == 16 ) { // architecture::intel_gpu_pvc
126
+ res += gemm_row_major<1 , 64 , 16 , class bf16_1x64x16 , bfloat16, bfloat16,
127
+ float >();
128
+ res += gemm_row_major<32 , 64 , 16 , class bf16_32x64x16 , bfloat16,
129
+ bfloat16, float >();
123
130
}
131
+ break ;
124
132
}
125
- if (combinations[i].atype == matrix_type::sint8 &&
126
- combinations[i].btype == matrix_type::sint8) {
127
- if (combinations[i].nsize == 0 ||
128
- (combinations[i].nsize == 16 && combinations[i].max_msize == 8 &&
129
- combinations[i].ksize == 32 )) {
130
- gemm_row_major<16 , 32 , class gemm_int8_16 , int8_t , int8_t , int32_t >();
131
- gemm_row_major<16 , 32 , class gemm_us_int8_16 , uint8_t , int8_t ,
132
- int32_t >();
133
- gemm_row_major<16 , 32 , class gemm_su_int8_16 , int8_t , uint8_t ,
134
- int32_t >();
135
- gemm_row_major<16 , 32 , class gemm_uu_int8_16 , uint8_t , uint8_t ,
136
- int32_t >();
137
- }
138
- if (combinations[i].nsize == 8 && combinations[i].max_msize == 8 &&
139
- combinations[i].ksize == 32 ) {
140
- gemm_row_major<8 , 32 , class gemm_int8_8 , int8_t , int8_t , int32_t >();
141
- gemm_row_major<8 , 32 , class gemm_us_int8_8 , uint8_t , int8_t , int32_t >();
142
- gemm_row_major<8 , 32 , class gemm_su_int8_8 , int8_t , uint8_t , int32_t >();
143
- gemm_row_major<8 , 32 , class gemm_uu_int8_8 , uint8_t , uint8_t ,
144
- int32_t >();
145
- }
133
+
134
+ if (combination.nsize == 8 ) { // architecture::intel_gpu_dg2*
135
+ res += gemm_row_major<8 , 8 , 16 , class bf16_8x8x16 , bfloat16, bfloat16,
136
+ float >();
137
+ res +=
138
+ gemm_row_major<8 , 8 , 32 , class ss_8x8x32 , int8_t , int8_t , int32_t >();
139
+ res +=
140
+ gemm_row_major<8 , 8 , 32 , class us_8x8x32 , uint8_t , int8_t , int32_t >();
141
+ res +=
142
+ gemm_row_major<8 , 8 , 32 , class su_8x8x32 , int8_t , uint8_t , int32_t >();
143
+ res += gemm_row_major<8 , 8 , 32 , class uu_8x8x32 , uint8_t , uint8_t ,
144
+ int32_t >();
145
+ break ;
146
146
}
147
147
}
148
- return 0 ;
148
+ return res ;
149
149
}
0 commit comments