Skip to content

[SYCL][Matrix spec] Add joint_matrix_prefetch and overloads of load/store with annotated_ptr #11473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,28 @@ template <typename Group, typename T, size_t Rows, size_t Cols,
access::decorated IsDecorated>
void joint_matrix_store(Group g,
const joint_matrix<Group, T, use::a, Rows, Cols, Layout> &res,
multi_ptr<T, Space, IsDecorated> src, size_t stride);
multi_ptr<T, Space, IsDecorated> dest, size_t stride);

template <typename Group, typename T, size_t Rows, size_t Cols,
layout Layout, access::address_space Space,
access::decorated IsDecorated>
void joint_matrix_store(Group g,
const joint_matrix<Group, T, use::b, Rows, Cols, Layout> &res,
multi_ptr<T, Space, IsDecorated> src, size_t stride);
multi_ptr<T, Space, IsDecorated> dest, size_t stride);

template <typename Group, typename T, size_t Rows, size_t Cols,
layout Layout, typename PropertyListT>
void joint_matrix_store(Group g,
const joint_matrix<Group, T, use::a, Rows, Cols, Layout> &res,
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> dest,
size_t stride);

template <typename Group, typename T, size_t Rows, size_t Cols,
layout Layout, typename PropertyListT>
void joint_matrix_store(Group g,
const joint_matrix<Group, T, use::b, Rows, Cols, Layout> &res,
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> dest,
size_t stride);

} // namespace sycl::ext::intel::experimental::matrix
```
Expand Down Expand Up @@ -327,6 +341,7 @@ q.submit([&](sycl::handler& cgh) {
});
q.wait();
```

== Revision History

[frame="none",options="header"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,23 @@ void joint_matrix_load(Group g,
joint_matrix<Group, T1, Use, Rows, Cols, Layout> &res,
multi_ptr<T2, Space, IsDecorated> src, size_t stride);

// Only available when std::is_same_v<T1, std::remove_const_t<T2>>
template <typename Group, typename T1, typename T2,
size_t Rows, size_t Cols,
typename PropertyListT>
void joint_matrix_load(Group g,
joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
annotated_ptr<T2, PropertyListT> src, size_t stride, layout Layout);

// Only available when Layout != layout::dynamic
// and when std::is_same_v<T1, std::remove_const_t<T2>>
template <typename Group, typename T1, typename T2,
size_t Rows, size_t Cols, use Use, layout Layout,
typename PropertyListT>
void joint_matrix_load(Group g,
joint_matrix<Group, T1, Use, Rows, Cols, Layout> &res,
annotated_ptr<T2, PropertyListT> src, size_t stride);

} // namespace sycl::ext::oneapi::experimental::matrix
```

Expand All @@ -248,6 +265,33 @@ fashion. `stride` describes the number of elements between consecutive
rows for the row major layout, or between columns for the column major
layout.

The two last overloads of `joint_matrix_load` take
`sycl::ext::oneapi::experimental::annotated_ptr` as argument instead
of `sycl::multi_ptr`. The property list associated with the
`annotated_ptr` argument represents the compile-time constant
properties for cache control included in the SYCL extenion
link:../../proposed/sycl_ext_intel_cache_controls.asciidoc[sycl_ext_intel_cache_controls]
as illustrated in the example below.

```c++
using syclex = sycl::ext::oneapi::experimental;
using syclintelex = sycl::ext::intel::experimental;

auto A_ptr = syclex::annotated_ptr{A,
syclex::properties{syclintelex::read_hint<
syclintelex::cache_control<syclintelex::cache_mode::cached,
syclex::cache_level::L2>>}};
q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> it) {
sub_group sg = it.get_sub_group();
joint_matrix<sub_group, bfloat16, use::a, tM, tK, layout::row_major> tA;
for (int k = 0; k < K; k += tileK) {
// User specifies that this load will be cached to L2
joint_matrix_load(sg, tA, A_ptr + sg_startx * tM * K + k, K);
...
}
});
```

==== Store
```c++
namespace sycl::ext::oneapi::experimental::matrix {
Expand All @@ -259,6 +303,12 @@ void joint_matrix_store(Group g,
const joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
multi_ptr<T2, Space, IsDecorated> dest, size_t stride, layout Layout);

template <typename Group, typename T1, typename T2, size_t Rows, size_t Cols,
typename PropertyListT>
void joint_matrix_store(Group g,
const joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
annotated_ptr<T2, PropertyListT> dest, size_t stride, layout Layout);

} // namespace sycl::ext::oneapi::experimental::matrix
```
This function stores the data in the accumulator matrix from the
Expand All @@ -270,6 +320,11 @@ written in a row (`row_major`), column major (`col_major`)
fashion. `stride` describes the number of elements between consecutive
rows for the row major layout, or between columns for the column major layout.

The second overload of `joint_matrix_store` takes
`sycl::ext::oneapi::experimental::annotated_ptr` as argument instead
of `sycl::multi_ptr`. The property list associated with the
`annotated_ptr` argument represents the compile-time constant
properties for cache control included in the SYCL extenion link:../../proposed/sycl_ext_intel_cache_controls.asciidoc[sycl_ext_intel_cache_controls]

==== Multiply and Add

Expand Down Expand Up @@ -372,6 +427,47 @@ joint_matrix_apply(sg, C, [=](T &x) {
});
```

==== Prefetch

```c++
namespace sycl::ext::oneapi::experimental::matrix {

template <size_t Rows, size_t Cols, typename Group, typename T,
typename Properties = empty_properties_t>
void joint_matrix_prefetch(Group g, T* ptr, size_t stride, layout Layout,
Properties properties = {});

} // namespace sycl::ext::oneapi::experimental::matrix
```

`joint_matrix_prefetch` allows groups of work-items to cooperatively
prefetch `Rows x Cols` elements in a 2d manner. This function is a group
function, as defined in Section 4.17.3 of the core SYCL
specification.

The level of cache targeted by `joint_matrix_prefetch` in the last
argument is specified using the compile-time properties defined in the
SYCL extension
link:../../proposed/sycl_ext_oneapi_prefetch.asciidoc[sycl_ext_oneapi_prefetch]
as illustrated in the example below. When no cache levels are
specified, the default behavior is to prefetch into the lowest level
cache (i.e. L1).

```c++
using syclex = sycl::ext::oneapi::experimental;

bfloat16 *memA = malloc_shared<bfloat16>(M*K, q);
q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> it) {
sub_group sg = it.get_sub_group();
for (int k = 0; k < K; k += tileK) {
syclex::joint_matrix_prefetch<tM, tK>(sg, memA + tM * K + tK, K,
layout::row_major,
syclex::properties{syclex::prefetch_hint_L2});
...
}
});
```

=== Support for Machine Learning Types
Some devices support special matrix element types that are commonly
used in machine learning algorithms.
Expand Down Expand Up @@ -1035,4 +1131,6 @@ and Intel XMX
|8 |2023-10-05 |Mahmoud Moadeli |Add AMD Matrix Core supported combinations
|9 |2023-11-13 |Dounia Khaldi |Add Granite Rapids Intel AMX
supported combinations
|9 |2023-12-04 |Dounia Khaldi |Add prefetch and `annotated_ptr`
load/store overloads
|======================