@@ -65,14 +65,27 @@ template <typename Group, typename T, use Use, size_t Rows, size_t Cols,
65
65
layout Layout>
66
66
struct joint_matrix ;
67
67
68
- template <typename T, size_t NumRows, size_t NumCols, use Use,
69
- layout Layout = layout::dynamic, typename Group = sycl::sub_group>
68
+ } // namespace matrix
69
+ } // namespace experimental
70
+ } // namespace oneapi
71
+
72
+ namespace intel ::experimental::matrix {
73
+
74
+ // Begin wi_element definition
75
+
76
+ template <typename T, size_t NumRows, size_t NumCols,
77
+ sycl::ext::oneapi::experimental::matrix::use Use,
78
+ sycl::ext::oneapi::experimental::matrix::layout Layout =
79
+ sycl::ext::oneapi::experimental::matrix::layout::dynamic,
80
+ typename Group = sycl::sub_group>
70
81
class wi_element {
71
- joint_matrix<Group, T, Use, NumRows, NumCols, Layout> &M;
82
+ sycl::ext::oneapi::experimental::matrix::joint_matrix<Group, T, Use, NumRows,
83
+ NumCols, Layout> &M;
72
84
std::size_t idx;
73
85
74
86
public:
75
- wi_element (joint_matrix<Group, T, Use, NumRows, NumCols, Layout> &Mat,
87
+ wi_element (sycl::ext::oneapi::experimental::matrix::joint_matrix<
88
+ Group, T, Use, NumRows, NumCols, Layout> &Mat,
76
89
std::size_t i)
77
90
: M(Mat), idx(i) {}
78
91
operator T () {
@@ -142,17 +155,20 @@ class wi_element {
142
155
#undef OP
143
156
};
144
157
145
- template <size_t NumRows, size_t NumCols, use Use, layout Layout,
158
+ template <size_t NumRows, size_t NumCols,
159
+ sycl::ext::oneapi::experimental::matrix::use Use,
160
+ sycl::ext::oneapi::experimental::matrix::layout Layout,
146
161
typename Group>
147
162
class wi_element <sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
148
163
Group> {
149
- joint_matrix<Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols,
150
- Layout> &M;
164
+ sycl::ext::oneapi::experimental::matrix::joint_matrix<
165
+ Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols, Layout> &M;
151
166
std::size_t idx;
152
167
153
168
public:
154
- wi_element (joint_matrix<Group, sycl::ext::oneapi::bfloat16, Use, NumRows,
155
- NumCols, Layout> &Mat,
169
+ wi_element (sycl::ext::oneapi::experimental::matrix::joint_matrix<
170
+ Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols,
171
+ Layout> &Mat,
156
172
std::size_t i)
157
173
: M(Mat), idx(i) {}
158
174
operator sycl::ext::oneapi::bfloat16 () {
@@ -290,11 +306,57 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
290
306
#endif // __SYCL_DEVICE_ONLY__
291
307
};
292
308
293
- } // namespace matrix
294
- } // namespace experimental
295
- } // namespace oneapi
309
+ // End wi_element definition
310
+
311
+ // Begin wi_data definition
312
+
313
+ template <typename Group, typename T,
314
+ sycl::ext::oneapi::experimental::matrix::use Use, size_t Rows,
315
+ size_t Cols, sycl::ext::oneapi::experimental::matrix::layout Layout>
316
+ class wi_data {
317
+
318
+ sycl::ext::oneapi::experimental::matrix::joint_matrix<Group, T, Use, Rows,
319
+ Cols, Layout> &jm;
320
+
321
+ wi_data (sycl::ext::oneapi::experimental::matrix::joint_matrix<
322
+ Group, T, Use, Rows, Cols, Layout> &_jm)
323
+ : jm(_jm){};
324
+
325
+ template <typename Grp, typename Type,
326
+ sycl::ext::oneapi::experimental::matrix::use UseJm, size_t NumRows,
327
+ size_t NumCols,
328
+ sycl::ext::oneapi::experimental::matrix::layout LayoutJm>
329
+ friend decltype (auto )
330
+ get_wi_data(Grp, sycl::ext::oneapi::experimental::matrix::joint_matrix<
331
+ Grp, Type, UseJm, NumRows, NumCols, LayoutJm> &);
332
+
333
+ public:
334
+ size_t length () {
335
+ #if __SYCL_DEVICE_ONLY__
336
+ return __spirv_JointMatrixWorkItemLengthINTEL (jm.spvm );
337
+ #else
338
+ throw runtime_error (" joint matrix is not supported on host device." ,
339
+ PI_ERROR_INVALID_DEVICE);
340
+ #endif
341
+ };
342
+
343
+ decltype (auto ) operator [](size_t i) {
344
+ return wi_element<T, Rows, Cols, Use, Layout, Group>(jm, i);
345
+ };
346
+ };
347
+
348
+ template <typename Group, typename T,
349
+ sycl::ext::oneapi::experimental::matrix::use Use, size_t Rows,
350
+ size_t Cols, sycl::ext::oneapi::experimental::matrix::layout Layout>
351
+ inline __SYCL_ALWAYS_INLINE decltype (auto )
352
+ get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix<
353
+ Group, T, Use, Rows, Cols, Layout> &jm) {
354
+ std::ignore = sg;
355
+ return wi_data (jm);
356
+ }
357
+
358
+ // End wi_data definition
296
359
297
- namespace intel ::experimental::matrix {
298
360
template <
299
361
typename Group, typename T,
300
362
sycl::ext::oneapi::experimental::matrix::use Use, size_t NumRows,
0 commit comments