Skip to content

[MLIR][ROCDL] Add Scale Convert Packed FP4 <-> F32/BF16/F16 Support for GFX950 #140676

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

tgymnich
Copy link
Member

@tgymnich tgymnich commented May 20, 2025

Rocdl support for the following GFX950 instructions:

  • rocdl.cvt.scalef32.pk.fp4.f32
  • rocdl.cvt.scalef32.pk.fp4.f16
  • rocdl.cvt.scalef32.pk.fp4.bf16
  • rocdl.cvt.scalef32.sr.pk.fp4.f32
  • rocdl.cvt.scalef32.sr.pk.fp4.f16
  • rocdl.cvt.scalef32.sr.pk.fp4.bf16
  • rocdl.cvt.scalef32.pk.f32.fp4
  • rocdl.cvt.scalef32.pk.f16.fp4
  • rocdl.cvt.scalef32.pk.bf16.fp4

@tgymnich tgymnich requested review from CRobeck and krzysz00 May 20, 2025 05:04
@tgymnich tgymnich marked this pull request as ready for review May 20, 2025 05:04
@llvmbot
Copy link
Member

llvmbot commented May 20, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Tim Gymnich (tgymnich)

Changes
  • Add Scale Convert Packed FP4 to rocdl

Full diff: https://github.com/llvm/llvm-project/pull/140676.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+118)
  • (modified) mlir/test/Dialect/LLVMIR/rocdl.mlir (+25)
  • (modified) mlir/test/Target/LLVMIR/rocdl.mlir (+25)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 6fb9e3aba1f0a..f52c2c391fbba 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -713,6 +713,10 @@ def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
                         BuildableType<"::mlir::VectorType::get("
                           "{2},$_builder.getI16Type())">;
 
+def ROCDL_V2F32Type : FixedVectorOfLengthAndType<[2], [F32]>,
+                        BuildableType<"::mlir::VectorType::get("
+                          "{2},$_builder.getF32Type())">;
+
 def ROCDL_V2F16Type : FixedVectorOfLengthAndType<[2], [F16]>,
                         BuildableType<"::mlir::VectorType::get("
                           "{2},$_builder.getF16Type())">;
@@ -1005,6 +1009,120 @@ def ROCDL_CvtScaleF32SrBf8F32Op :
   }];
 }
 
+//===---------------------------------------------------------------------===//
+// 4-bit float scale intrinsics
+//===---------------------------------------------------------------------===//
+def ROCDL_CvtScaleF32PkFp4F32Op :
+    ROCDL_IntrOp<"cvt.scalef32.pk.fp4.f32", [], [], [Pure], 1>,
+    Arguments<(ins I32:$old, F32:$src0, F32:$src1, F32: $scale, I32:$byteSel)> {
+    let summary = "Convert f32 to packed fp4 and scale";
+    let description = [{ Convert `src` based on $$byteSe to packed fp4, then scale
+    the packed values by the exponent in `scale`.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src0 `,` $src1 `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF32PkFp4F16Op :
+    ROCDL_IntrOp<"cvt.scalef32.pk.fp4.f16", [], [], [Pure], 1>,
+    Arguments<(ins I32:$old, ROCDL_V2F16Type:$src, F32: $scale, I32:$byteSel)> {
+    let summary = "Convert f16 to packed fp4 and scale";
+    let description = [{ Convert `src` based on $byteSel to packed fp4, then scale
+    the packed values by the exponent in `scale`.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF32PkFp4Bf16Op :
+    ROCDL_IntrOp<"cvt.scalef32.pk.fp4.bf16", [], [], [Pure], 1>,
+    Arguments<(ins I32:$old, ROCDL_V2BF16Type:$src, F32: $scale, I32:$byteSel)> {
+    let summary = "Convert bf16 to packed fp4 and scale";
+    let description = [{ Convert `src` based on $byteSel to packed fp4, then scale
+    the packed values by the exponent in `scale`.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF32SrPkFp4F32Op :
+    ROCDL_IntrOp<"cvt.scalef32.sr.pk.fp4.f32", [], [], [Pure], 1>,
+    Arguments<(ins I32:$old, ROCDL_V2F32Type:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+    let summary = "Scale and convert f32 to packed fp4 using stochastic rounding";
+    let description = [{
+       Scale `src` by the exponent in `scale` then convert to packed fp4 with stochastic rounding
+       using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+    }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF32SrPkFp4F16Op :
+    ROCDL_IntrOp<"cvt.scalef32.sr.pk.fp4.f16", [], [], [Pure], 1>,
+    Arguments<(ins I32:$old, ROCDL_V2F16Type:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+    let summary = "Scale and convert f16 to packed fp4 using stochastic rounding";
+    let description = [{
+       Scale `src` by the exponent in `scale` then convert to packed fp4 with stochastic rounding
+       using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+    }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF32SrPkFp4Bf16Op :
+    ROCDL_IntrOp<"cvt.scalef32.sr.pk.fp4.bf16", [], [], [Pure], 1>,
+    Arguments<(ins I32:$old, ROCDL_V2BF16Type:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+    let summary = "Scale and convert bf16 to packed fp4 using stochastic rounding";
+    let description = [{
+       Scale `src` by the exponent in `scale` then convert to packed fp4 with stochastic rounding
+       using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+    }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF32PkF32Fp4Op :
+    ROCDL_IntrOp<"cvt.scalef32.pk.f32.fp4", [], [], [Pure], 1>,
+    Arguments<(ins I32:$src, F32:$scale, I32:$byteSel)> {
+    let summary = "Convert fp4 to packed f32 and scale";
+    let description = [{ Convert `src` based on $byteSel to packed f32, then scale
+    the packed values by the exponent in `scale`.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
+  }];
+}
+
+
+def ROCDL_CvtScaleF32PkF16Fp4Op :
+    ROCDL_IntrOp<"cvt.scalef32.pk.f16.fp4", [], [], [Pure], 1>,
+    Arguments<(ins I32:$src, F32:$scale, I32:$byteSel)> {
+    let summary = "Convert fp4 to packed f16 and scale";
+    let description = [{ Convert `src` based on $byteSel to packed f16, then scale
+    the packed values by the exponent in `scale`.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF32PkBf16Fp4Op :
+    ROCDL_IntrOp<"cvt.scalef32.pk.bf16.fp4", [], [], [Pure], 1>,
+    Arguments<(ins I32:$src, F32:$scale, I32:$byteSel)> {
+    let summary = "Convert fp4 to packed bf16 and scale";
+    let description = [{ Convert `src` based on $byteSel to packed bf16, then scale
+    the packed values by the exponent in `scale`.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
+  }];
+}
 //===---------------------------------------------------------------------===//
 // 8-bit float intrinsics
 //===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index fbde993891342..84fa29ee2d8a1 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -844,6 +844,31 @@ llvm.func @rocdl_v2f16_v2i16(%source: vector<2xf16>, %source2: vector<2xbf16>, %
   llvm.return %source_scaled : vector<2xi16>
 }
 
+llvm.func @rocdl_4bit_packed_floats(%old: i32, %source0: f32, %source1: f32, %source: vector<2xf32>, %source_half: vector<2xf16>, %source_bfloat: vector<2xbf16>, %stoch: i32) -> i32 {
+  // CHECK-LABEL: @rocdl_4bit_packed_floats
+  // CHECK: rocdl.cvt.scalef32.pk.fp4.f32
+  // CHECK: rocdl.cvt.scalef32.pk.fp4.f16
+  // CHECK: rocdl.cvt.scalef32.pk.fp4.bf16
+  // CHECK: rocdl.cvt.scalef32.sr.pk.fp4.f32
+  // CHECK: rocdl.cvt.scalef32.sr.pk.fp4.f16 
+  // CHECK: rocdl.cvt.scalef32.sr.pk.fp4.bf16
+  // CHECK: rocdl.cvt.scalef32.pk.f32.fp4
+  // CHECK: rocdl.cvt.scalef32.pk.f16.fp4
+  // CHECK: rocdl.cvt.scalef32.pk.bf16.fp4
+  %c0 = llvm.mlir.constant(0 : i32) : i32
+  %scale = llvm.mlir.constant(1.0 : f32) : f32
+  %pk1 = rocdl.cvt.scalef32.pk.fp4.f32 %source0, %source1, %scale -> %old[%c0] : i32
+  %pk2 = rocdl.cvt.scalef32.pk.fp4.f16 %source_half, %scale -> %pk1[%c0] : i32
+  %pk3 = rocdl.cvt.scalef32.pk.fp4.bf16 %source_bfloat, %scale -> %pk2[%c0] : i32
+  %sr1 = rocdl.cvt.scalef32.sr.pk.fp4.f32 %source, %stoch, %scale -> %pk3[%c0]  : i32
+  %sr2 = rocdl.cvt.scalef32.sr.pk.fp4.f16 %source_half, %stoch, %scale -> %sr1[%c0] : i32
+  %sr3 = rocdl.cvt.scalef32.sr.pk.fp4.bf16 %source_bfloat, %stoch, %scale -> %sr2[%c0] : i32
+  %pk4 = rocdl.cvt.scalef32.pk.f32.fp4 %old[%c0], %scale : vector<2xf32>
+  %pk5 = rocdl.cvt.scalef32.pk.f16.fp4 %old[%c0], %scale : vector<2xf16>
+  %pk6 = rocdl.cvt.scalef32.pk.bf16.fp4 %old[%c0], %scale : vector<2xbf16>
+  llvm.return %sr3 : i32
+}
+
 llvm.func @rocdl.s.waitcnt() {
   // CHECK-LABEL: rocdl.s.waitcnt
   // CHECK: rocdl.s.waitcnt 0
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index b37f0da361950..16efce4a80908 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1145,6 +1145,31 @@ llvm.func @rocdl_16bit_packed_floats(%sourceA: f32, %sourceB: f32) -> vector<2xf
   llvm.return %source : vector<2xf16>
 }
 
+llvm.func @rocdl_4bit_packed_floats(%old: i32, %source0: f32, %source1: f32, %source: vector<2xf32>, %source_half: vector<2xf16>, %source_bfloat: vector<2xbf16>, %stoch: i32) -> i32 {
+  // CHECK-LABEL: @rocdl_4bit_packed_floats
+  // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.f32(i32 %0, float %1, float %2, float 1.000000e+00, i32 0)
+  // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.f16(i32 %8, <2 x half> %4, float 1.000000e+00, i32 0)
+  // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.bf16(i32 %9, <2 x bfloat> %5, float 1.000000e+00, i32 0)
+  // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.f32(i32 %10, <2 x float> %3, i32 %6, float 1.000000e+00, i32 0)
+  // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.f16(i32 %11, <2 x half> %4, i32 %6, float 1.000000e+00, i32 0)
+  // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.bf16(i32 %12, <2 x bfloat> %5, i32 %6, float 1.000000e+00, i32 0)
+  // CHECK: call <2xfloat> @llvm.amdgcn.cvt.scalef32.pk.f32.fp4(i32 %0, float 1.000000e+00, i32 0)
+  // CHECK: call <2xhalf> @llvm.amdgcn.cvt.scalef32.pk.f16.fp4(i32 %0, float 1.000000e+00, i32 0)
+  // CHECK: call <2xbfloat> @llvm.amdgcn.cvt.scalef32.pk.bf16.fp4(i32 %0, float 1.000000e+00, i32 0)
+  %c0 = llvm.mlir.constant(0 : i32) : i32
+  %scale = llvm.mlir.constant(1.0 : f32) : f32
+  %pk1 = rocdl.cvt.scalef32.pk.fp4.f32 %source0, %source1, %scale -> %old[%c0] : i32
+  %pk2 = rocdl.cvt.scalef32.pk.fp4.f16 %source_half, %scale -> %pk1[%c0] : i32
+  %pk3 = rocdl.cvt.scalef32.pk.fp4.bf16 %source_bfloat, %scale -> %pk2[%c0] : i32
+  %sr1 = rocdl.cvt.scalef32.sr.pk.fp4.f32 %source, %stoch, %scale -> %pk3[%c0]  : i32
+  %sr2 = rocdl.cvt.scalef32.sr.pk.fp4.f16 %source_half, %stoch, %scale -> %sr1[%c0] : i32
+  %sr3 = rocdl.cvt.scalef32.sr.pk.fp4.bf16 %source_bfloat, %stoch, %scale -> %sr2[%c0] : i32
+  %pk4 = rocdl.cvt.scalef32.pk.f32.fp4 %old[%c0], %scale : vector<2xf32>
+  %pk5 = rocdl.cvt.scalef32.pk.f16.fp4 %old[%c0], %scale : vector<2xf16>
+  %pk6 = rocdl.cvt.scalef32.pk.bf16.fp4 %old[%c0], %scale : vector<2xbf16>
+  llvm.return %sr3 : i32
+}
+
 llvm.func @rocdl_atomic_attrs(%ptr: !llvm.ptr<1>, %data: f32) {
   // CHECK-LABEL: @rocdl_atomic_attrs
   // CHECK: atomicrmw

@CRobeck
Copy link
Member

CRobeck commented May 20, 2025

These will throw an unsupported arch error at the LLVM IR level if compiled for arch < GFX950 but is it worth having some MLIR op verifiers for the arch in the ROCDL op definition (I realize that all the other, existing ops, don't have any arch based verifiers)?

@krzysz00
Copy link
Contributor

... I'll have a competing PR up an a few hours that also does a bit of a refactor of these intrinsics / has immarg handled correctly

@krzysz00
Copy link
Contributor

@CRobeck Such verification happens, if applicable, while rewriting higher-level operations to intrinsics. Also, since MLIR doesn't have a coherent way to check which architecture a piece of code is targetting, there's nowhere to put the verifier you're interested in

Also, competing PR at #140801

@tgymnich
Copy link
Member Author

tgymnich commented May 20, 2025

closing in favor of #140801

@tgymnich tgymnich closed this May 20, 2025
@tgymnich tgymnich deleted the tim/rocdl-fp4 branch May 21, 2025 05:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants