Skip to content

Commit f113249

Browse files
committed
[mlir][ArithToAMDGPU] Add option for saturating truncation to fp8
Many machine-learning applications (and most software written at AMD) expect the operation that truncates floats to 8-bit floats to be saturatinng. That is, they expect `truncf 256.0 : f32 to f8E4M3FNUZ` to yield `240.0`, not `NaN`, and similarly for negative numbers. However, the underlying hardware instruction that can be used for this truncation implements overflow-to-NaN semantics. To enable handling this usecase, we add the saturate-fp8-truncf option to ArithToAMDGPU (off by default), which causes the requisite clamping code to be emitted. Said clamping code ensures that Inf and NaN are passed through exactly (and thus trancate to NaN). Per review feedback, this commit efactors createScalarOrSplatConstant() to the Arith dialect utilities and uses it in this code. It also fixes naming of existing patterns and switches from vector.extractelement/insertelement to vector.extract/insert.
1 parent eabddf2 commit f113249

File tree

9 files changed

+208
-77
lines changed

9 files changed

+208
-77
lines changed

mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,13 @@ class Pass;
2020
#include "mlir/Conversion/Passes.h.inc"
2121

2222
namespace arith {
23-
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns);
23+
/// Add patterns for rewriting `arith.extf` and `arith.truncf` on FP8 types
24+
/// to wrappers around AMDGPU--specific intrinsics. If `saturateFP8TruncF`
25+
/// is set, values outside the range of the destination type are clamped
26+
/// to the largest value of that type instead of being rewritten to Inf (aka
27+
/// NaN).
28+
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
29+
bool saturateFP8TruncF);
2430
} // namespace arith
2531
} // namespace mlir
2632

mlir/include/mlir/Conversion/Passes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,12 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
125125
}];
126126

127127
let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
128+
129+
let options = [
130+
Option<"saturateFP8Truncf", "saturate-fp8-truncf", "bool",
131+
/*default=*/"false",
132+
"Use saturating truncation for 8-bit float types">,
133+
];
128134
}
129135

130136
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Arith/Utils/Utils.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,17 @@ Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
5353
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
5454
Type toType, bool isUnsignedCast);
5555

56+
/// Create a constant of type `type` at location `loc` whose value is `value`
57+
/// (an APInt or APFloat whose type must match the element type of `type`).
58+
/// If `type` is a shaped type, create a splat constant of the given value.
59+
/// Constants are folded if possible.
60+
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
61+
const APInt &value);
62+
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
63+
int64_t value);
64+
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
65+
const APFloat &value);
66+
5667
/// Helper struct to build simple arithmetic quantities with minimal type
5768
/// inference support.
5869
struct ArithBuilder {

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 80 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
1212
#include "mlir/Dialect/Arith/IR/Arith.h"
13+
#include "mlir/Dialect/Arith/Utils/Utils.h"
1314
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1415
#include "mlir/IR/BuiltinTypes.h"
1516
#include "mlir/IR/PatternMatch.h"
@@ -34,17 +35,17 @@ struct ArithToAMDGPUConversionPass final
3435
void runOnOperation() override;
3536
};
3637

37-
struct ExtfOnFloat8RewritePattern final
38-
: public OpRewritePattern<arith::ExtFOp> {
39-
using OpRewritePattern<arith::ExtFOp>::OpRewritePattern;
38+
struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
39+
using OpRewritePattern::OpRewritePattern;
4040

4141
LogicalResult match(arith::ExtFOp op) const override;
4242
void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
4343
};
4444

45-
struct TruncfToFloat8RewritePattern final
46-
: public OpRewritePattern<arith::TruncFOp> {
47-
using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
45+
struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
46+
bool saturateFP8 = false;
47+
TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8)
48+
: OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8) {}
4849

4950
LogicalResult match(arith::TruncFOp op) const override;
5051
void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
@@ -62,7 +63,7 @@ static Value castF32To(Type elementType, Value f32, Location loc,
6263
llvm_unreachable("The only 32-bit float type is f32");
6364
}
6465

65-
LogicalResult ExtfOnFloat8RewritePattern::match(arith::ExtFOp op) const {
66+
LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
6667
Type inType = op.getIn().getType();
6768
if (auto inVecType = inType.dyn_cast<VectorType>()) {
6869
if (inVecType.isScalable())
@@ -75,7 +76,7 @@ LogicalResult ExtfOnFloat8RewritePattern::match(arith::ExtFOp op) const {
7576
return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
7677
}
7778

78-
void ExtfOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
79+
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
7980
PatternRewriter &rewriter) const {
8081
Location loc = op.getLoc();
8182
Value in = op.getIn();
@@ -93,11 +94,13 @@ void ExtfOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
9394
Value result =
9495
rewriter.createOrFold<vector::SplatOp>(loc, op.getOut().getType(), zero);
9596
if (inType.getShape().empty()) {
96-
Value scalarIn = rewriter.create<vector::ExtractElementOp>(loc, in);
97+
Value scalarIn =
98+
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
9799
// Recurse to send the 0-D vector case to the 1-D vector case
98100
Value scalarExt =
99101
rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
100-
result = rewriter.create<vector::InsertElementOp>(loc, scalarExt, zero);
102+
result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
103+
ArrayRef<int64_t>{});
101104
return rewriter.replaceOp(op, result);
102105
}
103106
for (int64_t i = 0; i < numElements; i += 4) {
@@ -108,9 +111,7 @@ void ExtfOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
108111
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
109112
loc, rewriter.getF32Type(), inSlice, j);
110113
Value asType = castF32To(outElemType, asFloat, loc, rewriter);
111-
result = rewriter.create<vector::InsertElementOp>(
112-
loc, asType, result,
113-
rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + j));
114+
result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
114115
}
115116
}
116117
rewriter.replaceOp(op, result);
@@ -127,7 +128,53 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
127128
llvm_unreachable("The only 32-bit float type is f32");
128129
}
129130

130-
LogicalResult TruncfToFloat8RewritePattern::match(arith::TruncFOp op) const {
131+
// If `in` is a finite value, clamp it between the maximum and minimum values
132+
// of `outElemType` so that subsequent conversion instructions don't
133+
// overflow those out-of-range values to NaN. These semantics are commonly
134+
// used in machine-learning contexts where failure to clamp would lead to
135+
// excessive NaN production.
136+
static Value clampInput(PatternRewriter &rewriter, Location loc,
137+
Type outElemType, Value source) {
138+
Type sourceType = source.getType();
139+
const llvm::fltSemantics &sourceSem =
140+
cast<FloatType>(getElementTypeOrSelf(sourceType)).getFloatSemantics();
141+
const llvm::fltSemantics &targetSem =
142+
cast<FloatType>(outElemType).getFloatSemantics();
143+
144+
APFloat min = APFloat::getLargest(targetSem, /*Negative=*/true);
145+
APFloat max = APFloat::getLargest(targetSem, /*Negative=*/false);
146+
bool ignoredLosesInfo = false;
147+
// We can ignore conversion failures here because this conversion promotes
148+
// from a smaller type to a larger one - ex. there can be no loss of precision
149+
// when casting fp8 to f16.
150+
(void)min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
151+
(void)max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
152+
153+
Value minCst = createScalarOrSplatConstant(rewriter, loc, sourceType, min);
154+
Value maxCst = createScalarOrSplatConstant(rewriter, loc, sourceType, max);
155+
156+
Value inf = createScalarOrSplatConstant(
157+
rewriter, loc, sourceType,
158+
APFloat::getInf(sourceSem, /*Negative=*/false));
159+
Value negInf = createScalarOrSplatConstant(
160+
rewriter, loc, sourceType, APFloat::getInf(sourceSem, /*Negative=*/true));
161+
Value isInf = rewriter.createOrFold<arith::CmpFOp>(
162+
loc, arith::CmpFPredicate::OEQ, source, inf);
163+
Value isNegInf = rewriter.createOrFold<arith::CmpFOp>(
164+
loc, arith::CmpFPredicate::OEQ, source, negInf);
165+
Value isNan = rewriter.createOrFold<arith::CmpFOp>(
166+
loc, arith::CmpFPredicate::UNO, source, source);
167+
Value isNonFinite = rewriter.create<arith::OrIOp>(
168+
loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
169+
170+
Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst);
171+
Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
172+
Value res =
173+
rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped);
174+
return res;
175+
}
176+
177+
LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
131178
Type outType = op.getOut().getType();
132179
if (auto outVecType = outType.dyn_cast<VectorType>()) {
133180
if (outVecType.isScalable())
@@ -137,22 +184,27 @@ LogicalResult TruncfToFloat8RewritePattern::match(arith::TruncFOp op) const {
137184
return failure();
138185
outType = outVecType.getElementType();
139186
}
187+
auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
188+
if (inType && inType.getWidth() <= 8 && saturateFP8)
189+
// Conversion between 8-bit floats is not supported with truncation enabled.
190+
return failure();
140191
return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
141192
}
142193

143-
void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
194+
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
144195
PatternRewriter &rewriter) const {
145196
Location loc = op.getLoc();
146197
Value in = op.getIn();
147198
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
199+
if (saturateFP8)
200+
in = clampInput(rewriter, loc, outElemType, in);
148201
VectorType truncResType = VectorType::get(4, outElemType);
149202
if (!in.getType().isa<VectorType>()) {
150203
Value asFloat = castToF32(in, loc, rewriter);
151204
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
152205
loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
153206
/*existing=*/nullptr);
154-
Value result = rewriter.create<vector::ExtractElementOp>(
155-
loc, asF8s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
207+
Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
156208
return rewriter.replaceOp(op, result);
157209
}
158210
VectorType outType = op.getOut().getType().cast<VectorType>();
@@ -161,26 +213,25 @@ void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
161213
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
162214
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
163215
if (outType.getShape().empty()) {
164-
Value scalarIn = rewriter.create<vector::ExtractElementOp>(loc, in);
216+
Value scalarIn =
217+
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
165218
// Recurse to send the 0-D vector case to the 1-D vector case
166219
Value scalarTrunc =
167220
rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
168-
result = rewriter.create<vector::InsertElementOp>(loc, scalarTrunc, zero);
221+
result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
222+
ArrayRef<int64_t>{});
169223
return rewriter.replaceOp(op, result);
170224
}
171225

172226
for (int64_t i = 0; i < numElements; i += 4) {
173227
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
174228
Value thisResult = nullptr;
175229
for (int64_t j = 0; j < elemsThisOp; j += 2) {
176-
Value elemA = rewriter.create<vector::ExtractElementOp>(
177-
loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i + j));
230+
Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j);
178231
Value asFloatA = castToF32(elemA, loc, rewriter);
179232
Value asFloatB = nullptr;
180233
if (j + 1 < elemsThisOp) {
181-
Value elemB = rewriter.create<vector::ExtractElementOp>(
182-
loc, in,
183-
rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + j + 1));
234+
Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1);
184235
asFloatB = castToF32(elemB, loc, rewriter);
185236
}
186237
thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
@@ -196,15 +247,16 @@ void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
196247
}
197248

198249
void mlir::arith::populateArithToAMDGPUConversionPatterns(
199-
RewritePatternSet &patterns) {
200-
patterns.add<ExtfOnFloat8RewritePattern, TruncfToFloat8RewritePattern>(
201-
patterns.getContext());
250+
RewritePatternSet &patterns, bool saturateFP8TruncF) {
251+
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
252+
patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
253+
saturateFP8TruncF);
202254
}
203255

204256
void ArithToAMDGPUConversionPass::runOnOperation() {
205257
Operation *op = getOperation();
206258
RewritePatternSet patterns(op->getContext());
207-
arith::populateArithToAMDGPUConversionPatterns(patterns);
259+
arith::populateArithToAMDGPUConversionPatterns(patterns, saturateFP8Truncf);
208260
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
209261
return signalPassFailure();
210262
}

mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_mlir_conversion_library(MLIRArithToAMDGPU
1313
LINK_LIBS PUBLIC
1414
MLIRAMDGPUDialect
1515
MLIRArithDialect
16+
MLIRArithUtils
1617
MLIRVectorDialect
1718
MLIRPass
1819
MLIRTransforms

mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "mlir/Dialect/Arith/IR/Arith.h"
1212
#include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h"
13+
#include "mlir/Dialect/Arith/Utils/Utils.h"
1314
#include "mlir/Dialect/Func/IR/FuncOps.h"
1415
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
1516
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -58,35 +59,6 @@ static Type reduceInnermostDim(VectorType type) {
5859
return VectorType::get(newShape, type.getElementType());
5960
}
6061

61-
/// Returns a constant of integer of vector type filled with (repeated) `value`.
62-
static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
63-
Location loc, Type type,
64-
const APInt &value) {
65-
TypedAttr attr;
66-
if (dyn_cast<IntegerType>(type)) {
67-
attr = rewriter.getIntegerAttr(type, value);
68-
} else {
69-
auto vecTy = cast<VectorType>(type);
70-
attr = SplatElementsAttr::get(vecTy, value);
71-
}
72-
73-
return rewriter.create<arith::ConstantOp>(loc, attr);
74-
}
75-
76-
/// Returns a constant of integer of vector type filled with (repeated) `value`.
77-
static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
78-
Location loc, Type type,
79-
int64_t value) {
80-
unsigned elementBitWidth = 0;
81-
if (auto intTy = dyn_cast<IntegerType>(type))
82-
elementBitWidth = intTy.getWidth();
83-
else
84-
elementBitWidth = cast<VectorType>(type).getElementTypeBitWidth();
85-
86-
return createScalarOrSplatConstant(rewriter, loc, type,
87-
APInt(elementBitWidth, value));
88-
}
89-
9062
/// Extracts the `input` vector slice with elements at the last dimension offset
9163
/// by `lastOffset`. Returns a value of vector type with the last dimension
9264
/// reduced to x1 or fully scalarized, e.g.:

mlir/lib/Dialect/Arith/Utils/Utils.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,40 @@ mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
197197
}));
198198
}
199199

200+
Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
201+
Type type, const APInt &value) {
202+
TypedAttr attr;
203+
if (isa<IntegerType>(type)) {
204+
attr = builder.getIntegerAttr(type, value);
205+
} else {
206+
auto vecTy = cast<ShapedType>(type);
207+
attr = SplatElementsAttr::get(vecTy, value);
208+
}
209+
210+
return builder.create<arith::ConstantOp>(loc, attr);
211+
}
212+
213+
Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
214+
Type type, int64_t value) {
215+
unsigned elementBitWidth = 0;
216+
if (auto intTy = dyn_cast<IntegerType>(type))
217+
elementBitWidth = intTy.getWidth();
218+
else
219+
elementBitWidth = cast<ShapedType>(type).getElementTypeBitWidth();
220+
221+
return createScalarOrSplatConstant(builder, loc, type,
222+
APInt(elementBitWidth, value));
223+
}
224+
225+
Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
226+
Type type, const APFloat &value) {
227+
if (isa<FloatType>(type))
228+
return builder.createOrFold<arith::ConstantOp>(
229+
loc, type, builder.getFloatAttr(type, value));
230+
TypedAttr splat = SplatElementsAttr::get(cast<ShapedType>(type), value);
231+
return builder.createOrFold<arith::ConstantOp>(loc, type, splat);
232+
}
233+
200234
Value ArithBuilder::_and(Value lhs, Value rhs) {
201235
return b.create<arith::AndIOp>(loc, lhs, rhs);
202236
}

0 commit comments

Comments
 (0)