Skip to content

Commit f995f55

Browse files
[SYCL][Joint Matrix][E2E] Add Joint Matrix support and tests for combination 32x32x16 for DG2 (#14753)
Description: 1. The tests are added for testing joint_martix_store, joint_matrix_load, joint_matrix_apply and joint_matrix_mad for combination 32x32x16 for SYCL Joint Matrix. These includes Matrix A bfloat16 32x16 PackedA_RowMajor, Matrix B bfloat16 16x32 PackedB_RowMajor, Matrix B bfloat16 16x32 PackedB_PackedB, Matrix C float 32x32 Accumulator_RowMajor. 2. Modify sycl/source/detail/device_info.hpp and sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc to include the support for combination 32x32x16 for DG2 device.
1 parent 6dd3892 commit f995f55

File tree

6 files changed

+15
-3
lines changed

6 files changed

+15
-3
lines changed

sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,12 +1015,13 @@ architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12`
10151015
`architecture::intel_gpu_pvc`
10161016
|8| `architecture::intel_gpu_dg2_g10,
10171017
architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12`
1018-
.5+| `matrix_type::bf16` .5+| `matrix_type::bf16` .5+|
1018+
.6+| `matrix_type::bf16` .6+| `matrix_type::bf16` .6+|
10191019
`matrix_type::fp32` | 16 | 16 | 16 .4+|`architecture::intel_gpu_pvc`
10201020
| 1 | 64 | 16 | 32 | 64 | 16
10211021
.2+| +<=+ 8 | 16 .2+| 16
1022-
|8 | `architecture::intel_gpu_dg2_g10,
1022+
|8 .2+| `architecture::intel_gpu_dg2_g10,
10231023
architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12`
1024+
.1+| 32 .1+| 32 .1+| 16
10241025
| `matrix_type::tf32` | `matrix_type::tf32` |
10251026
`matrix_type::fp32` | +<=+ 8 | 16 | 8 |
10261027
`architecture::intel_gpu_pvc`
@@ -1147,4 +1148,4 @@ load/store overloads
11471148
|11 |2024-04-29 |Yury Plyakhin | Add 1x64x16 supported combination for
11481149
Intel XMX (intel_gpu_pvc)
11491150
|12 |2024-06-14 |Jack Kirk | Add note on sm version device matching issue.
1150-
|======================
1151+
|======================

sycl/source/detail/device_info.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,8 @@ struct get_device_info_impl<
839839
matrix_type::fp32, matrix_type::fp32},
840840
{8, 0, 0, 0, 8, 16, matrix_type::bf16, matrix_type::bf16,
841841
matrix_type::fp32, matrix_type::fp32},
842+
{0, 0, 0, 32, 32, 16, matrix_type::bf16, matrix_type::bf16,
843+
matrix_type::fp32, matrix_type::fp32},
842844
};
843845
else if (architecture::amd_gpu_gfx90a == DeviceArch)
844846
return {

sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ int main() {
240240
test_ewops_ab<bfloat16, 8, 16, use::a, layout::row_major, 1>();
241241
test_ewops_ab<bfloat16, 16, 8, use::b, layout::ext_intel_packed, 2>();
242242
test_ewops_c<float, 8, 8>();
243+
// test_ewops_ab<bfloat16, 32, 16, use::a, layout::row_major, 1>();
244+
// test_ewops_ab<bfloat16, 16, 32, use::b, layout::ext_intel_packed, 2>();
245+
// test_ewops_c<float, 32, 32>();
243246
break;
244247
}
245248
}

sycl/test-e2e/Matrix/element_wise_ops_impl.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ int main() {
142142
passed &= test<uint8_t, int32_t, 8, 8, 32, 4, class dg2_uint_8x8x32>();
143143
passed &= test<int8_t, int32_t, 8, 8, 32, 4, class dg2_sint_8x8x32>();
144144
passed &= test<bfloat16, float, 8, 8, 16, 2, class dg2_bf16_8x16x16>();
145+
// passed &= test<bfloat16, float, 32, 32, 16, 2, class
146+
// dg2_bf16_32x32x16>();
145147
break;
146148
}
147149
}

sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,8 @@ int main() {
444444

445445
test<bfloat16, float, 2, /*TM*/ 8, /*TN*/ 8, /*TK*/ 16, MCache1, NCache1,
446446
KCache1, MCache2, NCache2, KCache2>();
447+
// test<bfloat16, float, 2, /*TM*/ 32, /*TN*/ 32, /*TK*/ 16, MCache1,
448+
// NCache1, KCache1, MCache2, NCache2, KCache2>();
447449
break;
448450
}
449451
}

sycl/test-e2e/Matrix/joint_matrix_rowmajorA_rowmajorB_impl.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ int main() {
144144
gemm_row_major<8, 8, 32, class su_8x8x32, int8_t, uint8_t, int32_t>();
145145
res += gemm_row_major<8, 8, 32, class uu_8x8x32, uint8_t, uint8_t,
146146
int32_t>();
147+
// res += gemm_row_major<32, 32, 16, class dg2_bf16_32x32x16, bfloat16,
148+
// bfloat16, float>();
147149
break;
148150
}
149151
}

0 commit comments

Comments
 (0)