Skip to content

Commit 6634d44

Browse files
authored
[MLIR][Transform] Allow stateInitializer and stateExporter for applyTransforms (#101186)
This is discussed in RFC: https://discourse.llvm.org/t/rfc-making-the-constructor-of-the-transformstate-class-protected/80377
1 parent 111932d commit 6634d44

File tree

9 files changed

+204
-10
lines changed

9 files changed

+204
-10
lines changed

mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,13 @@ class TransformOptions {
131131
/// will be executed following the internal logic of the operation. It must
132132
/// have the `PossibleTopLevelTransformOp` trait and not have any operands.
133133
/// This function internally keeps track of the transformation state.
134-
LogicalResult
135-
applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
136-
const RaggedArray<MappedValue> &extraMapping = {},
137-
const TransformOptions &options = TransformOptions(),
138-
bool enforceToplevelTransformOp = true);
134+
LogicalResult applyTransforms(
135+
Operation *payloadRoot, TransformOpInterface transform,
136+
const RaggedArray<MappedValue> &extraMapping = {},
137+
const TransformOptions &options = TransformOptions(),
138+
bool enforceToplevelTransformOp = true,
139+
function_ref<void(TransformState &)> stateInitializer = nullptr,
140+
function_ref<LogicalResult(TransformState &)> stateExporter = nullptr);
139141

140142
/// The state maintained across applications of various ops implementing the
141143
/// TransformOpInterface. The operations implementing this interface and the
@@ -215,9 +217,11 @@ class TransformState {
215217
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
216218
};
217219

218-
friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
219-
const RaggedArray<MappedValue> &,
220-
const TransformOptions &, bool);
220+
friend LogicalResult
221+
applyTransforms(Operation *, TransformOpInterface,
222+
const RaggedArray<MappedValue> &, const TransformOptions &,
223+
bool, function_ref<void(TransformState &)>,
224+
function_ref<LogicalResult(TransformState &)>);
221225

222226
friend TransformState
223227
detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot);

mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1999,7 +1999,9 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
19991999
LogicalResult transform::applyTransforms(
20002000
Operation *payloadRoot, TransformOpInterface transform,
20012001
const RaggedArray<MappedValue> &extraMapping,
2002-
const TransformOptions &options, bool enforceToplevelTransformOp) {
2002+
const TransformOptions &options, bool enforceToplevelTransformOp,
2003+
function_ref<void(TransformState &)> stateInitializer,
2004+
function_ref<LogicalResult(TransformState &)> stateExporter) {
20032005
if (enforceToplevelTransformOp) {
20042006
if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
20052007
transform->getNumOperands() != 0) {
@@ -2013,7 +2015,13 @@ LogicalResult transform::applyTransforms(
20132015

20142016
TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
20152017
options);
2016-
return state.applyTransform(transform).checkAndReport();
2018+
if (stateInitializer)
2019+
stateInitializer(state);
2020+
if (state.applyTransform(transform).checkAndReport().failed())
2021+
return failure();
2022+
if (stateExporter)
2023+
return stateExporter(state);
2024+
return success();
20172025
}
20182026

20192027
//===----------------------------------------------------------------------===//
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: mlir-opt %s -test-pass-state-extension-communication -verify-diagnostics | FileCheck %s
2+
3+
// CHECK: Printing opCollection before processing transform ops, size: 1
4+
// CHECK: PASS-TRANSFORMOP-PASS
5+
6+
// CHECK: Printing opCollection after processing transform ops, size: 4
7+
// CHECK: PASS-TRANSFORMOP-PASS transform.test_initializer_extension_A transform.test_initializer_extension_B transform.test_initializer_extension_C
8+
9+
module attributes {transform.with_named_sequence} {
10+
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
11+
// expected-remark @below {{Number of currently registered op: 1}}
12+
transform.test_initializer_extension "A"
13+
// expected-remark @below {{Number of currently registered op: 2}}
14+
transform.test_initializer_extension "B"
15+
// expected-remark @below {{Number of currently registered op: 3}}
16+
transform.test_initializer_extension "C"
17+
transform.yield
18+
}
19+
}

mlir/test/lib/Dialect/Transform/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ mlir_tablegen(TestTransformDialectExtensionTypes.cpp.inc -gen-typedef-defs -type
66
add_public_tablegen_target(MLIRTestTransformDialectExtensionIncGen)
77

88
add_mlir_library(MLIRTestTransformDialect
9+
TestPassStateExtensionCommunication.cpp
910
TestTransformDialectExtension.cpp
1011
TestTransformDialectInterpreter.cpp
1112
TestTransformStateExtension.cpp
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
//===- TestPassStateExtensionCommunication.cpp ----------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines a test pass that showcases how communication can be
10+
// conducted between a regular mlir pass and transform ops through the
11+
// transform state extension stateInitializer and stateExporter mechanism.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "TestTransformStateExtension.h"
16+
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
17+
#include "mlir/IR/BuiltinOps.h"
18+
#include "mlir/Pass/Pass.h"
19+
20+
using namespace llvm;
21+
using namespace mlir;
22+
using namespace mlir::test;
23+
24+
namespace {
25+
template <typename Derived>
26+
class OpPassWrapper : public PassWrapper<Derived, OperationPass<>> {};
27+
28+
struct TestPassStateExtensionCommunication
29+
: public PassWrapper<TestPassStateExtensionCommunication,
30+
OperationPass<ModuleOp>> {
31+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
32+
TestPassStateExtensionCommunication)
33+
34+
StringRef getArgument() const final {
35+
return "test-pass-state-extension-communication";
36+
}
37+
38+
StringRef getDescription() const final {
39+
return "test state communciation between a mlir pass and transform ops";
40+
}
41+
42+
static void printVector(const SmallVector<std::string> &opCollection,
43+
const std::string &extraMessage = {}) {
44+
outs() << "Printing opCollection" << extraMessage
45+
<< ", size: " << opCollection.size() << "\n";
46+
for (const auto &subVector : opCollection) {
47+
outs() << subVector << " ";
48+
}
49+
outs() << "\n";
50+
}
51+
52+
void runOnOperation() override {
53+
ModuleOp module = getOperation();
54+
55+
// Create an opCollection vector.
56+
SmallVector<std::string> opCollection = {"PASS-TRANSFORMOP-PASS "};
57+
printVector(opCollection, " before processing transform ops");
58+
59+
auto stateInitializer =
60+
[&opCollection](mlir::transform::TransformState &state) -> void {
61+
TransformStateInitializerExtension *ext =
62+
state.getExtension<TransformStateInitializerExtension>();
63+
if (!ext)
64+
state.addExtension<TransformStateInitializerExtension>(0, opCollection);
65+
};
66+
67+
auto stateExporter =
68+
[&opCollection](
69+
mlir::transform::TransformState &state) -> LogicalResult {
70+
TransformStateInitializerExtension *ext =
71+
state.getExtension<TransformStateInitializerExtension>();
72+
if (!ext) {
73+
errs() << "Target transform state extension not found!\n";
74+
return failure();
75+
}
76+
opCollection.clear();
77+
opCollection = ext->getRegisteredOps();
78+
return success();
79+
};
80+
81+
// Process transform ops with stateInitializer and stateExporter.
82+
for (auto op : module.getBody()->getOps<transform::TransformOpInterface>())
83+
if (failed(transform::applyTransforms(
84+
module, op, {}, mlir::transform::TransformOptions(), false,
85+
stateInitializer, stateExporter)))
86+
return signalPassFailure();
87+
88+
// Print the opCollection vector after processing transform ops.
89+
printVector(opCollection, " after processing transform ops");
90+
}
91+
};
92+
} // namespace
93+
94+
namespace mlir {
95+
namespace test {
96+
/// Registers the test pass here.
97+
void registerTestPassStateExtensionCommunication() {
98+
PassRegistration<TestPassStateExtensionCommunication> reg;
99+
}
100+
} // namespace test
101+
} // namespace mlir

mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,28 @@ void mlir::test::TestProduceInvalidIR::getEffects(
804804
transform::modifiesPayload(effects);
805805
}
806806

807+
DiagnosedSilenceableFailure mlir::test::TestInitializerExtensionOp::apply(
808+
transform::TransformRewriter &rewriter,
809+
transform::TransformResults &results, transform::TransformState &state) {
810+
std::string opName =
811+
this->getOperationName().str() + "_" + getTypeAttr().str();
812+
TransformStateInitializerExtension *initExt =
813+
state.getExtension<TransformStateInitializerExtension>();
814+
if (!initExt) {
815+
emitRemark() << "\nSpecified extension not found, adding a new one!\n";
816+
SmallVector<std::string> opCollection = {opName};
817+
state.addExtension<TransformStateInitializerExtension>(1, opCollection);
818+
} else {
819+
initExt->setNumOp(initExt->getNumOp() + 1);
820+
initExt->pushRegisteredOps(opName);
821+
InFlightDiagnostic diag = emitRemark()
822+
<< "Number of currently registered op: "
823+
<< initExt->getNumOp() << "\n"
824+
<< initExt->printMessage() << "\n";
825+
}
826+
return DiagnosedSilenceableFailure::success();
827+
}
828+
807829
namespace {
808830
/// Test conversion pattern that replaces ops with the "replace_with_new_op"
809831
/// attribute with "test.new_op".

mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,4 +549,13 @@ def TestProduceInvalidIR
549549
}];
550550
}
551551

552+
def TestInitializerExtensionOp
553+
: Op<Transform_Dialect, "test_initializer_extension",
554+
[DeclareOpInterfaceMethods<TransformOpInterface>,
555+
NoMemoryEffect]> {
556+
let arguments = (ins StrAttr:$type);
557+
let assemblyFormat = "$type attr-dict";
558+
let cppNamespace = "::mlir::test";
559+
}
560+
552561
#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD

mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,34 @@ class TestTransformStateExtension
3434
private:
3535
StringAttr message;
3636
};
37+
38+
class TransformStateInitializerExtension
39+
: public transform::TransformState::Extension {
40+
public:
41+
TransformStateInitializerExtension(transform::TransformState &state,
42+
int numOp,
43+
SmallVector<std::string> &registeredOps)
44+
: Extension(state), numOp(numOp), registeredOps(registeredOps) {}
45+
46+
int getNumOp() { return numOp; }
47+
void setNumOp(int num) { numOp = num; }
48+
SmallVector<std::string> getRegisteredOps() { return registeredOps; }
49+
void pushRegisteredOps(const std::string &newOp) {
50+
registeredOps.push_back(newOp);
51+
}
52+
std::string printMessage() const {
53+
std::string message = "Registered transformOps are: ";
54+
for (const auto &op : registeredOps) {
55+
message += op + " | ";
56+
}
57+
return message;
58+
}
59+
60+
private:
61+
int numOp;
62+
SmallVector<std::string> registeredOps;
63+
};
64+
3765
} // namespace test
3866
} // namespace mlir
3967

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ void registerTestTensorCopyInsertionPass();
148148
void registerTestTensorTransforms();
149149
void registerTestTopologicalSortAnalysisPass();
150150
void registerTestTransformDialectEraseSchedulePass();
151+
void registerTestPassStateExtensionCommunication();
151152
void registerTestVectorLowerings();
152153
void registerTestVectorReductionToSPIRVDotProd();
153154
void registerTestWrittenToPass();
@@ -283,6 +284,7 @@ void registerTestPasses() {
283284
mlir::test::registerTestTensorTransforms();
284285
mlir::test::registerTestTopologicalSortAnalysisPass();
285286
mlir::test::registerTestTransformDialectEraseSchedulePass();
287+
mlir::test::registerTestPassStateExtensionCommunication();
286288
mlir::test::registerTestVectorLowerings();
287289
mlir::test::registerTestVectorReductionToSPIRVDotProd();
288290
mlir::test::registerTestWrittenToPass();

0 commit comments

Comments
 (0)