Skip to content

Commit 76ec14b

Browse files
[Matrix][SYCL] Add use argument for joint_matrix and add another feat… (#5835)
…ure macro for it
1 parent 26d5d98 commit 76ec14b

16 files changed

+1313
-644
lines changed

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,89 +23,107 @@
2323

2424
#ifdef __SYCL_DEVICE_ONLY__
2525
template <typename T, std::size_t R, std::size_t C,
26+
__spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
2627
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
2728
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
28-
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S> *
29+
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
2930
__spirv_JointMatrixLoadINTEL(T *Ptr, std::size_t Stride,
3031
__spv::MatrixLayout Layout = L,
3132
__spv::Scope::Flag Sc = S, int MemOperand = 0);
3233

3334
template <typename T, std::size_t R, std::size_t C,
35+
__spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
3436
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
3537
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
3638
extern SYCL_EXTERNAL void __spirv_JointMatrixStoreINTEL(
37-
T *Ptr, __spv::__spirv_JointMatrixINTEL<T, R, C, L, S> *Object,
39+
T *Ptr, __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *Object,
3840
std::size_t Stride, __spv::MatrixLayout Layout = L,
3941
__spv::Scope::Flag Sc = S, int MemOperand = 0);
4042

4143
template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
44+
__spv::MatrixUse UA = __spv::MatrixUse::Unnecessary,
45+
__spv::MatrixUse UB = __spv::MatrixUse::Unnecessary,
46+
__spv::MatrixUse UC = __spv::MatrixUse::Unnecessary,
4247
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
4348
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
4449
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
4550
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
46-
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T2, M, N, LC, S> *
51+
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T2, M, N, LC, S, UC> *
4752
__spirv_JointMatrixMadINTEL(
48-
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S> *A,
49-
__spv::__spirv_JointMatrixINTEL<T1, K, N, LB, S> *B,
50-
__spv::__spirv_JointMatrixINTEL<T2, M, N, LC, S> *C,
53+
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S, UA> *A,
54+
__spv::__spirv_JointMatrixINTEL<T1, K, N, LB, S, UB> *B,
55+
__spv::__spirv_JointMatrixINTEL<T2, M, N, LC, S, UC> *C,
5156
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
5257

5358
template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
54-
std::size_t N, __spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
59+
std::size_t N, __spv::MatrixUse UA = __spv::MatrixUse::Unnecessary,
60+
__spv::MatrixUse UB = __spv::MatrixUse::Unnecessary,
61+
__spv::MatrixUse UC = __spv::MatrixUse::Unnecessary,
62+
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
5563
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
5664
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
5765
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
58-
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S> *
66+
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *
5967
__spirv_JointMatrixUUMadINTEL(
60-
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S> *A,
61-
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S> *B,
62-
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S> *C,
68+
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S, UA> *A,
69+
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S, UB> *B,
70+
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *C,
6371
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
6472

6573
template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
66-
std::size_t N, __spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
74+
std::size_t N, __spv::MatrixUse UA = __spv::MatrixUse::Unnecessary,
75+
__spv::MatrixUse UB = __spv::MatrixUse::Unnecessary,
76+
__spv::MatrixUse UC = __spv::MatrixUse::Unnecessary,
77+
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
6778
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
6879
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
6980
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
70-
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S> *
81+
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *
7182
__spirv_JointMatrixUSMadINTEL(
72-
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S> *A,
73-
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S> *B,
74-
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S> *C,
83+
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S, UA> *A,
84+
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S, UB> *B,
85+
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *C,
7586
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
7687

7788
template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
78-
std::size_t N, __spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
89+
std::size_t N, __spv::MatrixUse UA = __spv::MatrixUse::Unnecessary,
90+
__spv::MatrixUse UB = __spv::MatrixUse::Unnecessary,
91+
__spv::MatrixUse UC = __spv::MatrixUse::Unnecessary,
92+
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
7993
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
8094
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
8195
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
82-
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S> *
96+
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *
8397
__spirv_JointMatrixSUMadINTEL(
84-
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S> *A,
85-
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S> *B,
86-
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S> *C,
98+
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S, UA> *A,
99+
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S, UB> *B,
100+
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *C,
87101
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
88102

89103
template <typename T, std::size_t R, std::size_t C,
104+
__spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
90105
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
91106
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
92-
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S> *
107+
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
93108
__spirv_CompositeConstruct(const T v);
94109

95-
template <typename T, std::size_t R, std::size_t C, __spv::MatrixLayout U,
110+
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
111+
__spv::MatrixLayout L,
96112
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
97113
extern SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL(
98-
__spv::__spirv_JointMatrixINTEL<T, R, C, U, S> *);
114+
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *);
99115

100-
template <typename T, std::size_t R, std::size_t C, __spv::MatrixLayout U,
116+
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
117+
__spv::MatrixLayout L,
101118
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
102119
extern SYCL_EXTERNAL T __spirv_VectorExtractDynamic(
103-
__spv::__spirv_JointMatrixINTEL<T, R, C, U, S> *, size_t i);
120+
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *, size_t i);
104121

105-
template <typename T, std::size_t R, std::size_t C, __spv::MatrixLayout U,
122+
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
123+
__spv::MatrixLayout L,
106124
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
107-
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, U, S> *
108-
__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL<T, R, C, U, S> *,
125+
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
126+
__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *,
109127
T val, size_t i);
110128

111129
#ifndef __SPIRV_BUILTIN_DECLARATIONS__

sycl/include/CL/__spirv/spirv_types.hpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,15 @@ enum class MatrixLayout : uint32_t {
112112
RowMajor = 0,
113113
ColumnMajor = 1,
114114
PackedA = 2,
115-
PackedB = 3
115+
PackedB = 3,
116+
Unused = 4
117+
};
118+
119+
enum class MatrixUse : uint32_t {
120+
MatrixA = 0,
121+
MatrixB = 1,
122+
Accumulator = 2,
123+
Unnecessary = 3
116124
};
117125

118126
// TODO: replace the following W/A with a better solution when we have it.
@@ -129,10 +137,13 @@ enum class MatrixLayout : uint32_t {
129137
// information to SPIRV translator.
130138
// The long term solution would be to introduce a matrix type in Clang and use
131139
// it instead of this member.
132-
template <typename T, std::size_t R, std::size_t C, MatrixLayout U,
133-
Scope::Flag S = Scope::Flag::Subgroup>
140+
template <typename T, std::size_t R, std::size_t C, MatrixLayout L,
141+
Scope::Flag S = Scope::Flag::Subgroup,
142+
MatrixUse U = MatrixUse::Unnecessary>
134143
struct __spirv_JointMatrixINTEL {
135-
T (*Value)[R][C][static_cast<size_t>(U) + 1][static_cast<size_t>(S) + 1];
144+
T(*Value)
145+
[R][C][static_cast<size_t>(L) + 1][static_cast<size_t>(S) + 1]
146+
[static_cast<size_t>(U) + 1];
136147
};
137148

138149
} // namespace __spv

0 commit comments

Comments
 (0)