@@ -172,13 +172,13 @@ get_layout_id<sycl::ext::oneapi::experimental::matrix::layout::col_major>() {
172
172
173
173
template <sycl::ext::oneapi::experimental::matrix::layout Layout, typename S,
174
174
typename T, size_t NumRows, size_t NumCols,
175
- access::address_space Space>
175
+ access::address_space Space, access::decorated IsDecorated >
176
176
void load_accumulator_layoutT (
177
177
sycl::ext::oneapi::experimental::matrix::joint_matrix<
178
178
S, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows,
179
179
NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic,
180
180
sycl::sub_group> &res,
181
- multi_ptr<T, Space> src, size_t stride) {
181
+ multi_ptr<T, Space, IsDecorated > src, size_t stride) {
182
182
if constexpr (std::is_same_v<S, int32_t >) {
183
183
auto destptr = reinterpret_cast <int32_t *>(&res.wi_marray );
184
184
if constexpr (NumRows == 16 && NumCols == 16 ) {
@@ -221,13 +221,13 @@ void load_accumulator_layoutT(
221
221
};
222
222
223
223
template <typename S, typename T, size_t NumRows, size_t NumCols,
224
- access::address_space Space>
224
+ access::address_space Space, access::decorated IsDecorated >
225
225
void load_accumulator_cuda (
226
226
sycl::ext::oneapi::experimental::matrix::joint_matrix<
227
227
S, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows,
228
228
NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic,
229
229
sycl::sub_group> &res,
230
- multi_ptr<T, Space> src, size_t stride,
230
+ multi_ptr<T, Space, IsDecorated > src, size_t stride,
231
231
sycl::ext::oneapi::experimental::matrix::layout Layout) {
232
232
switch (Layout) {
233
233
case sycl::ext::oneapi::experimental::matrix::layout::row_major:
@@ -249,7 +249,7 @@ template <
249
249
typename S, typename T, size_t NumRows, size_t NumCols,
250
250
sycl::ext::oneapi::experimental::matrix::use Use,
251
251
sycl::ext::oneapi::experimental::matrix::layout Layout,
252
- access::address_space Space,
252
+ access::address_space Space, access::decorated IsDecorated,
253
253
std::enable_if_t <
254
254
Layout == sycl::ext::oneapi::experimental::matrix::layout::row_major ||
255
255
Layout ==
@@ -258,7 +258,7 @@ template <
258
258
void load_multiplicand_cuda (
259
259
sycl::ext::oneapi::experimental::matrix::joint_matrix<
260
260
S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
261
- multi_ptr<T, Space> src, size_t stride) {
261
+ multi_ptr<T, Space, IsDecorated > src, size_t stride) {
262
262
if constexpr (std::is_same_v<S, sycl::ext::oneapi::experimental::bfloat16>) {
263
263
auto tileptr = reinterpret_cast <const int32_t *>(src.get ());
264
264
auto destptr = reinterpret_cast <int32_t *>(&res.wi_marray );
@@ -377,13 +377,14 @@ void load_multiplicand_cuda(
377
377
}
378
378
379
379
template <sycl::ext::oneapi::experimental::matrix::layout Layout, typename T,
380
- size_t NumRows, size_t NumCols, access::address_space Space>
380
+ size_t NumRows, size_t NumCols, access::address_space Space,
381
+ access::decorated IsDecorated>
381
382
void store_layoutT (
382
383
sycl::ext::oneapi::experimental::matrix::joint_matrix<
383
384
T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows,
384
385
NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic,
385
386
sycl::sub_group> &src,
386
- multi_ptr<T, Space> dst, size_t stride) {
387
+ multi_ptr<T, Space, IsDecorated > dst, size_t stride) {
387
388
if constexpr (NumRows == 16 && NumCols == 16 ) {
388
389
if constexpr (std::is_same_v<T, float >) {
389
390
__hmma_m16n16k16_st_c_f32 (dst.get (),
@@ -434,13 +435,13 @@ void store_layoutT(
434
435
}
435
436
436
437
template <typename T, size_t NumRows, size_t NumCols,
437
- access::address_space Space>
438
+ access::address_space Space, access::decorated IsDecorated >
438
439
void joint_matrix_store_cuda (
439
440
sycl::ext::oneapi::experimental::matrix::joint_matrix<
440
441
T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows,
441
442
NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic,
442
443
sycl::sub_group> &src,
443
- multi_ptr<T, Space> dst, size_t stride,
444
+ multi_ptr<T, Space, IsDecorated > dst, size_t stride,
444
445
sycl::ext::oneapi::experimental::matrix::layout Layout) {
445
446
switch (Layout) {
446
447
case sycl::ext::oneapi::experimental::matrix::layout::row_major:
0 commit comments