Skip to content

Commit f8a16f7

Browse files
authored
[SYCL][Matrix] Move elementwise operation under intel namespace and add joint_matrix_apply. (#8417)
This patch moves the wi_data and wi_element class (and corresponding operations) in the matrix-intel.hpp file, under the intel::experimental::matrix namespace. The original implementation is kept (but soon will be deprecated) to make the existing CUDA test cases work.
1 parent 5b183b4 commit f8a16f7

File tree

3 files changed

+120
-17
lines changed

3 files changed

+120
-17
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp

Lines changed: 75 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,27 @@ template <typename Group, typename T, use Use, size_t Rows, size_t Cols,
6565
layout Layout>
6666
struct joint_matrix;
6767

68-
template <typename T, size_t NumRows, size_t NumCols, use Use,
69-
layout Layout = layout::dynamic, typename Group = sycl::sub_group>
68+
} // namespace matrix
69+
} // namespace experimental
70+
} // namespace oneapi
71+
72+
namespace intel::experimental::matrix {
73+
74+
// Begin wi_element definition
75+
76+
template <typename T, size_t NumRows, size_t NumCols,
77+
sycl::ext::oneapi::experimental::matrix::use Use,
78+
sycl::ext::oneapi::experimental::matrix::layout Layout =
79+
sycl::ext::oneapi::experimental::matrix::layout::dynamic,
80+
typename Group = sycl::sub_group>
7081
class wi_element {
71-
joint_matrix<Group, T, Use, NumRows, NumCols, Layout> &M;
82+
sycl::ext::oneapi::experimental::matrix::joint_matrix<Group, T, Use, NumRows,
83+
NumCols, Layout> &M;
7284
std::size_t idx;
7385

7486
public:
75-
wi_element(joint_matrix<Group, T, Use, NumRows, NumCols, Layout> &Mat,
87+
wi_element(sycl::ext::oneapi::experimental::matrix::joint_matrix<
88+
Group, T, Use, NumRows, NumCols, Layout> &Mat,
7689
std::size_t i)
7790
: M(Mat), idx(i) {}
7891
operator T() {
@@ -142,17 +155,20 @@ class wi_element {
142155
#undef OP
143156
};
144157

145-
template <size_t NumRows, size_t NumCols, use Use, layout Layout,
158+
template <size_t NumRows, size_t NumCols,
159+
sycl::ext::oneapi::experimental::matrix::use Use,
160+
sycl::ext::oneapi::experimental::matrix::layout Layout,
146161
typename Group>
147162
class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
148163
Group> {
149-
joint_matrix<Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols,
150-
Layout> &M;
164+
sycl::ext::oneapi::experimental::matrix::joint_matrix<
165+
Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols, Layout> &M;
151166
std::size_t idx;
152167

153168
public:
154-
wi_element(joint_matrix<Group, sycl::ext::oneapi::bfloat16, Use, NumRows,
155-
NumCols, Layout> &Mat,
169+
wi_element(sycl::ext::oneapi::experimental::matrix::joint_matrix<
170+
Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols,
171+
Layout> &Mat,
156172
std::size_t i)
157173
: M(Mat), idx(i) {}
158174
operator sycl::ext::oneapi::bfloat16() {
@@ -290,11 +306,57 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
290306
#endif // __SYCL_DEVICE_ONLY__
291307
};
292308

293-
} // namespace matrix
294-
} // namespace experimental
295-
} // namespace oneapi
309+
// End wi_element definition
310+
311+
// Begin wi_data definition
312+
313+
template <typename Group, typename T,
314+
sycl::ext::oneapi::experimental::matrix::use Use, size_t Rows,
315+
size_t Cols, sycl::ext::oneapi::experimental::matrix::layout Layout>
316+
class wi_data {
317+
318+
sycl::ext::oneapi::experimental::matrix::joint_matrix<Group, T, Use, Rows,
319+
Cols, Layout> &jm;
320+
321+
wi_data(sycl::ext::oneapi::experimental::matrix::joint_matrix<
322+
Group, T, Use, Rows, Cols, Layout> &_jm)
323+
: jm(_jm){};
324+
325+
template <typename Grp, typename Type,
326+
sycl::ext::oneapi::experimental::matrix::use UseJm, size_t NumRows,
327+
size_t NumCols,
328+
sycl::ext::oneapi::experimental::matrix::layout LayoutJm>
329+
friend decltype(auto)
330+
get_wi_data(Grp, sycl::ext::oneapi::experimental::matrix::joint_matrix<
331+
Grp, Type, UseJm, NumRows, NumCols, LayoutJm> &);
332+
333+
public:
334+
size_t length() {
335+
#if __SYCL_DEVICE_ONLY__
336+
return __spirv_JointMatrixWorkItemLengthINTEL(jm.spvm);
337+
#else
338+
throw runtime_error("joint matrix is not supported on host device.",
339+
PI_ERROR_INVALID_DEVICE);
340+
#endif
341+
};
342+
343+
decltype(auto) operator[](size_t i) {
344+
return wi_element<T, Rows, Cols, Use, Layout, Group>(jm, i);
345+
};
346+
};
347+
348+
template <typename Group, typename T,
349+
sycl::ext::oneapi::experimental::matrix::use Use, size_t Rows,
350+
size_t Cols, sycl::ext::oneapi::experimental::matrix::layout Layout>
351+
inline __SYCL_ALWAYS_INLINE decltype(auto)
352+
get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix<
353+
Group, T, Use, Rows, Cols, Layout> &jm) {
354+
std::ignore = sg;
355+
return wi_data(jm);
356+
}
357+
358+
// End wi_data definition
296359

297-
namespace intel::experimental::matrix {
298360
template <
299361
typename Group, typename T,
300362
sycl::ext::oneapi::experimental::matrix::use Use, size_t NumRows,

sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,19 @@ class wi_data {
6363
#if defined(__NVPTX__)
6464
return jm.cuda_impl.wi_marray.size();
6565
#else
66-
return __spirv_JointMatrixWorkItemLengthINTEL(jm.spvm);
66+
throw runtime_error("get_wi_data is available using: "
67+
"ext::intel::experimental::matrix::get_wi_data.",
68+
PI_ERROR_INVALID_DEVICE);
6769
#endif
6870
};
6971

7072
decltype(auto) operator[](size_t i) {
7173
#if defined(__NVPTX__)
7274
return (jm.cuda_impl.wi_marray[i]);
7375
#else
74-
return wi_element<T, Rows, Cols, Use, Layout, Group>(jm, i);
76+
throw runtime_error("get_wi_data is available using: "
77+
"ext::intel::experimental::matrix::get_wi_data.",
78+
PI_ERROR_INVALID_DEVICE);
7579
#endif
7680
};
7781
};
@@ -94,8 +98,18 @@ template <typename type, size_t size> class wi_data {
9498

9599
template <typename Group, typename T, use Use, size_t Rows, size_t Cols,
96100
layout Layout>
101+
#if defined(__SYCL_DEVICE_ONLY__)
102+
#if defined(__NVPTX__)
103+
__SYCL2020_DEPRECATED("get_wi_data() is deprecated for CUDA backend. Please "
104+
"use joint_matrix_apply() instead.")
105+
#else
106+
__attribute__((unavailable(
107+
"get_wi_data can't be used on intel device, please use "
108+
"sycl::ext::intel::experimental::matrix::get_wi_data instead!")))
109+
#endif
110+
#endif
97111
inline __SYCL_ALWAYS_INLINE decltype(auto)
98-
get_wi_data(Group sg, joint_matrix<Group, T, Use, Rows, Cols, Layout> &jm) {
112+
get_wi_data(Group sg, joint_matrix<Group, T, Use, Rows, Cols, Layout> &jm) {
99113
#if defined(__SYCL_DEVICE_ONLY__)
100114
std::ignore = sg;
101115
return wi_data(jm);
@@ -112,6 +126,32 @@ get_wi_data(Group sg, joint_matrix<Group, T, Use, Rows, Cols, Layout> &jm) {
112126
#endif // defined(__SYCL_DEVICE_ONLY__)
113127
}
114128

129+
template <typename Group, typename T, use Use, size_t M, size_t N,
130+
layout Layout, typename F>
131+
inline __SYCL_ALWAYS_INLINE void
132+
joint_matrix_apply(Group sg, joint_matrix<Group, T, Use, M, N, Layout> &jm,
133+
F &&lambda) {
134+
#if defined(__SYCL_DEVICE_ONLY__)
135+
#if defined(__NVPTX__)
136+
std::ignore = sg;
137+
for (int i = 0; i < jm.cuda_impl.wi_marray.size(); i++) {
138+
lambda(jm.cuda_impl.wi_marray[i]);
139+
}
140+
#else // NVPTX
141+
auto wi_data_c = sycl::ext::intel::experimental::matrix::get_wi_data(sg, jm);
142+
for (int i = 0; i < wi_data_c.length(); i++) {
143+
T element = wi_data_c[i];
144+
lambda(element);
145+
wi_data_c[i] = element;
146+
}
147+
#endif
148+
#else
149+
throw runtime_error("joint matrix is not supported on host device.",
150+
PI_ERROR_INVALID_DEVICE);
151+
#endif
152+
return;
153+
}
154+
115155
template <typename Group, typename T, size_t NumRows, size_t NumCols, use Use,
116156
layout Layout, typename T2>
117157
inline __SYCL_ALWAYS_INLINE void

sycl/test/matrix/matrix-elemwise-ops.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
9191
N * 4);
9292
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
9393
}
94-
auto wi_data_c = get_wi_data(sg, sub_c);
94+
auto wi_data_c =
95+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c);
9596
for (int i = 0; i < wi_data_c.length(); i++) {
9697
wi_data_c[i] *= 2;
9798
}

0 commit comments

Comments
 (0)