-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Python: write bytecode to a file path #127118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The current `write_bytecode` implementation necessarily requires the serialized module to be duplicated in memory when the python `bytes` object is created and sent over the binding. For modules with large resources, we may want to avoid this in-memory copy by serializing directly to a file instead of sending bytes across the boundary.
@llvm/pr-subscribers-mlir Author: Nikhil Kalra (nikalra) ChangesThe current Full diff: https://github.com/llvm/llvm-project/pull/127118.diff 4 Files Affected:
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 827db5f3eba84..fbe54f8d81cf0 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include <optional>
+#include <system_error>
#include <utility>
#include "Globals.h"
@@ -20,8 +21,10 @@
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+#include "nanobind/nanobind.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/raw_ostream.h"
namespace nb = nanobind;
using namespace nb::literals;
@@ -1329,20 +1332,18 @@ void PyOperationBase::print(PyAsmState &state, nb::object fileObject,
accum.getUserData());
}
-void PyOperationBase::writeBytecode(const nb::object &fileObject,
- std::optional<int64_t> bytecodeVersion) {
- PyOperation &operation = getOperation();
- operation.checkValid();
- PyFileAccumulator accum(fileObject, /*binary=*/true);
-
+template <typename T>
+static void
+writeBytecodeForOperation(T &accumulator, MlirOperation operation,
+ const std::optional<int64_t> &bytecodeVersion) {
if (!bytecodeVersion.has_value())
- return mlirOperationWriteBytecode(operation, accum.getCallback(),
- accum.getUserData());
+ return mlirOperationWriteBytecode(operation, accumulator.getCallback(),
+ accumulator.getUserData());
MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion);
MlirLogicalResult res = mlirOperationWriteBytecodeWithConfig(
- operation, config, accum.getCallback(), accum.getUserData());
+ operation, config, accumulator.getCallback(), accumulator.getUserData());
mlirBytecodeWriterConfigDestroy(config);
if (mlirLogicalResultIsFailure(res))
throw nb::value_error((Twine("Unable to honor desired bytecode version ") +
@@ -1351,6 +1352,27 @@ void PyOperationBase::writeBytecode(const nb::object &fileObject,
.c_str());
}
+void PyOperationBase::writeBytecode(const nb::object &fileObject,
+ std::optional<int64_t> bytecodeVersion) {
+ PyOperation &operation = getOperation();
+ operation.checkValid();
+
+ std::string filePath;
+ if (nb::try_cast<std::string>(fileObject, filePath)) {
+ std::error_code ec;
+ llvm::raw_fd_ostream ostream(filePath, ec);
+ if (ec) {
+ throw nb::value_error("Unable to open file for writing");
+ }
+
+ OstreamAccumulator accum(ostream);
+ writeBytecodeForOperation(accum, operation, bytecodeVersion);
+ } else {
+ PyFileAccumulator accum(fileObject, /*binary=*/true);
+ writeBytecodeForOperation(accum, operation, bytecodeVersion);
+ }
+}
+
void PyOperationBase::walk(
std::function<MlirWalkResult(MlirOperation)> callback,
MlirWalkOrder walkOrder) {
diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/lib/Bindings/Python/NanobindUtils.h
index ee193cf9f8ef8..ca9aa064219cd 100644
--- a/mlir/lib/Bindings/Python/NanobindUtils.h
+++ b/mlir/lib/Bindings/Python/NanobindUtils.h
@@ -13,8 +13,10 @@
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/DataTypes.h"
+#include "llvm/Support/raw_ostream.h"
template <>
struct std::iterator_traits<nanobind::detail::fast_iterator> {
@@ -128,7 +130,7 @@ struct PyPrintAccumulator {
}
};
-/// Accumulates int a python file-like object, either writing text (default)
+/// Accumulates into a python file-like object, either writing text (default)
/// or binary.
class PyFileAccumulator {
public:
@@ -158,6 +160,24 @@ class PyFileAccumulator {
bool binary;
};
+/// Accumulates into a LLVM ostream.
+class OstreamAccumulator {
+public:
+ OstreamAccumulator(llvm::raw_ostream &ostream) : ostream(ostream) {}
+
+ void *getUserData() { return this; }
+
+ MlirStringCallback getCallback() {
+ return [](MlirStringRef part, void *userData) {
+ OstreamAccumulator *accum = static_cast<OstreamAccumulator *>(userData);
+ accum->ostream << llvm::StringRef(part.data, part.length);
+ };
+ }
+
+private:
+ llvm::raw_ostream &ostream;
+};
+
/// Accumulates into a python string from a method that is expected to make
/// one (no more, no less) call to the callback (asserts internally on
/// violation).
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index ab975a6954044..c93de2fe3154e 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -47,7 +47,7 @@ import collections
from collections.abc import Callable, Sequence
import io
from pathlib import Path
-from typing import Any, ClassVar, TypeVar, overload
+from typing import Any, BinaryIO, ClassVar, TypeVar, overload
__all__ = [
"AffineAddExpr",
@@ -285,12 +285,12 @@ class _OperationBase:
"""
Verify the operation. Raises MLIRError if verification fails, and returns true otherwise.
"""
- def write_bytecode(self, file: Any, desired_version: int | None = None) -> None:
+ def write_bytecode(self, file: BinaryIO | str, desired_version: int | None = None) -> None:
"""
Write the bytecode form of the operation to a file like object.
Args:
- file: The file like object to write to.
+ file: The file like object or path to write to.
desired_version: The version of bytecode to emit.
Returns:
The bytecode writer status.
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index c2d3aed8808b4..43836abb74f5e 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -3,6 +3,7 @@
import gc
import io
import itertools
+from tempfile import NamedTemporaryFile
from mlir.ir import *
from mlir.dialects.builtin import ModuleOp
from mlir.dialects import arith
@@ -617,6 +618,10 @@ def testOperationPrint():
module.operation.write_bytecode(bytecode_stream, desired_version=1)
bytecode = bytecode_stream.getvalue()
assert bytecode.startswith(b"ML\xefR"), "Expected bytecode to start with MLïR"
+ with NamedTemporaryFile() as tmpfile:
+ module.operation.write_bytecode(str(tmpfile.name), desired_version=1)
+ tmpfile.seek(0)
+ assert tmpfile.read().startswith(b"ML\xefR"), "Expected bytecode to start with MLïR"
ctx2 = Context()
module_roundtrip = Module.parse(bytecode, ctx2)
f = io.StringIO()
|
✅ With the latest revision this PR passed the Python code formatter. |
✅ With the latest revision this PR passed the C/C++ code formatter. |
mlir/lib/Bindings/Python/IRCore.cpp
Outdated
writeBytecodeForOperation(accum, operation, bytecodeVersion); | ||
} else { | ||
PyFileAccumulator accum(fileObject, /*binary=*/true); | ||
writeBytecodeForOperation(accum, operation, bytecodeVersion); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this logic be something we could just move to the PyFileAccumulator
itself so that every usage benefits from this optimization instead of specializing it everywhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure -- done!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
std::error_code ec; | ||
writeTarget.emplace<llvm::raw_fd_ostream>(filePath, ec); | ||
if (ec) { | ||
throw nanobind::value_error("Unable to open file for writing"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if the ec
can be used to provide a better error message? (like can we know if it is "permission denied" or "no space left on device" or...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah good point! Added the ec
message to the Python exception
The current
write_bytecode
implementation necessarily requires the serialized module to be duplicated in memory when the pythonbytes
object is created and sent over the binding. For modules with large resources, we may want to avoid this in-memory copy by serializing directly to a file instead of sending bytes across the boundary.