Skip to content

[SYCL][CUDA] joint_matrix required changes following #11215 #11563

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
12 changes: 6 additions & 6 deletions clang/include/clang/Basic/BuiltinsNVPTX.def
Original file line number Diff line number Diff line change
Expand Up @@ -2545,22 +2545,22 @@ TARGET_BUILTIN(__hmma_m16n16k16_ld_a, "vi*iC*UiIi", "", AND(SM_70,PTX60))
TARGET_BUILTIN(__hmma_m16n16k16_ld_b, "vi*iC*UiIi", "", AND(SM_70,PTX60))
TARGET_BUILTIN(__hmma_m16n16k16_ld_c_f16, "vi*iC*UiIi", "", AND(SM_70,PTX60))
TARGET_BUILTIN(__hmma_m16n16k16_ld_c_f32, "vf*fC*UiIi", "", AND(SM_70,PTX60))
TARGET_BUILTIN(__hmma_m16n16k16_st_c_f16, "vi*i*UiIi", "", AND(SM_70,PTX60))
TARGET_BUILTIN(__hmma_m16n16k16_st_c_f32, "vf*f*UiIi", "", AND(SM_70,PTX60))
TARGET_BUILTIN(__hmma_m16n16k16_st_c_f16, "vi*iC*UiIi", "", AND(SM_70,PTX60))
TARGET_BUILTIN(__hmma_m16n16k16_st_c_f32, "vf*fC*UiIi", "", AND(SM_70,PTX60))

TARGET_BUILTIN(__hmma_m32n8k16_ld_a, "vi*iC*UiIi", "", AND(SM_70,PTX61))
TARGET_BUILTIN(__hmma_m32n8k16_ld_b, "vi*iC*UiIi", "", AND(SM_70,PTX61))
TARGET_BUILTIN(__hmma_m32n8k16_ld_c_f16, "vi*iC*UiIi", "", AND(SM_70,PTX61))
TARGET_BUILTIN(__hmma_m32n8k16_ld_c_f32, "vf*fC*UiIi", "", AND(SM_70,PTX61))
TARGET_BUILTIN(__hmma_m32n8k16_st_c_f16, "vi*i*UiIi", "", AND(SM_70,PTX61))
TARGET_BUILTIN(__hmma_m32n8k16_st_c_f32, "vf*f*UiIi", "", AND(SM_70,PTX61))
TARGET_BUILTIN(__hmma_m32n8k16_st_c_f16, "vi*iC*UiIi", "", AND(SM_70,PTX61))
TARGET_BUILTIN(__hmma_m32n8k16_st_c_f32, "vf*fC*UiIi", "", AND(SM_70,PTX61))

TARGET_BUILTIN(__hmma_m8n32k16_ld_a, "vi*iC*UiIi", "", AND(SM_70,PTX61))
TARGET_BUILTIN(__hmma_m8n32k16_ld_b, "vi*iC*UiIi", "", AND(SM_70,PTX61))
TARGET_BUILTIN(__hmma_m8n32k16_ld_c_f16, "vi*iC*UiIi", "", AND(SM_70,PTX61))
TARGET_BUILTIN(__hmma_m8n32k16_ld_c_f32, "vf*fC*UiIi", "", AND(SM_70,PTX61))
TARGET_BUILTIN(__hmma_m8n32k16_st_c_f16, "vi*i*UiIi", "", AND(SM_70,PTX61))
TARGET_BUILTIN(__hmma_m8n32k16_st_c_f32, "vf*f*UiIi", "", AND(SM_70,PTX61))
TARGET_BUILTIN(__hmma_m8n32k16_st_c_f16, "vi*iC*UiIi", "", AND(SM_70,PTX61))
TARGET_BUILTIN(__hmma_m8n32k16_st_c_f32, "vf*fC*UiIi", "", AND(SM_70,PTX61))

TARGET_BUILTIN(__hmma_m16n16k16_mma_f16f16, "vi*iC*iC*iC*IiIi", "", AND(SM_70,PTX60))
TARGET_BUILTIN(__hmma_m16n16k16_mma_f32f16, "vf*iC*iC*iC*IiIi", "", AND(SM_70,PTX60))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,6 @@ The complete set of matrix data types and shapes that are supported by
the `ext_oneapi_cuda` backend are represented in the following
table. In this architecture's implementation,
the type of the A matrix must be the same as the type of the B
matrix. Also, the type of the C matrix must be the same as the type of the D
matrix.

IMPORTANT: When compiling for the `ext_oneapi_cuda` backend the target
Expand All @@ -933,29 +932,37 @@ supported parameter combination is specified in the following table.

[frame="none",options="header"]
|======================
| A and B type | C and D type | M | N | K | Minimum Compute Capability
.3+| `matrix_type::fp16` .3+| `matrix_type::fp32`
|16 |16 |16 .6+| sm_70
| A and B type | C type | D type | M | N | K | Minimum Compute Capability
.3+| `matrix_type::fp16` .3+| `matrix_type::fp32` .3+| `matrix_type::fp32`
|16 |16 |16 .12+| sm_70
|8 |32 |16
|32 |8 |16
.3+| `matrix_type::fp16` .3+| `matrix_type::fp16`
.3+| `matrix_type::fp16` .3+| `matrix_type::fp16` .3+| `matrix_type::fp16`
|16 |16 |16
|8 |32 |16
|32 |8 |16
.3+| `matrix_type::sint8` .3+| `matrix_type::sint32`
.3+| `matrix_type::fp16` .3+| `matrix_type::fp32` .3+| `matrix_type::fp16`
|16 |16 |16
|8 |32 |16
|32 |8 |16
.3+| `matrix_type::fp16` .3+| `matrix_type::fp16` .3+| `matrix_type::fp32`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, where are the bfloat16 combinations?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are in the same table at the bottom, between tf32 and fp64.

|16 |16 |16
|8 |32 |16
|32 |8 |16
.3+| `matrix_type::sint8` .3+| `matrix_type::sint32` .3+| `matrix_type::sint32`
|16 |16 |16 .6+| sm_72
|8 |32 |16
|32 |8 |16
.3+|`matrix_type::uint8` .3+|`matrix_type::sint32`
.3+|`matrix_type::uint8` .3+|`matrix_type::sint32` .3+|`matrix_type::sint32`
|16 |16 |16
|8 |32 |16
|32 |8 |16
| `matrix_type::tf32` | `matrix_type::fp32` |16 |16 |8 .5+| sm_80
.3+|`matrix_type::bf16` .3+| `matrix_type::fp32`
| `matrix_type::tf32` | `matrix_type::fp32` | `matrix_type::fp32` |16 |16 |8 .5+| sm_80
.3+|`matrix_type::bf16` .3+| `matrix_type::fp32` .3+| `matrix_type::fp32`
|16 |16 |16
|8 |32 |16
|32 |8 |16
| `matrix_type::fp64` | `matrix_type::fp64` |8 |8 |4
| `matrix_type::fp64` | `matrix_type::fp64` | `matrix_type::fp64` |8 |8 |4
|======================

IMPORTANT: The `stride` argument to `joint_matrix_load` and
Expand Down
172 changes: 105 additions & 67 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,63 +357,59 @@ template <sycl::ext::oneapi::experimental::matrix::layout Layout, typename T,
size_t NumRows, size_t NumCols, access::address_space Space,
access::decorated IsDecorated>
void store_layoutT(
joint_matrix_cuda<
const joint_matrix_cuda<
T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows,
NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src,
multi_ptr<T, Space, IsDecorated> dst, size_t stride) {
if constexpr (NumRows == 16 && NumCols == 16) {
if constexpr (std::is_same_v<T, float>) {
__hmma_m16n16k16_st_c_f32(dst.get(),
reinterpret_cast<float *>(&src.wi_marray),
stride, get_layout_id<Layout>());
__hmma_m16n16k16_st_c_f32(dst.get(), &src.wi_marray[0], stride,
get_layout_id<Layout>());
} else if constexpr (std::is_same_v<T, int32_t>) {
__imma_m16n16k16_st_c_i32(dst.get(),
reinterpret_cast<int32_t *>(&src.wi_marray),
stride, get_layout_id<Layout>());
__imma_m16n16k16_st_c_i32(dst.get(), &src.wi_marray[0], stride,
get_layout_id<Layout>());
} else if constexpr (std::is_same_v<T, half>) {
__hmma_m16n16k16_st_c_f16(reinterpret_cast<int32_t *>(dst.get()),
reinterpret_cast<int32_t *>(&src.wi_marray),
stride, get_layout_id<Layout>());
__hmma_m16n16k16_st_c_f16(
reinterpret_cast<int32_t *>(dst.get()),
reinterpret_cast<const int32_t *>(&src.wi_marray[0]), stride,
get_layout_id<Layout>());
}
} else if constexpr (NumRows == 8 && NumCols == 32) {
if constexpr (std::is_same_v<T, float>) {
__hmma_m8n32k16_st_c_f32(dst.get(),
reinterpret_cast<float *>(&src.wi_marray),
stride, get_layout_id<Layout>());
__hmma_m8n32k16_st_c_f32(dst.get(), &src.wi_marray[0], stride,
get_layout_id<Layout>());
} else if constexpr (std::is_same_v<T, int32_t>) {
__imma_m8n32k16_st_c_i32(dst.get(),
reinterpret_cast<int32_t *>(&src.wi_marray),
stride, get_layout_id<Layout>());
__imma_m8n32k16_st_c_i32(dst.get(), &src.wi_marray[0], stride,
get_layout_id<Layout>());
} else if constexpr (std::is_same_v<T, half>) {
__hmma_m8n32k16_st_c_f16(reinterpret_cast<int32_t *>(dst.get()),
reinterpret_cast<int32_t *>(&src.wi_marray),
stride, get_layout_id<Layout>());
__hmma_m8n32k16_st_c_f16(
reinterpret_cast<int32_t *>(dst.get()),
reinterpret_cast<const int32_t *>(&src.wi_marray[0]), stride,
get_layout_id<Layout>());
}
} else if constexpr (NumRows == 32 && NumCols == 8) {
if constexpr (std::is_same_v<T, float>) {
__hmma_m32n8k16_st_c_f32(dst.get(),
reinterpret_cast<float *>(&src.wi_marray),
stride, get_layout_id<Layout>());
__hmma_m32n8k16_st_c_f32(dst.get(), &src.wi_marray[0], stride,
get_layout_id<Layout>());
} else if constexpr (std::is_same_v<T, int32_t>) {
__imma_m32n8k16_st_c_i32(dst.get(),
reinterpret_cast<int32_t *>(&src.wi_marray),
stride, get_layout_id<Layout>());
__imma_m32n8k16_st_c_i32(dst.get(), &src.wi_marray[0], stride,
get_layout_id<Layout>());
} else if constexpr (std::is_same_v<T, half>) {
__hmma_m32n8k16_st_c_f16(reinterpret_cast<int32_t *>(dst.get()),
reinterpret_cast<int32_t *>(&src.wi_marray),
stride, get_layout_id<Layout>());
__hmma_m32n8k16_st_c_f16(
reinterpret_cast<int32_t *>(dst.get()),
reinterpret_cast<const int32_t *>(&src.wi_marray[0]), stride,
get_layout_id<Layout>());
}
} else if constexpr (std::is_same_v<T, double>) {
__dmma_m8n8k4_st_c_f64(dst.get(),
reinterpret_cast<double *>(&src.wi_marray), stride,
__dmma_m8n8k4_st_c_f64(dst.get(), &src.wi_marray[0], stride,
get_layout_id<Layout>());
}
}

template <typename T, size_t NumRows, size_t NumCols,
access::address_space Space, access::decorated IsDecorated>
void joint_matrix_store_cuda(
joint_matrix_cuda<
const joint_matrix_cuda<
T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows,
NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src,
multi_ptr<T, Space, IsDecorated> dst, size_t stride,
Expand Down Expand Up @@ -465,8 +461,8 @@ constexpr int get_layout_pair_id<
}

template <
typename Tm, typename Tc, std::size_t M, std::size_t K, std::size_t N,
sycl::ext::oneapi::experimental::matrix::layout LayoutA,
typename Tm, typename Tc, typename Td, std::size_t M, std::size_t K,
std::size_t N, sycl::ext::oneapi::experimental::matrix::layout LayoutA,
sycl::ext::oneapi::experimental::matrix::layout LayoutB,
std::enable_if_t<
(LayoutA ==
Expand All @@ -480,13 +476,13 @@ template <
bool> = true>
void joint_matrix_mad_cuda(
joint_matrix_cuda<
Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
Td, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D,
joint_matrix_cuda<Tm, sycl::ext::oneapi::experimental::matrix::use::a, M, K,
LayoutA> &A,
joint_matrix_cuda<Tm, sycl::ext::oneapi::experimental::matrix::use::b, K, N,
LayoutB> &B,
joint_matrix_cuda<
const joint_matrix_cuda<Tm, sycl::ext::oneapi::experimental::matrix::use::a,
M, K, LayoutA> &A,
const joint_matrix_cuda<Tm, sycl::ext::oneapi::experimental::matrix::use::b,
K, N, LayoutB> &B,
const joint_matrix_cuda<
Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &C) {
if constexpr (M == 16 && N == 16 && K == 16) {
Expand All @@ -506,16 +502,29 @@ void joint_matrix_mad_cuda(
auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
if constexpr (std::is_same_v<Tc, float>) {
__hmma_m16n16k16_mma_f32f32(
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<const float *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);

if constexpr (std::is_same<Td, float>::value) {
__hmma_m16n16k16_mma_f32f32(
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<const float *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
} else {
__hmma_m16n16k16_mma_f16f32(
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<const float *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
} else if constexpr (std::is_same_v<Tc, half>) {
__hmma_m16n16k16_mma_f16f16(
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<const int32_t *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
if constexpr (std::is_same<Td, float>::value) {
__hmma_m16n16k16_mma_f32f16(
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<const int32_t *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
} else {
__hmma_m16n16k16_mma_f16f16(
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<const int32_t *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
}
} else if constexpr (std::is_same_v<Tm, sycl::ext::oneapi::bfloat16>) {
__mma_bf16_m16n16k16_mma_f32(
Expand All @@ -542,15 +551,29 @@ void joint_matrix_mad_cuda(
auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
if constexpr (std::is_same_v<Tc, float>) {
__hmma_m8n32k16_mma_f32f32(
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<const float *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
if constexpr (std::is_same<Td, float>::value) {
__hmma_m8n32k16_mma_f32f32(
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<const float *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
} else {
__hmma_m8n32k16_mma_f16f32(
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<const float *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
} else if constexpr (std::is_same_v<Tc, half>) {
__hmma_m8n32k16_mma_f16f16(
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<const int32_t *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
if constexpr (std::is_same<Td, float>::value) {
__hmma_m8n32k16_mma_f32f16(
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<const int32_t *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
} else {
__hmma_m8n32k16_mma_f16f16(
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<const int32_t *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
}
} else if constexpr (std::is_same_v<Tm, sycl::ext::oneapi::bfloat16>) {
__mma_bf16_m8n32k16_mma_f32(
Expand Down Expand Up @@ -581,25 +604,40 @@ void joint_matrix_mad_cuda(
reinterpret_cast<const float *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
} else if constexpr (std::is_same_v<Tm, half>) {

auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
if constexpr (std::is_same_v<Tc, float>) {
__hmma_m32n8k16_mma_f32f32(
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<const float *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
if constexpr (std::is_same<Td, float>::value) {
__hmma_m32n8k16_mma_f32f32(
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<const float *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
} else {
__hmma_m32n8k16_mma_f16f32(
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<const float *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
} else if constexpr (std::is_same_v<Tc, half>) {
__hmma_m32n8k16_mma_f16f16(
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<const int32_t *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
if constexpr (std::is_same<Td, float>::value) {
__hmma_m32n8k16_mma_f32f16(
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<const int32_t *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
} else {
__hmma_m32n8k16_mma_f16f16(
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<const int32_t *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
}
}
} else if constexpr (M == 16 && N == 16 && K == 8) {
__mma_tf32_m16n16k8_mma_f32(reinterpret_cast<float *>(&D.wi_marray),
reinterpret_cast<int32_t *>(&A.wi_marray),
reinterpret_cast<int32_t *>(&B.wi_marray),
reinterpret_cast<float *>(&C.wi_marray),
reinterpret_cast<const int32_t *>(&A.wi_marray),
reinterpret_cast<const int32_t *>(&B.wi_marray),
reinterpret_cast<const float *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
} else if constexpr (std::is_same_v<Tm, double>) {
__dmma_m8n8k4_mma_f64(reinterpret_cast<double *>(&D.wi_marray),
Expand Down
Loading