Skip to content

[mlir][Transforms] Add 1:N support to replaceUsesOfBlockArgument #145171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -763,8 +763,9 @@ class ConversionPatternRewriter final : public PatternRewriter {
Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion = nullptr);

/// Replace all the uses of the block argument `from` with value `to`.
void replaceUsesOfBlockArgument(BlockArgument from, Value to);
/// Replace all the uses of the block argument `from` with `to`. This
/// function supports both 1:1 and 1:N replacements.
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);

/// Return the converted value of 'key' with a type defined by the type
/// converter of the currently executing pattern. Return nullptr in the case
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ static void restoreByValRefArgumentType(
Type resTy = typeConverter.convertType(
cast<TypeAttr>(byValRefAttr->getValue()).getValue());

auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
Value valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
rewriter.replaceUsesOfBlockArgument(arg, valueArg);
}
}
Expand Down
40 changes: 25 additions & 15 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// uses.
void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);

/// Replace the given block argument with the given values. The specified
/// converter is used to build materializations (if necessary).
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to,
const TypeConverter *converter);

/// Erase the given block and its contents.
void eraseBlock(Block *block);

Expand Down Expand Up @@ -1434,12 +1439,15 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
if (!inputMap) {
// This block argument was dropped and no replacement value was provided.
// Materialize a replacement value "out of thin air".
buildUnresolvedMaterialization(
MaterializationKind::Source,
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
/*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(),
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
Value mat =
buildUnresolvedMaterialization(
MaterializationKind::Source,
OpBuilder::InsertPoint(newBlock, newBlock->begin()),
origArg.getLoc(),
/*valuesToMap=*/{}, /*inputs=*/ValueRange(),
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter)
.front();
replaceUsesOfBlockArgument(origArg, mat, converter);
continue;
}

Expand All @@ -1448,17 +1456,15 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
assert(inputMap->size == 0 &&
"invalid to provide a replacement value when the argument isn't "
"dropped");
mapping.map(origArg, inputMap->replacementValues);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
replaceUsesOfBlockArgument(origArg, inputMap->replacementValues,
converter);
continue;
}

// This is a 1->1+ mapping.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
mapping.map(origArg, std::move(replArgVals));
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
replaceUsesOfBlockArgument(origArg, replArgs, converter);
}

appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
Expand Down Expand Up @@ -1612,6 +1618,12 @@ void ConversionPatternRewriterImpl::replaceOp(
op->walk([&](Operation *op) { replacedOps.insert(op); });
}

void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
BlockArgument from, ValueRange to, const TypeConverter *converter) {
appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
mapping.map(from, to);
}

void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
assert(!wasOpReplaced(block->getParentOp()) &&
"attempting to erase a block within a replaced/erased op");
Expand Down Expand Up @@ -1744,7 +1756,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
}

void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
Value to) {
ValueRange to) {
LLVM_DEBUG({
impl->logger.startLine() << "** Replace Argument : '" << from << "'";
if (Operation *parentOp = from.getOwner()->getParentOp()) {
Expand All @@ -1754,9 +1766,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
impl->logger.getOStream() << " (unlinked block)\n";
}
});
impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
impl->currentTypeConverter);
impl->mapping.map(from, to);
impl->replaceUsesOfBlockArgument(from, to, impl->currentTypeConverter);
}

Value ConversionPatternRewriter::getRemappedValue(Value key) {
Expand Down
31 changes: 24 additions & 7 deletions mlir/test/Transforms/test-legalizer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -300,18 +300,35 @@ func.func @create_illegal_block() {
// -----

// CHECK-LABEL: @undo_block_arg_replace
// expected-remark@+1{{applyPartialConversion failed}}
module {
func.func @undo_block_arg_replace() {
// expected-remark@+1 {{op 'test.undo_block_arg_replace' is not legalizable}}
"test.undo_block_arg_replace"() ({
^bb0(%arg0: i32):
// CHECK: ^bb0(%[[ARG:.*]]: i32):
// CHECK-NEXT: "test.return"(%[[ARG]]) : (i32)
// expected-error@+1{{failed to legalize operation 'test.block_arg_replace' that was explicitly marked illegal}}
"test.block_arg_replace"() ({
^bb0(%arg0: i32, %arg1: i16):
// CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
// CHECK-NEXT: "test.return"(%[[ARG0]]) : (i32)

"test.return"(%arg0) : (i32) -> ()
}) : () -> ()
// expected-remark@+1 {{op 'func.return' is not legalizable}}
}) {trigger_rollback} : () -> ()
return
}
}

// -----

// CHECK-LABEL: @replace_block_arg_1_to_n
func.func @replace_block_arg_1_to_n() {
// CHECK: "test.block_arg_replace"
"test.block_arg_replace"() ({
^bb0(%arg0: i32, %arg1: i16):
// CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
// CHECK: %[[cast:.*]] = "test.cast"(%[[ARG1]], %[[ARG1]]) : (i16, i16) -> i32
// CHECK-NEXT: "test.return"(%[[cast]]) : (i32)
"test.return"(%arg0) : (i32) -> ()
}) : () -> ()
"test.return"() : () -> ()
}

// -----

Expand Down
51 changes: 29 additions & 22 deletions mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -891,20 +891,25 @@ struct TestCreateIllegalBlock : public RewritePattern {
}
};

/// A simple pattern that tests the undo mechanism when replacing the uses of a
/// block argument.
struct TestUndoBlockArgReplace : public ConversionPattern {
TestUndoBlockArgReplace(MLIRContext *ctx)
: ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
/// A simple pattern that tests the "replaceUsesOfBlockArgument" API.
struct TestBlockArgReplace : public ConversionPattern {
TestBlockArgReplace(MLIRContext *ctx, const TypeConverter &converter)
: ConversionPattern(converter, "test.block_arg_replace", /*benefit=*/1,
ctx) {}

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto illegalOp =
rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
// Replace the first block argument with 2x the second block argument.
Value repl = op->getRegion(0).getArgument(1);
rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
illegalOp->getResult(0));
rewriter.modifyOpInPlace(op, [] {});
{repl, repl});
rewriter.modifyOpInPlace(op, [&] {
// If the "trigger_rollback" attribute is set, keep the op illegal, so
// that a rollback is triggered.
if (!op->hasAttr("trigger_rollback"))
op->setAttr("is_legal", rewriter.getUnitAttr());
});
return success();
}
};
Expand Down Expand Up @@ -1375,20 +1380,19 @@ struct TestLegalizePatternDriver
TestTypeConverter converter;
mlir::RewritePatternSet patterns(&getContext());
populateWithGenerated(patterns);
patterns
.add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
TestNonRootReplacement, TestBoundedRecursiveRewrite,
TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
TestCreateUnregisteredOp, TestUndoMoveOpBefore,
TestUndoPropertiesModification, TestEraseOp,
TestRepetitive1ToNConsumer>(&getContext());
patterns.add<
TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
TestUndoBlockErase, TestSplitReturnType, TestChangeProducerTypeI32ToF32,
TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
TestUpdateConsumerType, TestNonRootReplacement,
TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
TestUndoPropertiesModification, TestEraseOp,
TestRepetitive1ToNConsumer>(&getContext());
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
TestPassthroughInvalidOp, TestMultiple1ToNReplacement>(
&getContext(), converter);
TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
TestBlockArgReplace>(&getContext(), converter);
patterns.add<TestConvertBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
Expand All @@ -1413,6 +1417,9 @@ struct TestLegalizePatternDriver
});
target.addDynamicallyLegalOp<func::CallOp>(
[&](func::CallOp op) { return converter.isLegal(op); });
target.addDynamicallyLegalOp(
OperationName("test.block_arg_replace", &getContext()),
[](Operation *op) { return op->hasAttr("is_legal"); });

// TestCreateUnregisteredOp creates `arith.constant` operation,
// which was not added to target intentionally to test
Expand Down
Loading