Skip to content

[mlir][AMDGPU] Add gfx950 MFMAs to the amdgpu.mfma op #133553

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 1, 2025

Conversation

krzysz00
Copy link
Contributor

This commit extends the lowering of amdgpu.mfma to handle the new double-rate MFMAs in gfx950 and adds tests for these operations.

It also adds support for MFMAs on small floats (f6 and f4), which are implented using the "scaled" MFMA intrinsic with a scale value of 0 in order to have an unscaled MFMA.

This commit does not add a amdgpu.scaled_mfma operation, as that is future work.

This commit extends the lowering of amdgpu.mfma to handle the new
double-rate MFMAs in gfx950 and adds tests for these operations.

It also adds support for MFMAs on small floats (f6 and f4), which are
implented using the "scaled" MFMA intrinsic with a scale value of 0 in
order to have an unscaled MFMA.

This commit does not add a `amdgpu.scaled_mfma` operation, as that is
future work.
@llvmbot
Copy link
Member

llvmbot commented Mar 29, 2025

@llvm/pr-subscribers-mlir-amdgpu

@llvm/pr-subscribers-mlir-gpu

Author: Krzysztof Drewniak (krzysz00)

Changes

This commit extends the lowering of amdgpu.mfma to handle the new double-rate MFMAs in gfx950 and adds tests for these operations.

It also adds support for MFMAs on small floats (f6 and f4), which are implented using the "scaled" MFMA intrinsic with a scale value of 0 in order to have an unscaled MFMA.

This commit does not add a amdgpu.scaled_mfma operation, as that is future work.


Patch is 23.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133553.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+6-4)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+135-29)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+8-6)
  • (added) mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir (+53)
  • (modified) mlir/test/Dialect/AMDGPU/invalid.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index c0b3e5540b1df..9cdd961d96ff5 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -650,10 +650,12 @@ def AMDGPU_MFMAPermBAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_MFMAPermB,
 // mfma
 def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64,
                              VectorOfLengthAndType<[2], [F32]>,
-                             VectorOfLengthAndType<[4], [F16]>,
-                             VectorOfLengthAndType<[2, 4], [BF16]>,
-                             VectorOfLengthAndType<[4, 8], [I8]>,
-                             VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>;
+                             VectorOfLengthAndType<[4, 8], [F16]>,
+                             VectorOfLengthAndType<[2, 4, 8], [BF16]>,
+                             VectorOfLengthAndType<[4, 8, 16], [I8]>,
+                             VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>,
+                             VectorOfLengthAndType<[8, 32], [F8E5M2, F8E4M3FN]>,
+                             VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
 def MFMAOutTypes : AnyTypeOf<[F64,
                               VectorOfLengthAndType<[4, 16, 32], [F32]>,
                               VectorOfLengthAndType<[4, 16, 32], [I32]>,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3acd470cff7f5..77823fd2c52bf 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -22,6 +22,7 @@
 #include "../LLVMCommon/MemRefDescriptor.h"
 
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include <optional>
 
 namespace mlir {
@@ -36,6 +37,7 @@ using namespace mlir::amdgpu;
 constexpr Chipset kGfx908 = Chipset(9, 0, 8);
 constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
 constexpr Chipset kGfx942 = Chipset(9, 4, 2);
+constexpr Chipset kGfx950 = Chipset(9, 5, 0);
 
 /// Convert an unsigned number `val` to i32.
 static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
@@ -494,8 +496,11 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
 /// and LLVM AMDGPU intrinsics convention.
 ///
 /// Specifically:
-/// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
-/// 2. If the element type is bfloat16, bitcast it to i16.
+/// 1. If the element type is bfloat16, bitcast it to i16.
+/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
+/// instead, which is what the f8f6f4 intrinsics use.
+/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
+/// integer.
 static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
                                       Location loc, Value input) {
   Type inputType = input.getType();
@@ -503,10 +508,19 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
     if (vectorType.getElementType().isBF16())
       return rewriter.create<LLVM::BitcastOp>(
           loc, vectorType.clone(rewriter.getI16Type()), input);
-    if (vectorType.getElementType().isInteger(8)) {
+    if (vectorType.getElementType().isInteger(8) &&
+        vectorType.getNumElements() <= 8)
       return rewriter.create<LLVM::BitcastOp>(
           loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
-    }
+    if (isa<IntegerType>(vectorType.getElementType()) &&
+        vectorType.getElementTypeBitWidth() <= 8)
+      return rewriter.create<LLVM::BitcastOp>(
+          loc,
+          VectorType::get((vectorType.getNumElements() *
+                           vectorType.getElementTypeBitWidth()) /
+                              32,
+                          rewriter.getI32Type()),
+          input);
   }
   return input;
 }
@@ -622,12 +636,8 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
                                                   Chipset chipset) {
   uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
            b = mfma.getBlocks();
-  Type sourceElem = mfma.getSourceA().getType();
-  if (auto sourceType = dyn_cast<VectorType>(sourceElem))
-    sourceElem = sourceType.getElementType();
-  Type destElem = mfma.getDestC().getType();
-  if (auto destType = dyn_cast<VectorType>(destElem))
-    destElem = destType.getElementType();
+  Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
+  Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
 
   if (sourceElem.isF32() && destElem.isF32()) {
     if (mfma.getReducePrecision() && chipset >= kGfx942) {
@@ -649,6 +659,12 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
   }
 
   if (sourceElem.isF16() && destElem.isF32()) {
+    if (chipset >= kGfx950) {
+      if (m == 32 && n == 32 && k == 16 && b == 1)
+        return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
+      if (m == 16 && n == 16 && k == 32 && b == 1)
+        return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
+    }
     if (m == 32 && n == 32 && k == 4 && b == 2)
       return ROCDL::mfma_f32_32x32x4f16::getOperationName();
     if (m == 16 && n == 16 && k == 4 && b == 4)
@@ -661,20 +677,25 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
       return ROCDL::mfma_f32_16x16x16f16::getOperationName();
   }
 
-  if (sourceElem.isBF16() && destElem.isF32() && chipset >= kGfx90a) {
-    if (m == 32 && n == 32 && k == 4 && b == 2)
-      return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
-    if (m == 16 && n == 16 && k == 4 && b == 4)
-      return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
-    if (m == 4 && n == 4 && k == 4 && b == 16)
-      return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
-    if (m == 32 && n == 32 && k == 8 && b == 1)
-      return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
-    if (m == 16 && n == 16 && k == 16 && b == 1)
-      return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
-  }
-
   if (sourceElem.isBF16() && destElem.isF32()) {
+    if (chipset >= kGfx950) {
+      if (m == 32 && n == 32 && k == 16 && b == 1)
+        return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
+      if (m == 16 && n == 16 && k == 32 && b == 1)
+        return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
+    }
+    if (chipset >= kGfx90a) {
+      if (m == 32 && n == 32 && k == 4 && b == 2)
+        return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
+      if (m == 16 && n == 16 && k == 4 && b == 4)
+        return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
+      if (m == 4 && n == 4 && k == 4 && b == 16)
+        return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
+      if (m == 32 && n == 32 && k == 8 && b == 1)
+        return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
+      if (m == 16 && n == 16 && k == 16 && b == 1)
+        return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
+    }
     if (m == 32 && n == 32 && k == 2 && b == 2)
       return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
     if (m == 16 && n == 16 && k == 2 && b == 4)
@@ -687,7 +708,14 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
       return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
   }
 
-  if (isa<IntegerType>(sourceElem) && destElem.isInteger(32)) {
+  if (isa<IntegerType>(sourceElem) && sourceElem.getIntOrFloatBitWidth() >= 8 &&
+      destElem.isInteger(32)) {
+    if (chipset >= kGfx950) {
+      if (m == 32 && n == 32 && k == 32 && b == 1)
+        return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
+      if (m == 16 && n == 16 && k == 64 && b == 1)
+        return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
+    }
     if (m == 32 && n == 32 && k == 4 && b == 2)
       return ROCDL::mfma_i32_32x32x4i8::getOperationName();
     if (m == 16 && n == 16 && k == 4 && b == 4)
@@ -750,6 +778,59 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
   return std::nullopt;
 }
 
+static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
+  return llvm::TypeSwitch<Type, std::optional<uint32_t>>(mlirElemType)
+      .Case([](Float8E4M3FNType) { return 0u; })
+      .Case([](Float8E5M2Type) { return 1u; })
+      .Case([](Float6E2M3FNType) { return 2u; })
+      .Case([](Float6E3M2FNType) { return 3u; })
+      .Case([](Float4E2M1FNType) { return 4u; })
+      .Default([](Type) { return std::nullopt; });
+}
+
+/// If there is a scaled MFMA intsruction for the input element types `aType`
+/// and `bType`, output type `destType`, problem size M, N, K, and B (number of
+/// blocks) on the given `chipset`, return a tuple consisting of the
+/// OperationName of the intrinsic and the type codes that need to be passed to
+/// that intrinsic. Note that this is also used to implement some un-scaled
+/// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
+/// MFMA with a scale of 0.
+static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
+                        uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
+  aType = getElementTypeOrSelf(aType);
+  bType = getElementTypeOrSelf(bType);
+  destType = getElementTypeOrSelf(destType);
+
+  if (chipset < kGfx950)
+    return std::nullopt;
+  if (!isa<Float32Type>(destType))
+    return std::nullopt;
+
+  std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(aType);
+  std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(bType);
+  if (!aTypeCode || !bTypeCode)
+    return std::nullopt;
+
+  if (m == 32 && n == 32 && k == 64 && b == 1)
+    return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
+                      *aTypeCode, *bTypeCode};
+  if (m == 16 && n == 16 && k == 128 && b == 1)
+    return std::tuple{
+        ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
+        *bTypeCode};
+
+  return std::nullopt;
+}
+
+static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
+  return mfmaOpToScaledIntrinsic(
+      mfma.getSourceA().getType(), mfma.getSourceB().getType(),
+      mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
+      mfma.getBlocks(), chipset);
+}
+
 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
 /// if one exists. This includes checking to ensure the intrinsic is supported
 /// on the architecture you are compiling for.
@@ -829,16 +910,41 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
           op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
     }
     std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
-    if (!maybeIntrinsic.has_value())
+    std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+        maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
+    if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
       return op.emitOpError("no intrinsic matching MFMA size on given chipset");
-    OperationState loweredOp(loc, *maybeIntrinsic);
+
+    bool isScaled =
+        !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
+    if (isScaled &&
+        (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
+      return op.emitOpError(
+          "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
+          "be scaled as those fields are used for type information");
+    }
+
+    StringRef intrinsicName =
+        isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
+    OperationState loweredOp(loc, intrinsicName);
     loweredOp.addTypes(intrinsicOutType);
     loweredOp.addOperands(
         {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
          convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
-         adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
-         createI32Constant(rewriter, loc, op.getAbid()),
-         createI32Constant(rewriter, loc, getBlgpField)});
+         adaptor.getDestC()});
+    if (isScaled) {
+      Value zero = createI32Constant(rewriter, loc, 0);
+      auto [scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
+      std::ignore = scaledName;
+      loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
+                             createI32Constant(rewriter, loc, bTypeCode),
+                             /*scale A byte=*/zero, /*scale A=*/zero,
+                             /*scale B byte=*/zero, /*scale B=*/zero});
+    } else {
+      loweredOp.addOperands({createI32Constant(rewriter, loc, op.getCbsz()),
+                             createI32Constant(rewriter, loc, op.getAbid()),
+                             createI32Constant(rewriter, loc, getBlgpField)});
+    };
     Value lowered = rewriter.create(loweredOp)->getResult(0);
     if (outType != intrinsicOutType)
       lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 2b2a167b90c82..1e482515a4ee0 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -341,22 +341,24 @@ LogicalResult MFMAOp::verify() {
   }
 
   Type sourceBType = getSourceB().getType();
-  if (sourceElem.isFloat(8)) {
+  if (sourceElem.isFloat(8) || sourceElem.isFloat(6) || sourceElem.isFloat(4)) {
     int64_t sourceBLen = 1;
     Type sourceBElem = sourceBType;
     if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
       sourceBLen = sourceBVector.getNumElements();
       sourceBElem = sourceBVector.getElementType();
     }
-    if (!sourceBElem.isFloat(8))
-      return emitOpError("expected both source operands to have f8 elements");
+    if (!sourceBElem.isFloat(8) && !sourceBElem.isFloat(6) &&
+        !sourceBElem.isFloat(4))
+      return emitOpError("expected both source operands to have small-float "
+                         "elements if one does");
     if (sourceLen != sourceBLen)
       return emitOpError(
-          "expected both f8 source vectors to have the same length");
+          "expected both small-float source vectors to have the same length");
   } else {
     if (sourceType != sourceBType)
-      return emitOpError(
-          "expected both non-f8 source operand types to match exactly");
+      return emitOpError("expected both non-small-float source operand types "
+                         "to match exactly");
   }
   // Normalize the wider integer types the compiler expects to i8
   if (sourceElem.isInteger(32)) {
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
new file mode 100644
index 0000000000000..de63f249bb530
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
@@ -0,0 +1,53 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 -cse | FileCheck %s
+func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
+                    %arg2 : vector<4xf32>, %arg3 : vector<8xbf16>,
+                    %arg4 : vector<16xi8>, %arg5 : vector<16xi32>,
+                    %arg6 : vector<4xi32>, %arg7 : vector<32xf8E4M3FN>,
+                    %arg8 : vector<32xf8E5M2>, %arg9 : vector<32xf6E2M3FN>,
+                    %arg10 : vector<32xf6E3M2FN>, %arg11 : vector<32xf4E2M1FN>) {
+  // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
+
+  // CHECK: rocdl.mfma.f32.32x32x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<8xf16>, vector<8xf16>, vector<16xf32>
+  // CHECK: rocdl.mfma.f32.16x16x32.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<8xf16>, vector<8xf16>, vector<4xf32>
+  // CHECK: rocdl.mfma.f32.32x32x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg3 * %arg3 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<8xbf16>, vector<8xbf16>, vector<16xf32>
+  // CHECK: rocdl.mfma.f32.16x16x32.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg3 * %arg3 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<8xbf16>, vector<8xbf16>, vector<4xf32>
+  // CHECK: rocdl.mfma.i32.32x32x32.i8{{.*}}: (vector<4xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+  amdgpu.mfma %arg4 * %arg4 + %arg5 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<16xi8>, vector<16xi8>, vector<16xi32>
+  // CHECK: rocdl.mfma.i32.16x16x64.i8{{.*}}: (vector<4xi32>, vector<4xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+  amdgpu.mfma %arg4 * %arg4 + %arg6 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<16xi8>, vector<16xi8>, vector<4xi32>
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg7 * %arg7 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg7 * %arg7 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
+  // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg8 * %arg8 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg8 * %arg8 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
+  // CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg9 * %arg9 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<4xf32>
+  // CHECK: %[[c3:.+]] = llvm.mlir.constant(3 : i32) : i32
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg10 * %arg10 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 29, 2025

@llvm/pr-subscribers-backend-amdgpu

Author: Krzysztof Drewniak (krzysz00)

Changes

This commit extends the lowering of amdgpu.mfma to handle the new double-rate MFMAs in gfx950 and adds tests for these operations.

It also adds support for MFMAs on small floats (f6 and f4), which are implented using the "scaled" MFMA intrinsic with a scale value of 0 in order to have an unscaled MFMA.

This commit does not add a amdgpu.scaled_mfma operation, as that is future work.


Patch is 23.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133553.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+6-4)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+135-29)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+8-6)
  • (added) mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir (+53)
  • (modified) mlir/test/Dialect/AMDGPU/invalid.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index c0b3e5540b1df..9cdd961d96ff5 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -650,10 +650,12 @@ def AMDGPU_MFMAPermBAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_MFMAPermB,
 // mfma
 def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64,
                              VectorOfLengthAndType<[2], [F32]>,
-                             VectorOfLengthAndType<[4], [F16]>,
-                             VectorOfLengthAndType<[2, 4], [BF16]>,
-                             VectorOfLengthAndType<[4, 8], [I8]>,
-                             VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>;
+                             VectorOfLengthAndType<[4, 8], [F16]>,
+                             VectorOfLengthAndType<[2, 4, 8], [BF16]>,
+                             VectorOfLengthAndType<[4, 8, 16], [I8]>,
+                             VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>,
+                             VectorOfLengthAndType<[8, 32], [F8E5M2, F8E4M3FN]>,
+                             VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
 def MFMAOutTypes : AnyTypeOf<[F64,
                               VectorOfLengthAndType<[4, 16, 32], [F32]>,
                               VectorOfLengthAndType<[4, 16, 32], [I32]>,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3acd470cff7f5..77823fd2c52bf 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -22,6 +22,7 @@
 #include "../LLVMCommon/MemRefDescriptor.h"
 
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include <optional>
 
 namespace mlir {
@@ -36,6 +37,7 @@ using namespace mlir::amdgpu;
 constexpr Chipset kGfx908 = Chipset(9, 0, 8);
 constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
 constexpr Chipset kGfx942 = Chipset(9, 4, 2);
+constexpr Chipset kGfx950 = Chipset(9, 5, 0);
 
 /// Convert an unsigned number `val` to i32.
 static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
@@ -494,8 +496,11 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
 /// and LLVM AMDGPU intrinsics convention.
 ///
 /// Specifically:
-/// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
-/// 2. If the element type is bfloat16, bitcast it to i16.
+/// 1. If the element type is bfloat16, bitcast it to i16.
+/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
+/// instead, which is what the f8f6f4 intrinsics use.
+/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
+/// integer.
 static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
                                       Location loc, Value input) {
   Type inputType = input.getType();
@@ -503,10 +508,19 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
     if (vectorType.getElementType().isBF16())
       return rewriter.create<LLVM::BitcastOp>(
           loc, vectorType.clone(rewriter.getI16Type()), input);
-    if (vectorType.getElementType().isInteger(8)) {
+    if (vectorType.getElementType().isInteger(8) &&
+        vectorType.getNumElements() <= 8)
       return rewriter.create<LLVM::BitcastOp>(
           loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
-    }
+    if (isa<IntegerType>(vectorType.getElementType()) &&
+        vectorType.getElementTypeBitWidth() <= 8)
+      return rewriter.create<LLVM::BitcastOp>(
+          loc,
+          VectorType::get((vectorType.getNumElements() *
+                           vectorType.getElementTypeBitWidth()) /
+                              32,
+                          rewriter.getI32Type()),
+          input);
   }
   return input;
 }
@@ -622,12 +636,8 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
                                                   Chipset chipset) {
   uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
            b = mfma.getBlocks();
-  Type sourceElem = mfma.getSourceA().getType();
-  if (auto sourceType = dyn_cast<VectorType>(sourceElem))
-    sourceElem = sourceType.getElementType();
-  Type destElem = mfma.getDestC().getType();
-  if (auto destType = dyn_cast<VectorType>(destElem))
-    destElem = destType.getElementType();
+  Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
+  Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
 
   if (sourceElem.isF32() && destElem.isF32()) {
     if (mfma.getReducePrecision() && chipset >= kGfx942) {
@@ -649,6 +659,12 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
   }
 
   if (sourceElem.isF16() && destElem.isF32()) {
+    if (chipset >= kGfx950) {
+      if (m == 32 && n == 32 && k == 16 && b == 1)
+        return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
+      if (m == 16 && n == 16 && k == 32 && b == 1)
+        return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
+    }
     if (m == 32 && n == 32 && k == 4 && b == 2)
       return ROCDL::mfma_f32_32x32x4f16::getOperationName();
     if (m == 16 && n == 16 && k == 4 && b == 4)
@@ -661,20 +677,25 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
       return ROCDL::mfma_f32_16x16x16f16::getOperationName();
   }
 
-  if (sourceElem.isBF16() && destElem.isF32() && chipset >= kGfx90a) {
-    if (m == 32 && n == 32 && k == 4 && b == 2)
-      return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
-    if (m == 16 && n == 16 && k == 4 && b == 4)
-      return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
-    if (m == 4 && n == 4 && k == 4 && b == 16)
-      return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
-    if (m == 32 && n == 32 && k == 8 && b == 1)
-      return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
-    if (m == 16 && n == 16 && k == 16 && b == 1)
-      return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
-  }
-
   if (sourceElem.isBF16() && destElem.isF32()) {
+    if (chipset >= kGfx950) {
+      if (m == 32 && n == 32 && k == 16 && b == 1)
+        return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
+      if (m == 16 && n == 16 && k == 32 && b == 1)
+        return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
+    }
+    if (chipset >= kGfx90a) {
+      if (m == 32 && n == 32 && k == 4 && b == 2)
+        return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
+      if (m == 16 && n == 16 && k == 4 && b == 4)
+        return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
+      if (m == 4 && n == 4 && k == 4 && b == 16)
+        return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
+      if (m == 32 && n == 32 && k == 8 && b == 1)
+        return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
+      if (m == 16 && n == 16 && k == 16 && b == 1)
+        return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
+    }
     if (m == 32 && n == 32 && k == 2 && b == 2)
       return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
     if (m == 16 && n == 16 && k == 2 && b == 4)
@@ -687,7 +708,14 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
       return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
   }
 
-  if (isa<IntegerType>(sourceElem) && destElem.isInteger(32)) {
+  if (isa<IntegerType>(sourceElem) && sourceElem.getIntOrFloatBitWidth() >= 8 &&
+      destElem.isInteger(32)) {
+    if (chipset >= kGfx950) {
+      if (m == 32 && n == 32 && k == 32 && b == 1)
+        return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
+      if (m == 16 && n == 16 && k == 64 && b == 1)
+        return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
+    }
     if (m == 32 && n == 32 && k == 4 && b == 2)
       return ROCDL::mfma_i32_32x32x4i8::getOperationName();
     if (m == 16 && n == 16 && k == 4 && b == 4)
@@ -750,6 +778,59 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
   return std::nullopt;
 }
 
+static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
+  return llvm::TypeSwitch<Type, std::optional<uint32_t>>(mlirElemType)
+      .Case([](Float8E4M3FNType) { return 0u; })
+      .Case([](Float8E5M2Type) { return 1u; })
+      .Case([](Float6E2M3FNType) { return 2u; })
+      .Case([](Float6E3M2FNType) { return 3u; })
+      .Case([](Float4E2M1FNType) { return 4u; })
+      .Default([](Type) { return std::nullopt; });
+}
+
+/// If there is a scaled MFMA intsruction for the input element types `aType`
+/// and `bType`, output type `destType`, problem size M, N, K, and B (number of
+/// blocks) on the given `chipset`, return a tuple consisting of the
+/// OperationName of the intrinsic and the type codes that need to be passed to
+/// that intrinsic. Note that this is also used to implement some un-scaled
+/// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
+/// MFMA with a scale of 0.
+static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
+                        uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
+  aType = getElementTypeOrSelf(aType);
+  bType = getElementTypeOrSelf(bType);
+  destType = getElementTypeOrSelf(destType);
+
+  if (chipset < kGfx950)
+    return std::nullopt;
+  if (!isa<Float32Type>(destType))
+    return std::nullopt;
+
+  std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(aType);
+  std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(bType);
+  if (!aTypeCode || !bTypeCode)
+    return std::nullopt;
+
+  if (m == 32 && n == 32 && k == 64 && b == 1)
+    return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
+                      *aTypeCode, *bTypeCode};
+  if (m == 16 && n == 16 && k == 128 && b == 1)
+    return std::tuple{
+        ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
+        *bTypeCode};
+
+  return std::nullopt;
+}
+
+static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
+  return mfmaOpToScaledIntrinsic(
+      mfma.getSourceA().getType(), mfma.getSourceB().getType(),
+      mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
+      mfma.getBlocks(), chipset);
+}
+
 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
 /// if one exists. This includes checking to ensure the intrinsic is supported
 /// on the architecture you are compiling for.
@@ -829,16 +910,41 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
           op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
     }
     std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
-    if (!maybeIntrinsic.has_value())
+    std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+        maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
+    if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
       return op.emitOpError("no intrinsic matching MFMA size on given chipset");
-    OperationState loweredOp(loc, *maybeIntrinsic);
+
+    bool isScaled =
+        !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
+    if (isScaled &&
+        (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
+      return op.emitOpError(
+          "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
+          "be scaled as those fields are used for type information");
+    }
+
+    StringRef intrinsicName =
+        isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
+    OperationState loweredOp(loc, intrinsicName);
     loweredOp.addTypes(intrinsicOutType);
     loweredOp.addOperands(
         {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
          convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
-         adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
-         createI32Constant(rewriter, loc, op.getAbid()),
-         createI32Constant(rewriter, loc, getBlgpField)});
+         adaptor.getDestC()});
+    if (isScaled) {
+      Value zero = createI32Constant(rewriter, loc, 0);
+      auto [scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
+      std::ignore = scaledName;
+      loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
+                             createI32Constant(rewriter, loc, bTypeCode),
+                             /*scale A byte=*/zero, /*scale A=*/zero,
+                             /*scale B byte=*/zero, /*scale B=*/zero});
+    } else {
+      loweredOp.addOperands({createI32Constant(rewriter, loc, op.getCbsz()),
+                             createI32Constant(rewriter, loc, op.getAbid()),
+                             createI32Constant(rewriter, loc, getBlgpField)});
+    };
     Value lowered = rewriter.create(loweredOp)->getResult(0);
     if (outType != intrinsicOutType)
       lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 2b2a167b90c82..1e482515a4ee0 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -341,22 +341,24 @@ LogicalResult MFMAOp::verify() {
   }
 
   Type sourceBType = getSourceB().getType();
-  if (sourceElem.isFloat(8)) {
+  if (sourceElem.isFloat(8) || sourceElem.isFloat(6) || sourceElem.isFloat(4)) {
     int64_t sourceBLen = 1;
     Type sourceBElem = sourceBType;
     if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
       sourceBLen = sourceBVector.getNumElements();
       sourceBElem = sourceBVector.getElementType();
     }
-    if (!sourceBElem.isFloat(8))
-      return emitOpError("expected both source operands to have f8 elements");
+    if (!sourceBElem.isFloat(8) && !sourceBElem.isFloat(6) &&
+        !sourceBElem.isFloat(4))
+      return emitOpError("expected both source operands to have small-float "
+                         "elements if one does");
     if (sourceLen != sourceBLen)
       return emitOpError(
-          "expected both f8 source vectors to have the same length");
+          "expected both small-float source vectors to have the same length");
   } else {
     if (sourceType != sourceBType)
-      return emitOpError(
-          "expected both non-f8 source operand types to match exactly");
+      return emitOpError("expected both non-small-float source operand types "
+                         "to match exactly");
   }
   // Normalize the wider integer types the compiler expects to i8
   if (sourceElem.isInteger(32)) {
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
new file mode 100644
index 0000000000000..de63f249bb530
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
@@ -0,0 +1,53 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 -cse | FileCheck %s
+func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
+                    %arg2 : vector<4xf32>, %arg3 : vector<8xbf16>,
+                    %arg4 : vector<16xi8>, %arg5 : vector<16xi32>,
+                    %arg6 : vector<4xi32>, %arg7 : vector<32xf8E4M3FN>,
+                    %arg8 : vector<32xf8E5M2>, %arg9 : vector<32xf6E2M3FN>,
+                    %arg10 : vector<32xf6E3M2FN>, %arg11 : vector<32xf4E2M1FN>) {
+  // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
+
+  // CHECK: rocdl.mfma.f32.32x32x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<8xf16>, vector<8xf16>, vector<16xf32>
+  // CHECK: rocdl.mfma.f32.16x16x32.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<8xf16>, vector<8xf16>, vector<4xf32>
+  // CHECK: rocdl.mfma.f32.32x32x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg3 * %arg3 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<8xbf16>, vector<8xbf16>, vector<16xf32>
+  // CHECK: rocdl.mfma.f32.16x16x32.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg3 * %arg3 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<8xbf16>, vector<8xbf16>, vector<4xf32>
+  // CHECK: rocdl.mfma.i32.32x32x32.i8{{.*}}: (vector<4xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+  amdgpu.mfma %arg4 * %arg4 + %arg5 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<16xi8>, vector<16xi8>, vector<16xi32>
+  // CHECK: rocdl.mfma.i32.16x16x64.i8{{.*}}: (vector<4xi32>, vector<4xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+  amdgpu.mfma %arg4 * %arg4 + %arg6 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<16xi8>, vector<16xi8>, vector<4xi32>
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg7 * %arg7 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg7 * %arg7 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
+  // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg8 * %arg8 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg8 * %arg8 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
+  // CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg9 * %arg9 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<4xf32>
+  // CHECK: %[[c3:.+]] = llvm.mlir.constant(3 : i32) : i32
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg10 * %arg10 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 29, 2025

@llvm/pr-subscribers-mlir

Author: Krzysztof Drewniak (krzysz00)

Changes

This commit extends the lowering of amdgpu.mfma to handle the new double-rate MFMAs in gfx950 and adds tests for these operations.

It also adds support for MFMAs on small floats (f6 and f4), which are implented using the "scaled" MFMA intrinsic with a scale value of 0 in order to have an unscaled MFMA.

This commit does not add a amdgpu.scaled_mfma operation, as that is future work.


Patch is 23.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133553.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+6-4)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+135-29)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+8-6)
  • (added) mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir (+53)
  • (modified) mlir/test/Dialect/AMDGPU/invalid.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index c0b3e5540b1df..9cdd961d96ff5 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -650,10 +650,12 @@ def AMDGPU_MFMAPermBAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_MFMAPermB,
 // mfma
 def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64,
                              VectorOfLengthAndType<[2], [F32]>,
-                             VectorOfLengthAndType<[4], [F16]>,
-                             VectorOfLengthAndType<[2, 4], [BF16]>,
-                             VectorOfLengthAndType<[4, 8], [I8]>,
-                             VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>;
+                             VectorOfLengthAndType<[4, 8], [F16]>,
+                             VectorOfLengthAndType<[2, 4, 8], [BF16]>,
+                             VectorOfLengthAndType<[4, 8, 16], [I8]>,
+                             VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>,
+                             VectorOfLengthAndType<[8, 32], [F8E5M2, F8E4M3FN]>,
+                             VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
 def MFMAOutTypes : AnyTypeOf<[F64,
                               VectorOfLengthAndType<[4, 16, 32], [F32]>,
                               VectorOfLengthAndType<[4, 16, 32], [I32]>,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3acd470cff7f5..77823fd2c52bf 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -22,6 +22,7 @@
 #include "../LLVMCommon/MemRefDescriptor.h"
 
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include <optional>
 
 namespace mlir {
@@ -36,6 +37,7 @@ using namespace mlir::amdgpu;
 constexpr Chipset kGfx908 = Chipset(9, 0, 8);
 constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
 constexpr Chipset kGfx942 = Chipset(9, 4, 2);
+constexpr Chipset kGfx950 = Chipset(9, 5, 0);
 
 /// Convert an unsigned number `val` to i32.
 static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
@@ -494,8 +496,11 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
 /// and LLVM AMDGPU intrinsics convention.
 ///
 /// Specifically:
-/// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
-/// 2. If the element type is bfloat16, bitcast it to i16.
+/// 1. If the element type is bfloat16, bitcast it to i16.
+/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
+/// instead, which is what the f8f6f4 intrinsics use.
+/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
+/// integer.
 static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
                                       Location loc, Value input) {
   Type inputType = input.getType();
@@ -503,10 +508,19 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
     if (vectorType.getElementType().isBF16())
       return rewriter.create<LLVM::BitcastOp>(
           loc, vectorType.clone(rewriter.getI16Type()), input);
-    if (vectorType.getElementType().isInteger(8)) {
+    if (vectorType.getElementType().isInteger(8) &&
+        vectorType.getNumElements() <= 8)
       return rewriter.create<LLVM::BitcastOp>(
           loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
-    }
+    if (isa<IntegerType>(vectorType.getElementType()) &&
+        vectorType.getElementTypeBitWidth() <= 8)
+      return rewriter.create<LLVM::BitcastOp>(
+          loc,
+          VectorType::get((vectorType.getNumElements() *
+                           vectorType.getElementTypeBitWidth()) /
+                              32,
+                          rewriter.getI32Type()),
+          input);
   }
   return input;
 }
@@ -622,12 +636,8 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
                                                   Chipset chipset) {
   uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
            b = mfma.getBlocks();
-  Type sourceElem = mfma.getSourceA().getType();
-  if (auto sourceType = dyn_cast<VectorType>(sourceElem))
-    sourceElem = sourceType.getElementType();
-  Type destElem = mfma.getDestC().getType();
-  if (auto destType = dyn_cast<VectorType>(destElem))
-    destElem = destType.getElementType();
+  Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
+  Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
 
   if (sourceElem.isF32() && destElem.isF32()) {
     if (mfma.getReducePrecision() && chipset >= kGfx942) {
@@ -649,6 +659,12 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
   }
 
   if (sourceElem.isF16() && destElem.isF32()) {
+    if (chipset >= kGfx950) {
+      if (m == 32 && n == 32 && k == 16 && b == 1)
+        return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
+      if (m == 16 && n == 16 && k == 32 && b == 1)
+        return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
+    }
     if (m == 32 && n == 32 && k == 4 && b == 2)
       return ROCDL::mfma_f32_32x32x4f16::getOperationName();
     if (m == 16 && n == 16 && k == 4 && b == 4)
@@ -661,20 +677,25 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
       return ROCDL::mfma_f32_16x16x16f16::getOperationName();
   }
 
-  if (sourceElem.isBF16() && destElem.isF32() && chipset >= kGfx90a) {
-    if (m == 32 && n == 32 && k == 4 && b == 2)
-      return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
-    if (m == 16 && n == 16 && k == 4 && b == 4)
-      return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
-    if (m == 4 && n == 4 && k == 4 && b == 16)
-      return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
-    if (m == 32 && n == 32 && k == 8 && b == 1)
-      return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
-    if (m == 16 && n == 16 && k == 16 && b == 1)
-      return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
-  }
-
   if (sourceElem.isBF16() && destElem.isF32()) {
+    if (chipset >= kGfx950) {
+      if (m == 32 && n == 32 && k == 16 && b == 1)
+        return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
+      if (m == 16 && n == 16 && k == 32 && b == 1)
+        return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
+    }
+    if (chipset >= kGfx90a) {
+      if (m == 32 && n == 32 && k == 4 && b == 2)
+        return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
+      if (m == 16 && n == 16 && k == 4 && b == 4)
+        return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
+      if (m == 4 && n == 4 && k == 4 && b == 16)
+        return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
+      if (m == 32 && n == 32 && k == 8 && b == 1)
+        return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
+      if (m == 16 && n == 16 && k == 16 && b == 1)
+        return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
+    }
     if (m == 32 && n == 32 && k == 2 && b == 2)
       return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
     if (m == 16 && n == 16 && k == 2 && b == 4)
@@ -687,7 +708,14 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
       return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
   }
 
-  if (isa<IntegerType>(sourceElem) && destElem.isInteger(32)) {
+  if (isa<IntegerType>(sourceElem) && sourceElem.getIntOrFloatBitWidth() >= 8 &&
+      destElem.isInteger(32)) {
+    if (chipset >= kGfx950) {
+      if (m == 32 && n == 32 && k == 32 && b == 1)
+        return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
+      if (m == 16 && n == 16 && k == 64 && b == 1)
+        return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
+    }
     if (m == 32 && n == 32 && k == 4 && b == 2)
       return ROCDL::mfma_i32_32x32x4i8::getOperationName();
     if (m == 16 && n == 16 && k == 4 && b == 4)
@@ -750,6 +778,59 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
   return std::nullopt;
 }
 
+static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
+  return llvm::TypeSwitch<Type, std::optional<uint32_t>>(mlirElemType)
+      .Case([](Float8E4M3FNType) { return 0u; })
+      .Case([](Float8E5M2Type) { return 1u; })
+      .Case([](Float6E2M3FNType) { return 2u; })
+      .Case([](Float6E3M2FNType) { return 3u; })
+      .Case([](Float4E2M1FNType) { return 4u; })
+      .Default([](Type) { return std::nullopt; });
+}
+
+/// If there is a scaled MFMA intsruction for the input element types `aType`
+/// and `bType`, output type `destType`, problem size M, N, K, and B (number of
+/// blocks) on the given `chipset`, return a tuple consisting of the
+/// OperationName of the intrinsic and the type codes that need to be passed to
+/// that intrinsic. Note that this is also used to implement some un-scaled
+/// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
+/// MFMA with a scale of 0.
+static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
+                        uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
+  aType = getElementTypeOrSelf(aType);
+  bType = getElementTypeOrSelf(bType);
+  destType = getElementTypeOrSelf(destType);
+
+  if (chipset < kGfx950)
+    return std::nullopt;
+  if (!isa<Float32Type>(destType))
+    return std::nullopt;
+
+  std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(aType);
+  std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(bType);
+  if (!aTypeCode || !bTypeCode)
+    return std::nullopt;
+
+  if (m == 32 && n == 32 && k == 64 && b == 1)
+    return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
+                      *aTypeCode, *bTypeCode};
+  if (m == 16 && n == 16 && k == 128 && b == 1)
+    return std::tuple{
+        ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
+        *bTypeCode};
+
+  return std::nullopt;
+}
+
+static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
+  return mfmaOpToScaledIntrinsic(
+      mfma.getSourceA().getType(), mfma.getSourceB().getType(),
+      mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
+      mfma.getBlocks(), chipset);
+}
+
 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
 /// if one exists. This includes checking to ensure the intrinsic is supported
 /// on the architecture you are compiling for.
@@ -829,16 +910,41 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
           op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
     }
     std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
-    if (!maybeIntrinsic.has_value())
+    std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+        maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
+    if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
       return op.emitOpError("no intrinsic matching MFMA size on given chipset");
-    OperationState loweredOp(loc, *maybeIntrinsic);
+
+    bool isScaled =
+        !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
+    if (isScaled &&
+        (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
+      return op.emitOpError(
+          "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
+          "be scaled as those fields are used for type information");
+    }
+
+    StringRef intrinsicName =
+        isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
+    OperationState loweredOp(loc, intrinsicName);
     loweredOp.addTypes(intrinsicOutType);
     loweredOp.addOperands(
         {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
          convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
-         adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
-         createI32Constant(rewriter, loc, op.getAbid()),
-         createI32Constant(rewriter, loc, getBlgpField)});
+         adaptor.getDestC()});
+    if (isScaled) {
+      Value zero = createI32Constant(rewriter, loc, 0);
+      auto [scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
+      std::ignore = scaledName;
+      loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
+                             createI32Constant(rewriter, loc, bTypeCode),
+                             /*scale A byte=*/zero, /*scale A=*/zero,
+                             /*scale B byte=*/zero, /*scale B=*/zero});
+    } else {
+      loweredOp.addOperands({createI32Constant(rewriter, loc, op.getCbsz()),
+                             createI32Constant(rewriter, loc, op.getAbid()),
+                             createI32Constant(rewriter, loc, getBlgpField)});
+    };
     Value lowered = rewriter.create(loweredOp)->getResult(0);
     if (outType != intrinsicOutType)
       lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 2b2a167b90c82..1e482515a4ee0 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -341,22 +341,24 @@ LogicalResult MFMAOp::verify() {
   }
 
   Type sourceBType = getSourceB().getType();
-  if (sourceElem.isFloat(8)) {
+  if (sourceElem.isFloat(8) || sourceElem.isFloat(6) || sourceElem.isFloat(4)) {
     int64_t sourceBLen = 1;
     Type sourceBElem = sourceBType;
     if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
       sourceBLen = sourceBVector.getNumElements();
       sourceBElem = sourceBVector.getElementType();
     }
-    if (!sourceBElem.isFloat(8))
-      return emitOpError("expected both source operands to have f8 elements");
+    if (!sourceBElem.isFloat(8) && !sourceBElem.isFloat(6) &&
+        !sourceBElem.isFloat(4))
+      return emitOpError("expected both source operands to have small-float "
+                         "elements if one does");
     if (sourceLen != sourceBLen)
       return emitOpError(
-          "expected both f8 source vectors to have the same length");
+          "expected both small-float source vectors to have the same length");
   } else {
     if (sourceType != sourceBType)
-      return emitOpError(
-          "expected both non-f8 source operand types to match exactly");
+      return emitOpError("expected both non-small-float source operand types "
+                         "to match exactly");
   }
   // Normalize the wider integer types the compiler expects to i8
   if (sourceElem.isInteger(32)) {
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
new file mode 100644
index 0000000000000..de63f249bb530
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
@@ -0,0 +1,53 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 -cse | FileCheck %s
+func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
+                    %arg2 : vector<4xf32>, %arg3 : vector<8xbf16>,
+                    %arg4 : vector<16xi8>, %arg5 : vector<16xi32>,
+                    %arg6 : vector<4xi32>, %arg7 : vector<32xf8E4M3FN>,
+                    %arg8 : vector<32xf8E5M2>, %arg9 : vector<32xf6E2M3FN>,
+                    %arg10 : vector<32xf6E3M2FN>, %arg11 : vector<32xf4E2M1FN>) {
+  // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
+
+  // CHECK: rocdl.mfma.f32.32x32x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<8xf16>, vector<8xf16>, vector<16xf32>
+  // CHECK: rocdl.mfma.f32.16x16x32.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<8xf16>, vector<8xf16>, vector<4xf32>
+  // CHECK: rocdl.mfma.f32.32x32x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg3 * %arg3 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<8xbf16>, vector<8xbf16>, vector<16xf32>
+  // CHECK: rocdl.mfma.f32.16x16x32.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg3 * %arg3 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<8xbf16>, vector<8xbf16>, vector<4xf32>
+  // CHECK: rocdl.mfma.i32.32x32x32.i8{{.*}}: (vector<4xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+  amdgpu.mfma %arg4 * %arg4 + %arg5 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<16xi8>, vector<16xi8>, vector<16xi32>
+  // CHECK: rocdl.mfma.i32.16x16x64.i8{{.*}}: (vector<4xi32>, vector<4xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+  amdgpu.mfma %arg4 * %arg4 + %arg6 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<16xi8>, vector<16xi8>, vector<4xi32>
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg7 * %arg7 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg7 * %arg7 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
+  // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg8 * %arg8 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg8 * %arg8 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
+  // CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg9 * %arg9 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<4xf32>
+  // CHECK: %[[c3:.+]] = llvm.mlir.constant(3 : i32) : i32
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg10 * %arg10 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32...
[truncated]

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % nits

@krzysz00 krzysz00 merged commit 25622aa into llvm:main Apr 1, 2025
11 checks passed
Ankur-0429 pushed a commit to Ankur-0429/llvm-project that referenced this pull request Apr 2, 2025
This commit extends the lowering of amdgpu.mfma to handle the new
double-rate MFMAs in gfx950 and adds tests for these operations.

It also adds support for MFMAs on small floats (f6 and f4), which are
implented using the "scaled" MFMA intrinsic with a scale value of 0 in
order to have an unscaled MFMA.

This commit does not add a `amdgpu.scaled_mfma` operation, as that is
future work.

---------

Co-authored-by: Jakub Kuderski <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants