Skip to content

Commit 4513050

Browse files
authored
[MLIR] Harmonize the behavior of the folding API functions (#88508)
This commit changes `OpBuilder::tryFold` to behave more similarly to `Operation::fold`. Concretely, this ensures that even an in-place fold returns `success`. This is necessary to fix a bug in the dialect conversion that occurred when an in-place folding made an operation legal. The dialect conversion infrastructure did not check if the result of an in-place folding legalized the operation and just went ahead and tried to apply pattern anyways. The added test contains a simplified version of a breakage we observed downstream.
1 parent b28f4d4 commit 4513050

File tree

8 files changed

+63
-15
lines changed

8 files changed

+63
-15
lines changed

mlir/include/mlir/IR/Builders.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,7 @@ class OpBuilder : public Builder {
517517

518518
/// Create an operation of specific op type at the current insertion point,
519519
/// and immediately try to fold it. This functions populates 'results' with
520-
/// the results after folding the operation.
520+
/// the results of the operation.
521521
template <typename OpTy, typename... Args>
522522
void createOrFold(SmallVectorImpl<Value> &results, Location location,
523523
Args &&...args) {
@@ -530,10 +530,17 @@ class OpBuilder : public Builder {
530530
if (block)
531531
block->getOperations().insert(insertPoint, op);
532532

533-
// Fold the operation. If successful erase it, otherwise notify.
534-
if (succeeded(tryFold(op, results)))
533+
// Attempt to fold the operation.
534+
if (succeeded(tryFold(op, results)) && !results.empty()) {
535+
// Erase the operation, if the fold removed the need for this operation.
536+
// Note: The fold already populated the results in this case.
535537
op->erase();
536-
else if (block && listener)
538+
return;
539+
}
540+
541+
ResultRange opResults = op->getResults();
542+
results.assign(opResults.begin(), opResults.end());
543+
if (block && listener)
537544
listener->notifyOperationInserted(op, /*previous=*/{});
538545
}
539546

@@ -560,7 +567,8 @@ class OpBuilder : public Builder {
560567
}
561568

562569
/// Attempts to fold the given operation and places new results within
563-
/// 'results'. Returns success if the operation was folded, failure otherwise.
570+
/// `results`. Returns success if the operation was folded, failure otherwise.
571+
/// If the fold was in-place, `results` will not be filled.
564572
/// Note: This function does not erase the operation on a successful fold.
565573
LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
566574

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2831,7 +2831,8 @@ LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); }
28312831

28322832
/// Folds a cast op that can be chained.
28332833
template <typename T>
2834-
static Value foldChainableCast(T castOp, typename T::FoldAdaptor adaptor) {
2834+
static OpFoldResult foldChainableCast(T castOp,
2835+
typename T::FoldAdaptor adaptor) {
28352836
// cast(x : T0, T0) -> x
28362837
if (castOp.getArg().getType() == castOp.getType())
28372838
return castOp.getArg();

mlir/lib/IR/Builders.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -476,16 +476,14 @@ Operation *OpBuilder::create(Location loc, StringAttr opName,
476476
return create(state);
477477
}
478478

479-
/// Attempts to fold the given operation and places new results within
480-
/// 'results'. Returns success if the operation was folded, failure otherwise.
481-
/// Note: This function does not erase the operation on a successful fold.
482479
LogicalResult OpBuilder::tryFold(Operation *op,
483480
SmallVectorImpl<Value> &results) {
481+
assert(results.empty() && "expected empty results");
484482
ResultRange opResults = op->getResults();
485483

486484
results.reserve(opResults.size());
487485
auto cleanupFailure = [&] {
488-
results.assign(opResults.begin(), opResults.end());
486+
results.clear();
489487
return failure();
490488
};
491489

@@ -495,20 +493,24 @@ LogicalResult OpBuilder::tryFold(Operation *op,
495493

496494
// Try to fold the operation.
497495
SmallVector<OpFoldResult, 4> foldResults;
498-
if (failed(op->fold(foldResults)) || foldResults.empty())
496+
if (failed(op->fold(foldResults)))
499497
return cleanupFailure();
500498

499+
// An in-place fold does not require generation of any constants.
500+
if (foldResults.empty())
501+
return success();
502+
501503
// A temporary builder used for creating constants during folding.
502504
OpBuilder cstBuilder(context);
503505
SmallVector<Operation *, 1> generatedConstants;
504506

505507
// Populate the results with the folded results.
506508
Dialect *dialect = op->getDialect();
507-
for (auto it : llvm::zip_equal(foldResults, opResults.getTypes())) {
508-
Type expectedType = std::get<1>(it);
509+
for (auto [foldResult, expectedType] :
510+
llvm::zip_equal(foldResults, opResults.getTypes())) {
509511

510512
// Normal values get pushed back directly.
511-
if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
513+
if (auto value = llvm::dyn_cast_if_present<Value>(foldResult)) {
512514
results.push_back(value);
513515
continue;
514516
}
@@ -518,7 +520,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
518520
return cleanupFailure();
519521

520522
// Ask the dialect to materialize a constant operation for this value.
521-
Attribute attr = std::get<0>(it).get<Attribute>();
523+
Attribute attr = foldResult.get<Attribute>();
522524
auto *constOp = dialect->materializeConstant(cstBuilder, attr, expectedType,
523525
op->getLoc());
524526
if (!constOp) {

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2072,6 +2072,10 @@ OperationLegalizer::legalizeWithFold(Operation *op,
20722072
LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
20732073
return failure();
20742074
}
2075+
// An empty list of replacement values indicates that the fold was in-place.
2076+
// As the operation changed, a new legalization needs to be attempted.
2077+
if (replacementValues.empty())
2078+
return legalize(op, rewriter);
20752079

20762080
// Insert a replacement for 'op' with the folded replacement values.
20772081
rewriter.replaceOp(op, replacementValues);

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,3 +427,13 @@ func.func @use_of_replaced_bbarg(%arg0: i64) {
427427
}) : (i64) -> (i64)
428428
"test.invalid"(%0) : (i64) -> ()
429429
}
430+
431+
// -----
432+
433+
// CHECK-LABEL: @fold_legalization
434+
func.func @fold_legalization() -> i32 {
435+
// CHECK: op_in_place_self_fold
436+
// CHECK-SAME: folded
437+
%1 = "test.op_in_place_self_fold"() : () -> (i32)
438+
"test.return"(%1) : (i32) -> ()
439+
}

mlir/test/lib/Dialect/Test/TestOpDefs.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,19 @@ LogicalResult CompareOp::verify() {
825825
return success();
826826
}
827827

828+
//===----------------------------------------------------------------------===//
829+
// TestOpInPlaceSelfFold
830+
//===----------------------------------------------------------------------===//
831+
832+
OpFoldResult TestOpInPlaceSelfFold::fold(FoldAdaptor adaptor) {
833+
if (!getFolded()) {
834+
// The folder adds the "folded" if not present.
835+
setFolded(true);
836+
return getResult();
837+
}
838+
return {};
839+
}
840+
828841
//===----------------------------------------------------------------------===//
829842
// TestOpFoldWithFoldAdaptor
830843
//===----------------------------------------------------------------------===//

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,6 +1351,12 @@ def TestOpInPlaceFold : TEST_Op<"op_in_place_fold"> {
13511351
let hasFolder = 1;
13521352
}
13531353

1354+
def TestOpInPlaceSelfFold : TEST_Op<"op_in_place_self_fold"> {
1355+
let arguments = (ins UnitAttr:$folded);
1356+
let results = (outs I32);
1357+
let hasFolder = 1;
1358+
}
1359+
13541360
// Test op that simply returns success.
13551361
def TestOpInPlaceFoldSuccess : TEST_Op<"op_in_place_fold_success"> {
13561362
let results = (outs Variadic<I1>);

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,6 +1168,10 @@ struct TestLegalizePatternDriver
11681168
target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
11691169
[](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
11701170

1171+
// Create a dynamically legal rule that can only be legalized by folding it.
1172+
target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
1173+
[](TestOpInPlaceSelfFold op) { return op.getFolded(); });
1174+
11711175
// Handle a partial conversion.
11721176
if (mode == ConversionMode::Partial) {
11731177
DenseSet<Operation *> unlegalizedOps;

0 commit comments

Comments
 (0)