Skip to content

[SYCL][CUDA][DOC] Added Tensor Cores supported param combinations table to joint_matrix extension doc #9019

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,59 @@ for (int i = 0; i < data.length; ++i) {
}
```

=== Appendix: Supported Parameter Combinations Per Hardware

The tables below provide a list of the parameter combinations that
`joint_matrix` implementations support on each supported vendors hardware type.

==== Nvidia Tensor Cores Supported Combinations

The complete set of matrix data types and shapes that are supported by the `ext_oneapi_cuda` backend are represented in the following table. Tm indicates the matrix element data type held by a "multiplicand" `joint_matrix`: i.e requiring `use::a` or `use::b`. Tc indicates the matrix element data type held by an "accumulator" `joint_matrix`: i.e requiring `use::accumulator`.

IMPORTANT: When compiling for the `ext_oneapi_cuda` backend the target arch backend flag, `-Xsycl-target-backend --cuda-gpu-arch=sm_xx`, must be used, where `sm_xx` must be a Compute Capability that is equal to or greater than the appropriate Minimum Compute Capability. When an executable has been compiled for `sm_xx`, if the executable is run on a device with compute capability less than `sm_xx` then an error will be thrown. The mapping to Minimum Compute Capability from each supported parameter combination is specified in the following table.

--
[.center]
|======================
|Tm (`use::a` or `use::b`) |Tc (`use::accumulator`) |M |N |K | Minimum Compute Capability
.3+|half .3+|float
|16 |16 |16 .6+| sm_70
|8 |32 |16
|32 |8 |16
.3+|half .3+|half
|16 |16 |16
|8 |32 |16
|32 |8 |16
.3+|int8_t .3+|int32_t
|16 |16 |16 .6+| sm_72
|8 |32 |16
|32 |8 |16
.3+|uint8_t .3+|int32_t
|16 |16 |16
|8 |32 |16
|32 |8 |16
|precision::tf32 |float |16 |16 |8 .5+| sm_80
.3+|bfloat16 .3+|float
|16 |16 |16
|8 |32 |16
|32 |8 |16
|double |double |8 |8 |4
|======================
--

The M, N, K triple from the above table defines the complete set of matrix shapes constructible:
--
[.center]
|======================
|use |NumRows | NumCols
|a |M |K
|b |K |N
|accumulator | M| N
|======================
--

IMPORTANT: The `stride` argument to `joint_matrix_load` and `joint_matrix_store` must be a multiple of 8 when `T` is `half`, and a multiple of 4 when `T` is `float`; where `T` is the type of the `joint_matrix` elements. When `T` is not `half` or `float` there are no restrictions to `stride`.

## TODO List
- Add WI data to joint matrix mapping coordinates information for piece-wise operations. This will be added as part of the query or new methods to the 'get_wi_data' class.
- Add a more realistic and complete example that shows the value of the general query.
Expand Down