Skip to content

Commit 36004a0

Browse files
committed
added access::decorated.
Signed-off-by: JackAKirk <[email protected]>
1 parent 8da0aa7 commit 36004a0

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -172,13 +172,13 @@ get_layout_id<sycl::ext::oneapi::experimental::matrix::layout::col_major>() {
172172

173173
template <sycl::ext::oneapi::experimental::matrix::layout Layout, typename S,
174174
typename T, size_t NumRows, size_t NumCols,
175-
access::address_space Space>
175+
access::address_space Space, access::decorated IsDecorated>
176176
void load_accumulator_layoutT(
177177
sycl::ext::oneapi::experimental::matrix::joint_matrix<
178178
S, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows,
179179
NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic,
180180
sycl::sub_group> &res,
181-
multi_ptr<T, Space> src, size_t stride) {
181+
multi_ptr<T, Space, IsDecorated> src, size_t stride) {
182182
if constexpr (std::is_same_v<S, int32_t>) {
183183
auto destptr = reinterpret_cast<int32_t *>(&res.wi_marray);
184184
if constexpr (NumRows == 16 && NumCols == 16) {
@@ -221,13 +221,13 @@ void load_accumulator_layoutT(
221221
};
222222

223223
template <typename S, typename T, size_t NumRows, size_t NumCols,
224-
access::address_space Space>
224+
access::address_space Space, access::decorated IsDecorated>
225225
void load_accumulator_cuda(
226226
sycl::ext::oneapi::experimental::matrix::joint_matrix<
227227
S, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows,
228228
NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic,
229229
sycl::sub_group> &res,
230-
multi_ptr<T, Space> src, size_t stride,
230+
multi_ptr<T, Space, IsDecorated> src, size_t stride,
231231
sycl::ext::oneapi::experimental::matrix::layout Layout) {
232232
switch (Layout) {
233233
case sycl::ext::oneapi::experimental::matrix::layout::row_major:
@@ -249,7 +249,7 @@ template <
249249
typename S, typename T, size_t NumRows, size_t NumCols,
250250
sycl::ext::oneapi::experimental::matrix::use Use,
251251
sycl::ext::oneapi::experimental::matrix::layout Layout,
252-
access::address_space Space,
252+
access::address_space Space, access::decorated IsDecorated,
253253
std::enable_if_t<
254254
Layout == sycl::ext::oneapi::experimental::matrix::layout::row_major ||
255255
Layout ==
@@ -258,7 +258,7 @@ template <
258258
void load_multiplicand_cuda(
259259
sycl::ext::oneapi::experimental::matrix::joint_matrix<
260260
S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
261-
multi_ptr<T, Space> src, size_t stride) {
261+
multi_ptr<T, Space, IsDecorated> src, size_t stride) {
262262
if constexpr (std::is_same_v<S, sycl::ext::oneapi::experimental::bfloat16>) {
263263
auto tileptr = reinterpret_cast<const int32_t *>(src.get());
264264
auto destptr = reinterpret_cast<int32_t *>(&res.wi_marray);
@@ -377,13 +377,14 @@ void load_multiplicand_cuda(
377377
}
378378

379379
template <sycl::ext::oneapi::experimental::matrix::layout Layout, typename T,
380-
size_t NumRows, size_t NumCols, access::address_space Space>
380+
size_t NumRows, size_t NumCols, access::address_space Space,
381+
access::decorated IsDecorated>
381382
void store_layoutT(
382383
sycl::ext::oneapi::experimental::matrix::joint_matrix<
383384
T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows,
384385
NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic,
385386
sycl::sub_group> &src,
386-
multi_ptr<T, Space> dst, size_t stride) {
387+
multi_ptr<T, Space, IsDecorated> dst, size_t stride) {
387388
if constexpr (NumRows == 16 && NumCols == 16) {
388389
if constexpr (std::is_same_v<T, float>) {
389390
__hmma_m16n16k16_st_c_f32(dst.get(),
@@ -434,13 +435,13 @@ void store_layoutT(
434435
}
435436

436437
template <typename T, size_t NumRows, size_t NumCols,
437-
access::address_space Space>
438+
access::address_space Space, access::decorated IsDecorated>
438439
void joint_matrix_store_cuda(
439440
sycl::ext::oneapi::experimental::matrix::joint_matrix<
440441
T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows,
441442
NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic,
442443
sycl::sub_group> &src,
443-
multi_ptr<T, Space> dst, size_t stride,
444+
multi_ptr<T, Space, IsDecorated> dst, size_t stride,
444445
sycl::ext::oneapi::experimental::matrix::layout Layout) {
445446
switch (Layout) {
446447
case sycl::ext::oneapi::experimental::matrix::layout::row_major:

0 commit comments

Comments
 (0)