@@ -756,10 +756,9 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
756
756
namespace mlir {
757
757
namespace detail {
758
758
struct ConversionPatternRewriterImpl : public RewriterBase ::Listener {
759
- explicit ConversionPatternRewriterImpl (PatternRewriter &rewriter ,
759
+ explicit ConversionPatternRewriterImpl (MLIRContext *ctx ,
760
760
const ConversionConfig &config)
761
- : rewriter(rewriter), eraseRewriter(rewriter.getContext()),
762
- config(config) {}
761
+ : eraseRewriter(ctx), config(config) {}
763
762
764
763
// ===--------------------------------------------------------------------===//
765
764
// State Management
@@ -854,8 +853,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
854
853
Type origOutputType,
855
854
const TypeConverter *converter);
856
855
857
- Value buildUnresolvedArgumentMaterialization (PatternRewriter &rewriter ,
858
- Location loc, ValueRange inputs,
856
+ Value buildUnresolvedArgumentMaterialization (Block *block, Location loc ,
857
+ ValueRange inputs,
859
858
Type origOutputType,
860
859
Type outputType,
861
860
const TypeConverter *converter);
@@ -934,8 +933,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
934
933
// State
935
934
// ===--------------------------------------------------------------------===//
936
935
937
- PatternRewriter &rewriter;
938
-
939
936
// / This rewriter must be used for erasing ops/blocks.
940
937
SingleEraseRewriter eraseRewriter;
941
938
@@ -1037,8 +1034,12 @@ void BlockTypeConversionRewrite::rollback() {
1037
1034
1038
1035
LogicalResult BlockTypeConversionRewrite::materializeLiveConversions (
1039
1036
function_ref<Operation *(Value)> findLiveUser) {
1037
+ auto builder = OpBuilder::atBlockBegin (block, /* listener=*/ &rewriterImpl);
1038
+
1040
1039
// Process the remapping for each of the original arguments.
1041
1040
for (auto it : llvm::enumerate (origBlock->getArguments ())) {
1041
+ OpBuilder::InsertionGuard g (builder);
1042
+
1042
1043
// If the type of this argument changed and the argument is still live, we
1043
1044
// need to materialize a conversion.
1044
1045
BlockArgument origArg = it.value ();
@@ -1050,14 +1051,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
1050
1051
1051
1052
Value replacementValue = rewriterImpl.mapping .lookupOrDefault (origArg);
1052
1053
bool isDroppedArg = replacementValue == origArg;
1053
- if (isDroppedArg)
1054
- rewriterImpl.rewriter .setInsertionPointToStart (getBlock ());
1055
- else
1056
- rewriterImpl.rewriter .setInsertionPointAfterValue (replacementValue);
1054
+ if (!isDroppedArg)
1055
+ builder.setInsertionPointAfterValue (replacementValue);
1057
1056
Value newArg;
1058
1057
if (converter) {
1059
1058
newArg = converter->materializeSourceConversion (
1060
- rewriterImpl. rewriter , origArg.getLoc (), origArg.getType (),
1059
+ builder , origArg.getLoc (), origArg.getType (),
1061
1060
isDroppedArg ? ValueRange () : ValueRange (replacementValue));
1062
1061
assert ((!newArg || newArg.getType () == origArg.getType ()) &&
1063
1062
" materialization hook did not provide a value of the expected "
@@ -1322,6 +1321,8 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
1322
1321
Block *ConversionPatternRewriterImpl::applySignatureConversion (
1323
1322
Block *block, const TypeConverter *converter,
1324
1323
TypeConverter::SignatureConversion &signatureConversion) {
1324
+ MLIRContext *ctx = block->getParentOp ()->getContext ();
1325
+
1325
1326
// If no arguments are being changed or added, there is nothing to do.
1326
1327
unsigned origArgCount = block->getNumArguments ();
1327
1328
auto convertedTypes = signatureConversion.getConvertedTypes ();
@@ -1338,7 +1339,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1338
1339
1339
1340
// Map all new arguments to the location of the argument they originate from.
1340
1341
SmallVector<Location> newLocs (convertedTypes.size (),
1341
- rewriter .getUnknownLoc ());
1342
+ Builder (ctx) .getUnknownLoc ());
1342
1343
for (unsigned i = 0 ; i < origArgCount; ++i) {
1343
1344
auto inputMap = signatureConversion.getInputMapping (i);
1344
1345
if (!inputMap || inputMap->replacementValue )
@@ -1357,8 +1358,6 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1357
1358
SmallVector<std::optional<ConvertedArgInfo>, 1 > argInfo;
1358
1359
argInfo.resize (origArgCount);
1359
1360
1360
- OpBuilder::InsertionGuard guard (rewriter);
1361
- rewriter.setInsertionPointToStart (newBlock);
1362
1361
for (unsigned i = 0 ; i != origArgCount; ++i) {
1363
1362
auto inputMap = signatureConversion.getInputMapping (i);
1364
1363
if (!inputMap)
@@ -1401,7 +1400,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1401
1400
outputType = legalOutputType;
1402
1401
1403
1402
newArg = buildUnresolvedArgumentMaterialization (
1404
- rewriter , origArg.getLoc (), replArgs, origOutputType, outputType,
1403
+ newBlock , origArg.getLoc (), replArgs, origOutputType, outputType,
1405
1404
converter);
1406
1405
}
1407
1406
@@ -1439,12 +1438,11 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
1439
1438
return convertOp.getResult (0 );
1440
1439
}
1441
1440
Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization (
1442
- PatternRewriter &rewriter, Location loc, ValueRange inputs,
1443
- Type origOutputType, Type outputType, const TypeConverter *converter) {
1444
- return buildUnresolvedMaterialization (
1445
- MaterializationKind::Argument, rewriter.getInsertionBlock (),
1446
- rewriter.getInsertionPoint (), loc, inputs, outputType, origOutputType,
1447
- converter);
1441
+ Block *block, Location loc, ValueRange inputs, Type origOutputType,
1442
+ Type outputType, const TypeConverter *converter) {
1443
+ return buildUnresolvedMaterialization (MaterializationKind::Argument, block,
1444
+ block->begin (), loc, inputs, outputType,
1445
+ origOutputType, converter);
1448
1446
}
1449
1447
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization (
1450
1448
Location loc, Value input, Type outputType,
@@ -1556,7 +1554,7 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
1556
1554
ConversionPatternRewriter::ConversionPatternRewriter (
1557
1555
MLIRContext *ctx, const ConversionConfig &config)
1558
1556
: PatternRewriter(ctx),
1559
- impl(new detail::ConversionPatternRewriterImpl(* this , config)) {
1557
+ impl(new detail::ConversionPatternRewriterImpl(ctx , config)) {
1560
1558
setListener (impl.get ());
1561
1559
}
1562
1560
0 commit comments