Skip to content

Commit 96e28fe

Browse files
authored
[SYCL][Matrix tests] improve the transpose C test to include more cases (#11938)
1 parent ff69048 commit 96e28fe

File tree

2 files changed

+88
-45
lines changed

2 files changed

+88
-45
lines changed

sycl/test-e2e/Matrix/common.hpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,15 @@ void matrix_vnni(unsigned int rows, unsigned int cols, T *src, T *dest,
8181
}
8282
}
8383

84+
template <typename T>
85+
void matrix_transpose(unsigned int rows, unsigned int cols, T *dst, T *src) {
86+
for (unsigned int i = 0; i < rows; i++) {
87+
for (unsigned int j = 0; j < cols; j++) {
88+
dst[i + j * rows] = src[i * cols + j];
89+
}
90+
}
91+
}
92+
8493
template <typename T>
8594
void matrix_fill(unsigned int rows, unsigned int cols, T *src, T val) {
8695
for (unsigned int i = 0; i < rows; i++) {
@@ -128,11 +137,12 @@ void matrix_copy(unsigned int rows, unsigned int cols, T *src, T *dst) {
128137
}
129138
}
130139

131-
template <typename T1, typename T2>
140+
template <typename T1, typename T2, bool exact = false>
132141
bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) {
133142
for (int i = 0; i < rows; i++) {
134143
for (int j = 0; j < cols; j++) {
135-
if constexpr (std::is_same_v<T1, float> || std::is_same_v<T1, bfloat16>) {
144+
if constexpr (!exact && (std::is_same_v<T1, float> ||
145+
std::is_same_v<T1, bfloat16>)) {
136146
float diff = std::fabs(src[i * cols + j] - (T1)ref[i * cols + j]);
137147
if (diff > FLOAT_EPSILON || std::isnan(src[i * cols + j])) {
138148
std::cout << "Incorrect result in matrix. "
@@ -142,9 +152,10 @@ bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) {
142152
<< ", Epsilon: " << FLOAT_EPSILON << "\n";
143153
return false;
144154
}
145-
} else if constexpr (std::is_same_v<T1, int32_t>) {
155+
} else if constexpr (exact || std::is_same_v<T1, int32_t>) {
146156
if (src[i * cols + j] != ref[i * cols + j]) {
147-
std::cout << "Incorrect result in matrix. Ref: " << ref[i * cols + j]
157+
std::cout << "Incorrect result in matrix." << "i: " << i
158+
<< ", j: " << j << ", Ref: " << ref[i * cols + j]
148159
<< ", Val: " << src[i * cols + j] << "\n";
149160
return false;
150161
}
Lines changed: 73 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,102 @@
1-
#include <iostream>
2-
#include <random>
3-
41
using namespace sycl;
52
using namespace sycl::ext::oneapi::experimental::matrix;
63

7-
constexpr size_t TM = 8;
8-
constexpr size_t TK = 16;
4+
template <size_t TM, size_t TN, typename T1, size_t NUM_ROWS, size_t NUM_COLS>
5+
void matrix_load_and_store(T1 *input, T1 *out_col_major, T1 *out_row_major,
6+
queue q) {
7+
size_t M = NUM_ROWS;
8+
size_t N = NUM_COLS;
99

10-
template <typename T1, size_t NUM_ROWS_C, size_t NUM_COLS_C>
11-
void matrix_load_store(T1 *C, queue q) {
12-
size_t M = NUM_ROWS_C;
13-
size_t N = NUM_COLS_C;
10+
static_assert((NUM_ROWS % TM) == 0);
11+
static_assert((NUM_COLS % TN) == 0);
1412

1513
size_t NDRangeM = M / TM;
1614
size_t NDRangeN = N / TN;
1715

18-
auto pC = address_space_cast<sycl::access::address_space::global_space,
19-
sycl::access::decorated::no>(C);
16+
auto p_input = address_space_cast<sycl::access::address_space::global_space,
17+
sycl::access::decorated::no>(input);
18+
19+
auto p_out_col_major =
20+
address_space_cast<sycl::access::address_space::global_space,
21+
sycl::access::decorated::no>(out_col_major);
22+
auto p_out_row_major =
23+
address_space_cast<sycl::access::address_space::global_space,
24+
sycl::access::decorated::no>(out_row_major);
2025

2126
q.submit([&](handler &cgh) {
2227
cgh.parallel_for(
2328
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
24-
[=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]]
25-
26-
{
27-
// The submatrix API has to be accessed by all the workitems in
28-
// a subgroup these functions will be called once by the
29-
// subgroup no code divergence between the workitems
29+
[=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
3030
const auto global_idx = spmd_item.get_global_id(0);
3131
const auto global_idy = spmd_item.get_global_id(1);
3232
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
3333
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
3434

3535
sub_group sg = spmd_item.get_sub_group();
36-
joint_matrix<sub_group, float, use::accumulator, TM, TN> sub_c;
37-
// for transposeC
38-
// which TN x TM in N x M:
39-
// M x N => TM x N => TM x TN => TN x TM
40-
// m=sg_startx
41-
// sg_starty/SG_SZ
42-
// linear_index = M * (sg_starty/SG_SZ *TN) + TM *sg_startx
43-
joint_matrix_load(sg, sub_c,
44-
pC + M * (sg_starty / SG_SZ * TN) + TM * sg_startx,
45-
M, layout::col_major);
46-
joint_matrix_store(
47-
sg, sub_c, pC + M * (sg_starty / SG_SZ * TN) + TM * sg_startx, M,
48-
layout::col_major);
36+
joint_matrix<sub_group, float, use::accumulator, TM, TN> sub_matrix;
37+
38+
auto row_major_offset =
39+
(sg_startx * TM) * N + (sg_starty / SG_SZ * TN);
40+
auto col_major_offset =
41+
(sg_startx * TM) + (sg_starty / SG_SZ * TN) * M;
42+
43+
joint_matrix_load(sg, sub_matrix, p_input + col_major_offset, M,
44+
layout::col_major);
45+
46+
joint_matrix_store(sg, sub_matrix,
47+
p_out_col_major + row_major_offset, N,
48+
layout::row_major);
49+
50+
joint_matrix_store(sg, sub_matrix,
51+
p_out_row_major + col_major_offset, M,
52+
layout::col_major);
4953
}); // parallel for
5054
}).wait();
5155
}
5256

53-
int main() {
54-
static constexpr size_t MATRIX_M = 1024;
55-
static constexpr size_t MATRIX_N = 1024;
57+
template <size_t TM> void run_matrix_test() {
58+
static constexpr size_t MATRIX_M = TM * 16;
59+
static constexpr size_t MATRIX_N = TN * 16;
5660

5761
queue q;
58-
float *C = malloc_shared<float>(MATRIX_M * MATRIX_N, q);
59-
float *D = malloc_shared<float>(MATRIX_M * MATRIX_N, q);
62+
float *input = malloc_shared<float>(MATRIX_M * MATRIX_N, q);
63+
float *out_col_major = malloc_shared<float>(MATRIX_M * MATRIX_N, q);
64+
float *out_row_major = malloc_shared<float>(MATRIX_M * MATRIX_N, q);
65+
float *ref_col_major = malloc_shared<float>(MATRIX_M * MATRIX_N, q);
6066

61-
matrix_rand(MATRIX_M, MATRIX_N, C, (float)5.0);
62-
matrix_copy(MATRIX_M, MATRIX_N, C, D);
67+
// input is column majot matrix so it is of NxM shape
68+
matrix_rand(MATRIX_N, MATRIX_M, input, (float)5.0);
69+
matrix_fill(MATRIX_M, MATRIX_N, out_col_major, (float)0);
70+
matrix_fill(MATRIX_N, MATRIX_M, out_row_major, (float)0);
71+
matrix_transpose(MATRIX_N, MATRIX_M, ref_col_major, input);
6372

64-
matrix_load_store<float, MATRIX_M, MATRIX_N>(C, q);
73+
matrix_load_and_store<TM, TN, float, MATRIX_M, MATRIX_N>(input, out_col_major,
74+
out_row_major, q);
6575

66-
bool res = matrix_compare(MATRIX_M, MATRIX_N, C, D);
76+
// we use exact comparison as no low precision calculation is used in this
77+
// test
78+
std::cout << "compare results for TM " << TM << "\n";
79+
bool res = matrix_compare<float, float, true>(MATRIX_M, MATRIX_N,
80+
out_col_major, ref_col_major) &&
81+
matrix_compare<float, float, true>(MATRIX_N, MATRIX_M,
82+
out_row_major, input);
83+
free(input, q);
84+
free(out_col_major, q);
85+
free(out_row_major, q);
86+
free(ref_col_major, q);
87+
assert(res);
88+
}
89+
90+
int main() {
91+
run_matrix_test<8>();
92+
run_matrix_test<7>();
93+
run_matrix_test<6>();
94+
run_matrix_test<5>();
95+
run_matrix_test<4>();
96+
run_matrix_test<3>();
97+
run_matrix_test<2>();
98+
run_matrix_test<1>();
6799

68-
std::cout << (res ? "passed" : "failed") << std::endl;
69-
return !res;
100+
std::cout << "Passed\n";
101+
return 0;
70102
}

0 commit comments

Comments
 (0)