Skip to content

Commit 90824e4

Browse files
[mlir][Transforms] Add 1:N support to replaceUsesOfBlockArgument
1 parent 613c38a commit 90824e4

File tree

5 files changed

+82
-47
lines changed

5 files changed

+82
-47
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -763,8 +763,9 @@ class ConversionPatternRewriter final : public PatternRewriter {
763763
Region *region, const TypeConverter &converter,
764764
TypeConverter::SignatureConversion *entryConversion = nullptr);
765765

766-
/// Replace all the uses of the block argument `from` with value `to`.
767-
void replaceUsesOfBlockArgument(BlockArgument from, Value to);
766+
/// Replace all the uses of the block argument `from` with `to`. This
767+
/// function supports both 1:1 and 1:N replacements.
768+
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);
768769

769770
/// Return the converted value of 'key' with a type defined by the type
770771
/// converter of the currently executing pattern. Return nullptr in the case

mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ static void restoreByValRefArgumentType(
294294
Type resTy = typeConverter.convertType(
295295
cast<TypeAttr>(byValRefAttr->getValue()).getValue());
296296

297-
auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
297+
Value valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
298298
rewriter.replaceUsesOfBlockArgument(arg, valueArg);
299299
}
300300
}

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
948948
/// uses.
949949
void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
950950

951+
/// Replace the given block argument with the given values. The specified
952+
/// converter is used to build materializations (if necessary).
953+
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to,
954+
const TypeConverter *converter);
955+
951956
/// Erase the given block and its contents.
952957
void eraseBlock(Block *block);
953958

@@ -1434,12 +1439,15 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
14341439
if (!inputMap) {
14351440
// This block argument was dropped and no replacement value was provided.
14361441
// Materialize a replacement value "out of thin air".
1437-
buildUnresolvedMaterialization(
1438-
MaterializationKind::Source,
1439-
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1440-
/*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(),
1441-
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter);
1442-
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1442+
Value mat =
1443+
buildUnresolvedMaterialization(
1444+
MaterializationKind::Source,
1445+
OpBuilder::InsertPoint(newBlock, newBlock->begin()),
1446+
origArg.getLoc(),
1447+
/*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1448+
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter)
1449+
.front();
1450+
replaceUsesOfBlockArgument(origArg, mat, converter);
14431451
continue;
14441452
}
14451453

@@ -1448,17 +1456,15 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
14481456
assert(inputMap->size == 0 &&
14491457
"invalid to provide a replacement value when the argument isn't "
14501458
"dropped");
1451-
mapping.map(origArg, inputMap->replacementValues);
1452-
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1459+
replaceUsesOfBlockArgument(origArg, inputMap->replacementValues,
1460+
converter);
14531461
continue;
14541462
}
14551463

14561464
// This is a 1->1+ mapping.
14571465
auto replArgs =
14581466
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1459-
ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
1460-
mapping.map(origArg, std::move(replArgVals));
1461-
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1467+
replaceUsesOfBlockArgument(origArg, replArgs, converter);
14621468
}
14631469

14641470
appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
@@ -1612,6 +1618,12 @@ void ConversionPatternRewriterImpl::replaceOp(
16121618
op->walk([&](Operation *op) { replacedOps.insert(op); });
16131619
}
16141620

1621+
void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
1622+
BlockArgument from, ValueRange to, const TypeConverter *converter) {
1623+
appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
1624+
mapping.map(from, to);
1625+
}
1626+
16151627
void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
16161628
assert(!wasOpReplaced(block->getParentOp()) &&
16171629
"attempting to erase a block within a replaced/erased op");
@@ -1744,7 +1756,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
17441756
}
17451757

17461758
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
1747-
Value to) {
1759+
ValueRange to) {
17481760
LLVM_DEBUG({
17491761
impl->logger.startLine() << "** Replace Argument : '" << from << "'";
17501762
if (Operation *parentOp = from.getOwner()->getParentOp()) {
@@ -1754,9 +1766,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
17541766
impl->logger.getOStream() << " (unlinked block)\n";
17551767
}
17561768
});
1757-
impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
1758-
impl->currentTypeConverter);
1759-
impl->mapping.map(from, to);
1769+
impl->replaceUsesOfBlockArgument(from, to, impl->currentTypeConverter);
17601770
}
17611771

17621772
Value ConversionPatternRewriter::getRemappedValue(Value key) {

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -300,18 +300,35 @@ func.func @create_illegal_block() {
300300
// -----
301301

302302
// CHECK-LABEL: @undo_block_arg_replace
303+
// expected-remark@+1{{applyPartialConversion failed}}
304+
module {
303305
func.func @undo_block_arg_replace() {
304-
// expected-remark@+1 {{op 'test.undo_block_arg_replace' is not legalizable}}
305-
"test.undo_block_arg_replace"() ({
306-
^bb0(%arg0: i32):
307-
// CHECK: ^bb0(%[[ARG:.*]]: i32):
308-
// CHECK-NEXT: "test.return"(%[[ARG]]) : (i32)
306+
// expected-error@+1{{failed to legalize operation 'test.block_arg_replace' that was explicitly marked illegal}}
307+
"test.block_arg_replace"() ({
308+
^bb0(%arg0: i32, %arg1: i16):
309+
// CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
310+
// CHECK-NEXT: "test.return"(%[[ARG0]]) : (i32)
309311

310312
"test.return"(%arg0) : (i32) -> ()
311-
}) : () -> ()
312-
// expected-remark@+1 {{op 'func.return' is not legalizable}}
313+
}) {trigger_rollback} : () -> ()
313314
return
314315
}
316+
}
317+
318+
// -----
319+
320+
// CHECK-LABEL: @replace_block_arg_1_to_n
321+
func.func @replace_block_arg_1_to_n() {
322+
// CHECK: "test.block_arg_replace"
323+
"test.block_arg_replace"() ({
324+
^bb0(%arg0: i32, %arg1: i16):
325+
// CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
326+
// CHECK: %[[cast:.*]] = "test.cast"(%[[ARG1]], %[[ARG1]]) : (i16, i16) -> i32
327+
// CHECK-NEXT: "test.return"(%[[cast]]) : (i32)
328+
"test.return"(%arg0) : (i32) -> ()
329+
}) : () -> ()
330+
"test.return"() : () -> ()
331+
}
315332

316333
// -----
317334

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -891,20 +891,25 @@ struct TestCreateIllegalBlock : public RewritePattern {
891891
}
892892
};
893893

894-
/// A simple pattern that tests the undo mechanism when replacing the uses of a
895-
/// block argument.
896-
struct TestUndoBlockArgReplace : public ConversionPattern {
897-
TestUndoBlockArgReplace(MLIRContext *ctx)
898-
: ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
894+
/// A simple pattern that tests the "replaceUsesOfBlockArgument" API.
895+
struct TestBlockArgReplace : public ConversionPattern {
896+
TestBlockArgReplace(MLIRContext *ctx, const TypeConverter &converter)
897+
: ConversionPattern(converter, "test.block_arg_replace", /*benefit=*/1,
898+
ctx) {}
899899

900900
LogicalResult
901901
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
902902
ConversionPatternRewriter &rewriter) const final {
903-
auto illegalOp =
904-
rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
903+
// Replace the first block argument with 2x the second block argument.
904+
Value repl = op->getRegion(0).getArgument(1);
905905
rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
906-
illegalOp->getResult(0));
907-
rewriter.modifyOpInPlace(op, [] {});
906+
{repl, repl});
907+
rewriter.modifyOpInPlace(op, [&] {
908+
// If the "trigger_rollback" attribute is set, keep the op illegal, so
909+
// that a rollback is triggered.
910+
if (!op->hasAttr("trigger_rollback"))
911+
op->setAttr("is_legal", rewriter.getUnitAttr());
912+
});
908913
return success();
909914
}
910915
};
@@ -1375,20 +1380,19 @@ struct TestLegalizePatternDriver
13751380
TestTypeConverter converter;
13761381
mlir::RewritePatternSet patterns(&getContext());
13771382
populateWithGenerated(patterns);
1378-
patterns
1379-
.add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
1380-
TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
1381-
TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
1382-
TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
1383-
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
1384-
TestNonRootReplacement, TestBoundedRecursiveRewrite,
1385-
TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
1386-
TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1387-
TestUndoPropertiesModification, TestEraseOp,
1388-
TestRepetitive1ToNConsumer>(&getContext());
1383+
patterns.add<
1384+
TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
1385+
TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
1386+
TestUndoBlockErase, TestSplitReturnType, TestChangeProducerTypeI32ToF32,
1387+
TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
1388+
TestUpdateConsumerType, TestNonRootReplacement,
1389+
TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
1390+
TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1391+
TestUndoPropertiesModification, TestEraseOp,
1392+
TestRepetitive1ToNConsumer>(&getContext());
13891393
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
1390-
TestPassthroughInvalidOp, TestMultiple1ToNReplacement>(
1391-
&getContext(), converter);
1394+
TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
1395+
TestBlockArgReplace>(&getContext(), converter);
13921396
patterns.add<TestConvertBlockArgs>(converter, &getContext());
13931397
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
13941398
converter);
@@ -1413,6 +1417,9 @@ struct TestLegalizePatternDriver
14131417
});
14141418
target.addDynamicallyLegalOp<func::CallOp>(
14151419
[&](func::CallOp op) { return converter.isLegal(op); });
1420+
target.addDynamicallyLegalOp(
1421+
OperationName("test.block_arg_replace", &getContext()),
1422+
[](Operation *op) { return op->hasAttr("is_legal"); });
14161423

14171424
// TestCreateUnregisteredOp creates `arith.constant` operation,
14181425
// which was not added to target intentionally to test

0 commit comments

Comments
 (0)