Skip to content

Commit 6f35a9e

Browse files
authored
[MLIR][ROCDL] Add Scale Convert Packed FP8 <-> F32 Support for GFX950 (#125564)
Add Rocdl support for the following GFX950 instructions: CVT_SCALE_PK_FP8_F32 CVT_SCALE_PK_BF8_F32 CVT_SCALE_SR_FP8_F32 CVT_SCALE_SR_BF8_F32 CVT_SCALE_PK_F32_FP8 CVT_SCALE_PK_F32_BF8 CVT_SCALE_F32_FP8 CVT_SCALE_F32_BF8
1 parent 51d0ad7 commit 6f35a9e

File tree

3 files changed

+170
-0
lines changed

3 files changed

+170
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,93 @@ def ROCDL_CvtPkRtz:
736736
}];
737737
}
738738

739+
//===---------------------------------------------------------------------===//
740+
// 32-bit float intrinsics
741+
//===---------------------------------------------------------------------===//
742+
def ROCDL_CvtScalePkF32Fp8 :
743+
ROCDL_IntrOp<"cvt.scalef32.pk.f32.fp8", [], [], [Pure], 1>,
744+
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
745+
let summary = "Scale and convert packed fp8 to packed f32";
746+
let description = [{
747+
Scale `src` by the exponent in `scale` then convert to packed fp32.
748+
Store the result in low/high word based on $wordSel, preserving the other word.
749+
}];
750+
let assemblyFormat = [{
751+
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
752+
}];
753+
}
754+
def ROCDL_CvtScalePkF32Bf8 :
755+
ROCDL_IntrOp<"cvt.scalef32.pk.f32.bf8", [], [], [Pure], 1>,
756+
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
757+
let summary = "Scale and convert packed bf8 to packed f32";
758+
let description = [{
759+
Scale `src` by the exponent in `scale` then convert to packed fp32.
760+
Store the result in low/high word based on $wordSel, preserving the other word.
761+
}];
762+
let assemblyFormat = [{
763+
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
764+
}];
765+
}
766+
//===---------------------------------------------------------------------===//
767+
// 8-bit float scale intrinsics
768+
//===---------------------------------------------------------------------===//
769+
def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
770+
BuildableType<"::mlir::VectorType::get("
771+
"{2},$_builder.getI16Type())">;
772+
773+
def ROCDL_CvtScaleF32PkFp8F32:
774+
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f32", [], [], [Pure], 1>,
775+
Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32:$scale, I1:$wordSel)> {
776+
let summary = "Scale and convert two f32's to packed fp8";
777+
let description = [{
778+
Scale `srcA` and `srcB` by the exponent in `scale` then convert to packed fp8
779+
and store into the low/high word of `old`, preserving the other word.
780+
}];
781+
let assemblyFormat = [{
782+
attr-dict $srcA `,` $srcB `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
783+
}];
784+
}
785+
786+
def ROCDL_CvtScaleF32PkBf8F32:
787+
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f32", [], [], [Pure], 1>,
788+
Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32: $scale, I1:$wordSel)> {
789+
let summary = "Scale and convert two f32's to packed bf8";
790+
let description = [{
791+
Scale `srcA` and `srcB` by the exponent in `scale` then convert to packed bf8
792+
and store into the low/high word of `old`, preserving the other word.
793+
}];
794+
let assemblyFormat = [{
795+
attr-dict $srcA `,` $srcB `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
796+
}];
797+
}
798+
799+
def ROCDL_CvtScaleF32SrFp8F32:
800+
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f32", [], [], [Pure], 1>,
801+
Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
802+
let summary = "Scale and convert f32 to fp8 using stochastic rounding";
803+
let description = [{
804+
Scale `src` by the exponent in `scale` then convert to fp8 with stochastic rounding
805+
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
806+
}];
807+
let assemblyFormat = [{
808+
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
809+
}];
810+
}
811+
812+
813+
def ROCDL_CvtScaleF32SrBf8F32:
814+
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f32", [], [], [Pure], 1>,
815+
Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
816+
let summary = "Scale and convert f32 to bf8 using stochastic rounding";
817+
let description = [{
818+
Scale `src` by the exponent in `scale` then convert to bf8 with stochastic rounding
819+
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
820+
}];
821+
let assemblyFormat = [{
822+
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
823+
}];
824+
}
825+
739826
//===---------------------------------------------------------------------===//
740827
// 8-bit float intrinsics
741828
//===---------------------------------------------------------------------===//
@@ -751,6 +838,20 @@ def ROCDL_CvtF32Bf8Op :
751838
}];
752839
}
753840

841+
def ROCDL_CvtScaleF32Bf8Op :
842+
ROCDL_IntrOp<"cvt.scalef32.f32.bf8", [], [], [Pure], 1>,
843+
Arguments<(ins I32:$src, F32: $scale, I32:$byteSel)> {
844+
let summary = "Scale and convert bf8 to f32";
845+
let description = [{
846+
Scale `src` by the exponent in `scale` then convert 8-bit bf8 value
847+
from the `byteSel`th bit of `src` to fp32.
848+
}];
849+
let assemblyFormat = [{
850+
attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
851+
}];
852+
}
853+
854+
754855
def ROCDL_CvtF32Fp8Op :
755856
ROCDL_IntrOp<"cvt.f32.fp8", [], [], [Pure], 1>,
756857
Arguments<(ins I32:$srcA, I32:$byteSel)> {
@@ -763,6 +864,22 @@ def ROCDL_CvtF32Fp8Op :
763864
}];
764865
}
765866

867+
868+
def ROCDL_CvtScaleF32Fp8Op :
869+
ROCDL_IntrOp<"cvt.scalef32.f32.fp8", [], [], [Pure], 1>,
870+
Arguments<(ins I32:$src, F32: $scale, I32:$byteSel)> {
871+
let summary = "Scale and convert fp8 to f32";
872+
let description = [{
873+
Scale `src` by the exponent in `scale` then convert 8-bit fp8 value
874+
from the `byteSel`th bit of `src` to fp32.
875+
876+
}];
877+
let assemblyFormat = [{
878+
attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
879+
}];
880+
}
881+
882+
766883
def ROCDL_CvtPkBf8F32Op :
767884
ROCDL_IntrOp<"cvt.pk.bf8.f32", [], [], [Pure], 1>,
768885
Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1:$wordSel)> {

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,23 +754,49 @@ llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
754754
// CHECK-LABEL: @rocdl_8bit_floats
755755
// CHECK: rocdl.cvt.f32.bf8
756756
// CHECK: rocdl.cvt.f32.fp8
757+
// CHECK: rocdl.cvt.scalef32.f32.bf8
758+
// CHECK: rocdl.cvt.scalef32.f32.fp8
757759
// CHECK: rocdl.cvt.pk.bf8.f32
758760
// CHECK: rocdl.cvt.pk.fp8.f32
759761
// CHECK: rocdl.cvt.sr.bf8.f32
760762
// CHECK: rocdl.cvt.sr.fp8.f32
763+
// CHECK: rocdl.cvt.scalef32.sr.fp8.f32
764+
// CHECK: rocdl.cvt.sr.bf8.f32
765+
// CHECK: rocdl.cvt.scalef32.sr.bf8.f32
766+
// CHECK: rocdl.cvt.scalef32.pk.f32.fp8
767+
// CHECK: rocdl.cvt.scalef32.pk.f32.bf8
761768
%c0 = llvm.mlir.constant(0 : i32) : i32
762769
%c2 = llvm.mlir.constant(2 : i32) : i32
763770
%c3 = llvm.mlir.constant(3 : i32) : i32
771+
%c4 = llvm.mlir.constant(1.0 : f32) : f32
764772
%false = llvm.mlir.constant(false) : i1
765773
%v1 = rocdl.cvt.f32.bf8 %source[%c0] : f32
766774
%v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32
775+
%v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[%c0], %c4 : f32
776+
%v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32
767777
%source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
768778
%source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
769779
%source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
770780
%source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
781+
%source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32
782+
%source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c3] : i32
783+
%source6_scaled = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[%c3] : i32
784+
%source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[%false], %c4 : f32
785+
%source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[%false], %c4 : f32
771786
llvm.return %source5 : i32
772787
}
773788

789+
llvm.func @rocdl_8bit_packed_v2i16(%sourceA: f32, %sourceB: f32, %old: vector<2xi16>) -> vector<2xi16> {
790+
// CHECK-LABEL: @rocdl_8bit_packed_v2i16
791+
// CHECK: rocdl.cvt.scalef32.pk.fp8.f32
792+
%c0 = llvm.mlir.constant(1.0 : f32) : f32
793+
%false = llvm.mlir.constant(false) : i1
794+
%source_scaled = rocdl.cvt.scalef32.pk.fp8.f32 %sourceA, %sourceB, %c0 -> %old[%false] : vector<2xi16>
795+
%source2_scaled = rocdl.cvt.scalef32.pk.bf8.f32 %sourceA, %sourceB, %c0 -> %old[%false] : vector<2xi16>
796+
llvm.return %source_scaled : vector<2xi16>
797+
}
798+
799+
774800
llvm.func @rocdl.waitcnt() {
775801
// CHECK-LABEL: rocdl.waitcnt
776802
// CHECK: rocdl.waitcnt 0

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,24 +1002,51 @@ llvm.func @rocdl.raw.buffer.atomic.cmpswap(%rsrc : vector<4xi32>,
10021002
llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
10031003
// CHECK-LABEL: @rocdl_8bit_floats
10041004
// CHECK: call float @llvm.amdgcn.cvt.f32.bf8(i32 %{{.+}}, i32 0)
1005+
// CHECK: call float @llvm.amdgcn.cvt.scalef32.f32.bf8(i32 %{{.+}}, float 1.000000e+00, i32 0)
10051006
// CHECK: call float @llvm.amdgcn.cvt.f32.fp8(i32 %{{.+}}, i32 0)
1007+
// CHECK: call float @llvm.amdgcn.cvt.scalef32.f32.fp8(i32 %{{.+}}, float 1.000000e+00, i32 0)
10061008
// CHECK: call i32 @llvm.amdgcn.cvt.pk.bf8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false)
10071009
// CHECK: call i32 @llvm.amdgcn.cvt.pk.fp8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false)
10081010
// CHECK: call i32 @llvm.amdgcn.cvt.sr.bf8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 2)
10091011
// CHECK: call i32 @llvm.amdgcn.cvt.sr.fp8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 3)
1012+
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.fp8.f32(i32 %{{.+}}, float %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
1013+
// CHECK: call i32 @llvm.amdgcn.cvt.sr.bf8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 3)
1014+
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.bf8.f32(i32 %{{.+}}, float %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
1015+
// CHECK: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.fp8(i32 %{{.+}}, float 1.000000e+00, i1 false)
1016+
// CHECK: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.bf8(i32 %{{.+}}, float 1.000000e+00, i1 false)
1017+
10101018
%c0 = llvm.mlir.constant(0 : i32) : i32
10111019
%c2 = llvm.mlir.constant(2 : i32) : i32
10121020
%c3 = llvm.mlir.constant(3 : i32) : i32
1021+
%c4 = llvm.mlir.constant(1.0 : f32) : f32
10131022
%false = llvm.mlir.constant(false) : i1
10141023
%v1 = rocdl.cvt.f32.bf8 %source[%c0] : f32
1024+
%v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[%c0], %c4 : f32
10151025
%v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32
1026+
%v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32
10161027
%source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
10171028
%source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
10181029
%source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
10191030
%source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
1031+
%source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32
1032+
%source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c3] : i32
1033+
%source6_scaled = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[%c3] : i32
1034+
%source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[%false], %c4 : f32
1035+
%source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[%false], %c4 : f32
10201036
llvm.return %source5 : i32
10211037
}
10221038

1039+
llvm.func @rocdl_8bit_packed_v2i16(%sourceA: f32, %sourceB: f32, %old: vector<2xi16>) -> vector<2xi16> {
1040+
// CHECK-LABEL: @rocdl_8bit_packed_v2i16
1041+
// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.fp8.f32(<2 x i16> %{{.+}}, float %{{.+}}, float %{{.+}}, float 1.000000e+00, i1 false)
1042+
// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.bf8.f32(<2 x i16> %{{.+}}, float %{{.+}}, float %{{.+}}, float 1.000000e+00, i1 false)
1043+
%c0 = llvm.mlir.constant(1.0 : f32) : f32
1044+
%false = llvm.mlir.constant(false) : i1
1045+
%source_scaled = rocdl.cvt.scalef32.pk.fp8.f32 %sourceA, %sourceB, %c0 -> %old[%false] : vector<2xi16>
1046+
%source2_scaled = rocdl.cvt.scalef32.pk.bf8.f32 %sourceA, %sourceB, %c0 -> %old[%false] : vector<2xi16>
1047+
llvm.return %source_scaled : vector<2xi16>
1048+
}
1049+
10231050
llvm.func @rocdl_16bit_packed_floats(%sourceA: f32, %sourceB: f32) -> vector<2xf16> {
10241051
// CHECK-LABEL: @rocdl_16bit_packed_floats
10251052
// CHECK: call <2 x half> @llvm.amdgcn.cvt.pkrtz(float {{.*}}, float {{.*}})

0 commit comments

Comments
 (0)