Skip to content

Commit 05e85e4

Browse files
authored
[mlir][Math] Add pass to legalize math functions to f32-or-higher (#78361)
Since most of the operations in the `math` dialect don't have low-precision implementations, add the -math-legalize-to-f32 pass that goes through and brackets low-precision math funcitons (like `math.sin %0 : f16`) with `arith.extf` and `arith.truncf`. This preserves the original semantics of the math operation but allows lowering to proceed. Versions of this lowering are already implicitly present in some passes, like ConvertGPUToROCDL. However, because those are implicit rewrites, they hide the floating-point extension and truncation, preventing anyone from writing passes that operate on those implitic extf/truncf pairs. Exposing this legalization explicitly is needed to allow lowening 8-bit floats on AMD GPUs, as the implementation of extf and truncf on that platform requires the complex logic found in ArithToAMDGPU, which runs before the GPU to ROCDL lowering.
1 parent 2286789 commit 05e85e4

File tree

5 files changed

+231
-0
lines changed

5 files changed

+231
-0
lines changed

mlir/include/mlir/Dialect/Math/Transforms/Passes.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@ namespace math {
1616
#define GEN_PASS_DECL
1717
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
1818
#define GEN_PASS_DECL_MATHUPLIFTTOFMA
19+
#define GEN_PASS_DECL_MATHLEGALIZETOF32
1920
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
2021
#define GEN_PASS_REGISTRATION
2122
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
2223
} // namespace math
2324

25+
class ConversionTarget;
2426
class RewritePatternSet;
27+
class TypeConverter;
2528

2629
void populateExpandCtlzPattern(RewritePatternSet &patterns);
2730
void populateExpandTanPattern(RewritePatternSet &patterns);
@@ -48,6 +51,13 @@ void populateMathPolynomialApproximationPatterns(
4851

4952
void populateUpliftToFMAPatterns(RewritePatternSet &patterns);
5053

54+
namespace math {
55+
void populateLegalizeToF32TypeConverter(TypeConverter &typeConverter);
56+
void populateLegalizeToF32ConversionTarget(ConversionTarget &target,
57+
TypeConverter &typeConverter);
58+
void populateLegalizeToF32Patterns(RewritePatternSet &patterns,
59+
TypeConverter &typeConverter);
60+
} // namespace math
5161
} // namespace mlir
5262

5363
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_

mlir/include/mlir/Dialect/Math/Transforms/Passes.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,21 @@ def MathUpliftToFMA : Pass<"math-uplift-to-fma"> {
1919
let dependentDialects = ["math::MathDialect"];
2020
}
2121

22+
def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
23+
let summary = "Legalize floating-point math ops on low-precision floats";
24+
let description = [{
25+
On many targets, the math functions are not implemented for floating-point
26+
types less precise than IEEE single-precision (aka f32), such as half-floats,
27+
bfloat16, or 8-bit floats.
28+
29+
This pass explicitly legalizes these math functions by inserting
30+
`arith.extf` and `arith.truncf` pairs around said op, which preserves
31+
the original semantics while enabling lowering.
32+
33+
As an exception, this pass does not legalize `math.fma`, because
34+
that is an operation frequently implemented at low precisions.
35+
}];
36+
let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
37+
}
38+
2239
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES

mlir/lib/Dialect/Math/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_mlir_dialect_library(MLIRMathTransforms
22
AlgebraicSimplification.cpp
33
ExpandPatterns.cpp
4+
LegalizeToF32.cpp
45
PolynomialApproximation.cpp
56
UpliftToFMA.cpp
67

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
//===- LegalizeToF32.cpp - Legalize functions on small floats ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements legalizing math operations on small floating-point
10+
// types through arith.extf and arith.truncf.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Arith/IR/Arith.h"
15+
#include "mlir/Dialect/Math/IR/Math.h"
16+
#include "mlir/Dialect/Math/Transforms/Passes.h"
17+
#include "mlir/IR/Diagnostics.h"
18+
#include "mlir/IR/PatternMatch.h"
19+
#include "mlir/IR/TypeUtilities.h"
20+
#include "mlir/Transforms/DialectConversion.h"
21+
#include "llvm/ADT/STLExtras.h"
22+
23+
namespace mlir::math {
24+
#define GEN_PASS_DEF_MATHLEGALIZETOF32
25+
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
26+
} // namespace mlir::math
27+
28+
using namespace mlir;
29+
namespace {
30+
struct LegalizeToF32RewritePattern final : ConversionPattern {
31+
LegalizeToF32RewritePattern(TypeConverter &converter, MLIRContext *context)
32+
: ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {}
33+
LogicalResult
34+
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
35+
ConversionPatternRewriter &rewriter) const override;
36+
};
37+
38+
struct LegalizeToF32Pass final
39+
: mlir::math::impl::MathLegalizeToF32Base<LegalizeToF32Pass> {
40+
void runOnOperation() override;
41+
};
42+
} // namespace
43+
44+
void mlir::math::populateLegalizeToF32TypeConverter(
45+
TypeConverter &typeConverter) {
46+
typeConverter.addConversion(
47+
[](Type type) -> std::optional<Type> { return type; });
48+
typeConverter.addConversion([](FloatType type) -> std::optional<Type> {
49+
if (type.getWidth() < 32)
50+
return Float32Type::get(type.getContext());
51+
return std::nullopt;
52+
});
53+
typeConverter.addConversion([](ShapedType type) -> std::optional<Type> {
54+
if (auto elemTy = dyn_cast<FloatType>(type.getElementType()))
55+
return type.clone(Float32Type::get(type.getContext()));
56+
return std::nullopt;
57+
});
58+
typeConverter.addTargetMaterialization(
59+
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
60+
return b.create<arith::ExtFOp>(loc, target, input);
61+
});
62+
}
63+
64+
void mlir::math::populateLegalizeToF32ConversionTarget(
65+
ConversionTarget &target, TypeConverter &typeConverter) {
66+
target.addDynamicallyLegalDialect<MathDialect>(
67+
[&typeConverter](Operation *op) -> bool {
68+
return typeConverter.isLegal(op);
69+
});
70+
target.addLegalOp<FmaOp>();
71+
target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
72+
}
73+
74+
LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
75+
Operation *op, ArrayRef<Value> operands,
76+
ConversionPatternRewriter &rewriter) const {
77+
Location loc = op->getLoc();
78+
const TypeConverter *converter = getTypeConverter();
79+
if (converter->isLegal(op))
80+
return rewriter.notifyMatchFailure(loc, "op already legal");
81+
OperationState newOp(loc, op->getName());
82+
newOp.addOperands(operands);
83+
84+
SmallVector<Type> newResultTypes;
85+
if (failed(converter->convertTypes(op->getResultTypes(), newResultTypes)))
86+
return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
87+
newOp.addTypes(newResultTypes);
88+
newOp.addAttributes(op->getAttrs());
89+
Operation *legalized = rewriter.create(newOp);
90+
SmallVector<Value> results = legalized->getResults();
91+
for (auto [result, newType, origType] :
92+
llvm::zip_equal(results, newResultTypes, op->getResultTypes())) {
93+
if (newType != origType)
94+
result = rewriter.create<arith::TruncFOp>(loc, origType, result);
95+
}
96+
rewriter.replaceOp(op, results);
97+
return success();
98+
}
99+
100+
void mlir::math::populateLegalizeToF32Patterns(RewritePatternSet &patterns,
101+
TypeConverter &typeConverter) {
102+
patterns.add<LegalizeToF32RewritePattern>(typeConverter,
103+
patterns.getContext());
104+
}
105+
106+
void LegalizeToF32Pass::runOnOperation() {
107+
Operation *op = getOperation();
108+
MLIRContext &ctx = getContext();
109+
110+
TypeConverter typeConverter;
111+
math::populateLegalizeToF32TypeConverter(typeConverter);
112+
ConversionTarget target(ctx);
113+
math::populateLegalizeToF32ConversionTarget(target, typeConverter);
114+
RewritePatternSet patterns(&ctx);
115+
math::populateLegalizeToF32Patterns(patterns, typeConverter);
116+
if (failed(applyPartialConversion(op, target, std::move(patterns))))
117+
return signalPassFailure();
118+
}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 | FileCheck %s
2+
3+
// CHECK-LABEL: @sin
4+
// CHECK-SAME: ([[ARG0:%.+]]: f16)
5+
func.func @sin(%arg0: f16) -> f16 {
6+
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
7+
// CHECK: [[SIN:%.+]] = math.sin [[EXTF]]
8+
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
9+
// CHECK: return [[TRUNCF]] : f16
10+
%0 = math.sin %arg0 : f16
11+
return %0 : f16
12+
}
13+
14+
// CHECK-LABEL: @fpowi
15+
// CHECK-SAME: ([[ARG0:%.+]]: f16, [[ARG1:%.+]]: i32)
16+
func.func @fpowi(%arg0: f16, %arg1: i32) -> f16 {
17+
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
18+
// CHECK: [[FPOWI:%.+]] = math.fpowi [[EXTF]], [[ARG1]]
19+
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[FPOWI]]
20+
// CHECK: return [[TRUNCF]] : f16
21+
%0 = math.fpowi %arg0, %arg1 : f16, i32
22+
return %0 : f16
23+
}
24+
25+
// COM: Verify that the pass leaves `math.fma` untouched, since it is often
26+
// COM: implemented on small data types.
27+
// CHECK-LABEL: @fma
28+
// CHECK-SAME: ([[ARG0:%.+]]: f16, [[ARG1:%.+]]: f16, [[ARG2:%.+]]: f16)
29+
// CHECK: [[FMA:%.+]] = math.fma [[ARG0]], [[ARG1]], [[ARG2]]
30+
// CHECK: return [[FMA]] : f16
31+
func.func @fma(%arg0: f16, %arg1: f16, %arg2: f16) -> f16 {
32+
%0 = math.fma %arg0, %arg1, %arg2 : f16
33+
return %0 : f16
34+
}
35+
36+
// CHECK-LABEL: @absf_f32
37+
// CHECK-SAME: ([[ARG0:%.+]]: f32)
38+
// CHECK: [[ABSF:%.+]] = math.absf [[ARG0]]
39+
// CHECK: return [[ABSF]] : f32
40+
func.func @absf_f32(%arg0: f32) -> f32 {
41+
%0 = math.absf %arg0 : f32
42+
return %0 : f32
43+
}
44+
45+
// CHECK-LABEL: @absf_f64
46+
// CHECK-SAME: ([[ARG0:%.+]]: f64)
47+
// CHECK: [[ABSF:%.+]] = math.absf [[ARG0]]
48+
// CHECK: return [[ABSF]] : f64
49+
func.func @absf_f64(%arg0: f64) -> f64 {
50+
%0 = math.absf %arg0 : f64
51+
return %0 : f64
52+
}
53+
54+
// CHECK-LABEL: @sin_vector
55+
// CHECK-SAME: ([[ARG0:%.+]]: vector<2xbf16>)
56+
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
57+
// CHECK: [[SIN:%.+]] = math.sin [[EXTF]]
58+
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
59+
// CHECK: return [[TRUNCF]] : vector<2xbf16>
60+
func.func @sin_vector(%arg0: vector<2xbf16>) -> vector<2xbf16> {
61+
%0 = math.sin %arg0 : vector<2xbf16>
62+
return %0 : vector<2xbf16>
63+
}
64+
65+
// CHECK-LABEL: @fastmath
66+
// CHECK: math.sin %{{.+}} fastmath<nsz>
67+
func.func @fastmath(%arg0: f16) -> f16 {
68+
%0 = math.sin %arg0 fastmath<nsz> : f16
69+
return %0 : f16
70+
}
71+
72+
// CHECK-LABEL: @sequences
73+
// CHECK-SAME: ([[ARG0:%.+]]: f16)
74+
// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
75+
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]]
76+
// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[ABSF]]
77+
// CHECK: [[EXTF1:%.+]] = arith.extf [[TRUNCF0]]
78+
// CHECK: [[SIN:%.+]] = math.sin [[EXTF1]]
79+
// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[SIN]]
80+
// CHECK: return [[TRUNCF1]] : f16
81+
func.func @sequences(%arg0: f16) -> f16 {
82+
%0 = math.absf %arg0 : f16
83+
%1 = math.sin %0 : f16
84+
return %1 : f16
85+
}

0 commit comments

Comments
 (0)