Skip to content

[SYCL][CUDA] Allow joint_matrix to be loaded from const T #6532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 5, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 55 additions & 53 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,11 @@ struct joint_matrix_load_impl<
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix<
S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
multi_ptr<T, Space> src, size_t stride) {
if constexpr (std::is_same<T, uint16_t>::value ||
if constexpr (std::is_same<std::remove_const_t<T>, uint16_t>::value ||
std::is_same<
T, sycl::ext::oneapi::experimental::bfloat16>::value) {
auto tileptr = reinterpret_cast<int32_t const *>(src.get());
std::remove_const_t<T>,
sycl::ext::oneapi::experimental::bfloat16>::value) {
auto tileptr = reinterpret_cast<const int32_t *>(src.get());
auto destptr = reinterpret_cast<int32_t *>(&res.wi_marray);
if constexpr (NumRows == 16 && NumCols == 16) {
if constexpr (Use ==
Expand All @@ -246,8 +247,8 @@ struct joint_matrix_load_impl<
__mma_bf16_m32n8k16_ld_b(destptr, tileptr, stride,
get_layout_id<Layout>());
}
} else if constexpr (std::is_same<T, uint8_t>::value) {
auto tileptr = reinterpret_cast<int32_t const *>(src.get());
} else if constexpr (std::is_same<std::remove_const_t<T>, uint8_t>::value) {
auto tileptr = reinterpret_cast<const int32_t *>(src.get());
auto destptr = reinterpret_cast<int32_t *>(&res.wi_marray);
if constexpr (NumRows == 16 && NumCols == 16) {
if constexpr (Use ==
Expand All @@ -272,8 +273,8 @@ struct joint_matrix_load_impl<
__imma_m32n8k16_ld_b_u8(destptr, tileptr, stride,
get_layout_id<Layout>());
}
} else if constexpr (std::is_same<T, int8_t>::value) {
auto tileptr = reinterpret_cast<int32_t const *>(src.get());
} else if constexpr (std::is_same<std::remove_const_t<T>, int8_t>::value) {
auto tileptr = reinterpret_cast<const int32_t *>(src.get());
auto destptr = reinterpret_cast<int32_t *>(&res.wi_marray);
if constexpr (NumRows == 16 && NumCols == 16) {
if constexpr (Use ==
Expand All @@ -298,8 +299,8 @@ struct joint_matrix_load_impl<
__imma_m32n8k16_ld_b_s8(destptr, tileptr, stride,
get_layout_id<Layout>());
}
} else if constexpr (std::is_same<T, half>::value) {
auto tileptr = reinterpret_cast<int32_t const *>(src.get());
} else if constexpr (std::is_same<std::remove_const_t<T>, half>::value) {
auto tileptr = reinterpret_cast<const int32_t *>(src.get());
auto dstptr = reinterpret_cast<int32_t *>(&res.wi_marray);
if constexpr (NumRows == 16 && NumCols == 16) {
if constexpr (Use ==
Expand Down Expand Up @@ -331,7 +332,7 @@ struct joint_matrix_load_impl<
get_layout_id<Layout>());
}

} else if constexpr (std::is_same<T, int32_t>::value) {
} else if constexpr (std::is_same<std::remove_const_t<T>, int32_t>::value) {
auto destptr = reinterpret_cast<int32_t *>(&res.wi_marray);
if constexpr (NumRows == 16 && NumCols == 16) {
__imma_m16n16k16_ld_c(destptr, src.get(), stride,
Expand All @@ -343,7 +344,7 @@ struct joint_matrix_load_impl<
__imma_m32n8k16_ld_c(destptr, src.get(), stride,
get_layout_id<Layout>());
}
} else if constexpr (std::is_same<T, float>::value) {
} else if constexpr (std::is_same<std::remove_const_t<T>, float>::value) {
if constexpr (std::is_same<S, float>::value) {
auto dstptr = reinterpret_cast<float *>(&res.wi_marray);
if constexpr (NumRows == 16 && NumCols == 16) {
Expand All @@ -359,7 +360,7 @@ struct joint_matrix_load_impl<
} else if constexpr (std::is_same<S,
sycl::ext::oneapi::experimental::
matrix::precision::tf32>::value) {
auto tileptr = reinterpret_cast<int32_t *>(src.get());
auto tileptr = reinterpret_cast<const int32_t *>(src.get());
auto dstptr = reinterpret_cast<int32_t *>(&res.wi_marray);
if constexpr (NumRows == 16 && NumCols == 8) {
__mma_tf32_m16n16k8_ld_a(dstptr, tileptr, stride,
Expand All @@ -369,7 +370,7 @@ struct joint_matrix_load_impl<
get_layout_id<Layout>());
}
}
} else if constexpr (std::is_same<T, double>::value) {
} else if constexpr (std::is_same<std::remove_const_t<T>, double>::value) {
auto dstptr = reinterpret_cast<double *>(&res.wi_marray);
if constexpr (Use ==
sycl::ext::oneapi::experimental::matrix::matrix_use::a) {
Expand Down Expand Up @@ -559,9 +560,9 @@ struct joint_matrix_mad_impl<
D;
if constexpr (M == 16 && N == 16 && K == 16) {
if constexpr (std::is_same<T2, int32_t>::value) {
auto ptrA = reinterpret_cast<int32_t const *>(&A.wi_marray);
auto ptrB = reinterpret_cast<int32_t const *>(&B.wi_marray);
auto ptrC = reinterpret_cast<int32_t const *>(&C.wi_marray);
auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
auto ptrC = reinterpret_cast<const int32_t *>(&C.wi_marray);
auto ptrD = reinterpret_cast<int32_t *>(&D.wi_marray);
if constexpr (std::is_same<T1, int8_t>::value) {
__imma_m16n16k16_mma_s8(ptrD, ptrA, ptrB, ptrC,
Expand All @@ -571,34 +572,34 @@ struct joint_matrix_mad_impl<
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
} else if constexpr (std::is_same<T1, half>::value) {
auto ptrA = reinterpret_cast<int32_t const *>(&A.wi_marray);
auto ptrB = reinterpret_cast<int32_t const *>(&B.wi_marray);
auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
if constexpr (std::is_same<T2, float>::value) {
__hmma_m16n16k16_mma_f32f32(
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<float const *>(&C.wi_marray),
reinterpret_cast<const float *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
} else if constexpr (std::is_same<T2, half>::value) {
__hmma_m16n16k16_mma_f16f16(
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<int32_t const *>(&C.wi_marray),
reinterpret_cast<const int32_t *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
} else if constexpr (std::is_same<T1, uint16_t>::value ||
std::is_same<T1, sycl::ext::oneapi::experimental::
bfloat16>::value) {
__mma_bf16_m16n16k16_mma_f32(
reinterpret_cast<float *>(&D.wi_marray),
reinterpret_cast<int32_t const *>(&A.wi_marray),
reinterpret_cast<int32_t const *>(&B.wi_marray),
reinterpret_cast<float const *>(&C.wi_marray),
reinterpret_cast<const int32_t *>(&A.wi_marray),
reinterpret_cast<const int32_t *>(&B.wi_marray),
reinterpret_cast<const float *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
} else if constexpr (M == 8 && N == 32 && K == 16) {
if constexpr (std::is_same<T2, int32_t>::value) {
auto ptrA = reinterpret_cast<int32_t const *>(&A.wi_marray);
auto ptrB = reinterpret_cast<int32_t const *>(&B.wi_marray);
auto ptrC = reinterpret_cast<int32_t const *>(&C.wi_marray);
auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
auto ptrC = reinterpret_cast<const int32_t *>(&C.wi_marray);
auto ptrD = reinterpret_cast<int32_t *>(&D.wi_marray);
if constexpr (std::is_same<T1, int8_t>::value) {
__imma_m8n32k16_mma_s8(ptrD, ptrA, ptrB, ptrC,
Expand All @@ -608,34 +609,34 @@ struct joint_matrix_mad_impl<
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
} else if constexpr (std::is_same<T1, half>::value) {
auto ptrA = reinterpret_cast<int32_t const *>(&A.wi_marray);
auto ptrB = reinterpret_cast<int32_t const *>(&B.wi_marray);
auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
if constexpr (std::is_same<T2, float>::value) {
__hmma_m8n32k16_mma_f32f32(
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<float const *>(&C.wi_marray),
reinterpret_cast<const float *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
} else if constexpr (std::is_same<T2, half>::value) {
__hmma_m8n32k16_mma_f16f16(
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<int32_t const *>(&C.wi_marray),
reinterpret_cast<const int32_t *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
} else if constexpr (std::is_same<T1, uint16_t>::value ||
std::is_same<T1, sycl::ext::oneapi::experimental::
bfloat16>::value) {
__mma_bf16_m8n32k16_mma_f32(
reinterpret_cast<float *>(&D.wi_marray),
reinterpret_cast<int32_t const *>(&A.wi_marray),
reinterpret_cast<int32_t const *>(&B.wi_marray),
reinterpret_cast<float const *>(&C.wi_marray),
reinterpret_cast<const int32_t *>(&A.wi_marray),
reinterpret_cast<const int32_t *>(&B.wi_marray),
reinterpret_cast<const float *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
} else if constexpr (M == 32 && N == 8 && K == 16) {
if constexpr (std::is_same<T2, int32_t>::value) {
auto ptrA = reinterpret_cast<int32_t const *>(&A.wi_marray);
auto ptrB = reinterpret_cast<int32_t const *>(&B.wi_marray);
auto ptrC = reinterpret_cast<int32_t const *>(&C.wi_marray);
auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
auto ptrC = reinterpret_cast<const int32_t *>(&C.wi_marray);
auto ptrD = reinterpret_cast<int32_t *>(&D.wi_marray);
if constexpr (std::is_same<T1, int8_t>::value) {
__imma_m32n8k16_mma_s8(ptrD, ptrA, ptrB, ptrC,
Expand All @@ -649,22 +650,22 @@ struct joint_matrix_mad_impl<
bfloat16>::value) {
__mma_bf16_m32n8k16_mma_f32(
reinterpret_cast<float *>(&D.wi_marray),
reinterpret_cast<int32_t const *>(&A.wi_marray),
reinterpret_cast<int32_t const *>(&B.wi_marray),
reinterpret_cast<float const *>(&C.wi_marray),
reinterpret_cast<const int32_t *>(&A.wi_marray),
reinterpret_cast<const int32_t *>(&B.wi_marray),
reinterpret_cast<const float *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
} else if constexpr (std::is_same<T1, half>::value) {
auto ptrA = reinterpret_cast<int32_t const *>(&A.wi_marray);
auto ptrB = reinterpret_cast<int32_t const *>(&B.wi_marray);
auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
if constexpr (std::is_same<T2, float>::value) {
__hmma_m32n8k16_mma_f32f32(
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<float const *>(&C.wi_marray),
reinterpret_cast<const float *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
} else if constexpr (std::is_same<T2, half>::value) {
__hmma_m32n8k16_mma_f16f16(
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<int32_t const *>(&C.wi_marray),
reinterpret_cast<const int32_t *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
}
Expand All @@ -676,9 +677,9 @@ struct joint_matrix_mad_impl<
get_layout_pair_id<LayoutA, LayoutB>(), 0);
} else if constexpr (std::is_same<T1, double>::value) {
__dmma_m8n8k4_mma_f64(reinterpret_cast<double *>(&D.wi_marray),
reinterpret_cast<double const *>(&A.wi_marray),
reinterpret_cast<double const *>(&B.wi_marray),
reinterpret_cast<double const *>(&C.wi_marray),
reinterpret_cast<const double *>(&A.wi_marray),
reinterpret_cast<const double *>(&B.wi_marray),
reinterpret_cast<const double *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
return D;
Expand All @@ -691,13 +692,14 @@ struct joint_matrix_mad_impl<
namespace experimental {
namespace matrix {

template <typename Group, typename S, typename T, matrix_use Use,
size_t NumRows, size_t NumCols, matrix_layout Layout,
access::address_space Space,
std::enable_if_t<std::is_same<S, T>::value ||
(std::is_same<S, precision::tf32>::value &&
std::is_same<T, float>::value),
bool> = true>
template <
typename Group, typename S, typename T, matrix_use Use, size_t NumRows,
size_t NumCols, matrix_layout Layout, access::address_space Space,
std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
(std::is_same<S, precision::tf32>::value &&

std::is_same<std::remove_const_t<T>, float>::value),
bool> = true>
void joint_matrix_load(
Group sg, joint_matrix<S, Use, NumRows, NumCols, Layout, Group> &res,
multi_ptr<T, Space> src, size_t stride) {
Expand Down