-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Transform] Allow stateInitializer and stateExporter for applyTransforms #101186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, but let's give @ftynse a chance to review before merging this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see the discussion. Connecting to extensions via early initialization isn't much harder. And it will keep the single entry point to the system instead of letting everyone roll their own.
175a1c6
to
a7f7b24
Compare
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Amy Wang (kaitingwang) ChangesThis is discussed in RFC: Full diff: https://github.com/llvm/llvm-project/pull/101186.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
index 842e244dcde56c..0bb6037a77a16d 100644
--- a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
@@ -135,7 +135,9 @@ LogicalResult
applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
const RaggedArray<MappedValue> &extraMapping = {},
const TransformOptions &options = TransformOptions(),
- bool enforceToplevelTransformOp = true);
+ bool enforceToplevelTransformOp = true,
+ function_ref<void (TransformState &)> stateInitializer = nullptr,
+ function_ref<LogicalResult (TransformState &)> stateExporter = nullptr);
/// The state maintained across applications of various ops implementing the
/// TransformOpInterface. The operations implementing this interface and the
@@ -217,7 +219,9 @@ class TransformState {
friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
const RaggedArray<MappedValue> &,
- const TransformOptions &, bool);
+ const TransformOptions &, bool,
+ function_ref<void (TransformState &)>,
+ function_ref<LogicalResult (TransformState &)>);
friend TransformState
detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot);
diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
index f8f85e4615c500..5bc6d4ee5033f1 100644
--- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
@@ -1999,7 +1999,9 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
LogicalResult transform::applyTransforms(
Operation *payloadRoot, TransformOpInterface transform,
const RaggedArray<MappedValue> &extraMapping,
- const TransformOptions &options, bool enforceToplevelTransformOp) {
+ const TransformOptions &options, bool enforceToplevelTransformOp,
+ function_ref<void(TransformState &)> stateInitializer,
+ function_ref<LogicalResult(TransformState &)> stateExporter) {
if (enforceToplevelTransformOp) {
if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
transform->getNumOperands() != 0) {
@@ -2013,7 +2015,13 @@ LogicalResult transform::applyTransforms(
TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
options);
- return state.applyTransform(transform).checkAndReport();
+ if (stateInitializer)
+ stateInitializer(state);
+ if (state.applyTransform(transform).checkAndReport().failed())
+ return failure();
+ if (stateExporter)
+ return stateExporter(state);
+ return success();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/transform-state-extension-initializer.mlir b/mlir/test/Dialect/Transform/transform-state-extension-initializer.mlir
new file mode 100644
index 00000000000000..9fb4d1b8689164
--- /dev/null
+++ b/mlir/test/Dialect/Transform/transform-state-extension-initializer.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt %s -test-pass-state-extension-communication -verify-diagnostics | FileCheck %s
+
+// CHECK: Printing opCollection before processing transform ops, size: 1
+// CHECK: PASS-TRANSFORMOP-PASS
+
+// CHECK: Printing opCollection after processing transform ops, size: 4
+// CHECK: PASS-TRANSFORMOP-PASS transform.test_initializer_extension_A transform.test_initializer_extension_B transform.test_initializer_extension_C
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ // expected-remark @below {{Number of currently registered op: 1}}
+ transform.test_initializer_extension "A"
+ // expected-remark @below {{Number of currently registered op: 2}}
+ transform.test_initializer_extension "B"
+ // expected-remark @below {{Number of currently registered op: 3}}
+ transform.test_initializer_extension "C"
+ transform.yield
+ }
+}
diff --git a/mlir/test/lib/Dialect/Transform/CMakeLists.txt b/mlir/test/lib/Dialect/Transform/CMakeLists.txt
index e6ab915a657b6f..ca141d2778ee2d 100644
--- a/mlir/test/lib/Dialect/Transform/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Transform/CMakeLists.txt
@@ -6,6 +6,7 @@ mlir_tablegen(TestTransformDialectExtensionTypes.cpp.inc -gen-typedef-defs -type
add_public_tablegen_target(MLIRTestTransformDialectExtensionIncGen)
add_mlir_library(MLIRTestTransformDialect
+ TestPassStateExtensionCommunication.cpp
TestTransformDialectExtension.cpp
TestTransformDialectInterpreter.cpp
TestTransformStateExtension.cpp
diff --git a/mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp b/mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp
new file mode 100644
index 00000000000000..4b5958af21d014
--- /dev/null
+++ b/mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp
@@ -0,0 +1,101 @@
+//===- TestPassStateExtensionCommunication.cpp -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a test pass that showcases how communication can be
+// conducted between a regular mlir pass and transform ops through the
+// transform state extension stateInitializer and stateExporter mechanism.
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestTransformStateExtension.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::test;
+
+namespace {
+template <typename Derived>
+class OpPassWrapper : public PassWrapper<Derived, OperationPass<>> {};
+
+struct TestPassStateExtensionCommunication
+ : public PassWrapper<TestPassStateExtensionCommunication,
+ OperationPass<ModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestPassStateExtensionCommunication)
+
+ StringRef getArgument() const final {
+ return "test-pass-state-extension-communication";
+ }
+
+ StringRef getDescription() const final {
+ return "test state communciation between a mlir pass and transform ops";
+ }
+
+ static void printVector(const SmallVector<std::string> &opCollection,
+ const std::string &extraMessage = {}) {
+ outs() << "Printing opCollection" << extraMessage
+ << ", size: " << opCollection.size() << "\n";
+ for (const auto &subVector : opCollection) {
+ outs() << subVector << " ";
+ }
+ outs() << "\n";
+ }
+
+ void runOnOperation() override {
+ ModuleOp module = getOperation();
+
+ // Create an opCollection vector.
+ SmallVector<std::string> opCollection = {"PASS-TRANSFORMOP-PASS "};
+ printVector(opCollection, " before processing transform ops");
+
+ auto stateInitializer =
+ [&opCollection](mlir::transform::TransformState &state) -> void {
+ TransformStateInitializerExtension *ext =
+ state.getExtension<TransformStateInitializerExtension>();
+ if (!ext)
+ state.addExtension<TransformStateInitializerExtension>(0, opCollection);
+ };
+
+ auto stateExporter =
+ [&opCollection](
+ mlir::transform::TransformState &state) -> LogicalResult {
+ TransformStateInitializerExtension *ext =
+ state.getExtension<TransformStateInitializerExtension>();
+ if (!ext) {
+ errs() << "Target transform state extension not found!\n";
+ return failure();
+ }
+ opCollection.clear();
+ opCollection = ext->getRegisteredOps();
+ return success();
+ };
+
+ // Process transform ops with stateInitializer and stateExporter.
+ for (auto op : module.getBody()->getOps<transform::TransformOpInterface>())
+ if (failed(transform::applyTransforms(
+ module, op, {}, mlir::transform::TransformOptions(), false,
+ stateInitializer, stateExporter)))
+ return signalPassFailure();
+
+ // Print the opCollection vector after processing transform ops.
+ printVector(opCollection, " after processing transform ops");
+ }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+/// Registers the test pass here.
+void registerTestPassStateExtensionCommunication() {
+ PassRegistration<TestPassStateExtensionCommunication> reg;
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index c023aad4a3ee77..a0a7afce66d9a1 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -804,6 +804,28 @@ void mlir::test::TestProduceInvalidIR::getEffects(
transform::modifiesPayload(effects);
}
+DiagnosedSilenceableFailure mlir::test::TestInitializerExtensionOp::apply(
+ transform::TransformRewriter &rewriter,
+ transform::TransformResults &results, transform::TransformState &state) {
+ std::string opName =
+ this->getOperationName().str() + "_" + getTypeAttr().str();
+ TransformStateInitializerExtension *initExt =
+ state.getExtension<TransformStateInitializerExtension>();
+ if (!initExt) {
+ emitRemark() << "\nSpecified extension not found, adding a new one!\n";
+ SmallVector<std::string> opCollection = {opName};
+ state.addExtension<TransformStateInitializerExtension>(1, opCollection);
+ } else {
+ initExt->setNumOp(initExt->getNumOp() + 1);
+ initExt->pushRegisteredOps(opName);
+ InFlightDiagnostic diag = emitRemark()
+ << "Number of currently registered op: "
+ << initExt->getNumOp() << "\n"
+ << initExt->printMessage() << "\n";
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
namespace {
/// Test conversion pattern that replaces ops with the "replace_with_new_op"
/// attribute with "test.new_op".
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 4f2cf34f7d3347..76375dba369448 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -549,4 +549,13 @@ def TestProduceInvalidIR
}];
}
+def TestInitializerExtensionOp
+ : Op<Transform_Dialect, "test_initializer_extension",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ NoMemoryEffect]> {
+ let arguments = (ins StrAttr:$type);
+ let assemblyFormat = "$type attr-dict";
+ let cppNamespace = "::mlir::test";
+}
+
#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h b/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h
index 0bfa6bed015c0f..bbcbabea010b33 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h
+++ b/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h
@@ -34,6 +34,31 @@ class TestTransformStateExtension
private:
StringAttr message;
};
+
+class TransformStateInitializerExtension
+ : public transform::TransformState::Extension {
+public:
+ TransformStateInitializerExtension(transform::TransformState &state,
+ int numOp, SmallVector<std::string>& registeredOps)
+ : Extension(state), numOp(numOp), registeredOps(registeredOps) {}
+
+ int getNumOp() { return numOp; }
+ void setNumOp(int num) { numOp = num; }
+ SmallVector<std::string> getRegisteredOps() { return registeredOps; }
+ void pushRegisteredOps(const std::string& newOp) { registeredOps.push_back(newOp); }
+ std::string printMessage() const {
+ std::string message = "Registered transformOps are: ";
+ for (const auto& op : registeredOps) {
+ message += op + " | ";
+ }
+ return message;
+ }
+
+private:
+ int numOp;
+ SmallVector<std::string> registeredOps;
+};
+
} // namespace test
} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 1842fa158e75a9..36b142484bb04a 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -148,6 +148,7 @@ void registerTestTensorCopyInsertionPass();
void registerTestTensorTransforms();
void registerTestTopologicalSortAnalysisPass();
void registerTestTransformDialectEraseSchedulePass();
+void registerTestPassStateExtensionCommunication();
void registerTestVectorLowerings();
void registerTestVectorReductionToSPIRVDotProd();
void registerTestWrittenToPass();
@@ -283,6 +284,7 @@ void registerTestPasses() {
mlir::test::registerTestTensorTransforms();
mlir::test::registerTestTopologicalSortAnalysisPass();
mlir::test::registerTestTransformDialectEraseSchedulePass();
+ mlir::test::registerTestPassStateExtensionCommunication();
mlir::test::registerTestVectorLowerings();
mlir::test::registerTestVectorReductionToSPIRVDotProd();
mlir::test::registerTestWrittenToPass();
|
@ftynse I've updated the MR per our discussion on the RFC. Would appreciate your comments. Thank you! |
✅ With the latest revision this PR passed the C/C++ code formatter. |
a7f7b24
to
a369381
Compare
…form framework and back to the pass.
9937bf0
to
8d60864
Compare
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/4/builds/2009 Here is the relevant piece of the build log for the reference
|
Hello @llvm-ci ,
Both : ninja check-flang as well as ninja check-mlir pass completely. I cannot reproduce your reported compilation error. Would you paste me the command to reproduce this error? Thank you! |
After rebuild, everything passed! https://lab.llvm.org/buildbot/#/builders/4/builds/2016 |
This is discussed in RFC:
https://discourse.llvm.org/t/rfc-making-the-constructor-of-the-transformstate-class-protected/80377