Skip to content

Commit 26be7a8

Browse files
[mlir][IR] Trigger nested operation/block insertion notifications for clones
* Trigger "operation inserted" and "block inserted" notifications when cloning nested ops. (No such notifications were sent for nested ops when running `OpBuilder::clone(Operation *)` or `RewriterBase::cloneRegionBefore` so far.) * `cloneRegionBefore` is moved from `RewriterBase` to `OpBuilder`. (`cloneRegionBefore` just builds IR, it does not modify/erase ops.) * `cloneRegionBefore` is no longer virtual. It used to be virtual so that the dialect conversion can override it and update the internal state in the correct order (for nested ops). The internal state is updated based on listener notifications. Now that the listener notifications are sent for nested clones (and sent in the right order), this workaround in the dialect conversion is no longer needed. Details: Ops/blocks are notified in such an order in which they could have been created one-by-one with an `OpBuilder`. This ensures that when a listener is notified about an op being created, it was already notified about the defining ops of the operands (unless there is a block cycle/graph region). The implementation first clones an entire op (including nested ops) and then sends all notifications. Ideally, notifications should be interleaved with the cloning process, but that would require duplicating `Region::cloneInto` (with listener support). This commit is an incremental improvement over not sending any notifications. There is a "fast path" in the `clone` functions in case no listener is attached. No IR traversal is needed in that case. Imported from Phabricator: https://reviews.llvm.org/D146943 BEGIN_PUBLIC No public commit message needed for presubmit. END_PUBLIC
1 parent a2046ca commit 26be7a8

File tree

8 files changed

+366
-77
lines changed

8 files changed

+366
-77
lines changed

mlir/include/mlir/IR/Builders.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,16 @@ class OpBuilder : public Builder {
451451
Block *createBlock(Block *insertBefore, TypeRange argTypes = std::nullopt,
452452
ArrayRef<Location> locs = std::nullopt);
453453

454+
/// Clone the blocks that belong to "region" before the given position in
455+
/// another region "parent". The two regions must be different. The caller is
456+
/// responsible for creating or updating the operation transferring flow of
457+
/// control to the region and passing it the correct block arguments.
458+
void cloneRegionBefore(Region &region, Region &parent,
459+
Region::iterator before, IRMapping &mapping);
460+
void cloneRegionBefore(Region &region, Region &parent,
461+
Region::iterator before);
462+
void cloneRegionBefore(Region &region, Block *before);
463+
454464
//===--------------------------------------------------------------------===//
455465
// Operation Creation
456466
//===--------------------------------------------------------------------===//

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -489,16 +489,6 @@ class RewriterBase : public OpBuilder {
489489
Region::iterator before);
490490
void inlineRegionBefore(Region &region, Block *before);
491491

492-
/// Clone the blocks that belong to "region" before the given position in
493-
/// another region "parent". The two regions must be different. The caller is
494-
/// responsible for creating or updating the operation transferring flow of
495-
/// control to the region and passing it the correct block arguments.
496-
virtual void cloneRegionBefore(Region &region, Region &parent,
497-
Region::iterator before, IRMapping &mapping);
498-
void cloneRegionBefore(Region &region, Region &parent,
499-
Region::iterator before);
500-
void cloneRegionBefore(Region &region, Block *before);
501-
502492
/// This method replaces the uses of the results of `op` with the values in
503493
/// `newValues` when the provided `functor` returns true for a specific use.
504494
/// The number of values in `newValues` is required to match the number of

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -727,14 +727,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
727727
Region::iterator before) override;
728728
using PatternRewriter::inlineRegionBefore;
729729

730-
/// PatternRewriter hook for cloning blocks of one region into another. The
731-
/// given region to clone *must* not have been modified as part of conversion
732-
/// yet, i.e. it must be within an operation that is either in the process of
733-
/// conversion, or has not yet been converted.
734-
void cloneRegionBefore(Region &region, Region &parent,
735-
Region::iterator before, IRMapping &mapping) override;
736-
using PatternRewriter::cloneRegionBefore;
737-
738730
/// PatternRewriter hook for inserting a new operation.
739731
void notifyOperationInserted(Operation *op) override;
740732

mlir/lib/IR/Builders.cpp

Lines changed: 87 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
#include "mlir/IR/IRMapping.h"
1515
#include "mlir/IR/IntegerSet.h"
1616
#include "mlir/IR/Matchers.h"
17+
#include "mlir/IR/RegionGraphTraits.h"
1718
#include "mlir/IR/SymbolTable.h"
19+
#include "llvm/ADT/PostOrderIterator.h"
20+
#include "llvm/ADT/SetVector.h"
21+
#include "llvm/ADT/SmallPtrSet.h"
1822
#include "llvm/ADT/SmallVectorExtras.h"
1923
#include "llvm/Support/raw_ostream.h"
2024

@@ -525,22 +529,95 @@ LogicalResult OpBuilder::tryFold(Operation *op,
525529
return success();
526530
}
527531

532+
static void notifyOperationCloned(Operation *op, OpBuilder::Listener *listener);
533+
534+
/// Notify the listener that the given range of regions was cloned.
535+
///
536+
/// Blocks and operations are enumerated in an order in which they could have
537+
/// been created separately (without using the `clone` API): Within a region,
538+
/// blocks are notified according to their successor relationship. Within a
539+
/// block, operations are notified in forward mode. This ensures that defining
540+
/// ops are notified before ops that use their results (unless there are
541+
/// cycles/graph regions).
542+
static void notifyRegionsCloned(iterator_range<Region::iterator> range,
543+
OpBuilder::Listener *listener) {
544+
// Maintain a set of blocks that were not notified yet. This is needed because
545+
// the inverse_post_order iterator does not enumerate dead blocks.
546+
llvm::SetVector<Block *> remainingBlocks;
547+
// The order in which the set is initialized does not matter for correctness.
548+
// For better performance, "leaf" blocks with no successors should be starting
549+
// point for the block traversal. (Then there are fewer iterations of the
550+
// "while" loop.)
551+
for (Block &b : llvm::reverse(range))
552+
remainingBlocks.insert(&b);
553+
// Set of visited blocks that is shared among all inverse_post_order
554+
// iterations. This is to avoid that the same block is enumerated multiple
555+
// times.
556+
llvm::SmallPtrSet<Block *, 4> visited;
557+
while (!remainingBlocks.empty()) {
558+
// Enumerate predecessors before successors. I.e., reverse post-order.
559+
for (Block *b :
560+
llvm::inverse_post_order_ext(remainingBlocks.front(), visited)) {
561+
auto it = llvm::find(remainingBlocks, b);
562+
assert(it != remainingBlocks.end() &&
563+
"expected that only remaining blocks are visited");
564+
listener->notifyBlockCreated(b);
565+
remainingBlocks.erase(it);
566+
for (Operation &op : *b)
567+
notifyOperationCloned(&op, listener);
568+
}
569+
}
570+
}
571+
572+
/// Notify the listener that the given op was cloned.
573+
static void notifyOperationCloned(Operation *op,
574+
OpBuilder::Listener *listener) {
575+
listener->notifyOperationInserted(op);
576+
for (Region &r : op->getRegions())
577+
notifyRegionsCloned(r.getBlocks(), listener);
578+
}
579+
528580
Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
581+
// TODO: The listener notifications should be interleaved with `clone`.
529582
Operation *newOp = op.clone(mapper);
530-
// The `insert` call below handles the notification for inserting `newOp`
531-
// itself. But if `newOp` has any regions, we need to notify the listener
532-
// about any ops that got inserted inside those regions as part of cloning.
533-
if (listener) {
534-
auto walkFn = [&](Operation *walkedOp) {
535-
listener->notifyOperationInserted(walkedOp);
536-
};
537-
for (Region &region : newOp->getRegions())
538-
region.walk(walkFn);
539-
}
583+
584+
// Fast path: If no listener is attached, the op can be inserted directly.
585+
if (!listener)
586+
return insert(newOp);
587+
588+
// The `insert` call below handles the notification for inserting `newOp`.
589+
// Just notify about nested op/block insertion.
590+
for (Region &r : newOp->getRegions())
591+
notifyRegionsCloned(r.getBlocks(), listener);
540592
return insert(newOp);
541593
}
542594

543595
Operation *OpBuilder::clone(Operation &op) {
544596
IRMapping mapper;
545597
return clone(op, mapper);
546598
}
599+
600+
void OpBuilder::cloneRegionBefore(Region &region, Region &parent,
601+
Region::iterator before, IRMapping &mapping) {
602+
// TODO: The listener notifications should be interleaved with `clone`.
603+
region.cloneInto(&parent, before, mapping);
604+
605+
// Fast path: If no listener is attached, there is no more work to do.
606+
if (!listener)
607+
return;
608+
609+
// Notify about op/block insertion.
610+
Region::iterator clonedBeginIt =
611+
mapping.lookup(&region.front())->getIterator();
612+
notifyRegionsCloned(llvm::make_range(clonedBeginIt, before), listener);
613+
}
614+
615+
void OpBuilder::cloneRegionBefore(Region &region, Region &parent,
616+
Region::iterator before) {
617+
IRMapping mapping;
618+
cloneRegionBefore(region, parent, before, mapping);
619+
}
620+
621+
void OpBuilder::cloneRegionBefore(Region &region, Block *before) {
622+
cloneRegionBefore(region, *before->getParent(), before->getIterator());
623+
}

mlir/lib/IR/PatternMatch.cpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -458,21 +458,3 @@ void RewriterBase::inlineRegionBefore(Region &region, Region &parent,
458458
void RewriterBase::inlineRegionBefore(Region &region, Block *before) {
459459
inlineRegionBefore(region, *before->getParent(), before->getIterator());
460460
}
461-
462-
/// Clone the blocks that belong to "region" before the given position in
463-
/// another region "parent". The two regions must be different. The caller is
464-
/// responsible for creating or updating the operation transferring flow of
465-
/// control to the region and passing it the correct block arguments.
466-
void RewriterBase::cloneRegionBefore(Region &region, Region &parent,
467-
Region::iterator before,
468-
IRMapping &mapping) {
469-
region.cloneInto(&parent, before, mapping);
470-
}
471-
void RewriterBase::cloneRegionBefore(Region &region, Region &parent,
472-
Region::iterator before) {
473-
IRMapping mapping;
474-
cloneRegionBefore(region, parent, before, mapping);
475-
}
476-
void RewriterBase::cloneRegionBefore(Region &region, Block *before) {
477-
cloneRegionBefore(region, *before->getParent(), before->getIterator());
478-
}

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1588,23 +1588,6 @@ void ConversionPatternRewriter::inlineRegionBefore(Region &region,
15881588
PatternRewriter::inlineRegionBefore(region, parent, before);
15891589
}
15901590

1591-
void ConversionPatternRewriter::cloneRegionBefore(Region &region,
1592-
Region &parent,
1593-
Region::iterator before,
1594-
IRMapping &mapping) {
1595-
if (region.empty())
1596-
return;
1597-
1598-
PatternRewriter::cloneRegionBefore(region, parent, before, mapping);
1599-
1600-
for (Block &b : ForwardDominanceIterator<>::makeIterable(region)) {
1601-
Block *cloned = mapping.lookup(&b);
1602-
impl->notifyCreatedBlock(cloned);
1603-
cloned->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
1604-
[&](Operation *op) { notifyOperationInserted(op); });
1605-
}
1606-
}
1607-
16081591
void ConversionPatternRewriter::notifyOperationInserted(Operation *op) {
16091592
LLVM_DEBUG({
16101593
impl->logger.startLine()

0 commit comments

Comments
 (0)