Skip to content

Commit 3216b26

Browse files
[mlir][Transforms] Keep track of nested ignored/replaced ops
The dialect conversion maintains sets of "ignored" and "replaced" ops. This change simplifies the two sets, such that all nested ops are included. (This was previously not the case and sometimes only the parent op was included.) This change allows for more aggressive assertions to prevent incorrect rewriter API usage. E.g., accessing ops/blocks/regions within an erased op. A concrete example: I have seen conversion patterns in downstream projects where an op is replaced with a new op, and the region of the old op is afterwards inlined into the newly created op. This is invalid rewriter API usage: ops that were replaced/erased should not be accessed. Nested ops will be considered "ignored", even if they are moved to a different region after the region's parent op was erased (which is illegal API usage). Instead, create a new op, inline the regions, then replace the old op with the new op. BEGIN_PUBLIC No commit message needed for presubmit. END_PUBLIC
1 parent 6a884a9 commit 3216b26

File tree

2 files changed

+52
-39
lines changed

2 files changed

+52
-39
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -798,13 +798,12 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
798798
PatternRewriter &rewriter, ValueRange values,
799799
SmallVectorImpl<Value> &remapped);
800800

801-
/// Returns true if the given operation is ignored, and does not need to be
801+
/// Return "true" if the given operation is ignored, and does not need to be
802802
/// converted.
803803
bool isOpIgnored(Operation *op) const;
804804

805-
/// Recursively marks the nested operations under 'op' as ignored. This
806-
/// removes them from being considered for legalization.
807-
void markNestedOpsIgnored(Operation *op);
805+
/// Return "true" if the given operation was replaced or erased.
806+
bool wasOpReplaced(Operation *op) const;
808807

809808
//===--------------------------------------------------------------------===//
810809
// Type Conversion
@@ -946,18 +945,15 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
946945
/// Ordered list of block operations (creations, splits, motions).
947946
SmallVector<std::unique_ptr<IRRewrite>> rewrites;
948947

949-
/// A set of operations that should no longer be considered for legalization,
950-
/// but were not directly replace/erased/etc. by a pattern. These are
951-
/// generally child operations of other operations who were
952-
/// replaced/erased/etc. This is not meant to be an exhaustive list of all
953-
/// operations, but the minimal set that can be used to detect if a given
954-
/// operation should be `ignored`. For example, we may add the operations that
955-
/// define non-empty regions to the set, but not any of the others. This
956-
/// simplifies the amount of memory needed as we can query if the parent
957-
/// operation was ignored.
948+
/// A set of operations that should no longer be considered for legalization.
949+
/// E.g., ops that are recursively legal. Ops that were replaced/erased are
950+
/// tracked separately.
958951
SetVector<Operation *> ignoredOps;
959952

960-
// A set of operations that were erased.
953+
/// A set of operations that were replaced/erased. Such ops are not erased
954+
/// immediately but only when the dialect conversion succeeds. In the mean
955+
/// time, they should no longer be considered for legalization and any attempt
956+
/// to modify/access them is invalid rewriter API usage.
961957
SetVector<Operation *> replacedOps;
962958

963959
/// The current type converter, or nullptr if no type converter is currently
@@ -1237,24 +1233,14 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
12371233
return success();
12381234
}
12391235

1240-
// TODO: This function is a misnomer. It does not actually check if `op` is in
1241-
// `ignoredOps`.
12421236
bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
1243-
// Check to see if this operation or the parent operation is ignored.
1244-
return ignoredOps.count(op->getParentOp()) || replacedOps.count(op);
1237+
// Check to see if this operation is ignored or was replaced.
1238+
return replacedOps.count(op) || ignoredOps.count(op);
12451239
}
12461240

1247-
void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
1248-
// Walk this operation and collect nested operations that define non-empty
1249-
// regions. We mark such operations as 'ignored' so that we know we don't have
1250-
// to convert them, or their nested ops.
1251-
if (op->getNumRegions() == 0)
1252-
return;
1253-
op->walk([&](Operation *op) {
1254-
if (llvm::any_of(op->getRegions(),
1255-
[](Region &region) { return !region.empty(); }))
1256-
ignoredOps.insert(op);
1257-
});
1241+
bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
1242+
// Check to see if this operation was replaced.
1243+
return replacedOps.count(op);
12581244
}
12591245

12601246
//===----------------------------------------------------------------------===//
@@ -1476,6 +1462,9 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
14761462
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
14771463
<< ")\n";
14781464
});
1465+
assert(!wasOpReplaced(op) &&
1466+
"attempting to insert into a block within a replaced/erased op");
1467+
14791468
if (!previous.isSet()) {
14801469
// This is a newly created op.
14811470
appendRewrite<CreateOperationRewrite>(op);
@@ -1490,7 +1479,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
14901479
void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
14911480
ValueRange newValues) {
14921481
assert(newValues.size() == op->getNumResults());
1493-
assert(!replacedOps.contains(op) && "operation was already replaced");
1482+
assert(!ignoredOps.contains(op) && "operation was already replaced");
14941483

14951484
// Track if any of the results changed, e.g. erased and replaced with null.
14961485
bool resultChanged = false;
@@ -1509,10 +1498,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
15091498
appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter,
15101499
resultChanged);
15111500

1512-
// Mark this operation as recursively ignored so that we don't need to
1513-
// convert any nested operations.
1514-
replacedOps.insert(op);
1515-
markNestedOpsIgnored(op);
1501+
// Mark this operation and all nested ops as replaced.
1502+
op->walk([&](Operation *op) { replacedOps.insert(op); });
15161503
}
15171504

15181505
void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
@@ -1604,6 +1591,9 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
16041591
}
16051592

16061593
void ConversionPatternRewriter::eraseBlock(Block *block) {
1594+
assert(!impl->wasOpReplaced(block->getParentOp()) &&
1595+
"attempting to erase a block within a replaced/erased op");
1596+
16071597
// Mark all ops for erasure.
16081598
for (Operation &op : *block)
16091599
eraseOp(&op);
@@ -1619,18 +1609,27 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
16191609
Block *ConversionPatternRewriter::applySignatureConversion(
16201610
Region *region, TypeConverter::SignatureConversion &conversion,
16211611
const TypeConverter *converter) {
1612+
assert(!impl->wasOpReplaced(region->getParentOp()) &&
1613+
"attempting to apply a signature conversion to a block within a "
1614+
"replaced/erased op");
16221615
return impl->applySignatureConversion(region, conversion, converter);
16231616
}
16241617

16251618
FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
16261619
Region *region, const TypeConverter &converter,
16271620
TypeConverter::SignatureConversion *entryConversion) {
1621+
assert(!impl->wasOpReplaced(region->getParentOp()) &&
1622+
"attempting to apply a signature conversion to a block within a "
1623+
"replaced/erased op");
16281624
return impl->convertRegionTypes(region, converter, entryConversion);
16291625
}
16301626

16311627
LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
16321628
Region *region, const TypeConverter &converter,
16331629
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
1630+
assert(!impl->wasOpReplaced(region->getParentOp()) &&
1631+
"attempting to apply a signature conversion to a block within a "
1632+
"replaced/erased op");
16341633
return impl->convertNonEntryRegionTypes(region, converter, blockConversions);
16351634
}
16361635

@@ -1665,6 +1664,8 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
16651664

16661665
Block *ConversionPatternRewriter::splitBlock(Block *block,
16671666
Block::iterator before) {
1667+
assert(!impl->wasOpReplaced(block->getParentOp()) &&
1668+
"attempting to split a block within a replaced/erased op");
16681669
auto *continuation = block->splitBlock(before);
16691670
impl->notifySplitBlock(block, continuation);
16701671
return continuation;
@@ -1673,15 +1674,19 @@ Block *ConversionPatternRewriter::splitBlock(Block *block,
16731674
void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
16741675
Block::iterator before,
16751676
ValueRange argValues) {
1677+
#ifndef NDEBUG
16761678
assert(argValues.size() == source->getNumArguments() &&
16771679
"incorrect # of argument replacement values");
1678-
#ifndef NDEBUG
1680+
assert(!impl->wasOpReplaced(source->getParentOp()) &&
1681+
"attempting to inline a block from a replaced/erased op");
1682+
assert(!impl->wasOpReplaced(dest->getParentOp()) &&
1683+
"attempting to inline a block into a replaced/erased op");
16791684
auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
1680-
#endif // NDEBUG
16811685
// The source block will be deleted, so it should not have any users (i.e.,
16821686
// there should be no predecessors).
16831687
assert(llvm::all_of(source->getUsers(), opIgnored) &&
16841688
"expected 'source' to have no predecessors");
1689+
#endif // NDEBUG
16851690

16861691
impl->notifyBlockBeingInlined(dest, source, before);
16871692
for (auto it : llvm::zip(source->getArguments(), argValues))
@@ -1691,13 +1696,17 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
16911696
}
16921697

16931698
void ConversionPatternRewriter::startOpModification(Operation *op) {
1699+
assert(!impl->wasOpReplaced(op) &&
1700+
"attempting to modify a replaced/erased op");
16941701
#ifndef NDEBUG
16951702
impl->pendingRootUpdates.insert(op);
16961703
#endif
16971704
impl->appendRewrite<ModifyOperationRewrite>(op);
16981705
}
16991706

17001707
void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
1708+
assert(!impl->wasOpReplaced(op) &&
1709+
"attempting to modify a replaced/erased op");
17011710
PatternRewriter::finalizeOpModification(op);
17021711
// There is nothing to do here, we only need to track the operation at the
17031712
// start of the update.
@@ -1912,8 +1921,13 @@ OperationLegalizer::legalize(Operation *op,
19121921

19131922
// If this operation is recursively legal, mark its children as ignored so
19141923
// that we don't consider them for legalization.
1915-
if (legalityInfo->isRecursivelyLegal)
1916-
rewriter.getImpl().markNestedOpsIgnored(op);
1924+
if (legalityInfo->isRecursivelyLegal) {
1925+
op->walk([&](Operation *nested) {
1926+
if (op != nested)
1927+
rewriter.getImpl().ignoredOps.insert(nested);
1928+
});
1929+
}
1930+
19171931
return success();
19181932
}
19191933

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1768,7 +1768,6 @@ struct TestMergeSingleBlockOps
17681768
rewriter.inlineBlockBefore(&innerBlock, op);
17691769
rewriter.eraseOp(innerTerminator);
17701770
rewriter.eraseOp(op);
1771-
rewriter.modifyOpInPlace(op, [] {});
17721771
return success();
17731772
}
17741773
};

0 commit comments

Comments
 (0)