-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][AMDGPU] Add OCP FP8 support for new hardware #127728
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir-gpu Author: Mirza Halilčević (mirza-halilcevic) Changes(Continuing from #106160) This PR addresses remaining review comments from the original PR. Original PR DescriptionUpcoming hardware (gfx12 and some future gfx9) will support the OCP 8-bit float formats for their matrix multiplication intrinsics and conversion operations, retaining existing opcodes and compiler builtins. This commit adds support for these types to the MLIR wrappers around such operations, ensuring that the OCP types aren't used to generate those builtins on hardware that doesn't expect that format and, conversely, to ensure that the pre-OCP formats aren't used on new hardware. Patch is 35.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127728.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index cba35bbca1f83..484cea84f669b 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -41,8 +41,8 @@ class AMDGPU_Op<string mnemonic, list<Trait> traits = []> :
def AMDGPU_ExtPackedFp8Op :
AMDGPU_Op<"ext_packed_fp8", [Pure]>,
- Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ,
- VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ]>]>:$source,
+ Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN,
+ VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
Results<(outs F32:$res)> {
let summary = "Extend one of a vector of packed fp8 values to a float";
@@ -68,8 +68,8 @@ def AMDGPU_PackedTrunc2xFp8Op :
Arguments<(ins F32:$sourceA,
Optional<F32>:$sourceB,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
- Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
- Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
+ Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
+ Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
let summary = "Round two floats into a packed vector of 8-bit floats";
let description = [{
Round the inputs `sourceA` and `sourceB` (which is undefined if not
@@ -95,8 +95,8 @@ def AMDGPU_PackedStochRoundFp8Op :
Arguments<(ins F32:$source,
I32:$stochiasticParam,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$storeIndex,
- Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
- Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
+ Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
+ Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
let summary = "Round float stochiastically into a packed vector of 8-bit floats";
let description = [{
Round the input `source`, adding in `stochiasticParam`, and place it into
@@ -546,7 +546,7 @@ def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64,
VectorOfLengthAndType<[4], [F16]>,
VectorOfLengthAndType<[2, 4], [BF16]>,
VectorOfLengthAndType<[4, 8], [I8]>,
- VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>]>;
+ VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>;
def MFMAOutTypes : AnyTypeOf<[F64,
VectorOfLengthAndType<[4, 16, 32], [F32]>,
VectorOfLengthAndType<[4, 16, 32], [I32]>,
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
index a5dab1ab89630..768b390ed5381 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
@@ -49,6 +49,14 @@ struct Chipset {
#undef DEFINE_COMP_OPERATOR
};
+inline bool isGfx940Series(const Chipset &chipset) {
+ return chipset.majorVersion == 9 && chipset.minorVersion == 4;
+}
+inline bool hasOcpFp8(const Chipset &chipset) {
+ return (chipset.majorVersion == 9 && chipset.minorVersion >= 5) ||
+ chipset.majorVersion >= 12;
+}
+
} // namespace mlir::amdgpu
#endif
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index b6a307fd7cb0f..7feab4d966d59 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -140,6 +140,9 @@ class Type {
bool isF64() const;
bool isF80() const;
bool isF128() const;
+ /// Return true if this is an float type (with the specified width).
+ bool isFloat() const;
+ bool isFloat(unsigned width) const;
/// Return true if this is an integer type (with the specified width).
bool isInteger() const;
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index f80d2793eaef5..4a76739c7a06a 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -454,6 +454,20 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
}
}
+/// Return true if `type` is the E5M2 variant of an 8-bit float that is
+/// supported by the `_bf8` instructions on the given `chipset`.
+static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
+ return (isGfx940Series(chipset) && type.isFloat8E5M2FNUZ()) ||
+ (hasOcpFp8(chipset) && type.isFloat8E5M2());
+}
+
+/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
+/// supported by the `_fp8` instructions on the given `chipset`.
+static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
+ return (isGfx940Series(chipset) && type.isFloat8E4M3FNUZ()) ||
+ (hasOcpFp8(chipset) && type.isFloat8E4M3FN());
+}
+
/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
/// if one exists. This includes checking to ensure the intrinsic is supported
/// on the architecture you are compiling for.
@@ -550,38 +564,38 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
}
- if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) {
+ if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
// Known to be correct because there are no scalar f8 instructions and
// because a length mismatch will have been caught by the verifier.
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
}
}
- if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) {
+ if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
}
}
@@ -757,7 +771,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- if (chipset.majorVersion != 9 || chipset < kGfx940)
+ if (!(isGfx940Series(chipset) || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
@@ -787,10 +801,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
- if (sourceElemType.isFloat8E5M2FNUZ()) {
+ if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
wordSel);
- } else if (sourceElemType.isFloat8E4M3FNUZ()) {
+ } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
wordSel);
}
@@ -801,7 +815,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- if (chipset.majorVersion != 9 || chipset < kGfx940)
+ if (!(isGfx940Series(chipset) || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
@@ -822,10 +836,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
Value result;
- if (resultElemType.isFloat8E5M2FNUZ())
+ if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
- else if (resultElemType.isFloat8E4M3FNUZ())
+ else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
@@ -838,7 +852,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- if (chipset.majorVersion != 9 || chipset < kGfx940)
+ if (!(isGfx940Series(chipset) || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
@@ -857,10 +871,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
Value result;
- if (resultElemType.isFloat8E5M2FNUZ())
+ if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
existing, byteSel);
- else if (resultElemType.isFloat8E4M3FNUZ())
+ else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
existing, byteSel);
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 6b27ec9947cb0..e16f9f65cc919 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -41,6 +41,10 @@ struct ArithToAMDGPUConversionPass final
struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
+ Chipset chipset;
+ ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
+ : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
+
LogicalResult match(arith::ExtFOp op) const override;
void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
};
@@ -68,6 +72,14 @@ struct TruncfToFloat16RewritePattern final
} // end namespace
+static LogicalResult isSupportedF8(Type elementType, Chipset chipset) {
+ if (isGfx940Series(chipset))
+ return success(isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType));
+ if (hasOcpFp8(chipset))
+ return success(isa<Float8E4M3FNType, Float8E5M2Type>(elementType));
+ return failure();
+}
+
static Value castF32To(Type elementType, Value f32, Location loc,
PatternRewriter &rewriter) {
if (elementType.isF32())
@@ -86,7 +98,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
return failure();
inType = inVecType.getElementType();
}
- return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
+ return isSupportedF8(inType, chipset);
}
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -216,7 +228,8 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
if (inType && inType.getWidth() <= 8 && saturateFP8)
// Conversion between 8-bit floats is not supported with truncation enabled.
return failure();
- return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
+
+ return isSupportedF8(outType, chipset);
}
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
@@ -365,7 +378,7 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
if (convertFP8Arithmetic) {
- patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
+ patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset);
patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
saturateFP8Truncf, chipset);
}
@@ -384,7 +397,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
}
bool convertFP8Arithmetic =
- maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 0);
+ isGfx940Series(*maybeChipset) || hasOcpFp8(*maybeChipset);
arith::populateArithToAMDGPUConversionPatterns(
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
*maybeChipset);
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 63447baa31eb0..48fb1dfb0a003 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -272,14 +272,14 @@ LogicalResult MFMAOp::verify() {
}
Type sourceBType = getSourceB().getType();
- if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) {
+ if (sourceElem.isFloat(8)) {
int64_t sourceBLen = 1;
Type sourceBElem = sourceBType;
if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
sourceBLen = sourceBVector.getNumElements();
sourceBElem = sourceBVector.getElementType();
}
- if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ())
+ if (!sourceBElem.isFloat(8))
return emitOpError("expected both source operands to have f8 elements");
if (sourceLen != sourceBLen)
return emitOpError(
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index b78c372af77e6..c345dd5883ead 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -509,7 +509,8 @@ bool TosaValidation::isValidElementType(Type type) {
if (isa<FloatType>(type)) {
if (profile == TosaProfileEnum::BaseInference)
return false;
- return type.isF32() || type.isF16() || type.isBF16();
+ return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
+ Float8E5M2Type>(type);
}
if (auto intTy = dyn_cast<IntegerType>(type)) {
if (intTy.isUnsigned()) {
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index fa093664cf77f..ca09c4aed14cf 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -56,6 +56,15 @@ bool Type::isF64() const { return llvm::isa<Float64Type>(*this); }
bool Type::isF80() const { return llvm::isa<Float80Type>(*this); }
bool Type::isF128() const { return llvm::isa<Float128Type>(*this); }
+bool Type::isFloat() const { return llvm::isa<FloatType>(*this); }
+
+/// Return true if this is a float type with the specified width.
+bool Type::isFloat(unsigned width) const {
+ if (auto fltTy = llvm::dyn_cast<FloatType>(*this))
+ return fltTy.getWidth() == width;
+ return false;
+}
+
bool Type::isIndex() const { return llvm::isa<IndexType>(*this); }
bool Type::isInteger() const { return llvm::isa<IntegerType>(*this); }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
new file mode 100644
index 0000000000000..70775a603e54d
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
@@ -0,0 +1,109 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 | FileCheck %s
+
+// CHECK-LABEL: func @ext_scalar
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : f8E5M2 to i8
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
+// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
+// CHECK: return [[EXT]]
+func.func @ext_scalar(%v: f8E5M2) -> f32 {
+ %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @ext_short_vec
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<2xf8E4M3FN> to vector<2xi8>
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
+// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
+// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
+// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
+// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
+// CHECK: return [[EXT]]
+func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 {
+ %ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FN> to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @ext_full_vec(
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
+// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
+// CHECK: return [[EXT]] : f32
+
+func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> f32 {
+ %ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FN> to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @packed_trunc
+// CHECK-SAME: ([[V:%.+]]: f32)
+// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32
+// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
+// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
+func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FN> {
+ %ret = amdgpu.packed_trunc_2xfp8 %v, undef into undef[word 0] : f32 to...
[truncated]
|
@llvm/pr-subscribers-mlir-core Author: Mirza Halilčević (mirza-halilcevic) Changes(Continuing from #106160) This PR addresses remaining review comments from the original PR. Original PR DescriptionUpcoming hardware (gfx12 and some future gfx9) will support the OCP 8-bit float formats for their matrix multiplication intrinsics and conversion operations, retaining existing opcodes and compiler builtins. This commit adds support for these types to the MLIR wrappers around such operations, ensuring that the OCP types aren't used to generate those builtins on hardware that doesn't expect that format and, conversely, to ensure that the pre-OCP formats aren't used on new hardware. Patch is 35.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127728.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index cba35bbca1f83..484cea84f669b 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -41,8 +41,8 @@ class AMDGPU_Op<string mnemonic, list<Trait> traits = []> :
def AMDGPU_ExtPackedFp8Op :
AMDGPU_Op<"ext_packed_fp8", [Pure]>,
- Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ,
- VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ]>]>:$source,
+ Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN,
+ VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
Results<(outs F32:$res)> {
let summary = "Extend one of a vector of packed fp8 values to a float";
@@ -68,8 +68,8 @@ def AMDGPU_PackedTrunc2xFp8Op :
Arguments<(ins F32:$sourceA,
Optional<F32>:$sourceB,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
- Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
- Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
+ Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
+ Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
let summary = "Round two floats into a packed vector of 8-bit floats";
let description = [{
Round the inputs `sourceA` and `sourceB` (which is undefined if not
@@ -95,8 +95,8 @@ def AMDGPU_PackedStochRoundFp8Op :
Arguments<(ins F32:$source,
I32:$stochiasticParam,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$storeIndex,
- Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
- Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
+ Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
+ Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
let summary = "Round float stochiastically into a packed vector of 8-bit floats";
let description = [{
Round the input `source`, adding in `stochiasticParam`, and place it into
@@ -546,7 +546,7 @@ def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64,
VectorOfLengthAndType<[4], [F16]>,
VectorOfLengthAndType<[2, 4], [BF16]>,
VectorOfLengthAndType<[4, 8], [I8]>,
- VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>]>;
+ VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>;
def MFMAOutTypes : AnyTypeOf<[F64,
VectorOfLengthAndType<[4, 16, 32], [F32]>,
VectorOfLengthAndType<[4, 16, 32], [I32]>,
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
index a5dab1ab89630..768b390ed5381 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
@@ -49,6 +49,14 @@ struct Chipset {
#undef DEFINE_COMP_OPERATOR
};
+inline bool isGfx940Series(const Chipset &chipset) {
+ return chipset.majorVersion == 9 && chipset.minorVersion == 4;
+}
+inline bool hasOcpFp8(const Chipset &chipset) {
+ return (chipset.majorVersion == 9 && chipset.minorVersion >= 5) ||
+ chipset.majorVersion >= 12;
+}
+
} // namespace mlir::amdgpu
#endif
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index b6a307fd7cb0f..7feab4d966d59 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -140,6 +140,9 @@ class Type {
bool isF64() const;
bool isF80() const;
bool isF128() const;
+ /// Return true if this is an float type (with the specified width).
+ bool isFloat() const;
+ bool isFloat(unsigned width) const;
/// Return true if this is an integer type (with the specified width).
bool isInteger() const;
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index f80d2793eaef5..4a76739c7a06a 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -454,6 +454,20 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
}
}
+/// Return true if `type` is the E5M2 variant of an 8-bit float that is
+/// supported by the `_bf8` instructions on the given `chipset`.
+static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
+ return (isGfx940Series(chipset) && type.isFloat8E5M2FNUZ()) ||
+ (hasOcpFp8(chipset) && type.isFloat8E5M2());
+}
+
+/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
+/// supported by the `_fp8` instructions on the given `chipset`.
+static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
+ return (isGfx940Series(chipset) && type.isFloat8E4M3FNUZ()) ||
+ (hasOcpFp8(chipset) && type.isFloat8E4M3FN());
+}
+
/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
/// if one exists. This includes checking to ensure the intrinsic is supported
/// on the architecture you are compiling for.
@@ -550,38 +564,38 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
}
- if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) {
+ if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
// Known to be correct because there are no scalar f8 instructions and
// because a length mismatch will have been caught by the verifier.
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
}
}
- if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) {
+ if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
}
}
@@ -757,7 +771,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- if (chipset.majorVersion != 9 || chipset < kGfx940)
+ if (!(isGfx940Series(chipset) || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
@@ -787,10 +801,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
- if (sourceElemType.isFloat8E5M2FNUZ()) {
+ if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
wordSel);
- } else if (sourceElemType.isFloat8E4M3FNUZ()) {
+ } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
wordSel);
}
@@ -801,7 +815,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- if (chipset.majorVersion != 9 || chipset < kGfx940)
+ if (!(isGfx940Series(chipset) || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
@@ -822,10 +836,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
Value result;
- if (resultElemType.isFloat8E5M2FNUZ())
+ if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
- else if (resultElemType.isFloat8E4M3FNUZ())
+ else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
@@ -838,7 +852,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- if (chipset.majorVersion != 9 || chipset < kGfx940)
+ if (!(isGfx940Series(chipset) || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
@@ -857,10 +871,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
Value result;
- if (resultElemType.isFloat8E5M2FNUZ())
+ if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
existing, byteSel);
- else if (resultElemType.isFloat8E4M3FNUZ())
+ else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
existing, byteSel);
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 6b27ec9947cb0..e16f9f65cc919 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -41,6 +41,10 @@ struct ArithToAMDGPUConversionPass final
struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
+ Chipset chipset;
+ ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
+ : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
+
LogicalResult match(arith::ExtFOp op) const override;
void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
};
@@ -68,6 +72,14 @@ struct TruncfToFloat16RewritePattern final
} // end namespace
+static LogicalResult isSupportedF8(Type elementType, Chipset chipset) {
+ if (isGfx940Series(chipset))
+ return success(isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType));
+ if (hasOcpFp8(chipset))
+ return success(isa<Float8E4M3FNType, Float8E5M2Type>(elementType));
+ return failure();
+}
+
static Value castF32To(Type elementType, Value f32, Location loc,
PatternRewriter &rewriter) {
if (elementType.isF32())
@@ -86,7 +98,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
return failure();
inType = inVecType.getElementType();
}
- return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
+ return isSupportedF8(inType, chipset);
}
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -216,7 +228,8 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
if (inType && inType.getWidth() <= 8 && saturateFP8)
// Conversion between 8-bit floats is not supported with truncation enabled.
return failure();
- return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
+
+ return isSupportedF8(outType, chipset);
}
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
@@ -365,7 +378,7 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
if (convertFP8Arithmetic) {
- patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
+ patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset);
patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
saturateFP8Truncf, chipset);
}
@@ -384,7 +397,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
}
bool convertFP8Arithmetic =
- maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 0);
+ isGfx940Series(*maybeChipset) || hasOcpFp8(*maybeChipset);
arith::populateArithToAMDGPUConversionPatterns(
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
*maybeChipset);
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 63447baa31eb0..48fb1dfb0a003 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -272,14 +272,14 @@ LogicalResult MFMAOp::verify() {
}
Type sourceBType = getSourceB().getType();
- if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) {
+ if (sourceElem.isFloat(8)) {
int64_t sourceBLen = 1;
Type sourceBElem = sourceBType;
if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
sourceBLen = sourceBVector.getNumElements();
sourceBElem = sourceBVector.getElementType();
}
- if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ())
+ if (!sourceBElem.isFloat(8))
return emitOpError("expected both source operands to have f8 elements");
if (sourceLen != sourceBLen)
return emitOpError(
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index b78c372af77e6..c345dd5883ead 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -509,7 +509,8 @@ bool TosaValidation::isValidElementType(Type type) {
if (isa<FloatType>(type)) {
if (profile == TosaProfileEnum::BaseInference)
return false;
- return type.isF32() || type.isF16() || type.isBF16();
+ return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
+ Float8E5M2Type>(type);
}
if (auto intTy = dyn_cast<IntegerType>(type)) {
if (intTy.isUnsigned()) {
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index fa093664cf77f..ca09c4aed14cf 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -56,6 +56,15 @@ bool Type::isF64() const { return llvm::isa<Float64Type>(*this); }
bool Type::isF80() const { return llvm::isa<Float80Type>(*this); }
bool Type::isF128() const { return llvm::isa<Float128Type>(*this); }
+bool Type::isFloat() const { return llvm::isa<FloatType>(*this); }
+
+/// Return true if this is a float type with the specified width.
+bool Type::isFloat(unsigned width) const {
+ if (auto fltTy = llvm::dyn_cast<FloatType>(*this))
+ return fltTy.getWidth() == width;
+ return false;
+}
+
bool Type::isIndex() const { return llvm::isa<IndexType>(*this); }
bool Type::isInteger() const { return llvm::isa<IntegerType>(*this); }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
new file mode 100644
index 0000000000000..70775a603e54d
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
@@ -0,0 +1,109 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 | FileCheck %s
+
+// CHECK-LABEL: func @ext_scalar
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : f8E5M2 to i8
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
+// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
+// CHECK: return [[EXT]]
+func.func @ext_scalar(%v: f8E5M2) -> f32 {
+ %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @ext_short_vec
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<2xf8E4M3FN> to vector<2xi8>
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
+// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
+// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
+// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
+// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
+// CHECK: return [[EXT]]
+func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 {
+ %ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FN> to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @ext_full_vec(
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
+// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
+// CHECK: return [[EXT]] : f32
+
+func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> f32 {
+ %ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FN> to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @packed_trunc
+// CHECK-SAME: ([[V:%.+]]: f32)
+// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32
+// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
+// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
+func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FN> {
+ %ret = amdgpu.packed_trunc_2xfp8 %v, undef into undef[word 0] : f32 to...
[truncated]
|
@llvm/pr-subscribers-mlir-amdgpu Author: Mirza Halilčević (mirza-halilcevic) Changes(Continuing from #106160) This PR addresses remaining review comments from the original PR. Original PR DescriptionUpcoming hardware (gfx12 and some future gfx9) will support the OCP 8-bit float formats for their matrix multiplication intrinsics and conversion operations, retaining existing opcodes and compiler builtins. This commit adds support for these types to the MLIR wrappers around such operations, ensuring that the OCP types aren't used to generate those builtins on hardware that doesn't expect that format and, conversely, to ensure that the pre-OCP formats aren't used on new hardware. Patch is 35.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127728.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index cba35bbca1f83..484cea84f669b 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -41,8 +41,8 @@ class AMDGPU_Op<string mnemonic, list<Trait> traits = []> :
def AMDGPU_ExtPackedFp8Op :
AMDGPU_Op<"ext_packed_fp8", [Pure]>,
- Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ,
- VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ]>]>:$source,
+ Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN,
+ VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
Results<(outs F32:$res)> {
let summary = "Extend one of a vector of packed fp8 values to a float";
@@ -68,8 +68,8 @@ def AMDGPU_PackedTrunc2xFp8Op :
Arguments<(ins F32:$sourceA,
Optional<F32>:$sourceB,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
- Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
- Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
+ Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
+ Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
let summary = "Round two floats into a packed vector of 8-bit floats";
let description = [{
Round the inputs `sourceA` and `sourceB` (which is undefined if not
@@ -95,8 +95,8 @@ def AMDGPU_PackedStochRoundFp8Op :
Arguments<(ins F32:$source,
I32:$stochiasticParam,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$storeIndex,
- Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
- Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
+ Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
+ Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
let summary = "Round float stochiastically into a packed vector of 8-bit floats";
let description = [{
Round the input `source`, adding in `stochiasticParam`, and place it into
@@ -546,7 +546,7 @@ def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64,
VectorOfLengthAndType<[4], [F16]>,
VectorOfLengthAndType<[2, 4], [BF16]>,
VectorOfLengthAndType<[4, 8], [I8]>,
- VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>]>;
+ VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>;
def MFMAOutTypes : AnyTypeOf<[F64,
VectorOfLengthAndType<[4, 16, 32], [F32]>,
VectorOfLengthAndType<[4, 16, 32], [I32]>,
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
index a5dab1ab89630..768b390ed5381 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
@@ -49,6 +49,14 @@ struct Chipset {
#undef DEFINE_COMP_OPERATOR
};
+inline bool isGfx940Series(const Chipset &chipset) {
+ return chipset.majorVersion == 9 && chipset.minorVersion == 4;
+}
+inline bool hasOcpFp8(const Chipset &chipset) {
+ return (chipset.majorVersion == 9 && chipset.minorVersion >= 5) ||
+ chipset.majorVersion >= 12;
+}
+
} // namespace mlir::amdgpu
#endif
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index b6a307fd7cb0f..7feab4d966d59 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -140,6 +140,9 @@ class Type {
bool isF64() const;
bool isF80() const;
bool isF128() const;
+ /// Return true if this is an float type (with the specified width).
+ bool isFloat() const;
+ bool isFloat(unsigned width) const;
/// Return true if this is an integer type (with the specified width).
bool isInteger() const;
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index f80d2793eaef5..4a76739c7a06a 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -454,6 +454,20 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
}
}
+/// Return true if `type` is the E5M2 variant of an 8-bit float that is
+/// supported by the `_bf8` instructions on the given `chipset`.
+static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
+ return (isGfx940Series(chipset) && type.isFloat8E5M2FNUZ()) ||
+ (hasOcpFp8(chipset) && type.isFloat8E5M2());
+}
+
+/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
+/// supported by the `_fp8` instructions on the given `chipset`.
+static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
+ return (isGfx940Series(chipset) && type.isFloat8E4M3FNUZ()) ||
+ (hasOcpFp8(chipset) && type.isFloat8E4M3FN());
+}
+
/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
/// if one exists. This includes checking to ensure the intrinsic is supported
/// on the architecture you are compiling for.
@@ -550,38 +564,38 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
}
- if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) {
+ if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
// Known to be correct because there are no scalar f8 instructions and
// because a length mismatch will have been caught by the verifier.
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
}
}
- if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) {
+ if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
}
}
@@ -757,7 +771,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- if (chipset.majorVersion != 9 || chipset < kGfx940)
+ if (!(isGfx940Series(chipset) || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
@@ -787,10 +801,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
- if (sourceElemType.isFloat8E5M2FNUZ()) {
+ if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
wordSel);
- } else if (sourceElemType.isFloat8E4M3FNUZ()) {
+ } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
wordSel);
}
@@ -801,7 +815,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- if (chipset.majorVersion != 9 || chipset < kGfx940)
+ if (!(isGfx940Series(chipset) || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
@@ -822,10 +836,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
Value result;
- if (resultElemType.isFloat8E5M2FNUZ())
+ if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
- else if (resultElemType.isFloat8E4M3FNUZ())
+ else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
@@ -838,7 +852,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- if (chipset.majorVersion != 9 || chipset < kGfx940)
+ if (!(isGfx940Series(chipset) || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
@@ -857,10 +871,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
Value result;
- if (resultElemType.isFloat8E5M2FNUZ())
+ if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
existing, byteSel);
- else if (resultElemType.isFloat8E4M3FNUZ())
+ else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
existing, byteSel);
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 6b27ec9947cb0..e16f9f65cc919 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -41,6 +41,10 @@ struct ArithToAMDGPUConversionPass final
struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
+ Chipset chipset;
+ ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
+ : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
+
LogicalResult match(arith::ExtFOp op) const override;
void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
};
@@ -68,6 +72,14 @@ struct TruncfToFloat16RewritePattern final
} // end namespace
+static LogicalResult isSupportedF8(Type elementType, Chipset chipset) {
+ if (isGfx940Series(chipset))
+ return success(isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType));
+ if (hasOcpFp8(chipset))
+ return success(isa<Float8E4M3FNType, Float8E5M2Type>(elementType));
+ return failure();
+}
+
static Value castF32To(Type elementType, Value f32, Location loc,
PatternRewriter &rewriter) {
if (elementType.isF32())
@@ -86,7 +98,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
return failure();
inType = inVecType.getElementType();
}
- return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
+ return isSupportedF8(inType, chipset);
}
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -216,7 +228,8 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
if (inType && inType.getWidth() <= 8 && saturateFP8)
// Conversion between 8-bit floats is not supported with truncation enabled.
return failure();
- return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
+
+ return isSupportedF8(outType, chipset);
}
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
@@ -365,7 +378,7 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
if (convertFP8Arithmetic) {
- patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
+ patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset);
patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
saturateFP8Truncf, chipset);
}
@@ -384,7 +397,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
}
bool convertFP8Arithmetic =
- maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 0);
+ isGfx940Series(*maybeChipset) || hasOcpFp8(*maybeChipset);
arith::populateArithToAMDGPUConversionPatterns(
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
*maybeChipset);
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 63447baa31eb0..48fb1dfb0a003 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -272,14 +272,14 @@ LogicalResult MFMAOp::verify() {
}
Type sourceBType = getSourceB().getType();
- if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) {
+ if (sourceElem.isFloat(8)) {
int64_t sourceBLen = 1;
Type sourceBElem = sourceBType;
if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
sourceBLen = sourceBVector.getNumElements();
sourceBElem = sourceBVector.getElementType();
}
- if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ())
+ if (!sourceBElem.isFloat(8))
return emitOpError("expected both source operands to have f8 elements");
if (sourceLen != sourceBLen)
return emitOpError(
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index b78c372af77e6..c345dd5883ead 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -509,7 +509,8 @@ bool TosaValidation::isValidElementType(Type type) {
if (isa<FloatType>(type)) {
if (profile == TosaProfileEnum::BaseInference)
return false;
- return type.isF32() || type.isF16() || type.isBF16();
+ return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
+ Float8E5M2Type>(type);
}
if (auto intTy = dyn_cast<IntegerType>(type)) {
if (intTy.isUnsigned()) {
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index fa093664cf77f..ca09c4aed14cf 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -56,6 +56,15 @@ bool Type::isF64() const { return llvm::isa<Float64Type>(*this); }
bool Type::isF80() const { return llvm::isa<Float80Type>(*this); }
bool Type::isF128() const { return llvm::isa<Float128Type>(*this); }
+bool Type::isFloat() const { return llvm::isa<FloatType>(*this); }
+
+/// Return true if this is a float type with the specified width.
+bool Type::isFloat(unsigned width) const {
+ if (auto fltTy = llvm::dyn_cast<FloatType>(*this))
+ return fltTy.getWidth() == width;
+ return false;
+}
+
bool Type::isIndex() const { return llvm::isa<IndexType>(*this); }
bool Type::isInteger() const { return llvm::isa<IntegerType>(*this); }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
new file mode 100644
index 0000000000000..70775a603e54d
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
@@ -0,0 +1,109 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 | FileCheck %s
+
+// CHECK-LABEL: func @ext_scalar
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : f8E5M2 to i8
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
+// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
+// CHECK: return [[EXT]]
+func.func @ext_scalar(%v: f8E5M2) -> f32 {
+ %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @ext_short_vec
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<2xf8E4M3FN> to vector<2xi8>
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
+// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
+// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
+// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
+// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
+// CHECK: return [[EXT]]
+func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 {
+ %ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FN> to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @ext_full_vec(
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
+// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
+// CHECK: return [[EXT]] : f32
+
+func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> f32 {
+ %ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FN> to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @packed_trunc
+// CHECK-SAME: ([[V:%.+]]: f32)
+// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32
+// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
+// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
+func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FN> {
+ %ret = amdgpu.packed_trunc_2xfp8 %v, undef into undef[word 0] : f32 to...
[truncated]
|
Seems like this needs to be rebased. Also, can we close the original PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few changes, hopefully not contradicting previous reviews.
Also, if you look in your downstream, there's a fix to wmmaPushInputOperand
and probably the amdgpu.wmma
op generally - can you see if those were a separate PR?
@@ -49,6 +49,14 @@ struct Chipset { | |||
#undef DEFINE_COMP_OPERATOR | |||
}; | |||
|
|||
inline bool isGfx940Series(const Chipset &chipset) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per recent codebase cleanup, isGfx942Series
if we don't want to straight-up ==
it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this is supposed to match with all the gfx94_ chipsets, hence the "Series" in the function name. But if gfx940 and gfx941 are no longer supported, I can rename this to isGfx942
and just check chipset == Chipset(9, 4, 2)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For that, there's just a kGfx942
running around in a bunch of files - take equality to that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And I'm not sure if "where needed" wasn't supposed to mean "static method in the .cpp
".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense.
@kuhar What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this instance we are pretty sure gfx942 is the only gfxip version in the gfx940 series. Note that this wasn't the case a couple of months ago.
I'd go with @krzysz00's suggestion.
e79cb88
to
348cd7d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok so main thing - and we can wait on Jakub with this - is that inline utility floating around in a header like that ... but maybe it's fine.
Otherwise, seems fine to me
@@ -49,6 +49,14 @@ struct Chipset { | |||
#undef DEFINE_COMP_OPERATOR | |||
}; | |||
|
|||
inline bool isGfx940Series(const Chipset &chipset) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And I'm not sure if "where needed" wasn't supposed to mean "static method in the .cpp
".
Upcoming hardware (gfx12 and some future gfx9) will support the OCP 8-bit float formats for their matrix multiplication intrinsics and conversion operations, retaining existing opcodes and compiler builtins. This commit adds support for these types to the MLIR wrappers around such operations, ensuring that the OCP types aren't used to generate those builtins on hardware that doesn't expect that format and, conversely, to ensure that the pre-OCP formats aren't used on new hardware.
those operations were not being converted to the LLVM intrinsics they correspond to because the rewrite patterns were still checking for gfx940+. As part of this, factor out tests for type-match isto isNativeFp8() and isNativeBf8() functions in the AMDGPUToRocdl rewrites. Also, fix a typo in isGfx940() that caused it to be true for gfx950. Finally, test all these OCP format conversions by duplicating the gfx940 tests.
Signed-off-by: Mirza Halilcevic <[email protected]>
since gfx940 and gfx941 are no longer supported. Signed-off-by: Mirza Halilcevic <[email protected]>
348cd7d
to
19b23a2
Compare
@krzysz00 Are you referring to this? Looks like this was supposed to be a part of #106388, but maybe got unintentionally left out. |
@@ -696,7 +696,8 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) { | |||
|
|||
bool TosaValidation::isValidElementType(Type type) { | |||
if (isa<FloatType>(type)) { | |||
return type.isF32() || type.isF16() || type.isBF16(); | |||
return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType, | |||
Float8E5M2Type>(type); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are missing two f8 types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like the TOSA spec doesn't support them #106160 (comment)
This was addressed in 5bace46
I'm pretty happy with the PR and don't have any further objections. I figure we can land this if @kuhar 's good |
@mirza-halilcevic Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR. Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues. How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! |
(Continuing from llvm#106160) This PR addresses remaining review comments from the original PR. Original PR Description --- Upcoming hardware (gfx12 and some future gfx9) will support the OCP 8-bit float formats for their matrix multiplication intrinsics and conversion operations, retaining existing opcodes and compiler builtins. This commit adds support for these types to the MLIR wrappers around such operations, ensuring that the OCP types aren't used to generate those builtins on hardware that doesn't expect that format and, conversely, to ensure that the pre-OCP formats aren't used on new hardware. --------- Signed-off-by: Mirza Halilcevic <[email protected]> Co-authored-by: Paul Fuqua <[email protected]> Co-authored-by: Krzysztof Drewniak <[email protected]>
(Continuing from llvm#106160) This PR addresses remaining review comments from the original PR. Original PR Description --- Upcoming hardware (gfx12 and some future gfx9) will support the OCP 8-bit float formats for their matrix multiplication intrinsics and conversion operations, retaining existing opcodes and compiler builtins. This commit adds support for these types to the MLIR wrappers around such operations, ensuring that the OCP types aren't used to generate those builtins on hardware that doesn't expect that format and, conversely, to ensure that the pre-OCP formats aren't used on new hardware. --------- Signed-off-by: Mirza Halilcevic <[email protected]> Co-authored-by: Paul Fuqua <[email protected]> Co-authored-by: Krzysztof Drewniak <[email protected]>
(Continuing from #106160)
This PR addresses remaining review comments from the original PR.
Original PR Description
Upcoming hardware (gfx12 and some future gfx9) will support the OCP 8-bit float formats for their matrix multiplication intrinsics and conversion operations, retaining existing opcodes and compiler builtins.
This commit adds support for these types to the MLIR wrappers around such operations, ensuring that the OCP types aren't used to generate those builtins on hardware that doesn't expect that format and, conversely, to ensure that the pre-OCP formats aren't used on new hardware.