@@ -252,11 +252,12 @@ layout.
252
252
```c++
253
253
namespace sycl::ext::oneapi::experimental::matrix {
254
254
255
- template <typename Group, typename T, size_t Rows, size_t Cols,
255
+ // T1 must be the same as T2
256
+ template <typename Group, typename T1, typename T2, size_t Rows, size_t Cols,
256
257
access::address_space Space, access::decorated IsDecorated>
257
258
void joint_matrix_store(Group g,
258
- const joint_matrix<Group, T , use::accumulator, Rows, Cols, layout::dynamic> &res,
259
- multi_ptr<T , Space, IsDecorated> dest, size_t stride, layout Layout);
259
+ const joint_matrix<Group, T1 , use::accumulator, Rows, Cols, layout::dynamic> &res,
260
+ multi_ptr<T2 , Space, IsDecorated> dest, size_t stride, layout Layout);
260
261
261
262
} // namespace sycl::ext::oneapi::experimental::matrix
262
263
```
@@ -371,15 +372,17 @@ joint_matrix_apply(sg, C, [=](T &x) {
371
372
});
372
373
```
373
374
374
- === Support for the TF32 Data Type
375
- Some devices support the TF32 floating point type for matrix
376
- elements. This type has a 19 bit format with one sign bit, 8 exponent
377
- bits (offering the same range as float), and 10 mantissa bits
378
- (offering the same precision as sycl::half). Use of this type can
379
- accelerate the joint_matrix_mad operation by reducing its
380
- precision. In order to declare a `joint_matrix` object with this
381
- element type, use `matrix::precision::tf32` in place of the `T`
382
- template parameter.
375
+ === Support for Machine Learning Types
376
+ Some devices support special matrix element types that are commonly
377
+ used in machine learning algorithms.
378
+ These types are unusual because the type of the matrix element is
379
+ different from the way the data is stored in memory. As a result, each
380
+ of these elements has two types. There is an abstract identifier for
381
+ the element type, which is an incomplete type defined in the
382
+ `sycl::ext::oneapi::experimental::matrix::precision` namespace, and
383
+ there is a corresponding storage format type. The following synopsis
384
+ lists the abstract types and the table shows the associated storage
385
+ format type.
383
386
384
387
```c++
385
388
namespace sycl::ext::oneapi::experimental::matrix::precision {
@@ -389,94 +392,85 @@ class tf32;
389
392
} // namespace sycl::ext::oneapi::experimental::matrix::precision
390
393
```
391
394
392
- For example:
395
+ [frame="none",options="header",cols="20%,20%,60%"]
396
+ |======================
397
+ | `joint_matrix` element type | Storage type | Descritpion
398
+ |precision::tf32 | float | The TF32 type has a 19 bit format with one
399
+ sign bit, 8 exponent bits (offering the same range as float), and 10
400
+ mantissa bits (offering the same precision as sycl::half).
401
+ |======================
402
+
403
+ In order to declare a `joint_matrix` with one of these element types,
404
+ use the abstract type like so:
393
405
394
406
```c++
395
407
joint_matrix<sub_group, precision::tf32, use::a, tM, tK,
396
408
layout::row_major> tA;
397
409
```
398
410
399
- Whenever the application loads, stores, fills, or accesses the
400
- elements of a TF32 matrix, the application sees the elements as
401
- float. There are special overloads of these functions for TF32 for
402
- this purpose.
411
+ Operations on these matrices use the functions described above, but
412
+ there are different constraints on the template parameters as
413
+ described below.
403
414
404
- ==== TF32 load
405
- These overloads of `joint_matrix_load` load float values into a TF32
406
- matrix. It is unspecified whether the implementation loads all 32 bits
407
- into the joint matrix or if it only loads the relevant 19 bits.
415
+ ==== load
416
+ The template parameter `T2` must either be the storage format type
417
+ that corresponds to the abstract type `T1` or it must be a
418
+ const-qualified version of that storage format type. For example:
408
419
409
420
```c++
410
- namespace sycl::ext::oneapi::experimental::matrix {
421
+ joint_matrix<sub_group, precision::tf32, use::a, tM, tK, layout::row_major> tA;
411
422
412
- template <typename Group, size_t Rows, size_t Cols,
413
- access::address_space Space, access::decorated IsDecorated>
414
- void joint_matrix_load(Group g,
415
- joint_matrix<Group, precision::tf32, use::accumulator, Rows, Cols,
416
- layout::dynamic> &res,
417
- multi_ptr<const float, Space, IsDecorated> src, size_t stride, layout Layout);
423
+ float *buf = malloc_shared<float>(M*K, q);
424
+ auto pBuf = address_space_cast<sycl::access::address_space::global_space,
425
+ sycl::access::decorated::no>(buf);
418
426
419
- template <typename Group, size_t Rows, size_t Cols,
420
- access::address_space Space, access::decorated IsDecorated>
421
- void joint_matrix_load(Group g,
422
- joint_matrix<Group, precision::tf32, use::accumulator, Rows, Cols,
423
- layout::dynamic> &res,
424
- multi_ptr<float, Space, IsDecorated> src, size_t stride, layout Layout);
427
+ joint_matrix_load(sg, tA, pBuf + Offset, Stride);
428
+ ```
425
429
426
- // Only available when Layout != layout::dynamic
427
- template <typename Group, size_t Rows, size_t Cols,
428
- use Use, layout Layout,
429
- access::address_space Space, access::decorated IsDecorated>
430
- void joint_matrix_load(Group g,
431
- joint_matrix<Group, precision::tf32, Use, Rows, Cols, Layout> &res,
432
- multi_ptr<const float, Space, IsDecorated> src, size_t stride);
430
+ ==== store
431
+ The template parameter `T2` must be the storage format type that
432
+ corresponds to the abstract type `T1`. For example:
433
433
434
- // Only available when Layout != layout::dynamic
435
- template <typename Group, size_t Rows, size_t Cols,
436
- use Use, layout Layout,
437
- access::address_space Space, access::decorated IsDecorated>
438
- void joint_matrix_load(Group g,
439
- joint_matrix<Group, precision::tf32, Use, Rows, Cols, Layout> &res,
440
- multi_ptr<float, Space, IsDecorated> src, size_t stride);
434
+ ```c++
435
+ joint_matrix<sub_group, precision::tf32, use::accumulator, tM, tK> tC;
441
436
442
- } // namespace sycl::ext::oneapi::experimental::matrix
437
+ float *buf = malloc_shared<float>(M*K, q);
438
+ auto pBuf = address_space_cast<sycl::access::address_space::global_space,
439
+ sycl::access::decorated::no>(buf);
440
+
441
+ joint_matrix_store(sg, tA, pBuf + Offset, Stride, layout::row_major);
443
442
```
444
443
445
- ==== TF32 store
446
- This overload of joint_matrix_store stores float values from a TF32
447
- matrix.
444
+ ==== fill
445
+ The template parameter `Tv` must be implicitly convertible to the
446
+ storage format type that corresponds to the abstract type `T`. For example:
448
447
449
448
```c++
450
- namespace sycl::ext::oneapi::experimental::matrix {
451
-
452
- template <typename Group, size_t Rows, size_t Cols,
453
- access::address_space Space, access::decorated IsDecorated>
454
- void joint_matrix_store(Group g,
455
- const joint_matrix<Group, precision::tf32, use::accumulator, Rows,
456
- Cols, layout::dynamic> &res,
457
- multi_ptr<float, Space, IsDecorated> dest, size_t stride, layout Layout);
458
-
459
- } // namespace sycl::ext::oneapi::experimental::matrix
449
+ joint_matrix<sub_group, precision::tf32, use::a, tM, tK, layout::row_major> tA;
450
+ float v = 42.0;
451
+ joint_matrix_fill(sg, tA, v);
460
452
```
461
453
462
- ==== TF32 fill
463
- When `joint_matrix_fill` is called for a TF32 matrix, the type `Tv`
464
- (the type of the fill value) must be implicitly convertible to
465
- `float`. It is unspecified whether the implementation writes all 32
466
- bits of the value into the joint matrix or if it only writes the
467
- relevant 19 bits.
454
+ ==== copy
455
+ There is no special constraint for the `joint_matrix_copy`
456
+ function. The template parameters `T1` and `T2` correspond to the
457
+ element types of the `src` and `dest` matrices.
468
458
469
- ==== TF32 element-wise operations
470
- When `joint_matrix_apply` is called for a TF32 matrix, the Callable
471
- object func is called with a single argument of type `float &`. When the
472
- application changes this value, it is unspecified whether the
473
- implementation writes back all 32 bits of the element into the joint
474
- matrix or if it only write the relevant 19 bits.
459
+ ```c++
460
+ joint_matrix<sub_group, precision::tf32, use::a, tM, tK, layout::row_major> tA;
461
+ joint_matrix<sub_group, float, use::accumulator, tM, tK> tC;
462
+ joint_matrix_copy(sg, tC, tA);
463
+ ```
475
464
476
- In the example below, `C` is a joint matrix of type `precision::tf32`.
465
+ ==== Element-wise operations
466
+ The Callable function type `F` must be invocable with a single argument
467
+ whose type is a reference to the storage format type that corresponds
468
+ to the abstract type `T`. For example, in the case where `C` is a
469
+ joint matrix of type `precision::tf32`:
477
470
478
471
```c++
479
- joint_matrix_apply(sg, C, [=](float &x) {
472
+ joint_matrix<sub_group, precision::tf32, use::accumulator, tM, tK> tC;
473
+ joint_matrix_apply(sg, tC, [=](float &x) {
480
474
x *= alpha;
481
475
});
482
476
```
@@ -887,7 +881,8 @@ is shown in a single column in the table below.
887
881
This is currently available in devices with the architecture
888
882
`architecture::intel_gpu_pvc`, `architecture::intel_gpu_dg2_g10`,
889
883
`architecture::intel_gpu_dg2_g11`, and
890
- `architecture::intel_gpu_dg2_g12`. In these architectures'
884
+ `architecture::intel_gpu_dg2_g12`.
885
+ In these architectures'
891
886
implementation, the type of the C matrix must be the same as the type
892
887
of the D matrix. Therefore, that common type is shown in a single
893
888
column in the table below.
@@ -897,27 +892,32 @@ column in the table below.
897
892
| A type | B type | C and D type | M | N | K | device
898
893
.2+| `matrix_type::uint8` .2+| `matrix_type::uint8` .2+|
899
894
`matrix_type::sint32` .2+| +<=+ 8 | 16 .2+| 32
900
- |`architecture::intel_gpu_pvc`|8|`architecture::intel_gpu_dg2_g10,
895
+ |`architecture::intel_gpu_pvc`
896
+ |8|`architecture::intel_gpu_dg2_g10,
901
897
architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12`
902
898
.2+| `matrix_type::uint8` .2+| `matrix_type::sint8` .2+|
903
899
`matrix_type::sint32` .2+| +<=+ 8 | 16 .2+| 32 |
904
- `architecture::intel_gpu_pvc`|8|`architecture::intel_gpu_dg2_g10,
900
+ `architecture::intel_gpu_pvc`
901
+ |8|`architecture::intel_gpu_dg2_g10,
905
902
architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12`
906
903
.2+| `matrix_type::sint8` .2+| `matrix_type::uint8` .2+|
907
904
`matrix_type::sint32` .2+| +<=+ 8 | 16 .2+| 32 |
908
- `architecture::intel_gpu_pvc`|8|`architecture::intel_gpu_dg2_g10,
905
+ `architecture::intel_gpu_pvc`
906
+ |8|`architecture::intel_gpu_dg2_g10,
909
907
architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12`
910
908
.2+| `matrix_type::sint8` .2+| `matrix_type::sint8` .2+|
911
909
`matrix_type::sint32` .2+| +<=+ 8 | 16 .2+| 32 |
912
- `architecture::intel_gpu_pvc`|8|`architecture::intel_gpu_dg2_g10,
910
+ `architecture::intel_gpu_pvc`
911
+ |8|`architecture::intel_gpu_dg2_g10,
913
912
architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12`
914
913
.2+|`matrix_type::fp16` .2+| `matrix_type::fp16` .2+|
915
914
`matrix_type::fp32` .2+| +<=+ 8 | 16 .2+| 16 |
916
- `architecture::intel_gpu_pvc`|8| `architecture::intel_gpu_dg2_g10,
915
+ `architecture::intel_gpu_pvc`
916
+ |8| `architecture::intel_gpu_dg2_g10,
917
917
architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12`
918
918
.4+| `matrix_type::bf16` .4+| `matrix_type::bf16` .4+|
919
- `matrix_type::fp32` | 16 | 16 | 16 .3+|`architecture::intel_gpu_pvc` |
920
- 32 | 64 | 16
919
+ `matrix_type::fp32` | 16 | 16 | 16 .3+|`architecture::intel_gpu_pvc`
920
+ | 32 | 64 | 16
921
921
.2+| +<=+ 8 | 16 .2+| 16
922
922
|8 | `architecture::intel_gpu_dg2_g10,
923
923
architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12`
0 commit comments