-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] Harmonize the behavior of the folding API functions #88508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Christian Ulmann (Dinistro) ChangesThis commit changes The added test contains a simplified version of a breakage we observed downstream. Full diff: https://github.com/llvm/llvm-project/pull/88508.diff 8 Files Affected:
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 3beade017d1ab9..e74505e5dbfdf4 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -517,7 +517,7 @@ class OpBuilder : public Builder {
/// Create an operation of specific op type at the current insertion point,
/// and immediately try to fold it. This functions populates 'results' with
- /// the results after folding the operation.
+ /// the results of the operation.
template <typename OpTy, typename... Args>
void createOrFold(SmallVectorImpl<Value> &results, Location location,
Args &&...args) {
@@ -530,10 +530,17 @@ class OpBuilder : public Builder {
if (block)
block->getOperations().insert(insertPoint, op);
- // Fold the operation. If successful erase it, otherwise notify.
- if (succeeded(tryFold(op, results)))
+ // Attempt to fold the operation.
+ if (succeeded(tryFold(op, results)) && !results.empty()) {
+ // Erase the operation, if the fold removed the need for this operation.
+ // Note: The fold already populated the results in this case.
op->erase();
- else if (block && listener)
+ return;
+ }
+
+ ResultRange opResults = op->getResults();
+ results.assign(opResults.begin(), opResults.end());
+ if (block && listener)
listener->notifyOperationInserted(op, /*previous=*/{});
}
@@ -561,6 +568,7 @@ class OpBuilder : public Builder {
/// Attempts to fold the given operation and places new results within
/// 'results'. Returns success if the operation was folded, failure otherwise.
+ /// If the fold was in-place, `results` will not be filled.
/// Note: This function does not erase the operation on a successful fold.
LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index f90240a67dcc5f..0fff06df39c1e7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2763,7 +2763,8 @@ LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); }
/// Folds a cast op that can be chained.
template <typename T>
-static Value foldChainableCast(T castOp, typename T::FoldAdaptor adaptor) {
+static OpFoldResult foldChainableCast(T castOp,
+ typename T::FoldAdaptor adaptor) {
// cast(x : T0, T0) -> x
if (castOp.getArg().getType() == castOp.getType())
return castOp.getArg();
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 18ca3c332e0204..36e17609eab609 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -476,16 +476,14 @@ Operation *OpBuilder::create(Location loc, StringAttr opName,
return create(state);
}
-/// Attempts to fold the given operation and places new results within
-/// 'results'. Returns success if the operation was folded, failure otherwise.
-/// Note: This function does not erase the operation on a successful fold.
LogicalResult OpBuilder::tryFold(Operation *op,
SmallVectorImpl<Value> &results) {
+ assert(results.empty());
ResultRange opResults = op->getResults();
results.reserve(opResults.size());
auto cleanupFailure = [&] {
- results.assign(opResults.begin(), opResults.end());
+ results.clear();
return failure();
};
@@ -495,20 +493,24 @@ LogicalResult OpBuilder::tryFold(Operation *op,
// Try to fold the operation.
SmallVector<OpFoldResult, 4> foldResults;
- if (failed(op->fold(foldResults)) || foldResults.empty())
+ if (failed(op->fold(foldResults)))
return cleanupFailure();
+ // An in-place fold does not require generation of any constants.
+ if (foldResults.empty())
+ return success();
+
// A temporary builder used for creating constants during folding.
OpBuilder cstBuilder(context);
SmallVector<Operation *, 1> generatedConstants;
// Populate the results with the folded results.
Dialect *dialect = op->getDialect();
- for (auto it : llvm::zip_equal(foldResults, opResults.getTypes())) {
- Type expectedType = std::get<1>(it);
+ for (auto [foldResult, expectedType] :
+ llvm::zip_equal(foldResults, opResults.getTypes())) {
// Normal values get pushed back directly.
- if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
+ if (auto value = llvm::dyn_cast_if_present<Value>(foldResult)) {
results.push_back(value);
continue;
}
@@ -518,7 +520,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
return cleanupFailure();
// Ask the dialect to materialize a constant operation for this value.
- Attribute attr = std::get<0>(it).get<Attribute>();
+ Attribute attr = foldResult.get<Attribute>();
auto *constOp = dialect->materializeConstant(cstBuilder, attr, expectedType,
op->getLoc());
if (!constOp) {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 8671c1008902a0..18d6f7daa4bea7 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2072,6 +2072,10 @@ OperationLegalizer::legalizeWithFold(Operation *op,
LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
return failure();
}
+ // An empty list of replacement values indicates that the fold was in-place.
+ // As the operation changed, a new legalization needs to be attempted.
+ if (replacementValues.empty())
+ return legalize(op, rewriter);
// Insert a replacement for 'op' with the folded replacement values.
rewriter.replaceOp(op, replacementValues);
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index d552f0346644b3..7530b300d57b8b 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -427,3 +427,13 @@ func.func @use_of_replaced_bbarg(%arg0: i64) {
}) : (i64) -> (i64)
"test.invalid"(%0) : (i64) -> ()
}
+
+// -----
+
+// CHECK-LABEL: @fold_legalization
+func.func @fold_legalization() -> i32 {
+ // CHECK: op_in_place_self_fold
+ // CHECK-SAME: folded = true
+ %1 = "test.op_in_place_self_fold"() : () -> (i32)
+ "test.return"(%1) : (i32) -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 380c74a47e509a..becd0d68bf1d3e 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -568,6 +568,15 @@ OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
return {};
}
+OpFoldResult TestOpInPlaceSelfFold::fold(FoldAdaptor adaptor) {
+ if (!getProperties().folded) {
+ // The folder adds the "folded" if not present.
+ getProperties().folded = BoolAttr::get(getContext(), true);
+ return getResult();
+ }
+ return {};
+}
+
OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
int64_t sum = 0;
if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index e6c3601d08dad0..663064d51f1bbe 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1351,6 +1351,12 @@ def TestOpInPlaceFold : TEST_Op<"op_in_place_fold"> {
let hasFolder = 1;
}
+def TestOpInPlaceSelfFold : TEST_Op<"op_in_place_self_fold"> {
+ let arguments = (ins OptionalAttr<BoolAttr>:$folded);
+ let results = (outs I32);
+ let hasFolder = 1;
+}
+
// Test op that simply returns success.
def TestOpInPlaceFoldSuccess : TEST_Op<"op_in_place_fold_success"> {
let results = (outs Variadic<I1>);
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 76dc825fe44515..285e39dd9016e1 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1167,6 +1167,12 @@ struct TestLegalizePatternDriver
target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
[](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
+ // Create a dynamically legal rule that can only be legalized by folding it.
+ target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
+ [](TestOpInPlaceSelfFold op) {
+ return op.getProperties().folded != nullptr;
+ });
+
// Handle a partial conversion.
if (mode == ConversionMode::Partial) {
DenseSet<Operation *> unlegalizedOps;
|
@matthias-springer friendly ping 🙂 |
mlir/lib/IR/Builders.cpp
Outdated
LogicalResult OpBuilder::tryFold(Operation *op, | ||
SmallVectorImpl<Value> &results) { | ||
assert(results.empty()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: && "expected empty results"
mlir/include/mlir/IR/Builders.h
Outdated
@@ -561,6 +568,7 @@ class OpBuilder : public Builder { | |||
|
|||
/// Attempts to fold the given operation and places new results within | |||
/// 'results'. Returns success if the operation was folded, failure otherwise. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use backticks
@@ -1351,6 +1351,12 @@ def TestOpInPlaceFold : TEST_Op<"op_in_place_fold"> { | |||
let hasFolder = 1; | |||
} | |||
|
|||
def TestOpInPlaceSelfFold : TEST_Op<"op_in_place_self_fold"> { | |||
let arguments = (ins OptionalAttr<BoolAttr>:$folded); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about turning this into a UnitAttr
?
// Create a dynamically legal rule that can only be legalized by folding it. | ||
target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>( | ||
[](TestOpInPlaceSelfFold op) { | ||
return op.getProperties().folded != nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When this is a unit attr, you can write return op.getFolded();
.
OpFoldResult TestOpInPlaceSelfFold::fold(FoldAdaptor adaptor) { | ||
if (!getProperties().folded) { | ||
// The folder adds the "folded" if not present. | ||
getProperties().folded = BoolAttr::get(getContext(), true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use setFolded
here?
This commit changes `OpBuilder::tryFold` to behave more similarly to `Operation::fold`. Concretely, this ensures that even an in-place fold returns `success`. This is necessary to fix a bug in the dialect conversion that occurred when an in-place folding made an operation legal. The dialect conversion infrastructure did not check if the result of an in-place folding legalized the operation and just went ahead and tried to apply pattern anyways.
50edf67
to
c4e1f79
Compare
This commit changes
OpBuilder::tryFold
to behave more similarly toOperation::fold
. Concretely, this ensures that even an in-place fold returnssuccess
.This is necessary to fix a bug in the dialect conversion that occurred when an in-place folding made an operation legal. The dialect conversion infrastructure did not check if the result of an in-place folding legalized the operation and just went ahead and tried to apply pattern anyways.
The added test contains a simplified version of a breakage we observed downstream.