@@ -33,6 +33,7 @@ __SYCL_JOINT_MATRIX_OVERLOAD(double, b, 4, 8, double, 1)
33
33
__SYCL_JOINT_MATRIX_OVERLOAD (double , accumulator, 8 , 8 , double , 2 )
34
34
35
35
// m8n32k16
36
+ // bf16 data format uses uint16_t data type
36
37
__SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , a, 8 , 16 , int32_t , 2 )
37
38
__SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , b, 16 , 32 , int32_t , 8 )
38
39
__SYCL_JOINT_MATRIX_OVERLOAD (half, a, 8 , 16 , int32_t , 8 )
@@ -62,10 +63,10 @@ __SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 8, int32_t, 1)
62
63
__SYCL_JOINT_MATRIX_OVERLOAD (int32_t , accumulator, 32 , 8 , int32_t , 8 )
63
64
64
65
// m16n16k16
65
- __SYCL_JOINT_MATRIX_OVERLOAD (half, a, 16 , 16 , int32_t , 8 )
66
- __SYCL_JOINT_MATRIX_OVERLOAD (half, b, 16 , 16 , int32_t , 8 )
67
66
__SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , a, 16 , 16 , int32_t , 4 )
68
67
__SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , b, 16 , 16 , int32_t , 4 )
68
+ __SYCL_JOINT_MATRIX_OVERLOAD (half, a, 16 , 16 , int32_t , 8 )
69
+ __SYCL_JOINT_MATRIX_OVERLOAD (half, b, 16 , 16 , int32_t , 8 )
69
70
__SYCL_JOINT_MATRIX_OVERLOAD (float , accumulator, 16 , 16 , float , 8 )
70
71
__SYCL_JOINT_MATRIX_OVERLOAD (half, accumulator, 16 , 16 , int32_t , 4 )
71
72
@@ -508,7 +509,7 @@ void joint_matrix_load(
508
509
multi_ptr<T, Space> src, size_t stride) {
509
510
#ifdef __SYCL_DEVICE_ONLY__
510
511
#ifdef __NVPTX__
511
- sycl::ext::oneapi::detail::joint_matrix_load_impl<T, Use, NumRows, NumCols,
512
+ sycl::ext::oneapi::detail::joint_matrix_load_impl<T, Use, NumRows, NumCols,
512
513
Layout, Space>{}
513
514
.load (res, src, stride);
514
515
#endif
@@ -530,7 +531,7 @@ void joint_matrix_store(Group sg,
530
531
multi_ptr<T, Space> dst, size_t stride) {
531
532
#ifdef __SYCL_DEVICE_ONLY__
532
533
#ifdef __NVPTX__
533
- sycl::ext::oneapi::detail::joint_matrix_store_impl<T, NumRows, NumCols,
534
+ sycl::ext::oneapi::detail::joint_matrix_store_impl<T, NumRows, NumCols,
534
535
Layout, Space>{}
535
536
.store (src, dst, stride);
536
537
#endif
0 commit comments