-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] UnsignedWhenEquivalent: use greedy rewriter instead of dialect conversion #112454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… conversion UnsignedWhenEquivalent doesn't really need any dialect conversion features and switching it normal patterns makes it more composable with other patterns-based transformations.
@llvm/pr-subscribers-mlir-arith @llvm/pr-subscribers-mlir Author: Ivan Butygin (Hardcode84) Changes
Full diff: https://github.com/llvm/llvm-project/pull/112454.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index aee64475171a43..e866ac518dbbcb 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -70,6 +70,10 @@ std::unique_ptr<Pass> createArithUnsignedWhenEquivalentPass();
void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns,
DataFlowSolver &solver);
+/// Replace signed ops with unsigned ones where they are proven equivalent.
+void populateUnsignedWhenEquivalentPatterns(RewritePatternSet &patterns,
+ DataFlowSolver &solver);
+
/// Create a pass which do optimizations based on integer range analysis.
std::unique_ptr<Pass> createIntRangeOptimizationsPass();
diff --git a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
index 4edce84bafd416..c76f56279db706 100644
--- a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
@@ -13,7 +13,8 @@
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace arith {
@@ -85,35 +86,60 @@ static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
}
namespace {
+class DataFlowListener : public RewriterBase::Listener {
+public:
+ DataFlowListener(DataFlowSolver &s) : s(s) {}
+
+protected:
+ void notifyOperationErased(Operation *op) override {
+ s.eraseState(s.getProgramPointAfter(op));
+ for (Value res : op->getResults())
+ s.eraseState(res);
+ }
+
+ DataFlowSolver &s;
+};
+
template <typename Signed, typename Unsigned>
-struct ConvertOpToUnsigned : OpConversionPattern<Signed> {
- using OpConversionPattern<Signed>::OpConversionPattern;
+struct ConvertOpToUnsigned final : public OpRewritePattern<Signed> {
+ ConvertOpToUnsigned(MLIRContext *context, DataFlowSolver &s)
+ : OpRewritePattern<Signed>(context), solver(s) {}
- LogicalResult matchAndRewrite(Signed op, typename Signed::Adaptor adaptor,
- ConversionPatternRewriter &rw) const override {
- rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(),
- adaptor.getOperands(), op->getAttrs());
+ LogicalResult matchAndRewrite(Signed op, PatternRewriter &rw) const override {
+ if (failed(
+ staticallyNonNegative(this->solver, static_cast<Operation *>(op))))
+ return failure();
+
+ rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(), op->getOperands(),
+ op->getAttrs());
return success();
}
+
+private:
+ DataFlowSolver &solver;
};
-struct ConvertCmpIToUnsigned : OpConversionPattern<CmpIOp> {
- using OpConversionPattern<CmpIOp>::OpConversionPattern;
+struct ConvertCmpIToUnsigned final : public OpRewritePattern<CmpIOp> {
+ ConvertCmpIToUnsigned(MLIRContext *context, DataFlowSolver &s)
+ : OpRewritePattern<CmpIOp>(context), solver(s) {}
+
+ LogicalResult matchAndRewrite(CmpIOp op, PatternRewriter &rw) const override {
+ if (failed(isCmpIConvertable(this->solver, op)))
+ return failure();
- LogicalResult matchAndRewrite(CmpIOp op, CmpIOpAdaptor adaptor,
- ConversionPatternRewriter &rw) const override {
rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()),
op.getLhs(), op.getRhs());
return success();
}
+
+private:
+ DataFlowSolver &solver;
};
struct ArithUnsignedWhenEquivalentPass
: public arith::impl::ArithUnsignedWhenEquivalentBase<
ArithUnsignedWhenEquivalentPass> {
- /// Implementation structure: first find all equivalent ops and collect them,
- /// then perform all the rewrites in a second pass over the target op. This
- /// ensures that analysis results are not invalidated during rewriting.
+
void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *ctx = op->getContext();
@@ -123,35 +149,32 @@ struct ArithUnsignedWhenEquivalentPass
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();
- ConversionTarget target(*ctx);
- target.addLegalDialect<ArithDialect>();
- target.addDynamicallyLegalOp<DivSIOp, CeilDivSIOp, FloorDivSIOp, RemSIOp,
- MinSIOp, MaxSIOp, ExtSIOp>(
- [&solver](Operation *op) -> std::optional<bool> {
- return failed(staticallyNonNegative(solver, op));
- });
- target.addDynamicallyLegalOp<CmpIOp>(
- [&solver](CmpIOp op) -> std::optional<bool> {
- return failed(isCmpIConvertable(solver, op));
- });
+ DataFlowListener listener(solver);
RewritePatternSet patterns(ctx);
- patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
- ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
- ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
- ConvertOpToUnsigned<RemSIOp, RemUIOp>,
- ConvertOpToUnsigned<MinSIOp, MinUIOp>,
- ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
- ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
- ctx);
-
- if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
+ populateUnsignedWhenEquivalentPatterns(patterns, solver);
+
+ GreedyRewriteConfig config;
+ config.listener = &listener;
+
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
signalPassFailure();
- }
}
};
} // end anonymous namespace
+void mlir::arith::populateUnsignedWhenEquivalentPatterns(
+ RewritePatternSet &patterns, DataFlowSolver &solver) {
+ patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
+ ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
+ ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
+ ConvertOpToUnsigned<RemSIOp, RemUIOp>,
+ ConvertOpToUnsigned<MinSIOp, MinUIOp>,
+ ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
+ ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
+ patterns.getContext(), solver);
+}
+
std::unique_ptr<Pass> mlir::arith::createArithUnsignedWhenEquivalentPass() {
return std::make_unique<ArithUnsignedWhenEquivalentPass>();
}
diff --git a/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir b/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir
index 49bd74cfe9124a..e015d2d7543c93 100644
--- a/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir
+++ b/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir
@@ -12,7 +12,7 @@
// CHECK: arith.cmpi slt
// CHECK: arith.cmpi sge
// CHECK: arith.cmpi sgt
-func.func @not_with_maybe_overflow(%arg0 : i32) {
+func.func @not_with_maybe_overflow(%arg0 : i32) -> (i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1) {
%ci32_smax = arith.constant 0x7fffffff : i32
%c1 = arith.constant 1 : i32
%c4 = arith.constant 4 : i32
@@ -29,7 +29,7 @@ func.func @not_with_maybe_overflow(%arg0 : i32) {
%10 = arith.cmpi slt, %1, %c4 : i32
%11 = arith.cmpi sge, %1, %c4 : i32
%12 = arith.cmpi sgt, %1, %c4 : i32
- func.return
+ func.return %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12 : i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1
}
// CHECK-LABEL: func @yes_with_no_overflow
@@ -44,7 +44,7 @@ func.func @not_with_maybe_overflow(%arg0 : i32) {
// CHECK: arith.cmpi ult
// CHECK: arith.cmpi uge
// CHECK: arith.cmpi ugt
-func.func @yes_with_no_overflow(%arg0 : i32) {
+func.func @yes_with_no_overflow(%arg0 : i32) -> (i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1) {
%ci32_almost_smax = arith.constant 0x7ffffffe : i32
%c1 = arith.constant 1 : i32
%c4 = arith.constant 4 : i32
@@ -61,7 +61,7 @@ func.func @yes_with_no_overflow(%arg0 : i32) {
%10 = arith.cmpi slt, %1, %c4 : i32
%11 = arith.cmpi sge, %1, %c4 : i32
%12 = arith.cmpi sgt, %1, %c4 : i32
- func.return
+ func.return %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12 : i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1
}
// CHECK-LABEL: func @preserves_structure
@@ -90,20 +90,20 @@ func.func @preserves_structure(%arg0 : memref<8xindex>) {
func.func private @external() -> i8
// CHECK-LABEL: @dead_code
-func.func @dead_code() {
+func.func @dead_code() -> i8 {
%0 = call @external() : () -> i8
// CHECK: arith.floordivsi
%1 = arith.floordivsi %0, %0 : i8
- return
+ return %1 : i8
}
// Make sure not crash.
// CHECK-LABEL: @no_integer_or_index
-func.func @no_integer_or_index() {
+func.func @no_integer_or_index(%arg0 : vector<1xi32> ) -> vector<1xi1> {
// CHECK: arith.cmpi
%cst_0 = arith.constant dense<[0]> : vector<1xi32>
- %cmp = arith.cmpi slt, %cst_0, %cst_0 : vector<1xi32>
- return
+ %cmp = arith.cmpi slt, %cst_0, %arg0 : vector<1xi32>
+ return %cmp : vector<1xi1>
}
// CHECK-LABEL: @gpu_func
@@ -113,4 +113,4 @@ func.func @gpu_func(%arg0: memref<2x32xf32>, %arg1: memref<2x32xf32>, %arg2: mem
gpu.terminator
}
return %arg1 : memref<2x32xf32>
-}
+}
|
adaptor.getOperands(), op->getAttrs()); | ||
LogicalResult matchAndRewrite(Signed op, PatternRewriter &rw) const override { | ||
if (failed( | ||
staticallyNonNegative(this->solver, static_cast<Operation *>(op)))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not exactly sure where to leave this comment so will just do so here: the backing range analysis makes an assumption that IndexType is 64bits. While mostly harmless to the analysis, this pattern builds on that to apply an optimization that is not valid for a 32bit IndexType.
I don't have a good solution to this but it took me some sleuthing in an earlier incarnation to understand.
Commenting here because "staticallyNonNegative" is only applicable to 64bit IndexType. Probably should at least call for a comment somewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the cleanup.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
template <typename Signed, typename Unsigned> | ||
struct ConvertOpToUnsigned : OpConversionPattern<Signed> { | ||
using OpConversionPattern<Signed>::OpConversionPattern; | ||
struct ConvertOpToUnsigned final : public OpRewritePattern<Signed> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
struct ConvertOpToUnsigned final : public OpRewritePattern<Signed> { | |
struct ConvertOpToUnsigned final : OpRewritePattern<Signed> { |
also below
} | ||
|
||
// Make sure not crash. | ||
// CHECK-LABEL: @no_integer_or_index | ||
func.func @no_integer_or_index() { | ||
func.func @no_integer_or_index(%arg0 : vector<1xi32> ) -> vector<1xi1> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
func.func @no_integer_or_index(%arg0 : vector<1xi32> ) -> vector<1xi1> { | |
func.func @no_integer_or_index(%arg0: vector<1xi32>) -> vector<1xi1> { |
if (failed(applyPartialConversion(op, target, std::move(patterns)))) { | ||
populateUnsignedWhenEquivalentPatterns(patterns, solver); | ||
|
||
GreedyRewriteConfig config; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we expect to have to iterate to convergence here? Otherwise can we set options to limit to a single iteration?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think, it should finish in single iteration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, adding maxIterations = 1
causing it to fail to converge. I think I will leave things as is for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You probably also need to change it to top down iteration in order to get the same convergence behavior as before.
I've definitely seen combined passes that are doing other optimizations along with unsigned conversions require multiple iterations to converge (and be more efficient with bottom up iteration), but I expect that this simple test pass just needs one top down pass through the IR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It fails to converge (i.e. applyPatternsAndFoldGreedily
returns failure) even with
config.maxIterations = 1;
config.useTopDownTraversal = true;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is mostly test pass anyway so I don't see much problem here. Downstream we plan to combine it with other patterns. And using greedy driver here may not be ideal, but it's still better than current dialect conversion driver.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is a test pass, it should be moved to the test folder and be named accordingly though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't intended as a test pass - you're meant to be able to run this (barring the usual philosophical disagreements about even having non-test passes upstream)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wrote this code using the dialect converter because I wanted a one-shot "walk this function exactly once and apply matching patterns"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't intended as a test pass
Well, then my point stands: we shouldn't involve the greedy rewriter here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved, generally ... oh, it's merged
@@ -29,6 +30,9 @@ using namespace mlir::dataflow; | |||
/// Succeeds when a value is statically non-negative in that it has a lower | |||
/// bound on its value (if it is treated as signed) and that bound is | |||
/// non-negative. | |||
// TODO: IntegerRangeAnalysis internally assumes index is 64bit and this pattern |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this is true - all that IntegerRangeAnalysis
does is store index
as 64-bit. The implementations for various ops on index
ought to handle both cases - if they don't might be a bug
This is intended as a fast pattern rewrite driver for the cases when a simple walk gets the job done but we would still want to implement it in terms of rewrite patterns (that can be used with the greedy pattern rewrite driver downstream). The new driver is inspired by the discussion in #112454 and the LLVM Dev presentation from @matthias-springer earlier this week. This limitation comes with some limitations: * It does not repeat until a fixpoint or revisit ops modified in place or newly created ops. In general, it only walks forward (in the post-order). * `matchAndRewrite` can only erase the matched op or its descendants. This is verified under expensive checks. * It does not perform folding / DCE. We could probably relax some of these in the future without sacrificing too much performance.
This is intended as a fast pattern rewrite driver for the cases when a simple walk gets the job done but we would still want to implement it in terms of rewrite patterns (that can be used with the greedy pattern rewrite driver downstream). The new driver is inspired by the discussion in llvm#112454 and the LLVM Dev presentation from @matthias-springer earlier this week. This limitation comes with some limitations: * It does not repeat until a fixpoint or revisit ops modified in place or newly created ops. In general, it only walks forward (in the post-order). * `matchAndRewrite` can only erase the matched op or its descendants. This is verified under expensive checks. * It does not perform folding / DCE. We could probably relax some of these in the future without sacrificing too much performance.
This is intended as a fast pattern rewrite driver for the cases when a simple walk gets the job done but we would still want to implement it in terms of rewrite patterns (that can be used with the greedy pattern rewrite driver downstream). The new driver is inspired by the discussion in llvm#112454 and the LLVM Dev presentation from @matthias-springer earlier this week. This limitation comes with some limitations: * It does not repeat until a fixpoint or revisit ops modified in place or newly created ops. In general, it only walks forward (in the post-order). * `matchAndRewrite` can only erase the matched op or its descendants. This is verified under expensive checks. * It does not perform folding / DCE. We could probably relax some of these in the future without sacrificing too much performance.
UnsignedWhenEquivalent
doesn't really need any dialect conversion features and switching it normal patterns makes it more composable with other patterns-based transformations (and probably faster).