Skip to content

[mlir][Math] Add pass to legalize math functions to f32-or-higher #78361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions mlir/include/mlir/Dialect/Math/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ namespace math {
#define GEN_PASS_DECL
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
#define GEN_PASS_DECL_MATHUPLIFTTOFMA
#define GEN_PASS_DECL_MATHLEGALIZETOF32
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
} // namespace math

class ConversionTarget;
class RewritePatternSet;
class TypeConverter;

void populateExpandCtlzPattern(RewritePatternSet &patterns);
void populateExpandTanPattern(RewritePatternSet &patterns);
Expand All @@ -48,6 +51,13 @@ void populateMathPolynomialApproximationPatterns(

void populateUpliftToFMAPatterns(RewritePatternSet &patterns);

namespace math {
void populateLegalizeToF32TypeConverter(TypeConverter &typeConverter);
void populateLegalizeToF32ConversionTarget(ConversionTarget &target,
TypeConverter &typeConverter);
void populateLegalizeToF32Patterns(RewritePatternSet &patterns,
TypeConverter &typeConverter);
} // namespace math
} // namespace mlir

#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
17 changes: 17 additions & 0 deletions mlir/include/mlir/Dialect/Math/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,21 @@ def MathUpliftToFMA : Pass<"math-uplift-to-fma"> {
let dependentDialects = ["math::MathDialect"];
}

def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
let summary = "Legalize floating-point math ops on low-precision floats";
let description = [{
On many targets, the math functions are not implemented for floating-point
types less precise than IEEE single-precision (aka f32), such as half-floats,
bfloat16, or 8-bit floats.

This pass explicitly legalizes these math functions by inserting
`arith.extf` and `arith.truncf` pairs around said op, which preserves
the original semantics while enabling lowering.

As an exception, this pass does not legalize `math.fma`, because
that is an operation frequently implemented at low precisions.
}];
let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
}

#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRMathTransforms
AlgebraicSimplification.cpp
ExpandPatterns.cpp
LegalizeToF32.cpp
PolynomialApproximation.cpp
UpliftToFMA.cpp

Expand Down
118 changes: 118 additions & 0 deletions mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
//===- LegalizeToF32.cpp - Legalize functions on small floats ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements legalizing math operations on small floating-point
// types through arith.extf and arith.truncf.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"

namespace mlir::math {
#define GEN_PASS_DEF_MATHLEGALIZETOF32
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
} // namespace mlir::math

using namespace mlir;
namespace {
struct LegalizeToF32RewritePattern final : ConversionPattern {
LegalizeToF32RewritePattern(TypeConverter &converter, MLIRContext *context)
: ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};

struct LegalizeToF32Pass final
: mlir::math::impl::MathLegalizeToF32Base<LegalizeToF32Pass> {
void runOnOperation() override;
};
} // namespace

void mlir::math::populateLegalizeToF32TypeConverter(
TypeConverter &typeConverter) {
typeConverter.addConversion(
[](Type type) -> std::optional<Type> { return type; });
typeConverter.addConversion([](FloatType type) -> std::optional<Type> {
if (type.getWidth() < 32)
return Float32Type::get(type.getContext());
return std::nullopt;
});
typeConverter.addConversion([](ShapedType type) -> std::optional<Type> {
if (auto elemTy = dyn_cast<FloatType>(type.getElementType()))
return type.clone(Float32Type::get(type.getContext()));
return std::nullopt;
});
typeConverter.addTargetMaterialization(
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
return b.create<arith::ExtFOp>(loc, target, input);
});
}

void mlir::math::populateLegalizeToF32ConversionTarget(
ConversionTarget &target, TypeConverter &typeConverter) {
target.addDynamicallyLegalDialect<MathDialect>(
[&typeConverter](Operation *op) -> bool {
return typeConverter.isLegal(op);
});
target.addLegalOp<FmaOp>();
target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
}

LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
const TypeConverter *converter = getTypeConverter();
if (converter->isLegal(op))
return rewriter.notifyMatchFailure(loc, "op already legal");
OperationState newOp(loc, op->getName());
newOp.addOperands(operands);

SmallVector<Type> newResultTypes;
if (failed(converter->convertTypes(op->getResultTypes(), newResultTypes)))
return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
newOp.addTypes(newResultTypes);
newOp.addAttributes(op->getAttrs());
Operation *legalized = rewriter.create(newOp);
Copy link
Contributor

@Hardcode84 Hardcode84 Jan 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wondering, can we, instead of using low-level API to construct op, just clone original op and update args/return types on it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went for your suggestion - minus cloning the op, just mutating in place - what do you think?

Copy link
Contributor

@Hardcode84 Hardcode84 Jan 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure it's allowed in dialect conversion framework without cloning (or even with cloning, actually) as it should be able to 'revert' the changes if necessary. But with cloning we can at least use rewriter.replaceOp instead of RAU. Ping @joker-eph I guess?

Regarding Intintionally(typo) don't tell the rewriter we're doing this to prevent spurious attempts to legalize the consumer comment, I think you can avoid this by using addDynamicallyLegalOp and always return true for non-math ops.

Also, I think updateRootInPlace was recently renamed, can you rebase?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, I had the OperationState version so I could copy over attributes, properties, etc. while creating a new operation.

I'm not sure clone() + mutations really buys us anything over just the mutations. And looking at dialect conversion's resets, the only rollback you'd want to be able to do is cancelOpModification which will reset the operands (though not result types, apparently)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I thought clone version will be shorter and will help to avoid using OperationState builder, which is unconventional outside the builders, but it opens more questions on what is actually allowed in conversion framework. You can revert to old version if you like, at least I'm sure it should work with conversion framework, sorry.

SmallVector<Value> results = legalized->getResults();
for (auto [result, newType, origType] :
llvm::zip_equal(results, newResultTypes, op->getResultTypes())) {
if (newType != origType)
result = rewriter.create<arith::TruncFOp>(loc, origType, result);
}
rewriter.replaceOp(op, results);
return success();
}

void mlir::math::populateLegalizeToF32Patterns(RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<LegalizeToF32RewritePattern>(typeConverter,
patterns.getContext());
}

void LegalizeToF32Pass::runOnOperation() {
Operation *op = getOperation();
MLIRContext &ctx = getContext();

TypeConverter typeConverter;
math::populateLegalizeToF32TypeConverter(typeConverter);
ConversionTarget target(ctx);
math::populateLegalizeToF32ConversionTarget(target, typeConverter);
RewritePatternSet patterns(&ctx);
math::populateLegalizeToF32Patterns(patterns, typeConverter);
if (failed(applyPartialConversion(op, target, std::move(patterns))))
return signalPassFailure();
}
85 changes: 85 additions & 0 deletions mlir/test/Dialect/Math/legalize-to-f32.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 | FileCheck %s

// CHECK-LABEL: @sin
// CHECK-SAME: ([[ARG0:%.+]]: f16)
func.func @sin(%arg0: f16) -> f16 {
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
// CHECK: [[SIN:%.+]] = math.sin [[EXTF]]
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
// CHECK: return [[TRUNCF]] : f16
%0 = math.sin %arg0 : f16
return %0 : f16
}

// CHECK-LABEL: @fpowi
// CHECK-SAME: ([[ARG0:%.+]]: f16, [[ARG1:%.+]]: i32)
func.func @fpowi(%arg0: f16, %arg1: i32) -> f16 {
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
// CHECK: [[FPOWI:%.+]] = math.fpowi [[EXTF]], [[ARG1]]
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[FPOWI]]
// CHECK: return [[TRUNCF]] : f16
%0 = math.fpowi %arg0, %arg1 : f16, i32
return %0 : f16
}

// COM: Verify that the pass leaves `math.fma` untouched, since it is often
// COM: implemented on small data types.
// CHECK-LABEL: @fma
// CHECK-SAME: ([[ARG0:%.+]]: f16, [[ARG1:%.+]]: f16, [[ARG2:%.+]]: f16)
// CHECK: [[FMA:%.+]] = math.fma [[ARG0]], [[ARG1]], [[ARG2]]
// CHECK: return [[FMA]] : f16
func.func @fma(%arg0: f16, %arg1: f16, %arg2: f16) -> f16 {
%0 = math.fma %arg0, %arg1, %arg2 : f16
return %0 : f16
}

// CHECK-LABEL: @absf_f32
// CHECK-SAME: ([[ARG0:%.+]]: f32)
// CHECK: [[ABSF:%.+]] = math.absf [[ARG0]]
// CHECK: return [[ABSF]] : f32
func.func @absf_f32(%arg0: f32) -> f32 {
%0 = math.absf %arg0 : f32
return %0 : f32
}

// CHECK-LABEL: @absf_f64
// CHECK-SAME: ([[ARG0:%.+]]: f64)
// CHECK: [[ABSF:%.+]] = math.absf [[ARG0]]
// CHECK: return [[ABSF]] : f64
func.func @absf_f64(%arg0: f64) -> f64 {
%0 = math.absf %arg0 : f64
return %0 : f64
}

// CHECK-LABEL: @sin_vector
// CHECK-SAME: ([[ARG0:%.+]]: vector<2xbf16>)
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
// CHECK: [[SIN:%.+]] = math.sin [[EXTF]]
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
// CHECK: return [[TRUNCF]] : vector<2xbf16>
func.func @sin_vector(%arg0: vector<2xbf16>) -> vector<2xbf16> {
%0 = math.sin %arg0 : vector<2xbf16>
return %0 : vector<2xbf16>
}

// CHECK-LABEL: @fastmath
// CHECK: math.sin %{{.+}} fastmath<nsz>
func.func @fastmath(%arg0: f16) -> f16 {
%0 = math.sin %arg0 fastmath<nsz> : f16
return %0 : f16
}

// CHECK-LABEL: @sequences
// CHECK-SAME: ([[ARG0:%.+]]: f16)
// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]]
// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[ABSF]]
// CHECK: [[EXTF1:%.+]] = arith.extf [[TRUNCF0]]
// CHECK: [[SIN:%.+]] = math.sin [[EXTF1]]
// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[SIN]]
// CHECK: return [[TRUNCF1]] : f16
func.func @sequences(%arg0: f16) -> f16 {
%0 = math.absf %arg0 : f16
%1 = math.sin %0 : f16
return %1 : f16
}