Skip to content

Commit bb1121c

Browse files
authored
[SYCL][joint matrix] Make SPIRV mad call consistent with the SYCL call (#12369)
1 parent a920b53 commit bb1121c

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,18 +78,19 @@ extern __DPCPP_SYCL_EXTERNAL void __spirv_JointMatrixStoreCheckedINTEL(
7878
size_t CoordY, __spv::MatrixLayout Layout = L, __spv::Scope::Flag Sc = S,
7979
int MemOperand = 0);
8080

81-
template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
82-
__spv::MatrixUse UA, __spv::MatrixUse UB, __spv::MatrixUse UC,
81+
template <typename TA, typename TB, typename TC, std::size_t M, std::size_t K,
82+
std::size_t N, __spv::MatrixUse UA, __spv::MatrixUse UB,
83+
__spv::MatrixUse UC,
8384
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
8485
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
8586
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
8687
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
8788
extern __DPCPP_SYCL_EXTERNAL
88-
__spv::__spirv_JointMatrixINTEL<T2, M, N, LC, S, UC> *
89+
__spv::__spirv_JointMatrixINTEL<TC, M, N, LC, S, UC> *
8990
__spirv_JointMatrixMadINTEL(
90-
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S, UA> *A,
91-
__spv::__spirv_JointMatrixINTEL<T1, K, N, LB, S, UB> *B,
92-
__spv::__spirv_JointMatrixINTEL<T2, M, N, LC, S, UC> *C,
91+
__spv::__spirv_JointMatrixINTEL<TA, M, K, LA, S, UA> *A,
92+
__spv::__spirv_JointMatrixINTEL<TB, K, N, LB, S, UB> *B,
93+
__spv::__spirv_JointMatrixINTEL<TC, M, N, LC, S, UC> *C,
9394
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
9495

9596
template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,

0 commit comments

Comments
 (0)