Skip to content

[mlir] Refactor LegalizeToF32 to specify extra supported float types and target type as arguments #108815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Arith/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ namespace arith {
Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values);
Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
Type resultType);

// Map strings to float types.
std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name);

} // namespace arith
} // namespace mlir

Expand Down
12 changes: 7 additions & 5 deletions mlir/include/mlir/Dialect/Math/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,13 @@ void populateMathPolynomialApproximationPatterns(
void populateUpliftToFMAPatterns(RewritePatternSet &patterns);

namespace math {
void populateLegalizeToF32TypeConverter(TypeConverter &typeConverter);
void populateLegalizeToF32ConversionTarget(ConversionTarget &target,
TypeConverter &typeConverter);
void populateLegalizeToF32Patterns(RewritePatternSet &patterns,
TypeConverter &typeConverter);
void populateExtendToSupportedTypesTypeConverter(
TypeConverter &typeConverter, const SetVector<Type> &sourceTypes,
Type targetType);
void populateExtendToSupportedTypesConversionTarget(
ConversionTarget &target, TypeConverter &typeConverter);
void populateExtendToSupportedTypesPatterns(RewritePatternSet &patterns,
TypeConverter &typeConverter);
} // namespace math
} // namespace mlir

Expand Down
12 changes: 10 additions & 2 deletions mlir/include/mlir/Dialect/Math/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def MathUpliftToFMA : Pass<"math-uplift-to-fma"> {
let dependentDialects = ["math::MathDialect"];
}

def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
def MathExtendToSupportedTypes : Pass<"math-extend-to-supported-types"> {
let summary = "Legalize floating-point math ops on low-precision floats";
let description = [{
On many targets, the math functions are not implemented for floating-point
Expand All @@ -28,11 +28,19 @@ def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {

This pass explicitly legalizes these math functions by inserting
`arith.extf` and `arith.truncf` pairs around said op, which preserves
the original semantics while enabling lowering.
the original semantics while enabling lowering. The extra supported floating-point
types for the target are passed as arguments. Types f64 and f32 are implicitly
supported.

As an exception, this pass does not legalize `math.fma`, because
that is an operation frequently implemented at low precisions.
}];
let options = [
ListOption<"extraTypeStrs", "extra-types", "std::string",
"MLIR types with arithmetic support on a given target (f64 and f32 are implicitly supported)">,
Option<"targetTypeStr", "target-type", "std::string", "\"f32\"",
"MLIR type to convert the unsupported source types to">,
];
let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
}

Expand Down
30 changes: 4 additions & 26 deletions mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "mlir/Dialect/Arith/Transforms/Passes.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
Expand Down Expand Up @@ -49,30 +50,6 @@ struct EmulateFloatPattern final : ConversionPattern {
};
} // end namespace

/// Map strings to float types. This function is here because no one else needs
/// it yet, feel free to abstract it out.
static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
StringRef name) {
Builder b(ctx);
return llvm::StringSwitch<std::optional<FloatType>>(name)
.Case("f4E2M1FN", b.getFloat4E2M1FNType())
.Case("f6E2M3FN", b.getFloat6E2M3FNType())
.Case("f6E3M2FN", b.getFloat6E3M2FNType())
.Case("f8E5M2", b.getFloat8E5M2Type())
.Case("f8E4M3", b.getFloat8E4M3Type())
.Case("f8E4M3FN", b.getFloat8E4M3FNType())
.Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
.Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
.Case("f8E3M4", b.getFloat8E3M4Type())
.Case("bf16", b.getBF16Type())
.Case("f16", b.getF16Type())
.Case("f32", b.getF32Type())
.Case("f64", b.getF64Type())
.Case("f80", b.getF80Type())
.Case("f128", b.getF128Type())
.Default(std::nullopt);
}

LogicalResult EmulateFloatPattern::match(Operation *op) const {
if (getTypeConverter()->isLegal(op))
return failure();
Expand Down Expand Up @@ -156,7 +133,8 @@ void EmulateUnsupportedFloatsPass::runOnOperation() {
SmallVector<Type> sourceTypes;
Type targetType;

std::optional<FloatType> maybeTargetType = parseFloatType(ctx, targetTypeStr);
std::optional<FloatType> maybeTargetType =
arith::parseFloatType(ctx, targetTypeStr);
if (!maybeTargetType) {
emitError(UnknownLoc::get(ctx), "could not map target type '" +
targetTypeStr +
Expand All @@ -166,7 +144,7 @@ void EmulateUnsupportedFloatsPass::runOnOperation() {
targetType = *maybeTargetType;
for (StringRef sourceTypeStr : sourceTypeStrs) {
std::optional<FloatType> maybeSourceType =
parseFloatType(ctx, sourceTypeStr);
arith::parseFloatType(ctx, sourceTypeStr);
if (!maybeSourceType) {
emitError(UnknownLoc::get(ctx), "could not map source type '" +
sourceTypeStr +
Expand Down
22 changes: 22 additions & 0 deletions mlir/lib/Dialect/Arith/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,4 +357,26 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
[&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
}

/// Map strings to float types.
std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
Builder b(ctx);
return llvm::StringSwitch<std::optional<FloatType>>(name)
.Case("f4E2M1FN", b.getFloat4E2M1FNType())
.Case("f6E2M3FN", b.getFloat6E2M3FNType())
.Case("f6E3M2FN", b.getFloat6E3M2FNType())
.Case("f8E5M2", b.getFloat8E5M2Type())
.Case("f8E4M3", b.getFloat8E4M3Type())
.Case("f8E4M3FN", b.getFloat8E4M3FNType())
.Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
.Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
.Case("f8E3M4", b.getFloat8E3M4Type())
.Case("bf16", b.getBF16Type())
.Case("f16", b.getF16Type())
.Case("f32", b.getF32Type())
.Case("f64", b.getF64Type())
.Case("f80", b.getF80Type())
.Case("f128", b.getF128Type())
.Default(std::nullopt);
}

} // namespace mlir::arith
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
add_mlir_dialect_library(MLIRMathTransforms
AlgebraicSimplification.cpp
ExpandPatterns.cpp
LegalizeToF32.cpp
ExtendToSupportedTypes.cpp
PolynomialApproximation.cpp
UpliftToFMA.cpp

Expand Down
164 changes: 164 additions & 0 deletions mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
//===- ExtendToSupportedTypes.cpp - Legalize functions on unsupported floats
//----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements legalizing math operations on unsupported floating-point
// types through arith.extf and arith.truncf.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"

namespace mlir::math {
#define GEN_PASS_DEF_MATHEXTENDTOSUPPORTEDTYPES
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
} // namespace mlir::math

using namespace mlir;

namespace {
struct ExtendToSupportedTypesRewritePattern final : ConversionPattern {
ExtendToSupportedTypesRewritePattern(TypeConverter &converter,
MLIRContext *context)
: ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};

struct ExtendToSupportedTypesPass
: mlir::math::impl::MathExtendToSupportedTypesBase<
ExtendToSupportedTypesPass> {
using math::impl::MathExtendToSupportedTypesBase<
ExtendToSupportedTypesPass>::MathExtendToSupportedTypesBase;

void runOnOperation() override;
};
} // namespace

void mlir::math::populateExtendToSupportedTypesTypeConverter(
TypeConverter &typeConverter, const SetVector<Type> &sourceTypes,
Type targetType) {

typeConverter.addConversion(
[](Type type) -> std::optional<Type> { return type; });
typeConverter.addConversion(
[&sourceTypes, targetType](FloatType type) -> std::optional<Type> {
if (!sourceTypes.contains(type))
return targetType;

return std::nullopt;
});
typeConverter.addConversion(
[&sourceTypes, targetType](ShapedType type) -> std::optional<Type> {
if (auto elemTy = dyn_cast<FloatType>(type.getElementType()))
if (!sourceTypes.contains(elemTy))
return type.clone(targetType);

return std::nullopt;
});
typeConverter.addTargetMaterialization(
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
extFOp.setFastmath(arith::FastMathFlags::contract);
return extFOp;
});
}

void mlir::math::populateExtendToSupportedTypesConversionTarget(
ConversionTarget &target, TypeConverter &typeConverter) {
target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool {
if (isa<MathDialect>(op->getDialect()))
return typeConverter.isLegal(op);
return true;
});
target.addLegalOp<FmaOp>();
target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
}

LogicalResult ExtendToSupportedTypesRewritePattern::matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
const TypeConverter *converter = getTypeConverter();
FailureOr<Operation *> legalized =
convertOpResultTypes(op, operands, *converter, rewriter);
if (failed(legalized))
return failure();

SmallVector<Value> results = (*legalized)->getResults();
for (auto [result, newType, origType] : llvm::zip_equal(
results, (*legalized)->getResultTypes(), op->getResultTypes())) {
if (newType != origType) {
auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
truncFOp.setFastmath(arith::FastMathFlags::contract);
result = truncFOp.getResult();
}
}
rewriter.replaceOp(op, results);
return success();
}

void mlir::math::populateExtendToSupportedTypesPatterns(
RewritePatternSet &patterns, TypeConverter &typeConverter) {
patterns.add<ExtendToSupportedTypesRewritePattern>(typeConverter,
patterns.getContext());
}

void ExtendToSupportedTypesPass::runOnOperation() {
Operation *op = getOperation();
MLIRContext *ctx = &getContext();

// Parse target type
std::optional<Type> maybeTargetType =
arith::parseFloatType(ctx, targetTypeStr);
if (!maybeTargetType.has_value()) {
emitError(UnknownLoc::get(ctx), "could not map target type '" +
targetTypeStr +
"' to a known floating-point type");
return signalPassFailure();
}
Type targetType = maybeTargetType.value();

// Parse source types
llvm::SetVector<Type> sourceTypes;
for (const auto &extraTypeStr : extraTypeStrs) {
std::optional<FloatType> maybeExtraType =
arith::parseFloatType(ctx, extraTypeStr);
if (!maybeExtraType.has_value()) {
emitError(UnknownLoc::get(ctx), "could not map source type '" +
extraTypeStr +
"' to a known floating-point type");
return signalPassFailure();
}
sourceTypes.insert(maybeExtraType.value());
}
// f64 and f32 are implicitly supported
Builder b(ctx);
sourceTypes.insert(b.getF64Type());
sourceTypes.insert(b.getF32Type());

TypeConverter typeConverter;
math::populateExtendToSupportedTypesTypeConverter(typeConverter, sourceTypes,
targetType);
ConversionTarget target(*ctx);
math::populateExtendToSupportedTypesConversionTarget(target, typeConverter);
RewritePatternSet patterns(ctx);
math::populateExtendToSupportedTypesPatterns(patterns, typeConverter);
if (failed(applyPartialConversion(op, target, std::move(patterns))))
return signalPassFailure();
}
Loading
Loading