-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Math] Add pass to legalize math functions to f32-or-higher #78361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
//===- LegalizeToF32.cpp - Legalize functions on small floats ----------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This file implements legalizing math operations on small floating-point | ||
// types through arith.extf and arith.truncf. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Math/IR/Math.h" | ||
#include "mlir/Dialect/Math/Transforms/Passes.h" | ||
#include "mlir/IR/Diagnostics.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/IR/TypeUtilities.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
#include "llvm/ADT/STLExtras.h" | ||
|
||
namespace mlir::math { | ||
#define GEN_PASS_DEF_MATHLEGALIZETOF32 | ||
#include "mlir/Dialect/Math/Transforms/Passes.h.inc" | ||
} // namespace mlir::math | ||
|
||
using namespace mlir; | ||
namespace { | ||
struct LegalizeToF32RewritePattern final : ConversionPattern { | ||
LegalizeToF32RewritePattern(TypeConverter &converter, MLIRContext *context) | ||
: ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {} | ||
LogicalResult | ||
matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||
ConversionPatternRewriter &rewriter) const override; | ||
}; | ||
|
||
struct LegalizeToF32Pass final | ||
: mlir::math::impl::MathLegalizeToF32Base<LegalizeToF32Pass> { | ||
void runOnOperation() override; | ||
}; | ||
} // namespace | ||
|
||
void mlir::math::populateLegalizeToF32TypeConverter( | ||
TypeConverter &typeConverter) { | ||
typeConverter.addConversion( | ||
[](Type type) -> std::optional<Type> { return type; }); | ||
typeConverter.addConversion([](FloatType type) -> std::optional<Type> { | ||
if (type.getWidth() < 32) | ||
return Float32Type::get(type.getContext()); | ||
return std::nullopt; | ||
}); | ||
typeConverter.addConversion([](ShapedType type) -> std::optional<Type> { | ||
if (auto elemTy = dyn_cast<FloatType>(type.getElementType())) | ||
return type.clone(Float32Type::get(type.getContext())); | ||
return std::nullopt; | ||
}); | ||
typeConverter.addTargetMaterialization( | ||
[](OpBuilder &b, Type target, ValueRange input, Location loc) { | ||
return b.create<arith::ExtFOp>(loc, target, input); | ||
}); | ||
} | ||
|
||
void mlir::math::populateLegalizeToF32ConversionTarget( | ||
ConversionTarget &target, TypeConverter &typeConverter) { | ||
target.addDynamicallyLegalDialect<MathDialect>( | ||
[&typeConverter](Operation *op) -> bool { | ||
return typeConverter.isLegal(op); | ||
}); | ||
target.addLegalOp<FmaOp>(); | ||
target.addLegalOp<arith::ExtFOp, arith::TruncFOp>(); | ||
} | ||
|
||
LogicalResult LegalizeToF32RewritePattern::matchAndRewrite( | ||
Operation *op, ArrayRef<Value> operands, | ||
ConversionPatternRewriter &rewriter) const { | ||
Location loc = op->getLoc(); | ||
const TypeConverter *converter = getTypeConverter(); | ||
if (converter->isLegal(op)) | ||
return rewriter.notifyMatchFailure(loc, "op already legal"); | ||
OperationState newOp(loc, op->getName()); | ||
newOp.addOperands(operands); | ||
|
||
SmallVector<Type> newResultTypes; | ||
if (failed(converter->convertTypes(op->getResultTypes(), newResultTypes))) | ||
return rewriter.notifyMatchFailure(loc, "couldn't convert return types"); | ||
newOp.addTypes(newResultTypes); | ||
newOp.addAttributes(op->getAttrs()); | ||
Operation *legalized = rewriter.create(newOp); | ||
SmallVector<Value> results = legalized->getResults(); | ||
for (auto [result, newType, origType] : | ||
llvm::zip_equal(results, newResultTypes, op->getResultTypes())) { | ||
if (newType != origType) | ||
result = rewriter.create<arith::TruncFOp>(loc, origType, result); | ||
} | ||
rewriter.replaceOp(op, results); | ||
return success(); | ||
} | ||
|
||
void mlir::math::populateLegalizeToF32Patterns(RewritePatternSet &patterns, | ||
TypeConverter &typeConverter) { | ||
patterns.add<LegalizeToF32RewritePattern>(typeConverter, | ||
patterns.getContext()); | ||
} | ||
|
||
void LegalizeToF32Pass::runOnOperation() { | ||
Operation *op = getOperation(); | ||
MLIRContext &ctx = getContext(); | ||
|
||
TypeConverter typeConverter; | ||
math::populateLegalizeToF32TypeConverter(typeConverter); | ||
ConversionTarget target(ctx); | ||
math::populateLegalizeToF32ConversionTarget(target, typeConverter); | ||
RewritePatternSet patterns(&ctx); | ||
math::populateLegalizeToF32Patterns(patterns, typeConverter); | ||
if (failed(applyPartialConversion(op, target, std::move(patterns)))) | ||
return signalPassFailure(); | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 | FileCheck %s | ||
|
||
// CHECK-LABEL: @sin | ||
// CHECK-SAME: ([[ARG0:%.+]]: f16) | ||
func.func @sin(%arg0: f16) -> f16 { | ||
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] | ||
// CHECK: [[SIN:%.+]] = math.sin [[EXTF]] | ||
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] | ||
// CHECK: return [[TRUNCF]] : f16 | ||
%0 = math.sin %arg0 : f16 | ||
return %0 : f16 | ||
} | ||
|
||
// CHECK-LABEL: @fpowi | ||
// CHECK-SAME: ([[ARG0:%.+]]: f16, [[ARG1:%.+]]: i32) | ||
func.func @fpowi(%arg0: f16, %arg1: i32) -> f16 { | ||
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] | ||
// CHECK: [[FPOWI:%.+]] = math.fpowi [[EXTF]], [[ARG1]] | ||
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[FPOWI]] | ||
// CHECK: return [[TRUNCF]] : f16 | ||
%0 = math.fpowi %arg0, %arg1 : f16, i32 | ||
return %0 : f16 | ||
} | ||
|
||
// COM: Verify that the pass leaves `math.fma` untouched, since it is often | ||
// COM: implemented on small data types. | ||
// CHECK-LABEL: @fma | ||
krzysz00 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// CHECK-SAME: ([[ARG0:%.+]]: f16, [[ARG1:%.+]]: f16, [[ARG2:%.+]]: f16) | ||
// CHECK: [[FMA:%.+]] = math.fma [[ARG0]], [[ARG1]], [[ARG2]] | ||
// CHECK: return [[FMA]] : f16 | ||
func.func @fma(%arg0: f16, %arg1: f16, %arg2: f16) -> f16 { | ||
%0 = math.fma %arg0, %arg1, %arg2 : f16 | ||
return %0 : f16 | ||
} | ||
|
||
// CHECK-LABEL: @absf_f32 | ||
// CHECK-SAME: ([[ARG0:%.+]]: f32) | ||
// CHECK: [[ABSF:%.+]] = math.absf [[ARG0]] | ||
// CHECK: return [[ABSF]] : f32 | ||
func.func @absf_f32(%arg0: f32) -> f32 { | ||
%0 = math.absf %arg0 : f32 | ||
return %0 : f32 | ||
} | ||
|
||
// CHECK-LABEL: @absf_f64 | ||
// CHECK-SAME: ([[ARG0:%.+]]: f64) | ||
// CHECK: [[ABSF:%.+]] = math.absf [[ARG0]] | ||
// CHECK: return [[ABSF]] : f64 | ||
func.func @absf_f64(%arg0: f64) -> f64 { | ||
%0 = math.absf %arg0 : f64 | ||
return %0 : f64 | ||
} | ||
|
||
// CHECK-LABEL: @sin_vector | ||
// CHECK-SAME: ([[ARG0:%.+]]: vector<2xbf16>) | ||
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] | ||
// CHECK: [[SIN:%.+]] = math.sin [[EXTF]] | ||
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] | ||
// CHECK: return [[TRUNCF]] : vector<2xbf16> | ||
func.func @sin_vector(%arg0: vector<2xbf16>) -> vector<2xbf16> { | ||
%0 = math.sin %arg0 : vector<2xbf16> | ||
return %0 : vector<2xbf16> | ||
} | ||
|
||
// CHECK-LABEL: @fastmath | ||
// CHECK: math.sin %{{.+}} fastmath<nsz> | ||
func.func @fastmath(%arg0: f16) -> f16 { | ||
%0 = math.sin %arg0 fastmath<nsz> : f16 | ||
return %0 : f16 | ||
} | ||
|
||
// CHECK-LABEL: @sequences | ||
// CHECK-SAME: ([[ARG0:%.+]]: f16) | ||
// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]] | ||
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]] | ||
// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[ABSF]] | ||
// CHECK: [[EXTF1:%.+]] = arith.extf [[TRUNCF0]] | ||
// CHECK: [[SIN:%.+]] = math.sin [[EXTF1]] | ||
// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[SIN]] | ||
// CHECK: return [[TRUNCF1]] : f16 | ||
func.func @sequences(%arg0: f16) -> f16 { | ||
%0 = math.absf %arg0 : f16 | ||
%1 = math.sin %0 : f16 | ||
return %1 : f16 | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wondering, can we, instead of using low-level API to construct op, just clone original op and update args/return types on it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went for your suggestion - minus cloning the op, just mutating in place - what do you think?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not 100% sure it's allowed in dialect conversion framework without cloning (or even with cloning, actually) as it should be able to 'revert' the changes if necessary. But with cloning we can at least use
rewriter.replaceOp
instead of RAU. Ping @joker-eph I guess?Regarding
Intintionally(typo) don't tell the rewriter we're doing this to prevent spurious attempts to legalize the consumer
comment, I think you can avoid this by usingaddDynamicallyLegalOp
and always return true for non-math ops.Also, I think
updateRootInPlace
was recently renamed, can you rebase?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean, I had the
OperationState
version so I could copy over attributes, properties, etc. while creating a new operation.I'm not sure
clone()
+ mutations really buys us anything over just the mutations. And looking at dialect conversion's resets, the only rollback you'd want to be able to do iscancelOpModification
which will reset the operands (though not result types, apparently)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I thought clone version will be shorter and will help to avoid using
OperationState
builder, which is unconventional outside the builders, but it opens more questions on what is actually allowed in conversion framework. You can revert to old version if you like, at least I'm sure it should work with conversion framework, sorry.