Skip to content

Commit 4b56feb

Browse files
committed
update
1 parent 30f5f7e commit 4b56feb

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

sycl/test-e2e/Matrix/joint_matrix_fill_store_impl.hpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,9 @@
1010
#include "common.hpp"
1111
#define SG_SZ 16
1212

13-
using namespace sycl;
14-
using namespace sycl::ext::oneapi::experimental::matrix;
15-
16-
template <typename T1, typename T2, size_t TM, size_t TN, size_t TK>
17-
void matrix_fill_store(big_matrix<T1, TM, TN> &C, big_matrix<T2, TM, TK> &A,
18-
big_matrix<T2, TK / 2, TN * 2> &B) {
13+
template <typename TC, typename Tab, size_t TM, size_t TN, size_t TK>
14+
void matrix_fill_store(big_matrix<TC, TM, TN> &C, big_matrix<Tab, TM, TK> &A,
15+
big_matrix<Tab, TK / 2, TN * 2> &B) {
1916
buffer<bfloat16, 2> bufA(A.get_data(), range<2>(TM, TK));
2017
buffer<bfloat16, 2> bufB(B.get_data(), range<2>(TK / 2, TN * 2));
2118
buffer<float, 2> bufC((float *)C.get_data(), range<2>(TM, TN));
@@ -35,14 +32,14 @@ void matrix_fill_store(big_matrix<T1, TM, TN> &C, big_matrix<T2, TM, TK> &A,
3532
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
3633

3734
sub_group sg = spmd_item.get_sub_group();
38-
joint_matrix<sub_group, bfloat16, use::a, TM, TK, layout::row_major>
35+
joint_matrix<sub_group, Tab, use::a, TM, TK, layout::row_major>
3936
sub_a;
4037

4138
// For B, we assume B has been already VNNIed.
42-
joint_matrix<sub_group, bfloat16, use::b, TK, TN,
39+
joint_matrix<sub_group, Tab, use::b, TK, TN,
4340
layout::ext_intel_packed>
4441
sub_b;
45-
joint_matrix<sub_group, float, use::accumulator, TM, TN> sub_c;
42+
joint_matrix<sub_group, TC, use::accumulator, TM, TN> sub_c;
4643

4744
// TODO: uncomment these calls to add testing for other types of
4845
// matrices
@@ -100,6 +97,7 @@ int main() {
10097
// TODO: add all supported size and types combinations
10198
bool res = run_test<8, 16, 16>();
10299
res &= run_test<32, 64, 16>();
100+
res &= run_test<16, 16, 16>();
103101
std::cout << (res ? "passed" : "failed") << std::endl;
104102
return !res;
105103
}

0 commit comments

Comments
 (0)