Skip to content

[mlir][Transforms][NFC] Decouple ConversionPatternRewriterImpl from ConversionPatternRewriter #82333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 21 additions & 23 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -756,10 +756,9 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
namespace mlir {
namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter,
explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
const ConversionConfig &config)
: rewriter(rewriter), eraseRewriter(rewriter.getContext()),
config(config) {}
: eraseRewriter(ctx), config(config) {}

//===--------------------------------------------------------------------===//
// State Management
Expand Down Expand Up @@ -854,8 +853,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
Type origOutputType,
const TypeConverter *converter);

Value buildUnresolvedArgumentMaterialization(PatternRewriter &rewriter,
Location loc, ValueRange inputs,
Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
ValueRange inputs,
Type origOutputType,
Type outputType,
const TypeConverter *converter);
Expand Down Expand Up @@ -934,8 +933,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
// State
//===--------------------------------------------------------------------===//

PatternRewriter &rewriter;

/// This rewriter must be used for erasing ops/blocks.
SingleEraseRewriter eraseRewriter;

Expand Down Expand Up @@ -1037,8 +1034,12 @@ void BlockTypeConversionRewrite::rollback() {

LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
function_ref<Operation *(Value)> findLiveUser) {
auto builder = OpBuilder::atBlockBegin(block, /*listener=*/&rewriterImpl);

// Process the remapping for each of the original arguments.
for (auto it : llvm::enumerate(origBlock->getArguments())) {
OpBuilder::InsertionGuard g(builder);

// If the type of this argument changed and the argument is still live, we
// need to materialize a conversion.
BlockArgument origArg = it.value();
Expand All @@ -1050,14 +1051,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(

Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
bool isDroppedArg = replacementValue == origArg;
if (isDroppedArg)
rewriterImpl.rewriter.setInsertionPointToStart(getBlock());
else
rewriterImpl.rewriter.setInsertionPointAfterValue(replacementValue);
if (!isDroppedArg)
builder.setInsertionPointAfterValue(replacementValue);
Value newArg;
if (converter) {
newArg = converter->materializeSourceConversion(
rewriterImpl.rewriter, origArg.getLoc(), origArg.getType(),
builder, origArg.getLoc(), origArg.getType(),
isDroppedArg ? ValueRange() : ValueRange(replacementValue));
assert((!newArg || newArg.getType() == origArg.getType()) &&
"materialization hook did not provide a value of the expected "
Expand Down Expand Up @@ -1322,6 +1321,8 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
Block *ConversionPatternRewriterImpl::applySignatureConversion(
Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion) {
MLIRContext *ctx = block->getParentOp()->getContext();

// If no arguments are being changed or added, there is nothing to do.
unsigned origArgCount = block->getNumArguments();
auto convertedTypes = signatureConversion.getConvertedTypes();
Expand All @@ -1338,7 +1339,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(

// Map all new arguments to the location of the argument they originate from.
SmallVector<Location> newLocs(convertedTypes.size(),
rewriter.getUnknownLoc());
Builder(ctx).getUnknownLoc());
for (unsigned i = 0; i < origArgCount; ++i) {
auto inputMap = signatureConversion.getInputMapping(i);
if (!inputMap || inputMap->replacementValue)
Expand All @@ -1357,8 +1358,6 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
argInfo.resize(origArgCount);

OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(newBlock);
for (unsigned i = 0; i != origArgCount; ++i) {
auto inputMap = signatureConversion.getInputMapping(i);
if (!inputMap)
Expand Down Expand Up @@ -1401,7 +1400,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
outputType = legalOutputType;

newArg = buildUnresolvedArgumentMaterialization(
rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
newBlock, origArg.getLoc(), replArgs, origOutputType, outputType,
converter);
}

Expand Down Expand Up @@ -1439,12 +1438,11 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
return convertOp.getResult(0);
}
Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
PatternRewriter &rewriter, Location loc, ValueRange inputs,
Type origOutputType, Type outputType, const TypeConverter *converter) {
return buildUnresolvedMaterialization(
MaterializationKind::Argument, rewriter.getInsertionBlock(),
rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
converter);
Block *block, Location loc, ValueRange inputs, Type origOutputType,
Type outputType, const TypeConverter *converter) {
return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
block->begin(), loc, inputs, outputType,
origOutputType, converter);
}
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
Location loc, Value input, Type outputType,
Expand Down Expand Up @@ -1556,7 +1554,7 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
ConversionPatternRewriter::ConversionPatternRewriter(
MLIRContext *ctx, const ConversionConfig &config)
: PatternRewriter(ctx),
impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
setListener(impl.get());
}

Expand Down