Skip to content

Commit 3c213d2

Browse files
authored
[SYCL][Joint Matrix] Fix error in TF32 test, refactoring (#11576)
1 parent 6a98330 commit 3c213d2

File tree

4 files changed

+43
-73
lines changed

4 files changed

+43
-73
lines changed

sycl/test-e2e/Matrix/SG32/joint_matrix_tf32.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212

1313
// XFAIL:cpu
1414

15-
#include <iostream>
16-
#include <random>
17-
#include <sycl/sycl.hpp>
15+
#include "../common.hpp"
1816

1917
using namespace sycl;
2018
using namespace sycl::ext::oneapi::experimental::matrix;

sycl/test-e2e/Matrix/common.hpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,15 @@ void matrix_multiply_ref(Ta *A, Ta *B, Tc *C, int M, int N, int K,
3232
for (unsigned int n = 0; n < N; n++) {
3333
for (unsigned int k = 0; k < K; k++) {
3434
int c_ind = transpose_c ? (n * M + m) : m * N + n;
35-
if (std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float>)
35+
if constexpr (std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float>)
3636
C[c_ind] += make_fp32(A[m * K + k]) * make_fp32(B[k * N + n]);
37-
if (std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t>)
37+
else if constexpr (std::is_same_v<Ta, float> &&
38+
std::is_same_v<Tc, float> ||
39+
std::is_same_v<Ta, int8_t> &&
40+
std::is_same_v<Tc, int32_t>)
3841
C[c_ind] += A[m * K + k] * B[k * N + n];
42+
else
43+
assert(false && "Unsupported type in matrix_multiply_ref.");
3944
}
4045
}
4146
}
@@ -63,6 +68,15 @@ void matrix_fill(unsigned int rows, unsigned int cols, T *src, T val) {
6368
}
6469
}
6570

71+
template <typename T, typename F>
72+
void matrix_fill(unsigned int rows, unsigned int cols, T *src, F op) {
73+
for (unsigned int i = 0; i < rows; i++) {
74+
for (unsigned int j = 0; j < cols; j++) {
75+
src[i * cols + j] = T(op(i, j));
76+
}
77+
}
78+
}
79+
6680
template <typename T>
6781
void matrix_rand(unsigned int rows, unsigned int cols, T *src, T val) {
6882
std::random_device dev;
@@ -89,17 +103,17 @@ bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) {
89103
for (int j = 0; j < cols; j++) {
90104
if constexpr (std::is_same_v<T1, float> || std::is_same_v<T1, bfloat16>) {
91105
float diff = std::fabs(src[i * cols + j] - (T1)ref[i * cols + j]);
92-
if (diff > BF16_EPSILON) {
106+
if (std::is_same_v<T1, float> && diff > FLOAT_EPSILON ||
107+
std::is_same_v<T1, bfloat16> && diff > BF16_EPSILON) {
93108
std::cout << "Incorrect result in matrix. Ref: "
94-
<< (T1)ref[i * cols + j] << ", Val:" << src[i * cols + j]
95-
<< ", Diff: " << diff << ", Epsilon: " << BF16_EPSILON
96-
<< "\n";
109+
<< (T1)ref[i * cols + j] << ", Val: " << src[i * cols + j]
110+
<< ", Diff: " << diff << "\n";
97111
return false;
98112
}
99-
} else if (std::is_same_v<T1, int32_t>) {
113+
} else if constexpr (std::is_same_v<T1, int32_t>) {
100114
if (src[i * cols + j] != ref[i * cols + j]) {
101115
std::cout << "Incorrect result in matrix. Ref: " << ref[i * cols + j]
102-
<< ", Val:" << src[i * cols + j] << "\n";
116+
<< ", Val: " << src[i * cols + j] << "\n";
103117
return false;
104118
}
105119
} else {

sycl/test-e2e/Matrix/joint_matrix_tf32.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212

1313
// XFAIL:cpu
1414

15-
#include <iostream>
16-
#include <random>
17-
#include <sycl/sycl.hpp>
15+
#include "common.hpp"
1816

1917
using namespace sycl;
2018
using namespace sycl::ext::oneapi::experimental::matrix;
Lines changed: 19 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,5 @@
1-
#define TM 8
2-
#define TK 8
3-
4-
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
5-
public:
6-
T *mat;
7-
8-
public:
9-
T *get_data() { return mat; }
10-
void set_data(T *data) { mat = data; }
11-
big_matrix(T *data) : mat(data) {}
12-
};
1+
constexpr size_t TM = 8;
2+
constexpr size_t TK = 8;
133

144
template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A,
155
size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C,
@@ -60,7 +50,6 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
6050
accC.template get_multi_ptr<access::decorated::no>() +
6151
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
6252
N, layout::row_major);
63-
joint_matrix_fill(sg, sub_a, 42);
6453
for (int k = 0; k < K; k += TK) {
6554
joint_matrix_load(
6655
sg, sub_a,
@@ -75,13 +64,12 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
7564
// If no rounding to tf32 function is called, joint_matrix_mad
7665
// function will work on truncated floats.
7766
joint_matrix_apply(sg, sub_a,
78-
[=](float x) { x = round_to_tf32(x); });
67+
[=](float &x) { x = round_to_tf32(x); });
7968
joint_matrix_apply(sg, sub_b,
80-
[=](float x) { x = round_to_tf32(x); });
69+
[=](float &x) { x = round_to_tf32(x); });
8170
joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c);
8271
}
8372

84-
joint_matrix_apply(sg, sub_a, [=](float x) { x *= 2; });
8573
joint_matrix_store(
8674
sg, sub_c,
8775
accC.template get_multi_ptr<access::decorated::no>() +
@@ -91,43 +79,21 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
9179
}).wait();
9280
}
9381

94-
static constexpr size_t MATRIX_M = TM * 2;
95-
static constexpr size_t MATRIX_N = TN * 2;
96-
static constexpr size_t MATRIX_K = TK * 2;
97-
float A[MATRIX_M][MATRIX_K];
98-
float B[MATRIX_K][MATRIX_N];
99-
float C[MATRIX_M][MATRIX_N];
100-
float D[MATRIX_M][MATRIX_N];
101-
102-
void matrix_multiply_ref(float *A_mem, float *B_mem, float *C_mem, int M, int N,
103-
int K) {
104-
for (int m = 0; m < M; m++)
105-
for (int n = 0; n < N; n++) {
106-
for (int k = 0; k < K; k++) {
107-
float va = A_mem[m * K + k];
108-
float vb = B_mem[k * N + n];
109-
C_mem[m * N + n] += va * vb;
110-
}
111-
}
112-
}
113-
11482
int main() {
115-
for (int i = 0; i < MATRIX_M; i++) {
116-
for (int j = 0; j < MATRIX_K; j++) {
117-
A[i][j] = 1.0f * (i + j);
118-
}
119-
}
120-
for (int i = 0; i < MATRIX_K; i++) {
121-
for (int j = 0; j < MATRIX_N; j++) {
122-
B[i][j] = 2.0f * i + 3.0f * j;
123-
}
124-
}
125-
for (int i = 0; i < MATRIX_M; i++) {
126-
for (int j = 0; j < MATRIX_N; j++) {
127-
C[i][j] = 1.0;
128-
D[i][j] = 1.0;
129-
}
130-
}
83+
static constexpr size_t MATRIX_M = TM * 2;
84+
static constexpr size_t MATRIX_N = TN * 2;
85+
static constexpr size_t MATRIX_K = TK * 2;
86+
float A[MATRIX_M][MATRIX_K];
87+
float B[MATRIX_K][MATRIX_N];
88+
float C[MATRIX_M][MATRIX_N];
89+
float D[MATRIX_M][MATRIX_N];
90+
91+
matrix_fill(MATRIX_M, MATRIX_K, (float *)A,
92+
[](int i, int j) { return 1.0f * (i + j); });
93+
matrix_fill(MATRIX_K, MATRIX_N, (float *)B,
94+
[](int i, int j) { return 2.0f * i + 3.0f * j; });
95+
matrix_fill(MATRIX_M, MATRIX_N, (float *)C, 1.0f);
96+
matrix_fill(MATRIX_M, MATRIX_N, (float *)D, 1.0f);
13197

13298
big_matrix<float, MATRIX_M, MATRIX_N> MC((float *)&C);
13399
big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
@@ -137,13 +103,7 @@ int main() {
137103
matrix_multiply_ref((float *)A, (float *)B, (float *)D, MATRIX_M, MATRIX_N,
138104
MATRIX_K);
139105

140-
bool res = true;
141-
for (int i = 0; i < MATRIX_M; i++) {
142-
for (int j = 0; j < MATRIX_N; j++) {
143-
if (C[i][j] != D[i][j])
144-
res = false;
145-
}
146-
}
106+
bool res = matrix_compare(MATRIX_M, MATRIX_N, (float *)C, (float *)D);
147107
std::cout << (res ? "passed" : "failed") << std::endl;
148108
return !res;
149109
}

0 commit comments

Comments
 (0)