Skip to content

[SYCL][Doc][Joint matrix] Joint matrix tf32 type spec restructuring #12276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,12 @@ layout.
```c++
namespace sycl::ext::oneapi::experimental::matrix {

template <typename Group, typename T, size_t Rows, size_t Cols,
// T1 must be the same as T2
template <typename Group, typename T1, typename T2, size_t Rows, size_t Cols,
access::address_space Space, access::decorated IsDecorated>
void joint_matrix_store(Group g,
const joint_matrix<Group, T, use::accumulator, Rows, Cols, layout::dynamic> &res,
multi_ptr<T, Space, IsDecorated> dest, size_t stride, layout Layout);
const joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
multi_ptr<T2, Space, IsDecorated> dest, size_t stride, layout Layout);

} // namespace sycl::ext::oneapi::experimental::matrix
```
Expand Down Expand Up @@ -371,15 +372,17 @@ joint_matrix_apply(sg, C, [=](T &x) {
});
```

=== Support for the TF32 Data Type
Some devices support the TF32 floating point type for matrix
elements. This type has a 19 bit format with one sign bit, 8 exponent
bits (offering the same range as float), and 10 mantissa bits
(offering the same precision as sycl::half). Use of this type can
accelerate the joint_matrix_mad operation by reducing its
precision. In order to declare a `joint_matrix` object with this
element type, use `matrix::precision::tf32` in place of the `T`
template parameter.
=== Support for Machine Learning Types
Some devices support special matrix element types that are commonly
used in machine learning algorithms.
These types are unusual because the type of the matrix element is
different from the way the data is stored in memory. As a result, each
of these elements has two types. There is an abstract identifier for
the element type, which is an incomplete type defined in the
`sycl::ext::oneapi::experimental::matrix::precision` namespace, and
there is a corresponding storage format type. The following synopsis
lists the abstract types and the table shows the associated storage
format type.

```c++
namespace sycl::ext::oneapi::experimental::matrix::precision {
Expand All @@ -389,94 +392,85 @@ class tf32;
} // namespace sycl::ext::oneapi::experimental::matrix::precision
```

For example:
[frame="none",options="header",cols="20%,20%,60%"]
|======================
| `joint_matrix` element type | Storage type | Descritpion
|precision::tf32 | float | The TF32 type has a 19 bit format with one
sign bit, 8 exponent bits (offering the same range as float), and 10
mantissa bits (offering the same precision as sycl::half).
|======================

In order to declare a `joint_matrix` with one of these element types,
use the abstract type like so:

```c++
joint_matrix<sub_group, precision::tf32, use::a, tM, tK,
layout::row_major> tA;
```

Whenever the application loads, stores, fills, or accesses the
elements of a TF32 matrix, the application sees the elements as
float. There are special overloads of these functions for TF32 for
this purpose.
Operations on these matrices use the functions described above, but
there are different constraints on the template parameters as
described below.

==== TF32 load
These overloads of `joint_matrix_load` load float values into a TF32
matrix. It is unspecified whether the implementation loads all 32 bits
into the joint matrix or if it only loads the relevant 19 bits.
==== load
The template parameter `T2` must either be the storage format type
that corresponds to the abstract type `T1` or it must be a
const-qualified version of that storage format type. For example:

```c++
namespace sycl::ext::oneapi::experimental::matrix {
joint_matrix<sub_group, precision::tf32, use::a, tM, tK, layout::row_major> tA;

template <typename Group, size_t Rows, size_t Cols,
access::address_space Space, access::decorated IsDecorated>
void joint_matrix_load(Group g,
joint_matrix<Group, precision::tf32, use::accumulator, Rows, Cols,
layout::dynamic> &res,
multi_ptr<const float, Space, IsDecorated> src, size_t stride, layout Layout);
float *buf = malloc_shared<float>(M*K, q);
auto pBuf = address_space_cast<sycl::access::address_space::global_space,
sycl::access::decorated::no>(buf);

template <typename Group, size_t Rows, size_t Cols,
access::address_space Space, access::decorated IsDecorated>
void joint_matrix_load(Group g,
joint_matrix<Group, precision::tf32, use::accumulator, Rows, Cols,
layout::dynamic> &res,
multi_ptr<float, Space, IsDecorated> src, size_t stride, layout Layout);
joint_matrix_load(sg, tA, pBuf + Offset, Stride);
```

// Only available when Layout != layout::dynamic
template <typename Group, size_t Rows, size_t Cols,
use Use, layout Layout,
access::address_space Space, access::decorated IsDecorated>
void joint_matrix_load(Group g,
joint_matrix<Group, precision::tf32, Use, Rows, Cols, Layout> &res,
multi_ptr<const float, Space, IsDecorated> src, size_t stride);
==== store
The template parameter `T2` must be the storage format type that
corresponds to the abstract type `T1`. For example:

// Only available when Layout != layout::dynamic
template <typename Group, size_t Rows, size_t Cols,
use Use, layout Layout,
access::address_space Space, access::decorated IsDecorated>
void joint_matrix_load(Group g,
joint_matrix<Group, precision::tf32, Use, Rows, Cols, Layout> &res,
multi_ptr<float, Space, IsDecorated> src, size_t stride);
```c++
joint_matrix<sub_group, precision::tf32, use::accumulator, tM, tK> tC;

} // namespace sycl::ext::oneapi::experimental::matrix
float *buf = malloc_shared<float>(M*K, q);
auto pBuf = address_space_cast<sycl::access::address_space::global_space,
sycl::access::decorated::no>(buf);

joint_matrix_store(sg, tA, pBuf + Offset, Stride, layout::row_major);
```

==== TF32 store
This overload of joint_matrix_store stores float values from a TF32
matrix.
==== fill
The template parameter `Tv` must be implicitly convertible to the
storage format type that corresponds to the abstract type `T`. For example:

```c++
namespace sycl::ext::oneapi::experimental::matrix {

template <typename Group, size_t Rows, size_t Cols,
access::address_space Space, access::decorated IsDecorated>
void joint_matrix_store(Group g,
const joint_matrix<Group, precision::tf32, use::accumulator, Rows,
Cols, layout::dynamic> &res,
multi_ptr<float, Space, IsDecorated> dest, size_t stride, layout Layout);

} // namespace sycl::ext::oneapi::experimental::matrix
joint_matrix<sub_group, precision::tf32, use::a, tM, tK, layout::row_major> tA;
float v = 42.0;
joint_matrix_fill(sg, tA, v);
```

==== TF32 fill
When `joint_matrix_fill` is called for a TF32 matrix, the type `Tv`
(the type of the fill value) must be implicitly convertible to
`float`. It is unspecified whether the implementation writes all 32
bits of the value into the joint matrix or if it only writes the
relevant 19 bits.
==== copy
There is no special constraint for the `joint_matrix_copy`
function. The template parameters `T1` and `T2` correspond to the
element types of the `src` and `dest` matrices.

==== TF32 element-wise operations
When `joint_matrix_apply` is called for a TF32 matrix, the Callable
object func is called with a single argument of type `float &`. When the
application changes this value, it is unspecified whether the
implementation writes back all 32 bits of the element into the joint
matrix or if it only write the relevant 19 bits.
```c++
joint_matrix<sub_group, precision::tf32, use::a, tM, tK, layout::row_major> tA;
joint_matrix<sub_group, float, use::accumulator, tM, tK> tC;
joint_matrix_copy(sg, tC, tA);
```

In the example below, `C` is a joint matrix of type `precision::tf32`.
==== Element-wise operations
The Callable function type `F` must be invocable with a single argument
whose type is a reference to the storage format type that corresponds
to the abstract type `T`. For example, in the case where `C` is a
joint matrix of type `precision::tf32`:

```c++
joint_matrix_apply(sg, C, [=](float &x) {
joint_matrix<sub_group, precision::tf32, use::accumulator, tM, tK> tC;
joint_matrix_apply(sg, tC, [=](float &x) {
x *= alpha;
});
```
Expand Down Expand Up @@ -887,7 +881,8 @@ is shown in a single column in the table below.
This is currently available in devices with the architecture
`architecture::intel_gpu_pvc`, `architecture::intel_gpu_dg2_g10`,
`architecture::intel_gpu_dg2_g11`, and
`architecture::intel_gpu_dg2_g12`. In these architectures'
`architecture::intel_gpu_dg2_g12`.
In these architectures'
implementation, the type of the C matrix must be the same as the type
of the D matrix. Therefore, that common type is shown in a single
column in the table below.
Expand All @@ -897,27 +892,32 @@ column in the table below.
| A type | B type | C and D type | M | N | K | device
.2+| `matrix_type::uint8` .2+| `matrix_type::uint8` .2+|
`matrix_type::sint32` .2+| +<=+ 8 | 16 .2+| 32
|`architecture::intel_gpu_pvc`|8|`architecture::intel_gpu_dg2_g10,
|`architecture::intel_gpu_pvc`
|8|`architecture::intel_gpu_dg2_g10,
architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12`
.2+| `matrix_type::uint8` .2+| `matrix_type::sint8` .2+|
`matrix_type::sint32` .2+| +<=+ 8 | 16 .2+| 32 |
`architecture::intel_gpu_pvc`|8|`architecture::intel_gpu_dg2_g10,
`architecture::intel_gpu_pvc`
|8|`architecture::intel_gpu_dg2_g10,
architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12`
.2+| `matrix_type::sint8` .2+| `matrix_type::uint8` .2+|
`matrix_type::sint32` .2+| +<=+ 8 | 16 .2+| 32 |
`architecture::intel_gpu_pvc`|8|`architecture::intel_gpu_dg2_g10,
`architecture::intel_gpu_pvc`
|8|`architecture::intel_gpu_dg2_g10,
architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12`
.2+| `matrix_type::sint8` .2+| `matrix_type::sint8` .2+|
`matrix_type::sint32` .2+| +<=+ 8 | 16 .2+| 32 |
`architecture::intel_gpu_pvc`|8|`architecture::intel_gpu_dg2_g10,
`architecture::intel_gpu_pvc`
|8|`architecture::intel_gpu_dg2_g10,
architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12`
.2+|`matrix_type::fp16` .2+| `matrix_type::fp16` .2+|
`matrix_type::fp32` .2+| +<=+ 8 | 16 .2+| 16 |
`architecture::intel_gpu_pvc`|8| `architecture::intel_gpu_dg2_g10,
`architecture::intel_gpu_pvc`
|8| `architecture::intel_gpu_dg2_g10,
architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12`
.4+| `matrix_type::bf16` .4+| `matrix_type::bf16` .4+|
`matrix_type::fp32` | 16 | 16 | 16 .3+|`architecture::intel_gpu_pvc` |
32 | 64 | 16
`matrix_type::fp32` | 16 | 16 | 16 .3+|`architecture::intel_gpu_pvc`
|32 | 64 | 16
.2+| +<=+ 8 | 16 .2+| 16
|8 | `architecture::intel_gpu_dg2_g10,
architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12`
Expand Down