-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][py] Overload print with state. #72064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Enables reusing the AsmState when printing from Python. Also moves the fileObject and binary to the end (pybind11::object was resulting in the overload not working unless `state=` was specified).
@llvm/pr-subscribers-mlir Author: Jacques Pienaar (jpienaar) ChangesEnables reusing the AsmState when printing from Python. Also moves the fileObject and binary to the end (pybind11::object was resulting in the overload not working unless Full diff: https://github.com/llvm/llvm-project/pull/72064.diff 3 Files Affected:
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 0f2ca666ccc050e..a4330b062532763 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -110,6 +110,15 @@ static const char kOperationPrintDocstring[] =
invalid, behavior is undefined.
)";
+static const char kOperationPrintStateDocstring[] =
+ R"(Prints the assembly form of the operation to a file like object.
+
+Args:
+ file: The file like object to write to. Defaults to sys.stdout.
+ binary: Whether to write bytes (True) or str (False). Defaults to False.
+ state: AsmState capturing the operation numbering and flags.
+)";
+
static const char kOperationGetAsmDocstring[] =
R"(Gets the assembly form of the operation with all options available.
@@ -1169,11 +1178,11 @@ void PyOperation::checkValid() const {
}
}
-void PyOperationBase::print(py::object fileObject, bool binary,
- std::optional<int64_t> largeElementsLimit,
+void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
bool printGenericOpForm, bool useLocalScope,
- bool assumeVerified) {
+ bool assumeVerified, py::object fileObject,
+ bool binary) {
PyOperation &operation = getOperation();
operation.checkValid();
if (fileObject.is_none())
@@ -1198,6 +1207,17 @@ void PyOperationBase::print(py::object fileObject, bool binary,
mlirOpPrintingFlagsDestroy(flags);
}
+void PyOperationBase::print(PyAsmState &state, pybind11::object fileObject,
+ bool binary) {
+ PyOperation &operation = getOperation();
+ operation.checkValid();
+ if (fileObject.is_none())
+ fileObject = py::module::import("sys").attr("stdout");
+ PyFileAccumulator accum(fileObject, binary);
+ mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
+ accum.getUserData());
+}
+
void PyOperationBase::writeBytecode(const py::object &fileObject,
std::optional<int64_t> bytecodeVersion) {
PyOperation &operation = getOperation();
@@ -1230,13 +1250,14 @@ py::object PyOperationBase::getAsm(bool binary,
} else {
fileObject = py::module::import("io").attr("StringIO")();
}
- print(fileObject, /*binary=*/binary,
- /*largeElementsLimit=*/largeElementsLimit,
+ print(/*largeElementsLimit=*/largeElementsLimit,
/*enableDebugInfo=*/enableDebugInfo,
/*prettyDebugInfo=*/prettyDebugInfo,
/*printGenericOpForm=*/printGenericOpForm,
/*useLocalScope=*/useLocalScope,
- /*assumeVerified=*/assumeVerified);
+ /*assumeVerified=*/assumeVerified,
+ /*fileObject=*/fileObject,
+ /*binary=*/binary);
return fileObject.attr("getvalue")();
}
@@ -2946,15 +2967,23 @@ void mlir::python::populateIRCore(py::module &m) {
/*assumeVerified=*/false);
},
"Returns the assembly form of the operation.")
- .def("print", &PyOperationBase::print,
+ .def("print",
+ py::overload_cast<PyAsmState &, pybind11::object, bool>(
+ &PyOperationBase::print),
+ py::arg("state"), py::arg("file") = py::none(),
+ py::arg("binary") = false, kOperationPrintStateDocstring)
+ .def("print",
+ py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
+ bool, pybind11::object, bool>(
+ &PyOperationBase::print),
// Careful: Lots of arguments must match up with print method.
- py::arg("file") = py::none(), py::arg("binary") = false,
py::arg("large_elements_limit") = py::none(),
py::arg("enable_debug_info") = false,
py::arg("pretty_debug_info") = false,
py::arg("print_generic_op_form") = false,
py::arg("use_local_scope") = false,
- py::arg("assume_verified") = false, kOperationPrintDocstring)
+ py::arg("assume_verified") = false, py::arg("file") = py::none(),
+ py::arg("binary") = false, kOperationPrintDocstring)
.def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"),
py::arg("desired_version") = py::none(),
kOperationPrintBytecodeDocstring)
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index af55693f18fbbf9..3f856681881829a 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -550,16 +550,19 @@ class PyModule : public BaseContextObject {
pybind11::handle handle;
};
+class PyAsmState;
+
/// Base class for PyOperation and PyOpView which exposes the primary, user
/// visible methods for manipulating it.
class PyOperationBase {
public:
virtual ~PyOperationBase() = default;
/// Implements the bound 'print' method and helps with others.
- void print(pybind11::object fileObject, bool binary,
- std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
+ void print(std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
- bool assumeVerified);
+ bool assumeVerified, pybind11::object fileObject, bool binary);
+ void print(PyAsmState &state, pybind11::object fileObject, bool binary);
+
pybind11::object getAsm(bool binary,
std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 04239b048c1c641..04f8a9936e31f79 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -622,10 +622,15 @@ def testOperationPrint():
print(bytes_value.__class__)
print(bytes_value)
- # Test get_asm local_scope.
+ # Test print local_scope.
# CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
module.operation.print(enable_debug_info=True, use_local_scope=True)
+ # Test printing using state.
+ state = AsmState(module.operation)
+ # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+ module.operation.print(state)
+
# Test get_asm with options.
# CHECK: value = dense_resource<__elided__> : tensor<4xi32>
# CHECK: "func.return"(%arg0) : (i32) -> () -:4:7
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I'm not sure that the test actually tests the intent (reusing AsmState
) but I don't think it matters (it sufficiently exercises the code).
@@ -2946,15 +2967,23 @@ void mlir::python::populateIRCore(py::module &m) { | |||
/*assumeVerified=*/false); | |||
}, | |||
"Returns the assembly form of the operation.") | |||
.def("print", &PyOperationBase::print, | |||
.def("print", | |||
py::overload_cast<PyAsmState &, pybind11::object, bool>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
py::overload_cast<PyAsmState &, pybind11::object, bool>( | |
py::overload_cast<PyAsmState &, py::object, bool>( |
Co-authored-by: Maksim Levental <[email protected]>
Co-authored-by: Maksim Levental <[email protected]>
Co-authored-by: Maksim Levental <[email protected]>
You can test this locally with the following command:git-clang-format --diff c42d006f05a03cecc0417d77b3bfb8f936d7594a da8c0492a83040c1a3f569ef11b962e45880b491 -- mlir/lib/Bindings/Python/IRCore.cpp mlir/lib/Bindings/Python/IRModule.h View the diff from clang-format here.diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index d4e947a61a..0493a7ba07 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2974,8 +2974,7 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("binary") = false, kOperationPrintStateDocstring)
.def("print",
py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
- bool, py::object, bool>(
- &PyOperationBase::print),
+ bool, py::object, bool>(&PyOperationBase::print),
// Careful: Lots of arguments must match up with print method.
py::arg("large_elements_limit") = py::none(),
py::arg("enable_debug_info") = false,
|
You can test this locally with the following command:git-clang-format --diff c42d006f05a03cecc0417d77b3bfb8f936d7594a e9017419ab5bd8bdb89b455e403faf82586e48ae -- mlir/lib/Bindings/Python/IRCore.cpp mlir/lib/Bindings/Python/IRModule.h View the diff from clang-format here.diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index d4e947a61a..0493a7ba07 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2974,8 +2974,7 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("binary") = false, kOperationPrintStateDocstring)
.def("print",
py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
- bool, py::object, bool>(
- &PyOperationBase::print),
+ bool, py::object, bool>(&PyOperationBase::print),
// Careful: Lots of arguments must match up with print method.
py::arg("large_elements_limit") = py::none(),
py::arg("enable_debug_info") = false,
|
Thanks. |
Enables reusing the AsmState when printing from Python. Also moves the fileObject and binary to the end (pybind11::object was resulting in the overload not working unless `state=` was specified). --------- Co-authored-by: Maksim Levental <[email protected]>
Enables reusing the AsmState when printing from Python. Also moves the fileObject and binary to the end (pybind11::object was resulting in the overload not working unless
state=
was specified).