Skip to content

Commit a0911cc

Browse files
pcf000mirza-halilcevic
authored andcommitted
[MLIR][AMDGPU] Clean up and redo after other recent patches here.
1 parent b87a0a0 commit a0911cc

File tree

3 files changed

+21
-13
lines changed

3 files changed

+21
-13
lines changed

mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ struct Chipset {
4949
#undef DEFINE_COMP_OPERATOR
5050

5151
bool isGfx940() const {
52-
return majorVersion == 9 && minorVersion >= 0x40 && minorVersion < 0x50;
52+
return majorVersion == 9 && minorVersion >= 4 && minorVersion < 5;
5353
}
5454
bool hasOcpFp8() const {
55-
return (majorVersion == 9 && minorVersion >= 0x50) || majorVersion >= 12;
55+
return (majorVersion == 9 && minorVersion >= 5) || majorVersion >= 12;
5656
}
5757
};
5858

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
793793
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
794794
ConversionPatternRewriter &rewriter) const {
795795
Location loc = op.getLoc();
796-
if (chipset.majorVersion != 9 || chipset < kGfx942)
796+
if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
797797
return rewriter.notifyMatchFailure(
798798
loc, "Fp8 conversion instructions are not available on target "
799799
"architecture and their emulation is not implemented");
@@ -837,7 +837,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
837837
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
838838
ConversionPatternRewriter &rewriter) const {
839839
Location loc = op.getLoc();
840-
if (chipset.majorVersion != 9 || chipset < kGfx942)
840+
if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
841841
return rewriter.notifyMatchFailure(
842842
loc, "Fp8 conversion instructions are not available on target "
843843
"architecture and their emulation is not implemented");
@@ -874,7 +874,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
874874
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
875875
ConversionPatternRewriter &rewriter) const {
876876
Location loc = op.getLoc();
877-
if (chipset.majorVersion != 9 || chipset < kGfx942)
877+
if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
878878
return rewriter.notifyMatchFailure(
879879
loc, "Fp8 conversion instructions are not available on target "
880880
"architecture and their emulation is not implemented");

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ struct ArithToAMDGPUConversionPass final
4141
struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
4242
using OpRewritePattern::OpRewritePattern;
4343

44+
Chipset chipset;
45+
ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
46+
: OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
47+
4448
LogicalResult match(arith::ExtFOp op) const override;
4549
void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
4650
};
@@ -68,6 +72,14 @@ struct TruncfToFloat16RewritePattern final
6872

6973
} // end namespace
7074

75+
static LogicalResult isSupportedFp8(Type elementType, Chipset chipset) {
76+
if (chipset.isGfx940())
77+
return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(elementType));
78+
if (chipset.hasOcpFp8())
79+
return success(isa<Float8E5M2Type, Float8E4M3FNType>(elementType));
80+
return failure();
81+
}
82+
7183
static Value castF32To(Type elementType, Value f32, Location loc,
7284
PatternRewriter &rewriter) {
7385
if (elementType.isF32())
@@ -86,8 +98,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
8698
return failure();
8799
inType = inVecType.getElementType();
88100
}
89-
return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E5M2Type,
90-
Float8E4M3FNType>(inType));
101+
return isSupportedFp8(inType, chipset);
91102
}
92103

93104
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -221,10 +232,7 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
221232
// Conversion between 8-bit floats is not supported with truncation enabled.
222233
return failure();
223234

224-
return success((
225-
(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(outType) &&
226-
chipset.isGfx940()) ||
227-
(isa<Float8E5M2Type, Float8E4M3FNType>(outType) && chipset.hasOcpFp8())));
235+
return isSupportedFp8(outType, chipset);
228236
}
229237

230238
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
@@ -370,7 +378,7 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
370378
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
371379

372380
if (convertFP8Arithmetic) {
373-
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
381+
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset);
374382
patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
375383
saturateFP8Truncf, chipset);
376384
}
@@ -389,7 +397,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
389397
}
390398

391399
bool convertFP8Arithmetic =
392-
maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 2);
400+
maybeChipset->isGfx940() || maybeChipset->hasOcpFp8();
393401
arith::populateArithToAMDGPUConversionPatterns(
394402
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
395403
*maybeChipset);

0 commit comments

Comments
 (0)