Skip to content

Commit ad18dd1

Browse files
committed
Added a note that bf16 uses uint16_t.
Signed-off-by: jack.kirk <[email protected]>
1 parent 4d756a4 commit ad18dd1

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ __SYCL_JOINT_MATRIX_OVERLOAD(double, b, 4, 8, double, 1)
3333
__SYCL_JOINT_MATRIX_OVERLOAD(double, accumulator, 8, 8, double, 2)
3434

3535
// m8n32k16
36+
// bf16 data format uses uint16_t data type
3637
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 8, 16, int32_t, 2)
3738
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 32, int32_t, 8)
3839
__SYCL_JOINT_MATRIX_OVERLOAD(half, a, 8, 16, int32_t, 8)
@@ -62,10 +63,10 @@ __SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 8, int32_t, 1)
6263
__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 32, 8, int32_t, 8)
6364

6465
// m16n16k16
65-
__SYCL_JOINT_MATRIX_OVERLOAD(half, a, 16, 16, int32_t, 8)
66-
__SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 16, int32_t, 8)
6766
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 16, 16, int32_t, 4)
6867
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 16, int32_t, 4)
68+
__SYCL_JOINT_MATRIX_OVERLOAD(half, a, 16, 16, int32_t, 8)
69+
__SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 16, int32_t, 8)
6970
__SYCL_JOINT_MATRIX_OVERLOAD(float, accumulator, 16, 16, float, 8)
7071
__SYCL_JOINT_MATRIX_OVERLOAD(half, accumulator, 16, 16, int32_t, 4)
7172

@@ -508,7 +509,7 @@ void joint_matrix_load(
508509
multi_ptr<T, Space> src, size_t stride) {
509510
#ifdef __SYCL_DEVICE_ONLY__
510511
#ifdef __NVPTX__
511-
sycl::ext::oneapi::detail::joint_matrix_load_impl<T, Use, NumRows, NumCols,
512+
sycl::ext::oneapi::detail::joint_matrix_load_impl<T, Use, NumRows, NumCols,
512513
Layout, Space>{}
513514
.load(res, src, stride);
514515
#endif
@@ -530,7 +531,7 @@ void joint_matrix_store(Group sg,
530531
multi_ptr<T, Space> dst, size_t stride) {
531532
#ifdef __SYCL_DEVICE_ONLY__
532533
#ifdef __NVPTX__
533-
sycl::ext::oneapi::detail::joint_matrix_store_impl<T, NumRows, NumCols,
534+
sycl::ext::oneapi::detail::joint_matrix_store_impl<T, NumRows, NumCols,
534535
Layout, Space>{}
535536
.store(src, dst, stride);
536537
#endif

0 commit comments

Comments
 (0)