@@ -725,10 +725,9 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
725
725
namespace mlir {
726
726
namespace detail {
727
727
struct ConversionPatternRewriterImpl : public RewriterBase ::Listener {
728
- explicit ConversionPatternRewriterImpl (PatternRewriter &rewriter ,
728
+ explicit ConversionPatternRewriterImpl (MLIRContext *ctx ,
729
729
const ConversionConfig &config)
730
- : rewriter(rewriter), eraseRewriter(rewriter.getContext()),
731
- config(config) {}
730
+ : eraseRewriter(ctx), config(config) {}
732
731
733
732
// ===--------------------------------------------------------------------===//
734
733
// State Management
@@ -823,8 +822,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
823
822
Type origOutputType,
824
823
const TypeConverter *converter);
825
824
826
- Value buildUnresolvedArgumentMaterialization (PatternRewriter &rewriter ,
827
- Location loc, ValueRange inputs,
825
+ Value buildUnresolvedArgumentMaterialization (Block *block, Location loc ,
826
+ ValueRange inputs,
828
827
Type origOutputType,
829
828
Type outputType,
830
829
const TypeConverter *converter);
@@ -903,8 +902,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
903
902
// State
904
903
// ===--------------------------------------------------------------------===//
905
904
906
- PatternRewriter &rewriter;
907
-
908
905
// / This rewriter must be used for erasing ops/blocks.
909
906
SingleEraseRewriter eraseRewriter;
910
907
@@ -1008,8 +1005,12 @@ void BlockTypeConversionRewrite::rollback() {
1008
1005
1009
1006
LogicalResult BlockTypeConversionRewrite::materializeLiveConversions (
1010
1007
function_ref<Operation *(Value)> findLiveUser) {
1008
+ auto builder = OpBuilder::atBlockBegin (block, /* listener=*/ &rewriterImpl);
1009
+
1011
1010
// Process the remapping for each of the original arguments.
1012
1011
for (unsigned i = 0 , e = origBlock->getNumArguments (); i != e; ++i) {
1012
+ OpBuilder::InsertionGuard g (builder);
1013
+
1013
1014
// If the type of this argument changed and the argument is still live, we
1014
1015
// need to materialize a conversion.
1015
1016
BlockArgument origArg = origBlock->getArgument (i);
@@ -1021,14 +1022,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
1021
1022
1022
1023
Value replacementValue = rewriterImpl.mapping .lookupOrDefault (origArg);
1023
1024
bool isDroppedArg = replacementValue == origArg;
1024
- if (isDroppedArg)
1025
- rewriterImpl.rewriter .setInsertionPointToStart (getBlock ());
1026
- else
1027
- rewriterImpl.rewriter .setInsertionPointAfterValue (replacementValue);
1025
+ if (!isDroppedArg)
1026
+ builder.setInsertionPointAfterValue (replacementValue);
1028
1027
Value newArg;
1029
1028
if (converter) {
1030
1029
newArg = converter->materializeSourceConversion (
1031
- rewriterImpl. rewriter , origArg.getLoc (), origArg.getType (),
1030
+ builder , origArg.getLoc (), origArg.getType (),
1032
1031
isDroppedArg ? ValueRange () : ValueRange (replacementValue));
1033
1032
assert ((!newArg || newArg.getType () == origArg.getType ()) &&
1034
1033
" materialization hook did not provide a value of the expected "
@@ -1293,6 +1292,8 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
1293
1292
Block *ConversionPatternRewriterImpl::applySignatureConversion (
1294
1293
Block *block, const TypeConverter *converter,
1295
1294
TypeConverter::SignatureConversion &signatureConversion) {
1295
+ MLIRContext *ctx = block->getParentOp ()->getContext ();
1296
+
1296
1297
// If no arguments are being changed or added, there is nothing to do.
1297
1298
unsigned origArgCount = block->getNumArguments ();
1298
1299
auto convertedTypes = signatureConversion.getConvertedTypes ();
@@ -1309,7 +1310,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1309
1310
1310
1311
// Map all new arguments to the location of the argument they originate from.
1311
1312
SmallVector<Location> newLocs (convertedTypes.size (),
1312
- rewriter .getUnknownLoc ());
1313
+ Builder (ctx) .getUnknownLoc ());
1313
1314
for (unsigned i = 0 ; i < origArgCount; ++i) {
1314
1315
auto inputMap = signatureConversion.getInputMapping (i);
1315
1316
if (!inputMap || inputMap->replacementValue )
@@ -1328,8 +1329,6 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1328
1329
SmallVector<std::optional<ConvertedArgInfo>, 1 > argInfo;
1329
1330
argInfo.resize (origArgCount);
1330
1331
1331
- OpBuilder::InsertionGuard guard (rewriter);
1332
- rewriter.setInsertionPointToStart (newBlock);
1333
1332
for (unsigned i = 0 ; i != origArgCount; ++i) {
1334
1333
auto inputMap = signatureConversion.getInputMapping (i);
1335
1334
if (!inputMap)
@@ -1372,7 +1371,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1372
1371
outputType = legalOutputType;
1373
1372
1374
1373
newArg = buildUnresolvedArgumentMaterialization (
1375
- rewriter , origArg.getLoc (), replArgs, origOutputType, outputType,
1374
+ newBlock , origArg.getLoc (), replArgs, origOutputType, outputType,
1376
1375
converter);
1377
1376
}
1378
1377
@@ -1410,12 +1409,11 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
1410
1409
return convertOp.getResult (0 );
1411
1410
}
1412
1411
Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization (
1413
- PatternRewriter &rewriter, Location loc, ValueRange inputs,
1414
- Type origOutputType, Type outputType, const TypeConverter *converter) {
1415
- return buildUnresolvedMaterialization (
1416
- MaterializationKind::Argument, rewriter.getInsertionBlock (),
1417
- rewriter.getInsertionPoint (), loc, inputs, outputType, origOutputType,
1418
- converter);
1412
+ Block *block, Location loc, ValueRange inputs, Type origOutputType,
1413
+ Type outputType, const TypeConverter *converter) {
1414
+ return buildUnresolvedMaterialization (MaterializationKind::Argument, block,
1415
+ block->begin (), loc, inputs, outputType,
1416
+ origOutputType, converter);
1419
1417
}
1420
1418
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization (
1421
1419
Location loc, Value input, Type outputType,
@@ -1527,7 +1525,7 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
1527
1525
ConversionPatternRewriter::ConversionPatternRewriter (
1528
1526
MLIRContext *ctx, const ConversionConfig &config)
1529
1527
: PatternRewriter(ctx),
1530
- impl(new detail::ConversionPatternRewriterImpl(* this , config)) {
1528
+ impl(new detail::ConversionPatternRewriterImpl(ctx , config)) {
1531
1529
setListener (impl.get ());
1532
1530
}
1533
1531
0 commit comments