Skip to content

Commit aaf5c81

Browse files
[mlir][Transforms][NFC] Simplify BlockTypeConversionRewrite (#83286)
When a block signature is converted during dialect conversion, a `BlockTypeConversionRewrite` object is stored in the stack of rewrites. Such an object represents multiple steps: - Splitting the old block, i.e., creating a new block and moving all operations over. - Rewriting block arguments. - Erasing the old block. We have dedicated `IRRewrite` objects that represent "creating a block", "moving an op" and "erasing a block". This commit reuses those rewrite objects, so that there is less work to do in `BlockTypeConversionRewrite::rollback` and `BlockTypeConversionRewrite::commit`/`cleanup`. Note: This change is in preparation of adding listener support to the dialect conversion. The less work is done in a `commit` function, the fewer notifications will have to be sent.
1 parent da5966e commit aaf5c81

File tree

1 file changed

+40
-44
lines changed

1 file changed

+40
-44
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 40 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,6 @@ class BlockTypeConversionRewrite : public BlockRewrite {
441441

442442
void commit() override;
443443

444-
void cleanup() override;
445-
446444
void rollback() override;
447445

448446
private:
@@ -791,24 +789,27 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
791789
/// block is returned containing the new arguments. Returns `block` if it did
792790
/// not require conversion.
793791
FailureOr<Block *> convertBlockSignature(
794-
Block *block, const TypeConverter *converter,
792+
ConversionPatternRewriter &rewriter, Block *block,
793+
const TypeConverter *converter,
795794
TypeConverter::SignatureConversion *conversion = nullptr);
796795

797796
/// Convert the types of non-entry block arguments within the given region.
798797
LogicalResult convertNonEntryRegionTypes(
799-
Region *region, const TypeConverter &converter,
798+
ConversionPatternRewriter &rewriter, Region *region,
799+
const TypeConverter &converter,
800800
ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
801801

802802
/// Apply a signature conversion on the given region, using `converter` for
803803
/// materializations if not null.
804804
Block *
805-
applySignatureConversion(Region *region,
805+
applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region,
806806
TypeConverter::SignatureConversion &conversion,
807807
const TypeConverter *converter);
808808

809809
/// Convert the types of block arguments within the given region.
810810
FailureOr<Block *>
811-
convertRegionTypes(Region *region, const TypeConverter &converter,
811+
convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
812+
const TypeConverter &converter,
812813
TypeConverter::SignatureConversion *entryConversion);
813814

814815
/// Apply the given signature conversion on the given block. The new block
@@ -818,7 +819,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
818819
/// translate between the origin argument types and those specified in the
819820
/// signature conversion.
820821
Block *applySignatureConversion(
821-
Block *block, const TypeConverter *converter,
822+
ConversionPatternRewriter &rewriter, Block *block,
823+
const TypeConverter *converter,
822824
TypeConverter::SignatureConversion &signatureConversion);
823825

824826
//===--------------------------------------------------------------------===//
@@ -991,24 +993,8 @@ void BlockTypeConversionRewrite::commit() {
991993
}
992994
}
993995

994-
void BlockTypeConversionRewrite::cleanup() {
995-
assert(origBlock->empty() && "expected empty block");
996-
origBlock->dropAllDefinedValueUses();
997-
delete origBlock;
998-
origBlock = nullptr;
999-
}
1000-
1001996
void BlockTypeConversionRewrite::rollback() {
1002-
// Drop all uses of the new block arguments and replace uses of the new block.
1003-
for (int i = block->getNumArguments() - 1; i >= 0; --i)
1004-
block->getArgument(i).dropAllUses();
1005997
block->replaceAllUsesWith(origBlock);
1006-
1007-
// Move the operations back the original block, move the original block back
1008-
// into its original location and the delete the new block.
1009-
origBlock->getOperations().splice(origBlock->end(), block->getOperations());
1010-
block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
1011-
eraseBlock(block);
1012998
}
1013999

10141000
LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
@@ -1224,10 +1210,11 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
12241210
// Type Conversion
12251211

12261212
FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
1227-
Block *block, const TypeConverter *converter,
1213+
ConversionPatternRewriter &rewriter, Block *block,
1214+
const TypeConverter *converter,
12281215
TypeConverter::SignatureConversion *conversion) {
12291216
if (conversion)
1230-
return applySignatureConversion(block, converter, *conversion);
1217+
return applySignatureConversion(rewriter, block, converter, *conversion);
12311218

12321219
// If a converter wasn't provided, and the block wasn't already converted,
12331220
// there is nothing we can do.
@@ -1236,35 +1223,39 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
12361223

12371224
// Try to convert the signature for the block with the provided converter.
12381225
if (auto conversion = converter->convertBlockSignature(block))
1239-
return applySignatureConversion(block, converter, *conversion);
1226+
return applySignatureConversion(rewriter, block, converter, *conversion);
12401227
return failure();
12411228
}
12421229

12431230
Block *ConversionPatternRewriterImpl::applySignatureConversion(
1244-
Region *region, TypeConverter::SignatureConversion &conversion,
1231+
ConversionPatternRewriter &rewriter, Region *region,
1232+
TypeConverter::SignatureConversion &conversion,
12451233
const TypeConverter *converter) {
12461234
if (!region->empty())
1247-
return *convertBlockSignature(&region->front(), converter, &conversion);
1235+
return *convertBlockSignature(rewriter, &region->front(), converter,
1236+
&conversion);
12481237
return nullptr;
12491238
}
12501239

12511240
FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
1252-
Region *region, const TypeConverter &converter,
1241+
ConversionPatternRewriter &rewriter, Region *region,
1242+
const TypeConverter &converter,
12531243
TypeConverter::SignatureConversion *entryConversion) {
12541244
regionToConverter[region] = &converter;
12551245
if (region->empty())
12561246
return nullptr;
12571247

1258-
if (failed(convertNonEntryRegionTypes(region, converter)))
1248+
if (failed(convertNonEntryRegionTypes(rewriter, region, converter)))
12591249
return failure();
12601250

1261-
FailureOr<Block *> newEntry =
1262-
convertBlockSignature(&region->front(), &converter, entryConversion);
1251+
FailureOr<Block *> newEntry = convertBlockSignature(
1252+
rewriter, &region->front(), &converter, entryConversion);
12631253
return newEntry;
12641254
}
12651255

12661256
LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
1267-
Region *region, const TypeConverter &converter,
1257+
ConversionPatternRewriter &rewriter, Region *region,
1258+
const TypeConverter &converter,
12681259
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
12691260
regionToConverter[region] = &converter;
12701261
if (region->empty())
@@ -1285,16 +1276,18 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
12851276
: const_cast<TypeConverter::SignatureConversion *>(
12861277
&blockConversions[blockIdx++]);
12871278

1288-
if (failed(convertBlockSignature(&block, &converter, blockConversion)))
1279+
if (failed(convertBlockSignature(rewriter, &block, &converter,
1280+
blockConversion)))
12891281
return failure();
12901282
}
12911283
return success();
12921284
}
12931285

12941286
Block *ConversionPatternRewriterImpl::applySignatureConversion(
1295-
Block *block, const TypeConverter *converter,
1287+
ConversionPatternRewriter &rewriter, Block *block,
1288+
const TypeConverter *converter,
12961289
TypeConverter::SignatureConversion &signatureConversion) {
1297-
MLIRContext *ctx = eraseRewriter.getContext();
1290+
MLIRContext *ctx = rewriter.getContext();
12981291

12991292
// If no arguments are being changed or added, there is nothing to do.
13001293
unsigned origArgCount = block->getNumArguments();
@@ -1304,11 +1297,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13041297

13051298
// Split the block at the beginning to get a new block to use for the updated
13061299
// signature.
1307-
Block *newBlock = block->splitBlock(block->begin());
1300+
Block *newBlock = rewriter.splitBlock(block, block->begin());
13081301
block->replaceAllUsesWith(newBlock);
1309-
// Unlink the block, but do not erase it yet, so that the change can be rolled
1310-
// back.
1311-
block->getParent()->getBlocks().remove(block);
13121302

13131303
// Map all new arguments to the location of the argument they originate from.
13141304
SmallVector<Location> newLocs(convertedTypes.size(),
@@ -1384,6 +1374,11 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13841374

13851375
appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
13861376
converter);
1377+
1378+
// Erase the old block. (It is just unlinked for now and will be erased during
1379+
// cleanup.)
1380+
rewriter.eraseBlock(block);
1381+
13871382
return newBlock;
13881383
}
13891384

@@ -1592,7 +1587,7 @@ Block *ConversionPatternRewriter::applySignatureConversion(
15921587
assert(!impl->wasOpReplaced(region->getParentOp()) &&
15931588
"attempting to apply a signature conversion to a block within a "
15941589
"replaced/erased op");
1595-
return impl->applySignatureConversion(region, conversion, converter);
1590+
return impl->applySignatureConversion(*this, region, conversion, converter);
15961591
}
15971592

15981593
FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
@@ -1601,7 +1596,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
16011596
assert(!impl->wasOpReplaced(region->getParentOp()) &&
16021597
"attempting to apply a signature conversion to a block within a "
16031598
"replaced/erased op");
1604-
return impl->convertRegionTypes(region, converter, entryConversion);
1599+
return impl->convertRegionTypes(*this, region, converter, entryConversion);
16051600
}
16061601

16071602
LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
@@ -1610,7 +1605,8 @@ LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
16101605
assert(!impl->wasOpReplaced(region->getParentOp()) &&
16111606
"attempting to apply a signature conversion to a block within a "
16121607
"replaced/erased op");
1613-
return impl->convertNonEntryRegionTypes(region, converter, blockConversions);
1608+
return impl->convertNonEntryRegionTypes(*this, region, converter,
1609+
blockConversions);
16141610
}
16151611

16161612
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
@@ -2104,7 +2100,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
21042100
// If the region of the block has a type converter, try to convert the block
21052101
// directly.
21062102
if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
2107-
if (failed(impl.convertBlockSignature(block, converter))) {
2103+
if (failed(impl.convertBlockSignature(rewriter, block, converter))) {
21082104
LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
21092105
"block"));
21102106
return failure();

0 commit comments

Comments
 (0)