Skip to content

Commit 33102b6

Browse files
authored
[SYCL][JM tests] Fix apply test on SPR (#13321)
The test was failing on SPR. This provides a fix. I also changed values for tM, tK, tN on AMX, setting them to the maximum to improve test coverage as they don't need to match the PVC size anymore
1 parent cb28e09 commit 33102b6

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

sycl/test-e2e/Matrix/joint_matrix_apply_two_matrices_impl.hpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ bool apply_verify(Tc *C, Tc *D, Ta *A, Ta *Ar) {
2828
return true;
2929
}
3030
template <typename Tc, typename Ta, size_t TM, size_t TN, size_t TK, size_t M,
31-
size_t N, class kernel_name>
31+
size_t N, size_t K, class kernel_name>
3232
bool apply_two_matrices(Tc *C, Tc *D, Ta *A, Ta *Ar, queue q) {
3333
size_t NDRangeM = M / TM;
3434
size_t NDRangeN = N / TN;
@@ -76,13 +76,13 @@ bool apply_two_matrices(Tc *C, Tc *D, Ta *A, Ta *Ar, queue q) {
7676
sg, sub_d, pD + (sg_startx * TM) * N + sg_starty / sg_size * TN,
7777
N, layout::row_major);
7878
joint_matrix_load(
79-
sg, sub_a, pA + (sg_startx * TM) * N + sg_starty / sg_size * TK,
80-
N);
79+
sg, sub_a, pA + (sg_startx * TM) * K + sg_starty / sg_size * TK,
80+
K);
8181
joint_matrix_apply(sg, sub_a, sub_ar,
8282
[](const Ta &x, Ta &y) { y = x + 42; });
8383
ext::intel::experimental::matrix::joint_matrix_store(
8484
sg, sub_ar,
85-
pAr + (sg_startx * TM) * N + sg_starty / sg_size * TK, N);
85+
pAr + (sg_startx * TM) * K + sg_starty / sg_size * TK, K);
8686
}); // parallel for
8787
}).wait();
8888
return apply_verify<Tc, Ta, M, N>(C, D, A, Ar);
@@ -91,27 +91,27 @@ bool apply_two_matrices(Tc *C, Tc *D, Ta *A, Ta *Ar, queue q) {
9191
template <typename Ta, typename Tc, size_t TM, size_t TN, size_t TK,
9292
class kernel_name>
9393
bool test() {
94-
9594
static constexpr size_t M = TM * 2;
9695
static constexpr size_t N = TN * 2;
96+
static constexpr size_t K = TK * 2;
9797
queue q;
9898

9999
Tc *C = malloc_shared<Tc>(M * N, q);
100100
Tc *D = malloc_shared<Tc>(M * N, q);
101-
Ta *A = malloc_shared<Ta>(M * N, q);
102-
Ta *Ar = malloc_shared<Ta>(M * N, q);
101+
Ta *A = malloc_shared<Ta>(M * K, q);
102+
Ta *Ar = malloc_shared<Ta>(M * K, q);
103103

104104
matrix_rand(M, N, (Tc *)C, (Tc)100);
105-
matrix_rand(M, N, (Ta *)A, (Ta)100);
105+
matrix_rand(M, K, (Ta *)A, (Ta)100);
106106

107-
bool res =
108-
apply_two_matrices<Tc, Ta, TM, TN, TK, M, N, kernel_name>(C, D, A, Ar, q);
107+
bool res = apply_two_matrices<Tc, Ta, TM, TN, TK, M, N, K, kernel_name>(
108+
C, D, A, Ar, q);
109109

110110
if constexpr (std::is_same_v<Ta, bfloat16>)
111-
std::cout << "bfloat16 " << TM << "x" << TN << ": "
111+
std::cout << "bfloat16 " << TM << "x" << TN << "x" << TK << ": "
112112
<< (res ? "passed" : "failed") << std::endl;
113113
else if constexpr (std::is_same_v<Ta, int8_t>)
114-
std::cout << "int8_t " << TM << "x" << TN << ": "
114+
std::cout << "int8_t " << TM << "x" << TN << "x" << TK << ": "
115115
<< (res ? "passed" : "failed") << std::endl;
116116
return res;
117117
}
@@ -126,8 +126,8 @@ int main() {
126126
bool passed = true;
127127
for (unsigned int i = 0; i < combinations.size(); i++) {
128128
if (combinations[i].nsize == 0) { // Intel AMX
129-
passed &= test<int8_t, int32_t, 8, 16, 32, class amx_int_8x16x32>();
130-
passed &= test<bfloat16, float, 8, 16, 32, class amx_bf16_8x16x32>();
129+
passed &= test<int8_t, int32_t, 16, 16, 64, class amx_int_16x16x64>();
130+
passed &= test<bfloat16, float, 16, 16, 32, class amx_bf16_16x16x32>();
131131
break;
132132
}
133133

0 commit comments

Comments
 (0)