10
10
11
11
#include " mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12
12
#include " mlir/Dialect/Arith/IR/Arith.h"
13
+ #include " mlir/Dialect/Arith/Utils/Utils.h"
13
14
#include " mlir/Dialect/Vector/IR/VectorOps.h"
14
15
#include " mlir/IR/BuiltinTypes.h"
15
16
#include " mlir/IR/PatternMatch.h"
@@ -34,17 +35,17 @@ struct ArithToAMDGPUConversionPass final
34
35
void runOnOperation () override ;
35
36
};
36
37
37
- struct ExtfOnFloat8RewritePattern final
38
- : public OpRewritePattern<arith::ExtFOp> {
39
- using OpRewritePattern<arith::ExtFOp>::OpRewritePattern;
38
+ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
39
+ using OpRewritePattern::OpRewritePattern;
40
40
41
41
LogicalResult match (arith::ExtFOp op) const override ;
42
42
void rewrite (arith::ExtFOp op, PatternRewriter &rewriter) const override ;
43
43
};
44
44
45
- struct TruncfToFloat8RewritePattern final
46
- : public OpRewritePattern<arith::TruncFOp> {
47
- using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
45
+ struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
46
+ bool saturateFP8 = false ;
47
+ TruncFToFloat8RewritePattern (MLIRContext *ctx, bool saturateFP8)
48
+ : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8) {}
48
49
49
50
LogicalResult match (arith::TruncFOp op) const override ;
50
51
void rewrite (arith::TruncFOp op, PatternRewriter &rewriter) const override ;
@@ -62,7 +63,7 @@ static Value castF32To(Type elementType, Value f32, Location loc,
62
63
llvm_unreachable (" The only 32-bit float type is f32" );
63
64
}
64
65
65
- LogicalResult ExtfOnFloat8RewritePattern ::match (arith::ExtFOp op) const {
66
+ LogicalResult ExtFOnFloat8RewritePattern ::match (arith::ExtFOp op) const {
66
67
Type inType = op.getIn ().getType ();
67
68
if (auto inVecType = inType.dyn_cast <VectorType>()) {
68
69
if (inVecType.isScalable ())
@@ -75,7 +76,7 @@ LogicalResult ExtfOnFloat8RewritePattern::match(arith::ExtFOp op) const {
75
76
return success (inType.isFloat8E5M2FNUZ () || inType.isFloat8E4M3FNUZ ());
76
77
}
77
78
78
- void ExtfOnFloat8RewritePattern ::rewrite (arith::ExtFOp op,
79
+ void ExtFOnFloat8RewritePattern ::rewrite (arith::ExtFOp op,
79
80
PatternRewriter &rewriter) const {
80
81
Location loc = op.getLoc ();
81
82
Value in = op.getIn ();
@@ -93,11 +94,13 @@ void ExtfOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
93
94
Value result =
94
95
rewriter.createOrFold <vector::SplatOp>(loc, op.getOut ().getType (), zero);
95
96
if (inType.getShape ().empty ()) {
96
- Value scalarIn = rewriter.create <vector::ExtractElementOp>(loc, in);
97
+ Value scalarIn =
98
+ rewriter.create <vector::ExtractOp>(loc, in, ArrayRef<int64_t >{});
97
99
// Recurse to send the 0-D vector case to the 1-D vector case
98
100
Value scalarExt =
99
101
rewriter.create <arith::ExtFOp>(loc, outElemType, scalarIn);
100
- result = rewriter.create <vector::InsertElementOp>(loc, scalarExt, zero);
102
+ result = rewriter.create <vector::InsertOp>(loc, scalarExt, zero,
103
+ ArrayRef<int64_t >{});
101
104
return rewriter.replaceOp (op, result);
102
105
}
103
106
for (int64_t i = 0 ; i < numElements; i += 4 ) {
@@ -108,9 +111,7 @@ void ExtfOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
108
111
Value asFloat = rewriter.create <amdgpu::ExtPackedFp8Op>(
109
112
loc, rewriter.getF32Type (), inSlice, j);
110
113
Value asType = castF32To (outElemType, asFloat, loc, rewriter);
111
- result = rewriter.create <vector::InsertElementOp>(
112
- loc, asType, result,
113
- rewriter.createOrFold <arith::ConstantIndexOp>(loc, i + j));
114
+ result = rewriter.create <vector::InsertOp>(loc, asType, result, i + j);
114
115
}
115
116
}
116
117
rewriter.replaceOp (op, result);
@@ -127,7 +128,53 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
127
128
llvm_unreachable (" The only 32-bit float type is f32" );
128
129
}
129
130
130
- LogicalResult TruncfToFloat8RewritePattern::match (arith::TruncFOp op) const {
131
+ // If `in` is a finite value, clamp it between the maximum and minimum values
132
+ // of `outElemType` so that subsequent conversion instructions don't
133
+ // overflow those out-of-range values to NaN. These semantics are commonly
134
+ // used in machine-learning contexts where failure to clamp would lead to
135
+ // excessive NaN production.
136
+ static Value clampInput (PatternRewriter &rewriter, Location loc,
137
+ Type outElemType, Value source) {
138
+ Type sourceType = source.getType ();
139
+ const llvm::fltSemantics &sourceSem =
140
+ cast<FloatType>(getElementTypeOrSelf (sourceType)).getFloatSemantics ();
141
+ const llvm::fltSemantics &targetSem =
142
+ cast<FloatType>(outElemType).getFloatSemantics ();
143
+
144
+ APFloat min = APFloat::getLargest (targetSem, /* Negative=*/ true );
145
+ APFloat max = APFloat::getLargest (targetSem, /* Negative=*/ false );
146
+ bool ignoredLosesInfo = false ;
147
+ // We can ignore conversion failures here because this conversion promotes
148
+ // from a smaller type to a larger one - ex. there can be no loss of precision
149
+ // when casting fp8 to f16.
150
+ (void )min.convert (sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
151
+ (void )max.convert (sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
152
+
153
+ Value minCst = createScalarOrSplatConstant (rewriter, loc, sourceType, min);
154
+ Value maxCst = createScalarOrSplatConstant (rewriter, loc, sourceType, max);
155
+
156
+ Value inf = createScalarOrSplatConstant (
157
+ rewriter, loc, sourceType,
158
+ APFloat::getInf (sourceSem, /* Negative=*/ false ));
159
+ Value negInf = createScalarOrSplatConstant (
160
+ rewriter, loc, sourceType, APFloat::getInf (sourceSem, /* Negative=*/ true ));
161
+ Value isInf = rewriter.createOrFold <arith::CmpFOp>(
162
+ loc, arith::CmpFPredicate::OEQ, source, inf);
163
+ Value isNegInf = rewriter.createOrFold <arith::CmpFOp>(
164
+ loc, arith::CmpFPredicate::OEQ, source, negInf);
165
+ Value isNan = rewriter.createOrFold <arith::CmpFOp>(
166
+ loc, arith::CmpFPredicate::UNO, source, source);
167
+ Value isNonFinite = rewriter.create <arith::OrIOp>(
168
+ loc, rewriter.create <arith::OrIOp>(loc, isInf, isNegInf), isNan);
169
+
170
+ Value clampedBelow = rewriter.create <arith::MaximumFOp>(loc, source, minCst);
171
+ Value clamped = rewriter.create <arith::MinimumFOp>(loc, clampedBelow, maxCst);
172
+ Value res =
173
+ rewriter.create <arith::SelectOp>(loc, isNonFinite, source, clamped);
174
+ return res;
175
+ }
176
+
177
+ LogicalResult TruncFToFloat8RewritePattern::match (arith::TruncFOp op) const {
131
178
Type outType = op.getOut ().getType ();
132
179
if (auto outVecType = outType.dyn_cast <VectorType>()) {
133
180
if (outVecType.isScalable ())
@@ -137,22 +184,27 @@ LogicalResult TruncfToFloat8RewritePattern::match(arith::TruncFOp op) const {
137
184
return failure ();
138
185
outType = outVecType.getElementType ();
139
186
}
187
+ auto inType = dyn_cast<FloatType>(getElementTypeOrSelf (op.getIn ().getType ()));
188
+ if (inType && inType.getWidth () <= 8 && saturateFP8)
189
+ // Conversion between 8-bit floats is not supported with truncation enabled.
190
+ return failure ();
140
191
return success (outType.isFloat8E5M2FNUZ () || outType.isFloat8E4M3FNUZ ());
141
192
}
142
193
143
- void TruncfToFloat8RewritePattern ::rewrite (arith::TruncFOp op,
194
+ void TruncFToFloat8RewritePattern ::rewrite (arith::TruncFOp op,
144
195
PatternRewriter &rewriter) const {
145
196
Location loc = op.getLoc ();
146
197
Value in = op.getIn ();
147
198
Type outElemType = getElementTypeOrSelf (op.getOut ().getType ());
199
+ if (saturateFP8)
200
+ in = clampInput (rewriter, loc, outElemType, in);
148
201
VectorType truncResType = VectorType::get (4 , outElemType);
149
202
if (!in.getType ().isa <VectorType>()) {
150
203
Value asFloat = castToF32 (in, loc, rewriter);
151
204
Value asF8s = rewriter.create <amdgpu::PackedTrunc2xFp8Op>(
152
205
loc, truncResType, asFloat, /* sourceB=*/ nullptr , 0 ,
153
206
/* existing=*/ nullptr );
154
- Value result = rewriter.create <vector::ExtractElementOp>(
155
- loc, asF8s, rewriter.createOrFold <arith::ConstantIndexOp>(loc, 0 ));
207
+ Value result = rewriter.create <vector::ExtractOp>(loc, asF8s, 0 );
156
208
return rewriter.replaceOp (op, result);
157
209
}
158
210
VectorType outType = op.getOut ().getType ().cast <VectorType>();
@@ -161,26 +213,25 @@ void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
161
213
loc, outElemType, rewriter.getFloatAttr (outElemType, 0.0 ));
162
214
Value result = rewriter.createOrFold <vector::SplatOp>(loc, outType, zero);
163
215
if (outType.getShape ().empty ()) {
164
- Value scalarIn = rewriter.create <vector::ExtractElementOp>(loc, in);
216
+ Value scalarIn =
217
+ rewriter.create <vector::ExtractOp>(loc, in, ArrayRef<int64_t >{});
165
218
// Recurse to send the 0-D vector case to the 1-D vector case
166
219
Value scalarTrunc =
167
220
rewriter.create <arith::TruncFOp>(loc, outElemType, scalarIn);
168
- result = rewriter.create <vector::InsertElementOp>(loc, scalarTrunc, zero);
221
+ result = rewriter.create <vector::InsertOp>(loc, scalarTrunc, zero,
222
+ ArrayRef<int64_t >{});
169
223
return rewriter.replaceOp (op, result);
170
224
}
171
225
172
226
for (int64_t i = 0 ; i < numElements; i += 4 ) {
173
227
int64_t elemsThisOp = std::min (numElements, i + 4 ) - i;
174
228
Value thisResult = nullptr ;
175
229
for (int64_t j = 0 ; j < elemsThisOp; j += 2 ) {
176
- Value elemA = rewriter.create <vector::ExtractElementOp>(
177
- loc, in, rewriter.create <arith::ConstantIndexOp>(loc, i + j));
230
+ Value elemA = rewriter.create <vector::ExtractOp>(loc, in, i + j);
178
231
Value asFloatA = castToF32 (elemA, loc, rewriter);
179
232
Value asFloatB = nullptr ;
180
233
if (j + 1 < elemsThisOp) {
181
- Value elemB = rewriter.create <vector::ExtractElementOp>(
182
- loc, in,
183
- rewriter.createOrFold <arith::ConstantIndexOp>(loc, i + j + 1 ));
234
+ Value elemB = rewriter.create <vector::ExtractOp>(loc, in, i + j + 1 );
184
235
asFloatB = castToF32 (elemB, loc, rewriter);
185
236
}
186
237
thisResult = rewriter.create <amdgpu::PackedTrunc2xFp8Op>(
@@ -196,15 +247,16 @@ void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
196
247
}
197
248
198
249
void mlir::arith::populateArithToAMDGPUConversionPatterns (
199
- RewritePatternSet &patterns) {
200
- patterns.add <ExtfOnFloat8RewritePattern, TruncfToFloat8RewritePattern>(
201
- patterns.getContext ());
250
+ RewritePatternSet &patterns, bool saturateFP8TruncF) {
251
+ patterns.add <ExtFOnFloat8RewritePattern>(patterns.getContext ());
252
+ patterns.add <TruncFToFloat8RewritePattern>(patterns.getContext (),
253
+ saturateFP8TruncF);
202
254
}
203
255
204
256
void ArithToAMDGPUConversionPass::runOnOperation () {
205
257
Operation *op = getOperation ();
206
258
RewritePatternSet patterns (op->getContext ());
207
- arith::populateArithToAMDGPUConversionPatterns (patterns);
259
+ arith::populateArithToAMDGPUConversionPatterns (patterns, saturateFP8Truncf );
208
260
if (failed (applyPatternsAndFoldGreedily (op, std::move (patterns))))
209
261
return signalPassFailure ();
210
262
}
0 commit comments