Skip to content

[mlir][inliner] Add doClone and canHandleMultipleBlocks callbacks to Inliner #131226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 5, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions mlir/include/mlir/Transforms/Inliner.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ class InlinerConfig {
public:
using DefaultPipelineTy = std::function<void(OpPassManager &)>;
using OpPipelinesTy = llvm::StringMap<OpPassManager>;
using CloneCallbackTy =
std::function<void(OpBuilder &builder, Region *src, Block *inlineBlock,
Block *postInsertBlock, IRMapping &mapper,
bool shouldCloneInlinedRegion)>;

InlinerConfig() = default;
InlinerConfig(DefaultPipelineTy defaultPipeline,
Expand All @@ -39,13 +43,22 @@ class InlinerConfig {
}
const OpPipelinesTy &getOpPipelines() const { return opPipelines; }
unsigned getMaxInliningIterations() const { return maxInliningIterations; }
const CloneCallbackTy &getCloneCallback() const { return cloneCallback; }
bool getCanHandleMultipleBlocks() const { return canHandleMultipleBlocks; }

void setDefaultPipeline(DefaultPipelineTy pipeline) {
defaultPipeline = std::move(pipeline);
}
void setOpPipelines(OpPipelinesTy pipelines) {
opPipelines = std::move(pipelines);
}
void setMaxInliningIterations(unsigned max) { maxInliningIterations = max; }
void setCloneCallback(CloneCallbackTy callback) {
cloneCallback = std::move(callback);
}
void setCanHandleMultipleBlocks(bool value) {
canHandleMultipleBlocks = value;
}

private:
/// An optional function that constructs an optimization pipeline for
Expand All @@ -60,6 +73,28 @@ class InlinerConfig {
/// For SCC-based inlining algorithms, specifies maximum number of iterations
/// when inlining within an SCC.
unsigned maxInliningIterations{0};
/// Callback for cloning operations during inlining
CloneCallbackTy cloneCallback = [](OpBuilder &builder, Region *src,
Block *inlineBlock, Block *postInsertBlock,
IRMapping &mapper,
bool shouldCloneInlinedRegion) {
// Check to see if the region is being cloned, or moved inline. In
// either case, move the new blocks after the 'insertBlock' to improve
// IR readability.
Region *insertRegion = inlineBlock->getParent();
if (shouldCloneInlinedRegion)
src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
else
insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
src->getBlocks(), src->begin(),
src->end());
};
/// Determining if the inliner can inline a function containing multiple
/// blocks into a region that requires a single block. By default, it is
/// not allowed. If it is true, cloneCallback shuold perform the extra
/// transformation. see the example in
/// mlir/test/lib/Transforms/TestInliningCallback.cpp
bool canHandleMultipleBlocks{false};
};

/// This is an implementation of the inliner
Expand Down
16 changes: 11 additions & 5 deletions mlir/include/mlir/Transforms/InliningUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/IR/Location.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/Inliner.h"
#include <optional>

namespace mlir {
Expand Down Expand Up @@ -253,13 +254,15 @@ class InlinerInterface
/// provided, will be used to update the inlined operations' location
/// information. 'shouldCloneInlinedRegion' corresponds to whether the source
/// region should be cloned into the 'inlinePoint' or spliced directly.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the documentation the all the API changes should be updated, can you do it in a follow-up PR please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I will.

LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
LogicalResult inlineRegion(InlinerInterface &interface,
const InlinerConfig &config, Region *src,
Operation *inlinePoint, IRMapping &mapper,
ValueRange resultsToReplace,
TypeRange regionResultTypes,
std::optional<Location> inlineLoc = std::nullopt,
bool shouldCloneInlinedRegion = true);
LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
LogicalResult inlineRegion(InlinerInterface &interface,
const InlinerConfig &config, Region *src,
Block *inlineBlock, Block::iterator inlinePoint,
IRMapping &mapper, ValueRange resultsToReplace,
TypeRange regionResultTypes,
Expand All @@ -269,12 +272,14 @@ LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
/// This function is an overload of the above 'inlineRegion' that allows for
/// providing the set of operands ('inlinedOperands') that should be used
/// in-favor of the region arguments when inlining.
LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
LogicalResult inlineRegion(InlinerInterface &interface,
const InlinerConfig &config, Region *src,
Operation *inlinePoint, ValueRange inlinedOperands,
ValueRange resultsToReplace,
std::optional<Location> inlineLoc = std::nullopt,
bool shouldCloneInlinedRegion = true);
LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
LogicalResult inlineRegion(InlinerInterface &interface,
const InlinerConfig &config, Region *src,
Block *inlineBlock, Block::iterator inlinePoint,
ValueRange inlinedOperands,
ValueRange resultsToReplace,
Expand All @@ -287,7 +292,8 @@ LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
/// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
/// corresponds to whether the source region should be cloned into the 'call' or
/// spliced directly.
LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call,
LogicalResult inlineCall(InlinerInterface &interface,
const InlinerConfig &config, CallOpInterface call,
CallableOpInterface callable, Region *src,
bool shouldCloneInlinedRegion = true);

Expand Down
31 changes: 17 additions & 14 deletions mlir/lib/Transforms/Utils/Inliner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);

LogicalResult inlineResult =
inlineCall(inlinerIface, call,
inlineCall(inlinerIface, inliner.config, call,
cast<CallableOpInterface>(targetRegion->getParentOp()),
targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
if (failed(inlineResult)) {
Expand Down Expand Up @@ -729,19 +729,22 @@ bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) {

// Don't allow inlining if the callee has multiple blocks (unstructured
// control flow) but we cannot be sure that the caller region supports that.
bool calleeHasMultipleBlocks =
llvm::hasNItemsOrMore(*callableRegion, /*N=*/2);
// If both parent ops have the same type, it is safe to inline. Otherwise,
// decide based on whether the op has the SingleBlock trait or not.
// Note: This check does currently not account for SizedRegion/MaxSizedRegion.
auto callerRegionSupportsMultipleBlocks = [&]() {
return callableRegion->getParentOp()->getName() ==
resolvedCall.call->getParentOp()->getName() ||
!resolvedCall.call->getParentOp()
->mightHaveTrait<OpTrait::SingleBlock>();
};
if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
return false;
if (!inliner.config.getCanHandleMultipleBlocks()) {
bool calleeHasMultipleBlocks =
llvm::hasNItemsOrMore(*callableRegion, /*N=*/2);
// If both parent ops have the same type, it is safe to inline. Otherwise,
// decide based on whether the op has the SingleBlock trait or not.
// Note: This check does currently not account for
// SizedRegion/MaxSizedRegion.
auto callerRegionSupportsMultipleBlocks = [&]() {
return callableRegion->getParentOp()->getName() ==
resolvedCall.call->getParentOp()->getName() ||
!resolvedCall.call->getParentOp()
->mightHaveTrait<OpTrait::SingleBlock>();
};
if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
return false;
}

if (!inliner.isProfitableToInline(resolvedCall))
return false;
Expand Down
82 changes: 40 additions & 42 deletions mlir/lib/Transforms/Utils/InliningUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/Inliner.h"

#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
Expand Down Expand Up @@ -245,10 +246,10 @@ static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder,
}

static LogicalResult
inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
Block::iterator inlinePoint, IRMapping &mapper,
ValueRange resultsToReplace, TypeRange regionResultTypes,
std::optional<Location> inlineLoc,
inlineRegionImpl(InlinerInterface &interface, const InlinerConfig &config,
Region *src, Block *inlineBlock, Block::iterator inlinePoint,
IRMapping &mapper, ValueRange resultsToReplace,
TypeRange regionResultTypes, std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
assert(resultsToReplace.size() == regionResultTypes.size());
// We expect the region to have at least one block.
Expand All @@ -275,16 +276,10 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
if (call && callable)
handleArgumentImpl(interface, builder, call, callable, mapper);

// Check to see if the region is being cloned, or moved inline. In either
// case, move the new blocks after the 'insertBlock' to improve IR
// readability.
// Clone the callee's source into the caller.
Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint);
if (shouldCloneInlinedRegion)
src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
else
insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
src->getBlocks(), src->begin(),
src->end());
config.getCloneCallback()(builder, src, inlineBlock, postInsertBlock, mapper,
shouldCloneInlinedRegion);

// Get the range of newly inserted blocks.
auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()),
Expand All @@ -309,7 +304,7 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
bool singleBlockFastPath = interface.allowSingleBlockOptimization(newBlocks);

// Handle the case where only a single block was inlined.
if (singleBlockFastPath && std::next(newBlocks.begin()) == newBlocks.end()) {
if (singleBlockFastPath && llvm::hasSingleElement(newBlocks)) {
// Run the result attribute handler on the terminator operands.
Operation *firstBlockTerminator = firstNewBlock->getTerminator();
builder.setInsertionPoint(firstBlockTerminator);
Expand Down Expand Up @@ -353,9 +348,10 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
}

static LogicalResult
inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
Block::iterator inlinePoint, ValueRange inlinedOperands,
ValueRange resultsToReplace, std::optional<Location> inlineLoc,
inlineRegionImpl(InlinerInterface &interface, const InlinerConfig &config,
Region *src, Block *inlineBlock, Block::iterator inlinePoint,
ValueRange inlinedOperands, ValueRange resultsToReplace,
std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
// We expect the region to have at least one block.
if (src->empty())
Expand All @@ -377,51 +373,52 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
}

// Call into the main region inliner function.
return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
resultsToReplace, resultsToReplace.getTypes(),
return inlineRegionImpl(interface, config, src, inlineBlock, inlinePoint,
mapper, resultsToReplace, resultsToReplace.getTypes(),
inlineLoc, shouldCloneInlinedRegion, call);
}

LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
LogicalResult mlir::inlineRegion(InlinerInterface &interface,
const InlinerConfig &config, Region *src,
Operation *inlinePoint, IRMapping &mapper,
ValueRange resultsToReplace,
TypeRange regionResultTypes,
std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
return inlineRegion(interface, src, inlinePoint->getBlock(),
return inlineRegion(interface, config, src, inlinePoint->getBlock(),
++inlinePoint->getIterator(), mapper, resultsToReplace,
regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
}
LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
Block *inlineBlock,
Block::iterator inlinePoint, IRMapping &mapper,
ValueRange resultsToReplace,
TypeRange regionResultTypes,
std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
resultsToReplace, regionResultTypes, inlineLoc,
shouldCloneInlinedRegion);

LogicalResult mlir::inlineRegion(
InlinerInterface &interface, const InlinerConfig &config, Region *src,
Block *inlineBlock, Block::iterator inlinePoint, IRMapping &mapper,
ValueRange resultsToReplace, TypeRange regionResultTypes,
std::optional<Location> inlineLoc, bool shouldCloneInlinedRegion) {
return inlineRegionImpl(interface, config, src, inlineBlock, inlinePoint,
mapper, resultsToReplace, regionResultTypes,
inlineLoc, shouldCloneInlinedRegion);
}

LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
LogicalResult mlir::inlineRegion(InlinerInterface &interface,
const InlinerConfig &config, Region *src,
Operation *inlinePoint,
ValueRange inlinedOperands,
ValueRange resultsToReplace,
std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
return inlineRegion(interface, src, inlinePoint->getBlock(),
return inlineRegion(interface, config, src, inlinePoint->getBlock(),
++inlinePoint->getIterator(), inlinedOperands,
resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
}
LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
Block *inlineBlock,
Block::iterator inlinePoint,
ValueRange inlinedOperands,
ValueRange resultsToReplace,
std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
return inlineRegionImpl(interface, src, inlineBlock, inlinePoint,

LogicalResult
mlir::inlineRegion(InlinerInterface &interface, const InlinerConfig &config,
Region *src, Block *inlineBlock, Block::iterator inlinePoint,
ValueRange inlinedOperands, ValueRange resultsToReplace,
std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
return inlineRegionImpl(interface, config, src, inlineBlock, inlinePoint,
inlinedOperands, resultsToReplace, inlineLoc,
shouldCloneInlinedRegion);
}
Expand Down Expand Up @@ -455,6 +452,7 @@ static Value materializeConversion(const DialectInlinerInterface *interface,
/// corresponds to whether the source region should be cloned into the 'call' or
/// spliced directly.
LogicalResult mlir::inlineCall(InlinerInterface &interface,
const InlinerConfig &config,
CallOpInterface call,
CallableOpInterface callable, Region *src,
bool shouldCloneInlinedRegion) {
Expand Down Expand Up @@ -531,7 +529,7 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface,
return cleanupState();

// Attempt to inline the call.
if (failed(inlineRegionImpl(interface, src, call->getBlock(),
if (failed(inlineRegionImpl(interface, config, src, call->getBlock(),
++call->getIterator(), mapper, callResults,
callableResultTypes, call.getLoc(),
shouldCloneInlinedRegion, call)))
Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Transforms/test-inlining-callback.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -test-inline-callback | FileCheck %s

// Test inlining with multiple blocks and scf.execute_region transformation
// CHECK-LABEL: func @test_inline_multiple_blocks
func.func @test_inline_multiple_blocks(%arg0: i32) -> i32 {
// CHECK: %[[RES:.*]] = scf.execute_region -> i32
// CHECK-NEXT: %[[ADD1:.*]] = arith.addi %arg0, %arg0
// CHECK-NEXT: cf.br ^bb1(%[[ADD1]] : i32)
// CHECK: ^bb1(%[[ARG:.*]]: i32):
// CHECK-NEXT: %[[ADD2:.*]] = arith.addi %[[ARG]], %[[ARG]]
// CHECK-NEXT: scf.yield %[[ADD2]]
// CHECK: return %[[RES]]
%fn = "test.functional_region_op"() ({
^bb0(%a : i32):
%b = arith.addi %a, %a : i32
cf.br ^bb1(%b: i32)
^bb1(%c: i32):
%d = arith.addi %c, %c : i32
"test.return"(%d) : (i32) -> ()
}) : () -> ((i32) -> i32)

%0 = call_indirect %fn(%arg0) : (i32) -> i32
return %0 : i32
}
1 change: 1 addition & 0 deletions mlir/test/lib/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ add_mlir_library(MLIRTestTransforms
TestConstantFold.cpp
TestControlFlowSink.cpp
TestInlining.cpp
TestInliningCallback.cpp
TestMakeIsolatedFromAbove.cpp
TestTransformsOps.cpp
${MLIRTestTransformsPDLSrc}
Expand Down
Loading