Skip to content

Commit 362813d

Browse files
[mlir][Transforms][NFC] Decouple ConversionPatternRewriterImpl from ConversionPatternRewriter
`ConversionPatternRewriterImpl` no longer maintains a reference to the respective `ConversionPatternRewriter`. An `MLIRContext` is sufficient. This commit simplifies the internal state of `ConversionPatternRewriterImpl`.
1 parent 819e5f9 commit 362813d

File tree

1 file changed

+21
-23
lines changed

1 file changed

+21
-23
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -725,10 +725,9 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
725725
namespace mlir {
726726
namespace detail {
727727
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
728-
explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter,
728+
explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
729729
const ConversionConfig &config)
730-
: rewriter(rewriter), eraseRewriter(rewriter.getContext()),
731-
config(config) {}
730+
: eraseRewriter(ctx), config(config) {}
732731

733732
//===--------------------------------------------------------------------===//
734733
// State Management
@@ -823,8 +822,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
823822
Type origOutputType,
824823
const TypeConverter *converter);
825824

826-
Value buildUnresolvedArgumentMaterialization(PatternRewriter &rewriter,
827-
Location loc, ValueRange inputs,
825+
Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
826+
ValueRange inputs,
828827
Type origOutputType,
829828
Type outputType,
830829
const TypeConverter *converter);
@@ -903,8 +902,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
903902
// State
904903
//===--------------------------------------------------------------------===//
905904

906-
PatternRewriter &rewriter;
907-
908905
/// This rewriter must be used for erasing ops/blocks.
909906
SingleEraseRewriter eraseRewriter;
910907

@@ -1008,8 +1005,12 @@ void BlockTypeConversionRewrite::rollback() {
10081005

10091006
LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
10101007
function_ref<Operation *(Value)> findLiveUser) {
1008+
auto builder = OpBuilder::atBlockBegin(block, /*listener=*/&rewriterImpl);
1009+
10111010
// Process the remapping for each of the original arguments.
10121011
for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
1012+
OpBuilder::InsertionGuard g(builder);
1013+
10131014
// If the type of this argument changed and the argument is still live, we
10141015
// need to materialize a conversion.
10151016
BlockArgument origArg = origBlock->getArgument(i);
@@ -1021,14 +1022,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
10211022

10221023
Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
10231024
bool isDroppedArg = replacementValue == origArg;
1024-
if (isDroppedArg)
1025-
rewriterImpl.rewriter.setInsertionPointToStart(getBlock());
1026-
else
1027-
rewriterImpl.rewriter.setInsertionPointAfterValue(replacementValue);
1025+
if (!isDroppedArg)
1026+
builder.setInsertionPointAfterValue(replacementValue);
10281027
Value newArg;
10291028
if (converter) {
10301029
newArg = converter->materializeSourceConversion(
1031-
rewriterImpl.rewriter, origArg.getLoc(), origArg.getType(),
1030+
builder, origArg.getLoc(), origArg.getType(),
10321031
isDroppedArg ? ValueRange() : ValueRange(replacementValue));
10331032
assert((!newArg || newArg.getType() == origArg.getType()) &&
10341033
"materialization hook did not provide a value of the expected "
@@ -1293,6 +1292,8 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
12931292
Block *ConversionPatternRewriterImpl::applySignatureConversion(
12941293
Block *block, const TypeConverter *converter,
12951294
TypeConverter::SignatureConversion &signatureConversion) {
1295+
MLIRContext *ctx = block->getParentOp()->getContext();
1296+
12961297
// If no arguments are being changed or added, there is nothing to do.
12971298
unsigned origArgCount = block->getNumArguments();
12981299
auto convertedTypes = signatureConversion.getConvertedTypes();
@@ -1309,7 +1310,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13091310

13101311
// Map all new arguments to the location of the argument they originate from.
13111312
SmallVector<Location> newLocs(convertedTypes.size(),
1312-
rewriter.getUnknownLoc());
1313+
Builder(ctx).getUnknownLoc());
13131314
for (unsigned i = 0; i < origArgCount; ++i) {
13141315
auto inputMap = signatureConversion.getInputMapping(i);
13151316
if (!inputMap || inputMap->replacementValue)
@@ -1328,8 +1329,6 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13281329
SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
13291330
argInfo.resize(origArgCount);
13301331

1331-
OpBuilder::InsertionGuard guard(rewriter);
1332-
rewriter.setInsertionPointToStart(newBlock);
13331332
for (unsigned i = 0; i != origArgCount; ++i) {
13341333
auto inputMap = signatureConversion.getInputMapping(i);
13351334
if (!inputMap)
@@ -1372,7 +1371,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13721371
outputType = legalOutputType;
13731372

13741373
newArg = buildUnresolvedArgumentMaterialization(
1375-
rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
1374+
newBlock, origArg.getLoc(), replArgs, origOutputType, outputType,
13761375
converter);
13771376
}
13781377

@@ -1410,12 +1409,11 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
14101409
return convertOp.getResult(0);
14111410
}
14121411
Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
1413-
PatternRewriter &rewriter, Location loc, ValueRange inputs,
1414-
Type origOutputType, Type outputType, const TypeConverter *converter) {
1415-
return buildUnresolvedMaterialization(
1416-
MaterializationKind::Argument, rewriter.getInsertionBlock(),
1417-
rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
1418-
converter);
1412+
Block *block, Location loc, ValueRange inputs, Type origOutputType,
1413+
Type outputType, const TypeConverter *converter) {
1414+
return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
1415+
block->begin(), loc, inputs, outputType,
1416+
origOutputType, converter);
14191417
}
14201418
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
14211419
Location loc, Value input, Type outputType,
@@ -1527,7 +1525,7 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
15271525
ConversionPatternRewriter::ConversionPatternRewriter(
15281526
MLIRContext *ctx, const ConversionConfig &config)
15291527
: PatternRewriter(ctx),
1530-
impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
1528+
impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
15311529
setListener(impl.get());
15321530
}
15331531

0 commit comments

Comments
 (0)