10
10
#include " common.hpp"
11
11
#define SG_SZ 16
12
12
13
- using namespace sycl ;
14
- using namespace sycl ::ext::oneapi::experimental::matrix;
15
-
16
- template <typename T1, typename T2, size_t TM, size_t TN, size_t TK>
17
- void matrix_fill_store (big_matrix<T1, TM, TN> &C, big_matrix<T2, TM, TK> &A,
18
- big_matrix<T2, TK / 2 , TN * 2 > &B) {
13
+ template <typename TC, typename Tab, size_t TM, size_t TN, size_t TK>
14
+ void matrix_fill_store (big_matrix<TC, TM, TN> &C, big_matrix<Tab, TM, TK> &A,
15
+ big_matrix<Tab, TK / 2 , TN * 2 > &B) {
19
16
buffer<bfloat16, 2 > bufA (A.get_data (), range<2 >(TM, TK));
20
17
buffer<bfloat16, 2 > bufB (B.get_data (), range<2 >(TK / 2 , TN * 2 ));
21
18
buffer<float , 2 > bufC ((float *)C.get_data (), range<2 >(TM, TN));
@@ -35,14 +32,14 @@ void matrix_fill_store(big_matrix<T1, TM, TN> &C, big_matrix<T2, TM, TK> &A,
35
32
const auto sg_starty = global_idy - spmd_item.get_local_id (1 );
36
33
37
34
sub_group sg = spmd_item.get_sub_group ();
38
- joint_matrix<sub_group, bfloat16 , use::a, TM, TK, layout::row_major>
35
+ joint_matrix<sub_group, Tab , use::a, TM, TK, layout::row_major>
39
36
sub_a;
40
37
41
38
// For B, we assume B has been already VNNIed.
42
- joint_matrix<sub_group, bfloat16 , use::b, TK, TN,
39
+ joint_matrix<sub_group, Tab , use::b, TK, TN,
43
40
layout::ext_intel_packed>
44
41
sub_b;
45
- joint_matrix<sub_group, float , use::accumulator, TM, TN> sub_c;
42
+ joint_matrix<sub_group, TC , use::accumulator, TM, TN> sub_c;
46
43
47
44
// TODO: uncomment these calls to add testing for other types of
48
45
// matrices
@@ -100,6 +97,7 @@ int main() {
100
97
// TODO: add all supported size and types combinations
101
98
bool res = run_test<8 , 16 , 16 >();
102
99
res &= run_test<32 , 64 , 16 >();
100
+ res &= run_test<16 , 16 , 16 >();
103
101
std::cout << (res ? " passed" : " failed" ) << std::endl;
104
102
return !res;
105
103
}
0 commit comments