Skip to content

[mlir][py] Overload print with state. #72064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 13, 2023
Merged

[mlir][py] Overload print with state. #72064

merged 4 commits into from
Nov 13, 2023

Conversation

jpienaar
Copy link
Member

Enables reusing the AsmState when printing from Python. Also moves the fileObject and binary to the end (pybind11::object was resulting in the overload not working unless state= was specified).

Enables reusing the AsmState when printing from Python. Also moves the
fileObject and binary to the end (pybind11::object was resulting in the
overload not working unless `state=` was specified).
@llvmbot
Copy link
Member

llvmbot commented Nov 12, 2023

@llvm/pr-subscribers-mlir

Author: Jacques Pienaar (jpienaar)

Changes

Enables reusing the AsmState when printing from Python. Also moves the fileObject and binary to the end (pybind11::object was resulting in the overload not working unless state= was specified).


Full diff: https://github.com/llvm/llvm-project/pull/72064.diff

3 Files Affected:

  • (modified) mlir/lib/Bindings/Python/IRCore.cpp (+38-9)
  • (modified) mlir/lib/Bindings/Python/IRModule.h (+6-3)
  • (modified) mlir/test/python/ir/operation.py (+6-1)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 0f2ca666ccc050e..a4330b062532763 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -110,6 +110,15 @@ static const char kOperationPrintDocstring[] =
     invalid, behavior is undefined.
 )";
 
+static const char kOperationPrintStateDocstring[] =
+    R"(Prints the assembly form of the operation to a file like object.
+
+Args:
+  file: The file like object to write to. Defaults to sys.stdout.
+  binary: Whether to write bytes (True) or str (False). Defaults to False.
+  state: AsmState capturing the operation numbering and flags.
+)";
+
 static const char kOperationGetAsmDocstring[] =
     R"(Gets the assembly form of the operation with all options available.
 
@@ -1169,11 +1178,11 @@ void PyOperation::checkValid() const {
   }
 }
 
-void PyOperationBase::print(py::object fileObject, bool binary,
-                            std::optional<int64_t> largeElementsLimit,
+void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
                             bool enableDebugInfo, bool prettyDebugInfo,
                             bool printGenericOpForm, bool useLocalScope,
-                            bool assumeVerified) {
+                            bool assumeVerified, py::object fileObject,
+                            bool binary) {
   PyOperation &operation = getOperation();
   operation.checkValid();
   if (fileObject.is_none())
@@ -1198,6 +1207,17 @@ void PyOperationBase::print(py::object fileObject, bool binary,
   mlirOpPrintingFlagsDestroy(flags);
 }
 
+void PyOperationBase::print(PyAsmState &state, pybind11::object fileObject,
+                            bool binary) {
+  PyOperation &operation = getOperation();
+  operation.checkValid();
+  if (fileObject.is_none())
+    fileObject = py::module::import("sys").attr("stdout");
+  PyFileAccumulator accum(fileObject, binary);
+  mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
+                              accum.getUserData());
+}
+
 void PyOperationBase::writeBytecode(const py::object &fileObject,
                                     std::optional<int64_t> bytecodeVersion) {
   PyOperation &operation = getOperation();
@@ -1230,13 +1250,14 @@ py::object PyOperationBase::getAsm(bool binary,
   } else {
     fileObject = py::module::import("io").attr("StringIO")();
   }
-  print(fileObject, /*binary=*/binary,
-        /*largeElementsLimit=*/largeElementsLimit,
+  print(/*largeElementsLimit=*/largeElementsLimit,
         /*enableDebugInfo=*/enableDebugInfo,
         /*prettyDebugInfo=*/prettyDebugInfo,
         /*printGenericOpForm=*/printGenericOpForm,
         /*useLocalScope=*/useLocalScope,
-        /*assumeVerified=*/assumeVerified);
+        /*assumeVerified=*/assumeVerified,
+        /*fileObject=*/fileObject,
+        /*binary=*/binary);
 
   return fileObject.attr("getvalue")();
 }
@@ -2946,15 +2967,23 @@ void mlir::python::populateIRCore(py::module &m) {
                                /*assumeVerified=*/false);
           },
           "Returns the assembly form of the operation.")
-      .def("print", &PyOperationBase::print,
+      .def("print",
+           py::overload_cast<PyAsmState &, pybind11::object, bool>(
+               &PyOperationBase::print),
+           py::arg("state"), py::arg("file") = py::none(),
+           py::arg("binary") = false, kOperationPrintStateDocstring)
+      .def("print",
+           py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
+                             bool, pybind11::object, bool>(
+               &PyOperationBase::print),
            // Careful: Lots of arguments must match up with print method.
-           py::arg("file") = py::none(), py::arg("binary") = false,
            py::arg("large_elements_limit") = py::none(),
            py::arg("enable_debug_info") = false,
            py::arg("pretty_debug_info") = false,
            py::arg("print_generic_op_form") = false,
            py::arg("use_local_scope") = false,
-           py::arg("assume_verified") = false, kOperationPrintDocstring)
+           py::arg("assume_verified") = false, py::arg("file") = py::none(),
+           py::arg("binary") = false, kOperationPrintDocstring)
       .def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"),
            py::arg("desired_version") = py::none(),
            kOperationPrintBytecodeDocstring)
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index af55693f18fbbf9..3f856681881829a 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -550,16 +550,19 @@ class PyModule : public BaseContextObject {
   pybind11::handle handle;
 };
 
+class PyAsmState;
+
 /// Base class for PyOperation and PyOpView which exposes the primary, user
 /// visible methods for manipulating it.
 class PyOperationBase {
 public:
   virtual ~PyOperationBase() = default;
   /// Implements the bound 'print' method and helps with others.
-  void print(pybind11::object fileObject, bool binary,
-             std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
+  void print(std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
              bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
-             bool assumeVerified);
+             bool assumeVerified, pybind11::object fileObject, bool binary);
+  void print(PyAsmState &state, pybind11::object fileObject, bool binary);
+
   pybind11::object getAsm(bool binary,
                           std::optional<int64_t> largeElementsLimit,
                           bool enableDebugInfo, bool prettyDebugInfo,
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 04239b048c1c641..04f8a9936e31f79 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -622,10 +622,15 @@ def testOperationPrint():
     print(bytes_value.__class__)
     print(bytes_value)
 
-    # Test get_asm local_scope.
+    # Test print local_scope.
     # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
     module.operation.print(enable_debug_info=True, use_local_scope=True)
 
+    # Test printing using state.
+    state = AsmState(module.operation)
+    # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+    module.operation.print(state)
+
     # Test get_asm with options.
     # CHECK: value = dense_resource<__elided__> : tensor<4xi32>
     # CHECK: "func.return"(%arg0) : (i32) -> () -:4:7

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I'm not sure that the test actually tests the intent (reusing AsmState) but I don't think it matters (it sufficiently exercises the code).

@@ -2946,15 +2967,23 @@ void mlir::python::populateIRCore(py::module &m) {
/*assumeVerified=*/false);
},
"Returns the assembly form of the operation.")
.def("print", &PyOperationBase::print,
.def("print",
py::overload_cast<PyAsmState &, pybind11::object, bool>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
py::overload_cast<PyAsmState &, pybind11::object, bool>(
py::overload_cast<PyAsmState &, py::object, bool>(

@jpienaar jpienaar merged commit 204acc5 into llvm:main Nov 13, 2023
Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff c42d006f05a03cecc0417d77b3bfb8f936d7594a da8c0492a83040c1a3f569ef11b962e45880b491 -- mlir/lib/Bindings/Python/IRCore.cpp mlir/lib/Bindings/Python/IRModule.h
View the diff from clang-format here.
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index d4e947a61a..0493a7ba07 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2974,8 +2974,7 @@ void mlir::python::populateIRCore(py::module &m) {
            py::arg("binary") = false, kOperationPrintStateDocstring)
       .def("print",
            py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
-                             bool, py::object, bool>(
-               &PyOperationBase::print),
+                             bool, py::object, bool>(&PyOperationBase::print),
            // Careful: Lots of arguments must match up with print method.
            py::arg("large_elements_limit") = py::none(),
            py::arg("enable_debug_info") = false,

Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff c42d006f05a03cecc0417d77b3bfb8f936d7594a e9017419ab5bd8bdb89b455e403faf82586e48ae -- mlir/lib/Bindings/Python/IRCore.cpp mlir/lib/Bindings/Python/IRModule.h
View the diff from clang-format here.
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index d4e947a61a..0493a7ba07 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2974,8 +2974,7 @@ void mlir::python::populateIRCore(py::module &m) {
            py::arg("binary") = false, kOperationPrintStateDocstring)
       .def("print",
            py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
-                             bool, py::object, bool>(
-               &PyOperationBase::print),
+                             bool, py::object, bool>(&PyOperationBase::print),
            // Careful: Lots of arguments must match up with print method.
            py::arg("large_elements_limit") = py::none(),
            py::arg("enable_debug_info") = false,

@stellaraccident
Copy link
Contributor

Thanks.

zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
Enables reusing the AsmState when printing from Python. Also moves the
fileObject and binary to the end (pybind11::object was resulting in the
overload not working unless `state=` was specified).

---------

Co-authored-by: Maksim Levental <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants