Skip to content

[MLIR][Transform] Allow stateInitializer and stateExporter for applyTransforms #101186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,13 @@ class TransformOptions {
/// will be executed following the internal logic of the operation. It must
/// have the `PossibleTopLevelTransformOp` trait and not have any operands.
/// This function internally keeps track of the transformation state.
LogicalResult
applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
const RaggedArray<MappedValue> &extraMapping = {},
const TransformOptions &options = TransformOptions(),
bool enforceToplevelTransformOp = true);
LogicalResult applyTransforms(
Operation *payloadRoot, TransformOpInterface transform,
const RaggedArray<MappedValue> &extraMapping = {},
const TransformOptions &options = TransformOptions(),
bool enforceToplevelTransformOp = true,
function_ref<void(TransformState &)> stateInitializer = nullptr,
function_ref<LogicalResult(TransformState &)> stateExporter = nullptr);

/// The state maintained across applications of various ops implementing the
/// TransformOpInterface. The operations implementing this interface and the
Expand Down Expand Up @@ -215,9 +217,11 @@ class TransformState {
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
};

friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
const RaggedArray<MappedValue> &,
const TransformOptions &, bool);
friend LogicalResult
applyTransforms(Operation *, TransformOpInterface,
const RaggedArray<MappedValue> &, const TransformOptions &,
bool, function_ref<void(TransformState &)>,
function_ref<LogicalResult(TransformState &)>);

friend TransformState
detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot);
Expand Down
12 changes: 10 additions & 2 deletions mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1999,7 +1999,9 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
LogicalResult transform::applyTransforms(
Operation *payloadRoot, TransformOpInterface transform,
const RaggedArray<MappedValue> &extraMapping,
const TransformOptions &options, bool enforceToplevelTransformOp) {
const TransformOptions &options, bool enforceToplevelTransformOp,
function_ref<void(TransformState &)> stateInitializer,
function_ref<LogicalResult(TransformState &)> stateExporter) {
if (enforceToplevelTransformOp) {
if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
transform->getNumOperands() != 0) {
Expand All @@ -2013,7 +2015,13 @@ LogicalResult transform::applyTransforms(

TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
options);
return state.applyTransform(transform).checkAndReport();
if (stateInitializer)
stateInitializer(state);
if (state.applyTransform(transform).checkAndReport().failed())
return failure();
if (stateExporter)
return stateExporter(state);
return success();
}

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: mlir-opt %s -test-pass-state-extension-communication -verify-diagnostics | FileCheck %s

// CHECK: Printing opCollection before processing transform ops, size: 1
// CHECK: PASS-TRANSFORMOP-PASS

// CHECK: Printing opCollection after processing transform ops, size: 4
// CHECK: PASS-TRANSFORMOP-PASS transform.test_initializer_extension_A transform.test_initializer_extension_B transform.test_initializer_extension_C

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
// expected-remark @below {{Number of currently registered op: 1}}
transform.test_initializer_extension "A"
// expected-remark @below {{Number of currently registered op: 2}}
transform.test_initializer_extension "B"
// expected-remark @below {{Number of currently registered op: 3}}
transform.test_initializer_extension "C"
transform.yield
}
}
1 change: 1 addition & 0 deletions mlir/test/lib/Dialect/Transform/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mlir_tablegen(TestTransformDialectExtensionTypes.cpp.inc -gen-typedef-defs -type
add_public_tablegen_target(MLIRTestTransformDialectExtensionIncGen)

add_mlir_library(MLIRTestTransformDialect
TestPassStateExtensionCommunication.cpp
TestTransformDialectExtension.cpp
TestTransformDialectInterpreter.cpp
TestTransformStateExtension.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
//===- TestPassStateExtensionCommunication.cpp ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines a test pass that showcases how communication can be
// conducted between a regular mlir pass and transform ops through the
// transform state extension stateInitializer and stateExporter mechanism.
//
//===----------------------------------------------------------------------===//

#include "TestTransformStateExtension.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"

using namespace llvm;
using namespace mlir;
using namespace mlir::test;

namespace {
template <typename Derived>
class OpPassWrapper : public PassWrapper<Derived, OperationPass<>> {};

struct TestPassStateExtensionCommunication
: public PassWrapper<TestPassStateExtensionCommunication,
OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestPassStateExtensionCommunication)

StringRef getArgument() const final {
return "test-pass-state-extension-communication";
}

StringRef getDescription() const final {
return "test state communciation between a mlir pass and transform ops";
}

static void printVector(const SmallVector<std::string> &opCollection,
const std::string &extraMessage = {}) {
outs() << "Printing opCollection" << extraMessage
<< ", size: " << opCollection.size() << "\n";
for (const auto &subVector : opCollection) {
outs() << subVector << " ";
}
outs() << "\n";
}

void runOnOperation() override {
ModuleOp module = getOperation();

// Create an opCollection vector.
SmallVector<std::string> opCollection = {"PASS-TRANSFORMOP-PASS "};
printVector(opCollection, " before processing transform ops");

auto stateInitializer =
[&opCollection](mlir::transform::TransformState &state) -> void {
TransformStateInitializerExtension *ext =
state.getExtension<TransformStateInitializerExtension>();
if (!ext)
state.addExtension<TransformStateInitializerExtension>(0, opCollection);
};

auto stateExporter =
[&opCollection](
mlir::transform::TransformState &state) -> LogicalResult {
TransformStateInitializerExtension *ext =
state.getExtension<TransformStateInitializerExtension>();
if (!ext) {
errs() << "Target transform state extension not found!\n";
return failure();
}
opCollection.clear();
opCollection = ext->getRegisteredOps();
return success();
};

// Process transform ops with stateInitializer and stateExporter.
for (auto op : module.getBody()->getOps<transform::TransformOpInterface>())
if (failed(transform::applyTransforms(
module, op, {}, mlir::transform::TransformOptions(), false,
stateInitializer, stateExporter)))
return signalPassFailure();

// Print the opCollection vector after processing transform ops.
printVector(opCollection, " after processing transform ops");
}
};
} // namespace

namespace mlir {
namespace test {
/// Registers the test pass here.
void registerTestPassStateExtensionCommunication() {
PassRegistration<TestPassStateExtensionCommunication> reg;
}
} // namespace test
} // namespace mlir
22 changes: 22 additions & 0 deletions mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,28 @@ void mlir::test::TestProduceInvalidIR::getEffects(
transform::modifiesPayload(effects);
}

DiagnosedSilenceableFailure mlir::test::TestInitializerExtensionOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
std::string opName =
this->getOperationName().str() + "_" + getTypeAttr().str();
TransformStateInitializerExtension *initExt =
state.getExtension<TransformStateInitializerExtension>();
if (!initExt) {
emitRemark() << "\nSpecified extension not found, adding a new one!\n";
SmallVector<std::string> opCollection = {opName};
state.addExtension<TransformStateInitializerExtension>(1, opCollection);
} else {
initExt->setNumOp(initExt->getNumOp() + 1);
initExt->pushRegisteredOps(opName);
InFlightDiagnostic diag = emitRemark()
<< "Number of currently registered op: "
<< initExt->getNumOp() << "\n"
<< initExt->printMessage() << "\n";
}
return DiagnosedSilenceableFailure::success();
}

namespace {
/// Test conversion pattern that replaces ops with the "replace_with_new_op"
/// attribute with "test.new_op".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -549,4 +549,13 @@ def TestProduceInvalidIR
}];
}

def TestInitializerExtensionOp
: Op<Transform_Dialect, "test_initializer_extension",
[DeclareOpInterfaceMethods<TransformOpInterface>,
NoMemoryEffect]> {
let arguments = (ins StrAttr:$type);
let assemblyFormat = "$type attr-dict";
let cppNamespace = "::mlir::test";
}

#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
28 changes: 28 additions & 0 deletions mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,34 @@ class TestTransformStateExtension
private:
StringAttr message;
};

class TransformStateInitializerExtension
: public transform::TransformState::Extension {
public:
TransformStateInitializerExtension(transform::TransformState &state,
int numOp,
SmallVector<std::string> &registeredOps)
: Extension(state), numOp(numOp), registeredOps(registeredOps) {}

int getNumOp() { return numOp; }
void setNumOp(int num) { numOp = num; }
SmallVector<std::string> getRegisteredOps() { return registeredOps; }
void pushRegisteredOps(const std::string &newOp) {
registeredOps.push_back(newOp);
}
std::string printMessage() const {
std::string message = "Registered transformOps are: ";
for (const auto &op : registeredOps) {
message += op + " | ";
}
return message;
}

private:
int numOp;
SmallVector<std::string> registeredOps;
};

} // namespace test
} // namespace mlir

Expand Down
2 changes: 2 additions & 0 deletions mlir/tools/mlir-opt/mlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ void registerTestTensorCopyInsertionPass();
void registerTestTensorTransforms();
void registerTestTopologicalSortAnalysisPass();
void registerTestTransformDialectEraseSchedulePass();
void registerTestPassStateExtensionCommunication();
void registerTestVectorLowerings();
void registerTestVectorReductionToSPIRVDotProd();
void registerTestWrittenToPass();
Expand Down Expand Up @@ -283,6 +284,7 @@ void registerTestPasses() {
mlir::test::registerTestTensorTransforms();
mlir::test::registerTestTopologicalSortAnalysisPass();
mlir::test::registerTestTransformDialectEraseSchedulePass();
mlir::test::registerTestPassStateExtensionCommunication();
mlir::test::registerTestVectorLowerings();
mlir::test::registerTestVectorReductionToSPIRVDotProd();
mlir::test::registerTestWrittenToPass();
Expand Down
Loading