Skip to content

Commit ab78e8c

Browse files
[mlir][Transforms][NFC] Simplify BlockTypeConversionRewrite
1 parent 7e07450 commit ab78e8c

File tree

1 file changed

+40
-40
lines changed

1 file changed

+40
-40
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -746,24 +746,27 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
746746
/// block is returned containing the new arguments. Returns `block` if it did
747747
/// not require conversion.
748748
FailureOr<Block *> convertBlockSignature(
749-
Block *block, const TypeConverter *converter,
749+
ConversionPatternRewriter &rewriter, Block *block,
750+
const TypeConverter *converter,
750751
TypeConverter::SignatureConversion *conversion = nullptr);
751752

752753
/// Convert the types of non-entry block arguments within the given region.
753754
LogicalResult convertNonEntryRegionTypes(
754-
Region *region, const TypeConverter &converter,
755+
ConversionPatternRewriter &rewriter, Region *region,
756+
const TypeConverter &converter,
755757
ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
756758

757759
/// Apply a signature conversion on the given region, using `converter` for
758760
/// materializations if not null.
759761
Block *
760-
applySignatureConversion(Region *region,
762+
applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region,
761763
TypeConverter::SignatureConversion &conversion,
762764
const TypeConverter *converter);
763765

764766
/// Convert the types of block arguments within the given region.
765767
FailureOr<Block *>
766-
convertRegionTypes(Region *region, const TypeConverter &converter,
768+
convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
769+
const TypeConverter &converter,
767770
TypeConverter::SignatureConversion *entryConversion);
768771

769772
/// Apply the given signature conversion on the given block. The new block
@@ -773,7 +776,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
773776
/// translate between the origin argument types and those specified in the
774777
/// signature conversion.
775778
Block *applySignatureConversion(
776-
Block *block, const TypeConverter *converter,
779+
ConversionPatternRewriter &rewriter, Block *block,
780+
const TypeConverter *converter,
777781
TypeConverter::SignatureConversion &signatureConversion);
778782

779783
//===--------------------------------------------------------------------===//
@@ -940,24 +944,10 @@ void BlockTypeConversionRewrite::commit() {
940944
rewriterImpl.mapping.lookupOrDefault(castValue, origArg.getType()));
941945
}
942946
}
943-
944-
assert(origBlock->empty() && "expected empty block");
945-
origBlock->dropAllDefinedValueUses();
946-
delete origBlock;
947-
origBlock = nullptr;
948947
}
949948

950949
void BlockTypeConversionRewrite::rollback() {
951-
// Drop all uses of the new block arguments and replace uses of the new block.
952-
for (int i = block->getNumArguments() - 1; i >= 0; --i)
953-
block->getArgument(i).dropAllUses();
954950
block->replaceAllUsesWith(origBlock);
955-
956-
// Move the operations back the original block, move the original block back
957-
// into its original location and the delete the new block.
958-
origBlock->getOperations().splice(origBlock->end(), block->getOperations());
959-
block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
960-
eraseBlock(block);
961951
}
962952

963953
LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
@@ -1173,10 +1163,11 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
11731163
// Type Conversion
11741164

11751165
FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
1176-
Block *block, const TypeConverter *converter,
1166+
ConversionPatternRewriter &rewriter, Block *block,
1167+
const TypeConverter *converter,
11771168
TypeConverter::SignatureConversion *conversion) {
11781169
if (conversion)
1179-
return applySignatureConversion(block, converter, *conversion);
1170+
return applySignatureConversion(rewriter, block, converter, *conversion);
11801171

11811172
// If a converter wasn't provided, and the block wasn't already converted,
11821173
// there is nothing we can do.
@@ -1185,35 +1176,39 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
11851176

11861177
// Try to convert the signature for the block with the provided converter.
11871178
if (auto conversion = converter->convertBlockSignature(block))
1188-
return applySignatureConversion(block, converter, *conversion);
1179+
return applySignatureConversion(rewriter, block, converter, *conversion);
11891180
return failure();
11901181
}
11911182

11921183
Block *ConversionPatternRewriterImpl::applySignatureConversion(
1193-
Region *region, TypeConverter::SignatureConversion &conversion,
1184+
ConversionPatternRewriter &rewriter, Region *region,
1185+
TypeConverter::SignatureConversion &conversion,
11941186
const TypeConverter *converter) {
11951187
if (!region->empty())
1196-
return *convertBlockSignature(&region->front(), converter, &conversion);
1188+
return *convertBlockSignature(rewriter, &region->front(), converter,
1189+
&conversion);
11971190
return nullptr;
11981191
}
11991192

12001193
FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
1201-
Region *region, const TypeConverter &converter,
1194+
ConversionPatternRewriter &rewriter, Region *region,
1195+
const TypeConverter &converter,
12021196
TypeConverter::SignatureConversion *entryConversion) {
12031197
regionToConverter[region] = &converter;
12041198
if (region->empty())
12051199
return nullptr;
12061200

1207-
if (failed(convertNonEntryRegionTypes(region, converter)))
1201+
if (failed(convertNonEntryRegionTypes(rewriter, region, converter)))
12081202
return failure();
12091203

1210-
FailureOr<Block *> newEntry =
1211-
convertBlockSignature(&region->front(), &converter, entryConversion);
1204+
FailureOr<Block *> newEntry = convertBlockSignature(
1205+
rewriter, &region->front(), &converter, entryConversion);
12121206
return newEntry;
12131207
}
12141208

12151209
LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
1216-
Region *region, const TypeConverter &converter,
1210+
ConversionPatternRewriter &rewriter, Region *region,
1211+
const TypeConverter &converter,
12171212
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
12181213
regionToConverter[region] = &converter;
12191214
if (region->empty())
@@ -1234,16 +1229,18 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
12341229
: const_cast<TypeConverter::SignatureConversion *>(
12351230
&blockConversions[blockIdx++]);
12361231

1237-
if (failed(convertBlockSignature(&block, &converter, blockConversion)))
1232+
if (failed(convertBlockSignature(rewriter, &block, &converter,
1233+
blockConversion)))
12381234
return failure();
12391235
}
12401236
return success();
12411237
}
12421238

12431239
Block *ConversionPatternRewriterImpl::applySignatureConversion(
1244-
Block *block, const TypeConverter *converter,
1240+
ConversionPatternRewriter &rewriter, Block *block,
1241+
const TypeConverter *converter,
12451242
TypeConverter::SignatureConversion &signatureConversion) {
1246-
MLIRContext *ctx = eraseRewriter.getContext();
1243+
MLIRContext *ctx = rewriter.getContext();
12471244

12481245
// If no arguments are being changed or added, there is nothing to do.
12491246
unsigned origArgCount = block->getNumArguments();
@@ -1253,11 +1250,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12531250

12541251
// Split the block at the beginning to get a new block to use for the updated
12551252
// signature.
1256-
Block *newBlock = block->splitBlock(block->begin());
1253+
Block *newBlock = rewriter.splitBlock(block, block->begin());
12571254
block->replaceAllUsesWith(newBlock);
1258-
// Unlink the block, but do not erase it yet, so that the change can be rolled
1259-
// back.
1260-
block->getParent()->getBlocks().remove(block);
12611255

12621256
// Map all new arguments to the location of the argument they originate from.
12631257
SmallVector<Location> newLocs(convertedTypes.size(),
@@ -1333,6 +1327,11 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13331327

13341328
appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
13351329
converter);
1330+
1331+
// Erase the old block. (It is just unlinked for now and will be erased during
1332+
// cleanup.)
1333+
rewriter.eraseBlock(block);
1334+
13361335
return newBlock;
13371336
}
13381337

@@ -1531,7 +1530,7 @@ Block *ConversionPatternRewriter::applySignatureConversion(
15311530
assert(!impl->wasOpReplaced(region->getParentOp()) &&
15321531
"attempting to apply a signature conversion to a block within a "
15331532
"replaced/erased op");
1534-
return impl->applySignatureConversion(region, conversion, converter);
1533+
return impl->applySignatureConversion(*this, region, conversion, converter);
15351534
}
15361535

15371536
FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
@@ -1540,7 +1539,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
15401539
assert(!impl->wasOpReplaced(region->getParentOp()) &&
15411540
"attempting to apply a signature conversion to a block within a "
15421541
"replaced/erased op");
1543-
return impl->convertRegionTypes(region, converter, entryConversion);
1542+
return impl->convertRegionTypes(*this, region, converter, entryConversion);
15441543
}
15451544

15461545
LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
@@ -1549,7 +1548,8 @@ LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
15491548
assert(!impl->wasOpReplaced(region->getParentOp()) &&
15501549
"attempting to apply a signature conversion to a block within a "
15511550
"replaced/erased op");
1552-
return impl->convertNonEntryRegionTypes(region, converter, blockConversions);
1551+
return impl->convertNonEntryRegionTypes(*this, region, converter,
1552+
blockConversions);
15531553
}
15541554

15551555
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
@@ -2051,7 +2051,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
20512051
// If the region of the block has a type converter, try to convert the block
20522052
// directly.
20532053
if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
2054-
if (failed(impl.convertBlockSignature(block, converter))) {
2054+
if (failed(impl.convertBlockSignature(rewriter, block, converter))) {
20552055
LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
20562056
"block"));
20572057
return failure();

0 commit comments

Comments
 (0)