7
7
// ===----------------------------------------------------------------------===//
8
8
#include < sycl/usm.hpp>
9
9
10
- template <typename Tc, typename Ta, size_t M, size_t N>
11
- bool apply_verify (Tc *C, Tc *D, Ta *A, Ta *Ar) {
12
- for (size_t i = 0 ; i < M; i++)
13
- for (size_t j = 0 ; j < N; j++) {
14
- Tc diffc = D[i * N + j] - C[i * N + j] * 2 ;
15
- Ta diffa = Ar[i * N + j] - (A[i * N + j] + 42 );
16
- if constexpr (std::is_same_v<Ta, bfloat16>) {
17
- if (std::fabs (diffc) > FLOAT_EPSILON ||
18
- std::fabs (diffa) > FLOAT_EPSILON || std::isnan (C[i * N + j]) ||
19
- std::isnan (A[i * N + j])) {
20
- return false ;
21
- }
22
- } else {
23
- if (std::abs (diffc) > 0 || std::abs (diffa) > 0 ) {
24
- return false ;
25
- }
26
- }
27
- }
28
- return true ;
10
+ template <typename T> T mul2 (T x) { return x * 2 ; }
11
+
12
+ template <typename T> T add5 (T x) { return x + 5 ; }
13
+
14
+ template <typename Tc, size_t M, size_t N>
15
+ bool apply_verify (Tc *C, Tc *D, Tc *ref) {
16
+ Tc *refcopy = (Tc *)std::malloc (M * N * sizeof (Tc));
17
+ memcpy (refcopy, ref, M * N * sizeof (Tc));
18
+ matrix_apply (M, N, ref, mul2<Tc>);
19
+ bool res = matrix_compare (M, N, D, ref);
20
+
21
+ matrix_apply (M, N, refcopy, add5<Tc>);
22
+ res &= matrix_compare (M, N, C, refcopy);
23
+ return res;
29
24
}
25
+
30
26
template <typename Tc, typename Ta, size_t TM, size_t TN, size_t TK, size_t M,
31
27
size_t N, size_t K, class kernel_name >
32
- bool apply_two_matrices (Tc *C, Tc *D, Ta *A, Ta *Ar, queue q) {
28
+ bool apply_two_matrices (Tc *C, Tc *D, Ta *A, Ta *Ar, Tc *Cref, Ta *Aref,
29
+ queue q) {
33
30
size_t NDRangeM = M / TM;
34
31
size_t NDRangeN = N / TN;
35
32
@@ -70,22 +67,33 @@ bool apply_two_matrices(Tc *C, Tc *D, Ta *A, Ta *Ar, queue q) {
70
67
joint_matrix_load (
71
68
sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / sg_size * TN,
72
69
N, layout::row_major);
73
- joint_matrix_apply (sg, sub_c, sub_d,
74
- [](const Tc &x, Tc &y) { y = x * 2 ; });
70
+ joint_matrix_apply (sg, sub_c, sub_d, [](Tc &x, Tc &y) {
71
+ y = mul2 (x);
72
+ x = add5 (x);
73
+ });
75
74
joint_matrix_store (
76
75
sg, sub_d, pD + (sg_startx * TM) * N + sg_starty / sg_size * TN,
77
76
N, layout::row_major);
77
+ joint_matrix_store (
78
+ sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / sg_size * TN,
79
+ N, layout::row_major);
78
80
joint_matrix_load (
79
81
sg, sub_a, pA + (sg_startx * TM) * K + sg_starty / sg_size * TK,
80
82
K);
81
- joint_matrix_apply (sg, sub_a, sub_ar,
82
- [](const Ta &x, Ta &y) { y = x + 42 ; });
83
+ joint_matrix_apply (sg, sub_a, sub_ar, [](Ta &x, Ta &y) {
84
+ y = mul2 (x);
85
+ x = add5 (x);
86
+ });
83
87
ext::intel::experimental::matrix::joint_matrix_store (
84
88
sg, sub_ar,
85
89
pAr + (sg_startx * TM) * K + sg_starty / sg_size * TK, K);
90
+ ext::intel::experimental::matrix::joint_matrix_store (
91
+ sg, sub_a, pA + (sg_startx * TM) * K + sg_starty / sg_size * TK,
92
+ K);
86
93
}); // parallel for
87
94
}).wait ();
88
- return apply_verify<Tc, Ta, M, N>(C, D, A, Ar);
95
+ return apply_verify<Tc, M, N>(C, D, Cref) &&
96
+ apply_verify<Ta, M, N>(A, Ar, Aref);
89
97
}
90
98
91
99
template <typename Ta, typename Tc, size_t TM, size_t TN, size_t TK,
@@ -96,16 +104,20 @@ bool test() {
96
104
static constexpr size_t K = TK * 2 ;
97
105
queue q;
98
106
107
+ Tc *Cref = malloc_shared<Tc>(M * N, q);
108
+ Ta *Aref = malloc_shared<Ta>(M * K, q);
99
109
Tc *C = malloc_shared<Tc>(M * N, q);
100
110
Tc *D = malloc_shared<Tc>(M * N, q);
101
111
Ta *A = malloc_shared<Ta>(M * K, q);
102
112
Ta *Ar = malloc_shared<Ta>(M * K, q);
103
113
104
- matrix_rand (M, N, (Tc *)C, (Tc)100 );
105
- matrix_rand (M, K, (Ta *)A, (Ta)100 );
114
+ matrix_rand (M, N, (Tc *)Cref, (Tc)100 );
115
+ matrix_rand (M, K, (Ta *)Aref, (Ta)100 );
116
+ matrix_copy (M, N, Cref, C);
117
+ matrix_copy (M, K, Aref, A);
106
118
107
119
bool res = apply_two_matrices<Tc, Ta, TM, TN, TK, M, N, K, kernel_name>(
108
- C, D, A, Ar, q);
120
+ C, D, A, Ar, Cref, Aref, q);
109
121
110
122
if constexpr (std::is_same_v<Ta, bfloat16>)
111
123
std::cout << " bfloat16 " << TM << " x" << TN << " x" << TK << " : "
@@ -117,6 +129,8 @@ bool test() {
117
129
free (D, q);
118
130
free (A, q);
119
131
free (Ar, q);
132
+ free (Cref, q);
133
+ free (Aref, q);
120
134
121
135
return res;
122
136
}
0 commit comments