Skip to content

Commit 91f1161

Browse files
authored
[mlir] expose transform interpreter to Python (#82365)
Transform interpreter functionality can be used standalone without going through the interpreter pass, make it available in Python.
1 parent 5db49f7 commit 91f1161

File tree

17 files changed

+506
-31
lines changed

17 files changed

+506
-31
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
//===-- mlir-c/Dialect/Transform/Interpreter.h --------------------*- C -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4+
// Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// C interface to the transform dialect interpreter.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir-c/IR.h"
15+
#include "mlir-c/Support.h"
16+
17+
#ifdef __cplusplus
18+
extern "C" {
19+
#endif
20+
21+
#define DEFINE_C_API_STRUCT(name, storage) \
22+
struct name { \
23+
storage *ptr; \
24+
}; \
25+
typedef struct name name
26+
27+
DEFINE_C_API_STRUCT(MlirTransformOptions, void);
28+
29+
#undef DEFINE_C_API_STRUCT
30+
31+
//----------------------------------------------------------------------------//
32+
// MlirTransformOptions
33+
//----------------------------------------------------------------------------//
34+
35+
/// Creates a default-initialized transform options object.
36+
MLIR_CAPI_EXPORTED MlirTransformOptions mlirTransformOptionsCreate(void);
37+
38+
/// Enables or disables expensive checks in transform options.
39+
MLIR_CAPI_EXPORTED void
40+
mlirTransformOptionsEnableExpensiveChecks(MlirTransformOptions transformOptions,
41+
bool enable);
42+
43+
/// Returns true if expensive checks are enabled in transform options.
44+
MLIR_CAPI_EXPORTED bool mlirTransformOptionsGetExpensiveChecksEnabled(
45+
MlirTransformOptions transformOptions);
46+
47+
/// Enables or disables the enforcement of the top-level transform op being
48+
/// single in transform options.
49+
MLIR_CAPI_EXPORTED void mlirTransformOptionsEnforceSingleTopLevelTransformOp(
50+
MlirTransformOptions transformOptions, bool enable);
51+
52+
/// Returns true if the enforcement of the top-level transform op being single
53+
/// is enabled in transform options.
54+
MLIR_CAPI_EXPORTED bool mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
55+
MlirTransformOptions transformOptions);
56+
57+
/// Destroys a transform options object previously created by
58+
/// mlirTransformOptionsCreate.
59+
MLIR_CAPI_EXPORTED void
60+
mlirTransformOptionsDestroy(MlirTransformOptions transformOptions);
61+
62+
//----------------------------------------------------------------------------//
63+
// Transform interpreter.
64+
//----------------------------------------------------------------------------//
65+
66+
/// Applies the transformation script starting at the given transform root
67+
/// operation to the given payload operation. The module containing the
68+
/// transform root as well as the transform options should be provided. The
69+
/// transform operation must implement TransformOpInterface and the module must
70+
/// be a ModuleOp. Returns the status of the application.
71+
MLIR_CAPI_EXPORTED MlirLogicalResult mlirTransformApplyNamedSequence(
72+
MlirOperation payload, MlirOperation transformRoot,
73+
MlirOperation transformModule, MlirTransformOptions transformOptions);
74+
75+
#ifdef __cplusplus
76+
}
77+
#endif

mlir/include/mlir/Bindings/Python/PybindAdaptors.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <pybind11/stl.h>
2424

2525
#include "mlir-c/Bindings/Python/Interop.h"
26+
#include "mlir-c/Diagnostics.h"
2627
#include "mlir-c/IR.h"
2728

2829
#include "llvm/ADT/Twine.h"
@@ -569,6 +570,41 @@ class mlir_value_subclass : public pure_subclass {
569570
};
570571

571572
} // namespace adaptors
573+
574+
/// RAII scope intercepting all diagnostics into a string. The message must be
575+
/// checked before this goes out of scope.
576+
class CollectDiagnosticsToStringScope {
577+
public:
578+
explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
579+
handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
580+
/*deleteUserData=*/nullptr);
581+
}
582+
~CollectDiagnosticsToStringScope() {
583+
assert(errorMessage.empty() && "unchecked error message");
584+
mlirContextDetachDiagnosticHandler(context, handlerID);
585+
}
586+
587+
[[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
588+
589+
private:
590+
static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
591+
auto printer = +[](MlirStringRef message, void *data) {
592+
*static_cast<std::string *>(data) +=
593+
llvm::StringRef(message.data, message.length);
594+
};
595+
MlirLocation loc = mlirDiagnosticGetLocation(diag);
596+
*static_cast<std::string *>(data) += "at ";
597+
mlirLocationPrint(loc, printer, data);
598+
*static_cast<std::string *>(data) += ": ";
599+
mlirDiagnosticPrint(diag, printer, data);
600+
return mlirLogicalResultSuccess();
601+
}
602+
603+
MlirContext context;
604+
MlirDiagnosticHandlerID handlerID;
605+
std::string errorMessage = "";
606+
};
607+
572608
} // namespace python
573609
} // namespace mlir
574610

mlir/lib/Bindings/Python/DialectLLVM.cpp

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "mlir-c/Diagnostics.h"
109
#include "mlir-c/Dialect/LLVM.h"
1110
#include "mlir-c/IR.h"
1211
#include "mlir-c/Support.h"
@@ -19,36 +18,6 @@ using namespace mlir;
1918
using namespace mlir::python;
2019
using namespace mlir::python::adaptors;
2120

22-
/// RAII scope intercepting all diagnostics into a string. The message must be
23-
/// checked before this goes out of scope.
24-
class CollectDiagnosticsToStringScope {
25-
public:
26-
explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
27-
handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
28-
/*deleteUserData=*/nullptr);
29-
}
30-
~CollectDiagnosticsToStringScope() {
31-
assert(errorMessage.empty() && "unchecked error message");
32-
mlirContextDetachDiagnosticHandler(context, handlerID);
33-
}
34-
35-
[[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
36-
37-
private:
38-
static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
39-
auto printer = +[](MlirStringRef message, void *data) {
40-
*static_cast<std::string *>(data) +=
41-
StringRef(message.data, message.length);
42-
};
43-
mlirDiagnosticPrint(diag, printer, data);
44-
return mlirLogicalResultSuccess();
45-
}
46-
47-
MlirContext context;
48-
MlirDiagnosticHandlerID handlerID;
49-
std::string errorMessage = "";
50-
};
51-
5221
void populateDialectLLVMSubmodule(const pybind11::module &m) {
5322
auto llvmStructType =
5423
mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,10 @@ void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
678678
mlirOperationWalk(op.getOperation(), invalidatingCallback,
679679
static_cast<void *>(&data), MlirWalkPreOrder);
680680
}
681+
void PyMlirContext::clearOperationsInside(MlirOperation op) {
682+
PyOperationRef opRef = PyOperation::forOperation(getRef(), op);
683+
clearOperationsInside(opRef->getOperation());
684+
}
681685

682686
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
683687

@@ -2556,6 +2560,9 @@ void mlir::python::populateIRCore(py::module &m) {
25562560
.def("_get_live_operation_objects",
25572561
&PyMlirContext::getLiveOperationObjects)
25582562
.def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
2563+
.def("_clear_live_operations_inside",
2564+
py::overload_cast<MlirOperation>(
2565+
&PyMlirContext::clearOperationsInside))
25592566
.def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
25602567
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
25612568
&PyMlirContext::getCapsule)

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ class PyMlirContext {
223223
/// Clears all operations nested inside the given op using
224224
/// `clearOperation(MlirOperation)`.
225225
void clearOperationsInside(PyOperationBase &op);
226+
void clearOperationsInside(MlirOperation op);
226227

227228
/// Gets the count of live modules associated with this context.
228229
/// Used for testing.
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
//===- TransformInterpreter.cpp -------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Pybind classes for the transform dialect interpreter.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir-c/Dialect/Transform/Interpreter.h"
14+
#include "mlir-c/IR.h"
15+
#include "mlir-c/Support.h"
16+
#include "mlir/Bindings/Python/PybindAdaptors.h"
17+
18+
#include <pybind11/detail/common.h>
19+
#include <pybind11/pybind11.h>
20+
21+
namespace py = pybind11;
22+
23+
namespace {
24+
struct PyMlirTransformOptions {
25+
PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); };
26+
PyMlirTransformOptions(PyMlirTransformOptions &&other) {
27+
options = other.options;
28+
other.options.ptr = nullptr;
29+
}
30+
PyMlirTransformOptions(const PyMlirTransformOptions &) = delete;
31+
32+
~PyMlirTransformOptions() { mlirTransformOptionsDestroy(options); }
33+
34+
MlirTransformOptions options;
35+
};
36+
} // namespace
37+
38+
static void populateTransformInterpreterSubmodule(py::module &m) {
39+
py::class_<PyMlirTransformOptions>(m, "TransformOptions", py::module_local())
40+
.def(py::init())
41+
.def_property(
42+
"expensive_checks",
43+
[](const PyMlirTransformOptions &self) {
44+
return mlirTransformOptionsGetExpensiveChecksEnabled(self.options);
45+
},
46+
[](PyMlirTransformOptions &self, bool value) {
47+
mlirTransformOptionsEnableExpensiveChecks(self.options, value);
48+
})
49+
.def_property(
50+
"enforce_single_top_level_transform_op",
51+
[](const PyMlirTransformOptions &self) {
52+
return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
53+
self.options);
54+
},
55+
[](PyMlirTransformOptions &self, bool value) {
56+
mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options,
57+
value);
58+
});
59+
60+
m.def(
61+
"apply_named_sequence",
62+
[](MlirOperation payloadRoot, MlirOperation transformRoot,
63+
MlirOperation transformModule, const PyMlirTransformOptions &options) {
64+
mlir::python::CollectDiagnosticsToStringScope scope(
65+
mlirOperationGetContext(transformRoot));
66+
67+
// Calling back into Python to invalidate everything under the payload
68+
// root. This is awkward, but we don't have access to PyMlirContext
69+
// object here otherwise.
70+
py::object obj = py::cast(payloadRoot);
71+
obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot);
72+
73+
MlirLogicalResult result = mlirTransformApplyNamedSequence(
74+
payloadRoot, transformRoot, transformModule, options.options);
75+
if (mlirLogicalResultIsSuccess(result))
76+
return;
77+
78+
throw py::value_error(
79+
"Failed to apply named transform sequence.\nDiagnostic message " +
80+
scope.takeMessage());
81+
},
82+
py::arg("payload_root"), py::arg("transform_root"),
83+
py::arg("transform_module"),
84+
py::arg("transform_options") = PyMlirTransformOptions());
85+
}
86+
87+
PYBIND11_MODULE(_mlirTransformInterpreter, m) {
88+
m.doc() = "MLIR Transform dialect interpreter functionality.";
89+
populateTransformInterpreterSubmodule(m);
90+
}

mlir/lib/CAPI/Dialect/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,15 @@ add_mlir_upstream_c_api_library(MLIRCAPITransformDialect
198198
MLIRTransformDialect
199199
)
200200

201+
add_mlir_upstream_c_api_library(MLIRCAPITransformDialectTransforms
202+
TransformInterpreter.cpp
203+
204+
PARTIAL_SOURCES_INTENDED
205+
LINK_LIBS PUBLIC
206+
MLIRCAPIIR
207+
MLIRTransformDialectTransforms
208+
)
209+
201210
add_mlir_upstream_c_api_library(MLIRCAPIQuant
202211
Quant.cpp
203212

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
//===- TransformTransforms.cpp - C Interface for Transform dialect --------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// C interface to transforms for the transform dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir-c/Dialect/Transform/Interpreter.h"
14+
#include "mlir-c/Support.h"
15+
#include "mlir/CAPI/IR.h"
16+
#include "mlir/CAPI/Support.h"
17+
#include "mlir/CAPI/Wrap.h"
18+
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
19+
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
20+
21+
using namespace mlir;
22+
23+
DEFINE_C_API_PTR_METHODS(MlirTransformOptions, transform::TransformOptions)
24+
25+
extern "C" {
26+
27+
MlirTransformOptions mlirTransformOptionsCreate() {
28+
return wrap(new transform::TransformOptions);
29+
}
30+
31+
void mlirTransformOptionsEnableExpensiveChecks(
32+
MlirTransformOptions transformOptions, bool enable) {
33+
unwrap(transformOptions)->enableExpensiveChecks(enable);
34+
}
35+
36+
bool mlirTransformOptionsGetExpensiveChecksEnabled(
37+
MlirTransformOptions transformOptions) {
38+
return unwrap(transformOptions)->getExpensiveChecksEnabled();
39+
}
40+
41+
void mlirTransformOptionsEnforceSingleTopLevelTransformOp(
42+
MlirTransformOptions transformOptions, bool enable) {
43+
unwrap(transformOptions)->enableEnforceSingleToplevelTransformOp(enable);
44+
}
45+
46+
bool mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
47+
MlirTransformOptions transformOptions) {
48+
return unwrap(transformOptions)->getEnforceSingleToplevelTransformOp();
49+
}
50+
51+
void mlirTransformOptionsDestroy(MlirTransformOptions transformOptions) {
52+
delete unwrap(transformOptions);
53+
}
54+
55+
MlirLogicalResult mlirTransformApplyNamedSequence(
56+
MlirOperation payload, MlirOperation transformRoot,
57+
MlirOperation transformModule, MlirTransformOptions transformOptions) {
58+
Operation *transformRootOp = unwrap(transformRoot);
59+
Operation *transformModuleOp = unwrap(transformModule);
60+
if (!isa<transform::TransformOpInterface>(transformRootOp)) {
61+
transformRootOp->emitError()
62+
<< "must implement TransformOpInterface to be used as transform root";
63+
return mlirLogicalResultFailure();
64+
}
65+
if (!isa<ModuleOp>(transformModuleOp)) {
66+
transformModuleOp->emitError()
67+
<< "must be a " << ModuleOp::getOperationName();
68+
return mlirLogicalResultFailure();
69+
}
70+
return wrap(transform::applyTransformNamedSequence(
71+
unwrap(payload), unwrap(transformRoot),
72+
cast<ModuleOp>(unwrap(transformModule)), *unwrap(transformOptions)));
73+
}
74+
}

0 commit comments

Comments
 (0)