Skip to content

Commit c834d5a

Browse files
pcf000mirza-halilcevic
authored andcommitted
[MLIR][AMDGPU] Clean up and redo after other recent patches here.
1 parent ff47308 commit c834d5a

File tree

3 files changed

+21
-13
lines changed

3 files changed

+21
-13
lines changed

mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ struct Chipset {
4949
#undef DEFINE_COMP_OPERATOR
5050

5151
bool isGfx940() const {
52-
return majorVersion == 9 && minorVersion >= 0x40 && minorVersion < 0x50;
52+
return majorVersion == 9 && minorVersion >= 4 && minorVersion < 5;
5353
}
5454
bool hasOcpFp8() const {
55-
return (majorVersion == 9 && minorVersion >= 0x50) || majorVersion >= 12;
55+
return (majorVersion == 9 && minorVersion >= 5) || majorVersion >= 12;
5656
}
5757
};
5858

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
787787
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
788788
ConversionPatternRewriter &rewriter) const {
789789
Location loc = op.getLoc();
790-
if (chipset.majorVersion != 9 || chipset < kGfx942)
790+
if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
791791
return rewriter.notifyMatchFailure(
792792
loc, "Fp8 conversion instructions are not available on target "
793793
"architecture and their emulation is not implemented");
@@ -831,7 +831,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
831831
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
832832
ConversionPatternRewriter &rewriter) const {
833833
Location loc = op.getLoc();
834-
if (chipset.majorVersion != 9 || chipset < kGfx942)
834+
if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
835835
return rewriter.notifyMatchFailure(
836836
loc, "Fp8 conversion instructions are not available on target "
837837
"architecture and their emulation is not implemented");
@@ -868,7 +868,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
868868
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
869869
ConversionPatternRewriter &rewriter) const {
870870
Location loc = op.getLoc();
871-
if (chipset.majorVersion != 9 || chipset < kGfx942)
871+
if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
872872
return rewriter.notifyMatchFailure(
873873
loc, "Fp8 conversion instructions are not available on target "
874874
"architecture and their emulation is not implemented");

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ struct ArithToAMDGPUConversionPass final
4141
struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
4242
using OpRewritePattern::OpRewritePattern;
4343

44+
Chipset chipset;
45+
ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
46+
: OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
47+
4448
LogicalResult match(arith::ExtFOp op) const override;
4549
void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
4650
};
@@ -68,6 +72,14 @@ struct TruncfToFloat16RewritePattern final
6872

6973
} // end namespace
7074

75+
static LogicalResult isSupportedFp8(Type elementType, Chipset chipset) {
76+
if (chipset.isGfx940())
77+
return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(elementType));
78+
if (chipset.hasOcpFp8())
79+
return success(isa<Float8E5M2Type, Float8E4M3FNType>(elementType));
80+
return failure();
81+
}
82+
7183
static Value castF32To(Type elementType, Value f32, Location loc,
7284
PatternRewriter &rewriter) {
7385
if (elementType.isF32())
@@ -86,8 +98,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
8698
return failure();
8799
inType = inVecType.getElementType();
88100
}
89-
return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E5M2Type,
90-
Float8E4M3FNType>(inType));
101+
return isSupportedFp8(inType, chipset);
91102
}
92103

93104
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -221,10 +232,7 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
221232
// Conversion between 8-bit floats is not supported with truncation enabled.
222233
return failure();
223234

224-
return success((
225-
(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(outType) &&
226-
chipset.isGfx940()) ||
227-
(isa<Float8E5M2Type, Float8E4M3FNType>(outType) && chipset.hasOcpFp8())));
235+
return isSupportedFp8(outType, chipset);
228236
}
229237

230238
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
@@ -370,7 +378,7 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
370378
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
371379

372380
if (convertFP8Arithmetic) {
373-
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
381+
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset);
374382
patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
375383
saturateFP8Truncf, chipset);
376384
}
@@ -389,7 +397,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
389397
}
390398

391399
bool convertFP8Arithmetic =
392-
maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 2);
400+
maybeChipset->isGfx940() || maybeChipset->hasOcpFp8();
393401
arith::populateArithToAMDGPUConversionPatterns(
394402
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
395403
*maybeChipset);

0 commit comments

Comments
 (0)