Skip to content

Commit 64fb410

Browse files
authored
[SYCL][Joint Matrix] Remove duplicated matrix_multiply_ref in tests (#11609)
1 parent f0031fa commit 64fb410

35 files changed

+264
-671
lines changed

sycl/test-e2e/Matrix/SG32/joint_matrix_bfloat16_rowmajorA_rowmajorB.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,9 @@
1616
// XFAIL: gpu
1717

1818
#include "../common.hpp"
19-
#include <iostream>
20-
#include <sycl/sycl.hpp>
2119

2220
using namespace sycl;
2321
using namespace sycl::ext::oneapi::experimental::matrix;
24-
using bfloat16 = sycl::ext::oneapi::bfloat16;
2522

2623
constexpr size_t SG_SZ = 32;
2724
constexpr size_t TN = 16;

sycl/test-e2e/Matrix/SG32/joint_matrix_half.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
1212
// RUN: %{run} %t.out
1313

14-
#include <iostream>
15-
#include <sycl/sycl.hpp>
14+
#include "../common.hpp"
1615

1716
using namespace sycl;
1817
using namespace sycl::ext::oneapi::experimental::matrix;

sycl/test-e2e/Matrix/SG32/joint_matrix_int8_colmajorA_colmajorB.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515

1616
// XFAIL: gpu
1717

18-
#include <iostream>
19-
#include <sycl/sycl.hpp>
18+
#include "../common.hpp"
2019

2120
using namespace sycl;
2221
using namespace sycl::ext::oneapi::experimental::matrix;

sycl/test-e2e/Matrix/SG32/joint_matrix_int8_vnni.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212

1313
// XFAIL: gpu
1414

15-
#include <iostream>
16-
#include <sycl/sycl.hpp>
15+
#include "../common.hpp"
1716

1817
using namespace sycl;
1918
using namespace sycl::ext::oneapi::experimental::matrix;

sycl/test-e2e/Matrix/SG32/joint_matrix_ss_int8.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
1111
// RUN: %{run} %t.out
1212

13-
#include <iostream>
14-
#include <sycl/sycl.hpp>
13+
#include "../common.hpp"
1514

1615
using namespace sycl;
1716
using namespace sycl::ext::oneapi::experimental::matrix;

sycl/test-e2e/Matrix/SG32/joint_matrix_su_int8.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
1111
// RUN: %{run} %t.out
1212

13-
#include <iostream>
14-
#include <sycl/sycl.hpp>
13+
#include "../common.hpp"
1514

1615
using namespace sycl;
1716
using namespace sycl::ext::oneapi::experimental::matrix;

sycl/test-e2e/Matrix/SG32/joint_matrix_us_int8.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
1111
// RUN: %{run} %t.out
1212

13-
#include <iostream>
14-
#include <sycl/sycl.hpp>
13+
#include "../common.hpp"
1514

1615
using namespace sycl;
1716
using namespace sycl::ext::oneapi::experimental::matrix;

sycl/test-e2e/Matrix/SG32/joint_matrix_uu_int8.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
1111
// RUN: %{run} %t.out
1212

13-
#include <iostream>
14-
#include <sycl/sycl.hpp>
13+
#include "../common.hpp"
1514

1615
using namespace sycl;
1716
using namespace sycl::ext::oneapi::experimental::matrix;

sycl/test-e2e/Matrix/XMX8/joint_matrix_bfloat16_32x64.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,10 @@
1212

1313
// XFAIL: *
1414

15-
#include <iostream>
16-
#include <sycl/sycl.hpp>
15+
#include "../common.hpp"
1716

1817
using namespace sycl;
1918
using namespace sycl::ext::oneapi::experimental::matrix;
20-
using bfloat16 = sycl::ext::oneapi::bfloat16;
2119

2220
#define SG_SZ 8
2321
constexpr size_t TN = 8;

sycl/test-e2e/Matrix/XMX8/joint_matrix_half.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
1212
// RUN: %{run} %t.out
1313

14-
#include <iostream>
15-
#include <sycl/sycl.hpp>
14+
#include "../common.hpp"
1615

1716
using namespace sycl;
1817
using namespace sycl::ext::oneapi::experimental::matrix;

sycl/test-e2e/Matrix/XMX8/joint_matrix_int8_vnni.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212

1313
// XFAIL: *
1414

15-
#include <iostream>
16-
#include <sycl/sycl.hpp>
15+
#include "../common.hpp"
1716

1817
using namespace sycl;
1918
using namespace sycl::ext::oneapi::experimental::matrix;

sycl/test-e2e/Matrix/XMX8/joint_matrix_ss_int8.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
1111
// RUN: %{run} %t.out
1212

13-
#include <iostream>
14-
#include <sycl/sycl.hpp>
13+
#include "../common.hpp"
1514

1615
using namespace sycl;
1716
using namespace sycl::ext::oneapi::experimental::matrix;

sycl/test-e2e/Matrix/XMX8/joint_matrix_su_int8.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
1111
// RUN: %{run} %t.out
1212

13-
#include <iostream>
14-
#include <sycl/sycl.hpp>
13+
#include "../common.hpp"
1514

1615
using namespace sycl;
1716
using namespace sycl::ext::oneapi::experimental::matrix;

sycl/test-e2e/Matrix/XMX8/joint_matrix_us_int8.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
1111
// RUN: %{run} %t.out
1212

13-
#include <iostream>
14-
#include <sycl/sycl.hpp>
13+
#include "../common.hpp"
1514

1615
using namespace sycl;
1716
using namespace sycl::ext::oneapi::experimental::matrix;

sycl/test-e2e/Matrix/XMX8/joint_matrix_uu_int8.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
1111
// RUN: %{run} %t.out
1212

13-
#include <iostream>
14-
#include <sycl/sycl.hpp>
13+
#include "../common.hpp"
1514

1615
using namespace sycl;
1716
using namespace sycl::ext::oneapi::experimental::matrix;

sycl/test-e2e/Matrix/common.hpp

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,38 @@ float make_fp32(bfloat16 x) {
3131
return *res;
3232
}
3333

34-
template <typename Ta, typename Tc>
35-
void matrix_multiply_ref(Ta *A, Ta *B, Tc *C, int M, int N, int K,
36-
bool transpose_c = false) {
34+
template <typename Ta, typename Tb, typename Tc, uint VF = 1>
35+
void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
36+
bool transpose_c = false, bool colmajor_a = false,
37+
bool colmajor_b = false) {
3738
for (unsigned int m = 0; m < M; m++) {
3839
for (unsigned int n = 0; n < N; n++) {
3940
for (unsigned int k = 0; k < K; k++) {
41+
42+
int a_ind = colmajor_a ? (k * M + m) : m * K + k;
43+
int b_ind = colmajor_b ? (n * K + k) : k * N + n;
4044
int c_ind = transpose_c ? (n * M + m) : m * N + n;
41-
if constexpr (std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float>)
42-
C[c_ind] += make_fp32(A[m * K + k]) * make_fp32(B[k * N + n]);
43-
else if constexpr (std::is_same_v<Ta, float> &&
44-
std::is_same_v<Tc, float> ||
45-
std::is_same_v<Ta, int8_t> &&
46-
std::is_same_v<Tc, int32_t>)
47-
C[c_ind] += A[m * K + k] * B[k * N + n];
48-
else
49-
assert(false && "Unsupported type in matrix_multiply_ref.");
45+
46+
Ta *va = (Ta *)(A + a_ind * VF);
47+
Tb *vb = (Tb *)(B + b_ind * VF);
48+
Tc acc = *(C + c_ind);
49+
50+
for (uint i = 0; i < VF; i++) {
51+
if constexpr (std::is_same_v<Ta, bfloat16> &&
52+
std::is_same_v<Tc, float>)
53+
acc += make_fp32(va[i]) * make_fp32(vb[i]);
54+
else if constexpr (std::is_same_v<Ta, float> &&
55+
std::is_same_v<Tc, float> ||
56+
std::is_integral_v<Ta> && std::is_integral_v<Tc>)
57+
acc += va[i] * vb[i];
58+
else if constexpr (std::is_same_v<Ta, sycl::half> &&
59+
std::is_same_v<Tc, float>)
60+
acc += (float)va[i] * (float)vb[i];
61+
else
62+
assert(false && "Unsupported type in matrix_multiply_ref.");
63+
}
64+
65+
*(C + c_ind) = acc;
5066
}
5167
}
5268
}

sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp

Lines changed: 17 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -67,55 +67,30 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
6767
}).wait();
6868
}
6969

70-
static constexpr size_t MATRIX_M = TM * 2;
71-
static constexpr size_t MATRIX_N = TN * 2;
72-
static constexpr size_t MATRIX_K = TK * 2;
73-
bfloat16 A[MATRIX_M][MATRIX_K];
74-
bfloat16 B[MATRIX_K / 2][MATRIX_N * 2];
75-
float C[MATRIX_M][MATRIX_N];
76-
float D[MATRIX_M][MATRIX_N];
77-
78-
void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N,
79-
int K) {
80-
for (int m = 0; m < M; m++)
81-
for (int n = 0; n < N; n++) {
82-
for (int k = 0; k < K; k++) {
83-
bfloat16 *va = (bfloat16 *)(A_mem + m * K + k);
84-
bfloat16 *vb = (bfloat16 *)(B_mem + k * N + n);
85-
float acc = *((float *)(C_mem + m * N + n));
86-
for (int i = 0; i < 2; i++) {
87-
acc += (make_fp32(va[i]) * make_fp32(vb[i]));
88-
}
89-
*((float *)(C_mem + m * N + n)) = acc;
90-
}
91-
}
92-
}
93-
9470
int main() {
95-
for (int i = 0; i < MATRIX_M; i++) {
96-
for (int j = 0; j < MATRIX_K; j++) {
97-
A[i][j] = bfloat16(1.0f * (i + j));
98-
}
99-
}
100-
for (int i = 0; i < MATRIX_K / 2; i++) {
101-
for (int j = 0; j < MATRIX_N * 2; j++) {
102-
B[i][j] = bfloat16(2.0f * i + 3.0f * j);
103-
}
104-
}
105-
for (int i = 0; i < MATRIX_M; i++) {
106-
for (int j = 0; j < MATRIX_N; j++) {
107-
C[i][j] = 1.0;
108-
D[i][j] = 1.0;
109-
}
110-
}
71+
static constexpr size_t MATRIX_M = TM * 2;
72+
static constexpr size_t MATRIX_N = TN * 2;
73+
static constexpr size_t MATRIX_K = TK * 2;
74+
bfloat16 A[MATRIX_M][MATRIX_K];
75+
bfloat16 B[MATRIX_K / 2][MATRIX_N * 2];
76+
float C[MATRIX_M][MATRIX_N];
77+
float D[MATRIX_M][MATRIX_N];
78+
79+
matrix_fill(MATRIX_M, MATRIX_K, (bfloat16 *)A,
80+
[](int i, int j) { return 1.0f * (i + j); });
81+
matrix_fill(MATRIX_K / 2, MATRIX_N * 2, (bfloat16 *)B,
82+
[](int i, int j) { return 2.0f * i + 3.0f * j; });
83+
matrix_fill(MATRIX_M, MATRIX_N, (float *)C, 1.0f);
84+
matrix_fill(MATRIX_M, MATRIX_N, (float *)D, 1.0f);
11185

11286
big_matrix<float, MATRIX_M, MATRIX_N> MC((float *)&C);
11387
big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
11488
big_matrix<bfloat16, MATRIX_M, MATRIX_K> MA((bfloat16 *)&A);
11589
big_matrix<bfloat16, MATRIX_K / 2, MATRIX_N * 2> MB((bfloat16 *)&B);
11690
matrix_multiply(MC, MA, MB);
117-
matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M,
118-
MATRIX_N, MATRIX_K / 2);
91+
matrix_multiply_ref<bfloat16, bfloat16, float, 2>(
92+
(bfloat16 *)A, (bfloat16 *)B, (float *)D, MATRIX_M, MATRIX_N,
93+
MATRIX_K / 2);
11994

12095
bool res = matrix_compare(MATRIX_M, MATRIX_N, (float *)C, (float *)D);
12196
std::cout << (res ? "passed" : "failed") << std::endl;

sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp

Lines changed: 19 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -81,59 +81,33 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
8181
}).wait();
8282
}
8383

84-
static constexpr size_t MATRIX_M = TM * 2;
85-
static constexpr size_t MATRIX_N = TN * 2;
86-
static constexpr size_t MATRIX_K = TK * 2;
87-
88-
bfloat16 A[MATRIX_M][MATRIX_K];
89-
bfloat16 B[MATRIX_K / 2][MATRIX_N * 2];
90-
91-
float C[MATRIX_M][MATRIX_N];
92-
float D[MATRIX_M][MATRIX_N];
93-
94-
void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N,
95-
int K) {
96-
for (int m = 0; m < M; m++)
97-
for (int n = 0; n < N; n++) {
98-
for (int k = 0; k < K; k++) {
99-
// Because B was assumed VNNIed
100-
bfloat16 *va = (bfloat16 *)(A_mem + m * K + k);
101-
bfloat16 *vb = (bfloat16 *)(B_mem + k * N + n);
102-
float acc = *((float *)(C_mem + m * N + n));
103-
for (int i = 0; i < 2; i++) {
104-
acc += (make_fp32(va[i]) * make_fp32(vb[i]));
105-
}
106-
*((float *)(C_mem + m * N + n)) = acc;
107-
}
108-
}
109-
}
110-
11184
int main() {
112-
for (int i = 0; i < MATRIX_M; i++) {
113-
for (int j = 0; j < MATRIX_K; j++) {
114-
A[i][j] = bfloat16(1.0f * (i + j));
115-
}
116-
}
117-
for (int i = 0; i < MATRIX_K / 2; i++) {
118-
for (int j = 0; j < MATRIX_N * 2; j++) {
119-
B[i][j] = bfloat16(2.0f * i + 3.0f * j);
120-
}
121-
}
122-
for (int i = 0; i < MATRIX_M; i++) {
123-
for (int j = 0; j < MATRIX_N; j++) {
124-
C[i][j] = 1.0;
125-
D[i][j] = 1.0;
126-
}
127-
}
85+
static constexpr size_t MATRIX_M = TM * 2;
86+
static constexpr size_t MATRIX_N = TN * 2;
87+
static constexpr size_t MATRIX_K = TK * 2;
88+
89+
bfloat16 A[MATRIX_M][MATRIX_K];
90+
bfloat16 B[MATRIX_K / 2][MATRIX_N * 2];
91+
92+
float C[MATRIX_M][MATRIX_N];
93+
float D[MATRIX_M][MATRIX_N];
94+
95+
matrix_fill(MATRIX_M, MATRIX_K, (bfloat16 *)A,
96+
[](int i, int j) { return 1.0f * (i + j); });
97+
matrix_fill(MATRIX_K / 2, MATRIX_N * 2, (bfloat16 *)B,
98+
[](int i, int j) { return 2.0f * i + 3.0f * j; });
99+
matrix_fill(MATRIX_M, MATRIX_N, (float *)C, 1.0f);
100+
matrix_fill(MATRIX_M, MATRIX_N, (float *)D, 1.0f);
128101

129102
big_matrix<float, MATRIX_M, MATRIX_N> MC((float *)&C);
130103
big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
131104
big_matrix<bfloat16, MATRIX_M, MATRIX_K> MA((bfloat16 *)&A);
132105
big_matrix<bfloat16, MATRIX_K / 2, MATRIX_N * 2> MB((bfloat16 *)&B);
133106

134107
matrix_multiply(MC, MA, MB);
135-
matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M,
136-
MATRIX_N, MATRIX_K / 2);
108+
matrix_multiply_ref<bfloat16, bfloat16, float, 2>(
109+
(bfloat16 *)A, (bfloat16 *)B, (float *)D, MATRIX_M, MATRIX_N,
110+
MATRIX_K / 2);
137111

138112
bool res = matrix_compare(MATRIX_M, MATRIX_N, (float *)C, (float *)D);
139113
std::cout << (res ? "passed" : "failed") << std::endl;

0 commit comments

Comments
 (0)