Skip to content

Commit 0a761c0

Browse files
pcf000mirza-halilcevic
authored andcommitted
[MLIR][AMDGPU] Changes from the review.
1 parent a0911cc commit 0a761c0

File tree

6 files changed

+37
-26
lines changed

6 files changed

+37
-26
lines changed

mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,16 @@ struct Chipset {
4747
DEFINE_COMP_OPERATOR(>)
4848
DEFINE_COMP_OPERATOR(>=)
4949
#undef DEFINE_COMP_OPERATOR
50-
51-
bool isGfx940() const {
52-
return majorVersion == 9 && minorVersion >= 4 && minorVersion < 5;
53-
}
54-
bool hasOcpFp8() const {
55-
return (majorVersion == 9 && minorVersion >= 5) || majorVersion >= 12;
56-
}
5750
};
5851

52+
inline bool isGfx940Series(const Chipset &chipset) {
53+
return chipset.majorVersion == 9 && chipset.minorVersion == 4;
54+
}
55+
inline bool hasOcpFp8(const Chipset &chipset) {
56+
return (chipset.majorVersion == 9 && chipset.minorVersion >= 5) ||
57+
chipset.majorVersion >= 12;
58+
}
59+
5960
} // namespace mlir::amdgpu
6061

6162
#endif

mlir/include/mlir/IR/Types.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ class Type {
132132
bool isF64() const;
133133
bool isF80() const;
134134
bool isF128() const;
135+
/// Return true if this is an float type (with the specified width).
136+
bool isFloat() const;
137+
bool isFloat(unsigned width) const;
135138

136139
/// Return true if this is an integer type (with the specified width).
137140
bool isInteger() const;

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -477,15 +477,15 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
477477
/// Return true if `type` is the E5M2 variant of an 8-bit float that is
478478
/// supported by the `_bf8` instructions on the given `chipset`.
479479
static bool isNativeBf8(Chipset chipset, Type type) {
480-
return (chipset.isGfx940() && isa<Float8E5M2FNUZType>(type)) ||
481-
(chipset.hasOcpFp8() && isa<Float8E5M2Type>(type));
480+
return (isGfx940Series(chipset) && isa<Float8E5M2FNUZType>(type)) ||
481+
(hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
482482
}
483483

484484
/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
485485
/// supported by the `_fp8` instructions on the given `chipset`.
486486
static bool isNativeFp8(Chipset chipset, Type type) {
487-
return (chipset.isGfx940() && isa<Float8E4M3FNUZType>(type)) ||
488-
(chipset.hasOcpFp8() && isa<Float8E4M3FNType>(type));
487+
return (isGfx940Series(chipset) && isa<Float8E4M3FNUZType>(type)) ||
488+
(hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
489489
}
490490

491491
/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
@@ -793,7 +793,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
793793
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
794794
ConversionPatternRewriter &rewriter) const {
795795
Location loc = op.getLoc();
796-
if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
796+
if (!(isGfx940Series(chipset) || hasOcpFp8(chipset)))
797797
return rewriter.notifyMatchFailure(
798798
loc, "Fp8 conversion instructions are not available on target "
799799
"architecture and their emulation is not implemented");
@@ -837,7 +837,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
837837
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
838838
ConversionPatternRewriter &rewriter) const {
839839
Location loc = op.getLoc();
840-
if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
840+
if (!(isGfx940Series(chipset) || hasOcpFp8(chipset)))
841841
return rewriter.notifyMatchFailure(
842842
loc, "Fp8 conversion instructions are not available on target "
843843
"architecture and their emulation is not implemented");
@@ -874,7 +874,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
874874
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
875875
ConversionPatternRewriter &rewriter) const {
876876
Location loc = op.getLoc();
877-
if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
877+
if (!(isGfx940Series(chipset) || hasOcpFp8(chipset)))
878878
return rewriter.notifyMatchFailure(
879879
loc, "Fp8 conversion instructions are not available on target "
880880
"architecture and their emulation is not implemented");

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,11 @@ struct TruncfToFloat16RewritePattern final
7272

7373
} // end namespace
7474

75-
static LogicalResult isSupportedFp8(Type elementType, Chipset chipset) {
76-
if (chipset.isGfx940())
77-
return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(elementType));
78-
if (chipset.hasOcpFp8())
79-
return success(isa<Float8E5M2Type, Float8E4M3FNType>(elementType));
75+
static LogicalResult isSupportedF8(Type elementType, Chipset chipset) {
76+
if (isGfx940Series(chipset))
77+
return success(isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType));
78+
if (hasOcpFp8(chipset))
79+
return success(isa<Float8E4M3FNType, Float8E5M2Type>(elementType));
8080
return failure();
8181
}
8282

@@ -98,7 +98,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
9898
return failure();
9999
inType = inVecType.getElementType();
100100
}
101-
return isSupportedFp8(inType, chipset);
101+
return isSupportedF8(inType, chipset);
102102
}
103103

104104
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -232,7 +232,7 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
232232
// Conversion between 8-bit floats is not supported with truncation enabled.
233233
return failure();
234234

235-
return isSupportedFp8(outType, chipset);
235+
return isSupportedF8(outType, chipset);
236236
}
237237

238238
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
@@ -397,7 +397,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
397397
}
398398

399399
bool convertFP8Arithmetic =
400-
maybeChipset->isGfx940() || maybeChipset->hasOcpFp8();
400+
isGfx940Series(*maybeChipset) || hasOcpFp8(*maybeChipset);
401401
arith::populateArithToAMDGPUConversionPatterns(
402402
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
403403
*maybeChipset);

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -272,16 +272,14 @@ LogicalResult MFMAOp::verify() {
272272
}
273273

274274
Type sourceBType = getSourceB().getType();
275-
if (isa<Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E5M2Type,
276-
Float8E4M3FNType>(sourceElem)) {
275+
if (sourceElem.isFloat(8)) {
277276
int64_t sourceBLen = 1;
278277
Type sourceBElem = sourceBType;
279278
if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
280279
sourceBLen = sourceBVector.getNumElements();
281280
sourceBElem = sourceBVector.getElementType();
282281
}
283-
if (!isa<Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E5M2Type,
284-
Float8E4M3FNType>(sourceBElem))
282+
if (!sourceBElem.isFloat(8))
285283
return emitOpError("expected both source operands to have f8 elements");
286284
if (sourceLen != sourceBLen)
287285
return emitOpError(

mlir/lib/IR/Types.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,15 @@ bool Type::isF64() const { return llvm::isa<Float64Type>(*this); }
4242
bool Type::isF80() const { return llvm::isa<Float80Type>(*this); }
4343
bool Type::isF128() const { return llvm::isa<Float128Type>(*this); }
4444

45+
bool Type::isFloat() const { return llvm::isa<FloatType>(*this); }
46+
47+
/// Return true if this is an integer type with the specified width.
48+
bool Type::isFloat(unsigned width) const {
49+
if (auto fltTy = llvm::dyn_cast<FloatType>(*this))
50+
return fltTy.getWidth() == width;
51+
return false;
52+
}
53+
4554
bool Type::isIndex() const { return llvm::isa<IndexType>(*this); }
4655

4756
bool Type::isInteger() const { return llvm::isa<IntegerType>(*this); }

0 commit comments

Comments
 (0)