Skip to content

Commit 7bb08ee

Browse files
[mlir][Transforms][NFC] Decouple ConversionPatternRewriterImpl from ConversionPatternRewriter (#82333)
`ConversionPatternRewriterImpl` no longer maintains a reference to the respective `ConversionPatternRewriter`. An `MLIRContext` is sufficient. This commit simplifies the internal state of `ConversionPatternRewriterImpl`.
1 parent 9dfb843 commit 7bb08ee

File tree

1 file changed

+21
-23
lines changed

1 file changed

+21
-23
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -756,10 +756,9 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
756756
namespace mlir {
757757
namespace detail {
758758
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
759-
explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter,
759+
explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
760760
const ConversionConfig &config)
761-
: rewriter(rewriter), eraseRewriter(rewriter.getContext()),
762-
config(config) {}
761+
: eraseRewriter(ctx), config(config) {}
763762

764763
//===--------------------------------------------------------------------===//
765764
// State Management
@@ -854,8 +853,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
854853
Type origOutputType,
855854
const TypeConverter *converter);
856855

857-
Value buildUnresolvedArgumentMaterialization(PatternRewriter &rewriter,
858-
Location loc, ValueRange inputs,
856+
Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
857+
ValueRange inputs,
859858
Type origOutputType,
860859
Type outputType,
861860
const TypeConverter *converter);
@@ -934,8 +933,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
934933
// State
935934
//===--------------------------------------------------------------------===//
936935

937-
PatternRewriter &rewriter;
938-
939936
/// This rewriter must be used for erasing ops/blocks.
940937
SingleEraseRewriter eraseRewriter;
941938

@@ -1037,8 +1034,12 @@ void BlockTypeConversionRewrite::rollback() {
10371034

10381035
LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
10391036
function_ref<Operation *(Value)> findLiveUser) {
1037+
auto builder = OpBuilder::atBlockBegin(block, /*listener=*/&rewriterImpl);
1038+
10401039
// Process the remapping for each of the original arguments.
10411040
for (auto it : llvm::enumerate(origBlock->getArguments())) {
1041+
OpBuilder::InsertionGuard g(builder);
1042+
10421043
// If the type of this argument changed and the argument is still live, we
10431044
// need to materialize a conversion.
10441045
BlockArgument origArg = it.value();
@@ -1050,14 +1051,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
10501051

10511052
Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
10521053
bool isDroppedArg = replacementValue == origArg;
1053-
if (isDroppedArg)
1054-
rewriterImpl.rewriter.setInsertionPointToStart(getBlock());
1055-
else
1056-
rewriterImpl.rewriter.setInsertionPointAfterValue(replacementValue);
1054+
if (!isDroppedArg)
1055+
builder.setInsertionPointAfterValue(replacementValue);
10571056
Value newArg;
10581057
if (converter) {
10591058
newArg = converter->materializeSourceConversion(
1060-
rewriterImpl.rewriter, origArg.getLoc(), origArg.getType(),
1059+
builder, origArg.getLoc(), origArg.getType(),
10611060
isDroppedArg ? ValueRange() : ValueRange(replacementValue));
10621061
assert((!newArg || newArg.getType() == origArg.getType()) &&
10631062
"materialization hook did not provide a value of the expected "
@@ -1322,6 +1321,8 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
13221321
Block *ConversionPatternRewriterImpl::applySignatureConversion(
13231322
Block *block, const TypeConverter *converter,
13241323
TypeConverter::SignatureConversion &signatureConversion) {
1324+
MLIRContext *ctx = block->getParentOp()->getContext();
1325+
13251326
// If no arguments are being changed or added, there is nothing to do.
13261327
unsigned origArgCount = block->getNumArguments();
13271328
auto convertedTypes = signatureConversion.getConvertedTypes();
@@ -1338,7 +1339,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13381339

13391340
// Map all new arguments to the location of the argument they originate from.
13401341
SmallVector<Location> newLocs(convertedTypes.size(),
1341-
rewriter.getUnknownLoc());
1342+
Builder(ctx).getUnknownLoc());
13421343
for (unsigned i = 0; i < origArgCount; ++i) {
13431344
auto inputMap = signatureConversion.getInputMapping(i);
13441345
if (!inputMap || inputMap->replacementValue)
@@ -1357,8 +1358,6 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13571358
SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
13581359
argInfo.resize(origArgCount);
13591360

1360-
OpBuilder::InsertionGuard guard(rewriter);
1361-
rewriter.setInsertionPointToStart(newBlock);
13621361
for (unsigned i = 0; i != origArgCount; ++i) {
13631362
auto inputMap = signatureConversion.getInputMapping(i);
13641363
if (!inputMap)
@@ -1401,7 +1400,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
14011400
outputType = legalOutputType;
14021401

14031402
newArg = buildUnresolvedArgumentMaterialization(
1404-
rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
1403+
newBlock, origArg.getLoc(), replArgs, origOutputType, outputType,
14051404
converter);
14061405
}
14071406

@@ -1439,12 +1438,11 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
14391438
return convertOp.getResult(0);
14401439
}
14411440
Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
1442-
PatternRewriter &rewriter, Location loc, ValueRange inputs,
1443-
Type origOutputType, Type outputType, const TypeConverter *converter) {
1444-
return buildUnresolvedMaterialization(
1445-
MaterializationKind::Argument, rewriter.getInsertionBlock(),
1446-
rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
1447-
converter);
1441+
Block *block, Location loc, ValueRange inputs, Type origOutputType,
1442+
Type outputType, const TypeConverter *converter) {
1443+
return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
1444+
block->begin(), loc, inputs, outputType,
1445+
origOutputType, converter);
14481446
}
14491447
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
14501448
Location loc, Value input, Type outputType,
@@ -1556,7 +1554,7 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
15561554
ConversionPatternRewriter::ConversionPatternRewriter(
15571555
MLIRContext *ctx, const ConversionConfig &config)
15581556
: PatternRewriter(ctx),
1559-
impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
1557+
impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
15601558
setListener(impl.get());
15611559
}
15621560

0 commit comments

Comments
 (0)