Skip to content

Commit 7fc254a

Browse files
committed
[SYCL][Joint Matrix] Add new test for fill and store operations
1 parent 1dd8ea1 commit 7fc254a

File tree

2 files changed

+115
-0
lines changed

2 files changed

+115
-0
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//==-- joint_matrix_fill_store.cpp = Test for Joint Matrix fill and store --==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===-------------------------------------------------------------------------===//
8+
9+
#include "joint_matrix_fill_store_impl.hpp"
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
//==----------------------------------------------------------------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===-------------------------------------------------------------------------===//
8+
9+
// TODO: add this test to XMX8 and SG32 folders
10+
#include "common.hpp"
11+
#define SG_SZ 16
12+
13+
using namespace sycl;
14+
using namespace sycl::ext::oneapi::experimental::matrix;
15+
16+
template <typename T1, typename T2, size_t TM, size_t TN, size_t TK>
17+
void matrix_fill_store(big_matrix<T1, TM, TN> &C, big_matrix<T2, TM, TK> &A,
18+
big_matrix<T2, TK / 2, TN * 2> &B) {
19+
buffer<bfloat16, 2> bufA(A.get_data(), range<2>(TM, TK));
20+
buffer<bfloat16, 2> bufB(B.get_data(), range<2>(TK / 2, TN * 2));
21+
buffer<float, 2> bufC((float *)C.get_data(), range<2>(TM, TN));
22+
23+
queue q;
24+
q.submit([&](handler &cgh) {
25+
auto accC = bufC.get_access<access::mode::read_write>(cgh);
26+
auto accA = bufA.get_access<access::mode::read_write>(cgh);
27+
auto accB = bufB.get_access<access::mode::read_write>(cgh);
28+
29+
cgh.parallel_for(
30+
nd_range<2>({1, 1 * SG_SZ}, {1, 1 * SG_SZ}),
31+
[=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
32+
const auto global_idx = spmd_item.get_global_id(0);
33+
const auto global_idy = spmd_item.get_global_id(1);
34+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
35+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
36+
37+
sub_group sg = spmd_item.get_sub_group();
38+
joint_matrix<sub_group, bfloat16, use::a, TM, TK, layout::row_major>
39+
sub_a;
40+
41+
// For B, we assume B has been already VNNIed.
42+
joint_matrix<sub_group, bfloat16, use::b, TK, TN,
43+
layout::ext_intel_packed>
44+
sub_b;
45+
joint_matrix<sub_group, float, use::accumulator, TM, TN> sub_c;
46+
47+
// TODO: uncomment these calls to add testing for other types of
48+
// matrices
49+
// joint_matrix_fill(sg, sub_a, 5.0);
50+
// joint_matrix_fill(sg, sub_b, 5.0);
51+
joint_matrix_fill(sg, sub_c, 5.0);
52+
53+
ext::intel::experimental::matrix::joint_matrix_store(
54+
sg, sub_a, accA.template get_multi_ptr<access::decorated::no>(),
55+
TK);
56+
57+
ext::intel::experimental::matrix::joint_matrix_store(
58+
sg, sub_b, accB.template get_multi_ptr<access::decorated::no>(),
59+
TN * 2);
60+
61+
joint_matrix_store(
62+
sg, sub_c, accC.template get_multi_ptr<access::decorated::no>(),
63+
TN, layout::row_major);
64+
}); // parallel for
65+
}).wait();
66+
}
67+
68+
template <size_t TM, size_t TN, size_t TK> bool run_test() {
69+
70+
bfloat16 A[TM][TK];
71+
bfloat16 A_ref[TM][TK];
72+
bfloat16 B[TK / 2][TN * 2];
73+
bfloat16 B_ref[TK / 2][TN * 2];
74+
float C[TM][TN];
75+
float C_ref[TM][TN];
76+
77+
matrix_fill(TM, TK, (bfloat16 *)A, (bfloat16)0);
78+
matrix_fill(TK / 2, TN * 2, (bfloat16 *)B, (bfloat16)0);
79+
matrix_fill(TM, TN, (float *)C, 0.0f);
80+
81+
matrix_fill(TM, TK, (bfloat16 *)A_ref, (bfloat16)5);
82+
matrix_fill(TK / 2, TN * 2, (bfloat16 *)B_ref, (bfloat16)5);
83+
matrix_fill(TM, TN, (float *)C_ref, 5.0f);
84+
85+
big_matrix<float, TM, TN> MC((float *)&C);
86+
big_matrix<bfloat16, TM, TK> MA((bfloat16 *)&A);
87+
big_matrix<bfloat16, TK / 2, TN * 2> MB((bfloat16 *)&B);
88+
89+
matrix_fill_store(MC, MA, MB);
90+
91+
// TODO: uncomment these calls to verify other types of matrices
92+
// bool res = matrix_compare(TM, TK, (bfloat16 *)A, (bfloat16 *)A_ref);
93+
// res &= matrix_compare(TK / 2, TN * 2, (bfloat16 *)B, (bfloat16 *)B_ref);
94+
// TODO later
95+
bool res = matrix_compare(TM, TN, (float *)C, (float *)C_ref);
96+
97+
return res;
98+
}
99+
100+
int main() {
101+
// TODO: add all supported size and types combinations
102+
bool res = run_test<8, 16, 16>();
103+
res &= run_test<32, 64, 16>();
104+
std::cout << (res ? "passed" : "failed") << std::endl;
105+
return !res;
106+
}

0 commit comments

Comments
 (0)