@@ -746,24 +746,27 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
746
746
// / block is returned containing the new arguments. Returns `block` if it did
747
747
// / not require conversion.
748
748
FailureOr<Block *> convertBlockSignature (
749
- Block *block, const TypeConverter *converter,
749
+ ConversionPatternRewriter &rewriter, Block *block,
750
+ const TypeConverter *converter,
750
751
TypeConverter::SignatureConversion *conversion = nullptr );
751
752
752
753
// / Convert the types of non-entry block arguments within the given region.
753
754
LogicalResult convertNonEntryRegionTypes (
754
- Region *region, const TypeConverter &converter,
755
+ ConversionPatternRewriter &rewriter, Region *region,
756
+ const TypeConverter &converter,
755
757
ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
756
758
757
759
// / Apply a signature conversion on the given region, using `converter` for
758
760
// / materializations if not null.
759
761
Block *
760
- applySignatureConversion (Region *region,
762
+ applySignatureConversion (ConversionPatternRewriter &rewriter, Region *region,
761
763
TypeConverter::SignatureConversion &conversion,
762
764
const TypeConverter *converter);
763
765
764
766
// / Convert the types of block arguments within the given region.
765
767
FailureOr<Block *>
766
- convertRegionTypes (Region *region, const TypeConverter &converter,
768
+ convertRegionTypes (ConversionPatternRewriter &rewriter, Region *region,
769
+ const TypeConverter &converter,
767
770
TypeConverter::SignatureConversion *entryConversion);
768
771
769
772
// / Apply the given signature conversion on the given block. The new block
@@ -773,7 +776,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
773
776
// / translate between the origin argument types and those specified in the
774
777
// / signature conversion.
775
778
Block *applySignatureConversion (
776
- Block *block, const TypeConverter *converter,
779
+ ConversionPatternRewriter &rewriter, Block *block,
780
+ const TypeConverter *converter,
777
781
TypeConverter::SignatureConversion &signatureConversion);
778
782
779
783
// ===--------------------------------------------------------------------===//
@@ -940,24 +944,10 @@ void BlockTypeConversionRewrite::commit() {
940
944
rewriterImpl.mapping .lookupOrDefault (castValue, origArg.getType ()));
941
945
}
942
946
}
943
-
944
- assert (origBlock->empty () && " expected empty block" );
945
- origBlock->dropAllDefinedValueUses ();
946
- delete origBlock;
947
- origBlock = nullptr ;
948
947
}
949
948
950
949
void BlockTypeConversionRewrite::rollback () {
951
- // Drop all uses of the new block arguments and replace uses of the new block.
952
- for (int i = block->getNumArguments () - 1 ; i >= 0 ; --i)
953
- block->getArgument (i).dropAllUses ();
954
950
block->replaceAllUsesWith (origBlock);
955
-
956
- // Move the operations back the original block, move the original block back
957
- // into its original location and the delete the new block.
958
- origBlock->getOperations ().splice (origBlock->end (), block->getOperations ());
959
- block->getParent ()->getBlocks ().insert (Region::iterator (block), origBlock);
960
- eraseBlock (block);
961
951
}
962
952
963
953
LogicalResult BlockTypeConversionRewrite::materializeLiveConversions (
@@ -1173,10 +1163,11 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
1173
1163
// Type Conversion
1174
1164
1175
1165
FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature (
1176
- Block *block, const TypeConverter *converter,
1166
+ ConversionPatternRewriter &rewriter, Block *block,
1167
+ const TypeConverter *converter,
1177
1168
TypeConverter::SignatureConversion *conversion) {
1178
1169
if (conversion)
1179
- return applySignatureConversion (block, converter, *conversion);
1170
+ return applySignatureConversion (rewriter, block, converter, *conversion);
1180
1171
1181
1172
// If a converter wasn't provided, and the block wasn't already converted,
1182
1173
// there is nothing we can do.
@@ -1185,35 +1176,39 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
1185
1176
1186
1177
// Try to convert the signature for the block with the provided converter.
1187
1178
if (auto conversion = converter->convertBlockSignature (block))
1188
- return applySignatureConversion (block, converter, *conversion);
1179
+ return applySignatureConversion (rewriter, block, converter, *conversion);
1189
1180
return failure ();
1190
1181
}
1191
1182
1192
1183
Block *ConversionPatternRewriterImpl::applySignatureConversion (
1193
- Region *region, TypeConverter::SignatureConversion &conversion,
1184
+ ConversionPatternRewriter &rewriter, Region *region,
1185
+ TypeConverter::SignatureConversion &conversion,
1194
1186
const TypeConverter *converter) {
1195
1187
if (!region->empty ())
1196
- return *convertBlockSignature (®ion->front (), converter, &conversion);
1188
+ return *convertBlockSignature (rewriter, ®ion->front (), converter,
1189
+ &conversion);
1197
1190
return nullptr ;
1198
1191
}
1199
1192
1200
1193
FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes (
1201
- Region *region, const TypeConverter &converter,
1194
+ ConversionPatternRewriter &rewriter, Region *region,
1195
+ const TypeConverter &converter,
1202
1196
TypeConverter::SignatureConversion *entryConversion) {
1203
1197
regionToConverter[region] = &converter;
1204
1198
if (region->empty ())
1205
1199
return nullptr ;
1206
1200
1207
- if (failed (convertNonEntryRegionTypes (region, converter)))
1201
+ if (failed (convertNonEntryRegionTypes (rewriter, region, converter)))
1208
1202
return failure ();
1209
1203
1210
- FailureOr<Block *> newEntry =
1211
- convertBlockSignature ( ®ion->front (), &converter, entryConversion);
1204
+ FailureOr<Block *> newEntry = convertBlockSignature (
1205
+ rewriter, ®ion->front (), &converter, entryConversion);
1212
1206
return newEntry;
1213
1207
}
1214
1208
1215
1209
LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes (
1216
- Region *region, const TypeConverter &converter,
1210
+ ConversionPatternRewriter &rewriter, Region *region,
1211
+ const TypeConverter &converter,
1217
1212
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
1218
1213
regionToConverter[region] = &converter;
1219
1214
if (region->empty ())
@@ -1234,16 +1229,18 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
1234
1229
: const_cast <TypeConverter::SignatureConversion *>(
1235
1230
&blockConversions[blockIdx++]);
1236
1231
1237
- if (failed (convertBlockSignature (&block, &converter, blockConversion)))
1232
+ if (failed (convertBlockSignature (rewriter, &block, &converter,
1233
+ blockConversion)))
1238
1234
return failure ();
1239
1235
}
1240
1236
return success ();
1241
1237
}
1242
1238
1243
1239
Block *ConversionPatternRewriterImpl::applySignatureConversion (
1244
- Block *block, const TypeConverter *converter,
1240
+ ConversionPatternRewriter &rewriter, Block *block,
1241
+ const TypeConverter *converter,
1245
1242
TypeConverter::SignatureConversion &signatureConversion) {
1246
- MLIRContext *ctx = eraseRewriter .getContext ();
1243
+ MLIRContext *ctx = rewriter .getContext ();
1247
1244
1248
1245
// If no arguments are being changed or added, there is nothing to do.
1249
1246
unsigned origArgCount = block->getNumArguments ();
@@ -1253,11 +1250,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1253
1250
1254
1251
// Split the block at the beginning to get a new block to use for the updated
1255
1252
// signature.
1256
- Block *newBlock = block-> splitBlock (block->begin ());
1253
+ Block *newBlock = rewriter. splitBlock (block, block->begin ());
1257
1254
block->replaceAllUsesWith (newBlock);
1258
- // Unlink the block, but do not erase it yet, so that the change can be rolled
1259
- // back.
1260
- block->getParent ()->getBlocks ().remove (block);
1261
1255
1262
1256
// Map all new arguments to the location of the argument they originate from.
1263
1257
SmallVector<Location> newLocs (convertedTypes.size (),
@@ -1333,6 +1327,11 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1333
1327
1334
1328
appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
1335
1329
converter);
1330
+
1331
+ // Erase the old block. (It is just unlinked for now and will be erased during
1332
+ // cleanup.)
1333
+ rewriter.eraseBlock (block);
1334
+
1336
1335
return newBlock;
1337
1336
}
1338
1337
@@ -1531,7 +1530,7 @@ Block *ConversionPatternRewriter::applySignatureConversion(
1531
1530
assert (!impl->wasOpReplaced (region->getParentOp ()) &&
1532
1531
" attempting to apply a signature conversion to a block within a "
1533
1532
" replaced/erased op" );
1534
- return impl->applySignatureConversion (region, conversion, converter);
1533
+ return impl->applySignatureConversion (* this , region, conversion, converter);
1535
1534
}
1536
1535
1537
1536
FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes (
@@ -1540,7 +1539,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
1540
1539
assert (!impl->wasOpReplaced (region->getParentOp ()) &&
1541
1540
" attempting to apply a signature conversion to a block within a "
1542
1541
" replaced/erased op" );
1543
- return impl->convertRegionTypes (region, converter, entryConversion);
1542
+ return impl->convertRegionTypes (* this , region, converter, entryConversion);
1544
1543
}
1545
1544
1546
1545
LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes (
@@ -1549,7 +1548,8 @@ LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
1549
1548
assert (!impl->wasOpReplaced (region->getParentOp ()) &&
1550
1549
" attempting to apply a signature conversion to a block within a "
1551
1550
" replaced/erased op" );
1552
- return impl->convertNonEntryRegionTypes (region, converter, blockConversions);
1551
+ return impl->convertNonEntryRegionTypes (*this , region, converter,
1552
+ blockConversions);
1553
1553
}
1554
1554
1555
1555
void ConversionPatternRewriter::replaceUsesOfBlockArgument (BlockArgument from,
@@ -2051,7 +2051,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2051
2051
// If the region of the block has a type converter, try to convert the block
2052
2052
// directly.
2053
2053
if (auto *converter = impl.regionToConverter .lookup (block->getParent ())) {
2054
- if (failed (impl.convertBlockSignature (block, converter))) {
2054
+ if (failed (impl.convertBlockSignature (rewriter, block, converter))) {
2055
2055
LLVM_DEBUG (logFailure (impl.logger , " failed to convert types of moved "
2056
2056
" block" ));
2057
2057
return failure ();
0 commit comments