Skip to content

Commit 50cfdf5

Browse files
authored
[MLIR][ROCDL] Add Scale Convert Packed (B)FP8 <-> (B)F16 Support for GFX950 (#130300)
Add Rocdl support for the following GFX950 instructions: CVT_SCALE_PK_FP8_F16 CVT_SCALE_PK_BF8_F16 CVT_SCALE_PK_FP8_BF16 CVT_SCALE_PK_BF8_BF16 CVT_SCALE_SR_FP8_F16 CVT_SCALE_SR_BF8_F16 CVT_SCALE_SR_FP8_BF16 CVT_SCALE_SR_BF8_BF16 CVT_SCALE_PK_F16_FP8 CVT_SCALE_PK_F16_BF8 CVT_SCALE_F16_FP8 CVT_SCALE_F16_BF8
1 parent 8a43bc2 commit 50cfdf5

File tree

3 files changed

+235
-9
lines changed

3 files changed

+235
-9
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 174 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,20 @@ def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
652652
}];
653653
}
654654

655+
def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
656+
BuildableType<"::mlir::VectorType::get("
657+
"{2},$_builder.getI16Type())">;
658+
659+
def ROCDL_V2F16Type : FixedVectorOfLengthAndType<[2], [F16]>,
660+
BuildableType<"::mlir::VectorType::get("
661+
"{2},$_builder.getF16Type())">;
662+
663+
def ROCDL_V2BF16Type : FixedVectorOfLengthAndType<[2], [BF16]>,
664+
BuildableType<"::mlir::VectorType::get("
665+
"{2},$_builder.getBF16Type())">;
666+
667+
// TODO: The word and byte selectors are immarg in LLVM
668+
// update to be attributes in MLIR
655669
//===---------------------------------------------------------------------===//
656670
// 16-bit float intrinsics
657671
//===---------------------------------------------------------------------===//
@@ -667,10 +681,168 @@ def ROCDL_CvtPkRtz:
667681
}];
668682
}
669683

684+
def ROCDL_CvtScaleF32PkFp8F16 :
685+
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f16", [], [], [Pure], 1>,
686+
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
687+
let summary = "Scale and convert f16 to packed fp8";
688+
let description = [{
689+
Scale `src` by the exponent in `scale` then convert to packed fp8.
690+
Store the result in low/high word based on $wordSel, preserving the other word.
691+
}];
692+
let assemblyFormat = [{
693+
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
694+
}];
695+
}
696+
697+
def ROCDL_CvtScaleF32PkFp8Bf16 :
698+
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.bf16", [], [], [Pure], 1>,
699+
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
700+
let summary = "Scale and convert packed bf16 to packed fp8";
701+
let description = [{
702+
Scale `src` by the exponent in `scale` then convert to packed fp8.
703+
Store the result in low/high word based on $wordSel, preserving the other word.
704+
}];
705+
let assemblyFormat = [{
706+
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
707+
}];
708+
}
709+
710+
711+
def ROCDL_CvtScaleF32PkBf8F16 :
712+
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f16", [], [], [Pure], 1>,
713+
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
714+
let summary = "Scale and convert f16 to packed bf8";
715+
let description = [{
716+
Scale `src` by the exponent in `scale` then convert to packed bf8.
717+
Store the result in low/high word based on $wordSel, preserving the other word.
718+
}];
719+
let assemblyFormat = [{
720+
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
721+
}];
722+
}
723+
724+
725+
def ROCDL_CvtScaleF32PkBf8Bf16 :
726+
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.bf16", [], [], [Pure], 1>,
727+
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
728+
let summary = "Scale and convert bf16 to packed bf8";
729+
let description = [{
730+
Scale `src` by the exponent in `scale` then convert to packed bf8.
731+
Store the result in low/high word based on $wordSel, preserving the other word.
732+
}];
733+
let assemblyFormat = [{
734+
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
735+
}];
736+
}
737+
738+
def ROCDL_CvtScaleF32SrFp8F16 :
739+
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f16", [], [], [Pure], 1>,
740+
Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
741+
let summary = "Scale and convert f16 to packed fp8 using stochastic rounding";
742+
let description = [{
743+
Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding
744+
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
745+
746+
}];
747+
let assemblyFormat = [{
748+
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
749+
}];
750+
}
751+
752+
def ROCDL_CvtScaleF32SrBf8F16 :
753+
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f16", [], [], [Pure], 1>,
754+
Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
755+
let summary = "Scale and convert f16 to packed bf8 using stochastic rounding";
756+
let description = [{
757+
Scale `src` by the exponent in `scale` then convert to packed bf8 with stochastic rounding
758+
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
759+
760+
}];
761+
let assemblyFormat = [{
762+
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
763+
}];
764+
}
765+
766+
def ROCDL_CvtScaleF32SrFp8Bf16 :
767+
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.bf16", [], [], [Pure], 1>,
768+
Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
769+
let summary = "Scale and convert packed bf16 to packed fp8 using stochastic rounding";
770+
let description = [{
771+
Scale `src` by the exponent in `scale` then convert to packed fp8 with stochastic rounding
772+
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
773+
774+
}];
775+
let assemblyFormat = [{
776+
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
777+
}];
778+
}
779+
780+
def ROCDL_CvtScaleF32SrBf8Bf16:
781+
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.bf16", [], [], [Pure], 1>,
782+
Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
783+
let summary = "Scale and convert bf16 to packed fp8 using stochastic rounding";
784+
let description = [{
785+
Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding
786+
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
787+
788+
}];
789+
let assemblyFormat = [{
790+
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
791+
}];
792+
}
793+
794+
def ROCDL_CvtScaleF32PkF16Fp8 :
795+
ROCDL_IntrOp<"cvt.scalef32.pk.f16.fp8", [], [], [Pure], 1>,
796+
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
797+
let summary = "Scale and convert fp8 to packed f16";
798+
let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
799+
then convert to packed f16.
800+
}];
801+
let assemblyFormat = [{
802+
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
803+
}];
804+
}
805+
806+
def ROCDL_CvtScaleF32PkF16Bf8 :
807+
ROCDL_IntrOp<"cvt.scalef32.pk.f16.bf8", [], [], [Pure], 1>,
808+
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
809+
let summary = "Scale and convert bf8 to packed f16";
810+
let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
811+
then convert to packed f16.
812+
}];
813+
let assemblyFormat = [{
814+
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
815+
}];
816+
}
817+
818+
def ROCDL_CvtScaleF16Fp8 :
819+
ROCDL_IntrOp<"cvt.scalef32.f16.fp8", [], [], [Pure], 1>,
820+
Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
821+
let summary = "Scale and convert fp8 to f16";
822+
let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
823+
then convert to f16 store into the `byteSel`th byte of `old`, preserving the others.
824+
}];
825+
let assemblyFormat = [{
826+
attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
827+
}];
828+
}
829+
830+
def ROCDL_CvtScaleF16Bf8 :
831+
ROCDL_IntrOp<"cvt.scalef32.f16.bf8", [], [], [Pure], 1>,
832+
Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
833+
let summary = "Scale and convert fp8 to f16";
834+
let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
835+
then convert to f16 store into the `byteSel`th byte of `old`, preserving the others.
836+
}];
837+
let assemblyFormat = [{
838+
attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
839+
}];
840+
}
841+
670842
//===---------------------------------------------------------------------===//
671843
// 32-bit float intrinsics
672844
//===---------------------------------------------------------------------===//
673-
def ROCDL_CvtScalePkF32Fp8 :
845+
def ROCDL_CvtScale32PkF32Fp8 :
674846
ROCDL_IntrOp<"cvt.scalef32.pk.f32.fp8", [], [], [Pure], 1>,
675847
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
676848
let summary = "Scale and convert packed fp8 to packed f32";
@@ -682,7 +854,7 @@ def ROCDL_CvtScalePkF32Fp8 :
682854
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
683855
}];
684856
}
685-
def ROCDL_CvtScalePkF32Bf8 :
857+
def ROCDL_CvtScale32PkF32Bf8 :
686858
ROCDL_IntrOp<"cvt.scalef32.pk.f32.bf8", [], [], [Pure], 1>,
687859
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
688860
let summary = "Scale and convert packed bf8 to packed f32";
@@ -697,10 +869,6 @@ def ROCDL_CvtScalePkF32Bf8 :
697869
//===---------------------------------------------------------------------===//
698870
// 8-bit float scale intrinsics
699871
//===---------------------------------------------------------------------===//
700-
def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
701-
BuildableType<"::mlir::VectorType::get("
702-
"{2},$_builder.getI16Type())">;
703-
704872
def ROCDL_CvtScaleF32PkFp8F32:
705873
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f32", [], [], [Pure], 1>,
706874
Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32:$scale, I1:$wordSel)> {

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -759,19 +759,27 @@ llvm.func @rocdl.raw.buffer.i32(%rsrc : vector<4xi32>,
759759
llvm.return
760760
}
761761

762-
llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
762+
llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf16, %stoch: i32) -> i32 {
763763
// CHECK-LABEL: @rocdl_8bit_floats
764764
// CHECK: rocdl.cvt.f32.bf8
765765
// CHECK: rocdl.cvt.f32.fp8
766766
// CHECK: rocdl.cvt.scalef32.f32.bf8
767767
// CHECK: rocdl.cvt.scalef32.f32.fp8
768+
// CHECK: rocdl.cvt.scalef32.pk.f16.bf8
769+
// CHECK: rocdl.cvt.scalef32.pk.f16.fp8
770+
// CHECK: rocdl.cvt.scalef32.f16.fp8
771+
// CHECK: rocdl.cvt.scalef32.f16.bf8
768772
// CHECK: rocdl.cvt.pk.bf8.f32
769773
// CHECK: rocdl.cvt.pk.fp8.f32
770774
// CHECK: rocdl.cvt.sr.bf8.f32
771775
// CHECK: rocdl.cvt.sr.fp8.f32
772776
// CHECK: rocdl.cvt.scalef32.sr.fp8.f32
777+
// CHECK: rocdl.cvt.scalef32.sr.fp8.f16
778+
// CHECK: rocdl.cvt.scalef32.sr.fp8.bf16
773779
// CHECK: rocdl.cvt.sr.bf8.f32
774780
// CHECK: rocdl.cvt.scalef32.sr.bf8.f32
781+
// CHECK: rocdl.cvt.scalef32.sr.bf8.f16
782+
// CHECK: rocdl.cvt.scalef32.sr.bf8.bf16
775783
// CHECK: rocdl.cvt.scalef32.pk.f32.fp8
776784
// CHECK: rocdl.cvt.scalef32.pk.f32.bf8
777785
%c0 = llvm.mlir.constant(0 : i32) : i32
@@ -783,13 +791,21 @@ llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
783791
%v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32
784792
%v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[%c0], %c4 : f32
785793
%v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32
794+
%v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[%false], %c4 : vector<2xf16>
795+
%v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : vector<2xf16>
796+
%v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %v3_scaled[%c0] : f16
797+
%v6 = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %v3_scaled[%c0] : f16
786798
%source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
787799
%source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
788800
%source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
789801
%source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
790802
%source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32
803+
%source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[%c3] : i32
804+
%source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[%c3] : i32
791805
%source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c3] : i32
792806
%source6_scaled = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[%c3] : i32
807+
%source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[%c3] : i32
808+
%source6_scaled_bfloat = rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[%c3] : i32
793809
%source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[%false], %c4 : f32
794810
%source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[%false], %c4 : f32
795811
llvm.return %source5 : i32
@@ -805,6 +821,18 @@ llvm.func @rocdl_8bit_packed_v2i16(%sourceA: f32, %sourceB: f32, %old: vector<2x
805821
llvm.return %source_scaled : vector<2xi16>
806822
}
807823

824+
llvm.func @rocdl_v2f16_v2i16(%source: vector<2xf16>, %source2: vector<2xbf16>, %old: vector<2xi16>) -> vector<2xi16> {
825+
// CHECK-LABEL: @rocdl_v2f16_v2i16
826+
// CHECK: rocdl.cvt.scalef32.pk.fp8.f16
827+
%c0 = llvm.mlir.constant(1.0 : f32) : f32
828+
%false = llvm.mlir.constant(false) : i1
829+
%source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
830+
%source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
831+
%source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
832+
%source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
833+
llvm.return %source_scaled : vector<2xi16>
834+
}
835+
808836
llvm.func @rocdl.s.waitcnt() {
809837
// CHECK-LABEL: rocdl.s.waitcnt
810838
// CHECK: rocdl.s.waitcnt 0

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,22 +1032,29 @@ llvm.func @rocdl.raw.buffer.atomic.cmpswap(%rsrc : vector<4xi32>,
10321032
llvm.return %val : i32
10331033
}
10341034

1035-
llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
1035+
llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf16, %source_packed: vector<2xf16>, %stoch: i32) -> i32 {
10361036
// CHECK-LABEL: @rocdl_8bit_floats
10371037
// CHECK: call float @llvm.amdgcn.cvt.f32.bf8(i32 %{{.+}}, i32 0)
10381038
// CHECK: call float @llvm.amdgcn.cvt.scalef32.f32.bf8(i32 %{{.+}}, float 1.000000e+00, i32 0)
10391039
// CHECK: call float @llvm.amdgcn.cvt.f32.fp8(i32 %{{.+}}, i32 0)
10401040
// CHECK: call float @llvm.amdgcn.cvt.scalef32.f32.fp8(i32 %{{.+}}, float 1.000000e+00, i32 0)
1041+
// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.pk.f16.bf8(i32 %{{.+}}, float 1.000000e+00, i1 false)
1042+
// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.pk.f16.fp8(i32 %{{.+}}, float 1.000000e+00, i1 false)
1043+
// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.f16.fp8(<2 x half> %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 0, i1 false)
1044+
// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.f16.bf8(<2 x half> %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 0, i1 false)
10411045
// CHECK: call i32 @llvm.amdgcn.cvt.pk.bf8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false)
10421046
// CHECK: call i32 @llvm.amdgcn.cvt.pk.fp8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false)
10431047
// CHECK: call i32 @llvm.amdgcn.cvt.sr.bf8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 2)
10441048
// CHECK: call i32 @llvm.amdgcn.cvt.sr.fp8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 3)
10451049
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.fp8.f32(i32 %{{.+}}, float %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
1050+
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.fp8.f16(i32 %{{.+}}, half %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
1051+
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.fp8.bf16(i32 %{{.+}}, bfloat %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
10461052
// CHECK: call i32 @llvm.amdgcn.cvt.sr.bf8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 3)
10471053
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.bf8.f32(i32 %{{.+}}, float %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
1054+
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.bf8.f16(i32 %{{.+}}, half %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
1055+
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.bf8.bf16(i32 %{{.+}}, bfloat %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
10481056
// CHECK: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.fp8(i32 %{{.+}}, float 1.000000e+00, i1 false)
10491057
// CHECK: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.bf8(i32 %{{.+}}, float 1.000000e+00, i1 false)
1050-
10511058
%c0 = llvm.mlir.constant(0 : i32) : i32
10521059
%c2 = llvm.mlir.constant(2 : i32) : i32
10531060
%c3 = llvm.mlir.constant(3 : i32) : i32
@@ -1057,13 +1064,21 @@ llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
10571064
%v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[%c0], %c4 : f32
10581065
%v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32
10591066
%v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32
1067+
%v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[%false], %c4 : i32
1068+
%v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : i32
1069+
%v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %source_packed[%c0] : f16
1070+
%v6 = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %source_packed[%c0] : f16
10601071
%source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
10611072
%source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
10621073
%source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
10631074
%source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
10641075
%source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32
1076+
%source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[%c3] : i32
1077+
%source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[%c3] : i32
10651078
%source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c3] : i32
10661079
%source6_scaled = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[%c3] : i32
1080+
%source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[%c3] : i32
1081+
%source6_scaled_bfloat = rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[%c3] : i32
10671082
%source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[%false], %c4 : f32
10681083
%source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[%false], %c4 : f32
10691084
llvm.return %source5 : i32
@@ -1080,6 +1095,21 @@ llvm.func @rocdl_8bit_packed_v2i16(%sourceA: f32, %sourceB: f32, %old: vector<2x
10801095
llvm.return %source_scaled : vector<2xi16>
10811096
}
10821097

1098+
llvm.func @rocdl_v2f16_v2i16(%source: vector<2xf16>, %source2: vector<2xbf16>, %old: vector<2xi16>) -> vector<2xi16> {
1099+
// CHECK-LABEL: @rocdl_v2f16_v2i16
1100+
// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.fp8.f16(<2 x i16> %2, <2 x half> %0, float 1.000000e+00, i1 false)
1101+
// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.fp8.bf16(<2 x i16> %2, <2 x bfloat> %1, float 1.000000e+00, i1 false)
1102+
// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.bf8.f16(<2 x i16> %2, <2 x half> %0, float 1.000000e+00, i1 false)
1103+
// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.bf8.bf16(<2 x i16> %2, <2 x bfloat> %1, float 1.000000e+00, i1 false)
1104+
%c0 = llvm.mlir.constant(1.0 : f32) : f32
1105+
%false = llvm.mlir.constant(false) : i1
1106+
%source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
1107+
%source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
1108+
%source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
1109+
%source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
1110+
llvm.return %source_scaled : vector<2xi16>
1111+
}
1112+
10831113
llvm.func @rocdl_16bit_packed_floats(%sourceA: f32, %sourceB: f32) -> vector<2xf16> {
10841114
// CHECK-LABEL: @rocdl_16bit_packed_floats
10851115
// CHECK: call <2 x half> @llvm.amdgcn.cvt.pkrtz(float {{.*}}, float {{.*}})

0 commit comments

Comments
 (0)