Skip to content

[mlir][inliner] Add doClone and canHandleMultipleBlocks callbacks to Inliner #131226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions mlir/include/mlir/Transforms/Inliner.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ class InlinerConfig {
public:
using DefaultPipelineTy = std::function<void(OpPassManager &)>;
using OpPipelinesTy = llvm::StringMap<OpPassManager>;
using CloneCallbackSigTy = void(OpBuilder &builder, Region *src,
Block *inlineBlock, Block *postInsertBlock,
IRMapping &mapper,
bool shouldCloneInlinedRegion);
using CloneCallbackTy = std::function<CloneCallbackSigTy>;

InlinerConfig() = default;
InlinerConfig(DefaultPipelineTy defaultPipeline,
Expand All @@ -39,13 +44,22 @@ class InlinerConfig {
}
const OpPipelinesTy &getOpPipelines() const { return opPipelines; }
unsigned getMaxInliningIterations() const { return maxInliningIterations; }
const CloneCallbackTy &getCloneCallback() const { return cloneCallback; }
bool getCanHandleMultipleBlocks() const { return canHandleMultipleBlocks; }

void setDefaultPipeline(DefaultPipelineTy pipeline) {
defaultPipeline = std::move(pipeline);
}
void setOpPipelines(OpPipelinesTy pipelines) {
opPipelines = std::move(pipelines);
}
void setMaxInliningIterations(unsigned max) { maxInliningIterations = max; }
void setCloneCallback(CloneCallbackTy callback) {
cloneCallback = std::move(callback);
}
void setCanHandleMultipleBlocks(bool value = true) {
canHandleMultipleBlocks = value;
}

private:
/// An optional function that constructs an optimization pipeline for
Expand All @@ -60,6 +74,28 @@ class InlinerConfig {
/// For SCC-based inlining algorithms, specifies maximum number of iterations
/// when inlining within an SCC.
unsigned maxInliningIterations{0};
/// Callback for cloning operations during inlining
CloneCallbackTy cloneCallback = [](OpBuilder &builder, Region *src,
Block *inlineBlock, Block *postInsertBlock,
IRMapping &mapper,
bool shouldCloneInlinedRegion) {
// Check to see if the region is being cloned, or moved inline. In
// either case, move the new blocks after the 'insertBlock' to improve
// IR readability.
Region *insertRegion = inlineBlock->getParent();
if (shouldCloneInlinedRegion)
src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
else
insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
src->getBlocks(), src->begin(),
src->end());
};
/// Determine if the inliner can inline a function containing multiple
/// blocks into a region that requires a single block. By default, it is
/// not allowed. If it is true, cloneCallback should perform the extra
/// transformation. see the example in
/// mlir/test/lib/Transforms/TestInliningCallback.cpp
bool canHandleMultipleBlocks{false};
};

/// This is an implementation of the inliner
Expand Down
61 changes: 35 additions & 26 deletions mlir/include/mlir/Transforms/InliningUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/IR/Location.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/Inliner.h"
#include <optional>

namespace mlir {
Expand Down Expand Up @@ -253,43 +254,51 @@ class InlinerInterface
/// provided, will be used to update the inlined operations' location
/// information. 'shouldCloneInlinedRegion' corresponds to whether the source
/// region should be cloned into the 'inlinePoint' or spliced directly.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the documentation the all the API changes should be updated, can you do it in a follow-up PR please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I will.

LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
Operation *inlinePoint, IRMapping &mapper,
ValueRange resultsToReplace,
TypeRange regionResultTypes,
std::optional<Location> inlineLoc = std::nullopt,
bool shouldCloneInlinedRegion = true);
LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
Block *inlineBlock, Block::iterator inlinePoint,
IRMapping &mapper, ValueRange resultsToReplace,
TypeRange regionResultTypes,
std::optional<Location> inlineLoc = std::nullopt,
bool shouldCloneInlinedRegion = true);
LogicalResult
inlineRegion(InlinerInterface &interface,
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
Region *src, Operation *inlinePoint, IRMapping &mapper,
ValueRange resultsToReplace, TypeRange regionResultTypes,
std::optional<Location> inlineLoc = std::nullopt,
bool shouldCloneInlinedRegion = true);
LogicalResult
inlineRegion(InlinerInterface &interface,
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
Region *src, Block *inlineBlock, Block::iterator inlinePoint,
IRMapping &mapper, ValueRange resultsToReplace,
TypeRange regionResultTypes,
std::optional<Location> inlineLoc = std::nullopt,
bool shouldCloneInlinedRegion = true);

/// This function is an overload of the above 'inlineRegion' that allows for
/// providing the set of operands ('inlinedOperands') that should be used
/// in-favor of the region arguments when inlining.
LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
Operation *inlinePoint, ValueRange inlinedOperands,
ValueRange resultsToReplace,
std::optional<Location> inlineLoc = std::nullopt,
bool shouldCloneInlinedRegion = true);
LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
Block *inlineBlock, Block::iterator inlinePoint,
ValueRange inlinedOperands,
ValueRange resultsToReplace,
std::optional<Location> inlineLoc = std::nullopt,
bool shouldCloneInlinedRegion = true);
LogicalResult
inlineRegion(InlinerInterface &interface,
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
Region *src, Operation *inlinePoint, ValueRange inlinedOperands,
ValueRange resultsToReplace,
std::optional<Location> inlineLoc = std::nullopt,
bool shouldCloneInlinedRegion = true);
LogicalResult
inlineRegion(InlinerInterface &interface,
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
Region *src, Block *inlineBlock, Block::iterator inlinePoint,
ValueRange inlinedOperands, ValueRange resultsToReplace,
std::optional<Location> inlineLoc = std::nullopt,
bool shouldCloneInlinedRegion = true);

/// This function inlines a given region, 'src', of a callable operation,
/// 'callable', into the location defined by the given call operation. This
/// function returns failure if inlining is not possible, success otherwise. On
/// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
/// corresponds to whether the source region should be cloned into the 'call' or
/// spliced directly.
LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call,
CallableOpInterface callable, Region *src,
bool shouldCloneInlinedRegion = true);
LogicalResult
inlineCall(InlinerInterface &interface,
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
CallOpInterface call, CallableOpInterface callable, Region *src,
bool shouldCloneInlinedRegion = true);

} // namespace mlir

Expand Down
31 changes: 17 additions & 14 deletions mlir/lib/Transforms/Utils/Inliner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);

LogicalResult inlineResult =
inlineCall(inlinerIface, call,
inlineCall(inlinerIface, inliner.config.getCloneCallback(), call,
cast<CallableOpInterface>(targetRegion->getParentOp()),
targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
if (failed(inlineResult)) {
Expand Down Expand Up @@ -729,19 +729,22 @@ bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) {

// Don't allow inlining if the callee has multiple blocks (unstructured
// control flow) but we cannot be sure that the caller region supports that.
bool calleeHasMultipleBlocks =
llvm::hasNItemsOrMore(*callableRegion, /*N=*/2);
// If both parent ops have the same type, it is safe to inline. Otherwise,
// decide based on whether the op has the SingleBlock trait or not.
// Note: This check does currently not account for SizedRegion/MaxSizedRegion.
auto callerRegionSupportsMultipleBlocks = [&]() {
return callableRegion->getParentOp()->getName() ==
resolvedCall.call->getParentOp()->getName() ||
!resolvedCall.call->getParentOp()
->mightHaveTrait<OpTrait::SingleBlock>();
};
if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
return false;
if (!inliner.config.getCanHandleMultipleBlocks()) {
bool calleeHasMultipleBlocks =
llvm::hasNItemsOrMore(*callableRegion, /*N=*/2);
// If both parent ops have the same type, it is safe to inline. Otherwise,
// decide based on whether the op has the SingleBlock trait or not.
// Note: This check does currently not account for
// SizedRegion/MaxSizedRegion.
auto callerRegionSupportsMultipleBlocks = [&]() {
return callableRegion->getParentOp()->getName() ==
resolvedCall.call->getParentOp()->getName() ||
!resolvedCall.call->getParentOp()
->mightHaveTrait<OpTrait::SingleBlock>();
};
if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
return false;
}

if (!inliner.isProfitableToInline(resolvedCall))
return false;
Expand Down
118 changes: 59 additions & 59 deletions mlir/lib/Transforms/Utils/InliningUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/Inliner.h"

#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
Expand Down Expand Up @@ -245,10 +246,11 @@ static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder,
}

static LogicalResult
inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
Block::iterator inlinePoint, IRMapping &mapper,
ValueRange resultsToReplace, TypeRange regionResultTypes,
std::optional<Location> inlineLoc,
inlineRegionImpl(InlinerInterface &interface,
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
Region *src, Block *inlineBlock, Block::iterator inlinePoint,
IRMapping &mapper, ValueRange resultsToReplace,
TypeRange regionResultTypes, std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
assert(resultsToReplace.size() == regionResultTypes.size());
// We expect the region to have at least one block.
Expand All @@ -275,16 +277,10 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
if (call && callable)
handleArgumentImpl(interface, builder, call, callable, mapper);

// Check to see if the region is being cloned, or moved inline. In either
// case, move the new blocks after the 'insertBlock' to improve IR
// readability.
// Clone the callee's source into the caller.
Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint);
if (shouldCloneInlinedRegion)
src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
else
insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
src->getBlocks(), src->begin(),
src->end());
cloneCallback(builder, src, inlineBlock, postInsertBlock, mapper,
shouldCloneInlinedRegion);

// Get the range of newly inserted blocks.
auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()),
Expand All @@ -309,7 +305,7 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
bool singleBlockFastPath = interface.allowSingleBlockOptimization(newBlocks);

// Handle the case where only a single block was inlined.
if (singleBlockFastPath && std::next(newBlocks.begin()) == newBlocks.end()) {
if (singleBlockFastPath && llvm::hasSingleElement(newBlocks)) {
// Run the result attribute handler on the terminator operands.
Operation *firstBlockTerminator = firstNewBlock->getTerminator();
builder.setInsertionPoint(firstBlockTerminator);
Expand Down Expand Up @@ -353,9 +349,11 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
}

static LogicalResult
inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
Block::iterator inlinePoint, ValueRange inlinedOperands,
ValueRange resultsToReplace, std::optional<Location> inlineLoc,
inlineRegionImpl(InlinerInterface &interface,
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
Region *src, Block *inlineBlock, Block::iterator inlinePoint,
ValueRange inlinedOperands, ValueRange resultsToReplace,
std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
// We expect the region to have at least one block.
if (src->empty())
Expand All @@ -377,53 +375,54 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
}

// Call into the main region inliner function.
return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
resultsToReplace, resultsToReplace.getTypes(),
inlineLoc, shouldCloneInlinedRegion, call);
return inlineRegionImpl(interface, cloneCallback, src, inlineBlock,
inlinePoint, mapper, resultsToReplace,
resultsToReplace.getTypes(), inlineLoc,
shouldCloneInlinedRegion, call);
}

LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
Operation *inlinePoint, IRMapping &mapper,
ValueRange resultsToReplace,
TypeRange regionResultTypes,
std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
return inlineRegion(interface, src, inlinePoint->getBlock(),
LogicalResult mlir::inlineRegion(
InlinerInterface &interface,
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback, Region *src,
Operation *inlinePoint, IRMapping &mapper, ValueRange resultsToReplace,
TypeRange regionResultTypes, std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
return inlineRegion(interface, cloneCallback, src, inlinePoint->getBlock(),
++inlinePoint->getIterator(), mapper, resultsToReplace,
regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
}
LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
Block *inlineBlock,
Block::iterator inlinePoint, IRMapping &mapper,
ValueRange resultsToReplace,
TypeRange regionResultTypes,
std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
resultsToReplace, regionResultTypes, inlineLoc,
shouldCloneInlinedRegion);

LogicalResult mlir::inlineRegion(
InlinerInterface &interface,
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback, Region *src,
Block *inlineBlock, Block::iterator inlinePoint, IRMapping &mapper,
ValueRange resultsToReplace, TypeRange regionResultTypes,
std::optional<Location> inlineLoc, bool shouldCloneInlinedRegion) {
return inlineRegionImpl(
interface, cloneCallback, src, inlineBlock, inlinePoint, mapper,
resultsToReplace, regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
}

LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
Operation *inlinePoint,
ValueRange inlinedOperands,
ValueRange resultsToReplace,
std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
return inlineRegion(interface, src, inlinePoint->getBlock(),
LogicalResult mlir::inlineRegion(
InlinerInterface &interface,
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback, Region *src,
Operation *inlinePoint, ValueRange inlinedOperands,
ValueRange resultsToReplace, std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
return inlineRegion(interface, cloneCallback, src, inlinePoint->getBlock(),
++inlinePoint->getIterator(), inlinedOperands,
resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
}
LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
Block *inlineBlock,
Block::iterator inlinePoint,
ValueRange inlinedOperands,
ValueRange resultsToReplace,
std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
return inlineRegionImpl(interface, src, inlineBlock, inlinePoint,
inlinedOperands, resultsToReplace, inlineLoc,
shouldCloneInlinedRegion);

LogicalResult mlir::inlineRegion(
InlinerInterface &interface,
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback, Region *src,
Block *inlineBlock, Block::iterator inlinePoint, ValueRange inlinedOperands,
ValueRange resultsToReplace, std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
return inlineRegionImpl(interface, cloneCallback, src, inlineBlock,
inlinePoint, inlinedOperands, resultsToReplace,
inlineLoc, shouldCloneInlinedRegion);
}

/// Utility function used to generate a cast operation from the given interface,
Expand Down Expand Up @@ -454,10 +453,11 @@ static Value materializeConversion(const DialectInlinerInterface *interface,
/// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
/// corresponds to whether the source region should be cloned into the 'call' or
/// spliced directly.
LogicalResult mlir::inlineCall(InlinerInterface &interface,
CallOpInterface call,
CallableOpInterface callable, Region *src,
bool shouldCloneInlinedRegion) {
LogicalResult
mlir::inlineCall(InlinerInterface &interface,
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
CallOpInterface call, CallableOpInterface callable,
Region *src, bool shouldCloneInlinedRegion) {
// We expect the region to have at least one block.
if (src->empty())
return failure();
Expand Down Expand Up @@ -531,7 +531,7 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface,
return cleanupState();

// Attempt to inline the call.
if (failed(inlineRegionImpl(interface, src, call->getBlock(),
if (failed(inlineRegionImpl(interface, cloneCallback, src, call->getBlock(),
++call->getIterator(), mapper, callResults,
callableResultTypes, call.getLoc(),
shouldCloneInlinedRegion, call)))
Expand Down
Loading