Skip to content

[MLIR][AMDGPU] Add support for fp8 ops on gfx12 #106388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 3, 2024

Conversation

giuseros
Copy link
Contributor

This PR is adding support for fp8 and bfp8 on gfx12

@llvmbot
Copy link
Member

llvmbot commented Aug 28, 2024

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-backend-amdgpu

Author: Giuseppe Rossini (giuseros)

Changes

This PR is adding support for fp8 and bfp8 on gfx12


Full diff: https://github.com/llvm/llvm-project/pull/106388.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+5-2)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+22-15)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+3-1)
  • (added) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir (+9)
  • (modified) mlir/test/Target/LLVMIR/rocdl.mlir (+10)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index aa2b4543927a7f..35789984c92212 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -504,7 +504,7 @@ def MFMAOutTypes : AnyTypeOf<[F64,
                               VectorOfLengthAndType<[4, 16, 32], [I32]>,
                               VectorOfLengthAndType<[4], [F64]>]>;
 // wmma
-def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[16], [F16, BF16, I8, SI8, UI8]>]>;
+def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[8, 16], [F16, BF16, I8, SI8, UI8, F8E4M3FN, F8E5M2]>]>;
 def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
                               VectorOfLengthAndType<[8, 16], [F16, BF16]>]>;
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 868208ff74a521..bbb6e666d82956 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -165,7 +165,7 @@ def ROCDL_BallotOp :
   let summary = "Vote across thread group";
 
   let description = [{
-      Ballot provides a bit mask containing the 1-bit predicate value from each lane. 
+      Ballot provides a bit mask containing the 1-bit predicate value from each lane.
       The nth bit of the result contains the 1 bit contributed by the nth warp lane.
   }];
 
@@ -328,13 +328,16 @@ class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
     "$args attr-dict `:` functional-type($args, $res)";
 }
 
-// Available on RDNA3
+// Available from gfx11
 def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16", [0]>;
 def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16", [0]>;
 def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x16.f16", [0]>;
 def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16", [0]>;
 def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8", [1]>;
 def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]>;
+// Available from gfx12
+def ROCDL_wmma_f32_16x16x16_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
+def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
 
 //===---------------------------------------------------------------------===//
 // Operations on raw buffer resources (stride of 0, bounds checks either off or in
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b808738804030f..45c5070333b527 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -385,6 +385,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
                                  Location loc,
                                  const TypeConverter *typeConverter,
                                  bool isUnsigned, Value llvmInput,
+                                 Value mlirInput,
                                  SmallVector<Value, 4> &operands) {
   Type inputType = llvmInput.getType();
   auto vectorType = dyn_cast<VectorType>(inputType);
@@ -398,23 +399,25 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
     return;
   }
 
+  auto mlirInputType = dyn_cast<VectorType>(mlirInput.getType());
+  if (mlirInputType.getElementType().isInteger(8)) {
+    // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
+    bool localIsUnsigned = isUnsigned;
+    if (elemType.isUnsignedInteger(8)) {
+      localIsUnsigned = true;
+    } else if (elemType.isSignedInteger(8)) {
+      localIsUnsigned = false;
+    }
+    Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
+    operands.push_back(sign);
+  }
+
   int64_t numBytes = vectorType.getNumElements();
   Type i32 = rewriter.getI32Type();
   VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32);
   auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits);
-
   Value result = rewriter.createOrFold<LLVM::BitcastOp>(
       loc, llvmVectorType32bits, llvmInput);
-
-  // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
-  bool localIsUnsigned = isUnsigned;
-  if (elemType.isUnsignedInteger(8)) {
-    localIsUnsigned = true;
-  } else if (elemType.isSignedInteger(8)) {
-    localIsUnsigned = false;
-  }
-  Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
-  operands.push_back(sign);
   operands.push_back(result);
 }
 
@@ -601,6 +604,10 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
     return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
   } else if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) {
     return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
+  } else if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32()) {
+    return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
+  } else if (elemSourceType.isFloat8E5M2() && elemDestType.isF32()) {
+    return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
   }
   return std::nullopt;
 }
@@ -662,8 +669,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
     Location loc = op.getLoc();
     Type outType = typeConverter->convertType(op.getDestD().getType());
 
-    if (chipset.majorVersion != 11)
-      return op->emitOpError("WMMA only supported on gfx11");
+    if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
+      return op->emitOpError("WMMA only supported on gfx11 and gfx12");
 
     std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
 
@@ -675,9 +682,9 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
 
     SmallVector<Value, 4> operands;
     wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
-                         adaptor.getSourceA(), operands);
+                         adaptor.getSourceA(), op.getSourceA(), operands);
     wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
-                         adaptor.getSourceB(), operands);
+                         adaptor.getSourceB(), op.getSourceB(), operands);
     wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
                           op.getSubwordOffset(), op.getClamp(), operands);
 
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index e3beceaa3bbb5b..a8d6ccdc1a471e 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -235,7 +235,9 @@ LogicalResult WMMAOp::verify() {
 
   bool isDestFloat =
       (destElemType.isF32() || destElemType.isF16() || destElemType.isBF16());
-  bool isSrcFloat = (sourceAElemType.isF16() || sourceAElemType.isBF16());
+  bool isSrcFloat =
+      (sourceAElemType.isF16() || sourceAElemType.isBF16() ||
+       sourceAElemType.isFloat8E4M3FN() || sourceAElemType.isFloat8E5M2());
 
   if (isDestFloat && !isSrcFloat) {
     return emitOpError("Expected float sources with float destination");
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
new file mode 100644
index 00000000000000..7b2b524d4af426
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s
+func.func @mfma_to_rocdl(%arg0 : vector<8xf8E4M3FN>, %arg1 : vector<8xf8E5M2>,  %arg2 : vector<8xf32>) {
+  // CHECK: rocdl.wmma.f32.16x16x16.fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+  amdgpu.wmma %arg0 * %arg0 + %arg2: vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f32.16x16x16.bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+  amdgpu.wmma %arg1 * %arg1 + %arg2: vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32>
+  func.return
+}
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 78c3987fab648e..79f5c133503d44 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -363,6 +363,16 @@ llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
   llvm.return %rsrc : !llvm.ptr<8>
 }
 
+llvm.func @rocdl.wmma.fp8(%arg0 : vector<2 x i32>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+  // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.fp8.fp8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}})
+  %r0 = rocdl.wmma.f32.16x16x16.fp8_fp8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+
+  // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf8.bf8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}})
+  %r1 = rocdl.wmma.f32.16x16x16.bf8_bf8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+
+  llvm.return %r0 : vector<8 x f32>
+}
+
 llvm.func @rocdl.raw.ptr.buffer(%rsrc : !llvm.ptr<8>,
                         %offset : i32, %soffset : i32,
                         %vdata1 : i32,

@llvmbot
Copy link
Member

llvmbot commented Aug 28, 2024

@llvm/pr-subscribers-mlir-amdgpu

Author: Giuseppe Rossini (giuseros)

Changes

This PR is adding support for fp8 and bfp8 on gfx12


Full diff: https://github.com/llvm/llvm-project/pull/106388.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+5-2)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+22-15)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+3-1)
  • (added) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir (+9)
  • (modified) mlir/test/Target/LLVMIR/rocdl.mlir (+10)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index aa2b4543927a7f..35789984c92212 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -504,7 +504,7 @@ def MFMAOutTypes : AnyTypeOf<[F64,
                               VectorOfLengthAndType<[4, 16, 32], [I32]>,
                               VectorOfLengthAndType<[4], [F64]>]>;
 // wmma
-def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[16], [F16, BF16, I8, SI8, UI8]>]>;
+def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[8, 16], [F16, BF16, I8, SI8, UI8, F8E4M3FN, F8E5M2]>]>;
 def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
                               VectorOfLengthAndType<[8, 16], [F16, BF16]>]>;
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 868208ff74a521..bbb6e666d82956 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -165,7 +165,7 @@ def ROCDL_BallotOp :
   let summary = "Vote across thread group";
 
   let description = [{
-      Ballot provides a bit mask containing the 1-bit predicate value from each lane. 
+      Ballot provides a bit mask containing the 1-bit predicate value from each lane.
       The nth bit of the result contains the 1 bit contributed by the nth warp lane.
   }];
 
@@ -328,13 +328,16 @@ class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
     "$args attr-dict `:` functional-type($args, $res)";
 }
 
-// Available on RDNA3
+// Available from gfx11
 def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16", [0]>;
 def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16", [0]>;
 def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x16.f16", [0]>;
 def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16", [0]>;
 def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8", [1]>;
 def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]>;
+// Available from gfx12
+def ROCDL_wmma_f32_16x16x16_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
+def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
 
 //===---------------------------------------------------------------------===//
 // Operations on raw buffer resources (stride of 0, bounds checks either off or in
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b808738804030f..45c5070333b527 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -385,6 +385,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
                                  Location loc,
                                  const TypeConverter *typeConverter,
                                  bool isUnsigned, Value llvmInput,
+                                 Value mlirInput,
                                  SmallVector<Value, 4> &operands) {
   Type inputType = llvmInput.getType();
   auto vectorType = dyn_cast<VectorType>(inputType);
@@ -398,23 +399,25 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
     return;
   }
 
+  auto mlirInputType = dyn_cast<VectorType>(mlirInput.getType());
+  if (mlirInputType.getElementType().isInteger(8)) {
+    // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
+    bool localIsUnsigned = isUnsigned;
+    if (elemType.isUnsignedInteger(8)) {
+      localIsUnsigned = true;
+    } else if (elemType.isSignedInteger(8)) {
+      localIsUnsigned = false;
+    }
+    Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
+    operands.push_back(sign);
+  }
+
   int64_t numBytes = vectorType.getNumElements();
   Type i32 = rewriter.getI32Type();
   VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32);
   auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits);
-
   Value result = rewriter.createOrFold<LLVM::BitcastOp>(
       loc, llvmVectorType32bits, llvmInput);
-
-  // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
-  bool localIsUnsigned = isUnsigned;
-  if (elemType.isUnsignedInteger(8)) {
-    localIsUnsigned = true;
-  } else if (elemType.isSignedInteger(8)) {
-    localIsUnsigned = false;
-  }
-  Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
-  operands.push_back(sign);
   operands.push_back(result);
 }
 
@@ -601,6 +604,10 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
     return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
   } else if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) {
     return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
+  } else if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32()) {
+    return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
+  } else if (elemSourceType.isFloat8E5M2() && elemDestType.isF32()) {
+    return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
   }
   return std::nullopt;
 }
@@ -662,8 +669,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
     Location loc = op.getLoc();
     Type outType = typeConverter->convertType(op.getDestD().getType());
 
-    if (chipset.majorVersion != 11)
-      return op->emitOpError("WMMA only supported on gfx11");
+    if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
+      return op->emitOpError("WMMA only supported on gfx11 and gfx12");
 
     std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
 
@@ -675,9 +682,9 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
 
     SmallVector<Value, 4> operands;
     wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
-                         adaptor.getSourceA(), operands);
+                         adaptor.getSourceA(), op.getSourceA(), operands);
     wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
-                         adaptor.getSourceB(), operands);
+                         adaptor.getSourceB(), op.getSourceB(), operands);
     wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
                           op.getSubwordOffset(), op.getClamp(), operands);
 
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index e3beceaa3bbb5b..a8d6ccdc1a471e 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -235,7 +235,9 @@ LogicalResult WMMAOp::verify() {
 
   bool isDestFloat =
       (destElemType.isF32() || destElemType.isF16() || destElemType.isBF16());
-  bool isSrcFloat = (sourceAElemType.isF16() || sourceAElemType.isBF16());
+  bool isSrcFloat =
+      (sourceAElemType.isF16() || sourceAElemType.isBF16() ||
+       sourceAElemType.isFloat8E4M3FN() || sourceAElemType.isFloat8E5M2());
 
   if (isDestFloat && !isSrcFloat) {
     return emitOpError("Expected float sources with float destination");
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
new file mode 100644
index 00000000000000..7b2b524d4af426
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s
+func.func @mfma_to_rocdl(%arg0 : vector<8xf8E4M3FN>, %arg1 : vector<8xf8E5M2>,  %arg2 : vector<8xf32>) {
+  // CHECK: rocdl.wmma.f32.16x16x16.fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+  amdgpu.wmma %arg0 * %arg0 + %arg2: vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f32.16x16x16.bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+  amdgpu.wmma %arg1 * %arg1 + %arg2: vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32>
+  func.return
+}
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 78c3987fab648e..79f5c133503d44 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -363,6 +363,16 @@ llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
   llvm.return %rsrc : !llvm.ptr<8>
 }
 
+llvm.func @rocdl.wmma.fp8(%arg0 : vector<2 x i32>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+  // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.fp8.fp8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}})
+  %r0 = rocdl.wmma.f32.16x16x16.fp8_fp8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+
+  // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf8.bf8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}})
+  %r1 = rocdl.wmma.f32.16x16x16.bf8_bf8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+
+  llvm.return %r0 : vector<8 x f32>
+}
+
 llvm.func @rocdl.raw.ptr.buffer(%rsrc : !llvm.ptr<8>,
                         %offset : i32, %soffset : i32,
                         %vdata1 : i32,

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % one suggestion

@giuseros giuseros merged commit a8e1c6f into llvm:main Sep 3, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants