@@ -441,8 +441,6 @@ class BlockTypeConversionRewrite : public BlockRewrite {
441
441
442
442
void commit () override ;
443
443
444
- void cleanup () override ;
445
-
446
444
void rollback () override ;
447
445
448
446
private:
@@ -791,24 +789,27 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
791
789
// / block is returned containing the new arguments. Returns `block` if it did
792
790
// / not require conversion.
793
791
FailureOr<Block *> convertBlockSignature (
794
- Block *block, const TypeConverter *converter,
792
+ ConversionPatternRewriter &rewriter, Block *block,
793
+ const TypeConverter *converter,
795
794
TypeConverter::SignatureConversion *conversion = nullptr );
796
795
797
796
// / Convert the types of non-entry block arguments within the given region.
798
797
LogicalResult convertNonEntryRegionTypes (
799
- Region *region, const TypeConverter &converter,
798
+ ConversionPatternRewriter &rewriter, Region *region,
799
+ const TypeConverter &converter,
800
800
ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
801
801
802
802
// / Apply a signature conversion on the given region, using `converter` for
803
803
// / materializations if not null.
804
804
Block *
805
- applySignatureConversion (Region *region,
805
+ applySignatureConversion (ConversionPatternRewriter &rewriter, Region *region,
806
806
TypeConverter::SignatureConversion &conversion,
807
807
const TypeConverter *converter);
808
808
809
809
// / Convert the types of block arguments within the given region.
810
810
FailureOr<Block *>
811
- convertRegionTypes (Region *region, const TypeConverter &converter,
811
+ convertRegionTypes (ConversionPatternRewriter &rewriter, Region *region,
812
+ const TypeConverter &converter,
812
813
TypeConverter::SignatureConversion *entryConversion);
813
814
814
815
// / Apply the given signature conversion on the given block. The new block
@@ -818,7 +819,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
818
819
// / translate between the origin argument types and those specified in the
819
820
// / signature conversion.
820
821
Block *applySignatureConversion (
821
- Block *block, const TypeConverter *converter,
822
+ ConversionPatternRewriter &rewriter, Block *block,
823
+ const TypeConverter *converter,
822
824
TypeConverter::SignatureConversion &signatureConversion);
823
825
824
826
// ===--------------------------------------------------------------------===//
@@ -991,24 +993,8 @@ void BlockTypeConversionRewrite::commit() {
991
993
}
992
994
}
993
995
994
- void BlockTypeConversionRewrite::cleanup () {
995
- assert (origBlock->empty () && " expected empty block" );
996
- origBlock->dropAllDefinedValueUses ();
997
- delete origBlock;
998
- origBlock = nullptr ;
999
- }
1000
-
1001
996
void BlockTypeConversionRewrite::rollback () {
1002
- // Drop all uses of the new block arguments and replace uses of the new block.
1003
- for (int i = block->getNumArguments () - 1 ; i >= 0 ; --i)
1004
- block->getArgument (i).dropAllUses ();
1005
997
block->replaceAllUsesWith (origBlock);
1006
-
1007
- // Move the operations back the original block, move the original block back
1008
- // into its original location and the delete the new block.
1009
- origBlock->getOperations ().splice (origBlock->end (), block->getOperations ());
1010
- block->getParent ()->getBlocks ().insert (Region::iterator (block), origBlock);
1011
- eraseBlock (block);
1012
998
}
1013
999
1014
1000
LogicalResult BlockTypeConversionRewrite::materializeLiveConversions (
@@ -1224,10 +1210,11 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
1224
1210
// Type Conversion
1225
1211
1226
1212
FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature (
1227
- Block *block, const TypeConverter *converter,
1213
+ ConversionPatternRewriter &rewriter, Block *block,
1214
+ const TypeConverter *converter,
1228
1215
TypeConverter::SignatureConversion *conversion) {
1229
1216
if (conversion)
1230
- return applySignatureConversion (block, converter, *conversion);
1217
+ return applySignatureConversion (rewriter, block, converter, *conversion);
1231
1218
1232
1219
// If a converter wasn't provided, and the block wasn't already converted,
1233
1220
// there is nothing we can do.
@@ -1236,35 +1223,39 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
1236
1223
1237
1224
// Try to convert the signature for the block with the provided converter.
1238
1225
if (auto conversion = converter->convertBlockSignature (block))
1239
- return applySignatureConversion (block, converter, *conversion);
1226
+ return applySignatureConversion (rewriter, block, converter, *conversion);
1240
1227
return failure ();
1241
1228
}
1242
1229
1243
1230
Block *ConversionPatternRewriterImpl::applySignatureConversion (
1244
- Region *region, TypeConverter::SignatureConversion &conversion,
1231
+ ConversionPatternRewriter &rewriter, Region *region,
1232
+ TypeConverter::SignatureConversion &conversion,
1245
1233
const TypeConverter *converter) {
1246
1234
if (!region->empty ())
1247
- return *convertBlockSignature (®ion->front (), converter, &conversion);
1235
+ return *convertBlockSignature (rewriter, ®ion->front (), converter,
1236
+ &conversion);
1248
1237
return nullptr ;
1249
1238
}
1250
1239
1251
1240
FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes (
1252
- Region *region, const TypeConverter &converter,
1241
+ ConversionPatternRewriter &rewriter, Region *region,
1242
+ const TypeConverter &converter,
1253
1243
TypeConverter::SignatureConversion *entryConversion) {
1254
1244
regionToConverter[region] = &converter;
1255
1245
if (region->empty ())
1256
1246
return nullptr ;
1257
1247
1258
- if (failed (convertNonEntryRegionTypes (region, converter)))
1248
+ if (failed (convertNonEntryRegionTypes (rewriter, region, converter)))
1259
1249
return failure ();
1260
1250
1261
- FailureOr<Block *> newEntry =
1262
- convertBlockSignature ( ®ion->front (), &converter, entryConversion);
1251
+ FailureOr<Block *> newEntry = convertBlockSignature (
1252
+ rewriter, ®ion->front (), &converter, entryConversion);
1263
1253
return newEntry;
1264
1254
}
1265
1255
1266
1256
LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes (
1267
- Region *region, const TypeConverter &converter,
1257
+ ConversionPatternRewriter &rewriter, Region *region,
1258
+ const TypeConverter &converter,
1268
1259
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
1269
1260
regionToConverter[region] = &converter;
1270
1261
if (region->empty ())
@@ -1285,16 +1276,18 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
1285
1276
: const_cast <TypeConverter::SignatureConversion *>(
1286
1277
&blockConversions[blockIdx++]);
1287
1278
1288
- if (failed (convertBlockSignature (&block, &converter, blockConversion)))
1279
+ if (failed (convertBlockSignature (rewriter, &block, &converter,
1280
+ blockConversion)))
1289
1281
return failure ();
1290
1282
}
1291
1283
return success ();
1292
1284
}
1293
1285
1294
1286
Block *ConversionPatternRewriterImpl::applySignatureConversion (
1295
- Block *block, const TypeConverter *converter,
1287
+ ConversionPatternRewriter &rewriter, Block *block,
1288
+ const TypeConverter *converter,
1296
1289
TypeConverter::SignatureConversion &signatureConversion) {
1297
- MLIRContext *ctx = eraseRewriter .getContext ();
1290
+ MLIRContext *ctx = rewriter .getContext ();
1298
1291
1299
1292
// If no arguments are being changed or added, there is nothing to do.
1300
1293
unsigned origArgCount = block->getNumArguments ();
@@ -1304,11 +1297,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1304
1297
1305
1298
// Split the block at the beginning to get a new block to use for the updated
1306
1299
// signature.
1307
- Block *newBlock = block-> splitBlock (block->begin ());
1300
+ Block *newBlock = rewriter. splitBlock (block, block->begin ());
1308
1301
block->replaceAllUsesWith (newBlock);
1309
- // Unlink the block, but do not erase it yet, so that the change can be rolled
1310
- // back.
1311
- block->getParent ()->getBlocks ().remove (block);
1312
1302
1313
1303
// Map all new arguments to the location of the argument they originate from.
1314
1304
SmallVector<Location> newLocs (convertedTypes.size (),
@@ -1384,6 +1374,11 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1384
1374
1385
1375
appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
1386
1376
converter);
1377
+
1378
+ // Erase the old block. (It is just unlinked for now and will be erased during
1379
+ // cleanup.)
1380
+ rewriter.eraseBlock (block);
1381
+
1387
1382
return newBlock;
1388
1383
}
1389
1384
@@ -1592,7 +1587,7 @@ Block *ConversionPatternRewriter::applySignatureConversion(
1592
1587
assert (!impl->wasOpReplaced (region->getParentOp ()) &&
1593
1588
" attempting to apply a signature conversion to a block within a "
1594
1589
" replaced/erased op" );
1595
- return impl->applySignatureConversion (region, conversion, converter);
1590
+ return impl->applySignatureConversion (* this , region, conversion, converter);
1596
1591
}
1597
1592
1598
1593
FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes (
@@ -1601,7 +1596,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
1601
1596
assert (!impl->wasOpReplaced (region->getParentOp ()) &&
1602
1597
" attempting to apply a signature conversion to a block within a "
1603
1598
" replaced/erased op" );
1604
- return impl->convertRegionTypes (region, converter, entryConversion);
1599
+ return impl->convertRegionTypes (* this , region, converter, entryConversion);
1605
1600
}
1606
1601
1607
1602
LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes (
@@ -1610,7 +1605,8 @@ LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
1610
1605
assert (!impl->wasOpReplaced (region->getParentOp ()) &&
1611
1606
" attempting to apply a signature conversion to a block within a "
1612
1607
" replaced/erased op" );
1613
- return impl->convertNonEntryRegionTypes (region, converter, blockConversions);
1608
+ return impl->convertNonEntryRegionTypes (*this , region, converter,
1609
+ blockConversions);
1614
1610
}
1615
1611
1616
1612
void ConversionPatternRewriter::replaceUsesOfBlockArgument (BlockArgument from,
@@ -2104,7 +2100,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2104
2100
// If the region of the block has a type converter, try to convert the block
2105
2101
// directly.
2106
2102
if (auto *converter = impl.regionToConverter .lookup (block->getParent ())) {
2107
- if (failed (impl.convertBlockSignature (block, converter))) {
2103
+ if (failed (impl.convertBlockSignature (rewriter, block, converter))) {
2108
2104
LLVM_DEBUG (logFailure (impl.logger , " failed to convert types of moved "
2109
2105
" block" ));
2110
2106
return failure ();
0 commit comments