Skip to content

Commit 1795b6d

Browse files
[mlir][Transforms] Support rolling back properties in dialect conversion
The dialect conversion rolls back inplace op modifications upon failure. Rolling back modifications of op properties was not supported before this commit.
1 parent 1c69f42 commit 1795b6d

File tree

3 files changed

+59
-2
lines changed

3 files changed

+59
-2
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1002,12 +1002,33 @@ class ModifyOperationRewrite : public OperationRewrite {
10021002
: OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
10031003
loc(op->getLoc()), attrs(op->getAttrDictionary()),
10041004
operands(op->operand_begin(), op->operand_end()),
1005-
successors(op->successor_begin(), op->successor_end()) {}
1005+
successors(op->successor_begin(), op->successor_end()) {
1006+
if (OpaqueProperties prop = op->getPropertiesStorage()) {
1007+
// Make a copy of the properties.
1008+
propertiesStorage = operator new(op->getPropertiesStorageSize());
1009+
OpaqueProperties propCopy(propertiesStorage);
1010+
op->getName().initOpProperties(propCopy, /*init=*/prop);
1011+
}
1012+
}
10061013

10071014
static bool classof(const IRRewrite *rewrite) {
10081015
return rewrite->getKind() == Kind::ModifyOperation;
10091016
}
10101017

1018+
~ModifyOperationRewrite() override {
1019+
assert(!propertiesStorage &&
1020+
"rewrite was neither committed nor rolled back");
1021+
}
1022+
1023+
void commit() override {
1024+
if (propertiesStorage) {
1025+
OpaqueProperties propCopy(propertiesStorage);
1026+
op->getName().destroyOpProperties(propCopy);
1027+
operator delete(propertiesStorage);
1028+
propertiesStorage = nullptr;
1029+
}
1030+
}
1031+
10111032
/// Discard the transaction state and reset the state of the original
10121033
/// operation.
10131034
void rollback() override {
@@ -1016,13 +1037,21 @@ class ModifyOperationRewrite : public OperationRewrite {
10161037
op->setOperands(operands);
10171038
for (const auto &it : llvm::enumerate(successors))
10181039
op->setSuccessor(it.value(), it.index());
1040+
if (propertiesStorage) {
1041+
OpaqueProperties propCopy(propertiesStorage);
1042+
op->copyProperties(propCopy);
1043+
op->getName().destroyOpProperties(propCopy);
1044+
operator delete(propertiesStorage);
1045+
propertiesStorage = nullptr;
1046+
}
10191047
}
10201048

10211049
private:
10221050
LocationAttr loc;
10231051
DictionaryAttr attrs;
10241052
SmallVector<Value, 8> operands;
10251053
SmallVector<Block *, 2> successors;
1054+
void *propertiesStorage = nullptr;
10261055
};
10271056
} // namespace
10281057

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,3 +334,15 @@ func.func @test_move_op_before_rollback() {
334334
}) : () -> ()
335335
"test.return"() : () -> ()
336336
}
337+
338+
// -----
339+
340+
// CHECK-LABEL: func @test_properties_rollback()
341+
func.func @test_properties_rollback() {
342+
// CHECK: test.with_properties <{a = 32 : i64,
343+
// expected-remark @below{{op 'test.with_properties' is not legalizable}}
344+
test.with_properties
345+
<{a = 32 : i64, array = array<i64: 1, 2, 3, 4>, b = "foo"}>
346+
{modify_inplace}
347+
"test.return"() : () -> ()
348+
}

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,21 @@ struct TestUndoBlockErase : public ConversionPattern {
806806
}
807807
};
808808

809+
/// A pattern that modifies a property in-place, but keeps the op illegal.
810+
struct TestUndoPropertiesModification : public ConversionPattern {
811+
TestUndoPropertiesModification(MLIRContext *ctx)
812+
: ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {}
813+
LogicalResult
814+
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
815+
ConversionPatternRewriter &rewriter) const final {
816+
if (!op->hasAttr("modify_inplace"))
817+
return failure();
818+
rewriter.modifyOpInPlace(
819+
op, [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); });
820+
return success();
821+
}
822+
};
823+
809824
//===----------------------------------------------------------------------===//
810825
// Type-Conversion Rewrite Testing
811826

@@ -1085,7 +1100,8 @@ struct TestLegalizePatternDriver
10851100
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
10861101
TestNonRootReplacement, TestBoundedRecursiveRewrite,
10871102
TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
1088-
TestCreateUnregisteredOp, TestUndoMoveOpBefore>(&getContext());
1103+
TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1104+
TestUndoPropertiesModification>(&getContext());
10891105
patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
10901106
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
10911107
converter);

0 commit comments

Comments
 (0)