Skip to content

Commit 7a39341

Browse files
committed
add f16 convert instructions
1 parent c8dd852 commit 7a39341

File tree

3 files changed

+231
-7
lines changed

3 files changed

+231
-7
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 170 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,18 @@ def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
652652
}];
653653
}
654654

655+
def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
656+
BuildableType<"::mlir::VectorType::get("
657+
"{2},$_builder.getI16Type())">;
658+
659+
def ROCDL_V2F16Type : FixedVectorOfLengthAndType<[2], [F16]>,
660+
BuildableType<"::mlir::VectorType::get("
661+
"{2},$_builder.getF16Type())">;
662+
663+
def ROCDL_V2BF16Type : FixedVectorOfLengthAndType<[2], [BF16]>,
664+
BuildableType<"::mlir::VectorType::get("
665+
"{2},$_builder.getBF16Type())">;
666+
655667
//===---------------------------------------------------------------------===//
656668
// 16-bit float intrinsics
657669
//===---------------------------------------------------------------------===//
@@ -667,6 +679,164 @@ def ROCDL_CvtPkRtz:
667679
}];
668680
}
669681

682+
def ROCDL_CvtScalePkFp8F16 :
683+
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f16", [], [], [Pure], 1>,
684+
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
685+
let summary = "Scale and convert f16 to packed fp8";
686+
let description = [{
687+
Scale `src` by the exponent in `scale` then convert to packed fp8.
688+
Store the result in low/high word based on $wordSel, preserving the other word.
689+
}];
690+
let assemblyFormat = [{
691+
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
692+
}];
693+
}
694+
695+
def ROCDL_CvtScalePkFp8Bf16 :
696+
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.bf16", [], [], [Pure], 1>,
697+
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
698+
let summary = "Scale and convert packed bf16 to packed fp8";
699+
let description = [{
700+
Scale `src` by the exponent in `scale` then convert to packed fp8.
701+
Store the result in low/high word based on $wordSel, preserving the other word.
702+
}];
703+
let assemblyFormat = [{
704+
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
705+
}];
706+
}
707+
708+
709+
def ROCDL_CvtScalePkBf8F16 :
710+
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f16", [], [], [Pure], 1>,
711+
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
712+
let summary = "Scale and convert f16 to packed bf8";
713+
let description = [{
714+
Scale `src` by the exponent in `scale` then convert to packed bf8.
715+
Store the result in low/high word based on $wordSel, preserving the other word.
716+
}];
717+
let assemblyFormat = [{
718+
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
719+
}];
720+
}
721+
722+
723+
def ROCDL_CvtScalePkBf8Bf16 :
724+
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.bf16", [], [], [Pure], 1>,
725+
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
726+
let summary = "Scale and convert bf16 to packed bf8";
727+
let description = [{
728+
Scale `src` by the exponent in `scale` then convert to packed bf8.
729+
Store the result in low/high word based on $wordSel, preserving the other word.
730+
}];
731+
let assemblyFormat = [{
732+
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
733+
}];
734+
}
735+
736+
def ROCDL_CvtScaleSrFp8F16 :
737+
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f16", [], [], [Pure], 1>,
738+
Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
739+
let summary = "Scale and convert f16 to packed fp8 using stochastic rounding";
740+
let description = [{
741+
Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding
742+
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
743+
744+
}];
745+
let assemblyFormat = [{
746+
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
747+
}];
748+
}
749+
750+
def ROCDL_CvtScaleSrBf8F16 :
751+
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f16", [], [], [Pure], 1>,
752+
Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
753+
let summary = "Scale and convert f16 to packed bf8 using stochastic rounding";
754+
let description = [{
755+
Scale `src` by the exponent in `scale` then convert to packed bf8 with stochastic rounding
756+
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
757+
758+
}];
759+
let assemblyFormat = [{
760+
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
761+
}];
762+
}
763+
764+
def ROCDL_CvtScaleSrFp8Bf16 :
765+
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.bf16", [], [], [Pure], 1>,
766+
Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
767+
let summary = "Scale and convert packed bf16 to packed fp8 using stochastic rounding";
768+
let description = [{
769+
Scale `src` by the exponent in `scale` then convert to packed fp8 with stochastic rounding
770+
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
771+
772+
}];
773+
let assemblyFormat = [{
774+
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
775+
}];
776+
}
777+
778+
def ROCDL_CvtScaleSrBf8Bf16:
779+
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.bf16", [], [], [Pure], 1>,
780+
Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
781+
let summary = "Scale and convert bf16 to packed fp8 using stochastic rounding";
782+
let description = [{
783+
Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding
784+
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
785+
786+
}];
787+
let assemblyFormat = [{
788+
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
789+
}];
790+
}
791+
792+
def ROCDL_CvtScalePkF16Fp8 :
793+
ROCDL_IntrOp<"cvt.scalef32.pk.f16.fp8", [], [], [Pure], 1>,
794+
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
795+
let summary = "Scale and convert fp8 to packed f16";
796+
let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
797+
then convert to packed f16.
798+
}];
799+
let assemblyFormat = [{
800+
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
801+
}];
802+
}
803+
804+
def ROCDL_CvtScalePkF16Bf8 :
805+
ROCDL_IntrOp<"cvt.scalef32.pk.f16.bf8", [], [], [Pure], 1>,
806+
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
807+
let summary = "Scale and convert bf8 to packed f16";
808+
let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
809+
then convert to packed f16.
810+
}];
811+
let assemblyFormat = [{
812+
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
813+
}];
814+
}
815+
816+
def ROCDL_CvtScaleF16Fp8 :
817+
ROCDL_IntrOp<"cvt.scalef32.f16.fp8", [], [], [Pure], 1>,
818+
Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
819+
let summary = "Scale and convert fp8 to f16";
820+
let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
821+
then convert to f16 store into the `byteSel`th byte of `old`, preserving the others.
822+
}];
823+
let assemblyFormat = [{
824+
attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
825+
}];
826+
}
827+
828+
def ROCDL_CvtScaleF16Bf8 :
829+
ROCDL_IntrOp<"cvt.scalef32.f16.bf8", [], [], [Pure], 1>,
830+
Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
831+
let summary = "Scale and convert fp8 to f16";
832+
let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
833+
then convert to f16 store into the `byteSel`th byte of `old`, preserving the others.
834+
}];
835+
let assemblyFormat = [{
836+
attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
837+
}];
838+
}
839+
670840
//===---------------------------------------------------------------------===//
671841
// 32-bit float intrinsics
672842
//===---------------------------------------------------------------------===//
@@ -697,10 +867,6 @@ def ROCDL_CvtScalePkF32Bf8 :
697867
//===---------------------------------------------------------------------===//
698868
// 8-bit float scale intrinsics
699869
//===---------------------------------------------------------------------===//
700-
def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
701-
BuildableType<"::mlir::VectorType::get("
702-
"{2},$_builder.getI16Type())">;
703-
704870
def ROCDL_CvtScaleF32PkFp8F32:
705871
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f32", [], [], [Pure], 1>,
706872
Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32:$scale, I1:$wordSel)> {

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -759,19 +759,27 @@ llvm.func @rocdl.raw.buffer.i32(%rsrc : vector<4xi32>,
759759
llvm.return
760760
}
761761

762-
llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
762+
llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf16, %stoch: i32) -> i32 {
763763
// CHECK-LABEL: @rocdl_8bit_floats
764764
// CHECK: rocdl.cvt.f32.bf8
765765
// CHECK: rocdl.cvt.f32.fp8
766766
// CHECK: rocdl.cvt.scalef32.f32.bf8
767767
// CHECK: rocdl.cvt.scalef32.f32.fp8
768+
// CHECK: rocdl.cvt.scalef32.pk.f16.bf8
769+
// CHECK: rocdl.cvt.scalef32.pk.f16.fp8
770+
// CHECK: rocdl.cvt.scalef32.f16.fp8
771+
// CHECK: rocdl.cvt.scalef32.f16.bf8
768772
// CHECK: rocdl.cvt.pk.bf8.f32
769773
// CHECK: rocdl.cvt.pk.fp8.f32
770774
// CHECK: rocdl.cvt.sr.bf8.f32
771775
// CHECK: rocdl.cvt.sr.fp8.f32
772776
// CHECK: rocdl.cvt.scalef32.sr.fp8.f32
777+
// CHECK: rocdl.cvt.scalef32.sr.fp8.f16
778+
// CHECK: rocdl.cvt.scalef32.sr.fp8.bf16
773779
// CHECK: rocdl.cvt.sr.bf8.f32
774780
// CHECK: rocdl.cvt.scalef32.sr.bf8.f32
781+
// CHECK: rocdl.cvt.scalef32.sr.bf8.f16
782+
// CHECK: rocdl.cvt.scalef32.sr.bf8.bf16
775783
// CHECK: rocdl.cvt.scalef32.pk.f32.fp8
776784
// CHECK: rocdl.cvt.scalef32.pk.f32.bf8
777785
%c0 = llvm.mlir.constant(0 : i32) : i32
@@ -783,13 +791,21 @@ llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
783791
%v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32
784792
%v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[%c0], %c4 : f32
785793
%v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32
794+
%v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[%false], %c4 : vector<2xf16>
795+
%v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : vector<2xf16>
796+
%v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %v3_scaled[%c0] : f16
797+
%v6 = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %v3_scaled[%c0] : f16
786798
%source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
787799
%source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
788800
%source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
789801
%source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
790802
%source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32
803+
%source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[%c3] : i32
804+
%source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[%c3] : i32
791805
%source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c3] : i32
792806
%source6_scaled = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[%c3] : i32
807+
%source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[%c3] : i32
808+
%source6_scaled_bfloat = rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[%c3] : i32
793809
%source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[%false], %c4 : f32
794810
%source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[%false], %c4 : f32
795811
llvm.return %source5 : i32
@@ -805,6 +821,18 @@ llvm.func @rocdl_8bit_packed_v2i16(%sourceA: f32, %sourceB: f32, %old: vector<2x
805821
llvm.return %source_scaled : vector<2xi16>
806822
}
807823

824+
llvm.func @rocdl_v2f16_v2i16(%source: vector<2xf16>, %source2: vector<2xbf16>, %old: vector<2xi16>) -> vector<2xi16> {
825+
// CHECK-LABEL: @rocdl_v2f16_v2i16
826+
// CHECK: rocdl.cvt.scalef32.pk.fp8.f16
827+
%c0 = llvm.mlir.constant(1.0 : f32) : f32
828+
%false = llvm.mlir.constant(false) : i1
829+
%source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
830+
%source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
831+
%source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
832+
%source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
833+
llvm.return %source_scaled : vector<2xi16>
834+
}
835+
808836
llvm.func @rocdl.s.waitcnt() {
809837
// CHECK-LABEL: rocdl.s.waitcnt
810838
// CHECK: rocdl.s.waitcnt 0

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,22 +1032,29 @@ llvm.func @rocdl.raw.buffer.atomic.cmpswap(%rsrc : vector<4xi32>,
10321032
llvm.return %val : i32
10331033
}
10341034

1035-
llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
1035+
llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf16, %source_packed: vector<2xf16>, %stoch: i32) -> i32 {
10361036
// CHECK-LABEL: @rocdl_8bit_floats
10371037
// CHECK: call float @llvm.amdgcn.cvt.f32.bf8(i32 %{{.+}}, i32 0)
10381038
// CHECK: call float @llvm.amdgcn.cvt.scalef32.f32.bf8(i32 %{{.+}}, float 1.000000e+00, i32 0)
10391039
// CHECK: call float @llvm.amdgcn.cvt.f32.fp8(i32 %{{.+}}, i32 0)
10401040
// CHECK: call float @llvm.amdgcn.cvt.scalef32.f32.fp8(i32 %{{.+}}, float 1.000000e+00, i32 0)
1041+
// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.pk.f16.bf8(i32 %{{.+}}, float 1.000000e+00, i1 false)
1042+
// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.pk.f16.fp8(i32 %{{.+}}, float 1.000000e+00, i1 false)
1043+
// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.f16.fp8(<2 x half> %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 0, i1 false)
1044+
// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.f16.bf8(<2 x half> %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 0, i1 false)
10411045
// CHECK: call i32 @llvm.amdgcn.cvt.pk.bf8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false)
10421046
// CHECK: call i32 @llvm.amdgcn.cvt.pk.fp8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false)
10431047
// CHECK: call i32 @llvm.amdgcn.cvt.sr.bf8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 2)
10441048
// CHECK: call i32 @llvm.amdgcn.cvt.sr.fp8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 3)
10451049
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.fp8.f32(i32 %{{.+}}, float %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
1050+
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.fp8.f16(i32 %{{.+}}, half %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
1051+
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.fp8.bf16(i32 %{{.+}}, bfloat %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
10461052
// CHECK: call i32 @llvm.amdgcn.cvt.sr.bf8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 3)
10471053
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.bf8.f32(i32 %{{.+}}, float %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
1054+
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.bf8.f16(i32 %{{.+}}, half %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
1055+
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.bf8.bf16(i32 %{{.+}}, bfloat %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
10481056
// CHECK: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.fp8(i32 %{{.+}}, float 1.000000e+00, i1 false)
10491057
// CHECK: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.bf8(i32 %{{.+}}, float 1.000000e+00, i1 false)
1050-
10511058
%c0 = llvm.mlir.constant(0 : i32) : i32
10521059
%c2 = llvm.mlir.constant(2 : i32) : i32
10531060
%c3 = llvm.mlir.constant(3 : i32) : i32
@@ -1057,13 +1064,21 @@ llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
10571064
%v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[%c0], %c4 : f32
10581065
%v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32
10591066
%v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32
1067+
%v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[%false], %c4 : i32
1068+
%v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : i32
1069+
%v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %source_packed[%c0] : f16
1070+
%v6 = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %source_packed[%c0] : f16
10601071
%source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
10611072
%source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
10621073
%source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
10631074
%source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
10641075
%source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32
1076+
%source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[%c3] : i32
1077+
%source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[%c3] : i32
10651078
%source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c3] : i32
10661079
%source6_scaled = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[%c3] : i32
1080+
%source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[%c3] : i32
1081+
%source6_scaled_bfloat = rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[%c3] : i32
10671082
%source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[%false], %c4 : f32
10681083
%source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[%false], %c4 : f32
10691084
llvm.return %source5 : i32
@@ -1080,6 +1095,21 @@ llvm.func @rocdl_8bit_packed_v2i16(%sourceA: f32, %sourceB: f32, %old: vector<2x
10801095
llvm.return %source_scaled : vector<2xi16>
10811096
}
10821097

1098+
llvm.func @rocdl_v2f16_v2i16(%source: vector<2xf16>, %source2: vector<2xbf16>, %old: vector<2xi16>) -> vector<2xi16> {
1099+
// CHECK-LABEL: @rocdl_v2f16_v2i16
1100+
// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.fp8.f16(<2 x i16> %2, <2 x half> %0, float 1.000000e+00, i1 false)
1101+
// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.fp8.bf16(<2 x i16> %2, <2 x bfloat> %1, float 1.000000e+00, i1 false)
1102+
// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.bf8.f16(<2 x i16> %2, <2 x half> %0, float 1.000000e+00, i1 false)
1103+
// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.bf8.bf16(<2 x i16> %2, <2 x bfloat> %1, float 1.000000e+00, i1 false)
1104+
%c0 = llvm.mlir.constant(1.0 : f32) : f32
1105+
%false = llvm.mlir.constant(false) : i1
1106+
%source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
1107+
%source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
1108+
%source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
1109+
%source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
1110+
llvm.return %source_scaled : vector<2xi16>
1111+
}
1112+
10831113
llvm.func @rocdl_16bit_packed_floats(%sourceA: f32, %sourceB: f32) -> vector<2xf16> {
10841114
// CHECK-LABEL: @rocdl_16bit_packed_floats
10851115
// CHECK: call <2 x half> @llvm.amdgcn.cvt.pkrtz(float {{.*}}, float {{.*}})

0 commit comments

Comments
 (0)