Skip to content

Commit 5e7de51

Browse files
authored
[SYCL][Joint Matrix] Update apply to make both matrices read/write (#16155)
Spec change was added in #13153 It states that the overload of joint_matrix_apply that takes two matrices can modify both matrices. I also updated the test to reflect the change.
1 parent ada8c86 commit 5e7de51

File tree

3 files changed

+70
-45
lines changed

3 files changed

+70
-45
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -112,35 +112,39 @@ joint_matrix_apply(Group sg, joint_matrix<Group, T, Use, M, N, Layout> &jm,
112112
return;
113113
}
114114

115-
template <typename Group, typename T, use Use, size_t M, size_t N,
115+
template <typename Group, typename T0, typename T1, use Use, size_t M, size_t N,
116116
layout Layout, typename F>
117117
inline __SYCL_ALWAYS_INLINE void
118-
joint_matrix_apply(Group sg, joint_matrix<Group, T, Use, M, N, Layout> &jmsrc,
119-
joint_matrix<Group, T, Use, M, N, Layout> &jmdest,
118+
joint_matrix_apply(Group sg, joint_matrix<Group, T0, Use, M, N, Layout> &jm0,
119+
joint_matrix<Group, T1, Use, M, N, Layout> &jm1,
120120
F &&lambda) {
121121
#if defined(__SYCL_DEVICE_ONLY__)
122122
#if defined(__NVPTX__) || defined(__HIP_PLATFORM_AMD_MFMA__)
123123
std::ignore = sg;
124-
for (int i = 0; i < jmsrc.matrix_impl.wi_marray.size(); i++) {
125-
lambda(jmsrc.matrix_impl.wi_marray[i], jmdest.matrix_impl.wi_marray[i]);
124+
for (int i = 0; i < jm0.matrix_impl.wi_marray.size(); i++) {
125+
lambda(jm0.matrix_impl.wi_marray[i], jm1.matrix_impl.wi_marray[i]);
126126
}
127127
#else // NVPTX
128-
using storage_element_type =
128+
using storage_element_type0 =
129129
typename oneapi::detail::jm_type_interpretation_helper_trait<
130-
T>::storage_element_type;
131-
auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jmsrc);
132-
auto wi_data_d = sycl::ext::oneapi::detail::get_wi_data(sg, jmdest);
133-
for (int i = 0; i < wi_data_c.length(); i++) {
134-
storage_element_type elementsrc = wi_data_c[i];
135-
storage_element_type elementdest = wi_data_d[i];
136-
lambda(elementsrc, elementdest);
137-
wi_data_d[i] = elementdest;
130+
T0>::storage_element_type;
131+
using storage_element_type1 =
132+
typename oneapi::detail::jm_type_interpretation_helper_trait<
133+
T1>::storage_element_type;
134+
auto wi_data_0 = sycl::ext::oneapi::detail::get_wi_data(sg, jm0);
135+
auto wi_data_1 = sycl::ext::oneapi::detail::get_wi_data(sg, jm1);
136+
for (int i = 0; i < wi_data_0.length(); i++) {
137+
storage_element_type0 element0 = wi_data_0[i];
138+
storage_element_type1 element1 = wi_data_1[i];
139+
lambda(element0, element1);
140+
wi_data_0[i] = element0;
141+
wi_data_1[i] = element1;
138142
}
139143
#endif
140144
#else
141145
std::ignore = sg;
142-
std::ignore = jmsrc;
143-
std::ignore = jmdest;
146+
std::ignore = jm0;
147+
std::ignore = jm1;
144148
std::ignore = lambda;
145149
throw exception(make_error_code(errc::runtime),
146150
"joint matrix is not supported on host.");

sycl/test-e2e/Matrix/common.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,13 @@ void matrix_copy(unsigned int rows, unsigned int cols, T *src, T *dst) {
157157
}
158158
}
159159

160+
template <typename F, typename T>
161+
void matrix_apply(unsigned int rows, unsigned int cols, T *mat, F op) {
162+
for (unsigned int i = 0; i < rows; i++)
163+
for (unsigned int j = 0; j < cols; j++)
164+
mat[i * cols + j] = op(mat[i * cols + j]);
165+
}
166+
160167
template <typename T1, typename T2, bool exact = false>
161168
bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) {
162169
for (int i = 0; i < rows; i++) {
@@ -174,7 +181,7 @@ bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) {
174181
<< ", Epsilon: " << FLOAT_EPSILON << "\n";
175182
return false;
176183
}
177-
} else if constexpr (exact || std::is_same_v<T1, int32_t>) {
184+
} else if constexpr (exact || std::is_integral_v<T1>) {
178185
if (src[i * cols + j] != ref[i * cols + j]) {
179186
std::cout << "Incorrect result in matrix."
180187
<< "i: " << i << ", j: " << j

sycl/test-e2e/Matrix/joint_matrix_apply_two_matrices_impl.hpp

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,26 @@
77
//===----------------------------------------------------------------------===//
88
#include <sycl/usm.hpp>
99

10-
template <typename Tc, typename Ta, size_t M, size_t N>
11-
bool apply_verify(Tc *C, Tc *D, Ta *A, Ta *Ar) {
12-
for (size_t i = 0; i < M; i++)
13-
for (size_t j = 0; j < N; j++) {
14-
Tc diffc = D[i * N + j] - C[i * N + j] * 2;
15-
Ta diffa = Ar[i * N + j] - (A[i * N + j] + 42);
16-
if constexpr (std::is_same_v<Ta, bfloat16>) {
17-
if (std::fabs(diffc) > FLOAT_EPSILON ||
18-
std::fabs(diffa) > FLOAT_EPSILON || std::isnan(C[i * N + j]) ||
19-
std::isnan(A[i * N + j])) {
20-
return false;
21-
}
22-
} else {
23-
if (std::abs(diffc) > 0 || std::abs(diffa) > 0) {
24-
return false;
25-
}
26-
}
27-
}
28-
return true;
10+
template <typename T> T mul2(T x) { return x * 2; }
11+
12+
template <typename T> T add5(T x) { return x + 5; }
13+
14+
template <typename Tc, size_t M, size_t N>
15+
bool apply_verify(Tc *C, Tc *D, Tc *ref) {
16+
Tc *refcopy = (Tc *)std::malloc(M * N * sizeof(Tc));
17+
memcpy(refcopy, ref, M * N * sizeof(Tc));
18+
matrix_apply(M, N, ref, mul2<Tc>);
19+
bool res = matrix_compare(M, N, D, ref);
20+
21+
matrix_apply(M, N, refcopy, add5<Tc>);
22+
res &= matrix_compare(M, N, C, refcopy);
23+
return res;
2924
}
25+
3026
template <typename Tc, typename Ta, size_t TM, size_t TN, size_t TK, size_t M,
3127
size_t N, size_t K, class kernel_name>
32-
bool apply_two_matrices(Tc *C, Tc *D, Ta *A, Ta *Ar, queue q) {
28+
bool apply_two_matrices(Tc *C, Tc *D, Ta *A, Ta *Ar, Tc *Cref, Ta *Aref,
29+
queue q) {
3330
size_t NDRangeM = M / TM;
3431
size_t NDRangeN = N / TN;
3532

@@ -70,22 +67,33 @@ bool apply_two_matrices(Tc *C, Tc *D, Ta *A, Ta *Ar, queue q) {
7067
joint_matrix_load(
7168
sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / sg_size * TN,
7269
N, layout::row_major);
73-
joint_matrix_apply(sg, sub_c, sub_d,
74-
[](const Tc &x, Tc &y) { y = x * 2; });
70+
joint_matrix_apply(sg, sub_c, sub_d, [](Tc &x, Tc &y) {
71+
y = mul2(x);
72+
x = add5(x);
73+
});
7574
joint_matrix_store(
7675
sg, sub_d, pD + (sg_startx * TM) * N + sg_starty / sg_size * TN,
7776
N, layout::row_major);
77+
joint_matrix_store(
78+
sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / sg_size * TN,
79+
N, layout::row_major);
7880
joint_matrix_load(
7981
sg, sub_a, pA + (sg_startx * TM) * K + sg_starty / sg_size * TK,
8082
K);
81-
joint_matrix_apply(sg, sub_a, sub_ar,
82-
[](const Ta &x, Ta &y) { y = x + 42; });
83+
joint_matrix_apply(sg, sub_a, sub_ar, [](Ta &x, Ta &y) {
84+
y = mul2(x);
85+
x = add5(x);
86+
});
8387
ext::intel::experimental::matrix::joint_matrix_store(
8488
sg, sub_ar,
8589
pAr + (sg_startx * TM) * K + sg_starty / sg_size * TK, K);
90+
ext::intel::experimental::matrix::joint_matrix_store(
91+
sg, sub_a, pA + (sg_startx * TM) * K + sg_starty / sg_size * TK,
92+
K);
8693
}); // parallel for
8794
}).wait();
88-
return apply_verify<Tc, Ta, M, N>(C, D, A, Ar);
95+
return apply_verify<Tc, M, N>(C, D, Cref) &&
96+
apply_verify<Ta, M, N>(A, Ar, Aref);
8997
}
9098

9199
template <typename Ta, typename Tc, size_t TM, size_t TN, size_t TK,
@@ -96,16 +104,20 @@ bool test() {
96104
static constexpr size_t K = TK * 2;
97105
queue q;
98106

107+
Tc *Cref = malloc_shared<Tc>(M * N, q);
108+
Ta *Aref = malloc_shared<Ta>(M * K, q);
99109
Tc *C = malloc_shared<Tc>(M * N, q);
100110
Tc *D = malloc_shared<Tc>(M * N, q);
101111
Ta *A = malloc_shared<Ta>(M * K, q);
102112
Ta *Ar = malloc_shared<Ta>(M * K, q);
103113

104-
matrix_rand(M, N, (Tc *)C, (Tc)100);
105-
matrix_rand(M, K, (Ta *)A, (Ta)100);
114+
matrix_rand(M, N, (Tc *)Cref, (Tc)100);
115+
matrix_rand(M, K, (Ta *)Aref, (Ta)100);
116+
matrix_copy(M, N, Cref, C);
117+
matrix_copy(M, K, Aref, A);
106118

107119
bool res = apply_two_matrices<Tc, Ta, TM, TN, TK, M, N, K, kernel_name>(
108-
C, D, A, Ar, q);
120+
C, D, A, Ar, Cref, Aref, q);
109121

110122
if constexpr (std::is_same_v<Ta, bfloat16>)
111123
std::cout << "bfloat16 " << TM << "x" << TN << "x" << TK << ": "
@@ -117,6 +129,8 @@ bool test() {
117129
free(D, q);
118130
free(A, q);
119131
free(Ar, q);
132+
free(Cref, q);
133+
free(Aref, q);
120134

121135
return res;
122136
}

0 commit comments

Comments
 (0)