Skip to content

[MLIR] Harmonize the behavior of the folding API functions #88508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions mlir/include/mlir/IR/Builders.h
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ class OpBuilder : public Builder {

/// Create an operation of specific op type at the current insertion point,
/// and immediately try to fold it. This functions populates 'results' with
/// the results after folding the operation.
/// the results of the operation.
template <typename OpTy, typename... Args>
void createOrFold(SmallVectorImpl<Value> &results, Location location,
Args &&...args) {
Expand All @@ -530,10 +530,17 @@ class OpBuilder : public Builder {
if (block)
block->getOperations().insert(insertPoint, op);

// Fold the operation. If successful erase it, otherwise notify.
if (succeeded(tryFold(op, results)))
// Attempt to fold the operation.
if (succeeded(tryFold(op, results)) && !results.empty()) {
// Erase the operation, if the fold removed the need for this operation.
// Note: The fold already populated the results in this case.
op->erase();
else if (block && listener)
return;
}

ResultRange opResults = op->getResults();
results.assign(opResults.begin(), opResults.end());
if (block && listener)
listener->notifyOperationInserted(op, /*previous=*/{});
}

Expand All @@ -560,7 +567,8 @@ class OpBuilder : public Builder {
}

/// Attempts to fold the given operation and places new results within
/// 'results'. Returns success if the operation was folded, failure otherwise.
/// `results`. Returns success if the operation was folded, failure otherwise.
/// If the fold was in-place, `results` will not be filled.
/// Note: This function does not erase the operation on a successful fold.
LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);

Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2831,7 +2831,8 @@ LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); }

/// Folds a cast op that can be chained.
template <typename T>
static Value foldChainableCast(T castOp, typename T::FoldAdaptor adaptor) {
static OpFoldResult foldChainableCast(T castOp,
typename T::FoldAdaptor adaptor) {
// cast(x : T0, T0) -> x
if (castOp.getArg().getType() == castOp.getType())
return castOp.getArg();
Expand Down
20 changes: 11 additions & 9 deletions mlir/lib/IR/Builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -476,16 +476,14 @@ Operation *OpBuilder::create(Location loc, StringAttr opName,
return create(state);
}

/// Attempts to fold the given operation and places new results within
/// 'results'. Returns success if the operation was folded, failure otherwise.
/// Note: This function does not erase the operation on a successful fold.
LogicalResult OpBuilder::tryFold(Operation *op,
SmallVectorImpl<Value> &results) {
assert(results.empty() && "expected empty results");
ResultRange opResults = op->getResults();

results.reserve(opResults.size());
auto cleanupFailure = [&] {
results.assign(opResults.begin(), opResults.end());
results.clear();
return failure();
};

Expand All @@ -495,20 +493,24 @@ LogicalResult OpBuilder::tryFold(Operation *op,

// Try to fold the operation.
SmallVector<OpFoldResult, 4> foldResults;
if (failed(op->fold(foldResults)) || foldResults.empty())
if (failed(op->fold(foldResults)))
return cleanupFailure();

// An in-place fold does not require generation of any constants.
if (foldResults.empty())
return success();

// A temporary builder used for creating constants during folding.
OpBuilder cstBuilder(context);
SmallVector<Operation *, 1> generatedConstants;

// Populate the results with the folded results.
Dialect *dialect = op->getDialect();
for (auto it : llvm::zip_equal(foldResults, opResults.getTypes())) {
Type expectedType = std::get<1>(it);
for (auto [foldResult, expectedType] :
llvm::zip_equal(foldResults, opResults.getTypes())) {

// Normal values get pushed back directly.
if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
if (auto value = llvm::dyn_cast_if_present<Value>(foldResult)) {
results.push_back(value);
continue;
}
Expand All @@ -518,7 +520,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
return cleanupFailure();

// Ask the dialect to materialize a constant operation for this value.
Attribute attr = std::get<0>(it).get<Attribute>();
Attribute attr = foldResult.get<Attribute>();
auto *constOp = dialect->materializeConstant(cstBuilder, attr, expectedType,
op->getLoc());
if (!constOp) {
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2072,6 +2072,10 @@ OperationLegalizer::legalizeWithFold(Operation *op,
LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
return failure();
}
// An empty list of replacement values indicates that the fold was in-place.
// As the operation changed, a new legalization needs to be attempted.
if (replacementValues.empty())
return legalize(op, rewriter);

// Insert a replacement for 'op' with the folded replacement values.
rewriter.replaceOp(op, replacementValues);
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/Transforms/test-legalizer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,13 @@ func.func @use_of_replaced_bbarg(%arg0: i64) {
}) : (i64) -> (i64)
"test.invalid"(%0) : (i64) -> ()
}

// -----

// CHECK-LABEL: @fold_legalization
func.func @fold_legalization() -> i32 {
// CHECK: op_in_place_self_fold
// CHECK-SAME: folded
%1 = "test.op_in_place_self_fold"() : () -> (i32)
"test.return"(%1) : (i32) -> ()
}
13 changes: 13 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOpDefs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,19 @@ LogicalResult CompareOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// TestOpInPlaceSelfFold
//===----------------------------------------------------------------------===//

OpFoldResult TestOpInPlaceSelfFold::fold(FoldAdaptor adaptor) {
if (!getFolded()) {
// The folder adds the "folded" if not present.
setFolded(true);
return getResult();
}
return {};
}

//===----------------------------------------------------------------------===//
// TestOpFoldWithFoldAdaptor
//===----------------------------------------------------------------------===//
Expand Down
6 changes: 6 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1351,6 +1351,12 @@ def TestOpInPlaceFold : TEST_Op<"op_in_place_fold"> {
let hasFolder = 1;
}

def TestOpInPlaceSelfFold : TEST_Op<"op_in_place_self_fold"> {
let arguments = (ins UnitAttr:$folded);
let results = (outs I32);
let hasFolder = 1;
}

// Test op that simply returns success.
def TestOpInPlaceFoldSuccess : TEST_Op<"op_in_place_fold_success"> {
let results = (outs Variadic<I1>);
Expand Down
4 changes: 4 additions & 0 deletions mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,10 @@ struct TestLegalizePatternDriver
target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
[](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });

// Create a dynamically legal rule that can only be legalized by folding it.
target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
[](TestOpInPlaceSelfFold op) { return op.getFolded(); });

// Handle a partial conversion.
if (mode == ConversionMode::Partial) {
DenseSet<Operation *> unlegalizedOps;
Expand Down