-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][ROCDL] Add Scale Convert Packed FP4 <-> F32/BF16/F16 Support for GFX950 #140676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Tim Gymnich (tgymnich) Changes
Full diff: https://github.com/llvm/llvm-project/pull/140676.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 6fb9e3aba1f0a..f52c2c391fbba 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -713,6 +713,10 @@ def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
BuildableType<"::mlir::VectorType::get("
"{2},$_builder.getI16Type())">;
+def ROCDL_V2F32Type : FixedVectorOfLengthAndType<[2], [F32]>,
+ BuildableType<"::mlir::VectorType::get("
+ "{2},$_builder.getF32Type())">;
+
def ROCDL_V2F16Type : FixedVectorOfLengthAndType<[2], [F16]>,
BuildableType<"::mlir::VectorType::get("
"{2},$_builder.getF16Type())">;
@@ -1005,6 +1009,120 @@ def ROCDL_CvtScaleF32SrBf8F32Op :
}];
}
+//===---------------------------------------------------------------------===//
+// 4-bit float scale intrinsics
+//===---------------------------------------------------------------------===//
+def ROCDL_CvtScaleF32PkFp4F32Op :
+ ROCDL_IntrOp<"cvt.scalef32.pk.fp4.f32", [], [], [Pure], 1>,
+ Arguments<(ins I32:$old, F32:$src0, F32:$src1, F32: $scale, I32:$byteSel)> {
+ let summary = "Convert f32 to packed fp4 and scale";
+ let description = [{ Convert `src` based on $$byteSe to packed fp4, then scale
+ the packed values by the exponent in `scale`.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src0 `,` $src1 `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleF32PkFp4F16Op :
+ ROCDL_IntrOp<"cvt.scalef32.pk.fp4.f16", [], [], [Pure], 1>,
+ Arguments<(ins I32:$old, ROCDL_V2F16Type:$src, F32: $scale, I32:$byteSel)> {
+ let summary = "Convert f16 to packed fp4 and scale";
+ let description = [{ Convert `src` based on $byteSel to packed fp4, then scale
+ the packed values by the exponent in `scale`.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleF32PkFp4Bf16Op :
+ ROCDL_IntrOp<"cvt.scalef32.pk.fp4.bf16", [], [], [Pure], 1>,
+ Arguments<(ins I32:$old, ROCDL_V2BF16Type:$src, F32: $scale, I32:$byteSel)> {
+ let summary = "Convert bf16 to packed fp4 and scale";
+ let description = [{ Convert `src` based on $byteSel to packed fp4, then scale
+ the packed values by the exponent in `scale`.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleF32SrPkFp4F32Op :
+ ROCDL_IntrOp<"cvt.scalef32.sr.pk.fp4.f32", [], [], [Pure], 1>,
+ Arguments<(ins I32:$old, ROCDL_V2F32Type:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+ let summary = "Scale and convert f32 to packed fp4 using stochastic rounding";
+ let description = [{
+ Scale `src` by the exponent in `scale` then convert to packed fp4 with stochastic rounding
+ using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleF32SrPkFp4F16Op :
+ ROCDL_IntrOp<"cvt.scalef32.sr.pk.fp4.f16", [], [], [Pure], 1>,
+ Arguments<(ins I32:$old, ROCDL_V2F16Type:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+ let summary = "Scale and convert f16 to packed fp4 using stochastic rounding";
+ let description = [{
+ Scale `src` by the exponent in `scale` then convert to packed fp4 with stochastic rounding
+ using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleF32SrPkFp4Bf16Op :
+ ROCDL_IntrOp<"cvt.scalef32.sr.pk.fp4.bf16", [], [], [Pure], 1>,
+ Arguments<(ins I32:$old, ROCDL_V2BF16Type:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+ let summary = "Scale and convert bf16 to packed fp4 using stochastic rounding";
+ let description = [{
+ Scale `src` by the exponent in `scale` then convert to packed fp4 with stochastic rounding
+ using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleF32PkF32Fp4Op :
+ ROCDL_IntrOp<"cvt.scalef32.pk.f32.fp4", [], [], [Pure], 1>,
+ Arguments<(ins I32:$src, F32:$scale, I32:$byteSel)> {
+ let summary = "Convert fp4 to packed f32 and scale";
+ let description = [{ Convert `src` based on $byteSel to packed f32, then scale
+ the packed values by the exponent in `scale`.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
+ }];
+}
+
+
+def ROCDL_CvtScaleF32PkF16Fp4Op :
+ ROCDL_IntrOp<"cvt.scalef32.pk.f16.fp4", [], [], [Pure], 1>,
+ Arguments<(ins I32:$src, F32:$scale, I32:$byteSel)> {
+ let summary = "Convert fp4 to packed f16 and scale";
+ let description = [{ Convert `src` based on $byteSel to packed f16, then scale
+ the packed values by the exponent in `scale`.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleF32PkBf16Fp4Op :
+ ROCDL_IntrOp<"cvt.scalef32.pk.bf16.fp4", [], [], [Pure], 1>,
+ Arguments<(ins I32:$src, F32:$scale, I32:$byteSel)> {
+ let summary = "Convert fp4 to packed bf16 and scale";
+ let description = [{ Convert `src` based on $byteSel to packed bf16, then scale
+ the packed values by the exponent in `scale`.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
+ }];
+}
//===---------------------------------------------------------------------===//
// 8-bit float intrinsics
//===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index fbde993891342..84fa29ee2d8a1 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -844,6 +844,31 @@ llvm.func @rocdl_v2f16_v2i16(%source: vector<2xf16>, %source2: vector<2xbf16>, %
llvm.return %source_scaled : vector<2xi16>
}
+llvm.func @rocdl_4bit_packed_floats(%old: i32, %source0: f32, %source1: f32, %source: vector<2xf32>, %source_half: vector<2xf16>, %source_bfloat: vector<2xbf16>, %stoch: i32) -> i32 {
+ // CHECK-LABEL: @rocdl_4bit_packed_floats
+ // CHECK: rocdl.cvt.scalef32.pk.fp4.f32
+ // CHECK: rocdl.cvt.scalef32.pk.fp4.f16
+ // CHECK: rocdl.cvt.scalef32.pk.fp4.bf16
+ // CHECK: rocdl.cvt.scalef32.sr.pk.fp4.f32
+ // CHECK: rocdl.cvt.scalef32.sr.pk.fp4.f16
+ // CHECK: rocdl.cvt.scalef32.sr.pk.fp4.bf16
+ // CHECK: rocdl.cvt.scalef32.pk.f32.fp4
+ // CHECK: rocdl.cvt.scalef32.pk.f16.fp4
+ // CHECK: rocdl.cvt.scalef32.pk.bf16.fp4
+ %c0 = llvm.mlir.constant(0 : i32) : i32
+ %scale = llvm.mlir.constant(1.0 : f32) : f32
+ %pk1 = rocdl.cvt.scalef32.pk.fp4.f32 %source0, %source1, %scale -> %old[%c0] : i32
+ %pk2 = rocdl.cvt.scalef32.pk.fp4.f16 %source_half, %scale -> %pk1[%c0] : i32
+ %pk3 = rocdl.cvt.scalef32.pk.fp4.bf16 %source_bfloat, %scale -> %pk2[%c0] : i32
+ %sr1 = rocdl.cvt.scalef32.sr.pk.fp4.f32 %source, %stoch, %scale -> %pk3[%c0] : i32
+ %sr2 = rocdl.cvt.scalef32.sr.pk.fp4.f16 %source_half, %stoch, %scale -> %sr1[%c0] : i32
+ %sr3 = rocdl.cvt.scalef32.sr.pk.fp4.bf16 %source_bfloat, %stoch, %scale -> %sr2[%c0] : i32
+ %pk4 = rocdl.cvt.scalef32.pk.f32.fp4 %old[%c0], %scale : vector<2xf32>
+ %pk5 = rocdl.cvt.scalef32.pk.f16.fp4 %old[%c0], %scale : vector<2xf16>
+ %pk6 = rocdl.cvt.scalef32.pk.bf16.fp4 %old[%c0], %scale : vector<2xbf16>
+ llvm.return %sr3 : i32
+}
+
llvm.func @rocdl.s.waitcnt() {
// CHECK-LABEL: rocdl.s.waitcnt
// CHECK: rocdl.s.waitcnt 0
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index b37f0da361950..16efce4a80908 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1145,6 +1145,31 @@ llvm.func @rocdl_16bit_packed_floats(%sourceA: f32, %sourceB: f32) -> vector<2xf
llvm.return %source : vector<2xf16>
}
+llvm.func @rocdl_4bit_packed_floats(%old: i32, %source0: f32, %source1: f32, %source: vector<2xf32>, %source_half: vector<2xf16>, %source_bfloat: vector<2xbf16>, %stoch: i32) -> i32 {
+ // CHECK-LABEL: @rocdl_4bit_packed_floats
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.f32(i32 %0, float %1, float %2, float 1.000000e+00, i32 0)
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.f16(i32 %8, <2 x half> %4, float 1.000000e+00, i32 0)
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.bf16(i32 %9, <2 x bfloat> %5, float 1.000000e+00, i32 0)
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.f32(i32 %10, <2 x float> %3, i32 %6, float 1.000000e+00, i32 0)
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.f16(i32 %11, <2 x half> %4, i32 %6, float 1.000000e+00, i32 0)
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.bf16(i32 %12, <2 x bfloat> %5, i32 %6, float 1.000000e+00, i32 0)
+ // CHECK: call <2xfloat> @llvm.amdgcn.cvt.scalef32.pk.f32.fp4(i32 %0, float 1.000000e+00, i32 0)
+ // CHECK: call <2xhalf> @llvm.amdgcn.cvt.scalef32.pk.f16.fp4(i32 %0, float 1.000000e+00, i32 0)
+ // CHECK: call <2xbfloat> @llvm.amdgcn.cvt.scalef32.pk.bf16.fp4(i32 %0, float 1.000000e+00, i32 0)
+ %c0 = llvm.mlir.constant(0 : i32) : i32
+ %scale = llvm.mlir.constant(1.0 : f32) : f32
+ %pk1 = rocdl.cvt.scalef32.pk.fp4.f32 %source0, %source1, %scale -> %old[%c0] : i32
+ %pk2 = rocdl.cvt.scalef32.pk.fp4.f16 %source_half, %scale -> %pk1[%c0] : i32
+ %pk3 = rocdl.cvt.scalef32.pk.fp4.bf16 %source_bfloat, %scale -> %pk2[%c0] : i32
+ %sr1 = rocdl.cvt.scalef32.sr.pk.fp4.f32 %source, %stoch, %scale -> %pk3[%c0] : i32
+ %sr2 = rocdl.cvt.scalef32.sr.pk.fp4.f16 %source_half, %stoch, %scale -> %sr1[%c0] : i32
+ %sr3 = rocdl.cvt.scalef32.sr.pk.fp4.bf16 %source_bfloat, %stoch, %scale -> %sr2[%c0] : i32
+ %pk4 = rocdl.cvt.scalef32.pk.f32.fp4 %old[%c0], %scale : vector<2xf32>
+ %pk5 = rocdl.cvt.scalef32.pk.f16.fp4 %old[%c0], %scale : vector<2xf16>
+ %pk6 = rocdl.cvt.scalef32.pk.bf16.fp4 %old[%c0], %scale : vector<2xbf16>
+ llvm.return %sr3 : i32
+}
+
llvm.func @rocdl_atomic_attrs(%ptr: !llvm.ptr<1>, %data: f32) {
// CHECK-LABEL: @rocdl_atomic_attrs
// CHECK: atomicrmw
|
These will throw an unsupported arch error at the LLVM IR level if compiled for arch < GFX950 but is it worth having some MLIR op verifiers for the arch in the ROCDL op definition (I realize that all the other, existing ops, don't have any arch based verifiers)? |
... I'll have a competing PR up an a few hours that also does a bit of a refactor of these intrinsics / has immarg handled correctly |
closing in favor of #140801 |
Rocdl support for the following GFX950 instructions:
rocdl.cvt.scalef32.pk.fp4.f32
rocdl.cvt.scalef32.pk.fp4.f16
rocdl.cvt.scalef32.pk.fp4.bf16
rocdl.cvt.scalef32.sr.pk.fp4.f32
rocdl.cvt.scalef32.sr.pk.fp4.f16
rocdl.cvt.scalef32.sr.pk.fp4.bf16
rocdl.cvt.scalef32.pk.f32.fp4
rocdl.cvt.scalef32.pk.f16.fp4
rocdl.cvt.scalef32.pk.bf16.fp4