Skip to content

[mlir] UnsignedWhenEquivalent: use greedy rewriter instead of dialect conversion #112454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 17, 2024

Conversation

Hardcode84
Copy link
Contributor

UnsignedWhenEquivalent doesn't really need any dialect conversion features and switching it normal patterns makes it more composable with other patterns-based transformations (and probably faster).

… conversion

UnsignedWhenEquivalent doesn't really need any dialect conversion features and switching it normal patterns makes it more composable with other patterns-based transformations.
@llvmbot
Copy link
Member

llvmbot commented Oct 16, 2024

@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

UnsignedWhenEquivalent doesn't really need any dialect conversion features and switching it normal patterns makes it more composable with other patterns-based transformations (and probably faster).


Full diff: https://github.com/llvm/llvm-project/pull/112454.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Arith/Transforms/Passes.h (+4)
  • (modified) mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp (+59-36)
  • (modified) mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir (+10-10)
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index aee64475171a43..e866ac518dbbcb 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -70,6 +70,10 @@ std::unique_ptr<Pass> createArithUnsignedWhenEquivalentPass();
 void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns,
                                            DataFlowSolver &solver);
 
+/// Replace signed ops with unsigned ones where they are proven equivalent.
+void populateUnsignedWhenEquivalentPatterns(RewritePatternSet &patterns,
+                                            DataFlowSolver &solver);
+
 /// Create a pass which do optimizations based on integer range analysis.
 std::unique_ptr<Pass> createIntRangeOptimizationsPass();
 
diff --git a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
index 4edce84bafd416..c76f56279db706 100644
--- a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
@@ -13,7 +13,8 @@
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir {
 namespace arith {
@@ -85,35 +86,60 @@ static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
 }
 
 namespace {
+class DataFlowListener : public RewriterBase::Listener {
+public:
+  DataFlowListener(DataFlowSolver &s) : s(s) {}
+
+protected:
+  void notifyOperationErased(Operation *op) override {
+    s.eraseState(s.getProgramPointAfter(op));
+    for (Value res : op->getResults())
+      s.eraseState(res);
+  }
+
+  DataFlowSolver &s;
+};
+
 template <typename Signed, typename Unsigned>
-struct ConvertOpToUnsigned : OpConversionPattern<Signed> {
-  using OpConversionPattern<Signed>::OpConversionPattern;
+struct ConvertOpToUnsigned final : public OpRewritePattern<Signed> {
+  ConvertOpToUnsigned(MLIRContext *context, DataFlowSolver &s)
+      : OpRewritePattern<Signed>(context), solver(s) {}
 
-  LogicalResult matchAndRewrite(Signed op, typename Signed::Adaptor adaptor,
-                                ConversionPatternRewriter &rw) const override {
-    rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(),
-                                    adaptor.getOperands(), op->getAttrs());
+  LogicalResult matchAndRewrite(Signed op, PatternRewriter &rw) const override {
+    if (failed(
+            staticallyNonNegative(this->solver, static_cast<Operation *>(op))))
+      return failure();
+
+    rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(), op->getOperands(),
+                                    op->getAttrs());
     return success();
   }
+
+private:
+  DataFlowSolver &solver;
 };
 
-struct ConvertCmpIToUnsigned : OpConversionPattern<CmpIOp> {
-  using OpConversionPattern<CmpIOp>::OpConversionPattern;
+struct ConvertCmpIToUnsigned final : public OpRewritePattern<CmpIOp> {
+  ConvertCmpIToUnsigned(MLIRContext *context, DataFlowSolver &s)
+      : OpRewritePattern<CmpIOp>(context), solver(s) {}
+
+  LogicalResult matchAndRewrite(CmpIOp op, PatternRewriter &rw) const override {
+    if (failed(isCmpIConvertable(this->solver, op)))
+      return failure();
 
-  LogicalResult matchAndRewrite(CmpIOp op, CmpIOpAdaptor adaptor,
-                                ConversionPatternRewriter &rw) const override {
     rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()),
                                   op.getLhs(), op.getRhs());
     return success();
   }
+
+private:
+  DataFlowSolver &solver;
 };
 
 struct ArithUnsignedWhenEquivalentPass
     : public arith::impl::ArithUnsignedWhenEquivalentBase<
           ArithUnsignedWhenEquivalentPass> {
-  /// Implementation structure: first find all equivalent ops and collect them,
-  /// then perform all the rewrites in a second pass over the target op. This
-  /// ensures that analysis results are not invalidated during rewriting.
+
   void runOnOperation() override {
     Operation *op = getOperation();
     MLIRContext *ctx = op->getContext();
@@ -123,35 +149,32 @@ struct ArithUnsignedWhenEquivalentPass
     if (failed(solver.initializeAndRun(op)))
       return signalPassFailure();
 
-    ConversionTarget target(*ctx);
-    target.addLegalDialect<ArithDialect>();
-    target.addDynamicallyLegalOp<DivSIOp, CeilDivSIOp, FloorDivSIOp, RemSIOp,
-                                 MinSIOp, MaxSIOp, ExtSIOp>(
-        [&solver](Operation *op) -> std::optional<bool> {
-          return failed(staticallyNonNegative(solver, op));
-        });
-    target.addDynamicallyLegalOp<CmpIOp>(
-        [&solver](CmpIOp op) -> std::optional<bool> {
-          return failed(isCmpIConvertable(solver, op));
-        });
+    DataFlowListener listener(solver);
 
     RewritePatternSet patterns(ctx);
-    patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
-                 ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
-                 ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
-                 ConvertOpToUnsigned<RemSIOp, RemUIOp>,
-                 ConvertOpToUnsigned<MinSIOp, MinUIOp>,
-                 ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
-                 ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
-        ctx);
-
-    if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
+    populateUnsignedWhenEquivalentPatterns(patterns, solver);
+
+    GreedyRewriteConfig config;
+    config.listener = &listener;
+
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
       signalPassFailure();
-    }
   }
 };
 } // end anonymous namespace
 
+void mlir::arith::populateUnsignedWhenEquivalentPatterns(
+    RewritePatternSet &patterns, DataFlowSolver &solver) {
+  patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
+               ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
+               ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
+               ConvertOpToUnsigned<RemSIOp, RemUIOp>,
+               ConvertOpToUnsigned<MinSIOp, MinUIOp>,
+               ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
+               ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
+      patterns.getContext(), solver);
+}
+
 std::unique_ptr<Pass> mlir::arith::createArithUnsignedWhenEquivalentPass() {
   return std::make_unique<ArithUnsignedWhenEquivalentPass>();
 }
diff --git a/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir b/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir
index 49bd74cfe9124a..e015d2d7543c93 100644
--- a/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir
+++ b/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir
@@ -12,7 +12,7 @@
 // CHECK: arith.cmpi slt
 // CHECK: arith.cmpi sge
 // CHECK: arith.cmpi sgt
-func.func @not_with_maybe_overflow(%arg0 : i32) {
+func.func @not_with_maybe_overflow(%arg0 : i32) -> (i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1) {
     %ci32_smax = arith.constant 0x7fffffff : i32
     %c1 = arith.constant 1 : i32
     %c4 = arith.constant 4 : i32
@@ -29,7 +29,7 @@ func.func @not_with_maybe_overflow(%arg0 : i32) {
     %10 = arith.cmpi slt, %1, %c4 : i32
     %11 = arith.cmpi sge, %1, %c4 : i32
     %12 = arith.cmpi sgt, %1, %c4 : i32
-    func.return
+    func.return %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12 : i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1
 }
 
 // CHECK-LABEL: func @yes_with_no_overflow
@@ -44,7 +44,7 @@ func.func @not_with_maybe_overflow(%arg0 : i32) {
 // CHECK: arith.cmpi ult
 // CHECK: arith.cmpi uge
 // CHECK: arith.cmpi ugt
-func.func @yes_with_no_overflow(%arg0 : i32) {
+func.func @yes_with_no_overflow(%arg0 : i32) -> (i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1) {
     %ci32_almost_smax = arith.constant 0x7ffffffe : i32
     %c1 = arith.constant 1 : i32
     %c4 = arith.constant 4 : i32
@@ -61,7 +61,7 @@ func.func @yes_with_no_overflow(%arg0 : i32) {
     %10 = arith.cmpi slt, %1, %c4 : i32
     %11 = arith.cmpi sge, %1, %c4 : i32
     %12 = arith.cmpi sgt, %1, %c4 : i32
-    func.return
+    func.return %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12 : i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1
 }
 
 // CHECK-LABEL: func @preserves_structure
@@ -90,20 +90,20 @@ func.func @preserves_structure(%arg0 : memref<8xindex>) {
 func.func private @external() -> i8
 
 // CHECK-LABEL: @dead_code
-func.func @dead_code() {
+func.func @dead_code() -> i8 {
   %0 = call @external() : () -> i8
   // CHECK: arith.floordivsi
   %1 = arith.floordivsi %0, %0 : i8
-  return
+  return %1 : i8
 }
 
 // Make sure not crash.
 // CHECK-LABEL: @no_integer_or_index
-func.func @no_integer_or_index() { 
+func.func @no_integer_or_index(%arg0 : vector<1xi32> ) -> vector<1xi1> {
   // CHECK: arith.cmpi
   %cst_0 = arith.constant dense<[0]> : vector<1xi32> 
-  %cmp = arith.cmpi slt, %cst_0, %cst_0 : vector<1xi32> 
-  return
+  %cmp = arith.cmpi slt, %cst_0, %arg0 : vector<1xi32>
+  return %cmp : vector<1xi1>
 }
 
 // CHECK-LABEL: @gpu_func
@@ -113,4 +113,4 @@ func.func @gpu_func(%arg0: memref<2x32xf32>, %arg1: memref<2x32xf32>, %arg2: mem
     gpu.terminator
   } 
   return %arg1 : memref<2x32xf32> 
-}  
+}

adaptor.getOperands(), op->getAttrs());
LogicalResult matchAndRewrite(Signed op, PatternRewriter &rw) const override {
if (failed(
staticallyNonNegative(this->solver, static_cast<Operation *>(op))))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not exactly sure where to leave this comment so will just do so here: the backing range analysis makes an assumption that IndexType is 64bits. While mostly harmless to the analysis, this pattern builds on that to apply an optimization that is not valid for a 32bit IndexType.

I don't have a good solution to this but it took me some sleuthing in an earlier incarnation to understand.

Commenting here because "staticallyNonNegative" is only applicable to 64bit IndexType. Probably should at least call for a comment somewhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a comment.

Copy link
Contributor

@stellaraccident stellaraccident left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the cleanup.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

template <typename Signed, typename Unsigned>
struct ConvertOpToUnsigned : OpConversionPattern<Signed> {
using OpConversionPattern<Signed>::OpConversionPattern;
struct ConvertOpToUnsigned final : public OpRewritePattern<Signed> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
struct ConvertOpToUnsigned final : public OpRewritePattern<Signed> {
struct ConvertOpToUnsigned final : OpRewritePattern<Signed> {

also below

}

// Make sure not crash.
// CHECK-LABEL: @no_integer_or_index
func.func @no_integer_or_index() {
func.func @no_integer_or_index(%arg0 : vector<1xi32> ) -> vector<1xi1> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
func.func @no_integer_or_index(%arg0 : vector<1xi32> ) -> vector<1xi1> {
func.func @no_integer_or_index(%arg0: vector<1xi32>) -> vector<1xi1> {

if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
populateUnsignedWhenEquivalentPatterns(patterns, solver);

GreedyRewriteConfig config;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we expect to have to iterate to convergence here? Otherwise can we set options to limit to a single iteration?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, it should finish in single iteration.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, adding maxIterations = 1 causing it to fail to converge. I think I will leave things as is for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably also need to change it to top down iteration in order to get the same convergence behavior as before.

I've definitely seen combined passes that are doing other optimizations along with unsigned conversions require multiple iterations to converge (and be more efficient with bottom up iteration), but I expect that this simple test pass just needs one top down pass through the IR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It fails to converge (i.e. applyPatternsAndFoldGreedily returns failure) even with

    config.maxIterations = 1;
    config.useTopDownTraversal = true;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly test pass anyway so I don't see much problem here. Downstream we plan to combine it with other patterns. And using greedy driver here may not be ideal, but it's still better than current dialect conversion driver.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is a test pass, it should be moved to the test folder and be named accordingly though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't intended as a test pass - you're meant to be able to run this (barring the usual philosophical disagreements about even having non-test passes upstream)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote this code using the dialect converter because I wanted a one-shot "walk this function exactly once and apply matching patterns"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't intended as a test pass

Well, then my point stands: we shouldn't involve the greedy rewriter here.

@Hardcode84 Hardcode84 merged commit 6902b39 into llvm:main Oct 17, 2024
8 checks passed
@Hardcode84 Hardcode84 deleted the unsigned-refac branch October 17, 2024 09:23
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved, generally ... oh, it's merged

@@ -29,6 +30,9 @@ using namespace mlir::dataflow;
/// Succeeds when a value is statically non-negative in that it has a lower
/// bound on its value (if it is treated as signed) and that bound is
/// non-negative.
// TODO: IntegerRangeAnalysis internally assumes index is 64bit and this pattern
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is true - all that IntegerRangeAnalysis does is store index as 64-bit. The implementations for various ops on index ought to handle both cases - if they don't might be a bug

kuhar added a commit that referenced this pull request Oct 31, 2024
This is intended as a fast pattern rewrite driver for the cases when a
simple walk gets the job done but we would still want to implement it in
terms of rewrite patterns (that can be used with the greedy pattern
rewrite driver downstream).

The new driver is inspired by the discussion in
#112454 and the LLVM Dev
presentation from @matthias-springer earlier this week.

This limitation comes with some limitations:
* It does not repeat until a fixpoint or revisit ops modified in place
or newly created ops. In general, it only walks forward (in the
post-order).
* `matchAndRewrite` can only erase the matched op or its descendants.
  This is verified under expensive checks.
* It does not perform folding / DCE.
 
We could probably relax some of these in the future without sacrificing
too much performance.
smallp-o-p pushed a commit to smallp-o-p/llvm-project that referenced this pull request Nov 3, 2024
This is intended as a fast pattern rewrite driver for the cases when a
simple walk gets the job done but we would still want to implement it in
terms of rewrite patterns (that can be used with the greedy pattern
rewrite driver downstream).

The new driver is inspired by the discussion in
llvm#112454 and the LLVM Dev
presentation from @matthias-springer earlier this week.

This limitation comes with some limitations:
* It does not repeat until a fixpoint or revisit ops modified in place
or newly created ops. In general, it only walks forward (in the
post-order).
* `matchAndRewrite` can only erase the matched op or its descendants.
  This is verified under expensive checks.
* It does not perform folding / DCE.
 
We could probably relax some of these in the future without sacrificing
too much performance.
NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
This is intended as a fast pattern rewrite driver for the cases when a
simple walk gets the job done but we would still want to implement it in
terms of rewrite patterns (that can be used with the greedy pattern
rewrite driver downstream).

The new driver is inspired by the discussion in
llvm#112454 and the LLVM Dev
presentation from @matthias-springer earlier this week.

This limitation comes with some limitations:
* It does not repeat until a fixpoint or revisit ops modified in place
or newly created ops. In general, it only walks forward (in the
post-order).
* `matchAndRewrite` can only erase the matched op or its descendants.
  This is verified under expensive checks.
* It does not perform folding / DCE.
 
We could probably relax some of these in the future without sacrificing
too much performance.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants