Skip to content

Commit e2db2f4

Browse files
committed
Refactor LegalizeToF32 to specify extra supported float types and target type as arguments
Instead of hardcoding all fp smaller than 32 bits are unsupported we provide a way to pass supported floating point types as well as the target type.
1 parent 49a754a commit e2db2f4

File tree

10 files changed

+358
-152
lines changed

10 files changed

+358
-152
lines changed

mlir/include/mlir/Dialect/Arith/Utils/Utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ namespace arith {
130130
Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values);
131131
Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
132132
Type resultType);
133+
134+
// Map strings to float types.
135+
std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name);
136+
133137
} // namespace arith
134138
} // namespace mlir
135139

mlir/include/mlir/Dialect/Math/Transforms/Passes.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,13 @@ void populateMathPolynomialApproximationPatterns(
5656
void populateUpliftToFMAPatterns(RewritePatternSet &patterns);
5757

5858
namespace math {
59-
void populateLegalizeToF32TypeConverter(TypeConverter &typeConverter);
60-
void populateLegalizeToF32ConversionTarget(ConversionTarget &target,
61-
TypeConverter &typeConverter);
62-
void populateLegalizeToF32Patterns(RewritePatternSet &patterns,
63-
TypeConverter &typeConverter);
59+
void populateExtendToSupportedTypesTypeConverter(
60+
TypeConverter &typeConverter, const SetVector<Type> &sourceTypes,
61+
Type targetType);
62+
void populateExtendToSupportedTypesConversionTarget(
63+
ConversionTarget &target, TypeConverter &typeConverter);
64+
void populateExtendToSupportedTypesPatterns(RewritePatternSet &patterns,
65+
TypeConverter &typeConverter);
6466
} // namespace math
6567
} // namespace mlir
6668

mlir/include/mlir/Dialect/Math/Transforms/Passes.td

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def MathUpliftToFMA : Pass<"math-uplift-to-fma"> {
1919
let dependentDialects = ["math::MathDialect"];
2020
}
2121

22-
def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
22+
def MathExtendToSupportedTypes : Pass<"math-extend-to-supported-types"> {
2323
let summary = "Legalize floating-point math ops on low-precision floats";
2424
let description = [{
2525
On many targets, the math functions are not implemented for floating-point
@@ -28,11 +28,19 @@ def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
2828

2929
This pass explicitly legalizes these math functions by inserting
3030
`arith.extf` and `arith.truncf` pairs around said op, which preserves
31-
the original semantics while enabling lowering.
31+
the original semantics while enabling lowering. The extra supported floating-point
32+
types for the target are passed as arguments. Types f64 and f32 are implicitly
33+
supported.
3234

3335
As an exception, this pass does not legalize `math.fma`, because
3436
that is an operation frequently implemented at low precisions.
3537
}];
38+
let options = [
39+
ListOption<"extraTypeStrs", "extra-types", "std::string",
40+
"MLIR types with arithmetic support on a given target (f64 and f32 are implicitly supported)">,
41+
Option<"targetTypeStr", "target-type", "std::string", "\"f32\"",
42+
"MLIR type to convert the unsupported source types to">,
43+
];
3644
let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
3745
}
3846

mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Dialect/Arith/Transforms/Passes.h"
1414

1515
#include "mlir/Dialect/Arith/IR/Arith.h"
16+
#include "mlir/Dialect/Arith/Utils/Utils.h"
1617
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1718
#include "mlir/IR/BuiltinTypes.h"
1819
#include "mlir/IR/Location.h"
@@ -49,29 +50,6 @@ struct EmulateFloatPattern final : ConversionPattern {
4950
};
5051
} // end namespace
5152

52-
/// Map strings to float types. This function is here because no one else needs
53-
/// it yet, feel free to abstract it out.
54-
static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
55-
StringRef name) {
56-
Builder b(ctx);
57-
return llvm::StringSwitch<std::optional<FloatType>>(name)
58-
.Case("f6E2M3FN", b.getFloat6E2M3FNType())
59-
.Case("f6E3M2FN", b.getFloat6E3M2FNType())
60-
.Case("f8E5M2", b.getFloat8E5M2Type())
61-
.Case("f8E4M3", b.getFloat8E4M3Type())
62-
.Case("f8E4M3FN", b.getFloat8E4M3FNType())
63-
.Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
64-
.Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
65-
.Case("f8E3M4", b.getFloat8E3M4Type())
66-
.Case("bf16", b.getBF16Type())
67-
.Case("f16", b.getF16Type())
68-
.Case("f32", b.getF32Type())
69-
.Case("f64", b.getF64Type())
70-
.Case("f80", b.getF80Type())
71-
.Case("f128", b.getF128Type())
72-
.Default(std::nullopt);
73-
}
74-
7553
LogicalResult EmulateFloatPattern::match(Operation *op) const {
7654
if (getTypeConverter()->isLegal(op))
7755
return failure();
@@ -155,7 +133,8 @@ void EmulateUnsupportedFloatsPass::runOnOperation() {
155133
SmallVector<Type> sourceTypes;
156134
Type targetType;
157135

158-
std::optional<FloatType> maybeTargetType = parseFloatType(ctx, targetTypeStr);
136+
std::optional<FloatType> maybeTargetType =
137+
arith::parseFloatType(ctx, targetTypeStr);
159138
if (!maybeTargetType) {
160139
emitError(UnknownLoc::get(ctx), "could not map target type '" +
161140
targetTypeStr +
@@ -165,7 +144,7 @@ void EmulateUnsupportedFloatsPass::runOnOperation() {
165144
targetType = *maybeTargetType;
166145
for (StringRef sourceTypeStr : sourceTypeStrs) {
167146
std::optional<FloatType> maybeSourceType =
168-
parseFloatType(ctx, sourceTypeStr);
147+
arith::parseFloatType(ctx, sourceTypeStr);
169148
if (!maybeSourceType) {
170149
emitError(UnknownLoc::get(ctx), "could not map source type '" +
171150
sourceTypeStr +

mlir/lib/Dialect/Arith/Utils/Utils.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,4 +357,25 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
357357
[&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
358358
}
359359

360+
/// Map strings to float types.
361+
std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
362+
Builder b(ctx);
363+
return llvm::StringSwitch<std::optional<FloatType>>(name)
364+
.Case("f6E2M3FN", b.getFloat6E2M3FNType())
365+
.Case("f6E3M2FN", b.getFloat6E3M2FNType())
366+
.Case("f8E5M2", b.getFloat8E5M2Type())
367+
.Case("f8E4M3", b.getFloat8E4M3Type())
368+
.Case("f8E4M3FN", b.getFloat8E4M3FNType())
369+
.Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
370+
.Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
371+
.Case("f8E3M4", b.getFloat8E3M4Type())
372+
.Case("bf16", b.getBF16Type())
373+
.Case("f16", b.getF16Type())
374+
.Case("f32", b.getF32Type())
375+
.Case("f64", b.getF64Type())
376+
.Case("f80", b.getF80Type())
377+
.Case("f128", b.getF128Type())
378+
.Default(std::nullopt);
379+
}
380+
360381
} // namespace mlir::arith

mlir/lib/Dialect/Math/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
add_mlir_dialect_library(MLIRMathTransforms
22
AlgebraicSimplification.cpp
33
ExpandPatterns.cpp
4-
LegalizeToF32.cpp
4+
ExtendToSupportedTypes.cpp
55
PolynomialApproximation.cpp
66
UpliftToFMA.cpp
77

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
//===- ExtendToSupportedTypes.cpp - Legalize functions on unsupported floats
2+
//----------===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// This file implements legalizing math operations on unsupported floating-point
11+
// types through arith.extf and arith.truncf.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "mlir/Dialect/Arith/IR/Arith.h"
16+
#include "mlir/Dialect/Arith/Utils/Utils.h"
17+
#include "mlir/Dialect/Math/IR/Math.h"
18+
#include "mlir/Dialect/Math/Transforms/Passes.h"
19+
#include "mlir/IR/Diagnostics.h"
20+
#include "mlir/IR/PatternMatch.h"
21+
#include "mlir/IR/TypeUtilities.h"
22+
#include "mlir/Transforms/DialectConversion.h"
23+
#include "llvm/ADT/STLExtras.h"
24+
#include "llvm/ADT/SetVector.h"
25+
26+
namespace mlir::math {
27+
#define GEN_PASS_DEF_MATHEXTENDTOSUPPORTEDTYPES
28+
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
29+
} // namespace mlir::math
30+
31+
using namespace mlir;
32+
33+
namespace {
34+
struct ExtendToSupportedTypesRewritePattern final : ConversionPattern {
35+
ExtendToSupportedTypesRewritePattern(TypeConverter &converter,
36+
MLIRContext *context)
37+
: ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {}
38+
LogicalResult
39+
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
40+
ConversionPatternRewriter &rewriter) const override;
41+
};
42+
43+
struct ExtendToSupportedTypesPass
44+
: mlir::math::impl::MathExtendToSupportedTypesBase<
45+
ExtendToSupportedTypesPass> {
46+
using math::impl::MathExtendToSupportedTypesBase<
47+
ExtendToSupportedTypesPass>::MathExtendToSupportedTypesBase;
48+
49+
void runOnOperation() override;
50+
};
51+
} // namespace
52+
53+
void mlir::math::populateExtendToSupportedTypesTypeConverter(
54+
TypeConverter &typeConverter, const SetVector<Type> &sourceTypes,
55+
Type targetType) {
56+
57+
typeConverter.addConversion(
58+
[](Type type) -> std::optional<Type> { return type; });
59+
typeConverter.addConversion(
60+
[&sourceTypes, targetType](FloatType type) -> std::optional<Type> {
61+
if (!sourceTypes.contains(type))
62+
return targetType;
63+
64+
return std::nullopt;
65+
});
66+
typeConverter.addConversion(
67+
[&sourceTypes, targetType](ShapedType type) -> std::optional<Type> {
68+
if (auto elemTy = dyn_cast<FloatType>(type.getElementType()))
69+
if (!sourceTypes.contains(elemTy))
70+
return type.clone(targetType);
71+
72+
return std::nullopt;
73+
});
74+
typeConverter.addTargetMaterialization(
75+
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
76+
auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
77+
extFOp.setFastmath(arith::FastMathFlags::contract);
78+
return extFOp;
79+
});
80+
}
81+
82+
void mlir::math::populateExtendToSupportedTypesConversionTarget(
83+
ConversionTarget &target, TypeConverter &typeConverter) {
84+
target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool {
85+
if (isa<MathDialect>(op->getDialect()))
86+
return typeConverter.isLegal(op);
87+
return true;
88+
});
89+
target.addLegalOp<FmaOp>();
90+
target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
91+
}
92+
93+
LogicalResult ExtendToSupportedTypesRewritePattern::matchAndRewrite(
94+
Operation *op, ArrayRef<Value> operands,
95+
ConversionPatternRewriter &rewriter) const {
96+
Location loc = op->getLoc();
97+
const TypeConverter *converter = getTypeConverter();
98+
FailureOr<Operation *> legalized =
99+
convertOpResultTypes(op, operands, *converter, rewriter);
100+
if (failed(legalized))
101+
return failure();
102+
103+
SmallVector<Value> results = (*legalized)->getResults();
104+
for (auto [result, newType, origType] : llvm::zip_equal(
105+
results, (*legalized)->getResultTypes(), op->getResultTypes())) {
106+
if (newType != origType) {
107+
auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
108+
truncFOp.setFastmath(arith::FastMathFlags::contract);
109+
result = truncFOp.getResult();
110+
}
111+
}
112+
rewriter.replaceOp(op, results);
113+
return success();
114+
}
115+
116+
void mlir::math::populateExtendToSupportedTypesPatterns(
117+
RewritePatternSet &patterns, TypeConverter &typeConverter) {
118+
patterns.add<ExtendToSupportedTypesRewritePattern>(typeConverter,
119+
patterns.getContext());
120+
}
121+
122+
void ExtendToSupportedTypesPass::runOnOperation() {
123+
Operation *op = getOperation();
124+
MLIRContext *ctx = &getContext();
125+
126+
// Parse target type
127+
std::optional<Type> maybeTargetType =
128+
arith::parseFloatType(ctx, targetTypeStr);
129+
if (!maybeTargetType.has_value()) {
130+
emitError(UnknownLoc::get(ctx), "could not map target type '" +
131+
targetTypeStr +
132+
"' to a known floating-point type");
133+
return signalPassFailure();
134+
}
135+
Type targetType = maybeTargetType.value();
136+
137+
// Parse source types
138+
llvm::SetVector<Type> sourceTypes;
139+
for (const auto &extraTypeStr : extraTypeStrs) {
140+
std::optional<FloatType> maybeExtraType =
141+
arith::parseFloatType(ctx, extraTypeStr);
142+
if (!maybeExtraType.has_value()) {
143+
emitError(UnknownLoc::get(ctx), "could not map source type '" +
144+
extraTypeStr +
145+
"' to a known floating-point type");
146+
return signalPassFailure();
147+
}
148+
sourceTypes.insert(maybeExtraType.value());
149+
}
150+
// f64 and f32 are implicitly supported
151+
Builder b(ctx);
152+
sourceTypes.insert(b.getF64Type());
153+
sourceTypes.insert(b.getF32Type());
154+
155+
TypeConverter typeConverter;
156+
math::populateExtendToSupportedTypesTypeConverter(typeConverter, sourceTypes,
157+
targetType);
158+
ConversionTarget target(*ctx);
159+
math::populateExtendToSupportedTypesConversionTarget(target, typeConverter);
160+
RewritePatternSet patterns(ctx);
161+
math::populateExtendToSupportedTypesPatterns(patterns, typeConverter);
162+
if (failed(applyPartialConversion(op, target, std::move(patterns))))
163+
return signalPassFailure();
164+
}

0 commit comments

Comments
 (0)