Skip to content

Commit 68f6d92

Browse files
authored
[SYCL][Doc][Joint matrix] Joint matrix tf32 type spec restructuring (#12276)
1 parent f62def0 commit 68f6d92

File tree

1 file changed

+84
-84
lines changed

1 file changed

+84
-84
lines changed

sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc

Lines changed: 84 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -252,11 +252,12 @@ layout.
252252
```c++
253253
namespace sycl::ext::oneapi::experimental::matrix {
254254

255-
template <typename Group, typename T, size_t Rows, size_t Cols,
255+
// T1 must be the same as T2
256+
template <typename Group, typename T1, typename T2, size_t Rows, size_t Cols,
256257
access::address_space Space, access::decorated IsDecorated>
257258
void joint_matrix_store(Group g,
258-
const joint_matrix<Group, T, use::accumulator, Rows, Cols, layout::dynamic> &res,
259-
multi_ptr<T, Space, IsDecorated> dest, size_t stride, layout Layout);
259+
const joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
260+
multi_ptr<T2, Space, IsDecorated> dest, size_t stride, layout Layout);
260261

261262
} // namespace sycl::ext::oneapi::experimental::matrix
262263
```
@@ -371,15 +372,17 @@ joint_matrix_apply(sg, C, [=](T &x) {
371372
});
372373
```
373374

374-
=== Support for the TF32 Data Type
375-
Some devices support the TF32 floating point type for matrix
376-
elements. This type has a 19 bit format with one sign bit, 8 exponent
377-
bits (offering the same range as float), and 10 mantissa bits
378-
(offering the same precision as sycl::half). Use of this type can
379-
accelerate the joint_matrix_mad operation by reducing its
380-
precision. In order to declare a `joint_matrix` object with this
381-
element type, use `matrix::precision::tf32` in place of the `T`
382-
template parameter.
375+
=== Support for Machine Learning Types
376+
Some devices support special matrix element types that are commonly
377+
used in machine learning algorithms.
378+
These types are unusual because the type of the matrix element is
379+
different from the way the data is stored in memory. As a result, each
380+
of these elements has two types. There is an abstract identifier for
381+
the element type, which is an incomplete type defined in the
382+
`sycl::ext::oneapi::experimental::matrix::precision` namespace, and
383+
there is a corresponding storage format type. The following synopsis
384+
lists the abstract types and the table shows the associated storage
385+
format type.
383386

384387
```c++
385388
namespace sycl::ext::oneapi::experimental::matrix::precision {
@@ -389,94 +392,85 @@ class tf32;
389392
} // namespace sycl::ext::oneapi::experimental::matrix::precision
390393
```
391394

392-
For example:
395+
[frame="none",options="header",cols="20%,20%,60%"]
396+
|======================
397+
| `joint_matrix` element type | Storage type | Descritpion
398+
|precision::tf32 | float | The TF32 type has a 19 bit format with one
399+
sign bit, 8 exponent bits (offering the same range as float), and 10
400+
mantissa bits (offering the same precision as sycl::half).
401+
|======================
402+
403+
In order to declare a `joint_matrix` with one of these element types,
404+
use the abstract type like so:
393405

394406
```c++
395407
joint_matrix<sub_group, precision::tf32, use::a, tM, tK,
396408
layout::row_major> tA;
397409
```
398410

399-
Whenever the application loads, stores, fills, or accesses the
400-
elements of a TF32 matrix, the application sees the elements as
401-
float. There are special overloads of these functions for TF32 for
402-
this purpose.
411+
Operations on these matrices use the functions described above, but
412+
there are different constraints on the template parameters as
413+
described below.
403414

404-
==== TF32 load
405-
These overloads of `joint_matrix_load` load float values into a TF32
406-
matrix. It is unspecified whether the implementation loads all 32 bits
407-
into the joint matrix or if it only loads the relevant 19 bits.
415+
==== load
416+
The template parameter `T2` must either be the storage format type
417+
that corresponds to the abstract type `T1` or it must be a
418+
const-qualified version of that storage format type. For example:
408419

409420
```c++
410-
namespace sycl::ext::oneapi::experimental::matrix {
421+
joint_matrix<sub_group, precision::tf32, use::a, tM, tK, layout::row_major> tA;
411422

412-
template <typename Group, size_t Rows, size_t Cols,
413-
access::address_space Space, access::decorated IsDecorated>
414-
void joint_matrix_load(Group g,
415-
joint_matrix<Group, precision::tf32, use::accumulator, Rows, Cols,
416-
layout::dynamic> &res,
417-
multi_ptr<const float, Space, IsDecorated> src, size_t stride, layout Layout);
423+
float *buf = malloc_shared<float>(M*K, q);
424+
auto pBuf = address_space_cast<sycl::access::address_space::global_space,
425+
sycl::access::decorated::no>(buf);
418426

419-
template <typename Group, size_t Rows, size_t Cols,
420-
access::address_space Space, access::decorated IsDecorated>
421-
void joint_matrix_load(Group g,
422-
joint_matrix<Group, precision::tf32, use::accumulator, Rows, Cols,
423-
layout::dynamic> &res,
424-
multi_ptr<float, Space, IsDecorated> src, size_t stride, layout Layout);
427+
joint_matrix_load(sg, tA, pBuf + Offset, Stride);
428+
```
425429

426-
// Only available when Layout != layout::dynamic
427-
template <typename Group, size_t Rows, size_t Cols,
428-
use Use, layout Layout,
429-
access::address_space Space, access::decorated IsDecorated>
430-
void joint_matrix_load(Group g,
431-
joint_matrix<Group, precision::tf32, Use, Rows, Cols, Layout> &res,
432-
multi_ptr<const float, Space, IsDecorated> src, size_t stride);
430+
==== store
431+
The template parameter `T2` must be the storage format type that
432+
corresponds to the abstract type `T1`. For example:
433433

434-
// Only available when Layout != layout::dynamic
435-
template <typename Group, size_t Rows, size_t Cols,
436-
use Use, layout Layout,
437-
access::address_space Space, access::decorated IsDecorated>
438-
void joint_matrix_load(Group g,
439-
joint_matrix<Group, precision::tf32, Use, Rows, Cols, Layout> &res,
440-
multi_ptr<float, Space, IsDecorated> src, size_t stride);
434+
```c++
435+
joint_matrix<sub_group, precision::tf32, use::accumulator, tM, tK> tC;
441436

442-
} // namespace sycl::ext::oneapi::experimental::matrix
437+
float *buf = malloc_shared<float>(M*K, q);
438+
auto pBuf = address_space_cast<sycl::access::address_space::global_space,
439+
sycl::access::decorated::no>(buf);
440+
441+
joint_matrix_store(sg, tA, pBuf + Offset, Stride, layout::row_major);
443442
```
444443

445-
==== TF32 store
446-
This overload of joint_matrix_store stores float values from a TF32
447-
matrix.
444+
==== fill
445+
The template parameter `Tv` must be implicitly convertible to the
446+
storage format type that corresponds to the abstract type `T`. For example:
448447

449448
```c++
450-
namespace sycl::ext::oneapi::experimental::matrix {
451-
452-
template <typename Group, size_t Rows, size_t Cols,
453-
access::address_space Space, access::decorated IsDecorated>
454-
void joint_matrix_store(Group g,
455-
const joint_matrix<Group, precision::tf32, use::accumulator, Rows,
456-
Cols, layout::dynamic> &res,
457-
multi_ptr<float, Space, IsDecorated> dest, size_t stride, layout Layout);
458-
459-
} // namespace sycl::ext::oneapi::experimental::matrix
449+
joint_matrix<sub_group, precision::tf32, use::a, tM, tK, layout::row_major> tA;
450+
float v = 42.0;
451+
joint_matrix_fill(sg, tA, v);
460452
```
461453

462-
==== TF32 fill
463-
When `joint_matrix_fill` is called for a TF32 matrix, the type `Tv`
464-
(the type of the fill value) must be implicitly convertible to
465-
`float`. It is unspecified whether the implementation writes all 32
466-
bits of the value into the joint matrix or if it only writes the
467-
relevant 19 bits.
454+
==== copy
455+
There is no special constraint for the `joint_matrix_copy`
456+
function. The template parameters `T1` and `T2` correspond to the
457+
element types of the `src` and `dest` matrices.
468458

469-
==== TF32 element-wise operations
470-
When `joint_matrix_apply` is called for a TF32 matrix, the Callable
471-
object func is called with a single argument of type `float &`. When the
472-
application changes this value, it is unspecified whether the
473-
implementation writes back all 32 bits of the element into the joint
474-
matrix or if it only write the relevant 19 bits.
459+
```c++
460+
joint_matrix<sub_group, precision::tf32, use::a, tM, tK, layout::row_major> tA;
461+
joint_matrix<sub_group, float, use::accumulator, tM, tK> tC;
462+
joint_matrix_copy(sg, tC, tA);
463+
```
475464

476-
In the example below, `C` is a joint matrix of type `precision::tf32`.
465+
==== Element-wise operations
466+
The Callable function type `F` must be invocable with a single argument
467+
whose type is a reference to the storage format type that corresponds
468+
to the abstract type `T`. For example, in the case where `C` is a
469+
joint matrix of type `precision::tf32`:
477470

478471
```c++
479-
joint_matrix_apply(sg, C, [=](float &x) {
472+
joint_matrix<sub_group, precision::tf32, use::accumulator, tM, tK> tC;
473+
joint_matrix_apply(sg, tC, [=](float &x) {
480474
x *= alpha;
481475
});
482476
```
@@ -887,7 +881,8 @@ is shown in a single column in the table below.
887881
This is currently available in devices with the architecture
888882
`architecture::intel_gpu_pvc`, `architecture::intel_gpu_dg2_g10`,
889883
`architecture::intel_gpu_dg2_g11`, and
890-
`architecture::intel_gpu_dg2_g12`. In these architectures'
884+
`architecture::intel_gpu_dg2_g12`.
885+
In these architectures'
891886
implementation, the type of the C matrix must be the same as the type
892887
of the D matrix. Therefore, that common type is shown in a single
893888
column in the table below.
@@ -897,27 +892,32 @@ column in the table below.
897892
| A type | B type | C and D type | M | N | K | device
898893
.2+| `matrix_type::uint8` .2+| `matrix_type::uint8` .2+|
899894
`matrix_type::sint32` .2+| +<=+ 8 | 16 .2+| 32
900-
|`architecture::intel_gpu_pvc`|8|`architecture::intel_gpu_dg2_g10,
895+
|`architecture::intel_gpu_pvc`
896+
|8|`architecture::intel_gpu_dg2_g10,
901897
architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12`
902898
.2+| `matrix_type::uint8` .2+| `matrix_type::sint8` .2+|
903899
`matrix_type::sint32` .2+| +<=+ 8 | 16 .2+| 32 |
904-
`architecture::intel_gpu_pvc`|8|`architecture::intel_gpu_dg2_g10,
900+
`architecture::intel_gpu_pvc`
901+
|8|`architecture::intel_gpu_dg2_g10,
905902
architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12`
906903
.2+| `matrix_type::sint8` .2+| `matrix_type::uint8` .2+|
907904
`matrix_type::sint32` .2+| +<=+ 8 | 16 .2+| 32 |
908-
`architecture::intel_gpu_pvc`|8|`architecture::intel_gpu_dg2_g10,
905+
`architecture::intel_gpu_pvc`
906+
|8|`architecture::intel_gpu_dg2_g10,
909907
architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12`
910908
.2+| `matrix_type::sint8` .2+| `matrix_type::sint8` .2+|
911909
`matrix_type::sint32` .2+| +<=+ 8 | 16 .2+| 32 |
912-
`architecture::intel_gpu_pvc`|8|`architecture::intel_gpu_dg2_g10,
910+
`architecture::intel_gpu_pvc`
911+
|8|`architecture::intel_gpu_dg2_g10,
913912
architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12`
914913
.2+|`matrix_type::fp16` .2+| `matrix_type::fp16` .2+|
915914
`matrix_type::fp32` .2+| +<=+ 8 | 16 .2+| 16 |
916-
`architecture::intel_gpu_pvc`|8| `architecture::intel_gpu_dg2_g10,
915+
`architecture::intel_gpu_pvc`
916+
|8| `architecture::intel_gpu_dg2_g10,
917917
architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12`
918918
.4+| `matrix_type::bf16` .4+| `matrix_type::bf16` .4+|
919-
`matrix_type::fp32` | 16 | 16 | 16 .3+|`architecture::intel_gpu_pvc` |
920-
32 | 64 | 16
919+
`matrix_type::fp32` | 16 | 16 | 16 .3+|`architecture::intel_gpu_pvc`
920+
|32 | 64 | 16
921921
.2+| +<=+ 8 | 16 .2+| 16
922922
|8 | `architecture::intel_gpu_dg2_g10,
923923
architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12`

0 commit comments

Comments
 (0)