Skip to content

[mlir][IR] Change notifyBlockCreated to notifyBlockInserted #79472

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -735,9 +735,10 @@ struct HLFIRListener : public mlir::OpBuilder::Listener {
builder.notifyOperationInserted(op, previous);
rewriter.notifyOperationInserted(op, previous);
}
virtual void notifyBlockCreated(mlir::Block *block) override {
builder.notifyBlockCreated(block);
rewriter.notifyBlockCreated(block);
virtual void notifyBlockInserted(mlir::Block *block, mlir::Region *previous,
mlir::Region::iterator previousIt) override {
builder.notifyBlockInserted(block, previous, previousIt);
rewriter.notifyBlockInserted(block, previous, previousIt);
}
fir::FirOpBuilder &builder;
mlir::ConversionPatternRewriter &rewriter;
Expand Down
10 changes: 9 additions & 1 deletion mlir/include/mlir/IR/Builders.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,15 @@ class OpBuilder : public Builder {
virtual void notifyOperationInserted(Operation *op, InsertPoint previous) {}

/// Notify the listener that the specified block was inserted.
virtual void notifyBlockCreated(Block *block) {}
///
/// * If the block was moved, then `previous` and `previousIt` are the
/// previous location of the block.
/// * If the block was unlinked before it was inserted, then `previous`
/// is "nullptr".
///
/// Note: Creating an (unlinked) block does not trigger this notification.
virtual void notifyBlockInserted(Block *block, Region *previous,
Region::iterator previousIt) {}

protected:
Listener(Kind kind) : ListenerBase(kind) {}
Expand Down
9 changes: 5 additions & 4 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,9 @@ class RewriterBase : public OpBuilder {
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
listener->notifyOperationInserted(op, previous);
}
void notifyBlockCreated(Block *block) override {
listener->notifyBlockCreated(block);
void notifyBlockInserted(Block *block, Region *previous,
Region::iterator previousIt) override {
listener->notifyBlockInserted(block, previous, previousIt);
}
void notifyBlockRemoved(Block *block) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
Expand Down Expand Up @@ -495,8 +496,8 @@ class RewriterBase : public OpBuilder {
/// another region "parent". The two regions must be different. The caller
/// is responsible for creating or updating the operation transferring flow
/// of control to the region and passing it the correct block arguments.
virtual void inlineRegionBefore(Region &region, Region &parent,
Region::iterator before);
void inlineRegionBefore(Region &region, Region &parent,
Region::iterator before);
void inlineRegionBefore(Region &region, Block *before);

/// Clone the blocks that belong to "region" before the given position in
Expand Down
8 changes: 2 additions & 6 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,8 @@ class ConversionPatternRewriter final : public PatternRewriter,
void eraseBlock(Block *block) override;

/// PatternRewriter hook creating a new block.
void notifyBlockCreated(Block *block) override;
void notifyBlockInserted(Block *block, Region *previous,
Region::iterator previousIt) override;

/// PatternRewriter hook for splitting a block into two parts.
Block *splitBlock(Block *block, Block::iterator before) override;
Expand All @@ -723,11 +724,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
ValueRange argValues = std::nullopt) override;
using PatternRewriter::inlineBlockBefore;

/// PatternRewriter hook for moving blocks out of a region.
void inlineRegionBefore(Region &region, Region &parent,
Region::iterator before) override;
using PatternRewriter::inlineRegionBefore;

/// PatternRewriter hook for cloning blocks of one region into another. The
/// given region to clone *must* not have been modified as part of conversion
/// yet, i.e. it must be within an operation that is either in the process of
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/IR/Builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
setInsertionPointToEnd(b);

if (listener)
listener->notifyBlockCreated(b);
listener->notifyBlockInserted(b, /*previous=*/nullptr, /*previousIt=*/{});
return b;
}

Expand Down
13 changes: 12 additions & 1 deletion mlir/lib/IR/PatternMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,18 @@ Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
/// region and pass it the correct block arguments.
void RewriterBase::inlineRegionBefore(Region &region, Region &parent,
Region::iterator before) {
parent.getBlocks().splice(before, region.getBlocks());
// Fast path: If no listener is attached, move all blocks at once.
if (!listener) {
parent.getBlocks().splice(before, region.getBlocks());
return;
}

// Move blocks from the beginning of the region one-by-one.
while (!region.empty()) {
Block *block = &region.front();
parent.getBlocks().splice(before, region.getBlocks(), block->getIterator());
listener->notifyBlockInserted(block, &region, region.begin());
}
}
void RewriterBase::inlineRegionBefore(Region &region, Block *before) {
inlineRegionBefore(region, *before->getParent(), before->getIterator());
Expand Down
66 changes: 25 additions & 41 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,10 @@ enum class BlockActionKind {
};

/// Original position of the given block in its parent region. During undo
/// actions, the block needs to be placed after `insertAfterBlock`.
/// actions, the block needs to be placed before `insertBeforeBlock`.
struct BlockPosition {
Region *region;
Block *insertAfterBlock;
Block *insertBeforeBlock;
};

/// Information needed to undo inlining actions.
Expand Down Expand Up @@ -910,7 +910,8 @@ struct ConversionPatternRewriterImpl {
void notifyBlockIsBeingErased(Block *block);

/// Notifies that a block was created.
void notifyCreatedBlock(Block *block);
void notifyInsertedBlock(Block *block, Region *previous,
Region::iterator previousIt);

/// Notifies that a block was split.
void notifySplitBlock(Block *block, Block *continuation);
Expand All @@ -919,10 +920,6 @@ struct ConversionPatternRewriterImpl {
void notifyBlockBeingInlined(Block *block, Block *srcBlock,
Block::iterator before);

/// Notifies that the blocks of a region are about to be moved.
void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent,
Region::iterator before);

/// Notifies that a pattern match failed for the given reason.
LogicalResult
notifyMatchFailure(Location loc,
Expand Down Expand Up @@ -1173,10 +1170,9 @@ void ConversionPatternRewriterImpl::undoBlockActions(
// Put the block (owned by action) back into its original position.
case BlockActionKind::Erase: {
auto &blockList = action.originalPosition.region->getBlocks();
Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
blockList.insert((insertAfterBlock
? std::next(Region::iterator(insertAfterBlock))
: blockList.begin()),
Block *insertBeforeBlock = action.originalPosition.insertBeforeBlock;
blockList.insert((insertBeforeBlock ? Region::iterator(insertBeforeBlock)
: blockList.end()),
action.block);
break;
}
Expand All @@ -1196,10 +1192,10 @@ void ConversionPatternRewriterImpl::undoBlockActions(
// Move the block back to its original position.
case BlockActionKind::Move: {
Region *originalRegion = action.originalPosition.region;
Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
Block *insertBeforeBlock = action.originalPosition.insertBeforeBlock;
originalRegion->getBlocks().splice(
(insertAfterBlock ? std::next(Region::iterator(insertAfterBlock))
: originalRegion->end()),
(insertBeforeBlock ? Region::iterator(insertBeforeBlock)
: originalRegion->end()),
action.block->getParent()->getBlocks(), action.block);
break;
}
Expand Down Expand Up @@ -1398,12 +1394,19 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,

void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
Region *region = block->getParent();
Block *origPrevBlock = block->getPrevNode();
blockActions.push_back(BlockAction::getErase(block, {region, origPrevBlock}));
Block *origNextBlock = block->getNextNode();
blockActions.push_back(BlockAction::getErase(block, {region, origNextBlock}));
}

void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) {
blockActions.push_back(BlockAction::getCreate(block));
void ConversionPatternRewriterImpl::notifyInsertedBlock(
Block *block, Region *previous, Region::iterator previousIt) {
if (!previous) {
// This is a newly created block.
blockActions.push_back(BlockAction::getCreate(block));
return;
}
Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
blockActions.push_back(BlockAction::getMove(block, {previous, prevBlock}));
}

void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
Expand All @@ -1416,19 +1419,6 @@ void ConversionPatternRewriterImpl::notifyBlockBeingInlined(
blockActions.push_back(BlockAction::getInline(block, srcBlock, before));
}

void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
Region &region, Region &parent, Region::iterator before) {
if (region.empty())
return;
Block *laterBlock = &region.back();
for (auto &earlierBlock : llvm::drop_begin(llvm::reverse(region), 1)) {
blockActions.push_back(
BlockAction::getMove(laterBlock, {&region, &earlierBlock}));
laterBlock = &earlierBlock;
}
blockActions.push_back(BlockAction::getMove(laterBlock, {&region, nullptr}));
}

LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
LLVM_DEBUG({
Expand Down Expand Up @@ -1551,8 +1541,9 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
results);
}

void ConversionPatternRewriter::notifyBlockCreated(Block *block) {
impl->notifyCreatedBlock(block);
void ConversionPatternRewriter::notifyBlockInserted(
Block *block, Region *previous, Region::iterator previousIt) {
impl->notifyInsertedBlock(block, previous, previousIt);
}

Block *ConversionPatternRewriter::splitBlock(Block *block,
Expand Down Expand Up @@ -1582,13 +1573,6 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
eraseBlock(source);
}

void ConversionPatternRewriter::inlineRegionBefore(Region &region,
Region &parent,
Region::iterator before) {
impl->notifyRegionIsBeingInlinedBefore(region, parent, before);
PatternRewriter::inlineRegionBefore(region, parent, before);
}

void ConversionPatternRewriter::cloneRegionBefore(Region &region,
Region &parent,
Region::iterator before,
Expand All @@ -1600,7 +1584,7 @@ void ConversionPatternRewriter::cloneRegionBefore(Region &region,

for (Block &b : ForwardDominanceIterator<>::makeIterable(region)) {
Block *cloned = mapping.lookup(&b);
impl->notifyCreatedBlock(cloned);
impl->notifyInsertedBlock(cloned, /*previous=*/nullptr, /*previousIt=*/{});
cloned->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
[&](Operation *op) { notifyOperationInserted(op, /*previous=*/{}); });
}
Expand Down
10 changes: 6 additions & 4 deletions mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,9 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
/// simplifications.
void addOperandsToWorklist(ValueRange operands);

/// Notify the driver that the given block was created.
void notifyBlockCreated(Block *block) override;
/// Notify the driver that the given block was inserted.
void notifyBlockInserted(Block *block, Region *previous,
Region::iterator previousIt) override;

/// Notify the driver that the given block is about to be removed.
void notifyBlockRemoved(Block *block) override;
Expand Down Expand Up @@ -638,9 +639,10 @@ void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
worklist.push(op);
}

void GreedyPatternRewriteDriver::notifyBlockCreated(Block *block) {
void GreedyPatternRewriteDriver::notifyBlockInserted(
Block *block, Region *previous, Region::iterator previousIt) {
if (config.listener)
config.listener->notifyBlockCreated(block);
config.listener->notifyBlockInserted(block, previous, previousIt);
}

void GreedyPatternRewriteDriver::notifyBlockRemoved(Block *block) {
Expand Down