Skip to content

[MLIR][ROCDL] Add Scale Convert Packed (B)FP8 <-> (B)F16 Support for GFX950 #130300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 7, 2025

Conversation

CRobeck
Copy link
Member

@CRobeck CRobeck commented Mar 7, 2025

Add Rocdl support for the following GFX950 instructions:

CVT_SCALE_PK_FP8_F16
CVT_SCALE_PK_BF8_F16
CVT_SCALE_PK_FP8_BF16
CVT_SCALE_PK_BF8_BF16
CVT_SCALE_SR_FP8_F16
CVT_SCALE_SR_BF8_F16
CVT_SCALE_SR_FP8_BF16
CVT_SCALE_SR_BF8_BF16
CVT_SCALE_PK_F16_FP8
CVT_SCALE_PK_F16_BF8
CVT_SCALE_F16_FP8
CVT_SCALE_F16_BF8

@llvmbot
Copy link
Member

llvmbot commented Mar 7, 2025

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Corbin Robeck (CRobeck)

Changes

Add Rocdl support for the following GFX950 instructions:

CVT_SCALE_PK_FP8_F16
CVT_SCALE_PK_BF8_F16
CVT_SCALE_PK_FP8_BF16
CVT_SCALE_PK_BF8_BF16
CVT_SCALE_SR_FP8_F16
CVT_SCALE_SR_BF8_F16
CVT_SCALE_SR_FP8_BF16
CVT_SCALE_SR_BF8_BF16
CVT_SCALE_PK_F16_FP8
CVT_SCALE_PK_F16_BF8
CVT_SCALE_F16_FP8
CVT_SCALE_F16_BF8


Full diff: https://github.com/llvm/llvm-project/pull/130300.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+170-4)
  • (modified) mlir/test/Dialect/LLVMIR/rocdl.mlir (+29-1)
  • (modified) mlir/test/Target/LLVMIR/rocdl.mlir (+32-2)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 18fec95f700c4..5435c32000bb5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -652,6 +652,18 @@ def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
   }];
 }
 
+def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
+                        BuildableType<"::mlir::VectorType::get("
+                          "{2},$_builder.getI16Type())">;
+
+def ROCDL_V2F16Type : FixedVectorOfLengthAndType<[2], [F16]>,
+                        BuildableType<"::mlir::VectorType::get("
+                          "{2},$_builder.getF16Type())">;
+
+def ROCDL_V2BF16Type : FixedVectorOfLengthAndType<[2], [BF16]>,
+                        BuildableType<"::mlir::VectorType::get("
+                          "{2},$_builder.getBF16Type())">;
+
 //===---------------------------------------------------------------------===//
 // 16-bit float intrinsics
 //===---------------------------------------------------------------------===//
@@ -667,6 +679,164 @@ def ROCDL_CvtPkRtz:
   }];
 }
 
+def ROCDL_CvtScalePkFp8F16 :
+    ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f16", [], [], [Pure], 1>,
+    Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
+    let summary = "Scale and convert f16 to packed fp8";
+    let description = [{
+    Scale `src` by the exponent in `scale` then convert to packed fp8.
+    Store the result in low/high word based on $wordSel, preserving the other word.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScalePkFp8Bf16 :
+    ROCDL_IntrOp<"cvt.scalef32.pk.fp8.bf16", [], [], [Pure], 1>,
+    Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
+    let summary = "Scale and convert packed bf16 to packed fp8";
+    let description = [{
+    Scale `src` by the exponent in `scale` then convert to packed fp8.
+    Store the result in low/high word based on $wordSel, preserving the other word.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
+  }];
+}
+
+
+def ROCDL_CvtScalePkBf8F16 :
+    ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f16", [], [], [Pure], 1>,
+    Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
+    let summary = "Scale and convert f16 to packed bf8";
+    let description = [{
+    Scale `src` by the exponent in `scale` then convert to packed bf8.
+    Store the result in low/high word based on $wordSel, preserving the other word.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
+  }];
+}
+
+
+def ROCDL_CvtScalePkBf8Bf16 :
+    ROCDL_IntrOp<"cvt.scalef32.pk.bf8.bf16", [], [], [Pure], 1>,
+    Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
+    let summary = "Scale and convert bf16 to packed bf8";
+    let description = [{
+    Scale `src` by the exponent in `scale` then convert to packed bf8.
+    Store the result in low/high word based on $wordSel, preserving the other word.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleSrFp8F16 :
+    ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f16", [], [], [Pure], 1>,
+    Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+    let summary = "Scale and convert f16 to packed fp8 using stochastic rounding";
+    let description = [{
+    Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding
+    using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleSrBf8F16 :
+    ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f16", [], [], [Pure], 1>,
+    Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+    let summary = "Scale and convert f16 to packed bf8 using stochastic rounding";
+    let description = [{
+    Scale `src` by the exponent in `scale` then convert to packed bf8 with stochastic rounding
+    using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleSrFp8Bf16 :
+    ROCDL_IntrOp<"cvt.scalef32.sr.fp8.bf16", [], [], [Pure], 1>,
+    Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+    let summary = "Scale and convert packed bf16 to packed fp8 using stochastic rounding";
+    let description = [{
+    Scale `src` by the exponent in `scale` then convert to packed fp8 with stochastic rounding
+    using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleSrBf8Bf16:
+    ROCDL_IntrOp<"cvt.scalef32.sr.bf8.bf16", [], [], [Pure], 1>,
+    Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+    let summary = "Scale and convert bf16 to packed fp8 using stochastic rounding";
+    let description = [{
+    Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding
+    using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScalePkF16Fp8 :
+    ROCDL_IntrOp<"cvt.scalef32.pk.f16.fp8", [], [], [Pure], 1>,
+    Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
+    let summary = "Scale and convert fp8 to packed f16";
+    let description = [{ Scale `src` based on $wordSel by the exponent in `scale` 
+    then convert to packed f16.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScalePkF16Bf8 :
+    ROCDL_IntrOp<"cvt.scalef32.pk.f16.bf8", [], [], [Pure], 1>,
+    Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
+    let summary = "Scale and convert bf8 to packed f16";
+    let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
+    then convert to packed f16.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF16Fp8 :
+    ROCDL_IntrOp<"cvt.scalef32.f16.fp8", [], [], [Pure], 1>,
+    Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
+    let summary = "Scale and convert fp8 to f16";
+    let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
+    then convert to f16 store into the `byteSel`th byte of `old`, preserving the others.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF16Bf8 :
+    ROCDL_IntrOp<"cvt.scalef32.f16.bf8", [], [], [Pure], 1>,
+    Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
+    let summary = "Scale and convert fp8 to f16";
+    let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
+    then convert to f16 store into the `byteSel`th byte of `old`, preserving the others.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
 //===---------------------------------------------------------------------===//
 // 32-bit float intrinsics
 //===---------------------------------------------------------------------===//
@@ -697,10 +867,6 @@ def ROCDL_CvtScalePkF32Bf8 :
 //===---------------------------------------------------------------------===//
 // 8-bit float scale intrinsics
 //===---------------------------------------------------------------------===//
-def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
-                        BuildableType<"::mlir::VectorType::get("
-                          "{2},$_builder.getI16Type())">;
-
 def ROCDL_CvtScaleF32PkFp8F32:
     ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f32", [], [], [Pure], 1>,
     Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32:$scale, I1:$wordSel)> {
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 05469b64a8083..bc917041998d8 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -759,19 +759,27 @@ llvm.func @rocdl.raw.buffer.i32(%rsrc : vector<4xi32>,
   llvm.return
 }
 
-llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
+llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf16, %stoch: i32) -> i32 {
 // CHECK-LABEL: @rocdl_8bit_floats
 // CHECK: rocdl.cvt.f32.bf8
 // CHECK: rocdl.cvt.f32.fp8
 // CHECK: rocdl.cvt.scalef32.f32.bf8
 // CHECK: rocdl.cvt.scalef32.f32.fp8
+// CHECK: rocdl.cvt.scalef32.pk.f16.bf8 
+// CHECK: rocdl.cvt.scalef32.pk.f16.fp8 
+// CHECK: rocdl.cvt.scalef32.f16.fp8
+// CHECK: rocdl.cvt.scalef32.f16.bf8
 // CHECK: rocdl.cvt.pk.bf8.f32
 // CHECK: rocdl.cvt.pk.fp8.f32
 // CHECK: rocdl.cvt.sr.bf8.f32
 // CHECK: rocdl.cvt.sr.fp8.f32
 // CHECK: rocdl.cvt.scalef32.sr.fp8.f32
+// CHECK: rocdl.cvt.scalef32.sr.fp8.f16
+// CHECK: rocdl.cvt.scalef32.sr.fp8.bf16
 // CHECK: rocdl.cvt.sr.bf8.f32
 // CHECK: rocdl.cvt.scalef32.sr.bf8.f32
+// CHECK: rocdl.cvt.scalef32.sr.bf8.f16
+// CHECK: rocdl.cvt.scalef32.sr.bf8.bf16
 // CHECK: rocdl.cvt.scalef32.pk.f32.fp8
 // CHECK: rocdl.cvt.scalef32.pk.f32.bf8
   %c0 = llvm.mlir.constant(0 : i32) : i32
@@ -783,13 +791,21 @@ llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
   %v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32
   %v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[%c0], %c4 : f32
   %v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32
+  %v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[%false], %c4 : vector<2xf16>
+  %v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : vector<2xf16>
+  %v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %v3_scaled[%c0] : f16
+  %v6  = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %v3_scaled[%c0] : f16
   %source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
   %source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
   %source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
   %source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
   %source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32
+  %source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[%c3] : i32
+  %source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[%c3] : i32
   %source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c3] : i32
   %source6_scaled  = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[%c3] : i32
+  %source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[%c3] : i32
+  %source6_scaled_bfloat =  rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[%c3] : i32
   %source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[%false], %c4 : f32
   %source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[%false], %c4 : f32
   llvm.return %source5 : i32
@@ -805,6 +821,18 @@ llvm.func @rocdl_8bit_packed_v2i16(%sourceA: f32, %sourceB: f32, %old: vector<2x
   llvm.return %source_scaled : vector<2xi16>
 }
 
+llvm.func @rocdl_v2f16_v2i16(%source: vector<2xf16>, %source2: vector<2xbf16>, %old: vector<2xi16>) -> vector<2xi16> {
+// CHECK-LABEL: @rocdl_v2f16_v2i16
+// CHECK: rocdl.cvt.scalef32.pk.fp8.f16
+  %c0 = llvm.mlir.constant(1.0 : f32) : f32
+  %false = llvm.mlir.constant(false) : i1
+  %source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
+  %source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
+  %source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
+  %source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
+  llvm.return %source_scaled : vector<2xi16>
+}
+
 llvm.func @rocdl.s.waitcnt() {
   // CHECK-LABEL: rocdl.s.waitcnt
   // CHECK: rocdl.s.waitcnt 0
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 981ef4848535c..11f2faa2761ff 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1032,22 +1032,29 @@ llvm.func @rocdl.raw.buffer.atomic.cmpswap(%rsrc : vector<4xi32>,
   llvm.return %val : i32
 }
 
-llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
+llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf16, %source_packed: vector<2xf16>, %stoch: i32) -> i32 {
 // CHECK-LABEL: @rocdl_8bit_floats
 // CHECK: call float @llvm.amdgcn.cvt.f32.bf8(i32 %{{.+}}, i32 0)
 // CHECK: call float @llvm.amdgcn.cvt.scalef32.f32.bf8(i32 %{{.+}}, float 1.000000e+00, i32 0)
 // CHECK: call float @llvm.amdgcn.cvt.f32.fp8(i32 %{{.+}}, i32 0)
 // CHECK: call float @llvm.amdgcn.cvt.scalef32.f32.fp8(i32 %{{.+}}, float 1.000000e+00, i32 0)
+// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.pk.f16.bf8(i32 %{{.+}}, float 1.000000e+00, i1 false)
+// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.pk.f16.fp8(i32 %{{.+}}, float 1.000000e+00, i1 false)
+// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.f16.fp8(<2 x half> %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 0, i1 false)
+// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.f16.bf8(<2 x half> %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 0, i1 false)
 // CHECK: call i32 @llvm.amdgcn.cvt.pk.bf8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false)
 // CHECK: call i32 @llvm.amdgcn.cvt.pk.fp8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false)
 // CHECK: call i32 @llvm.amdgcn.cvt.sr.bf8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 2)
 // CHECK: call i32 @llvm.amdgcn.cvt.sr.fp8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 3)
 // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.fp8.f32(i32 %{{.+}}, float %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
+// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.fp8.f16(i32 %{{.+}}, half %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
+// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.fp8.bf16(i32 %{{.+}}, bfloat %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
 // CHECK: call i32 @llvm.amdgcn.cvt.sr.bf8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 3)
 // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.bf8.f32(i32 %{{.+}}, float %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
+// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.bf8.f16(i32 %{{.+}}, half %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
+// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.bf8.bf16(i32 %{{.+}}, bfloat %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
 // CHECK: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.fp8(i32 %{{.+}}, float 1.000000e+00, i1 false)
 // CHECK: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.bf8(i32 %{{.+}}, float 1.000000e+00, i1 false)
-
   %c0 = llvm.mlir.constant(0 : i32) : i32
   %c2 = llvm.mlir.constant(2 : i32) : i32
   %c3 = llvm.mlir.constant(3 : i32) : i32
@@ -1057,13 +1064,21 @@ llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
   %v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[%c0], %c4 : f32
   %v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32
   %v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32
+  %v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[%false], %c4 : i32
+  %v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : i32
+  %v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %source_packed[%c0] : f16
+  %v6 = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %source_packed[%c0] : f16
   %source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
   %source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
   %source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
   %source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
   %source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32
+  %source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[%c3] : i32
+  %source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[%c3] : i32
   %source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c3] : i32
   %source6_scaled  = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[%c3] : i32
+  %source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[%c3] : i32
+  %source6_scaled_bfloat = rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[%c3] : i32
   %source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[%false], %c4 : f32
   %source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[%false], %c4 : f32
   llvm.return %source5 : i32
@@ -1080,6 +1095,21 @@ llvm.func @rocdl_8bit_packed_v2i16(%sourceA: f32, %sourceB: f32, %old: vector<2x
   llvm.return %source_scaled : vector<2xi16>
 }
 
+llvm.func @rocdl_v2f16_v2i16(%source: vector<2xf16>, %source2: vector<2xbf16>, %old: vector<2xi16>) -> vector<2xi16> {
+// CHECK-LABEL: @rocdl_v2f16_v2i16
+// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.fp8.f16(<2 x i16> %2, <2 x half> %0, float 1.000000e+00, i1 false)
+// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.fp8.bf16(<2 x i16> %2, <2 x bfloat> %1, float 1.000000e+00, i1 false)
+// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.bf8.f16(<2 x i16> %2, <2 x half> %0, float 1.000000e+00, i1 false)
+// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.bf8.bf16(<2 x i16> %2, <2 x bfloat> %1, float 1.000000e+00, i1 false)
+  %c0 = llvm.mlir.constant(1.0 : f32) : f32
+  %false = llvm.mlir.constant(false) : i1
+  %source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
+  %source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
+  %source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
+  %source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
+  llvm.return %source_scaled : vector<2xi16>
+}
+
 llvm.func @rocdl_16bit_packed_floats(%sourceA: f32, %sourceB: f32) -> vector<2xf16> {
   // CHECK-LABEL: @rocdl_16bit_packed_floats
   // CHECK: call <2 x half> @llvm.amdgcn.cvt.pkrtz(float {{.*}}, float {{.*}})

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The word and byte selectors are immarg in LLVM and so should be attributes in MLIR

I don't have other objections that I could see

@CRobeck
Copy link
Member Author

CRobeck commented Mar 7, 2025

While I don't disagree that would require a pretty involved rewrite of the ROCDLOps.td tablegen file since almost every existing conversion op includes them in the op definition. If we decide to do that it should be a stand alone PR.

@krzysz00 krzysz00 self-requested a review March 7, 2025 17:50
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, yeah, you're right

Agreed that someone should fix that at some point in a followup PR

Approved

@CRobeck CRobeck force-pushed the users/crobeck/gfx-950-f16 branch from 7a39341 to 08e177d Compare March 7, 2025 18:06
@CRobeck CRobeck merged commit 50cfdf5 into main Mar 7, 2025
11 checks passed
@CRobeck CRobeck deleted the users/crobeck/gfx-950-f16 branch March 7, 2025 19:00
antiagainst pushed a commit to triton-lang/triton that referenced this pull request Mar 7, 2025
Pulls in llvm/llvm-project#130300 for gfx950
type conversion ops.
ThomasRaoux pushed a commit to triton-lang/triton that referenced this pull request Mar 10, 2025
Pulls in llvm/llvm-project#130300 for gfx950
type conversion ops.

This update includes changes that store the parsed Triple in the module
and deprecate `match` and `rewrite` functions.
jph-13 pushed a commit to jph-13/llvm-project that referenced this pull request Mar 21, 2025
…GFX950 (llvm#130300)

Add Rocdl support for the following GFX950 instructions:

CVT_SCALE_PK_FP8_F16
CVT_SCALE_PK_BF8_F16 
CVT_SCALE_PK_FP8_BF16 
CVT_SCALE_PK_BF8_BF16 
CVT_SCALE_SR_FP8_F16 
CVT_SCALE_SR_BF8_F16 
CVT_SCALE_SR_FP8_BF16 
CVT_SCALE_SR_BF8_BF16 
CVT_SCALE_PK_F16_FP8 
CVT_SCALE_PK_F16_BF8 
CVT_SCALE_F16_FP8 
CVT_SCALE_F16_BF8
makslevental pushed a commit to makslevental/triton that referenced this pull request Mar 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants