-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[flang][hlfir][NFC] Fix mlir misuse in LowerHLFIRIntrinsics #83293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
In llvm#83253 @matthias-springer pointed out that LowerHLFIRIntrinsics.cpp should not be using rewrite patterns with the dialect conversion driver. The intention of this pass is to lower HLFIR intrinsic operations into FIR and so I think this best fits dialect conversion and so I have changed all of these into conversion patterns. Taking this approach also avoids test suite churn because GreedyPatternRewriter also performs canonicalization. One remaining misuse of the MLIR API is that we replace values of one type with a different (although safe) type e.g. !hlfir.expr<2xi32> -> !hlfir.expr<?xi32>. There isn't a convenient way to perform this conversion in IR at the moment because fir.convert does not accept !hlfir.expr.
@llvm/pr-subscribers-flang-fir-hlfir Author: Tom Eccles (tblah) ChangesIn #83253 @matthias-springer pointed out that LowerHLFIRIntrinsics.cpp should not be using rewrite patterns with the dialect conversion driver. The intention of this pass is to lower HLFIR intrinsic operations into FIR and so I think this best fits dialect conversion and so I have changed all of these into conversion patterns. Taking this approach also avoids test suite churn because GreedyPatternRewriter also performs canonicalization. One remaining misuse of the MLIR API is that we replace values of one type with a different (although safe) type e.g. Full diff: https://github.com/llvm/llvm-project/pull/83293.diff 1 Files Affected:
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
index 377cc44392028f..b2e02376599636 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
@@ -36,10 +36,10 @@ namespace {
/// Base class for passes converting transformational intrinsic operations into
/// runtime calls
template <class OP>
-class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
+class HlfirIntrinsicConversion : public mlir::OpConversionPattern<OP> {
public:
explicit HlfirIntrinsicConversion(mlir::MLIRContext *ctx)
- : mlir::OpRewritePattern<OP>{ctx} {
+ : mlir::OpConversionPattern<OP>{ctx} {
// required for cases where intrinsics are chained together e.g.
// matmul(matmul(a, b), c)
// because converting the inner operation then invalidates the
@@ -145,7 +145,7 @@ class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
void processReturnValue(mlir::Operation *op,
const fir::ExtendedValue &resultExv, bool mustBeFreed,
fir::FirOpBuilder &builder,
- mlir::PatternRewriter &rewriter) const {
+ mlir::ConversionPatternRewriter &rewriter) const {
mlir::Location loc = op->getLoc();
mlir::Value firBase = fir::getBase(resultExv);
@@ -176,13 +176,9 @@ class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
rewriter.eraseOp(use);
}
}
- // TODO: This entire pass should be a greedy pattern rewrite or a manual
- // IR traversal. A dialect conversion cannot be used here because
- // `replaceAllUsesWith` is not supported. Similarly, `replaceOp` is not
- // suitable because "op->getResult(0)" and "base" can have different types.
- // In such a case, the dialect conversion will attempt to convert the type,
- // but no type converter is specified in this pass. Also note that all
- // patterns in this pass are actually rewrite patterns.
+ // the types might not match exactly (but are safe)
+ // e.g. !hlfir.expr<?xi32> vs !hlfir.expr<2xi32>
+ // TODO: is this allowed by MLIR?
op->getResult(0).replaceAllUsesWith(base);
rewriter.replaceOp(op, base);
}
@@ -203,48 +199,53 @@ class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
typename HlfirIntrinsicConversion<OP>::IntrinsicArgument;
using HlfirIntrinsicConversion<OP>::lowerArguments;
using HlfirIntrinsicConversion<OP>::processReturnValue;
+ using Adaptor = typename OP::Adaptor;
protected:
- auto buildNumericalArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
+ auto buildNumericalArgs(mlir::Operation *operation, Adaptor adaptor,
+ mlir::Type i32, mlir::Type logicalType,
mlir::PatternRewriter &rewriter,
std::string opName) const {
llvm::SmallVector<IntrinsicArgument, 3> inArgs;
- inArgs.push_back({operation.getArray(), operation.getArray().getType()});
- inArgs.push_back({operation.getDim(), i32});
- inArgs.push_back({operation.getMask(), logicalType});
+ inArgs.push_back({adaptor.getArray(), adaptor.getArray().getType()});
+ inArgs.push_back({adaptor.getDim(), i32});
+ inArgs.push_back({adaptor.getMask(), logicalType});
auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
return lowerArguments(operation, inArgs, rewriter, argLowering);
};
- auto buildMinMaxLocArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
+ auto buildMinMaxLocArgs(mlir::Operation *operation, Adaptor adaptor,
+ mlir::Type i32, mlir::Type logicalType,
mlir::PatternRewriter &rewriter, std::string opName,
fir::FirOpBuilder builder) const {
llvm::SmallVector<IntrinsicArgument, 3> inArgs;
- inArgs.push_back({operation.getArray(), operation.getArray().getType()});
- inArgs.push_back({operation.getDim(), i32});
- inArgs.push_back({operation.getMask(), logicalType});
+ inArgs.push_back({adaptor.getArray(), adaptor.getArray().getType()});
+ inArgs.push_back({adaptor.getDim(), i32});
+ inArgs.push_back({adaptor.getMask(), logicalType});
mlir::Value kind = builder.createIntegerConstant(
- operation->getLoc(), i32, getKindForType(operation.getType()));
+ operation->getLoc(), i32,
+ getKindForType(operation->getResult(0).getType()));
inArgs.push_back({kind, i32});
- inArgs.push_back({operation.getBack(), i32});
+ inArgs.push_back({adaptor.getBack(), i32});
auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
return lowerArguments(operation, inArgs, rewriter, argLowering);
};
- auto buildLogicalArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
+ auto buildLogicalArgs(mlir::Operation *operation, Adaptor adaptor,
+ mlir::Type i32, mlir::Type logicalType,
mlir::PatternRewriter &rewriter,
std::string opName) const {
llvm::SmallVector<IntrinsicArgument, 2> inArgs;
- inArgs.push_back({operation.getMask(), logicalType});
- inArgs.push_back({operation.getDim(), i32});
+ inArgs.push_back({adaptor.getMask(), logicalType});
+ inArgs.push_back({adaptor.getDim(), i32});
auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
return lowerArguments(operation, inArgs, rewriter, argLowering);
};
public:
mlir::LogicalResult
- matchAndRewrite(OP operation,
- mlir::PatternRewriter &rewriter) const override {
+ matchAndRewrite(OP operation, Adaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
std::string opName;
if constexpr (std::is_same_v<OP, hlfir::SumOp>) {
opName = "sum";
@@ -279,13 +280,15 @@ class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
std::is_same_v<OP, hlfir::ProductOp> ||
std::is_same_v<OP, hlfir::MaxvalOp> ||
std::is_same_v<OP, hlfir::MinvalOp>) {
- args = buildNumericalArgs(operation, i32, logicalType, rewriter, opName);
+ args = buildNumericalArgs(operation, adaptor, i32, logicalType, rewriter,
+ opName);
} else if constexpr (std::is_same_v<OP, hlfir::MinlocOp> ||
std::is_same_v<OP, hlfir::MaxlocOp>) {
- args = buildMinMaxLocArgs(operation, i32, logicalType, rewriter, opName,
- builder);
+ args = buildMinMaxLocArgs(operation, adaptor, i32, logicalType, rewriter,
+ opName, builder);
} else {
- args = buildLogicalArgs(operation, i32, logicalType, rewriter, opName);
+ args = buildLogicalArgs(operation, adaptor, i32, logicalType, rewriter,
+ opName);
}
mlir::Type scalarResultType =
@@ -319,8 +322,8 @@ struct CountOpConversion : public HlfirIntrinsicConversion<hlfir::CountOp> {
using HlfirIntrinsicConversion<hlfir::CountOp>::HlfirIntrinsicConversion;
mlir::LogicalResult
- matchAndRewrite(hlfir::CountOp count,
- mlir::PatternRewriter &rewriter) const override {
+ matchAndRewrite(hlfir::CountOp count, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, count.getOperation()};
const mlir::Location &loc = count->getLoc();
@@ -329,8 +332,8 @@ struct CountOpConversion : public HlfirIntrinsicConversion<hlfir::CountOp> {
builder.getContext(), builder.getKindMap().defaultLogicalKind());
llvm::SmallVector<IntrinsicArgument, 3> inArgs;
- inArgs.push_back({count.getMask(), logicalType});
- inArgs.push_back({count.getDim(), i32});
+ inArgs.push_back({adaptor.getMask(), logicalType});
+ inArgs.push_back({adaptor.getDim(), i32});
mlir::Value kind = builder.createIntegerConstant(
count->getLoc(), i32, getKindForType(count.getType()));
inArgs.push_back({kind, i32});
@@ -353,13 +356,13 @@ struct MatmulOpConversion : public HlfirIntrinsicConversion<hlfir::MatmulOp> {
using HlfirIntrinsicConversion<hlfir::MatmulOp>::HlfirIntrinsicConversion;
mlir::LogicalResult
- matchAndRewrite(hlfir::MatmulOp matmul,
- mlir::PatternRewriter &rewriter) const override {
+ matchAndRewrite(hlfir::MatmulOp matmul, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, matmul.getOperation()};
const mlir::Location &loc = matmul->getLoc();
- mlir::Value lhs = matmul.getLhs();
- mlir::Value rhs = matmul.getRhs();
+ mlir::Value lhs = adaptor.getLhs();
+ mlir::Value rhs = adaptor.getRhs();
llvm::SmallVector<IntrinsicArgument, 2> inArgs;
inArgs.push_back({lhs, lhs.getType()});
inArgs.push_back({rhs, rhs.getType()});
@@ -384,13 +387,13 @@ struct DotProductOpConversion
using HlfirIntrinsicConversion<hlfir::DotProductOp>::HlfirIntrinsicConversion;
mlir::LogicalResult
- matchAndRewrite(hlfir::DotProductOp dotProduct,
- mlir::PatternRewriter &rewriter) const override {
+ matchAndRewrite(hlfir::DotProductOp dotProduct, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, dotProduct.getOperation()};
const mlir::Location &loc = dotProduct->getLoc();
- mlir::Value lhs = dotProduct.getLhs();
- mlir::Value rhs = dotProduct.getRhs();
+ mlir::Value lhs = adaptor.getLhs();
+ mlir::Value rhs = adaptor.getRhs();
llvm::SmallVector<IntrinsicArgument, 2> inArgs;
inArgs.push_back({lhs, lhs.getType()});
inArgs.push_back({rhs, rhs.getType()});
@@ -415,12 +418,12 @@ class TransposeOpConversion
using HlfirIntrinsicConversion<hlfir::TransposeOp>::HlfirIntrinsicConversion;
mlir::LogicalResult
- matchAndRewrite(hlfir::TransposeOp transpose,
- mlir::PatternRewriter &rewriter) const override {
+ matchAndRewrite(hlfir::TransposeOp transpose, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, transpose.getOperation()};
const mlir::Location &loc = transpose->getLoc();
- mlir::Value arg = transpose.getArray();
+ mlir::Value arg = adaptor.getArray();
llvm::SmallVector<IntrinsicArgument, 1> inArgs;
inArgs.push_back({arg, arg.getType()});
@@ -445,13 +448,13 @@ struct MatmulTransposeOpConversion
hlfir::MatmulTransposeOp>::HlfirIntrinsicConversion;
mlir::LogicalResult
- matchAndRewrite(hlfir::MatmulTransposeOp multranspose,
- mlir::PatternRewriter &rewriter) const override {
+ matchAndRewrite(hlfir::MatmulTransposeOp multranspose, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, multranspose.getOperation()};
const mlir::Location &loc = multranspose->getLoc();
- mlir::Value lhs = multranspose.getLhs();
- mlir::Value rhs = multranspose.getRhs();
+ mlir::Value lhs = adaptor.getLhs();
+ mlir::Value rhs = adaptor.getRhs();
llvm::SmallVector<IntrinsicArgument, 2> inArgs;
inArgs.push_back({lhs, lhs.getType()});
inArgs.push_back({rhs, rhs.getType()});
|
@matthias-springer MLIR allows these conversions without further changes. Is this intentional or am I still misusing the API? They are safe to perform in place in the context of the dialects involved. |
// patterns in this pass are actually rewrite patterns. | ||
// the types might not match exactly (but are safe) | ||
// e.g. !hlfir.expr<?xi32> vs !hlfir.expr<2xi32> | ||
// TODO: is this allowed by MLIR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think this is allowed. (As long as you don't produce an invalid operation, everything's fine.)
The problem in the current implementation is that op->getResult(0).replaceAllUsesWith(base);
bypasses the rewriter API, i.e., making an IR modification without going through the rewriter. The rewriter equivalent would be RewriterBase::replaceAllUsesWith
. But that function is not supported by the conversion pattern rewriter yet.
But it should also not be necessary:
op->getResult(0).replaceAllUsesWith(base);
rewriter.replaceOp(op, base);
rewriter.replaceOp
already performs the replacement, so the previous replaceAllUsesWith
is redundant. I tried removing the replaceAllUsesWith
but then the dialect conversion no longer succeeds. I think the problem is that whenever the type changes during a rewriter.replaceOp
, the dialect conversion tries to build a type conversion op with the type converter. In this case here, no type conversion is needed, but that's not something that the dialect conversion considers. The error that I saw was due to the fact that no type converter is specified. (Take a look at the ConversionOpPattern
constructor, you can pass a type converter.) One thing you could try is passing a type converter that always returns the same SSA value, i.e., does not convert anything. Not sure if it will work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your help.
I don't think it works. After following your suggestion I am getting null operand errors from operations referring to the replaced value. It looks like rewriter.replaceOp(op, base)
is not replacing cases where the type changed (or maybe the old versions of the operation usage aren't removed before this check runs?).
I guess the right solution will be to use greedy pattern rewriter after all, even though this pass is conceptually performing a partial dialect conversion.
Closing this to re-do with a different solution |
New approach using GreedyPatternRewriter: #83438 |
In #83253 @matthias-springer pointed out that LowerHLFIRIntrinsics.cpp should not be using rewrite patterns with the dialect conversion driver.
The intention of this pass is to lower HLFIR intrinsic operations into FIR and so I think this best fits dialect conversion and so I have changed all of these into conversion patterns. Taking this approach also avoids test suite churn because GreedyPatternRewriter also performs canonicalization.
One remaining misuse of the MLIR API is that we replace values of one type with a different (although safe) type e.g.
!hlfir.expr<2xi32> -> !hlfir.expr<?xi32>. There isn't a convenient way to perform this conversion in IR at the moment because fir.convert does not accept !hlfir.expr.