Skip to content

[flang][hlfir][NFC] Fix mlir misuse in LowerHLFIRIntrinsics #83293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

tblah
Copy link
Contributor

@tblah tblah commented Feb 28, 2024

In #83253 @matthias-springer pointed out that LowerHLFIRIntrinsics.cpp should not be using rewrite patterns with the dialect conversion driver.

The intention of this pass is to lower HLFIR intrinsic operations into FIR and so I think this best fits dialect conversion and so I have changed all of these into conversion patterns. Taking this approach also avoids test suite churn because GreedyPatternRewriter also performs canonicalization.

One remaining misuse of the MLIR API is that we replace values of one type with a different (although safe) type e.g.
!hlfir.expr<2xi32> -> !hlfir.expr<?xi32>. There isn't a convenient way to perform this conversion in IR at the moment because fir.convert does not accept !hlfir.expr.

In llvm#83253 @matthias-springer pointed out that LowerHLFIRIntrinsics.cpp
should not be using rewrite patterns with the dialect conversion driver.

The intention of this pass is to lower HLFIR intrinsic operations into
FIR and so I think this best fits dialect conversion and so I have
changed all of these into conversion patterns. Taking this approach
also avoids test suite churn because GreedyPatternRewriter also performs
canonicalization.

One remaining misuse of the MLIR API is that we replace values of one
type with a different (although safe) type e.g.
!hlfir.expr<2xi32> -> !hlfir.expr<?xi32>. There isn't a convenient way
to perform this conversion in IR at the moment because fir.convert does
not accept !hlfir.expr.
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Feb 28, 2024
@llvmbot
Copy link
Member

llvmbot commented Feb 28, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Tom Eccles (tblah)

Changes

In #83253 @matthias-springer pointed out that LowerHLFIRIntrinsics.cpp should not be using rewrite patterns with the dialect conversion driver.

The intention of this pass is to lower HLFIR intrinsic operations into FIR and so I think this best fits dialect conversion and so I have changed all of these into conversion patterns. Taking this approach also avoids test suite churn because GreedyPatternRewriter also performs canonicalization.

One remaining misuse of the MLIR API is that we replace values of one type with a different (although safe) type e.g.
!hlfir.expr<2xi32> -> !hlfir.expr<?xi32>. There isn't a convenient way to perform this conversion in IR at the moment because fir.convert does not accept !hlfir.expr.


Full diff: https://github.com/llvm/llvm-project/pull/83293.diff

1 Files Affected:

  • (modified) flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp (+51-48)
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
index 377cc44392028f..b2e02376599636 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
@@ -36,10 +36,10 @@ namespace {
 /// Base class for passes converting transformational intrinsic operations into
 /// runtime calls
 template <class OP>
-class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
+class HlfirIntrinsicConversion : public mlir::OpConversionPattern<OP> {
 public:
   explicit HlfirIntrinsicConversion(mlir::MLIRContext *ctx)
-      : mlir::OpRewritePattern<OP>{ctx} {
+      : mlir::OpConversionPattern<OP>{ctx} {
     // required for cases where intrinsics are chained together e.g.
     // matmul(matmul(a, b), c)
     // because converting the inner operation then invalidates the
@@ -145,7 +145,7 @@ class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
   void processReturnValue(mlir::Operation *op,
                           const fir::ExtendedValue &resultExv, bool mustBeFreed,
                           fir::FirOpBuilder &builder,
-                          mlir::PatternRewriter &rewriter) const {
+                          mlir::ConversionPatternRewriter &rewriter) const {
     mlir::Location loc = op->getLoc();
 
     mlir::Value firBase = fir::getBase(resultExv);
@@ -176,13 +176,9 @@ class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
           rewriter.eraseOp(use);
       }
     }
-    // TODO: This entire pass should be a greedy pattern rewrite or a manual
-    // IR traversal. A dialect conversion cannot be used here because
-    // `replaceAllUsesWith` is not supported. Similarly, `replaceOp` is not
-    // suitable because "op->getResult(0)" and "base" can have different types.
-    // In such a case, the dialect conversion will attempt to convert the type,
-    // but no type converter is specified in this pass. Also note that all
-    // patterns in this pass are actually rewrite patterns.
+    // the types might not match exactly (but are safe)
+    // e.g. !hlfir.expr<?xi32> vs !hlfir.expr<2xi32>
+    // TODO: is this allowed by MLIR?
     op->getResult(0).replaceAllUsesWith(base);
     rewriter.replaceOp(op, base);
   }
@@ -203,48 +199,53 @@ class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
       typename HlfirIntrinsicConversion<OP>::IntrinsicArgument;
   using HlfirIntrinsicConversion<OP>::lowerArguments;
   using HlfirIntrinsicConversion<OP>::processReturnValue;
+  using Adaptor = typename OP::Adaptor;
 
 protected:
-  auto buildNumericalArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
+  auto buildNumericalArgs(mlir::Operation *operation, Adaptor adaptor,
+                          mlir::Type i32, mlir::Type logicalType,
                           mlir::PatternRewriter &rewriter,
                           std::string opName) const {
     llvm::SmallVector<IntrinsicArgument, 3> inArgs;
-    inArgs.push_back({operation.getArray(), operation.getArray().getType()});
-    inArgs.push_back({operation.getDim(), i32});
-    inArgs.push_back({operation.getMask(), logicalType});
+    inArgs.push_back({adaptor.getArray(), adaptor.getArray().getType()});
+    inArgs.push_back({adaptor.getDim(), i32});
+    inArgs.push_back({adaptor.getMask(), logicalType});
     auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
     return lowerArguments(operation, inArgs, rewriter, argLowering);
   };
 
-  auto buildMinMaxLocArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
+  auto buildMinMaxLocArgs(mlir::Operation *operation, Adaptor adaptor,
+                          mlir::Type i32, mlir::Type logicalType,
                           mlir::PatternRewriter &rewriter, std::string opName,
                           fir::FirOpBuilder builder) const {
     llvm::SmallVector<IntrinsicArgument, 3> inArgs;
-    inArgs.push_back({operation.getArray(), operation.getArray().getType()});
-    inArgs.push_back({operation.getDim(), i32});
-    inArgs.push_back({operation.getMask(), logicalType});
+    inArgs.push_back({adaptor.getArray(), adaptor.getArray().getType()});
+    inArgs.push_back({adaptor.getDim(), i32});
+    inArgs.push_back({adaptor.getMask(), logicalType});
     mlir::Value kind = builder.createIntegerConstant(
-        operation->getLoc(), i32, getKindForType(operation.getType()));
+        operation->getLoc(), i32,
+        getKindForType(operation->getResult(0).getType()));
     inArgs.push_back({kind, i32});
-    inArgs.push_back({operation.getBack(), i32});
+    inArgs.push_back({adaptor.getBack(), i32});
     auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
     return lowerArguments(operation, inArgs, rewriter, argLowering);
   };
 
-  auto buildLogicalArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
+  auto buildLogicalArgs(mlir::Operation *operation, Adaptor adaptor,
+                        mlir::Type i32, mlir::Type logicalType,
                         mlir::PatternRewriter &rewriter,
                         std::string opName) const {
     llvm::SmallVector<IntrinsicArgument, 2> inArgs;
-    inArgs.push_back({operation.getMask(), logicalType});
-    inArgs.push_back({operation.getDim(), i32});
+    inArgs.push_back({adaptor.getMask(), logicalType});
+    inArgs.push_back({adaptor.getDim(), i32});
     auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
     return lowerArguments(operation, inArgs, rewriter, argLowering);
   };
 
 public:
   mlir::LogicalResult
-  matchAndRewrite(OP operation,
-                  mlir::PatternRewriter &rewriter) const override {
+  matchAndRewrite(OP operation, Adaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
     std::string opName;
     if constexpr (std::is_same_v<OP, hlfir::SumOp>) {
       opName = "sum";
@@ -279,13 +280,15 @@ class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
                   std::is_same_v<OP, hlfir::ProductOp> ||
                   std::is_same_v<OP, hlfir::MaxvalOp> ||
                   std::is_same_v<OP, hlfir::MinvalOp>) {
-      args = buildNumericalArgs(operation, i32, logicalType, rewriter, opName);
+      args = buildNumericalArgs(operation, adaptor, i32, logicalType, rewriter,
+                                opName);
     } else if constexpr (std::is_same_v<OP, hlfir::MinlocOp> ||
                          std::is_same_v<OP, hlfir::MaxlocOp>) {
-      args = buildMinMaxLocArgs(operation, i32, logicalType, rewriter, opName,
-                                builder);
+      args = buildMinMaxLocArgs(operation, adaptor, i32, logicalType, rewriter,
+                                opName, builder);
     } else {
-      args = buildLogicalArgs(operation, i32, logicalType, rewriter, opName);
+      args = buildLogicalArgs(operation, adaptor, i32, logicalType, rewriter,
+                              opName);
     }
 
     mlir::Type scalarResultType =
@@ -319,8 +322,8 @@ struct CountOpConversion : public HlfirIntrinsicConversion<hlfir::CountOp> {
   using HlfirIntrinsicConversion<hlfir::CountOp>::HlfirIntrinsicConversion;
 
   mlir::LogicalResult
-  matchAndRewrite(hlfir::CountOp count,
-                  mlir::PatternRewriter &rewriter) const override {
+  matchAndRewrite(hlfir::CountOp count, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
     fir::FirOpBuilder builder{rewriter, count.getOperation()};
     const mlir::Location &loc = count->getLoc();
 
@@ -329,8 +332,8 @@ struct CountOpConversion : public HlfirIntrinsicConversion<hlfir::CountOp> {
         builder.getContext(), builder.getKindMap().defaultLogicalKind());
 
     llvm::SmallVector<IntrinsicArgument, 3> inArgs;
-    inArgs.push_back({count.getMask(), logicalType});
-    inArgs.push_back({count.getDim(), i32});
+    inArgs.push_back({adaptor.getMask(), logicalType});
+    inArgs.push_back({adaptor.getDim(), i32});
     mlir::Value kind = builder.createIntegerConstant(
         count->getLoc(), i32, getKindForType(count.getType()));
     inArgs.push_back({kind, i32});
@@ -353,13 +356,13 @@ struct MatmulOpConversion : public HlfirIntrinsicConversion<hlfir::MatmulOp> {
   using HlfirIntrinsicConversion<hlfir::MatmulOp>::HlfirIntrinsicConversion;
 
   mlir::LogicalResult
-  matchAndRewrite(hlfir::MatmulOp matmul,
-                  mlir::PatternRewriter &rewriter) const override {
+  matchAndRewrite(hlfir::MatmulOp matmul, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
     fir::FirOpBuilder builder{rewriter, matmul.getOperation()};
     const mlir::Location &loc = matmul->getLoc();
 
-    mlir::Value lhs = matmul.getLhs();
-    mlir::Value rhs = matmul.getRhs();
+    mlir::Value lhs = adaptor.getLhs();
+    mlir::Value rhs = adaptor.getRhs();
     llvm::SmallVector<IntrinsicArgument, 2> inArgs;
     inArgs.push_back({lhs, lhs.getType()});
     inArgs.push_back({rhs, rhs.getType()});
@@ -384,13 +387,13 @@ struct DotProductOpConversion
   using HlfirIntrinsicConversion<hlfir::DotProductOp>::HlfirIntrinsicConversion;
 
   mlir::LogicalResult
-  matchAndRewrite(hlfir::DotProductOp dotProduct,
-                  mlir::PatternRewriter &rewriter) const override {
+  matchAndRewrite(hlfir::DotProductOp dotProduct, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
     fir::FirOpBuilder builder{rewriter, dotProduct.getOperation()};
     const mlir::Location &loc = dotProduct->getLoc();
 
-    mlir::Value lhs = dotProduct.getLhs();
-    mlir::Value rhs = dotProduct.getRhs();
+    mlir::Value lhs = adaptor.getLhs();
+    mlir::Value rhs = adaptor.getRhs();
     llvm::SmallVector<IntrinsicArgument, 2> inArgs;
     inArgs.push_back({lhs, lhs.getType()});
     inArgs.push_back({rhs, rhs.getType()});
@@ -415,12 +418,12 @@ class TransposeOpConversion
   using HlfirIntrinsicConversion<hlfir::TransposeOp>::HlfirIntrinsicConversion;
 
   mlir::LogicalResult
-  matchAndRewrite(hlfir::TransposeOp transpose,
-                  mlir::PatternRewriter &rewriter) const override {
+  matchAndRewrite(hlfir::TransposeOp transpose, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
     fir::FirOpBuilder builder{rewriter, transpose.getOperation()};
     const mlir::Location &loc = transpose->getLoc();
 
-    mlir::Value arg = transpose.getArray();
+    mlir::Value arg = adaptor.getArray();
     llvm::SmallVector<IntrinsicArgument, 1> inArgs;
     inArgs.push_back({arg, arg.getType()});
 
@@ -445,13 +448,13 @@ struct MatmulTransposeOpConversion
       hlfir::MatmulTransposeOp>::HlfirIntrinsicConversion;
 
   mlir::LogicalResult
-  matchAndRewrite(hlfir::MatmulTransposeOp multranspose,
-                  mlir::PatternRewriter &rewriter) const override {
+  matchAndRewrite(hlfir::MatmulTransposeOp multranspose, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
     fir::FirOpBuilder builder{rewriter, multranspose.getOperation()};
     const mlir::Location &loc = multranspose->getLoc();
 
-    mlir::Value lhs = multranspose.getLhs();
-    mlir::Value rhs = multranspose.getRhs();
+    mlir::Value lhs = adaptor.getLhs();
+    mlir::Value rhs = adaptor.getRhs();
     llvm::SmallVector<IntrinsicArgument, 2> inArgs;
     inArgs.push_back({lhs, lhs.getType()});
     inArgs.push_back({rhs, rhs.getType()});

@tblah
Copy link
Contributor Author

tblah commented Feb 28, 2024

@matthias-springer MLIR allows these conversions without further changes. Is this intentional or am I still misusing the API? They are safe to perform in place in the context of the dialects involved.

// patterns in this pass are actually rewrite patterns.
// the types might not match exactly (but are safe)
// e.g. !hlfir.expr<?xi32> vs !hlfir.expr<2xi32>
// TODO: is this allowed by MLIR?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think this is allowed. (As long as you don't produce an invalid operation, everything's fine.)

The problem in the current implementation is that op->getResult(0).replaceAllUsesWith(base); bypasses the rewriter API, i.e., making an IR modification without going through the rewriter. The rewriter equivalent would be RewriterBase::replaceAllUsesWith. But that function is not supported by the conversion pattern rewriter yet.

But it should also not be necessary:

    op->getResult(0).replaceAllUsesWith(base);
    rewriter.replaceOp(op, base);

rewriter.replaceOp already performs the replacement, so the previous replaceAllUsesWith is redundant. I tried removing the replaceAllUsesWith but then the dialect conversion no longer succeeds. I think the problem is that whenever the type changes during a rewriter.replaceOp, the dialect conversion tries to build a type conversion op with the type converter. In this case here, no type conversion is needed, but that's not something that the dialect conversion considers. The error that I saw was due to the fact that no type converter is specified. (Take a look at the ConversionOpPattern constructor, you can pass a type converter.) One thing you could try is passing a type converter that always returns the same SSA value, i.e., does not convert anything. Not sure if it will work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your help.

I don't think it works. After following your suggestion I am getting null operand errors from operations referring to the replaced value. It looks like rewriter.replaceOp(op, base) is not replacing cases where the type changed (or maybe the old versions of the operation usage aren't removed before this check runs?).

I guess the right solution will be to use greedy pattern rewriter after all, even though this pass is conceptually performing a partial dialect conversion.

@tblah
Copy link
Contributor Author

tblah commented Feb 29, 2024

Closing this to re-do with a different solution

@tblah tblah closed this Feb 29, 2024
@tblah
Copy link
Contributor Author

tblah commented Feb 29, 2024

New approach using GreedyPatternRewriter: #83438

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants