Skip to content

Commit b14d243

Browse files
committed
[MLIR][ROCDL] Add Scale Convert Packed FP4 <-> F32/BF16/F16 Support for GFX950
1 parent 575f66c commit b14d243

File tree

3 files changed

+156
-0
lines changed

3 files changed

+156
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,10 @@ def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
713713
BuildableType<"::mlir::VectorType::get("
714714
"{2},$_builder.getI16Type())">;
715715

716+
def ROCDL_V2F32Type : FixedVectorOfLengthAndType<[2], [F32]>,
717+
BuildableType<"::mlir::VectorType::get("
718+
"{2},$_builder.getF32Type())">;
719+
716720
def ROCDL_V2F16Type : FixedVectorOfLengthAndType<[2], [F16]>,
717721
BuildableType<"::mlir::VectorType::get("
718722
"{2},$_builder.getF16Type())">;
@@ -1005,6 +1009,120 @@ def ROCDL_CvtScaleF32SrBf8F32Op :
10051009
}];
10061010
}
10071011

1012+
//===---------------------------------------------------------------------===//
1013+
// 4-bit float scale intrinsics
1014+
//===---------------------------------------------------------------------===//
1015+
def ROCDL_CvtScaleF32PkFp4F32Op :
1016+
ROCDL_IntrOp<"cvt.scalef32.pk.fp4.f32", [], [], [Pure], 1>,
1017+
Arguments<(ins I32:$old, F32:$src0, F32:$src1, F32: $scale, I32:$byteSel)> {
1018+
let summary = "Convert f32 to packed fp4 and scale";
1019+
let description = [{ Convert `src` based on $$byteSe to packed fp4, then scale
1020+
the packed values by the exponent in `scale`.
1021+
}];
1022+
let assemblyFormat = [{
1023+
attr-dict $src0 `,` $src1 `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
1024+
}];
1025+
}
1026+
1027+
def ROCDL_CvtScaleF32PkFp4F16Op :
1028+
ROCDL_IntrOp<"cvt.scalef32.pk.fp4.f16", [], [], [Pure], 1>,
1029+
Arguments<(ins I32:$old, ROCDL_V2F16Type:$src, F32: $scale, I32:$byteSel)> {
1030+
let summary = "Convert f16 to packed fp4 and scale";
1031+
let description = [{ Convert `src` based on $byteSel to packed fp4, then scale
1032+
the packed values by the exponent in `scale`.
1033+
}];
1034+
let assemblyFormat = [{
1035+
attr-dict $src `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
1036+
}];
1037+
}
1038+
1039+
def ROCDL_CvtScaleF32PkFp4Bf16Op :
1040+
ROCDL_IntrOp<"cvt.scalef32.pk.fp4.bf16", [], [], [Pure], 1>,
1041+
Arguments<(ins I32:$old, ROCDL_V2BF16Type:$src, F32: $scale, I32:$byteSel)> {
1042+
let summary = "Convert bf16 to packed fp4 and scale";
1043+
let description = [{ Convert `src` based on $byteSel to packed fp4, then scale
1044+
the packed values by the exponent in `scale`.
1045+
}];
1046+
let assemblyFormat = [{
1047+
attr-dict $src `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
1048+
}];
1049+
}
1050+
1051+
def ROCDL_CvtScaleF32SrPkFp4F32Op :
1052+
ROCDL_IntrOp<"cvt.scalef32.sr.pk.fp4.f32", [], [], [Pure], 1>,
1053+
Arguments<(ins I32:$old, ROCDL_V2F32Type:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
1054+
let summary = "Scale and convert f32 to packed fp4 using stochastic rounding";
1055+
let description = [{
1056+
Scale `src` by the exponent in `scale` then convert to packed fp4 with stochastic rounding
1057+
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
1058+
}];
1059+
let assemblyFormat = [{
1060+
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
1061+
}];
1062+
}
1063+
1064+
def ROCDL_CvtScaleF32SrPkFp4F16Op :
1065+
ROCDL_IntrOp<"cvt.scalef32.sr.pk.fp4.f16", [], [], [Pure], 1>,
1066+
Arguments<(ins I32:$old, ROCDL_V2F16Type:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
1067+
let summary = "Scale and convert f16 to packed fp4 using stochastic rounding";
1068+
let description = [{
1069+
Scale `src` by the exponent in `scale` then convert to packed fp4 with stochastic rounding
1070+
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
1071+
}];
1072+
let assemblyFormat = [{
1073+
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
1074+
}];
1075+
}
1076+
1077+
def ROCDL_CvtScaleF32SrPkFp4Bf16Op :
1078+
ROCDL_IntrOp<"cvt.scalef32.sr.pk.fp4.bf16", [], [], [Pure], 1>,
1079+
Arguments<(ins I32:$old, ROCDL_V2BF16Type:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
1080+
let summary = "Scale and convert bf16 to packed fp4 using stochastic rounding";
1081+
let description = [{
1082+
Scale `src` by the exponent in `scale` then convert to packed fp4 with stochastic rounding
1083+
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
1084+
}];
1085+
let assemblyFormat = [{
1086+
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
1087+
}];
1088+
}
1089+
1090+
def ROCDL_CvtScaleF32PkF32Fp4Op :
1091+
ROCDL_IntrOp<"cvt.scalef32.pk.f32.fp4", [], [], [Pure], 1>,
1092+
Arguments<(ins I32:$src, F32:$scale, I32:$byteSel)> {
1093+
let summary = "Convert fp4 to packed f32 and scale";
1094+
let description = [{ Convert `src` based on $byteSel to packed f32, then scale
1095+
the packed values by the exponent in `scale`.
1096+
}];
1097+
let assemblyFormat = [{
1098+
attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
1099+
}];
1100+
}
1101+
1102+
1103+
def ROCDL_CvtScaleF32PkF16Fp4Op :
1104+
ROCDL_IntrOp<"cvt.scalef32.pk.f16.fp4", [], [], [Pure], 1>,
1105+
Arguments<(ins I32:$src, F32:$scale, I32:$byteSel)> {
1106+
let summary = "Convert fp4 to packed f16 and scale";
1107+
let description = [{ Convert `src` based on $byteSel to packed f16, then scale
1108+
the packed values by the exponent in `scale`.
1109+
}];
1110+
let assemblyFormat = [{
1111+
attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
1112+
}];
1113+
}
1114+
1115+
def ROCDL_CvtScaleF32PkBf16Fp4Op :
1116+
ROCDL_IntrOp<"cvt.scalef32.pk.bf16.fp4", [], [], [Pure], 1>,
1117+
Arguments<(ins I32:$src, F32:$scale, I32:$byteSel)> {
1118+
let summary = "Convert fp4 to packed bf16 and scale";
1119+
let description = [{ Convert `src` based on $byteSel to packed bf16, then scale
1120+
the packed values by the exponent in `scale`.
1121+
}];
1122+
let assemblyFormat = [{
1123+
attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
1124+
}];
1125+
}
10081126
//===---------------------------------------------------------------------===//
10091127
// 8-bit float intrinsics
10101128
//===---------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,25 @@ llvm.func @rocdl_v2f16_v2i16(%source: vector<2xf16>, %source2: vector<2xbf16>, %
844844
llvm.return %source_scaled : vector<2xi16>
845845
}
846846

847+
llvm.func @rocdl_4bit_packed_floats(%old: i32, %source0: f32, %source1: f32, %source: vector<2xf32>, %source_half: vector<2xf16>, %source_bfloat: vector<2xbf16>, %stoch: i32) -> i32 {
848+
// CHECK-LABEL: @rocdl_4bit_packed_floats
849+
// CHECK: rocdl.cvt.scalef32.pk.fp4.f32
850+
// CHECK: rocdl.cvt.scalef32.pk.fp4.f16
851+
// CHECK: rocdl.cvt.scalef32.pk.fp4.bf16
852+
// CHECK: rocdl.cvt.scalef32.sr.pk.fp4.f32
853+
// CHECK: rocdl.cvt.scalef32.sr.pk.fp4.f16
854+
// CHECK: rocdl.cvt.scalef32.sr.pk.fp4.bf16
855+
%c0 = llvm.mlir.constant(0 : i32) : i32
856+
%scale = llvm.mlir.constant(1.0 : f32) : f32
857+
%pk1 = rocdl.cvt.scalef32.pk.fp4.f32 %source0, %source1, %scale -> %old[%c0] : i32
858+
%pk2 = rocdl.cvt.scalef32.pk.fp4.f16 %source_half, %scale -> %pk1[%c0] : i32
859+
%pk3 = rocdl.cvt.scalef32.pk.fp4.bf16 %source_bfloat, %scale -> %pk2[%c0] : i32
860+
%sr1 = rocdl.cvt.scalef32.sr.pk.fp4.f32 %source, %stoch, %scale -> %pk3[%c0] : i32
861+
%sr2 = rocdl.cvt.scalef32.sr.pk.fp4.f16 %source_half, %stoch, %scale -> %sr1[%c0] : i32
862+
%sr3 = rocdl.cvt.scalef32.sr.pk.fp4.bf16 %source_bfloat, %stoch, %scale -> %sr2[%c0] : i32
863+
llvm.return %sr3 : i32
864+
}
865+
847866
llvm.func @rocdl.s.waitcnt() {
848867
// CHECK-LABEL: rocdl.s.waitcnt
849868
// CHECK: rocdl.s.waitcnt 0

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,6 +1145,25 @@ llvm.func @rocdl_16bit_packed_floats(%sourceA: f32, %sourceB: f32) -> vector<2xf
11451145
llvm.return %source : vector<2xf16>
11461146
}
11471147

1148+
llvm.func @rocdl_4bit_packed_floats(%old: i32, %source0: f32, %source1: f32, %source: vector<2xf32>, %source_half: vector<2xf16>, %source_bfloat: vector<2xbf16>, %stoch: i32) -> i32 {
1149+
// CHECK-LABEL: @rocdl_4bit_packed_floats
1150+
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.f32(i32 %0, float %1, float %2, float 1.000000e+00, i32 0)
1151+
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.f16(i32 %8, <2 x half> %4, float 1.000000e+00, i32 0)
1152+
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.bf16(i32 %9, <2 x bfloat> %5, float 1.000000e+00, i32 0)
1153+
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.f32(i32 %10, <2 x float> %3, i32 %6, float 1.000000e+00, i32 0)
1154+
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.f16(i32 %11, <2 x half> %4, i32 %6, float 1.000000e+00, i32 0)
1155+
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.bf16(i32 %12, <2 x bfloat> %5, i32 %6, float 1.000000e+00, i32 0)
1156+
%c0 = llvm.mlir.constant(0 : i32) : i32
1157+
%scale = llvm.mlir.constant(1.0 : f32) : f32
1158+
%pk1 = rocdl.cvt.scalef32.pk.fp4.f32 %source0, %source1, %scale -> %old[%c0] : i32
1159+
%pk2 = rocdl.cvt.scalef32.pk.fp4.f16 %source_half, %scale -> %pk1[%c0] : i32
1160+
%pk3 = rocdl.cvt.scalef32.pk.fp4.bf16 %source_bfloat, %scale -> %pk2[%c0] : i32
1161+
%sr1 = rocdl.cvt.scalef32.sr.pk.fp4.f32 %source, %stoch, %scale -> %pk3[%c0] : i32
1162+
%sr2 = rocdl.cvt.scalef32.sr.pk.fp4.f16 %source_half, %stoch, %scale -> %sr1[%c0] : i32
1163+
%sr3 = rocdl.cvt.scalef32.sr.pk.fp4.bf16 %source_bfloat, %stoch, %scale -> %sr2[%c0] : i32
1164+
llvm.return %sr3 : i32
1165+
}
1166+
11481167
llvm.func @rocdl_atomic_attrs(%ptr: !llvm.ptr<1>, %data: f32) {
11491168
// CHECK-LABEL: @rocdl_atomic_attrs
11501169
// CHECK: atomicrmw

0 commit comments

Comments
 (0)