Skip to content

Commit 4881d6d

Browse files
[SYCL][Joint Matrix][E2E] Uncomment the Joint Matrix tests for combination 32x32x16 (#16191)
Description: The support for combination 32x32x16 on DG2 has been implemented in commit intel/intel-graphics-compiler@6a06c93. This PR re-enables the Joint Matrix tests for the combination 32x32x16, which were previously unsupported. This PR also adds XFAIL: gpu to two tests, joint_matrix_bf16_fill_k_cache_arg_dim.cpp and joint_matrix_bf16_fill_k_cache_runtime_dim.cpp, because the driver has not yet been updated to include the necessary changes in IGC to support these tests. As a result, these tests should not be tested in CI at this time.
1 parent 74d4e9d commit 4881d6d

6 files changed

+18
-9
lines changed

sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,9 +242,9 @@ int main() {
242242
test_ewops_ab<bfloat16, 8, 16, use::a, layout::row_major, 1>();
243243
test_ewops_ab<bfloat16, 16, 8, use::b, layout::ext_intel_packed, 2>();
244244
test_ewops_c<float, 8, 8>();
245-
// test_ewops_ab<bfloat16, 32, 16, use::a, layout::row_major, 1>();
246-
// test_ewops_ab<bfloat16, 16, 32, use::b, layout::ext_intel_packed, 2>();
247-
// test_ewops_c<float, 32, 32>();
245+
test_ewops_ab<bfloat16, 32, 16, use::a, layout::row_major, 1>();
246+
test_ewops_ab<bfloat16, 16, 32, use::b, layout::ext_intel_packed, 2>();
247+
test_ewops_c<float, 32, 32>();
248248
break;
249249
}
250250
}

sycl/test-e2e/Matrix/element_wise_ops_impl.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,7 @@ int main() {
142142
passed &= test<uint8_t, int32_t, 8, 8, 32, 4, class dg2_uint_8x8x32>();
143143
passed &= test<int8_t, int32_t, 8, 8, 32, 4, class dg2_sint_8x8x32>();
144144
passed &= test<bfloat16, float, 8, 8, 16, 2, class dg2_bf16_8x16x16>();
145-
// passed &= test<bfloat16, float, 32, 32, 16, 2, class
146-
// dg2_bf16_32x32x16>();
145+
passed &= test<bfloat16, float, 32, 32, 16, 2, class dg2_bf16_32x32x16>();
147146
break;
148147
}
149148
}

sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_arg_dim.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,10 @@
1212

1313
// -ffp-model=precise is added to not depend on compiler defaults.
1414

15+
// Waiting for the commit in IGC to be pulled into the driver to resolve the
16+
// test.
17+
// XFAIL: gpu
18+
// XFAIL-TRACKER: CMPLRLLVM-63710
19+
1520
#include "common.hpp"
1621
#include "joint_matrix_bf16_fill_k_cache_impl.hpp"

sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -493,8 +493,8 @@ size_t matrix_size = -1;
493493

494494
test<bfloat16, float, VnniFactor, /*TM*/ 8, /*TN*/ 8, /*TK*/ 16, MCache1,
495495
NCache1, KCache1, MCache2, NCache2, KCache2>(matrix_size);
496-
// test<bfloat16, float, VnniFactor, /*TM*/ 32, /*TN*/ 32, /*TK*/ 16, MCache1,
497-
// NCache1, KCache1, MCache2, NCache2, KCache2>(matrix_size);
496+
test<bfloat16, float, VnniFactor, /*TM*/ 32, /*TN*/ 32, /*TK*/ 16,
497+
MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>(matrix_size);
498498
break;
499499
}
500500
}

sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_runtime_dim.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,10 @@
1212

1313
// -ffp-model=precise is added to not depend on compiler defaults.
1414

15+
// Waiting for the commit in IGC to be pulled into the driver to resolve the
16+
// test.
17+
// XFAIL: gpu
18+
// XFAIL-TRACKER: CMPLRLLVM-63710
19+
1520
#include "common.hpp"
1621
#include "joint_matrix_bf16_fill_k_cache_impl.hpp"

sycl/test-e2e/Matrix/joint_matrix_rowmajorA_rowmajorB_impl.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ int main() {
144144
gemm_row_major<8, 8, 32, class su_8x8x32, int8_t, uint8_t, int32_t>();
145145
res += gemm_row_major<8, 8, 32, class uu_8x8x32, uint8_t, uint8_t,
146146
int32_t>();
147-
// res += gemm_row_major<32, 32, 16, class dg2_bf16_32x32x16, bfloat16,
148-
// bfloat16, float>();
147+
res += gemm_row_major<32, 32, 16, class dg2_bf16_32x32x16, bfloat16,
148+
bfloat16, float>();
149149
break;
150150
}
151151
}

0 commit comments

Comments
 (0)