Skip to content

[mlir][ArithToAMDGPU] Add option for saturating truncation to fp8 #74153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ class Pass;
#include "mlir/Conversion/Passes.h.inc"

namespace arith {
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns);
/// Add patterns for rewriting `arith.extf` and `arith.truncf` on FP8 types
/// to wrappers around AMDGPU--specific intrinsics. If `saturateFP8TruncF`
/// is set, values outside the range of the destination type are clamped
/// to the largest value of that type instead of being rewritten to Inf (aka
/// NaN).
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
bool saturateFP8TruncF);
} // namespace arith
} // namespace mlir

Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
}];

let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];

let options = [
Option<"saturateFP8Truncf", "saturate-fp8-truncf", "bool",
/*default=*/"false",
"Use saturating truncation for 8-bit float types">,
];
}

//===----------------------------------------------------------------------===//
Expand Down
11 changes: 11 additions & 0 deletions mlir/include/mlir/Dialect/Arith/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
Type toType, bool isUnsignedCast);

/// Create a constant of type `type` at location `loc` whose value is `value`
/// (an APInt or APFloat whose type must match the element type of `type`).
/// If `type` is a shaped type, create a splat constant of the given value.
/// Constants are folded if possible.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
const APInt &value);
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
int64_t value);
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
const APFloat &value);

/// Helper struct to build simple arithmetic quantities with minimal type
/// inference support.
struct ArithBuilder {
Expand Down
108 changes: 80 additions & 28 deletions mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
Expand All @@ -34,17 +35,17 @@ struct ArithToAMDGPUConversionPass final
void runOnOperation() override;
};

struct ExtfOnFloat8RewritePattern final
: public OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern<arith::ExtFOp>::OpRewritePattern;
struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult match(arith::ExtFOp op) const override;
void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
};

struct TruncfToFloat8RewritePattern final
: public OpRewritePattern<arith::TruncFOp> {
using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
bool saturateFP8 = false;
TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8)
: OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8) {}

LogicalResult match(arith::TruncFOp op) const override;
void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
Expand All @@ -62,7 +63,7 @@ static Value castF32To(Type elementType, Value f32, Location loc,
llvm_unreachable("The only 32-bit float type is f32");
}

LogicalResult ExtfOnFloat8RewritePattern::match(arith::ExtFOp op) const {
LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
Type inType = op.getIn().getType();
if (auto inVecType = inType.dyn_cast<VectorType>()) {
if (inVecType.isScalable())
Expand All @@ -75,7 +76,7 @@ LogicalResult ExtfOnFloat8RewritePattern::match(arith::ExtFOp op) const {
return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
}

void ExtfOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value in = op.getIn();
Expand All @@ -93,11 +94,13 @@ void ExtfOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
Value result =
rewriter.createOrFold<vector::SplatOp>(loc, op.getOut().getType(), zero);
if (inType.getShape().empty()) {
Value scalarIn = rewriter.create<vector::ExtractElementOp>(loc, in);
Value scalarIn =
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
// Recurse to send the 0-D vector case to the 1-D vector case
Value scalarExt =
rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
result = rewriter.create<vector::InsertElementOp>(loc, scalarExt, zero);
result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
ArrayRef<int64_t>{});
return rewriter.replaceOp(op, result);
}
for (int64_t i = 0; i < numElements; i += 4) {
Expand All @@ -108,9 +111,7 @@ void ExtfOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
loc, rewriter.getF32Type(), inSlice, j);
Value asType = castF32To(outElemType, asFloat, loc, rewriter);
result = rewriter.create<vector::InsertElementOp>(
loc, asType, result,
rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + j));
result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
}
}
rewriter.replaceOp(op, result);
Expand All @@ -127,7 +128,53 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
llvm_unreachable("The only 32-bit float type is f32");
}

LogicalResult TruncfToFloat8RewritePattern::match(arith::TruncFOp op) const {
// If `in` is a finite value, clamp it between the maximum and minimum values
// of `outElemType` so that subsequent conversion instructions don't
// overflow those out-of-range values to NaN. These semantics are commonly
// used in machine-learning contexts where failure to clamp would lead to
// excessive NaN production.
static Value clampInput(PatternRewriter &rewriter, Location loc,
Type outElemType, Value source) {
Type sourceType = source.getType();
const llvm::fltSemantics &sourceSem =
cast<FloatType>(getElementTypeOrSelf(sourceType)).getFloatSemantics();
const llvm::fltSemantics &targetSem =
cast<FloatType>(outElemType).getFloatSemantics();

APFloat min = APFloat::getLargest(targetSem, /*Negative=*/true);
APFloat max = APFloat::getLargest(targetSem, /*Negative=*/false);
bool ignoredLosesInfo = false;
// We can ignore conversion failures here because this conversion promotes
// from a smaller type to a larger one - ex. there can be no loss of precision
// when casting fp8 to f16.
(void)min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
(void)max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);

Value minCst = createScalarOrSplatConstant(rewriter, loc, sourceType, min);
Value maxCst = createScalarOrSplatConstant(rewriter, loc, sourceType, max);

Value inf = createScalarOrSplatConstant(
rewriter, loc, sourceType,
APFloat::getInf(sourceSem, /*Negative=*/false));
Value negInf = createScalarOrSplatConstant(
rewriter, loc, sourceType, APFloat::getInf(sourceSem, /*Negative=*/true));
Value isInf = rewriter.createOrFold<arith::CmpFOp>(
loc, arith::CmpFPredicate::OEQ, source, inf);
Value isNegInf = rewriter.createOrFold<arith::CmpFOp>(
loc, arith::CmpFPredicate::OEQ, source, negInf);
Value isNan = rewriter.createOrFold<arith::CmpFOp>(
loc, arith::CmpFPredicate::UNO, source, source);
Value isNonFinite = rewriter.create<arith::OrIOp>(
loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan);

Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst);
Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
Value res =
rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped);
return res;
}

LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
Type outType = op.getOut().getType();
if (auto outVecType = outType.dyn_cast<VectorType>()) {
if (outVecType.isScalable())
Expand All @@ -137,22 +184,27 @@ LogicalResult TruncfToFloat8RewritePattern::match(arith::TruncFOp op) const {
return failure();
outType = outVecType.getElementType();
}
auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
if (inType && inType.getWidth() <= 8 && saturateFP8)
// Conversion between 8-bit floats is not supported with truncation enabled.
return failure();
return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
}

void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
if (saturateFP8)
in = clampInput(rewriter, loc, outElemType, in);
VectorType truncResType = VectorType::get(4, outElemType);
if (!in.getType().isa<VectorType>()) {
Value asFloat = castToF32(in, loc, rewriter);
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
/*existing=*/nullptr);
Value result = rewriter.create<vector::ExtractElementOp>(
loc, asF8s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
return rewriter.replaceOp(op, result);
}
VectorType outType = op.getOut().getType().cast<VectorType>();
Expand All @@ -161,26 +213,25 @@ void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
if (outType.getShape().empty()) {
Value scalarIn = rewriter.create<vector::ExtractElementOp>(loc, in);
Value scalarIn =
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
// Recurse to send the 0-D vector case to the 1-D vector case
Value scalarTrunc =
rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
result = rewriter.create<vector::InsertElementOp>(loc, scalarTrunc, zero);
result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
ArrayRef<int64_t>{});
return rewriter.replaceOp(op, result);
}

for (int64_t i = 0; i < numElements; i += 4) {
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
Value thisResult = nullptr;
for (int64_t j = 0; j < elemsThisOp; j += 2) {
Value elemA = rewriter.create<vector::ExtractElementOp>(
loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i + j));
Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j);
Value asFloatA = castToF32(elemA, loc, rewriter);
Value asFloatB = nullptr;
if (j + 1 < elemsThisOp) {
Value elemB = rewriter.create<vector::ExtractElementOp>(
loc, in,
rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + j + 1));
Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1);
asFloatB = castToF32(elemB, loc, rewriter);
}
thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
Expand All @@ -196,15 +247,16 @@ void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
}

void mlir::arith::populateArithToAMDGPUConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<ExtfOnFloat8RewritePattern, TruncfToFloat8RewritePattern>(
patterns.getContext());
RewritePatternSet &patterns, bool saturateFP8TruncF) {
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
saturateFP8TruncF);
}

void ArithToAMDGPUConversionPass::runOnOperation() {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
arith::populateArithToAMDGPUConversionPatterns(patterns);
arith::populateArithToAMDGPUConversionPatterns(patterns, saturateFP8Truncf);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
return signalPassFailure();
}
1 change: 1 addition & 0 deletions mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_mlir_conversion_library(MLIRArithToAMDGPU
LINK_LIBS PUBLIC
MLIRAMDGPUDialect
MLIRArithDialect
MLIRArithUtils
MLIRVectorDialect
MLIRPass
MLIRTransforms
Expand Down
30 changes: 1 addition & 29 deletions mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
Expand Down Expand Up @@ -58,35 +59,6 @@ static Type reduceInnermostDim(VectorType type) {
return VectorType::get(newShape, type.getElementType());
}

/// Returns a constant of integer of vector type filled with (repeated) `value`.
static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
Location loc, Type type,
const APInt &value) {
TypedAttr attr;
if (dyn_cast<IntegerType>(type)) {
attr = rewriter.getIntegerAttr(type, value);
} else {
auto vecTy = cast<VectorType>(type);
attr = SplatElementsAttr::get(vecTy, value);
}

return rewriter.create<arith::ConstantOp>(loc, attr);
}

/// Returns a constant of integer of vector type filled with (repeated) `value`.
static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
Location loc, Type type,
int64_t value) {
unsigned elementBitWidth = 0;
if (auto intTy = dyn_cast<IntegerType>(type))
elementBitWidth = intTy.getWidth();
else
elementBitWidth = cast<VectorType>(type).getElementTypeBitWidth();

return createScalarOrSplatConstant(rewriter, loc, type,
APInt(elementBitWidth, value));
}

/// Extracts the `input` vector slice with elements at the last dimension offset
/// by `lastOffset`. Returns a value of vector type with the last dimension
/// reduced to x1 or fully scalarized, e.g.:
Expand Down
34 changes: 34 additions & 0 deletions mlir/lib/Dialect/Arith/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,40 @@ mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
}));
}

Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
Type type, const APInt &value) {
TypedAttr attr;
if (isa<IntegerType>(type)) {
attr = builder.getIntegerAttr(type, value);
} else {
auto vecTy = cast<ShapedType>(type);
attr = SplatElementsAttr::get(vecTy, value);
}

return builder.create<arith::ConstantOp>(loc, attr);
}

Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
Type type, int64_t value) {
unsigned elementBitWidth = 0;
if (auto intTy = dyn_cast<IntegerType>(type))
elementBitWidth = intTy.getWidth();
else
elementBitWidth = cast<ShapedType>(type).getElementTypeBitWidth();

return createScalarOrSplatConstant(builder, loc, type,
APInt(elementBitWidth, value));
}

Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
Type type, const APFloat &value) {
if (isa<FloatType>(type))
return builder.createOrFold<arith::ConstantOp>(
loc, type, builder.getFloatAttr(type, value));
TypedAttr splat = SplatElementsAttr::get(cast<ShapedType>(type), value);
return builder.createOrFold<arith::ConstantOp>(loc, type, splat);
}

Value ArithBuilder::_and(Value lhs, Value rhs) {
return b.create<arith::AndIOp>(loc, lhs, rhs);
}
Expand Down
Loading