-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][ROCDL] Add Scale Convert Packed (B)FP8 <-> (B)F16 Support for GFX950 #130300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: Corbin Robeck (CRobeck) ChangesAdd Rocdl support for the following GFX950 instructions: CVT_SCALE_PK_FP8_F16 Full diff: https://github.com/llvm/llvm-project/pull/130300.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 18fec95f700c4..5435c32000bb5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -652,6 +652,18 @@ def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
}];
}
+def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
+ BuildableType<"::mlir::VectorType::get("
+ "{2},$_builder.getI16Type())">;
+
+def ROCDL_V2F16Type : FixedVectorOfLengthAndType<[2], [F16]>,
+ BuildableType<"::mlir::VectorType::get("
+ "{2},$_builder.getF16Type())">;
+
+def ROCDL_V2BF16Type : FixedVectorOfLengthAndType<[2], [BF16]>,
+ BuildableType<"::mlir::VectorType::get("
+ "{2},$_builder.getBF16Type())">;
+
//===---------------------------------------------------------------------===//
// 16-bit float intrinsics
//===---------------------------------------------------------------------===//
@@ -667,6 +679,164 @@ def ROCDL_CvtPkRtz:
}];
}
+def ROCDL_CvtScalePkFp8F16 :
+ ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f16", [], [], [Pure], 1>,
+ Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
+ let summary = "Scale and convert f16 to packed fp8";
+ let description = [{
+ Scale `src` by the exponent in `scale` then convert to packed fp8.
+ Store the result in low/high word based on $wordSel, preserving the other word.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScalePkFp8Bf16 :
+ ROCDL_IntrOp<"cvt.scalef32.pk.fp8.bf16", [], [], [Pure], 1>,
+ Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
+ let summary = "Scale and convert packed bf16 to packed fp8";
+ let description = [{
+ Scale `src` by the exponent in `scale` then convert to packed fp8.
+ Store the result in low/high word based on $wordSel, preserving the other word.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
+ }];
+}
+
+
+def ROCDL_CvtScalePkBf8F16 :
+ ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f16", [], [], [Pure], 1>,
+ Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
+ let summary = "Scale and convert f16 to packed bf8";
+ let description = [{
+ Scale `src` by the exponent in `scale` then convert to packed bf8.
+ Store the result in low/high word based on $wordSel, preserving the other word.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
+ }];
+}
+
+
+def ROCDL_CvtScalePkBf8Bf16 :
+ ROCDL_IntrOp<"cvt.scalef32.pk.bf8.bf16", [], [], [Pure], 1>,
+ Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
+ let summary = "Scale and convert bf16 to packed bf8";
+ let description = [{
+ Scale `src` by the exponent in `scale` then convert to packed bf8.
+ Store the result in low/high word based on $wordSel, preserving the other word.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleSrFp8F16 :
+ ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f16", [], [], [Pure], 1>,
+ Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+ let summary = "Scale and convert f16 to packed fp8 using stochastic rounding";
+ let description = [{
+ Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding
+ using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleSrBf8F16 :
+ ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f16", [], [], [Pure], 1>,
+ Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+ let summary = "Scale and convert f16 to packed bf8 using stochastic rounding";
+ let description = [{
+ Scale `src` by the exponent in `scale` then convert to packed bf8 with stochastic rounding
+ using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleSrFp8Bf16 :
+ ROCDL_IntrOp<"cvt.scalef32.sr.fp8.bf16", [], [], [Pure], 1>,
+ Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+ let summary = "Scale and convert packed bf16 to packed fp8 using stochastic rounding";
+ let description = [{
+ Scale `src` by the exponent in `scale` then convert to packed fp8 with stochastic rounding
+ using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleSrBf8Bf16:
+ ROCDL_IntrOp<"cvt.scalef32.sr.bf8.bf16", [], [], [Pure], 1>,
+ Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+ let summary = "Scale and convert bf16 to packed fp8 using stochastic rounding";
+ let description = [{
+ Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding
+ using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScalePkF16Fp8 :
+ ROCDL_IntrOp<"cvt.scalef32.pk.f16.fp8", [], [], [Pure], 1>,
+ Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
+ let summary = "Scale and convert fp8 to packed f16";
+ let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
+ then convert to packed f16.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScalePkF16Bf8 :
+ ROCDL_IntrOp<"cvt.scalef32.pk.f16.bf8", [], [], [Pure], 1>,
+ Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
+ let summary = "Scale and convert bf8 to packed f16";
+ let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
+ then convert to packed f16.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleF16Fp8 :
+ ROCDL_IntrOp<"cvt.scalef32.f16.fp8", [], [], [Pure], 1>,
+ Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
+ let summary = "Scale and convert fp8 to f16";
+ let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
+ then convert to f16 store into the `byteSel`th byte of `old`, preserving the others.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleF16Bf8 :
+ ROCDL_IntrOp<"cvt.scalef32.f16.bf8", [], [], [Pure], 1>,
+ Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
+ let summary = "Scale and convert fp8 to f16";
+ let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
+ then convert to f16 store into the `byteSel`th byte of `old`, preserving the others.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+ }];
+}
+
//===---------------------------------------------------------------------===//
// 32-bit float intrinsics
//===---------------------------------------------------------------------===//
@@ -697,10 +867,6 @@ def ROCDL_CvtScalePkF32Bf8 :
//===---------------------------------------------------------------------===//
// 8-bit float scale intrinsics
//===---------------------------------------------------------------------===//
-def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
- BuildableType<"::mlir::VectorType::get("
- "{2},$_builder.getI16Type())">;
-
def ROCDL_CvtScaleF32PkFp8F32:
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f32", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32:$scale, I1:$wordSel)> {
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 05469b64a8083..bc917041998d8 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -759,19 +759,27 @@ llvm.func @rocdl.raw.buffer.i32(%rsrc : vector<4xi32>,
llvm.return
}
-llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
+llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf16, %stoch: i32) -> i32 {
// CHECK-LABEL: @rocdl_8bit_floats
// CHECK: rocdl.cvt.f32.bf8
// CHECK: rocdl.cvt.f32.fp8
// CHECK: rocdl.cvt.scalef32.f32.bf8
// CHECK: rocdl.cvt.scalef32.f32.fp8
+// CHECK: rocdl.cvt.scalef32.pk.f16.bf8
+// CHECK: rocdl.cvt.scalef32.pk.f16.fp8
+// CHECK: rocdl.cvt.scalef32.f16.fp8
+// CHECK: rocdl.cvt.scalef32.f16.bf8
// CHECK: rocdl.cvt.pk.bf8.f32
// CHECK: rocdl.cvt.pk.fp8.f32
// CHECK: rocdl.cvt.sr.bf8.f32
// CHECK: rocdl.cvt.sr.fp8.f32
// CHECK: rocdl.cvt.scalef32.sr.fp8.f32
+// CHECK: rocdl.cvt.scalef32.sr.fp8.f16
+// CHECK: rocdl.cvt.scalef32.sr.fp8.bf16
// CHECK: rocdl.cvt.sr.bf8.f32
// CHECK: rocdl.cvt.scalef32.sr.bf8.f32
+// CHECK: rocdl.cvt.scalef32.sr.bf8.f16
+// CHECK: rocdl.cvt.scalef32.sr.bf8.bf16
// CHECK: rocdl.cvt.scalef32.pk.f32.fp8
// CHECK: rocdl.cvt.scalef32.pk.f32.bf8
%c0 = llvm.mlir.constant(0 : i32) : i32
@@ -783,13 +791,21 @@ llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
%v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32
%v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[%c0], %c4 : f32
%v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32
+ %v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[%false], %c4 : vector<2xf16>
+ %v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : vector<2xf16>
+ %v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %v3_scaled[%c0] : f16
+ %v6 = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %v3_scaled[%c0] : f16
%source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
%source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
%source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
%source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
%source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32
+ %source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[%c3] : i32
+ %source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[%c3] : i32
%source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c3] : i32
%source6_scaled = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[%c3] : i32
+ %source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[%c3] : i32
+ %source6_scaled_bfloat = rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[%c3] : i32
%source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[%false], %c4 : f32
%source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[%false], %c4 : f32
llvm.return %source5 : i32
@@ -805,6 +821,18 @@ llvm.func @rocdl_8bit_packed_v2i16(%sourceA: f32, %sourceB: f32, %old: vector<2x
llvm.return %source_scaled : vector<2xi16>
}
+llvm.func @rocdl_v2f16_v2i16(%source: vector<2xf16>, %source2: vector<2xbf16>, %old: vector<2xi16>) -> vector<2xi16> {
+// CHECK-LABEL: @rocdl_v2f16_v2i16
+// CHECK: rocdl.cvt.scalef32.pk.fp8.f16
+ %c0 = llvm.mlir.constant(1.0 : f32) : f32
+ %false = llvm.mlir.constant(false) : i1
+ %source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
+ %source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
+ %source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
+ %source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
+ llvm.return %source_scaled : vector<2xi16>
+}
+
llvm.func @rocdl.s.waitcnt() {
// CHECK-LABEL: rocdl.s.waitcnt
// CHECK: rocdl.s.waitcnt 0
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 981ef4848535c..11f2faa2761ff 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1032,22 +1032,29 @@ llvm.func @rocdl.raw.buffer.atomic.cmpswap(%rsrc : vector<4xi32>,
llvm.return %val : i32
}
-llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
+llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf16, %source_packed: vector<2xf16>, %stoch: i32) -> i32 {
// CHECK-LABEL: @rocdl_8bit_floats
// CHECK: call float @llvm.amdgcn.cvt.f32.bf8(i32 %{{.+}}, i32 0)
// CHECK: call float @llvm.amdgcn.cvt.scalef32.f32.bf8(i32 %{{.+}}, float 1.000000e+00, i32 0)
// CHECK: call float @llvm.amdgcn.cvt.f32.fp8(i32 %{{.+}}, i32 0)
// CHECK: call float @llvm.amdgcn.cvt.scalef32.f32.fp8(i32 %{{.+}}, float 1.000000e+00, i32 0)
+// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.pk.f16.bf8(i32 %{{.+}}, float 1.000000e+00, i1 false)
+// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.pk.f16.fp8(i32 %{{.+}}, float 1.000000e+00, i1 false)
+// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.f16.fp8(<2 x half> %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 0, i1 false)
+// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.f16.bf8(<2 x half> %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 0, i1 false)
// CHECK: call i32 @llvm.amdgcn.cvt.pk.bf8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false)
// CHECK: call i32 @llvm.amdgcn.cvt.pk.fp8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false)
// CHECK: call i32 @llvm.amdgcn.cvt.sr.bf8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 2)
// CHECK: call i32 @llvm.amdgcn.cvt.sr.fp8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 3)
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.fp8.f32(i32 %{{.+}}, float %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
+// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.fp8.f16(i32 %{{.+}}, half %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
+// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.fp8.bf16(i32 %{{.+}}, bfloat %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
// CHECK: call i32 @llvm.amdgcn.cvt.sr.bf8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 3)
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.bf8.f32(i32 %{{.+}}, float %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
+// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.bf8.f16(i32 %{{.+}}, half %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
+// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.bf8.bf16(i32 %{{.+}}, bfloat %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
// CHECK: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.fp8(i32 %{{.+}}, float 1.000000e+00, i1 false)
// CHECK: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.bf8(i32 %{{.+}}, float 1.000000e+00, i1 false)
-
%c0 = llvm.mlir.constant(0 : i32) : i32
%c2 = llvm.mlir.constant(2 : i32) : i32
%c3 = llvm.mlir.constant(3 : i32) : i32
@@ -1057,13 +1064,21 @@ llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
%v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[%c0], %c4 : f32
%v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32
%v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32
+ %v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[%false], %c4 : i32
+ %v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : i32
+ %v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %source_packed[%c0] : f16
+ %v6 = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %source_packed[%c0] : f16
%source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
%source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
%source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
%source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
%source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32
+ %source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[%c3] : i32
+ %source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[%c3] : i32
%source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c3] : i32
%source6_scaled = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[%c3] : i32
+ %source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[%c3] : i32
+ %source6_scaled_bfloat = rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[%c3] : i32
%source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[%false], %c4 : f32
%source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[%false], %c4 : f32
llvm.return %source5 : i32
@@ -1080,6 +1095,21 @@ llvm.func @rocdl_8bit_packed_v2i16(%sourceA: f32, %sourceB: f32, %old: vector<2x
llvm.return %source_scaled : vector<2xi16>
}
+llvm.func @rocdl_v2f16_v2i16(%source: vector<2xf16>, %source2: vector<2xbf16>, %old: vector<2xi16>) -> vector<2xi16> {
+// CHECK-LABEL: @rocdl_v2f16_v2i16
+// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.fp8.f16(<2 x i16> %2, <2 x half> %0, float 1.000000e+00, i1 false)
+// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.fp8.bf16(<2 x i16> %2, <2 x bfloat> %1, float 1.000000e+00, i1 false)
+// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.bf8.f16(<2 x i16> %2, <2 x half> %0, float 1.000000e+00, i1 false)
+// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.bf8.bf16(<2 x i16> %2, <2 x bfloat> %1, float 1.000000e+00, i1 false)
+ %c0 = llvm.mlir.constant(1.0 : f32) : f32
+ %false = llvm.mlir.constant(false) : i1
+ %source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
+ %source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
+ %source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
+ %source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
+ llvm.return %source_scaled : vector<2xi16>
+}
+
llvm.func @rocdl_16bit_packed_floats(%sourceA: f32, %sourceB: f32) -> vector<2xf16> {
// CHECK-LABEL: @rocdl_16bit_packed_floats
// CHECK: call <2 x half> @llvm.amdgcn.cvt.pkrtz(float {{.*}}, float {{.*}})
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The word and byte selectors are immarg
in LLVM and so should be attributes in MLIR
I don't have other objections that I could see
While I don't disagree that would require a pretty involved rewrite of the ROCDLOps.td tablegen file since almost every existing conversion op includes them in the op definition. If we decide to do that it should be a stand alone PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Huh, yeah, you're right
Agreed that someone should fix that at some point in a followup PR
Approved
7a39341
to
08e177d
Compare
Pulls in llvm/llvm-project#130300 for gfx950 type conversion ops.
Pulls in llvm/llvm-project#130300 for gfx950 type conversion ops. This update includes changes that store the parsed Triple in the module and deprecate `match` and `rewrite` functions.
…GFX950 (llvm#130300) Add Rocdl support for the following GFX950 instructions: CVT_SCALE_PK_FP8_F16 CVT_SCALE_PK_BF8_F16 CVT_SCALE_PK_FP8_BF16 CVT_SCALE_PK_BF8_BF16 CVT_SCALE_SR_FP8_F16 CVT_SCALE_SR_BF8_F16 CVT_SCALE_SR_FP8_BF16 CVT_SCALE_SR_BF8_BF16 CVT_SCALE_PK_F16_FP8 CVT_SCALE_PK_F16_BF8 CVT_SCALE_F16_FP8 CVT_SCALE_F16_BF8
Pulls in llvm/llvm-project#130300 for gfx950 type conversion ops.
Add Rocdl support for the following GFX950 instructions:
CVT_SCALE_PK_FP8_F16
CVT_SCALE_PK_BF8_F16
CVT_SCALE_PK_FP8_BF16
CVT_SCALE_PK_BF8_BF16
CVT_SCALE_SR_FP8_F16
CVT_SCALE_SR_BF8_F16
CVT_SCALE_SR_FP8_BF16
CVT_SCALE_SR_BF8_BF16
CVT_SCALE_PK_F16_FP8
CVT_SCALE_PK_F16_BF8
CVT_SCALE_F16_FP8
CVT_SCALE_F16_BF8