@@ -41,6 +41,10 @@ struct ArithToAMDGPUConversionPass final
41
41
struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
42
42
using OpRewritePattern::OpRewritePattern;
43
43
44
+ Chipset chipset;
45
+ ExtFOnFloat8RewritePattern (MLIRContext *ctx, Chipset chipset)
46
+ : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
47
+
44
48
LogicalResult match (arith::ExtFOp op) const override ;
45
49
void rewrite (arith::ExtFOp op, PatternRewriter &rewriter) const override ;
46
50
};
@@ -68,6 +72,14 @@ struct TruncfToFloat16RewritePattern final
68
72
69
73
} // end namespace
70
74
75
+ static LogicalResult isSupportedFp8 (Type elementType, Chipset chipset) {
76
+ if (chipset.isGfx940 ())
77
+ return success (isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(elementType));
78
+ if (chipset.hasOcpFp8 ())
79
+ return success (isa<Float8E5M2Type, Float8E4M3FNType>(elementType));
80
+ return failure ();
81
+ }
82
+
71
83
static Value castF32To (Type elementType, Value f32 , Location loc,
72
84
PatternRewriter &rewriter) {
73
85
if (elementType.isF32 ())
@@ -86,8 +98,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
86
98
return failure ();
87
99
inType = inVecType.getElementType ();
88
100
}
89
- return success (isa<Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E5M2Type,
90
- Float8E4M3FNType>(inType));
101
+ return isSupportedFp8 (inType, chipset);
91
102
}
92
103
93
104
void ExtFOnFloat8RewritePattern::rewrite (arith::ExtFOp op,
@@ -221,10 +232,7 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
221
232
// Conversion between 8-bit floats is not supported with truncation enabled.
222
233
return failure ();
223
234
224
- return success ((
225
- (isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(outType) &&
226
- chipset.isGfx940 ()) ||
227
- (isa<Float8E5M2Type, Float8E4M3FNType>(outType) && chipset.hasOcpFp8 ())));
235
+ return isSupportedFp8 (outType, chipset);
228
236
}
229
237
230
238
void TruncFToFloat8RewritePattern::rewrite (arith::TruncFOp op,
@@ -370,7 +378,7 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
370
378
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
371
379
372
380
if (convertFP8Arithmetic) {
373
- patterns.add <ExtFOnFloat8RewritePattern>(patterns.getContext ());
381
+ patterns.add <ExtFOnFloat8RewritePattern>(patterns.getContext (), chipset );
374
382
patterns.add <TruncFToFloat8RewritePattern>(patterns.getContext (),
375
383
saturateFP8Truncf, chipset);
376
384
}
@@ -389,7 +397,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
389
397
}
390
398
391
399
bool convertFP8Arithmetic =
392
- maybeChipset->majorVersion == 9 && * maybeChipset >= Chipset ( 9 , 4 , 2 );
400
+ maybeChipset->isGfx940 () || maybeChipset-> hasOcpFp8 ( );
393
401
arith::populateArithToAMDGPUConversionPatterns (
394
402
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
395
403
*maybeChipset);
0 commit comments