Skip to content

[mlir] Python: write bytecode to a file path #127118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 25, 2025

Conversation

nikalra
Copy link
Contributor

@nikalra nikalra commented Feb 13, 2025

The current write_bytecode implementation necessarily requires the serialized module to be duplicated in memory when the python bytes object is created and sent over the binding. For modules with large resources, we may want to avoid this in-memory copy by serializing directly to a file instead of sending bytes across the boundary.

The current `write_bytecode` implementation necessarily requires the serialized module to be duplicated in memory when the python `bytes` object is created and sent over the binding. For modules with large resources, we may want to avoid this in-memory copy by serializing directly to a file instead of sending bytes across the boundary.
@llvmbot
Copy link
Member

llvmbot commented Feb 13, 2025

@llvm/pr-subscribers-mlir

Author: Nikhil Kalra (nikalra)

Changes

The current write_bytecode implementation necessarily requires the serialized module to be duplicated in memory when the python bytes object is created and sent over the binding. For modules with large resources, we may want to avoid this in-memory copy by serializing directly to a file instead of sending bytes across the boundary.


Full diff: https://github.com/llvm/llvm-project/pull/127118.diff

4 Files Affected:

  • (modified) mlir/lib/Bindings/Python/IRCore.cpp (+31-9)
  • (modified) mlir/lib/Bindings/Python/NanobindUtils.h (+21-1)
  • (modified) mlir/python/mlir/_mlir_libs/_mlir/ir.pyi (+3-3)
  • (modified) mlir/test/python/ir/operation.py (+5)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 827db5f3eba84..fbe54f8d81cf0 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include <optional>
+#include <system_error>
 #include <utility>
 
 #include "Globals.h"
@@ -20,8 +21,10 @@
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+#include "nanobind/nanobind.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/raw_ostream.h"
 
 namespace nb = nanobind;
 using namespace nb::literals;
@@ -1329,20 +1332,18 @@ void PyOperationBase::print(PyAsmState &state, nb::object fileObject,
                               accum.getUserData());
 }
 
-void PyOperationBase::writeBytecode(const nb::object &fileObject,
-                                    std::optional<int64_t> bytecodeVersion) {
-  PyOperation &operation = getOperation();
-  operation.checkValid();
-  PyFileAccumulator accum(fileObject, /*binary=*/true);
-
+template <typename T>
+static void
+writeBytecodeForOperation(T &accumulator, MlirOperation operation,
+                          const std::optional<int64_t> &bytecodeVersion) {
   if (!bytecodeVersion.has_value())
-    return mlirOperationWriteBytecode(operation, accum.getCallback(),
-                                      accum.getUserData());
+    return mlirOperationWriteBytecode(operation, accumulator.getCallback(),
+                                      accumulator.getUserData());
 
   MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
   mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion);
   MlirLogicalResult res = mlirOperationWriteBytecodeWithConfig(
-      operation, config, accum.getCallback(), accum.getUserData());
+      operation, config, accumulator.getCallback(), accumulator.getUserData());
   mlirBytecodeWriterConfigDestroy(config);
   if (mlirLogicalResultIsFailure(res))
     throw nb::value_error((Twine("Unable to honor desired bytecode version ") +
@@ -1351,6 +1352,27 @@ void PyOperationBase::writeBytecode(const nb::object &fileObject,
                               .c_str());
 }
 
+void PyOperationBase::writeBytecode(const nb::object &fileObject,
+                                    std::optional<int64_t> bytecodeVersion) {
+  PyOperation &operation = getOperation();
+  operation.checkValid();
+
+  std::string filePath;
+  if (nb::try_cast<std::string>(fileObject, filePath)) {
+    std::error_code ec;
+    llvm::raw_fd_ostream ostream(filePath, ec);
+    if (ec) {
+      throw nb::value_error("Unable to open file for writing");
+    }
+
+    OstreamAccumulator accum(ostream);
+    writeBytecodeForOperation(accum, operation, bytecodeVersion);
+  } else {
+    PyFileAccumulator accum(fileObject, /*binary=*/true);
+    writeBytecodeForOperation(accum, operation, bytecodeVersion);
+  }
+}
+
 void PyOperationBase::walk(
     std::function<MlirWalkResult(MlirOperation)> callback,
     MlirWalkOrder walkOrder) {
diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/lib/Bindings/Python/NanobindUtils.h
index ee193cf9f8ef8..ca9aa064219cd 100644
--- a/mlir/lib/Bindings/Python/NanobindUtils.h
+++ b/mlir/lib/Bindings/Python/NanobindUtils.h
@@ -13,8 +13,10 @@
 #include "mlir-c/Support.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/DataTypes.h"
+#include "llvm/Support/raw_ostream.h"
 
 template <>
 struct std::iterator_traits<nanobind::detail::fast_iterator> {
@@ -128,7 +130,7 @@ struct PyPrintAccumulator {
   }
 };
 
-/// Accumulates int a python file-like object, either writing text (default)
+/// Accumulates into a python file-like object, either writing text (default)
 /// or binary.
 class PyFileAccumulator {
 public:
@@ -158,6 +160,24 @@ class PyFileAccumulator {
   bool binary;
 };
 
+/// Accumulates into a LLVM ostream.
+class OstreamAccumulator {
+public:
+  OstreamAccumulator(llvm::raw_ostream &ostream) : ostream(ostream) {}
+
+  void *getUserData() { return this; }
+
+  MlirStringCallback getCallback() {
+    return [](MlirStringRef part, void *userData) {
+      OstreamAccumulator *accum = static_cast<OstreamAccumulator *>(userData);
+      accum->ostream << llvm::StringRef(part.data, part.length);
+    };
+  }
+
+private:
+  llvm::raw_ostream &ostream;
+};
+
 /// Accumulates into a python string from a method that is expected to make
 /// one (no more, no less) call to the callback (asserts internally on
 /// violation).
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index ab975a6954044..c93de2fe3154e 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -47,7 +47,7 @@ import collections
 from collections.abc import Callable, Sequence
 import io
 from pathlib import Path
-from typing import Any, ClassVar, TypeVar, overload
+from typing import Any, BinaryIO, ClassVar, TypeVar, overload
 
 __all__ = [
     "AffineAddExpr",
@@ -285,12 +285,12 @@ class _OperationBase:
         """
         Verify the operation. Raises MLIRError if verification fails, and returns true otherwise.
         """
-    def write_bytecode(self, file: Any, desired_version: int | None = None) -> None:
+    def write_bytecode(self, file: BinaryIO | str, desired_version: int | None = None) -> None:
         """
         Write the bytecode form of the operation to a file like object.
 
         Args:
-          file: The file like object to write to.
+          file: The file like object or path to write to.
           desired_version: The version of bytecode to emit.
         Returns:
           The bytecode writer status.
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index c2d3aed8808b4..43836abb74f5e 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -3,6 +3,7 @@
 import gc
 import io
 import itertools
+from tempfile import NamedTemporaryFile
 from mlir.ir import *
 from mlir.dialects.builtin import ModuleOp
 from mlir.dialects import arith
@@ -617,6 +618,10 @@ def testOperationPrint():
     module.operation.write_bytecode(bytecode_stream, desired_version=1)
     bytecode = bytecode_stream.getvalue()
     assert bytecode.startswith(b"ML\xefR"), "Expected bytecode to start with MLïR"
+    with NamedTemporaryFile() as tmpfile:
+        module.operation.write_bytecode(str(tmpfile.name), desired_version=1)
+        tmpfile.seek(0)
+        assert tmpfile.read().startswith(b"ML\xefR"), "Expected bytecode to start with MLïR"
     ctx2 = Context()
     module_roundtrip = Module.parse(bytecode, ctx2)
     f = io.StringIO()

Copy link

github-actions bot commented Feb 13, 2025

✅ With the latest revision this PR passed the Python code formatter.

Copy link

github-actions bot commented Feb 13, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

writeBytecodeForOperation(accum, operation, bytecodeVersion);
} else {
PyFileAccumulator accum(fileObject, /*binary=*/true);
writeBytecodeForOperation(accum, operation, bytecodeVersion);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this logic be something we could just move to the PyFileAccumulator itself so that every usage benefits from this optimization instead of specializing it everywhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure -- done!

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

std::error_code ec;
writeTarget.emplace<llvm::raw_fd_ostream>(filePath, ec);
if (ec) {
throw nanobind::value_error("Unable to open file for writing");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if the ec can be used to provide a better error message? (like can we know if it is "permission denied" or "no space left on device" or...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good point! Added the ec message to the Python exception

@nikalra nikalra merged commit a60e8a2 into llvm:main Feb 25, 2025
11 checks passed
@nikalra nikalra deleted the write-bytecode-file branch February 25, 2025 01:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants