Skip to content

Commit d59d1da

Browse files
authored
[SYCL][Joint Matrix][E2E] Add test for B row_major bfloat16 32x64x16,1x64x16 (#13554)
1 parent e1119d9 commit d59d1da

File tree

1 file changed

+38
-38
lines changed

1 file changed

+38
-38
lines changed

sycl/test-e2e/Matrix/joint_matrix_rowmajorA_rowmajorB_impl.hpp

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,9 @@ void matrix_multiply(big_matrix<TC, M, N> &C, big_matrix<TA, M, K> &A,
7070
}).wait();
7171
}
7272

73-
template <size_t TN, size_t TK, class kernel_name, typename TA, typename TB,
74-
typename TC>
73+
template <size_t TM, size_t TN, size_t TK, class kernel_name, typename TA,
74+
typename TB, typename TC>
7575
int gemm_row_major() {
76-
static constexpr size_t TM = 8;
77-
7876
static constexpr size_t MATRIX_M = TM * 2;
7977
static constexpr size_t MATRIX_N = TN * 2;
8078
static constexpr size_t MATRIX_K = TK * 2;
@@ -98,6 +96,7 @@ int gemm_row_major() {
9896
matrix_multiply_ref((TA *)A, (TB *)B, (TC *)D, MATRIX_M, MATRIX_N, MATRIX_K);
9997

10098
bool res = matrix_compare(MATRIX_M, MATRIX_N, (TC *)C, (TC *)D);
99+
std::cout << TM << "x" << TN << "x" << TK << ": ";
101100
std::cout << (res ? "passed" : "failed") << std::endl;
102101
return !res;
103102
}
@@ -108,42 +107,43 @@ int main() {
108107
q.get_device()
109108
.get_info<sycl::ext::oneapi::experimental::info::device::
110109
matrix_combinations>();
111-
for (unsigned int i = 0; i < combinations.size(); i++) {
112-
if (combinations[i].atype == matrix_type::bf16) {
113-
if (combinations[i].nsize == 0 ||
114-
(combinations[i].nsize == 16 && combinations[i].max_msize == 8 &&
115-
combinations[i].ksize == 16)) {
116-
gemm_row_major<16, 16, class gemm_bfloat16_16, bfloat16, bfloat16,
117-
float>();
118-
}
119-
if (combinations[i].nsize == 8 && combinations[i].max_msize == 8 &&
120-
combinations[i].ksize == 16) {
121-
gemm_row_major<8, 16, class gemm_bfloat16_8, bfloat16, bfloat16,
122-
float>();
110+
int res = 0;
111+
for (auto &combination : combinations) {
112+
if (combination.nsize == 0 ||
113+
combination.nsize == 16) { // Intel AMX or architecture::intel_gpu_pvc
114+
res += gemm_row_major<8, 16, 16, class bf16_8x16x16, bfloat16, bfloat16,
115+
float>();
116+
res += gemm_row_major<8, 16, 32, class ss_8x16x32, int8_t, int8_t,
117+
int32_t>();
118+
res += gemm_row_major<8, 16, 32, class us_8x16x32, uint8_t, int8_t,
119+
int32_t>();
120+
res += gemm_row_major<8, 16, 32, class su_8x16x32, int8_t, uint8_t,
121+
int32_t>();
122+
res += gemm_row_major<8, 16, 32, class uu_8x16x32, uint8_t, uint8_t,
123+
int32_t>();
124+
125+
if (combination.nsize == 16) { // architecture::intel_gpu_pvc
126+
res += gemm_row_major<1, 64, 16, class bf16_1x64x16, bfloat16, bfloat16,
127+
float>();
128+
res += gemm_row_major<32, 64, 16, class bf16_32x64x16, bfloat16,
129+
bfloat16, float>();
123130
}
131+
break;
124132
}
125-
if (combinations[i].atype == matrix_type::sint8 &&
126-
combinations[i].btype == matrix_type::sint8) {
127-
if (combinations[i].nsize == 0 ||
128-
(combinations[i].nsize == 16 && combinations[i].max_msize == 8 &&
129-
combinations[i].ksize == 32)) {
130-
gemm_row_major<16, 32, class gemm_int8_16, int8_t, int8_t, int32_t>();
131-
gemm_row_major<16, 32, class gemm_us_int8_16, uint8_t, int8_t,
132-
int32_t>();
133-
gemm_row_major<16, 32, class gemm_su_int8_16, int8_t, uint8_t,
134-
int32_t>();
135-
gemm_row_major<16, 32, class gemm_uu_int8_16, uint8_t, uint8_t,
136-
int32_t>();
137-
}
138-
if (combinations[i].nsize == 8 && combinations[i].max_msize == 8 &&
139-
combinations[i].ksize == 32) {
140-
gemm_row_major<8, 32, class gemm_int8_8, int8_t, int8_t, int32_t>();
141-
gemm_row_major<8, 32, class gemm_us_int8_8, uint8_t, int8_t, int32_t>();
142-
gemm_row_major<8, 32, class gemm_su_int8_8, int8_t, uint8_t, int32_t>();
143-
gemm_row_major<8, 32, class gemm_uu_int8_8, uint8_t, uint8_t,
144-
int32_t>();
145-
}
133+
134+
if (combination.nsize == 8) { // architecture::intel_gpu_dg2*
135+
res += gemm_row_major<8, 8, 16, class bf16_8x8x16, bfloat16, bfloat16,
136+
float>();
137+
res +=
138+
gemm_row_major<8, 8, 32, class ss_8x8x32, int8_t, int8_t, int32_t>();
139+
res +=
140+
gemm_row_major<8, 8, 32, class us_8x8x32, uint8_t, int8_t, int32_t>();
141+
res +=
142+
gemm_row_major<8, 8, 32, class su_8x8x32, int8_t, uint8_t, int32_t>();
143+
res += gemm_row_major<8, 8, 32, class uu_8x8x32, uint8_t, uint8_t,
144+
int32_t>();
145+
break;
146146
}
147147
}
148-
return 0;
148+
return res;
149149
}

0 commit comments

Comments
 (0)