Skip to content

Commit 8cae553

Browse files
authored
[SYCL][CUDA] joint_matrix required changes following #11215 (#11563)
As discussed in #11215 this patch: - removed mutable from `joint_matrix_cuda` (This change requires an upstream llvm patch (https://reviews.llvm.org/rGb781c7ab574f)) - removed `get_wi_data()` I also added back the cases that the change in the `joint_matrix_mad` interface allows: namely when the type of C/D matrices differ. I correspondingly updated the tests, to test the new cases that are supported. I also updated the support matrix for cuda in the spec doc for the newly supported combinations. --------- Signed-off-by: JackAKirk <[email protected]>
1 parent ef70eeb commit 8cae553

File tree

10 files changed

+270
-394
lines changed

10 files changed

+270
-394
lines changed

clang/include/clang/Basic/BuiltinsNVPTX.def

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2545,22 +2545,22 @@ TARGET_BUILTIN(__hmma_m16n16k16_ld_a, "vi*iC*UiIi", "", AND(SM_70,PTX60))
25452545
TARGET_BUILTIN(__hmma_m16n16k16_ld_b, "vi*iC*UiIi", "", AND(SM_70,PTX60))
25462546
TARGET_BUILTIN(__hmma_m16n16k16_ld_c_f16, "vi*iC*UiIi", "", AND(SM_70,PTX60))
25472547
TARGET_BUILTIN(__hmma_m16n16k16_ld_c_f32, "vf*fC*UiIi", "", AND(SM_70,PTX60))
2548-
TARGET_BUILTIN(__hmma_m16n16k16_st_c_f16, "vi*i*UiIi", "", AND(SM_70,PTX60))
2549-
TARGET_BUILTIN(__hmma_m16n16k16_st_c_f32, "vf*f*UiIi", "", AND(SM_70,PTX60))
2548+
TARGET_BUILTIN(__hmma_m16n16k16_st_c_f16, "vi*iC*UiIi", "", AND(SM_70,PTX60))
2549+
TARGET_BUILTIN(__hmma_m16n16k16_st_c_f32, "vf*fC*UiIi", "", AND(SM_70,PTX60))
25502550

25512551
TARGET_BUILTIN(__hmma_m32n8k16_ld_a, "vi*iC*UiIi", "", AND(SM_70,PTX61))
25522552
TARGET_BUILTIN(__hmma_m32n8k16_ld_b, "vi*iC*UiIi", "", AND(SM_70,PTX61))
25532553
TARGET_BUILTIN(__hmma_m32n8k16_ld_c_f16, "vi*iC*UiIi", "", AND(SM_70,PTX61))
25542554
TARGET_BUILTIN(__hmma_m32n8k16_ld_c_f32, "vf*fC*UiIi", "", AND(SM_70,PTX61))
2555-
TARGET_BUILTIN(__hmma_m32n8k16_st_c_f16, "vi*i*UiIi", "", AND(SM_70,PTX61))
2556-
TARGET_BUILTIN(__hmma_m32n8k16_st_c_f32, "vf*f*UiIi", "", AND(SM_70,PTX61))
2555+
TARGET_BUILTIN(__hmma_m32n8k16_st_c_f16, "vi*iC*UiIi", "", AND(SM_70,PTX61))
2556+
TARGET_BUILTIN(__hmma_m32n8k16_st_c_f32, "vf*fC*UiIi", "", AND(SM_70,PTX61))
25572557

25582558
TARGET_BUILTIN(__hmma_m8n32k16_ld_a, "vi*iC*UiIi", "", AND(SM_70,PTX61))
25592559
TARGET_BUILTIN(__hmma_m8n32k16_ld_b, "vi*iC*UiIi", "", AND(SM_70,PTX61))
25602560
TARGET_BUILTIN(__hmma_m8n32k16_ld_c_f16, "vi*iC*UiIi", "", AND(SM_70,PTX61))
25612561
TARGET_BUILTIN(__hmma_m8n32k16_ld_c_f32, "vf*fC*UiIi", "", AND(SM_70,PTX61))
2562-
TARGET_BUILTIN(__hmma_m8n32k16_st_c_f16, "vi*i*UiIi", "", AND(SM_70,PTX61))
2563-
TARGET_BUILTIN(__hmma_m8n32k16_st_c_f32, "vf*f*UiIi", "", AND(SM_70,PTX61))
2562+
TARGET_BUILTIN(__hmma_m8n32k16_st_c_f16, "vi*iC*UiIi", "", AND(SM_70,PTX61))
2563+
TARGET_BUILTIN(__hmma_m8n32k16_st_c_f32, "vf*fC*UiIi", "", AND(SM_70,PTX61))
25642564

25652565
TARGET_BUILTIN(__hmma_m16n16k16_mma_f16f16, "vi*iC*iC*iC*IiIi", "", AND(SM_70,PTX60))
25662566
TARGET_BUILTIN(__hmma_m16n16k16_mma_f32f16, "vf*iC*iC*iC*IiIi", "", AND(SM_70,PTX60))

sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,6 @@ The complete set of matrix data types and shapes that are supported by
918918
the `ext_oneapi_cuda` backend are represented in the following
919919
table. In this architecture's implementation,
920920
the type of the A matrix must be the same as the type of the B
921-
matrix. Also, the type of the C matrix must be the same as the type of the D
922921
matrix.
923922

924923
IMPORTANT: When compiling for the `ext_oneapi_cuda` backend the target
@@ -933,29 +932,37 @@ supported parameter combination is specified in the following table.
933932

934933
[frame="none",options="header"]
935934
|======================
936-
| A and B type | C and D type | M | N | K | Minimum Compute Capability
937-
.3+| `matrix_type::fp16` .3+| `matrix_type::fp32`
938-
|16 |16 |16 .6+| sm_70
935+
| A and B type | C type | D type | M | N | K | Minimum Compute Capability
936+
.3+| `matrix_type::fp16` .3+| `matrix_type::fp32` .3+| `matrix_type::fp32`
937+
|16 |16 |16 .12+| sm_70
939938
|8 |32 |16
940939
|32 |8 |16
941-
.3+| `matrix_type::fp16` .3+| `matrix_type::fp16`
940+
.3+| `matrix_type::fp16` .3+| `matrix_type::fp16` .3+| `matrix_type::fp16`
942941
|16 |16 |16
943942
|8 |32 |16
944943
|32 |8 |16
945-
.3+| `matrix_type::sint8` .3+| `matrix_type::sint32`
944+
.3+| `matrix_type::fp16` .3+| `matrix_type::fp32` .3+| `matrix_type::fp16`
945+
|16 |16 |16
946+
|8 |32 |16
947+
|32 |8 |16
948+
.3+| `matrix_type::fp16` .3+| `matrix_type::fp16` .3+| `matrix_type::fp32`
949+
|16 |16 |16
950+
|8 |32 |16
951+
|32 |8 |16
952+
.3+| `matrix_type::sint8` .3+| `matrix_type::sint32` .3+| `matrix_type::sint32`
946953
|16 |16 |16 .6+| sm_72
947954
|8 |32 |16
948955
|32 |8 |16
949-
.3+|`matrix_type::uint8` .3+|`matrix_type::sint32`
956+
.3+|`matrix_type::uint8` .3+|`matrix_type::sint32` .3+|`matrix_type::sint32`
950957
|16 |16 |16
951958
|8 |32 |16
952959
|32 |8 |16
953-
| `matrix_type::tf32` | `matrix_type::fp32` |16 |16 |8 .5+| sm_80
954-
.3+|`matrix_type::bf16` .3+| `matrix_type::fp32`
960+
| `matrix_type::tf32` | `matrix_type::fp32` | `matrix_type::fp32` |16 |16 |8 .5+| sm_80
961+
.3+|`matrix_type::bf16` .3+| `matrix_type::fp32` .3+| `matrix_type::fp32`
955962
|16 |16 |16
956963
|8 |32 |16
957964
|32 |8 |16
958-
| `matrix_type::fp64` | `matrix_type::fp64` |8 |8 |4
965+
| `matrix_type::fp64` | `matrix_type::fp64` | `matrix_type::fp64` |8 |8 |4
959966
|======================
960967

961968
IMPORTANT: The `stride` argument to `joint_matrix_load` and

sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp

Lines changed: 105 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -357,63 +357,59 @@ template <sycl::ext::oneapi::experimental::matrix::layout Layout, typename T,
357357
size_t NumRows, size_t NumCols, access::address_space Space,
358358
access::decorated IsDecorated>
359359
void store_layoutT(
360-
joint_matrix_cuda<
360+
const joint_matrix_cuda<
361361
T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows,
362362
NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src,
363363
multi_ptr<T, Space, IsDecorated> dst, size_t stride) {
364364
if constexpr (NumRows == 16 && NumCols == 16) {
365365
if constexpr (std::is_same_v<T, float>) {
366-
__hmma_m16n16k16_st_c_f32(dst.get(),
367-
reinterpret_cast<float *>(&src.wi_marray),
368-
stride, get_layout_id<Layout>());
366+
__hmma_m16n16k16_st_c_f32(dst.get(), &src.wi_marray[0], stride,
367+
get_layout_id<Layout>());
369368
} else if constexpr (std::is_same_v<T, int32_t>) {
370-
__imma_m16n16k16_st_c_i32(dst.get(),
371-
reinterpret_cast<int32_t *>(&src.wi_marray),
372-
stride, get_layout_id<Layout>());
369+
__imma_m16n16k16_st_c_i32(dst.get(), &src.wi_marray[0], stride,
370+
get_layout_id<Layout>());
373371
} else if constexpr (std::is_same_v<T, half>) {
374-
__hmma_m16n16k16_st_c_f16(reinterpret_cast<int32_t *>(dst.get()),
375-
reinterpret_cast<int32_t *>(&src.wi_marray),
376-
stride, get_layout_id<Layout>());
372+
__hmma_m16n16k16_st_c_f16(
373+
reinterpret_cast<int32_t *>(dst.get()),
374+
reinterpret_cast<const int32_t *>(&src.wi_marray[0]), stride,
375+
get_layout_id<Layout>());
377376
}
378377
} else if constexpr (NumRows == 8 && NumCols == 32) {
379378
if constexpr (std::is_same_v<T, float>) {
380-
__hmma_m8n32k16_st_c_f32(dst.get(),
381-
reinterpret_cast<float *>(&src.wi_marray),
382-
stride, get_layout_id<Layout>());
379+
__hmma_m8n32k16_st_c_f32(dst.get(), &src.wi_marray[0], stride,
380+
get_layout_id<Layout>());
383381
} else if constexpr (std::is_same_v<T, int32_t>) {
384-
__imma_m8n32k16_st_c_i32(dst.get(),
385-
reinterpret_cast<int32_t *>(&src.wi_marray),
386-
stride, get_layout_id<Layout>());
382+
__imma_m8n32k16_st_c_i32(dst.get(), &src.wi_marray[0], stride,
383+
get_layout_id<Layout>());
387384
} else if constexpr (std::is_same_v<T, half>) {
388-
__hmma_m8n32k16_st_c_f16(reinterpret_cast<int32_t *>(dst.get()),
389-
reinterpret_cast<int32_t *>(&src.wi_marray),
390-
stride, get_layout_id<Layout>());
385+
__hmma_m8n32k16_st_c_f16(
386+
reinterpret_cast<int32_t *>(dst.get()),
387+
reinterpret_cast<const int32_t *>(&src.wi_marray[0]), stride,
388+
get_layout_id<Layout>());
391389
}
392390
} else if constexpr (NumRows == 32 && NumCols == 8) {
393391
if constexpr (std::is_same_v<T, float>) {
394-
__hmma_m32n8k16_st_c_f32(dst.get(),
395-
reinterpret_cast<float *>(&src.wi_marray),
396-
stride, get_layout_id<Layout>());
392+
__hmma_m32n8k16_st_c_f32(dst.get(), &src.wi_marray[0], stride,
393+
get_layout_id<Layout>());
397394
} else if constexpr (std::is_same_v<T, int32_t>) {
398-
__imma_m32n8k16_st_c_i32(dst.get(),
399-
reinterpret_cast<int32_t *>(&src.wi_marray),
400-
stride, get_layout_id<Layout>());
395+
__imma_m32n8k16_st_c_i32(dst.get(), &src.wi_marray[0], stride,
396+
get_layout_id<Layout>());
401397
} else if constexpr (std::is_same_v<T, half>) {
402-
__hmma_m32n8k16_st_c_f16(reinterpret_cast<int32_t *>(dst.get()),
403-
reinterpret_cast<int32_t *>(&src.wi_marray),
404-
stride, get_layout_id<Layout>());
398+
__hmma_m32n8k16_st_c_f16(
399+
reinterpret_cast<int32_t *>(dst.get()),
400+
reinterpret_cast<const int32_t *>(&src.wi_marray[0]), stride,
401+
get_layout_id<Layout>());
405402
}
406403
} else if constexpr (std::is_same_v<T, double>) {
407-
__dmma_m8n8k4_st_c_f64(dst.get(),
408-
reinterpret_cast<double *>(&src.wi_marray), stride,
404+
__dmma_m8n8k4_st_c_f64(dst.get(), &src.wi_marray[0], stride,
409405
get_layout_id<Layout>());
410406
}
411407
}
412408

413409
template <typename T, size_t NumRows, size_t NumCols,
414410
access::address_space Space, access::decorated IsDecorated>
415411
void joint_matrix_store_cuda(
416-
joint_matrix_cuda<
412+
const joint_matrix_cuda<
417413
T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows,
418414
NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src,
419415
multi_ptr<T, Space, IsDecorated> dst, size_t stride,
@@ -465,8 +461,8 @@ constexpr int get_layout_pair_id<
465461
}
466462

467463
template <
468-
typename Tm, typename Tc, std::size_t M, std::size_t K, std::size_t N,
469-
sycl::ext::oneapi::experimental::matrix::layout LayoutA,
464+
typename Tm, typename Tc, typename Td, std::size_t M, std::size_t K,
465+
std::size_t N, sycl::ext::oneapi::experimental::matrix::layout LayoutA,
470466
sycl::ext::oneapi::experimental::matrix::layout LayoutB,
471467
std::enable_if_t<
472468
(LayoutA ==
@@ -480,13 +476,13 @@ template <
480476
bool> = true>
481477
void joint_matrix_mad_cuda(
482478
joint_matrix_cuda<
483-
Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
479+
Td, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
484480
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D,
485-
joint_matrix_cuda<Tm, sycl::ext::oneapi::experimental::matrix::use::a, M, K,
486-
LayoutA> &A,
487-
joint_matrix_cuda<Tm, sycl::ext::oneapi::experimental::matrix::use::b, K, N,
488-
LayoutB> &B,
489-
joint_matrix_cuda<
481+
const joint_matrix_cuda<Tm, sycl::ext::oneapi::experimental::matrix::use::a,
482+
M, K, LayoutA> &A,
483+
const joint_matrix_cuda<Tm, sycl::ext::oneapi::experimental::matrix::use::b,
484+
K, N, LayoutB> &B,
485+
const joint_matrix_cuda<
490486
Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
491487
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &C) {
492488
if constexpr (M == 16 && N == 16 && K == 16) {
@@ -506,16 +502,29 @@ void joint_matrix_mad_cuda(
506502
auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
507503
auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
508504
if constexpr (std::is_same_v<Tc, float>) {
509-
__hmma_m16n16k16_mma_f32f32(
510-
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
511-
reinterpret_cast<const float *>(&C.wi_marray),
512-
get_layout_pair_id<LayoutA, LayoutB>(), 0);
513-
505+
if constexpr (std::is_same<Td, float>::value) {
506+
__hmma_m16n16k16_mma_f32f32(
507+
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
508+
reinterpret_cast<const float *>(&C.wi_marray),
509+
get_layout_pair_id<LayoutA, LayoutB>(), 0);
510+
} else {
511+
__hmma_m16n16k16_mma_f16f32(
512+
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
513+
reinterpret_cast<const float *>(&C.wi_marray),
514+
get_layout_pair_id<LayoutA, LayoutB>(), 0);
515+
}
514516
} else if constexpr (std::is_same_v<Tc, half>) {
515-
__hmma_m16n16k16_mma_f16f16(
516-
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
517-
reinterpret_cast<const int32_t *>(&C.wi_marray),
518-
get_layout_pair_id<LayoutA, LayoutB>(), 0);
517+
if constexpr (std::is_same<Td, float>::value) {
518+
__hmma_m16n16k16_mma_f32f16(
519+
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
520+
reinterpret_cast<const int32_t *>(&C.wi_marray),
521+
get_layout_pair_id<LayoutA, LayoutB>(), 0);
522+
} else {
523+
__hmma_m16n16k16_mma_f16f16(
524+
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
525+
reinterpret_cast<const int32_t *>(&C.wi_marray),
526+
get_layout_pair_id<LayoutA, LayoutB>(), 0);
527+
}
519528
}
520529
} else if constexpr (std::is_same_v<Tm, sycl::ext::oneapi::bfloat16>) {
521530
__mma_bf16_m16n16k16_mma_f32(
@@ -542,15 +551,29 @@ void joint_matrix_mad_cuda(
542551
auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
543552
auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
544553
if constexpr (std::is_same_v<Tc, float>) {
545-
__hmma_m8n32k16_mma_f32f32(
546-
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
547-
reinterpret_cast<const float *>(&C.wi_marray),
548-
get_layout_pair_id<LayoutA, LayoutB>(), 0);
554+
if constexpr (std::is_same<Td, float>::value) {
555+
__hmma_m8n32k16_mma_f32f32(
556+
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
557+
reinterpret_cast<const float *>(&C.wi_marray),
558+
get_layout_pair_id<LayoutA, LayoutB>(), 0);
559+
} else {
560+
__hmma_m8n32k16_mma_f16f32(
561+
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
562+
reinterpret_cast<const float *>(&C.wi_marray),
563+
get_layout_pair_id<LayoutA, LayoutB>(), 0);
564+
}
549565
} else if constexpr (std::is_same_v<Tc, half>) {
550-
__hmma_m8n32k16_mma_f16f16(
551-
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
552-
reinterpret_cast<const int32_t *>(&C.wi_marray),
553-
get_layout_pair_id<LayoutA, LayoutB>(), 0);
566+
if constexpr (std::is_same<Td, float>::value) {
567+
__hmma_m8n32k16_mma_f32f16(
568+
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
569+
reinterpret_cast<const int32_t *>(&C.wi_marray),
570+
get_layout_pair_id<LayoutA, LayoutB>(), 0);
571+
} else {
572+
__hmma_m8n32k16_mma_f16f16(
573+
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
574+
reinterpret_cast<const int32_t *>(&C.wi_marray),
575+
get_layout_pair_id<LayoutA, LayoutB>(), 0);
576+
}
554577
}
555578
} else if constexpr (std::is_same_v<Tm, sycl::ext::oneapi::bfloat16>) {
556579
__mma_bf16_m8n32k16_mma_f32(
@@ -581,25 +604,40 @@ void joint_matrix_mad_cuda(
581604
reinterpret_cast<const float *>(&C.wi_marray),
582605
get_layout_pair_id<LayoutA, LayoutB>(), 0);
583606
} else if constexpr (std::is_same_v<Tm, half>) {
607+
584608
auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
585609
auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
586610
if constexpr (std::is_same_v<Tc, float>) {
587-
__hmma_m32n8k16_mma_f32f32(
588-
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
589-
reinterpret_cast<const float *>(&C.wi_marray),
590-
get_layout_pair_id<LayoutA, LayoutB>(), 0);
611+
if constexpr (std::is_same<Td, float>::value) {
612+
__hmma_m32n8k16_mma_f32f32(
613+
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
614+
reinterpret_cast<const float *>(&C.wi_marray),
615+
get_layout_pair_id<LayoutA, LayoutB>(), 0);
616+
} else {
617+
__hmma_m32n8k16_mma_f16f32(
618+
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
619+
reinterpret_cast<const float *>(&C.wi_marray),
620+
get_layout_pair_id<LayoutA, LayoutB>(), 0);
621+
}
591622
} else if constexpr (std::is_same_v<Tc, half>) {
592-
__hmma_m32n8k16_mma_f16f16(
593-
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
594-
reinterpret_cast<const int32_t *>(&C.wi_marray),
595-
get_layout_pair_id<LayoutA, LayoutB>(), 0);
623+
if constexpr (std::is_same<Td, float>::value) {
624+
__hmma_m32n8k16_mma_f32f16(
625+
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
626+
reinterpret_cast<const int32_t *>(&C.wi_marray),
627+
get_layout_pair_id<LayoutA, LayoutB>(), 0);
628+
} else {
629+
__hmma_m32n8k16_mma_f16f16(
630+
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
631+
reinterpret_cast<const int32_t *>(&C.wi_marray),
632+
get_layout_pair_id<LayoutA, LayoutB>(), 0);
633+
}
596634
}
597635
}
598636
} else if constexpr (M == 16 && N == 16 && K == 8) {
599637
__mma_tf32_m16n16k8_mma_f32(reinterpret_cast<float *>(&D.wi_marray),
600-
reinterpret_cast<int32_t *>(&A.wi_marray),
601-
reinterpret_cast<int32_t *>(&B.wi_marray),
602-
reinterpret_cast<float *>(&C.wi_marray),
638+
reinterpret_cast<const int32_t *>(&A.wi_marray),
639+
reinterpret_cast<const int32_t *>(&B.wi_marray),
640+
reinterpret_cast<const float *>(&C.wi_marray),
603641
get_layout_pair_id<LayoutA, LayoutB>(), 0);
604642
} else if constexpr (std::is_same_v<Tm, double>) {
605643
__dmma_m8n8k4_mma_f64(reinterpret_cast<double *>(&D.wi_marray),

0 commit comments

Comments
 (0)