Skip to content

Commit 1fd1f65

Browse files
authored
[mlir] Refactor LegalizeToF32 to specify extra supported float types and target type as arguments (#108815)
Instead of hardcoding all fp smaller than 32 bits are unsupported we provide a way to pass supported floating point types as well as the target type. fp64 and fp32 are implicitly supported. CC: @krzysz00 @manupak
1 parent 84a0a3d commit 1fd1f65

File tree

10 files changed

+359
-153
lines changed

10 files changed

+359
-153
lines changed

mlir/include/mlir/Dialect/Arith/Utils/Utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ namespace arith {
130130
Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values);
131131
Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
132132
Type resultType);
133+
134+
// Map strings to float types.
135+
std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name);
136+
133137
} // namespace arith
134138
} // namespace mlir
135139

mlir/include/mlir/Dialect/Math/Transforms/Passes.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,13 @@ void populateMathPolynomialApproximationPatterns(
5656
void populateUpliftToFMAPatterns(RewritePatternSet &patterns);
5757

5858
namespace math {
59-
void populateLegalizeToF32TypeConverter(TypeConverter &typeConverter);
60-
void populateLegalizeToF32ConversionTarget(ConversionTarget &target,
61-
TypeConverter &typeConverter);
62-
void populateLegalizeToF32Patterns(RewritePatternSet &patterns,
63-
TypeConverter &typeConverter);
59+
void populateExtendToSupportedTypesTypeConverter(
60+
TypeConverter &typeConverter, const SetVector<Type> &sourceTypes,
61+
Type targetType);
62+
void populateExtendToSupportedTypesConversionTarget(
63+
ConversionTarget &target, TypeConverter &typeConverter);
64+
void populateExtendToSupportedTypesPatterns(RewritePatternSet &patterns,
65+
TypeConverter &typeConverter);
6466
} // namespace math
6567
} // namespace mlir
6668

mlir/include/mlir/Dialect/Math/Transforms/Passes.td

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def MathUpliftToFMA : Pass<"math-uplift-to-fma"> {
1919
let dependentDialects = ["math::MathDialect"];
2020
}
2121

22-
def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
22+
def MathExtendToSupportedTypes : Pass<"math-extend-to-supported-types"> {
2323
let summary = "Legalize floating-point math ops on low-precision floats";
2424
let description = [{
2525
On many targets, the math functions are not implemented for floating-point
@@ -28,11 +28,19 @@ def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
2828

2929
This pass explicitly legalizes these math functions by inserting
3030
`arith.extf` and `arith.truncf` pairs around said op, which preserves
31-
the original semantics while enabling lowering.
31+
the original semantics while enabling lowering. The extra supported floating-point
32+
types for the target are passed as arguments. Types f64 and f32 are implicitly
33+
supported.
3234

3335
As an exception, this pass does not legalize `math.fma`, because
3436
that is an operation frequently implemented at low precisions.
3537
}];
38+
let options = [
39+
ListOption<"extraTypeStrs", "extra-types", "std::string",
40+
"MLIR types with arithmetic support on a given target (f64 and f32 are implicitly supported)">,
41+
Option<"targetTypeStr", "target-type", "std::string", "\"f32\"",
42+
"MLIR type to convert the unsupported source types to">,
43+
];
3644
let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
3745
}
3846

mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Dialect/Arith/Transforms/Passes.h"
1414

1515
#include "mlir/Dialect/Arith/IR/Arith.h"
16+
#include "mlir/Dialect/Arith/Utils/Utils.h"
1617
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1718
#include "mlir/IR/BuiltinTypes.h"
1819
#include "mlir/IR/Location.h"
@@ -49,30 +50,6 @@ struct EmulateFloatPattern final : ConversionPattern {
4950
};
5051
} // end namespace
5152

52-
/// Map strings to float types. This function is here because no one else needs
53-
/// it yet, feel free to abstract it out.
54-
static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
55-
StringRef name) {
56-
Builder b(ctx);
57-
return llvm::StringSwitch<std::optional<FloatType>>(name)
58-
.Case("f4E2M1FN", b.getFloat4E2M1FNType())
59-
.Case("f6E2M3FN", b.getFloat6E2M3FNType())
60-
.Case("f6E3M2FN", b.getFloat6E3M2FNType())
61-
.Case("f8E5M2", b.getFloat8E5M2Type())
62-
.Case("f8E4M3", b.getFloat8E4M3Type())
63-
.Case("f8E4M3FN", b.getFloat8E4M3FNType())
64-
.Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
65-
.Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
66-
.Case("f8E3M4", b.getFloat8E3M4Type())
67-
.Case("bf16", b.getBF16Type())
68-
.Case("f16", b.getF16Type())
69-
.Case("f32", b.getF32Type())
70-
.Case("f64", b.getF64Type())
71-
.Case("f80", b.getF80Type())
72-
.Case("f128", b.getF128Type())
73-
.Default(std::nullopt);
74-
}
75-
7653
LogicalResult EmulateFloatPattern::match(Operation *op) const {
7754
if (getTypeConverter()->isLegal(op))
7855
return failure();
@@ -156,7 +133,8 @@ void EmulateUnsupportedFloatsPass::runOnOperation() {
156133
SmallVector<Type> sourceTypes;
157134
Type targetType;
158135

159-
std::optional<FloatType> maybeTargetType = parseFloatType(ctx, targetTypeStr);
136+
std::optional<FloatType> maybeTargetType =
137+
arith::parseFloatType(ctx, targetTypeStr);
160138
if (!maybeTargetType) {
161139
emitError(UnknownLoc::get(ctx), "could not map target type '" +
162140
targetTypeStr +
@@ -166,7 +144,7 @@ void EmulateUnsupportedFloatsPass::runOnOperation() {
166144
targetType = *maybeTargetType;
167145
for (StringRef sourceTypeStr : sourceTypeStrs) {
168146
std::optional<FloatType> maybeSourceType =
169-
parseFloatType(ctx, sourceTypeStr);
147+
arith::parseFloatType(ctx, sourceTypeStr);
170148
if (!maybeSourceType) {
171149
emitError(UnknownLoc::get(ctx), "could not map source type '" +
172150
sourceTypeStr +

mlir/lib/Dialect/Arith/Utils/Utils.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,4 +357,26 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
357357
[&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
358358
}
359359

360+
/// Map strings to float types.
361+
std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
362+
Builder b(ctx);
363+
return llvm::StringSwitch<std::optional<FloatType>>(name)
364+
.Case("f4E2M1FN", b.getFloat4E2M1FNType())
365+
.Case("f6E2M3FN", b.getFloat6E2M3FNType())
366+
.Case("f6E3M2FN", b.getFloat6E3M2FNType())
367+
.Case("f8E5M2", b.getFloat8E5M2Type())
368+
.Case("f8E4M3", b.getFloat8E4M3Type())
369+
.Case("f8E4M3FN", b.getFloat8E4M3FNType())
370+
.Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
371+
.Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
372+
.Case("f8E3M4", b.getFloat8E3M4Type())
373+
.Case("bf16", b.getBF16Type())
374+
.Case("f16", b.getF16Type())
375+
.Case("f32", b.getF32Type())
376+
.Case("f64", b.getF64Type())
377+
.Case("f80", b.getF80Type())
378+
.Case("f128", b.getF128Type())
379+
.Default(std::nullopt);
380+
}
381+
360382
} // namespace mlir::arith

mlir/lib/Dialect/Math/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
add_mlir_dialect_library(MLIRMathTransforms
22
AlgebraicSimplification.cpp
33
ExpandPatterns.cpp
4-
LegalizeToF32.cpp
4+
ExtendToSupportedTypes.cpp
55
PolynomialApproximation.cpp
66
UpliftToFMA.cpp
77

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
//===- ExtendToSupportedTypes.cpp - Legalize functions on unsupported floats
2+
//----------===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// This file implements legalizing math operations on unsupported floating-point
11+
// types through arith.extf and arith.truncf.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "mlir/Dialect/Arith/IR/Arith.h"
16+
#include "mlir/Dialect/Arith/Utils/Utils.h"
17+
#include "mlir/Dialect/Math/IR/Math.h"
18+
#include "mlir/Dialect/Math/Transforms/Passes.h"
19+
#include "mlir/IR/Diagnostics.h"
20+
#include "mlir/IR/PatternMatch.h"
21+
#include "mlir/IR/TypeUtilities.h"
22+
#include "mlir/Transforms/DialectConversion.h"
23+
#include "llvm/ADT/STLExtras.h"
24+
#include "llvm/ADT/SetVector.h"
25+
26+
namespace mlir::math {
27+
#define GEN_PASS_DEF_MATHEXTENDTOSUPPORTEDTYPES
28+
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
29+
} // namespace mlir::math
30+
31+
using namespace mlir;
32+
33+
namespace {
34+
struct ExtendToSupportedTypesRewritePattern final : ConversionPattern {
35+
ExtendToSupportedTypesRewritePattern(TypeConverter &converter,
36+
MLIRContext *context)
37+
: ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {}
38+
LogicalResult
39+
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
40+
ConversionPatternRewriter &rewriter) const override;
41+
};
42+
43+
struct ExtendToSupportedTypesPass
44+
: mlir::math::impl::MathExtendToSupportedTypesBase<
45+
ExtendToSupportedTypesPass> {
46+
using math::impl::MathExtendToSupportedTypesBase<
47+
ExtendToSupportedTypesPass>::MathExtendToSupportedTypesBase;
48+
49+
void runOnOperation() override;
50+
};
51+
} // namespace
52+
53+
void mlir::math::populateExtendToSupportedTypesTypeConverter(
54+
TypeConverter &typeConverter, const SetVector<Type> &sourceTypes,
55+
Type targetType) {
56+
57+
typeConverter.addConversion(
58+
[](Type type) -> std::optional<Type> { return type; });
59+
typeConverter.addConversion(
60+
[&sourceTypes, targetType](FloatType type) -> std::optional<Type> {
61+
if (!sourceTypes.contains(type))
62+
return targetType;
63+
64+
return std::nullopt;
65+
});
66+
typeConverter.addConversion(
67+
[&sourceTypes, targetType](ShapedType type) -> std::optional<Type> {
68+
if (auto elemTy = dyn_cast<FloatType>(type.getElementType()))
69+
if (!sourceTypes.contains(elemTy))
70+
return type.clone(targetType);
71+
72+
return std::nullopt;
73+
});
74+
typeConverter.addTargetMaterialization(
75+
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
76+
auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
77+
extFOp.setFastmath(arith::FastMathFlags::contract);
78+
return extFOp;
79+
});
80+
}
81+
82+
void mlir::math::populateExtendToSupportedTypesConversionTarget(
83+
ConversionTarget &target, TypeConverter &typeConverter) {
84+
target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool {
85+
if (isa<MathDialect>(op->getDialect()))
86+
return typeConverter.isLegal(op);
87+
return true;
88+
});
89+
target.addLegalOp<FmaOp>();
90+
target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
91+
}
92+
93+
LogicalResult ExtendToSupportedTypesRewritePattern::matchAndRewrite(
94+
Operation *op, ArrayRef<Value> operands,
95+
ConversionPatternRewriter &rewriter) const {
96+
Location loc = op->getLoc();
97+
const TypeConverter *converter = getTypeConverter();
98+
FailureOr<Operation *> legalized =
99+
convertOpResultTypes(op, operands, *converter, rewriter);
100+
if (failed(legalized))
101+
return failure();
102+
103+
SmallVector<Value> results = (*legalized)->getResults();
104+
for (auto [result, newType, origType] : llvm::zip_equal(
105+
results, (*legalized)->getResultTypes(), op->getResultTypes())) {
106+
if (newType != origType) {
107+
auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
108+
truncFOp.setFastmath(arith::FastMathFlags::contract);
109+
result = truncFOp.getResult();
110+
}
111+
}
112+
rewriter.replaceOp(op, results);
113+
return success();
114+
}
115+
116+
void mlir::math::populateExtendToSupportedTypesPatterns(
117+
RewritePatternSet &patterns, TypeConverter &typeConverter) {
118+
patterns.add<ExtendToSupportedTypesRewritePattern>(typeConverter,
119+
patterns.getContext());
120+
}
121+
122+
void ExtendToSupportedTypesPass::runOnOperation() {
123+
Operation *op = getOperation();
124+
MLIRContext *ctx = &getContext();
125+
126+
// Parse target type
127+
std::optional<Type> maybeTargetType =
128+
arith::parseFloatType(ctx, targetTypeStr);
129+
if (!maybeTargetType.has_value()) {
130+
emitError(UnknownLoc::get(ctx), "could not map target type '" +
131+
targetTypeStr +
132+
"' to a known floating-point type");
133+
return signalPassFailure();
134+
}
135+
Type targetType = maybeTargetType.value();
136+
137+
// Parse source types
138+
llvm::SetVector<Type> sourceTypes;
139+
for (const auto &extraTypeStr : extraTypeStrs) {
140+
std::optional<FloatType> maybeExtraType =
141+
arith::parseFloatType(ctx, extraTypeStr);
142+
if (!maybeExtraType.has_value()) {
143+
emitError(UnknownLoc::get(ctx), "could not map source type '" +
144+
extraTypeStr +
145+
"' to a known floating-point type");
146+
return signalPassFailure();
147+
}
148+
sourceTypes.insert(maybeExtraType.value());
149+
}
150+
// f64 and f32 are implicitly supported
151+
Builder b(ctx);
152+
sourceTypes.insert(b.getF64Type());
153+
sourceTypes.insert(b.getF32Type());
154+
155+
TypeConverter typeConverter;
156+
math::populateExtendToSupportedTypesTypeConverter(typeConverter, sourceTypes,
157+
targetType);
158+
ConversionTarget target(*ctx);
159+
math::populateExtendToSupportedTypesConversionTarget(target, typeConverter);
160+
RewritePatternSet patterns(ctx);
161+
math::populateExtendToSupportedTypesPatterns(patterns, typeConverter);
162+
if (failed(applyPartialConversion(op, target, std::move(patterns))))
163+
return signalPassFailure();
164+
}

0 commit comments

Comments
 (0)