-
Notifications
You must be signed in to change notification settings - Fork 788
[SYCL][Matrix] Add support for tf32 type #5920
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
9b6ed22
[SYCL][Matrix] Add support for tf32 type
dkhaldi 016523f
[SYCL][Matrix] Add a comment about conversion function
dkhaldi 108f04a
[SYCL][Matrix] minor formatting
dkhaldi 98480ab
tf32 cannot be constructed, change load,store, and slicing signatures
dkhaldi 4d7c661
Change the signatures of extract and insert dynamic to return storage…
dkhaldi e349b71
make it illegal to construct tf32 class type
dkhaldi a910a25
formatting
dkhaldi 682dff9
update branch
dkhaldi 41d07cd
Merge remote-tracking branch 'intel_llvm/sycl' into tf32
dkhaldi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,175 @@ | ||||||
// RUN: %clangxx -fsycl -O2 %s -o %t.out | ||||||
|
||||||
#include <sycl/sycl.hpp> | ||||||
#if (SYCL_EXT_ONEAPI_MATRIX == 2) | ||||||
#include <iostream> | ||||||
|
||||||
using namespace sycl; | ||||||
using namespace sycl::ext::oneapi::experimental::matrix; | ||||||
|
||||||
auto constexpr SG_SZ = 8; | ||||||
|
||||||
#define TM 8 | ||||||
#define TN SG_SZ | ||||||
#define TK 16 | ||||||
|
||||||
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix { | ||||||
public: | ||||||
T *mat; | ||||||
|
||||||
public: | ||||||
T *get_data() { return mat; } | ||||||
void set_data(T *data) { mat = data; } | ||||||
big_matrix(T *data) : mat(data) {} | ||||||
}; | ||||||
|
||||||
// this should be replaced with a DPC++ and spirv functions | ||||||
float round_to_tf32(float a) { | ||||||
uint32_t tmp_uint = reinterpret_cast<uint32_t &>(a); | ||||||
tmp_uint += 0x1000u; // Round up the 13th last bit | ||||||
tmp_uint &= 0xFFFFE000u; // Zero out the bottom 13 bits | ||||||
float ret = reinterpret_cast<float &>(tmp_uint); | ||||||
return ret; | ||||||
} | ||||||
|
||||||
template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A, | ||||||
size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C, | ||||||
size_t NUM_COLS_C> | ||||||
void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C, | ||||||
big_matrix<T2, NUM_ROWS_A, NUM_COLS_A> &A, | ||||||
big_matrix<T2, NUM_ROWS_B, NUM_COLS_B> &B) { | ||||||
size_t M = NUM_ROWS_C; | ||||||
size_t N = NUM_COLS_C; | ||||||
size_t K = NUM_COLS_A; | ||||||
|
||||||
assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B); | ||||||
size_t NDRangeM = M / TM; | ||||||
size_t NDRangeN = N / TN; | ||||||
// buffer<float, 2> bufA(A.get_data(), range<2>(M, K)); | ||||||
buffer<float, 2> bufA(A.get_data(), range<2>(M, K)); | ||||||
buffer<float, 2> bufB(B.get_data(), range<2>(K, N)); | ||||||
buffer<float, 2> bufC((float *)C.get_data(), range<2>(M, N)); | ||||||
|
||||||
queue q; | ||||||
q.submit([&](handler &cgh) { | ||||||
auto accC = bufC.get_access<access::mode::read_write>(cgh); | ||||||
auto accA = bufA.get_access<access::mode::read_write>(cgh); | ||||||
auto accB = bufB.get_access<access::mode::read_write>(cgh); | ||||||
|
||||||
cgh.parallel_for<class imatrix>( | ||||||
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), [= | ||||||
](nd_item<2> spmd_item)[[intel::reqd_sub_group_size(SG_SZ)]] | ||||||
|
||||||
{ | ||||||
// The matrix API has to be accessed by all the workitems in a | ||||||
// subgroup these functions will be called once by the subgroup no | ||||||
// code divergence between the workitems | ||||||
const auto global_idx = spmd_item.get_global_id(0); | ||||||
const auto global_idy = spmd_item.get_global_id(1); | ||||||
const auto sg_startx = global_idx - spmd_item.get_local_id(0); | ||||||
const auto sg_starty = global_idy - spmd_item.get_local_id(1); | ||||||
|
||||||
sub_group sg = spmd_item.get_sub_group(); | ||||||
joint_matrix<precision::tf32, TM, TK> sub_a(sg); | ||||||
joint_matrix<precision::tf32, TK, TN, matrix_layout::packed_b> sub_b( | ||||||
sg); | ||||||
joint_matrix<float, TM, TN> sub_c(sg); | ||||||
joint_matrix_load(sg, sub_c, | ||||||
accC.get_pointer() + (sg_startx * TM) * N + | ||||||
sg_starty / SG_SZ * TN, | ||||||
N, matrix_layout::row_major); | ||||||
for (int k = 0; k < K; k += TK) { | ||||||
joint_matrix_load(sg, sub_a, | ||||||
accA.get_pointer() + (sg_startx * TM) * K + k, K, | ||||||
matrix_layout::row_major); | ||||||
// Assume we alreay in vnni format. | ||||||
joint_matrix_load(sg, sub_b, | ||||||
accB.get_pointer() + (k) * (N) + | ||||||
sg_starty / SG_SZ * TN, | ||||||
N, matrix_layout::packed_b); | ||||||
// If no rounding to tf32 function is called, the mad function will | ||||||
// work on truncated floats. | ||||||
// TODO: change signature of __spirv_VectorInsertDynamic to have | ||||||
// two types: matrix type can be different from value type | ||||||
for (int i = 0; i < sub_a.get_wi_data().length(); i++) { | ||||||
sub_a.get_wi_data()[i] = round_to_tf32(sub_a.get_wi_data()[i]); | ||||||
} | ||||||
for (int i = 0; i < sub_b.get_wi_data().length(); i++) { | ||||||
sub_b.get_wi_data()[i] = round_to_tf32(sub_b.get_wi_data()[i]); | ||||||
} | ||||||
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); | ||||||
} | ||||||
auto wi_slice_a = sub_a.get_wi_data(); | ||||||
for (int i = 0; i < wi_slice_a.length(); i++) { | ||||||
float elem = wi_slice_a[i]; | ||||||
wi_slice_a[i] *= 2; | ||||||
} | ||||||
joint_matrix_store(sg, sub_c, | ||||||
accC.get_pointer() + (sg_startx * TM) * N + | ||||||
sg_starty / SG_SZ * TN, | ||||||
N, matrix_layout::row_major); | ||||||
}); // parallel for | ||||||
}) | ||||||
.wait(); | ||||||
} | ||||||
|
||||||
static constexpr size_t MATRIX_M = TM * 2; | ||||||
static constexpr size_t MATRIX_N = TN * 2; | ||||||
static constexpr size_t MATRIX_K = TK * 2; | ||||||
float A[MATRIX_M][MATRIX_K]; | ||||||
float B[MATRIX_K][MATRIX_N]; | ||||||
float C[MATRIX_M][MATRIX_N]; | ||||||
float D[MATRIX_M][MATRIX_N]; | ||||||
|
||||||
void matrix_multiply_ref(float *A_mem, float *B_mem, float *C_mem, int M, int N, | ||||||
int K) { | ||||||
for (int m = 0; m < M; m++) | ||||||
for (int n = 0; n < N; n++) { | ||||||
for (int k = 0; k < K; k++) { | ||||||
float va = A_mem[m * K + k]; | ||||||
float vb = B_mem[k * N + n]; | ||||||
float acc = C_mem[m * N + n]; | ||||||
C_mem[m * N + n] = va * vb; | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
int main() { | ||||||
for (int i = 0; i < MATRIX_M; i++) { | ||||||
for (int j = 0; j < MATRIX_K; j++) { | ||||||
A[i][j] = 1.0f * (i + j); | ||||||
} | ||||||
} | ||||||
for (int i = 0; i < MATRIX_K / 2; i++) { | ||||||
for (int j = 0; j < MATRIX_N * 2; j++) { | ||||||
B[i][j] = 2.0f * i + 3.0f * j; | ||||||
} | ||||||
} | ||||||
for (int i = 0; i < MATRIX_M; i++) { | ||||||
for (int j = 0; j < MATRIX_N; j++) { | ||||||
C[i][j] = 1.0; | ||||||
D[i][j] = 1.0; | ||||||
} | ||||||
} | ||||||
|
||||||
big_matrix<float, MATRIX_M, MATRIX_N> MC((float *)&C); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
and add the right constructor in |
||||||
big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D); | ||||||
big_matrix<float, MATRIX_M, MATRIX_K> MA((float *)&A); | ||||||
big_matrix<float, MATRIX_K, MATRIX_N> MB((float *)&B); | ||||||
matrix_multiply(MC, MA, MB); | ||||||
matrix_multiply_ref((float *)A, (float *)B, (float *)D, MATRIX_M, MATRIX_N, | ||||||
MATRIX_K / 2); | ||||||
|
||||||
bool res = true; | ||||||
for (int i = 0; i < MATRIX_M; i++) { | ||||||
for (int j = 0; j < MATRIX_N; j++) { | ||||||
if (C[i][j] != D[i][j]) | ||||||
res = false; | ||||||
} | ||||||
} | ||||||
if (res) | ||||||
std::cout << "passed\n"; | ||||||
else | ||||||
std::cout << "failed\n"; | ||||||
} | ||||||
#endif // (SYCL_EXT_ONEAPI_MATRIX == 2) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps it is a good opportunity to remove all these macros?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yubingex007-a11y, I remember you change this code to use macros and make it more compact.
The code was before expanded for each of the ops. Bing changed it to remove the redundancy.
@keryell what do you suggest we should use instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sometimes macros are the best or only reasonable solution.
In that case use protected names like
__DPC_SYCL_OP
or whatever to avoid the case where a user decides to use in her program::-)