Skip to content

Commit d46d3d6

Browse files
authored
[SYCL][CUDA][DOC] Added Tensor Cores supported param combinations table to joint_matrix extension doc (#9019)
This PR documents the supported joint_matrix API parameters sets when using `ext_oneapi_cuda`, similar to the XMX, AMX tables added here: #7964 This will allow us to point people who would like to use `joint_matrix` on a specific architecture to the extension document. E.g. #8795 --------- Signed-off-by: JackAKirk <[email protected]>
1 parent c0ab9f8 commit d46d3d6

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

sycl/doc/extensions/experimental/sycl_ext_oneapi_matrix/sycl_ext_oneapi_matrix.asciidoc

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,59 @@ for (int i = 0; i < data.length; ++i) {
579579
}
580580
```
581581

582+
=== Appendix: Supported Parameter Combinations Per Hardware
583+
584+
The tables below provide a list of the parameter combinations that
585+
`joint_matrix` implementations support on each supported vendors hardware type.
586+
587+
==== Nvidia Tensor Cores Supported Combinations
588+
589+
The complete set of matrix data types and shapes that are supported by the `ext_oneapi_cuda` backend are represented in the following table. Tm indicates the matrix element data type held by a "multiplicand" `joint_matrix`: i.e requiring `use::a` or `use::b`. Tc indicates the matrix element data type held by an "accumulator" `joint_matrix`: i.e requiring `use::accumulator`.
590+
591+
IMPORTANT: When compiling for the `ext_oneapi_cuda` backend the target arch backend flag, `-Xsycl-target-backend --cuda-gpu-arch=sm_xx`, must be used, where `sm_xx` must be a Compute Capability that is equal to or greater than the appropriate Minimum Compute Capability. When an executable has been compiled for `sm_xx`, if the executable is run on a device with compute capability less than `sm_xx` then an error will be thrown. The mapping to Minimum Compute Capability from each supported parameter combination is specified in the following table.
592+
593+
--
594+
[.center]
595+
|======================
596+
|Tm (`use::a` or `use::b`) |Tc (`use::accumulator`) |M |N |K | Minimum Compute Capability
597+
.3+|half .3+|float
598+
|16 |16 |16 .6+| sm_70
599+
|8 |32 |16
600+
|32 |8 |16
601+
.3+|half .3+|half
602+
|16 |16 |16
603+
|8 |32 |16
604+
|32 |8 |16
605+
.3+|int8_t .3+|int32_t
606+
|16 |16 |16 .6+| sm_72
607+
|8 |32 |16
608+
|32 |8 |16
609+
.3+|uint8_t .3+|int32_t
610+
|16 |16 |16
611+
|8 |32 |16
612+
|32 |8 |16
613+
|precision::tf32 |float |16 |16 |8 .5+| sm_80
614+
.3+|bfloat16 .3+|float
615+
|16 |16 |16
616+
|8 |32 |16
617+
|32 |8 |16
618+
|double |double |8 |8 |4
619+
|======================
620+
--
621+
622+
The M, N, K triple from the above table defines the complete set of matrix shapes constructible:
623+
--
624+
[.center]
625+
|======================
626+
|use |NumRows | NumCols
627+
|a |M |K
628+
|b |K |N
629+
|accumulator | M| N
630+
|======================
631+
--
632+
633+
IMPORTANT: The `stride` argument to `joint_matrix_load` and `joint_matrix_store` must be a multiple of 8 when `T` is `half`, and a multiple of 4 when `T` is `float`; where `T` is the type of the `joint_matrix` elements. When `T` is not `half` or `float` there are no restrictions to `stride`.
634+
582635
## TODO List
583636
- Add WI data to joint matrix mapping coordinates information for piece-wise operations. This will be added as part of the query or new methods to the 'get_wi_data' class.
584637
- Add a more realistic and complete example that shows the value of the general query.

0 commit comments

Comments
 (0)