Skip to content

Commit 15fbefc

Browse files
authored
[SYCL][Matrix] Use KHR cooperative matrix instructions instead of Intel's (#13817)
The usage is currently guarded by __SPIRV_USE_COOPERATIVE_MATRIX macro. It's a split from #13316 --------- Signed-off-by: Sidorov, Dmitry <[email protected]>
1 parent 8407960 commit 15fbefc

File tree

85 files changed

+2107
-2
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

85 files changed

+2107
-2
lines changed

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
extern __DPCPP_SYCL_EXTERNAL float __spirv_RoundFToTF32INTEL(float a);
2929

30+
#ifndef __SPIRV_USE_COOPERATIVE_MATRIX
3031
template <typename T, typename Tp, std::size_t R, std::size_t C,
3132
__spv::MatrixUse U,
3233
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
@@ -174,6 +175,136 @@ template <typename Ts, typename T, std::size_t R, std::size_t C,
174175
extern __DPCPP_SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
175176
__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *,
176177
Ts val, size_t i);
178+
#else // __SPIRV_USE_COOPERATIVE_MATRIX
179+
template <typename T, typename Tp, std::size_t R, std::size_t C,
180+
__spv::MatrixUse U,
181+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
182+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
183+
extern __DPCPP_SYCL_EXTERNAL
184+
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
185+
__spirv_CooperativeMatrixLoadKHR(T *Ptr, __spv::MatrixLayout Layout = L,
186+
std::size_t Stride = 0,
187+
int MemOperand = 0);
188+
189+
template <typename T, typename Tp, std::size_t R, std::size_t C,
190+
__spv::MatrixUse U,
191+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
192+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
193+
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreKHR(
194+
T *Ptr, __spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *Object,
195+
__spv::MatrixLayout Layout = L, std::size_t Stride = 0, int MemOperand = 0);
196+
197+
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
198+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
199+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
200+
extern __DPCPP_SYCL_EXTERNAL size_t __spirv_CooperativeMatrixLengthKHR(
201+
__spv::__spirv_CooperativeMatrixKHR<T, S, R, C, U> *);
202+
203+
template <typename T, typename Tp, std::size_t R, std::size_t C,
204+
__spv::MatrixUse U,
205+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
206+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
207+
extern __DPCPP_SYCL_EXTERNAL
208+
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
209+
__spirv_CooperativeMatrixConstructCheckedINTEL(const T Value, size_t Height,
210+
size_t Stride, size_t Width,
211+
size_t CoordX,
212+
size_t CoordY);
213+
214+
template <typename T, typename Tp, std::size_t R, std::size_t C,
215+
__spv::MatrixUse U,
216+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
217+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
218+
extern __DPCPP_SYCL_EXTERNAL
219+
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
220+
__spirv_CooperativeMatrixLoadCheckedINTEL(T *Ptr, std::size_t Stride,
221+
size_t Height, size_t Width,
222+
size_t CoordX, size_t CoordY,
223+
__spv::MatrixLayout Layout = L,
224+
int MemOperand = 0);
225+
226+
template <typename T, typename Tp, std::size_t R, std::size_t C,
227+
__spv::MatrixUse U,
228+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
229+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
230+
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreCheckedINTEL(
231+
T *Ptr, __spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *Object,
232+
std::size_t Stride, size_t Height, size_t Width, size_t CoordX,
233+
size_t CoordY, __spv::MatrixLayout Layout = L, int MemOperand = 0);
234+
235+
template <typename TA, typename TB, typename TC, std::size_t M, std::size_t K,
236+
std::size_t N, __spv::MatrixUse UA, __spv::MatrixUse UB,
237+
__spv::MatrixUse UC,
238+
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
239+
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
240+
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
241+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
242+
extern __DPCPP_SYCL_EXTERNAL
243+
__spv::__spirv_CooperativeMatrixKHR<TC, S, M, N, UC> *
244+
__spirv_CooperativeMatrixMulAddKHR(
245+
__spv::__spirv_CooperativeMatrixKHR<TA, S, M, K, UA> *A,
246+
__spv::__spirv_CooperativeMatrixKHR<TB, S, K, N, UB> *B,
247+
__spv::__spirv_CooperativeMatrixKHR<TC, S, M, N, UC> *C,
248+
size_t Operands = 0);
249+
250+
template <typename T, typename Tp, std::size_t R, std::size_t C,
251+
__spv::MatrixUse U,
252+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
253+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
254+
extern __DPCPP_SYCL_EXTERNAL
255+
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
256+
__spirv_CompositeConstruct(const T v);
257+
258+
// TODO: replace with __spirv_CooperativeMatrixGetElementCoordINTEL when ready
259+
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
260+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
261+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
262+
extern __DPCPP_SYCL_EXTERNAL __ocl_vec_t<uint32_t, 2>
263+
__spirv_JointMatrixGetElementCoordINTEL(
264+
__spv::__spirv_CooperativeMatrixKHR<T, S, R, C, U> *, size_t i);
265+
266+
// AccessChain followed by load/store serves to extract/insert and element
267+
// from/to the matrix
268+
template <typename Ts, typename T, std::size_t R, std::size_t C,
269+
__spv::MatrixUse U,
270+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
271+
extern __DPCPP_SYCL_EXTERNAL Ts *
272+
__spirv_AccessChain(__spv::__spirv_CooperativeMatrixKHR<T, S, R, C, U> **,
273+
size_t i);
274+
275+
template <typename T, typename Tp, std::size_t R, std::size_t C,
276+
__spv::MatrixUse U,
277+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
278+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
279+
extern __DPCPP_SYCL_EXTERNAL
280+
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
281+
__spirv_CooperativeMatrixConstructCheckedINTEL(int32_t CoordX,
282+
int32_t CoordY,
283+
uint32_t Height,
284+
uint32_t Width,
285+
const T Value);
286+
287+
template <typename T, typename Tp, std::size_t R, std::size_t C,
288+
__spv::MatrixUse U,
289+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
290+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
291+
extern __DPCPP_SYCL_EXTERNAL
292+
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
293+
__spirv_CooperativeMatrixLoadCheckedINTEL(
294+
T *Ptr, int32_t CoordX, int32_t CoordY, __spv::MatrixLayout Layout = L,
295+
uint32_t Height = 0, uint32_t Width = 0, std::size_t Stride = 0,
296+
int MemOperand = 0);
297+
298+
template <typename T, typename Tp, std::size_t R, std::size_t C,
299+
__spv::MatrixUse U,
300+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
301+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
302+
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreCheckedINTEL(
303+
T *Ptr, int32_t CoordX, int32_t CoordY,
304+
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *Object,
305+
__spv::MatrixLayout Layout = L, uint32_t Height = 0, uint32_t Width = 0,
306+
std::size_t Stride = 0, int MemOperand = 0);
307+
#endif // __SPIRV_USE_COOPERATIVE_MATRIX
177308

178309
template <typename T>
179310
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixPrefetchINTEL(

sycl/include/CL/__spirv/spirv_types.hpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,34 @@ enum class MatrixLayout : uint32_t {
118118

119119
enum class MatrixUse : uint32_t { MatrixA = 0, MatrixB = 1, Accumulator = 2 };
120120

121+
#ifdef __SPIRV_USE_COOPERATIVE_MATRIX
122+
enum class MatrixOperands : uint32_t {
123+
// SPV_KHR_cooperative_matrix operands
124+
NoneKHR = 0,
125+
MatrixASignedComponentsKHR = 0x1,
126+
MatrixBSignedComponentsKHR = 0x2,
127+
MatrixCSignedComponentsKHR = 0x4,
128+
MatrixResultSignedComponentsKHR = 0x8,
129+
SaturatingAccumulationKHR = 0x10,
130+
// SPV_INTEL_joint_matrix operands
131+
MatrixAAndBTF32ComponentsINTEL = 0x20,
132+
MatrixAAndBBFloat16ComponentsINTEL = 0x40,
133+
MatrixCBFloat16ComponentsINTEL = 0x80,
134+
MatrixResultBFloat16ComponentsINTEL = 0x100
135+
};
136+
#endif // __SPIRV_USE_COOPERATIVE_MATRIX
137+
138+
#ifndef __SPIRV_USE_COOPERATIVE_MATRIX
139+
121140
template <typename T, std::size_t R, std::size_t C, MatrixLayout L,
122141
Scope::Flag S = Scope::Flag::Subgroup,
123142
MatrixUse U = MatrixUse::MatrixA>
124143
struct __spirv_JointMatrixINTEL;
144+
#else
145+
template <typename T, Scope::Flag S = Scope::Flag::Subgroup, std::size_t R = 1,
146+
std::size_t C = 1, MatrixUse U = MatrixUse::MatrixA>
147+
struct __spirv_CooperativeMatrixKHR;
148+
#endif // __SPIRV_USE_COOPERATIVE_MATRIX
125149

126150
struct __spirv_TaskSequenceINTEL;
127151

0 commit comments

Comments
 (0)