@@ -78,18 +78,19 @@ extern __DPCPP_SYCL_EXTERNAL void __spirv_JointMatrixStoreCheckedINTEL(
78
78
size_t CoordY, __spv::MatrixLayout Layout = L, __spv::Scope::Flag Sc = S,
79
79
int MemOperand = 0 );
80
80
81
- template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
82
- __spv::MatrixUse UA, __spv::MatrixUse UB, __spv::MatrixUse UC,
81
+ template <typename TA, typename TB, typename TC, std::size_t M, std::size_t K,
82
+ std::size_t N, __spv::MatrixUse UA, __spv::MatrixUse UB,
83
+ __spv::MatrixUse UC,
83
84
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
84
85
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
85
86
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
86
87
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
87
88
extern __DPCPP_SYCL_EXTERNAL
88
- __spv::__spirv_JointMatrixINTEL<T2 , M, N, LC, S, UC> *
89
+ __spv::__spirv_JointMatrixINTEL<TC , M, N, LC, S, UC> *
89
90
__spirv_JointMatrixMadINTEL (
90
- __spv::__spirv_JointMatrixINTEL<T1 , M, K, LA, S, UA> *A,
91
- __spv::__spirv_JointMatrixINTEL<T1 , K, N, LB, S, UB> *B,
92
- __spv::__spirv_JointMatrixINTEL<T2 , M, N, LC, S, UC> *C,
91
+ __spv::__spirv_JointMatrixINTEL<TA , M, K, LA, S, UA> *A,
92
+ __spv::__spirv_JointMatrixINTEL<TB , K, N, LB, S, UB> *B,
93
+ __spv::__spirv_JointMatrixINTEL<TC , M, N, LC, S, UC> *C,
93
94
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
94
95
95
96
template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
0 commit comments