Skip to content

[SYCL][Joint Matrix] Fix error in TF32 test, refactoring #11576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions sycl/test-e2e/Matrix/SG32/joint_matrix_tf32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@

// XFAIL:cpu

#include <iostream>
#include <random>
#include <sycl/sycl.hpp>
#include "../common.hpp"

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;
Expand Down
30 changes: 22 additions & 8 deletions sycl/test-e2e/Matrix/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,15 @@ void matrix_multiply_ref(Ta *A, Ta *B, Tc *C, int M, int N, int K,
for (unsigned int n = 0; n < N; n++) {
for (unsigned int k = 0; k < K; k++) {
int c_ind = transpose_c ? (n * M + m) : m * N + n;
if (std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float>)
if constexpr (std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float>)
C[c_ind] += make_fp32(A[m * K + k]) * make_fp32(B[k * N + n]);
if (std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t>)
else if constexpr (std::is_same_v<Ta, float> &&
std::is_same_v<Tc, float> ||
std::is_same_v<Ta, int8_t> &&
std::is_same_v<Tc, int32_t>)
C[c_ind] += A[m * K + k] * B[k * N + n];
else
assert(false && "Unsupported type in matrix_multiply_ref.");
}
}
}
Expand Down Expand Up @@ -63,6 +68,15 @@ void matrix_fill(unsigned int rows, unsigned int cols, T *src, T val) {
}
}

template <typename T, typename F>
void matrix_fill(unsigned int rows, unsigned int cols, T *src, F op) {
for (unsigned int i = 0; i < rows; i++) {
for (unsigned int j = 0; j < cols; j++) {
src[i * cols + j] = T(op(i, j));
}
}
}

template <typename T>
void matrix_rand(unsigned int rows, unsigned int cols, T *src, T val) {
std::random_device dev;
Expand All @@ -89,17 +103,17 @@ bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) {
for (int j = 0; j < cols; j++) {
if constexpr (std::is_same_v<T1, float> || std::is_same_v<T1, bfloat16>) {
float diff = std::fabs(src[i * cols + j] - (T1)ref[i * cols + j]);
if (diff > BF16_EPSILON) {
if (std::is_same_v<T1, float> && diff > FLOAT_EPSILON ||
std::is_same_v<T1, bfloat16> && diff > BF16_EPSILON) {
std::cout << "Incorrect result in matrix. Ref: "
<< (T1)ref[i * cols + j] << ", Val:" << src[i * cols + j]
<< ", Diff: " << diff << ", Epsilon: " << BF16_EPSILON
<< "\n";
<< (T1)ref[i * cols + j] << ", Val: " << src[i * cols + j]
<< ", Diff: " << diff << "\n";
return false;
}
} else if (std::is_same_v<T1, int32_t>) {
} else if constexpr (std::is_same_v<T1, int32_t>) {
if (src[i * cols + j] != ref[i * cols + j]) {
std::cout << "Incorrect result in matrix. Ref: " << ref[i * cols + j]
<< ", Val:" << src[i * cols + j] << "\n";
<< ", Val: " << src[i * cols + j] << "\n";
return false;
}
} else {
Expand Down
4 changes: 1 addition & 3 deletions sycl/test-e2e/Matrix/joint_matrix_tf32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@

// XFAIL:cpu

#include <iostream>
#include <random>
#include <sycl/sycl.hpp>
#include "common.hpp"

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;
Expand Down
78 changes: 19 additions & 59 deletions sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,5 @@
#define TM 8
#define TK 8

template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
public:
T *mat;

public:
T *get_data() { return mat; }
void set_data(T *data) { mat = data; }
big_matrix(T *data) : mat(data) {}
};
constexpr size_t TM = 8;
constexpr size_t TK = 8;

template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A,
size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C,
Expand Down Expand Up @@ -60,7 +50,6 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
accC.template get_multi_ptr<access::decorated::no>() +
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
N, layout::row_major);
joint_matrix_fill(sg, sub_a, 42);
for (int k = 0; k < K; k += TK) {
joint_matrix_load(
sg, sub_a,
Expand All @@ -75,13 +64,12 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
// If no rounding to tf32 function is called, joint_matrix_mad
// function will work on truncated floats.
joint_matrix_apply(sg, sub_a,
[=](float x) { x = round_to_tf32(x); });
[=](float &x) { x = round_to_tf32(x); });
joint_matrix_apply(sg, sub_b,
[=](float x) { x = round_to_tf32(x); });
[=](float &x) { x = round_to_tf32(x); });
joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c);
}

joint_matrix_apply(sg, sub_a, [=](float x) { x *= 2; });
joint_matrix_store(
sg, sub_c,
accC.template get_multi_ptr<access::decorated::no>() +
Expand All @@ -91,43 +79,21 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
}).wait();
}

static constexpr size_t MATRIX_M = TM * 2;
static constexpr size_t MATRIX_N = TN * 2;
static constexpr size_t MATRIX_K = TK * 2;
float A[MATRIX_M][MATRIX_K];
float B[MATRIX_K][MATRIX_N];
float C[MATRIX_M][MATRIX_N];
float D[MATRIX_M][MATRIX_N];

void matrix_multiply_ref(float *A_mem, float *B_mem, float *C_mem, int M, int N,
int K) {
for (int m = 0; m < M; m++)
for (int n = 0; n < N; n++) {
for (int k = 0; k < K; k++) {
float va = A_mem[m * K + k];
float vb = B_mem[k * N + n];
C_mem[m * N + n] += va * vb;
}
}
}

int main() {
for (int i = 0; i < MATRIX_M; i++) {
for (int j = 0; j < MATRIX_K; j++) {
A[i][j] = 1.0f * (i + j);
}
}
for (int i = 0; i < MATRIX_K; i++) {
for (int j = 0; j < MATRIX_N; j++) {
B[i][j] = 2.0f * i + 3.0f * j;
}
}
for (int i = 0; i < MATRIX_M; i++) {
for (int j = 0; j < MATRIX_N; j++) {
C[i][j] = 1.0;
D[i][j] = 1.0;
}
}
static constexpr size_t MATRIX_M = TM * 2;
static constexpr size_t MATRIX_N = TN * 2;
static constexpr size_t MATRIX_K = TK * 2;
float A[MATRIX_M][MATRIX_K];
float B[MATRIX_K][MATRIX_N];
float C[MATRIX_M][MATRIX_N];
float D[MATRIX_M][MATRIX_N];

matrix_fill(MATRIX_M, MATRIX_K, (float *)A,
[](int i, int j) { return 1.0f * (i + j); });
matrix_fill(MATRIX_K, MATRIX_N, (float *)B,
[](int i, int j) { return 2.0f * i + 3.0f * j; });
matrix_fill(MATRIX_M, MATRIX_N, (float *)C, 1.0f);
matrix_fill(MATRIX_M, MATRIX_N, (float *)D, 1.0f);

big_matrix<float, MATRIX_M, MATRIX_N> MC((float *)&C);
big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
Expand All @@ -137,13 +103,7 @@ int main() {
matrix_multiply_ref((float *)A, (float *)B, (float *)D, MATRIX_M, MATRIX_N,
MATRIX_K);

bool res = true;
for (int i = 0; i < MATRIX_M; i++) {
for (int j = 0; j < MATRIX_N; j++) {
if (C[i][j] != D[i][j])
res = false;
}
}
bool res = matrix_compare(MATRIX_M, MATRIX_N, (float *)C, (float *)D);
std::cout << (res ? "passed" : "failed") << std::endl;
return !res;
}