Skip to content

Commit 687f579

Browse files
[SYCL][Matrix] syntax changes as preparation before moving joint matrix from experimental namespace (#11215)
As part of the effort to move joint matrix from experimental namespace to supported. A review of the API is being done as part of #7964. This results in the following changes in the syntax: 1- Add Td to joint_matrix_mad as Tc can be different from Td on the GPU, Now, we make D as an input argument to mad. 2- Change “packed” to ext_intel_packed: 3- Move EWOps (get_wi_data, wi_element, get_coord) to detail namespace) 4- add const to joint_matrix in store and mad 5 - add joint_matrix_copy/assignment function 6- add apply with coordination (change existing tests) 7- change get_coord vector type from int32_t to size_t 8- delete explicitly both = and copy ctor.
1 parent f605df6 commit 687f579

File tree

53 files changed

+375
-646
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+375
-646
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,6 @@
2929
namespace sycl {
3030
inline namespace _V1 {
3131
namespace ext {
32-
namespace intel::experimental::matrix::layout {
33-
constexpr sycl::ext::oneapi::experimental::matrix::layout packed =
34-
static_cast<sycl::ext::oneapi::experimental::matrix::layout>(2);
35-
}
3632
namespace oneapi {
3733
namespace experimental {
3834
namespace matrix {
@@ -48,8 +44,7 @@ template <layout Layout> struct spv_matrix_layout_traits {
4844

4945
SPV_MATRIX_LAYOUT_TRAITS(layout::row_major, __spv::MatrixLayout::RowMajor)
5046
SPV_MATRIX_LAYOUT_TRAITS(layout::col_major, __spv::MatrixLayout::ColumnMajor)
51-
SPV_MATRIX_LAYOUT_TRAITS(sycl::ext::intel::experimental::matrix::layout::packed,
52-
__spv::MatrixLayout::Packed)
47+
SPV_MATRIX_LAYOUT_TRAITS(layout::ext_intel_packed, __spv::MatrixLayout::Packed)
5348
SPV_MATRIX_LAYOUT_TRAITS(layout::dynamic, __spv::MatrixLayout::Dynamic)
5449

5550
template <use Use> struct spv_matrix_use_traits {
@@ -94,10 +89,6 @@ struct jm_type_interpretation_helper_trait<
9489
using element_type = sycl::ext::oneapi::experimental::matrix::precision::tf32;
9590
using storage_element_type = float;
9691
};
97-
} // namespace detail
98-
} // namespace oneapi
99-
100-
namespace intel::experimental::matrix {
10192

10293
using namespace sycl::ext::oneapi::experimental::matrix;
10394
// Begin wi_element definition
@@ -121,12 +112,12 @@ class wi_element {
121112
std::size_t i)
122113
: M(Mat), idx(i) {}
123114

124-
inline __SYCL_ALWAYS_INLINE std::tuple<uint32_t, uint32_t> get_coord() {
115+
inline __SYCL_ALWAYS_INLINE std::tuple<size_t, size_t> get_coord() {
125116
#if defined(__SYCL_DEVICE_ONLY__)
126117
__ocl_vec_t<uint32_t, 2> coord =
127118
__spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
128-
const uint32_t row = coord[0];
129-
const uint32_t col = coord[1];
119+
const size_t row = coord[0];
120+
const size_t col = coord[1];
130121
return std::make_tuple(row, col);
131122
#else
132123
throw runtime_error("joint matrix is not supported on host device.",
@@ -196,7 +187,7 @@ class wi_element {
196187

197188
#if __SYCL_DEVICE_ONLY__
198189
#define OP(op) \
199-
template <typename T2> wi_element &operator op##=(const T2 &rhs) { \
190+
template <typename T2> wi_element &operator op##=(const T2 & rhs) { \
200191
M.spvm = __spirv_VectorInsertDynamic( \
201192
M.spvm, \
202193
static_cast<storage_element_type>( \
@@ -211,7 +202,7 @@ class wi_element {
211202
}
212203
#else // __SYCL_DEVICE_ONLY__
213204
#define OP(op) \
214-
template <typename T2> wi_element &operator op##=(const T2 &rhs) { \
205+
template <typename T2> wi_element &operator op##=(const T2 & rhs) { \
215206
(void)rhs; \
216207
throw runtime_error("joint matrix is not supported on host device.", \
217208
PI_ERROR_INVALID_DEVICE); \
@@ -315,7 +306,7 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
315306

316307
#if __SYCL_DEVICE_ONLY__
317308
#define OP(opassign, op) \
318-
wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 &rhs) { \
309+
wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 & rhs) { \
319310
M.spvm = __spirv_VectorInsertDynamic( \
320311
M.spvm, \
321312
__spirv_VectorExtractDynamic< \
@@ -328,7 +319,7 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
328319
}
329320
#else // __SYCL_DEVICE_ONLY__
330321
#define OP(opassign, op) \
331-
wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 &rhs) { \
322+
wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 & rhs) { \
332323
(void)rhs; \
333324
throw runtime_error("joint matrix is not supported on host device.", \
334325
PI_ERROR_INVALID_DEVICE); \
@@ -479,7 +470,10 @@ get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix<
479470
}
480471

481472
// End wi_data definition
473+
} // namespace detail
474+
} // namespace oneapi
482475

476+
namespace intel::experimental::matrix {
483477
template <
484478
typename Group, typename T, typename Tp,
485479
sycl::ext::oneapi::experimental::matrix::use Use, size_t NumRows,
@@ -490,7 +484,7 @@ template <
490484
bool> = true>
491485
inline __SYCL_ALWAYS_INLINE void
492486
joint_matrix_store(Group,
493-
sycl::ext::oneapi::experimental::matrix::joint_matrix<
487+
const sycl::ext::oneapi::experimental::matrix::joint_matrix<
494488
Group, Tp, Use, NumRows, NumCols, Layout> &src,
495489
multi_ptr<T, Space, IsDecorated> dst, size_t stride) {
496490
#if defined(__SYCL_DEVICE_ONLY__)
@@ -526,6 +520,43 @@ joint_matrix_store(Group,
526520
PI_ERROR_INVALID_DEVICE);
527521
#endif // defined(__SYCL_DEVICE_ONLY__)
528522
}
523+
524+
template <typename Group, typename T,
525+
sycl::ext::oneapi::experimental::matrix::use Use, size_t Rows,
526+
size_t Cols, sycl::ext::oneapi::experimental::matrix::layout Layout,
527+
typename F>
528+
inline __SYCL_ALWAYS_INLINE void joint_matrix_apply(
529+
Group sg,
530+
sycl::ext::oneapi::experimental::matrix::joint_matrix<Group, T, Use, Rows,
531+
Cols, Layout> &jm,
532+
F &&lambda) {
533+
#if defined(__SYCL_DEVICE_ONLY__)
534+
#if defined(__NVPTX__)
535+
std::ignore = sg;
536+
for (int i = 0; i < jm.cuda_impl.wi_marray.size(); i++) {
537+
lambda(jm.cuda_impl.wi_marray[i]);
538+
}
539+
#else // NVPTX
540+
using storage_element_type =
541+
typename oneapi::detail::jm_type_interpretation_helper_trait<
542+
T>::storage_element_type;
543+
auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jm);
544+
for (int i = 0; i < wi_data_c.length(); i++) {
545+
storage_element_type element = wi_data_c[i];
546+
auto [row, col] = wi_data_c[i].get_coord();
547+
lambda(element, row, col);
548+
wi_data_c[i] = element;
549+
}
550+
#endif
551+
#else
552+
std::ignore = sg;
553+
std::ignore = jm;
554+
std::ignore = lambda;
555+
throw runtime_error("joint matrix is not supported on host device.",
556+
PI_ERROR_INVALID_DEVICE);
557+
#endif
558+
}
559+
529560
} // namespace intel::experimental::matrix
530561

531562
} // namespace ext

sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@ namespace matrix {
1616

1717
enum class use { a, b, accumulator };
1818

19-
enum class layout { row_major = 0, col_major = 1, dynamic = 3 };
19+
enum class layout {
20+
row_major = 0,
21+
col_major = 1,
22+
ext_intel_packed = 2,
23+
dynamic = 3
24+
};
2025

2126
namespace precision {
2227
class tf32 {

sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp

Lines changed: 60 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ struct joint_matrix {
4040

4141
#if defined(__SYCL_DEVICE_ONLY__)
4242
#if defined(__NVPTX__)
43-
sycl::ext::oneapi::detail::joint_matrix_cuda<T, Use, Rows, Cols, Layout>
43+
mutable sycl::ext::oneapi::detail::joint_matrix_cuda<T, Use, Rows, Cols,
44+
Layout>
4445
cuda_impl;
4546
#elif defined(__SPIR__)
4647
__spv::__spirv_JointMatrixINTEL<
@@ -61,19 +62,8 @@ struct joint_matrix {
6162
}
6263
#ifdef __SYCL_DEVICE_ONLY__
6364
#if defined(__SPIR__)
64-
// Generate a non-trivial assignment operator and copy c'tor that prevents
65-
// memcpy from being generated.
66-
// TODO: to remove, when either IGC can handle alloca JointMatrix or
67-
// combination of InstCombine + SROA + mem2reg can remove it
68-
joint_matrix(const joint_matrix &other) {
69-
spvm = other.spvm;
70-
return *this;
71-
}
72-
73-
joint_matrix &operator=(const joint_matrix &rhs) {
74-
spvm = rhs.spvm;
75-
return *this;
76-
}
65+
joint_matrix(const joint_matrix &other) = delete;
66+
joint_matrix &operator=(const joint_matrix &rhs) = delete;
7767
#endif // defined(__SPIR__)
7868
#endif
7969
};
@@ -97,10 +87,6 @@ class wi_data {
9787
size_t length() {
9888
#if defined(__NVPTX__)
9989
return jm.cuda_impl.wi_marray.size();
100-
#else
101-
throw runtime_error("get_wi_data is available using: "
102-
"ext::intel::experimental::matrix::get_wi_data.",
103-
PI_ERROR_INVALID_DEVICE);
10490
#endif
10591
};
10692

@@ -109,9 +95,6 @@ class wi_data {
10995
return (jm.cuda_impl.wi_marray[i]);
11096
#else
11197
std::ignore = i;
112-
throw runtime_error("get_wi_data is available using: "
113-
"ext::intel::experimental::matrix::get_wi_data.",
114-
PI_ERROR_INVALID_DEVICE);
11598
#endif
11699
};
117100
};
@@ -139,9 +122,8 @@ template <typename Group, typename T, use Use, size_t Rows, size_t Cols,
139122
__SYCL2020_DEPRECATED("get_wi_data() is deprecated for CUDA backend. Please "
140123
"use joint_matrix_apply() instead.")
141124
#else
142-
__attribute__((unavailable(
143-
"get_wi_data can't be used on intel device, please use "
144-
"sycl::ext::intel::experimental::matrix::get_wi_data instead!")))
125+
__attribute__((unavailable("get_wi_data() has been removed from the API and "
126+
"replaced with joint_matrix_apply!")))
145127
#endif
146128
#endif
147129
inline __SYCL_ALWAYS_INLINE decltype(auto)
@@ -177,7 +159,7 @@ joint_matrix_apply(Group sg, joint_matrix<Group, T, Use, M, N, Layout> &jm,
177159
using storage_element_type =
178160
typename oneapi::detail::jm_type_interpretation_helper_trait<
179161
T>::storage_element_type;
180-
auto wi_data_c = sycl::ext::intel::experimental::matrix::get_wi_data(sg, jm);
162+
auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jm);
181163
for (int i = 0; i < wi_data_c.length(); i++) {
182164
storage_element_type element = wi_data_c[i];
183165
lambda(element);
@@ -260,7 +242,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load(
260242
Ptr, stride, __spv::MatrixLayout::ColumnMajor,
261243
spv_scope_traits<Group>::value);
262244
break;
263-
case sycl::ext::intel::experimental::matrix::layout::packed:
245+
case layout::ext_intel_packed:
264246
res.spvm = __spirv_JointMatrixLoadINTEL<
265247
DecorT, S, NumRows, NumCols,
266248
spv_matrix_use_traits<use::accumulator>::value,
@@ -322,8 +304,9 @@ template <typename Group, typename T, size_t NumRows, size_t NumCols,
322304
access::address_space Space, access::decorated IsDecorated>
323305
inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
324306
Group,
325-
joint_matrix<Group, T, use::accumulator, NumRows, NumCols,
326-
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src,
307+
const joint_matrix<Group, T, use::accumulator, NumRows, NumCols,
308+
sycl::ext::oneapi::experimental::matrix::layout::dynamic>
309+
&src,
327310
multi_ptr<T, Space, IsDecorated> dst, size_t stride,
328311
sycl::ext::oneapi::experimental::matrix::layout Layout) {
329312
#if defined(__SYCL_DEVICE_ONLY__)
@@ -355,7 +338,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
355338
Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor,
356339
spv_scope_traits<Group>::value);
357340
break;
358-
case sycl::ext::intel::experimental::matrix::layout::packed:
341+
case layout::ext_intel_packed:
359342
__spirv_JointMatrixStoreINTEL<
360343
DecorT, T, NumRows, NumCols,
361344
spv_matrix_use_traits<use::accumulator>::value,
@@ -375,51 +358,77 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
375358
#endif // defined(__SYCL_DEVICE_ONLY__)
376359
}
377360

378-
template <typename Group, typename Ta, typename Tb, typename Tc, std::size_t M,
379-
std::size_t K, std::size_t N, layout LayoutA, layout LayoutB>
380-
inline __SYCL_ALWAYS_INLINE
381-
joint_matrix<Group, Tc, use::accumulator, M, N,
382-
sycl::ext::oneapi::experimental::matrix::layout::dynamic>
383-
joint_matrix_mad(
384-
Group, joint_matrix<Group, Ta, use::a, M, K, LayoutA> &A,
385-
joint_matrix<Group, Tb, use::b, K, N, LayoutB> &B,
386-
joint_matrix<Group, Tc, use::accumulator, M, N,
387-
sycl::ext::oneapi::experimental::matrix::layout::dynamic>
388-
&C) {
361+
template <typename Group, typename Ta, typename Tb, typename Tc, typename Td,
362+
std::size_t M, std::size_t K, std::size_t N, layout LayoutA,
363+
layout LayoutB>
364+
inline __SYCL_ALWAYS_INLINE void joint_matrix_mad(
365+
Group,
366+
joint_matrix<Group, Td, use::accumulator, M, N,
367+
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D,
368+
const joint_matrix<Group, Ta, use::a, M, K, LayoutA> &A,
369+
const joint_matrix<Group, Tb, use::b, K, N, LayoutB> &B,
370+
const joint_matrix<Group, Tc, use::accumulator, M, N,
371+
sycl::ext::oneapi::experimental::matrix::layout::dynamic>
372+
&C) {
389373
#if defined(__SYCL_DEVICE_ONLY__)
390374
#if defined(__NVPTX__)
391375
if constexpr (std::is_same<Ta, Tb>::value) {
392-
joint_matrix<Group, Tc, use::accumulator, M, N,
393-
sycl::ext::oneapi::experimental::matrix::layout::dynamic>
394-
D;
395376
sycl::ext::oneapi::detail::joint_matrix_mad_cuda<Ta, Tc, M, K, N, LayoutA,
396377
LayoutB>(
397378
D.cuda_impl, A.cuda_impl, B.cuda_impl, C.cuda_impl);
398-
return D;
399379
} else {
400380
assert(false && "Ta != Tb : In the CUDA backend joint_matrix_mad "
401381
"requires that joint_matrix data types Ta and Tb match");
402382
}
403383
#else
404-
joint_matrix<Group, Tc, use::accumulator, M, N, layout::dynamic> res;
405384
if constexpr (std::is_same<Ta, uint16_t>::value &&
406385
std::is_same<Tb, uint16_t>::value &&
407386
std::is_same<Tc, float>::value)
408-
res.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm);
387+
D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm);
409388
else if constexpr (std::is_unsigned<Ta>::value && std::is_unsigned<Tb>::value)
410-
res.spvm = __spirv_JointMatrixUUMadINTEL(A.spvm, B.spvm, C.spvm);
389+
D.spvm = __spirv_JointMatrixUUMadINTEL(A.spvm, B.spvm, C.spvm);
411390
else if constexpr (std::is_signed<Ta>::value && std::is_unsigned<Tb>::value)
412-
res.spvm = __spirv_JointMatrixSUMadINTEL(A.spvm, B.spvm, C.spvm);
391+
D.spvm = __spirv_JointMatrixSUMadINTEL(A.spvm, B.spvm, C.spvm);
413392
else if constexpr (std::is_unsigned<Ta>::value && std::is_signed<Tb>::value)
414-
res.spvm = __spirv_JointMatrixUSMadINTEL(A.spvm, B.spvm, C.spvm);
393+
D.spvm = __spirv_JointMatrixUSMadINTEL(A.spvm, B.spvm, C.spvm);
415394
else
416-
res.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm);
417-
return res;
395+
D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm);
418396
#endif // defined(__NVPTX__)
419397
#else
420398
std::ignore = A;
421399
std::ignore = B;
422400
std::ignore = C;
401+
std::ignore = D;
402+
throw runtime_error("joint matrix is not supported on host device.",
403+
PI_ERROR_INVALID_DEVICE);
404+
#endif // defined(__SYCL_DEVICE_ONLY__)
405+
}
406+
407+
template <typename Group, typename T1, typename T2, size_t Rows, size_t Cols,
408+
use Use1, use Use2, layout Layout1, layout Layout2>
409+
void joint_matrix_copy(
410+
Group sg, joint_matrix<Group, T1, Use1, Rows, Cols, Layout1> &src,
411+
joint_matrix<Group, T2, Use2, Rows, Cols, Layout2> &dst) {
412+
#if defined(__SYCL_DEVICE_ONLY__)
413+
#if defined(__NVPTX__)
414+
std::ignore = sg;
415+
for (int i = 0; i < src.cuda_impl.wi_marray.size(); i++) {
416+
dst.cuda_impl.wi_marray[i] = src.cuda_impl.wi_marray[i];
417+
}
418+
#else
419+
using storage_element_type =
420+
typename oneapi::detail::jm_type_interpretation_helper_trait<
421+
T2>::storage_element_type;
422+
auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, src);
423+
auto wi_data_dst = sycl::ext::oneapi::detail::get_wi_data(sg, dst);
424+
for (int i = 0; i < wi_data_c.length(); i++) {
425+
wi_data_dst[i] = static_cast<storage_element_type>(wi_data_c[i]);
426+
}
427+
#endif // defined(__NVPTX__)
428+
#else
429+
std::ignore = sg;
430+
std::ignore = dst;
431+
std::ignore = src;
423432
throw runtime_error("joint matrix is not supported on host device.",
424433
PI_ERROR_INVALID_DEVICE);
425434
#endif // defined(__SYCL_DEVICE_ONLY__)

sycl/test-e2e/Matrix/XMX8/element_wise_irreg_sum_rows.cpp

Lines changed: 0 additions & 26 deletions
This file was deleted.

0 commit comments

Comments
 (0)