Skip to content

Commit 32e052d

Browse files
pcf000mirza-halilcevic
authored andcommitted
[MLIR][AMDGPU] Add OCP FP8 support to for new hardware
Upcoming hardware (gfx12 and some future gfx9) will support the OCP 8-bit float formats for their matrix multiplication intrinsics and conversion operations, retaining existing opcodes and compiler builtins. This commit adds support for these types to the MLIR wrappers around such operations, ensuring that the OCP types aren't used to generate those builtins on hardware that doesn't expect that format and, conversely, to ensure that the pre-OCP formats aren't used on new hardware.
1 parent be5c66d commit 32e052d

File tree

6 files changed

+47
-30
lines changed

6 files changed

+47
-30
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ class AMDGPU_Op<string mnemonic, list<Trait> traits = []> :
4141

4242
def AMDGPU_ExtPackedFp8Op :
4343
AMDGPU_Op<"ext_packed_fp8", [Pure]>,
44-
Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ,
45-
VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ]>]>:$source,
44+
Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN,
45+
VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source,
4646
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
4747
Results<(outs F32:$res)> {
4848
let summary = "Extend one of a vector of packed fp8 values to a float";
@@ -68,8 +68,8 @@ def AMDGPU_PackedTrunc2xFp8Op :
6868
Arguments<(ins F32:$sourceA,
6969
Optional<F32>:$sourceB,
7070
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
71-
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
72-
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
71+
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
72+
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
7373
let summary = "Round two floats into a packed vector of 8-bit floats";
7474
let description = [{
7575
Round the inputs `sourceA` and `sourceB` (which is undefined if not
@@ -95,8 +95,8 @@ def AMDGPU_PackedStochRoundFp8Op :
9595
Arguments<(ins F32:$source,
9696
I32:$stochiasticParam,
9797
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$storeIndex,
98-
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
99-
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
98+
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
99+
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
100100
let summary = "Round float stochiastically into a packed vector of 8-bit floats";
101101
let description = [{
102102
Round the input `source`, adding in `stochiasticParam`, and place it into
@@ -546,7 +546,7 @@ def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64,
546546
VectorOfLengthAndType<[4], [F16]>,
547547
VectorOfLengthAndType<[2, 4], [BF16]>,
548548
VectorOfLengthAndType<[4, 8], [I8]>,
549-
VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>]>;
549+
VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>;
550550
def MFMAOutTypes : AnyTypeOf<[F64,
551551
VectorOfLengthAndType<[4, 16, 32], [F32]>,
552552
VectorOfLengthAndType<[4, 16, 32], [I32]>,

mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@ struct Chipset {
4747
DEFINE_COMP_OPERATOR(>)
4848
DEFINE_COMP_OPERATOR(>=)
4949
#undef DEFINE_COMP_OPERATOR
50+
51+
bool isGfx940() const {
52+
return majorVersion == 9 && minorVersion >= 0x40 && minorVersion < 0x50;
53+
}
54+
bool hasOcpFp8() const {
55+
return (majorVersion == 9 && minorVersion >= 0x50) || majorVersion >= 12;
56+
}
5057
};
5158

5259
} // namespace mlir::amdgpu

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -570,40 +570,42 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
570570
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
571571
}
572572

573-
if (isa<Float8E5M2FNUZType>(sourceElem) && destElem.isF32() &&
574-
chipset >= kGfx942) {
573+
if (destElem.isF32() &&
574+
((isa<Float8E5M2FNUZType>(sourceElem) && chipset >= kGfx942) ||
575+
(isa<Float8E5M2Type>(sourceElem) && chipset.hasOcpFp8()))) {
575576
// Known to be correct because there are no scalar f8 instructions and
576577
// because a length mismatch will have been caught by the verifier.
577578
Type sourceBElem =
578579
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
579580
if (m == 16 && n == 16 && k == 32 && b == 1) {
580-
if (isa<Float8E5M2FNUZType>(sourceBElem))
581+
if (isa<Float8E5M2FNUZType, Float8E5M2Type>(sourceBElem))
581582
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
582-
if (isa<Float8E4M3FNUZType>(sourceBElem))
583+
if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(sourceBElem))
583584
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
584585
}
585586
if (m == 32 && n == 32 && k == 16 && b == 1) {
586-
if (isa<Float8E5M2FNUZType>(sourceBElem))
587+
if (isa<Float8E5M2FNUZType, Float8E5M2Type>(sourceBElem))
587588
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
588-
if (isa<Float8E4M3FNUZType>(sourceBElem))
589+
if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(sourceBElem))
589590
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
590591
}
591592
}
592593

593-
if (isa<Float8E4M3FNUZType>(sourceElem) && destElem.isF32() &&
594-
chipset >= kGfx942) {
594+
if (destElem.isF32() &&
595+
((isa<Float8E4M3FNUZType>(sourceElem) && chipset >= kGfx942) ||
596+
(isa<Float8E4M3FNType>(sourceElem) && chipset.hasOcpFp8()))) {
595597
Type sourceBElem =
596598
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
597599
if (m == 16 && n == 16 && k == 32 && b == 1) {
598-
if (isa<Float8E5M2FNUZType>(sourceBElem))
600+
if (isa<Float8E5M2FNUZType, Float8E5M2Type>(sourceBElem))
599601
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
600-
if (isa<Float8E4M3FNUZType>(sourceBElem))
602+
if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(sourceBElem))
601603
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
602604
}
603605
if (m == 32 && n == 32 && k == 16 && b == 1) {
604-
if (isa<Float8E5M2FNUZType>(sourceBElem))
606+
if (isa<Float8E5M2FNUZType, Float8E5M2Type>(sourceBElem))
605607
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
606-
if (isa<Float8E4M3FNUZType>(sourceBElem))
608+
if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(sourceBElem))
607609
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
608610
}
609611
}
@@ -811,10 +813,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
811813
}
812814
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
813815
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
814-
if (isa<Float8E5M2FNUZType>(sourceElemType)) {
816+
if (isa<Float8E5M2FNUZType, Float8E5M2Type>(sourceElemType)) {
815817
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
816818
wordSel);
817-
} else if (isa<Float8E4M3FNUZType>(sourceElemType)) {
819+
} else if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(sourceElemType)) {
818820
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
819821
wordSel);
820822
}
@@ -846,10 +848,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
846848
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
847849

848850
Value result;
849-
if (isa<Float8E5M2FNUZType>(resultElemType))
851+
if (isa<Float8E5M2FNUZType, Float8E5M2Type>(resultElemType))
850852
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
851853
existing, wordSel);
852-
else if (isa<Float8E4M3FNUZType>(resultElemType))
854+
else if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(resultElemType))
853855
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
854856
existing, wordSel);
855857

@@ -881,10 +883,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
881883
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
882884

883885
Value result;
884-
if (isa<Float8E5M2FNUZType>(resultElemType))
886+
if (isa<Float8E5M2FNUZType, Float8E5M2Type>(resultElemType))
885887
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
886888
existing, byteSel);
887-
else if (isa<Float8E4M3FNUZType>(resultElemType))
889+
else if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(resultElemType))
888890
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
889891
existing, byteSel);
890892

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
8686
return failure();
8787
inType = inVecType.getElementType();
8888
}
89-
return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(inType));
89+
return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E5M2Type,
90+
Float8E4M3FNType>(inType));
9091
}
9192

9293
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -219,7 +220,11 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
219220
if (inType && inType.getWidth() <= 8 && saturateFP8)
220221
// Conversion between 8-bit floats is not supported with truncation enabled.
221222
return failure();
222-
return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(outType));
223+
224+
return success((
225+
(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(outType) &&
226+
chipset.isGfx940()) ||
227+
(isa<Float8E5M2Type, Float8E4M3FNType>(outType) && chipset.hasOcpFp8())));
223228
}
224229

225230
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,14 +272,16 @@ LogicalResult MFMAOp::verify() {
272272
}
273273

274274
Type sourceBType = getSourceB().getType();
275-
if (isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceElem)) {
275+
if (isa<Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E5M2Type,
276+
Float8E4M3FNType>(sourceElem)) {
276277
int64_t sourceBLen = 1;
277278
Type sourceBElem = sourceBType;
278279
if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
279280
sourceBLen = sourceBVector.getNumElements();
280281
sourceBElem = sourceBVector.getElementType();
281282
}
282-
if (!isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceBElem))
283+
if (!isa<Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E5M2Type,
284+
Float8E4M3FNType>(sourceBElem))
283285
return emitOpError("expected both source operands to have f8 elements");
284286
if (sourceLen != sourceBLen)
285287
return emitOpError(

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,8 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
696696

697697
bool TosaValidation::isValidElementType(Type type) {
698698
if (isa<FloatType>(type)) {
699-
return type.isF32() || type.isF16() || type.isBF16();
699+
return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNUZType,
700+
Float8E5M2FNUZType, Float8E4M3FNType, Float8E5M2Type>(type);
700701
} else if (auto intTy = dyn_cast<IntegerType>(type)) {
701702
if (intTy.isSignless()) {
702703
switch (intTy.getWidth()) {

0 commit comments

Comments
 (0)