Skip to content

Commit 42e0c62

Browse files
author
mmoadeli
committed
Update use cases of mad to have variables holding result of mad as a parameter of the function.
1 parent 3c460af commit 42e0c62

File tree

5 files changed

+21
-22
lines changed

5 files changed

+21
-22
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ template <typename Group,
265265
size_t M, size_t N, access::address_space Space,
266266
access::decorated IsDecorated>
267267
void store_layoutT(
268-
joint_matrix_hip<
268+
const joint_matrix_hip<
269269
T, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
270270
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src,
271271
multi_ptr<T, Space, IsDecorated> dst, size_t stride, Group &sg) {
@@ -333,7 +333,7 @@ void store_layoutT(
333333
template <typename Group, typename T, size_t M, size_t N,
334334
access::address_space Space, access::decorated IsDecorated>
335335
void joint_matrix_store_hip(
336-
joint_matrix_hip<
336+
const joint_matrix_hip<
337337
T, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
338338
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src,
339339
multi_ptr<T, Space, IsDecorated> dst, size_t stride,
@@ -356,11 +356,11 @@ void joint_matrix_mad_hip(
356356
joint_matrix_hip<
357357
Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
358358
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D,
359-
joint_matrix_hip<Tm, sycl::ext::oneapi::experimental::matrix::use::a, M, K,
360-
LayoutA> &A,
361-
joint_matrix_hip<Tm, sycl::ext::oneapi::experimental::matrix::use::b, K, N,
362-
LayoutB> &B,
363-
joint_matrix_hip<
359+
const joint_matrix_hip<Tm, sycl::ext::oneapi::experimental::matrix::use::a,
360+
M, K, LayoutA> &A,
361+
const joint_matrix_hip<Tm, sycl::ext::oneapi::experimental::matrix::use::b,
362+
K, N, LayoutB> &B,
363+
const joint_matrix_hip<
364364
Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
365365
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &C) {
366366
if constexpr (std::is_same_v<Tm, sycl::half>) {
@@ -387,12 +387,12 @@ void joint_matrix_mad_hip(
387387
} else if constexpr (std::is_same_v<Tm, int8_t>) {
388388
if constexpr (M == 16 && N == 16) {
389389
D.data = __builtin_amdgcn_mfma_i32_16x16x16i8(
390-
*reinterpret_cast<int32_t *>(A.data),
391-
*reinterpret_cast<int32_t *>(B.data), C.data, 0, 0, 0);
390+
*reinterpret_cast<const Tc *>(A.data),
391+
*reinterpret_cast<const Tc *>(B.data), C.data, 0, 0, 0);
392392
} else if constexpr (M == 32 && N == 32) {
393393
D.data = __builtin_amdgcn_mfma_i32_32x32x8i8(
394-
*reinterpret_cast<int32_t *>(A.data),
395-
*reinterpret_cast<int32_t *>(B.data), C.data, 0, 0, 0);
394+
*reinterpret_cast<const Tc *>(A.data),
395+
*reinterpret_cast<const Tc *>(B.data), C.data, 0, 0, 0);
396396
}
397397
} else {
398398
static_assert(false && "Invalid configuration!");

sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,6 @@ joint_matrix_fill(Group,
192192
#if defined(__NVPTX__)
193193
res.cuda_impl.wi_marray = v;
194194
#elif defined(__HIP_PLATFORM_AMD_MFMA__)
195-
std::ignore = sg;
196195
sycl::ext::oneapi::detail::joint_matrix_apply(res.hip_impl,
197196
[=](T) { return v; });
198197
#else
@@ -219,7 +218,7 @@ template <
219218
std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value, bool> =
220219
true>
221220
inline __SYCL_ALWAYS_INLINE void joint_matrix_load(
222-
Group,
221+
Group &sg,
223222
joint_matrix<Group, S, use::accumulator, NumRows, NumCols,
224223
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res,
225224
multi_ptr<T, Space, IsDecorated> src, size_t stride,
@@ -228,6 +227,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load(
228227
static_assert(Space != access::address_space::private_space,
229228
"Joint Matrix doesn't support load from private memory!");
230229
#if defined(__NVPTX__)
230+
std::ignore = sg;
231231
sycl::ext::oneapi::detail::load_accumulator_cuda(res.cuda_impl, src, stride,
232232
Layout);
233233
#elif defined(__HIP_PLATFORM_AMD_MFMA__)
@@ -266,6 +266,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load(
266266
}
267267
#endif // defined(__NVPTX__)
268268
#else
269+
std::ignore = sg;
269270
std::ignore = res;
270271
std::ignore = src;
271272
std::ignore = stride;
@@ -284,13 +285,14 @@ template <
284285
std::is_same<std::remove_const_t<T>, float>::value),
285286
bool> = true>
286287
inline __SYCL_ALWAYS_INLINE void
287-
joint_matrix_load(Group,
288+
joint_matrix_load(Group &sg,
288289
joint_matrix<Group, S, Use, NumRows, NumCols, Layout> &res,
289290
multi_ptr<T, Space, IsDecorated> src, size_t stride) {
290291
#if defined(__SYCL_DEVICE_ONLY__)
291292
static_assert(Space != access::address_space::private_space,
292293
"Joint Matrix doesn't support load from private memory!");
293294
#if defined(__NVPTX__)
295+
std::ignore = sg;
294296
sycl::ext::oneapi::detail::load_multiplicand_cuda<S, T, NumRows, NumCols, Use,
295297
Layout, Space>(
296298
res.cuda_impl, src, stride);
@@ -320,7 +322,7 @@ joint_matrix_load(Group,
320322
template <typename Group, typename T, size_t NumRows, size_t NumCols,
321323
access::address_space Space, access::decorated IsDecorated>
322324
inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
323-
Group,
325+
Group &sg,
324326
const joint_matrix<Group, T, use::accumulator, NumRows, NumCols,
325327
sycl::ext::oneapi::experimental::matrix::layout::dynamic>
326328
&src,
@@ -330,6 +332,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
330332
static_assert(Space != access::address_space::private_space,
331333
"Joint Matrix doesn't support store to private memory!");
332334
#if defined(__NVPTX__)
335+
std::ignore = sg;
333336
sycl::ext::oneapi::detail::joint_matrix_store_cuda<T, NumRows, NumCols,
334337
Space>(src.cuda_impl, dst,
335338
stride, Layout);
@@ -403,13 +406,9 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_mad(
403406
}
404407
#elif defined(__HIP_PLATFORM_AMD_MFMA__)
405408
if constexpr (std::is_same<Ta, Tb>::value) {
406-
joint_matrix<Group, Tc, use::accumulator, M, N,
407-
sycl::ext::oneapi::experimental::matrix::layout::dynamic>
408-
D;
409409
sycl::ext::oneapi::detail::joint_matrix_mad_hip<Ta, Tc, M, K, N, LayoutA,
410410
LayoutB>(
411411
D.hip_impl, A.hip_impl, B.hip_impl, C.hip_impl);
412-
return D;
413412
} else {
414413
assert(false && "Ta != Tb : In the HIP backend joint_matrix_mad "
415414
"requires that joint_matrix data types Ta and Tb match");

sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ void hip_matrix_apply() {
7070
joint_matrix_apply(sg, sub_b, [=](InType v) { return v * 3; });
7171
joint_matrix_apply(sg, sub_c, [=](OutType v) { return v * 4; });
7272

73-
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
73+
joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c);
7474

7575
joint_matrix_store(
7676
sg, sub_c,

sycl/test-e2e/Matrix/joint_matrix_hip_fill.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ void hip_matrix_fill() {
5959
joint_matrix_fill(sg, sub_b, 2);
6060
joint_matrix_fill(sg, sub_c, 3);
6161

62-
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
62+
joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c);
6363

6464
joint_matrix_store(
6565
sg, sub_c,

sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ void hip_matrix_mfma() {
8282
accC.template get_multi_ptr<access::decorated::yes>(), N,
8383
layout::row_major);
8484

85-
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
85+
joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c);
8686

8787
joint_matrix_store(
8888
sg, sub_c,

0 commit comments

Comments
 (0)