Skip to content

Commit f4a9ef1

Browse files
[SYCL] Implement matrix extension using new unified interface (#7413)
1 parent 3c8cf0b commit f4a9ef1

18 files changed

+931
-863
lines changed

sycl/include/CL/__spirv/spirv_types.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,22 @@ enum class GroupOperation : uint32_t {
108108
ExclusiveScan = 2
109109
};
110110

111+
#if (SYCL_EXT_ONEAPI_MATRIX_VERSION > 1)
112+
enum class MatrixLayout : uint32_t {
113+
RowMajor = 0,
114+
ColumnMajor = 1,
115+
Packed = 2,
116+
Dynamic = 3
117+
};
118+
#else
111119
enum class MatrixLayout : uint32_t {
112120
RowMajor = 0,
113121
ColumnMajor = 1,
114122
PackedA = 2,
115123
PackedB = 3,
116124
Unused = 4
117125
};
126+
#endif
118127

119128
enum class MatrixUse : uint32_t { MatrixA = 0, MatrixB = 1, Accumulator = 2 };
120129

Lines changed: 347 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
//==------------------ matrix-intel.hpp - SYCL matrix ----------*- C++ -*---==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
// ===--------------------------------------------------------------------=== //
8+
9+
#pragma once
10+
11+
#include "matrix-unified-utils.hpp"
12+
#include <CL/__spirv/spirv_ops.hpp>
13+
#include <sycl/detail/defines_elementary.hpp>
14+
#include <sycl/feature_test.hpp>
15+
16+
namespace sycl {
17+
__SYCL_INLINE_VER_NAMESPACE(_V1) {
18+
namespace ext {
19+
namespace intel::experimental::matrix::layout {
20+
constexpr sycl::ext::oneapi::experimental::matrix::layout packed =
21+
static_cast<sycl::ext::oneapi::experimental::matrix::layout>(2);
22+
}
23+
namespace oneapi {
24+
namespace experimental {
25+
namespace matrix {
26+
27+
template <layout Layout> struct spv_matrix_layout_traits {
28+
static constexpr __spv::MatrixLayout value = __spv::MatrixLayout::Dynamic;
29+
};
30+
31+
#define SPV_MATRIX_LAYOUT_TRAITS(LAYOUT, SPV_LAYOUT) \
32+
template <> struct spv_matrix_layout_traits<LAYOUT> { \
33+
static constexpr __spv::MatrixLayout value = SPV_LAYOUT; \
34+
};
35+
36+
SPV_MATRIX_LAYOUT_TRAITS(layout::row_major, __spv::MatrixLayout::RowMajor)
37+
SPV_MATRIX_LAYOUT_TRAITS(layout::col_major, __spv::MatrixLayout::ColumnMajor)
38+
SPV_MATRIX_LAYOUT_TRAITS(sycl::ext::intel::experimental::matrix::layout::packed,
39+
__spv::MatrixLayout::Packed)
40+
SPV_MATRIX_LAYOUT_TRAITS(layout::dynamic, __spv::MatrixLayout::Dynamic)
41+
42+
template <use Use> struct spv_matrix_use_traits {
43+
static constexpr __spv::MatrixUse value = __spv::MatrixUse::MatrixA;
44+
};
45+
46+
#define SPV_MATRIX_USE_TRAITS(USE, SPV_USE) \
47+
template <> struct spv_matrix_use_traits<USE> { \
48+
static constexpr __spv::MatrixUse value = SPV_USE; \
49+
};
50+
51+
SPV_MATRIX_USE_TRAITS(use::a, __spv::MatrixUse::MatrixA)
52+
SPV_MATRIX_USE_TRAITS(use::b, __spv::MatrixUse::MatrixB)
53+
SPV_MATRIX_USE_TRAITS(use::accumulator, __spv::MatrixUse::Accumulator)
54+
55+
template <typename G> struct spv_scope_traits {};
56+
template <> struct spv_scope_traits<sycl::sub_group> {
57+
constexpr static auto value = __spv::Scope::Subgroup;
58+
};
59+
template <int D> struct spv_scope_traits<sycl::group<D>> {
60+
constexpr static auto value = __spv::Scope::Workgroup;
61+
};
62+
63+
// forward declarations
64+
template <typename Group, typename T, use Use, size_t Rows, size_t Cols,
65+
layout Layout>
66+
struct joint_matrix;
67+
68+
template <typename T, size_t NumRows, size_t NumCols, use Use,
69+
layout Layout = layout::dynamic, typename Group = sycl::sub_group>
70+
class wi_element {
71+
joint_matrix<Group, T, Use, NumRows, NumCols, Layout> &M;
72+
std::size_t idx;
73+
74+
public:
75+
wi_element(joint_matrix<Group, T, Use, NumRows, NumCols, Layout> &Mat,
76+
std::size_t i)
77+
: M(Mat), idx(i) {}
78+
operator T() {
79+
#ifdef __SYCL_DEVICE_ONLY__
80+
return __spirv_VectorExtractDynamic(M.spvm, idx);
81+
#else
82+
throw runtime_error("joint matrix is not supported on host device.",
83+
PI_ERROR_INVALID_DEVICE);
84+
#endif // __SYCL_DEVICE_ONLY__
85+
}
86+
87+
explicit operator bool() {
88+
#ifdef __SYCL_DEVICE_ONLY__
89+
return __spirv_VectorExtractDynamic(M.spvm, idx) != static_cast<T>(0);
90+
#else
91+
throw runtime_error("joint matrix is not supported on host device.",
92+
PI_ERROR_INVALID_DEVICE);
93+
#endif // __SYCL_DEVICE_ONLY__
94+
}
95+
96+
template <typename T2> wi_element &operator=(const T2 &rhs) {
97+
#ifdef __SYCL_DEVICE_ONLY__
98+
M.spvm = __spirv_VectorInsertDynamic(M.spvm, static_cast<T>(rhs), idx);
99+
return *this;
100+
#else
101+
(void)rhs;
102+
throw runtime_error("joint matrix is not supported on host device.",
103+
PI_ERROR_INVALID_DEVICE);
104+
#endif // __SYCL_DEVICE_ONLY__
105+
}
106+
107+
wi_element &
108+
operator=(const wi_element<T, NumRows, NumCols, Use, Layout, Group> &rhs) {
109+
#ifdef __SYCL_DEVICE_ONLY__
110+
M.spvm = __spirv_VectorInsertDynamic(
111+
M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx);
112+
return *this;
113+
#else
114+
(void)rhs;
115+
throw runtime_error("joint matrix is not supported on host device.",
116+
PI_ERROR_INVALID_DEVICE);
117+
#endif // __SYCL_DEVICE_ONLY__
118+
}
119+
120+
#if __SYCL_DEVICE_ONLY__
121+
#define OP(op) \
122+
template <typename T2> wi_element &operator op##=(const T2 &rhs) { \
123+
M.spvm = __spirv_VectorInsertDynamic( \
124+
M.spvm, \
125+
static_cast<T>(__spirv_VectorExtractDynamic(M.spvm, idx) \
126+
op static_cast<T>(rhs)), \
127+
idx); \
128+
return *this; \
129+
}
130+
#else // __SYCL_DEVICE_ONLY__
131+
#define OP(op) \
132+
template <typename T2> wi_element &operator op##=(const T2 &rhs) { \
133+
(void)rhs; \
134+
throw runtime_error("joint matrix is not supported on host device.", \
135+
PI_ERROR_INVALID_DEVICE); \
136+
}
137+
#endif // __SYCL_DEVICE_ONLY__
138+
OP(+)
139+
OP(-)
140+
OP(*)
141+
OP(/)
142+
#undef OP
143+
};
144+
145+
template <size_t NumRows, size_t NumCols, use Use, layout Layout,
146+
typename Group>
147+
class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
148+
Group> {
149+
joint_matrix<Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols,
150+
Layout> &M;
151+
std::size_t idx;
152+
153+
public:
154+
wi_element(joint_matrix<Group, sycl::ext::oneapi::bfloat16, Use, NumRows,
155+
NumCols, Layout> &Mat,
156+
std::size_t i)
157+
: M(Mat), idx(i) {}
158+
operator sycl::ext::oneapi::bfloat16() {
159+
#ifdef __SYCL_DEVICE_ONLY__
160+
return __spirv_VectorExtractDynamic(M.spvm, idx);
161+
#else
162+
throw runtime_error("joint matrix is not supported on host device.",
163+
PI_ERROR_INVALID_DEVICE);
164+
#endif // __SYCL_DEVICE_ONLY__
165+
}
166+
167+
explicit operator bool() {
168+
#ifdef __SYCL_DEVICE_ONLY__
169+
return std::fabs(static_cast<float>(__spirv_VectorExtractDynamic(
170+
M.spvm, idx))) >= std::numeric_limits<float>::epsilon();
171+
#else
172+
throw runtime_error("joint matrix is not supported on host device.",
173+
PI_ERROR_INVALID_DEVICE);
174+
#endif // __SYCL_DEVICE_ONLY__
175+
}
176+
177+
wi_element &operator=(const sycl::ext::oneapi::bfloat16 &rhs) {
178+
#ifdef __SYCL_DEVICE_ONLY__
179+
M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx);
180+
return *this;
181+
#else
182+
(void)rhs;
183+
throw runtime_error("joint matrix is not supported on host device.",
184+
PI_ERROR_INVALID_DEVICE);
185+
#endif // __SYCL_DEVICE_ONLY__
186+
}
187+
188+
wi_element &operator=(const wi_element<sycl::ext::oneapi::bfloat16, NumRows,
189+
NumCols, Use, Layout, Group> &rhs) {
190+
#ifdef __SYCL_DEVICE_ONLY__
191+
M.spvm = __spirv_VectorInsertDynamic(
192+
M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx);
193+
return *this;
194+
#else
195+
(void)rhs;
196+
throw runtime_error("joint matrix is not supported on host device.",
197+
PI_ERROR_INVALID_DEVICE);
198+
#endif // __SYCL_DEVICE_ONLY__
199+
}
200+
201+
#if __SYCL_DEVICE_ONLY__
202+
#define OP(opassign, op) \
203+
wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 &rhs) { \
204+
M.spvm = __spirv_VectorInsertDynamic( \
205+
M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) op rhs, idx); \
206+
return *this; \
207+
}
208+
#else // __SYCL_DEVICE_ONLY__
209+
#define OP(opassign, op) \
210+
wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 &rhs) { \
211+
(void)rhs; \
212+
throw runtime_error("joint matrix is not supported on host device.", \
213+
PI_ERROR_INVALID_DEVICE); \
214+
}
215+
#endif // __SYCL_DEVICE_ONLY__
216+
OP(+=, +)
217+
OP(-=, -)
218+
OP(*=, *)
219+
OP(/=, /)
220+
#undef OP
221+
222+
#if __SYCL_DEVICE_ONLY__
223+
#define OP(type, op) \
224+
friend type operator op( \
225+
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
226+
Layout, Group> &lhs, \
227+
const sycl::ext::oneapi::bfloat16 &rhs) { \
228+
return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) op rhs; \
229+
} \
230+
friend type operator op( \
231+
const sycl::ext::oneapi::bfloat16 &lhs, \
232+
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
233+
Layout, Group> &rhs) { \
234+
return __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx) op lhs; \
235+
}
236+
OP(sycl::ext::oneapi::bfloat16, +)
237+
OP(sycl::ext::oneapi::bfloat16, -)
238+
OP(sycl::ext::oneapi::bfloat16, *)
239+
OP(sycl::ext::oneapi::bfloat16, /)
240+
#undef OP
241+
#define OP(type, op) \
242+
friend type operator op( \
243+
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
244+
Layout, Group> &lhs, \
245+
const sycl::ext::oneapi::bfloat16 &rhs) { \
246+
return type{static_cast<float>(__spirv_VectorExtractDynamic( \
247+
lhs.M.spvm, lhs.idx)) op static_cast<float>(rhs)}; \
248+
} \
249+
friend type operator op( \
250+
const sycl::ext::oneapi::bfloat16 &lhs, \
251+
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
252+
Layout, Group> &rhs) { \
253+
return type{static_cast<float>(__spirv_VectorExtractDynamic( \
254+
rhs.M.spvm, rhs.idx)) op static_cast<float>(lhs)}; \
255+
}
256+
OP(bool, ==)
257+
OP(bool, !=)
258+
OP(bool, <)
259+
OP(bool, >)
260+
OP(bool, <=)
261+
OP(bool, >=)
262+
#undef OP
263+
#else // __SYCL_DEVICE_ONLY__
264+
#define OP(type, op) \
265+
friend type operator op( \
266+
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
267+
Layout, Group> &, \
268+
const sycl::ext::oneapi::bfloat16 &) { \
269+
throw runtime_error("joint matrix is not supported on host device.", \
270+
PI_ERROR_INVALID_DEVICE); \
271+
} \
272+
friend type operator op( \
273+
const sycl::ext::oneapi::bfloat16 &, \
274+
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
275+
Layout, Group> &) { \
276+
throw runtime_error("joint matrix is not supported on host device.", \
277+
PI_ERROR_INVALID_DEVICE); \
278+
}
279+
OP(sycl::ext::oneapi::bfloat16, +)
280+
OP(sycl::ext::oneapi::bfloat16, -)
281+
OP(sycl::ext::oneapi::bfloat16, *)
282+
OP(sycl::ext::oneapi::bfloat16, /)
283+
OP(bool, ==)
284+
OP(bool, !=)
285+
OP(bool, <)
286+
OP(bool, >)
287+
OP(bool, <=)
288+
OP(bool, >=)
289+
#undef OP
290+
#endif // __SYCL_DEVICE_ONLY__
291+
};
292+
293+
} // namespace matrix
294+
} // namespace experimental
295+
} // namespace oneapi
296+
297+
namespace intel::experimental::matrix {
298+
template <
299+
typename Group, typename T,
300+
sycl::ext::oneapi::experimental::matrix::use Use, size_t NumRows,
301+
size_t NumCols, sycl::ext::oneapi::experimental::matrix::layout Layout,
302+
access::address_space Space, access::decorated IsDecorated,
303+
std::enable_if_t<Use == sycl::ext::oneapi::experimental::matrix::use::a ||
304+
Use == sycl::ext::oneapi::experimental::matrix::use::b,
305+
bool> = true>
306+
inline __SYCL_ALWAYS_INLINE void
307+
joint_matrix_store(Group sg,
308+
sycl::ext::oneapi::experimental::matrix::joint_matrix<
309+
Group, T, Use, NumRows, NumCols, Layout> &src,
310+
multi_ptr<T, Space, IsDecorated> dst, size_t stride) {
311+
#if defined(__SYCL_DEVICE_ONLY__)
312+
#if defined(__NVPTX__)
313+
std::ignore = sg;
314+
std::ignore = src;
315+
std::ignore = dst;
316+
std::ignore = stride;
317+
throw runtime_error(
318+
"This version of the matrix extension is only currently supported on "
319+
"intel devices",
320+
PI_ERROR_INVALID_DEVICE);
321+
#else
322+
// intel's impl
323+
T *Ptr = dst.get();
324+
__spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
325+
sycl::ext::oneapi::experimental::matrix::
326+
spv_matrix_use_traits<Use>::value,
327+
sycl::ext::oneapi::experimental::matrix::
328+
spv_matrix_layout_traits<Layout>::value>(
329+
Ptr, src.spvm, stride,
330+
sycl::ext::oneapi::experimental::matrix::spv_matrix_layout_traits<
331+
Layout>::value,
332+
sycl::ext::oneapi::experimental::matrix::spv_scope_traits<Group>::value);
333+
#endif // defined(__NVPTX__)
334+
#else
335+
std::ignore = sg;
336+
std::ignore = src;
337+
std::ignore = dst;
338+
std::ignore = stride;
339+
throw runtime_error("joint matrix is not supported on host device.",
340+
PI_ERROR_INVALID_DEVICE);
341+
#endif // defined(__SYCL_DEVICE_ONLY__)
342+
}
343+
} // namespace intel::experimental::matrix
344+
345+
} // namespace ext
346+
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
347+
} // namespace sycl

0 commit comments

Comments
 (0)