-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][AMDGPU] Add support for fp8 ops on gfx12 #106388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-backend-amdgpu Author: Giuseppe Rossini (giuseros) ChangesThis PR is adding support for Full diff: https://github.com/llvm/llvm-project/pull/106388.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index aa2b4543927a7f..35789984c92212 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -504,7 +504,7 @@ def MFMAOutTypes : AnyTypeOf<[F64,
VectorOfLengthAndType<[4, 16, 32], [I32]>,
VectorOfLengthAndType<[4], [F64]>]>;
// wmma
-def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[16], [F16, BF16, I8, SI8, UI8]>]>;
+def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[8, 16], [F16, BF16, I8, SI8, UI8, F8E4M3FN, F8E5M2]>]>;
def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
VectorOfLengthAndType<[8, 16], [F16, BF16]>]>;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 868208ff74a521..bbb6e666d82956 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -165,7 +165,7 @@ def ROCDL_BallotOp :
let summary = "Vote across thread group";
let description = [{
- Ballot provides a bit mask containing the 1-bit predicate value from each lane.
+ Ballot provides a bit mask containing the 1-bit predicate value from each lane.
The nth bit of the result contains the 1 bit contributed by the nth warp lane.
}];
@@ -328,13 +328,16 @@ class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
"$args attr-dict `:` functional-type($args, $res)";
}
-// Available on RDNA3
+// Available from gfx11
def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16", [0]>;
def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16", [0]>;
def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x16.f16", [0]>;
def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16", [0]>;
def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8", [1]>;
def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]>;
+// Available from gfx12
+def ROCDL_wmma_f32_16x16x16_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
+def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
//===---------------------------------------------------------------------===//
// Operations on raw buffer resources (stride of 0, bounds checks either off or in
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b808738804030f..45c5070333b527 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -385,6 +385,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
Location loc,
const TypeConverter *typeConverter,
bool isUnsigned, Value llvmInput,
+ Value mlirInput,
SmallVector<Value, 4> &operands) {
Type inputType = llvmInput.getType();
auto vectorType = dyn_cast<VectorType>(inputType);
@@ -398,23 +399,25 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
return;
}
+ auto mlirInputType = dyn_cast<VectorType>(mlirInput.getType());
+ if (mlirInputType.getElementType().isInteger(8)) {
+ // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
+ bool localIsUnsigned = isUnsigned;
+ if (elemType.isUnsignedInteger(8)) {
+ localIsUnsigned = true;
+ } else if (elemType.isSignedInteger(8)) {
+ localIsUnsigned = false;
+ }
+ Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
+ operands.push_back(sign);
+ }
+
int64_t numBytes = vectorType.getNumElements();
Type i32 = rewriter.getI32Type();
VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32);
auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits);
-
Value result = rewriter.createOrFold<LLVM::BitcastOp>(
loc, llvmVectorType32bits, llvmInput);
-
- // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
- bool localIsUnsigned = isUnsigned;
- if (elemType.isUnsignedInteger(8)) {
- localIsUnsigned = true;
- } else if (elemType.isSignedInteger(8)) {
- localIsUnsigned = false;
- }
- Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
- operands.push_back(sign);
operands.push_back(result);
}
@@ -601,6 +604,10 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
} else if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) {
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
+ } else if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32()) {
+ return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
+ } else if (elemSourceType.isFloat8E5M2() && elemDestType.isF32()) {
+ return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
}
return std::nullopt;
}
@@ -662,8 +669,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
Location loc = op.getLoc();
Type outType = typeConverter->convertType(op.getDestD().getType());
- if (chipset.majorVersion != 11)
- return op->emitOpError("WMMA only supported on gfx11");
+ if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
+ return op->emitOpError("WMMA only supported on gfx11 and gfx12");
std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
@@ -675,9 +682,9 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
SmallVector<Value, 4> operands;
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
- adaptor.getSourceA(), operands);
+ adaptor.getSourceA(), op.getSourceA(), operands);
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
- adaptor.getSourceB(), operands);
+ adaptor.getSourceB(), op.getSourceB(), operands);
wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
op.getSubwordOffset(), op.getClamp(), operands);
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index e3beceaa3bbb5b..a8d6ccdc1a471e 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -235,7 +235,9 @@ LogicalResult WMMAOp::verify() {
bool isDestFloat =
(destElemType.isF32() || destElemType.isF16() || destElemType.isBF16());
- bool isSrcFloat = (sourceAElemType.isF16() || sourceAElemType.isBF16());
+ bool isSrcFloat =
+ (sourceAElemType.isF16() || sourceAElemType.isBF16() ||
+ sourceAElemType.isFloat8E4M3FN() || sourceAElemType.isFloat8E5M2());
if (isDestFloat && !isSrcFloat) {
return emitOpError("Expected float sources with float destination");
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
new file mode 100644
index 00000000000000..7b2b524d4af426
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s
+func.func @mfma_to_rocdl(%arg0 : vector<8xf8E4M3FN>, %arg1 : vector<8xf8E5M2>, %arg2 : vector<8xf32>) {
+ // CHECK: rocdl.wmma.f32.16x16x16.fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+ amdgpu.wmma %arg0 * %arg0 + %arg2: vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.f32.16x16x16.bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+ amdgpu.wmma %arg1 * %arg1 + %arg2: vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32>
+ func.return
+}
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 78c3987fab648e..79f5c133503d44 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -363,6 +363,16 @@ llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
llvm.return %rsrc : !llvm.ptr<8>
}
+llvm.func @rocdl.wmma.fp8(%arg0 : vector<2 x i32>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+ // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.fp8.fp8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}})
+ %r0 = rocdl.wmma.f32.16x16x16.fp8_fp8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+
+ // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf8.bf8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}})
+ %r1 = rocdl.wmma.f32.16x16x16.bf8_bf8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+
+ llvm.return %r0 : vector<8 x f32>
+}
+
llvm.func @rocdl.raw.ptr.buffer(%rsrc : !llvm.ptr<8>,
%offset : i32, %soffset : i32,
%vdata1 : i32,
|
@llvm/pr-subscribers-mlir-amdgpu Author: Giuseppe Rossini (giuseros) ChangesThis PR is adding support for Full diff: https://github.com/llvm/llvm-project/pull/106388.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index aa2b4543927a7f..35789984c92212 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -504,7 +504,7 @@ def MFMAOutTypes : AnyTypeOf<[F64,
VectorOfLengthAndType<[4, 16, 32], [I32]>,
VectorOfLengthAndType<[4], [F64]>]>;
// wmma
-def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[16], [F16, BF16, I8, SI8, UI8]>]>;
+def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[8, 16], [F16, BF16, I8, SI8, UI8, F8E4M3FN, F8E5M2]>]>;
def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
VectorOfLengthAndType<[8, 16], [F16, BF16]>]>;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 868208ff74a521..bbb6e666d82956 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -165,7 +165,7 @@ def ROCDL_BallotOp :
let summary = "Vote across thread group";
let description = [{
- Ballot provides a bit mask containing the 1-bit predicate value from each lane.
+ Ballot provides a bit mask containing the 1-bit predicate value from each lane.
The nth bit of the result contains the 1 bit contributed by the nth warp lane.
}];
@@ -328,13 +328,16 @@ class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
"$args attr-dict `:` functional-type($args, $res)";
}
-// Available on RDNA3
+// Available from gfx11
def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16", [0]>;
def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16", [0]>;
def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x16.f16", [0]>;
def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16", [0]>;
def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8", [1]>;
def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]>;
+// Available from gfx12
+def ROCDL_wmma_f32_16x16x16_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
+def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
//===---------------------------------------------------------------------===//
// Operations on raw buffer resources (stride of 0, bounds checks either off or in
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b808738804030f..45c5070333b527 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -385,6 +385,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
Location loc,
const TypeConverter *typeConverter,
bool isUnsigned, Value llvmInput,
+ Value mlirInput,
SmallVector<Value, 4> &operands) {
Type inputType = llvmInput.getType();
auto vectorType = dyn_cast<VectorType>(inputType);
@@ -398,23 +399,25 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
return;
}
+ auto mlirInputType = dyn_cast<VectorType>(mlirInput.getType());
+ if (mlirInputType.getElementType().isInteger(8)) {
+ // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
+ bool localIsUnsigned = isUnsigned;
+ if (elemType.isUnsignedInteger(8)) {
+ localIsUnsigned = true;
+ } else if (elemType.isSignedInteger(8)) {
+ localIsUnsigned = false;
+ }
+ Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
+ operands.push_back(sign);
+ }
+
int64_t numBytes = vectorType.getNumElements();
Type i32 = rewriter.getI32Type();
VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32);
auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits);
-
Value result = rewriter.createOrFold<LLVM::BitcastOp>(
loc, llvmVectorType32bits, llvmInput);
-
- // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
- bool localIsUnsigned = isUnsigned;
- if (elemType.isUnsignedInteger(8)) {
- localIsUnsigned = true;
- } else if (elemType.isSignedInteger(8)) {
- localIsUnsigned = false;
- }
- Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
- operands.push_back(sign);
operands.push_back(result);
}
@@ -601,6 +604,10 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
} else if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) {
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
+ } else if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32()) {
+ return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
+ } else if (elemSourceType.isFloat8E5M2() && elemDestType.isF32()) {
+ return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
}
return std::nullopt;
}
@@ -662,8 +669,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
Location loc = op.getLoc();
Type outType = typeConverter->convertType(op.getDestD().getType());
- if (chipset.majorVersion != 11)
- return op->emitOpError("WMMA only supported on gfx11");
+ if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
+ return op->emitOpError("WMMA only supported on gfx11 and gfx12");
std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
@@ -675,9 +682,9 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
SmallVector<Value, 4> operands;
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
- adaptor.getSourceA(), operands);
+ adaptor.getSourceA(), op.getSourceA(), operands);
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
- adaptor.getSourceB(), operands);
+ adaptor.getSourceB(), op.getSourceB(), operands);
wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
op.getSubwordOffset(), op.getClamp(), operands);
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index e3beceaa3bbb5b..a8d6ccdc1a471e 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -235,7 +235,9 @@ LogicalResult WMMAOp::verify() {
bool isDestFloat =
(destElemType.isF32() || destElemType.isF16() || destElemType.isBF16());
- bool isSrcFloat = (sourceAElemType.isF16() || sourceAElemType.isBF16());
+ bool isSrcFloat =
+ (sourceAElemType.isF16() || sourceAElemType.isBF16() ||
+ sourceAElemType.isFloat8E4M3FN() || sourceAElemType.isFloat8E5M2());
if (isDestFloat && !isSrcFloat) {
return emitOpError("Expected float sources with float destination");
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
new file mode 100644
index 00000000000000..7b2b524d4af426
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s
+func.func @mfma_to_rocdl(%arg0 : vector<8xf8E4M3FN>, %arg1 : vector<8xf8E5M2>, %arg2 : vector<8xf32>) {
+ // CHECK: rocdl.wmma.f32.16x16x16.fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+ amdgpu.wmma %arg0 * %arg0 + %arg2: vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.f32.16x16x16.bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+ amdgpu.wmma %arg1 * %arg1 + %arg2: vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32>
+ func.return
+}
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 78c3987fab648e..79f5c133503d44 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -363,6 +363,16 @@ llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
llvm.return %rsrc : !llvm.ptr<8>
}
+llvm.func @rocdl.wmma.fp8(%arg0 : vector<2 x i32>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+ // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.fp8.fp8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}})
+ %r0 = rocdl.wmma.f32.16x16x16.fp8_fp8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+
+ // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf8.bf8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}})
+ %r1 = rocdl.wmma.f32.16x16x16.bf8_bf8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+
+ llvm.return %r0 : vector<8 x f32>
+}
+
llvm.func @rocdl.raw.ptr.buffer(%rsrc : !llvm.ptr<8>,
%offset : i32, %soffset : i32,
%vdata1 : i32,
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % one suggestion
This PR is adding support for
fp8
andbfp8
on gfx12