Skip to content

Commit 204acc5

Browse files
[mlir][py] Overload print with state. (llvm#72064)
Enables reusing the AsmState when printing from Python. Also moves the fileObject and binary to the end (pybind11::object was resulting in the overload not working unless `state=` was specified). --------- Co-authored-by: Maksim Levental <[email protected]>
1 parent ad20a9e commit 204acc5

File tree

3 files changed

+50
-13
lines changed

3 files changed

+50
-13
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,15 @@ static const char kOperationPrintDocstring[] =
110110
invalid, behavior is undefined.
111111
)";
112112

113+
static const char kOperationPrintStateDocstring[] =
114+
R"(Prints the assembly form of the operation to a file like object.
115+
116+
Args:
117+
file: The file like object to write to. Defaults to sys.stdout.
118+
binary: Whether to write bytes (True) or str (False). Defaults to False.
119+
state: AsmState capturing the operation numbering and flags.
120+
)";
121+
113122
static const char kOperationGetAsmDocstring[] =
114123
R"(Gets the assembly form of the operation with all options available.
115124
@@ -1169,11 +1178,11 @@ void PyOperation::checkValid() const {
11691178
}
11701179
}
11711180

1172-
void PyOperationBase::print(py::object fileObject, bool binary,
1173-
std::optional<int64_t> largeElementsLimit,
1181+
void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
11741182
bool enableDebugInfo, bool prettyDebugInfo,
11751183
bool printGenericOpForm, bool useLocalScope,
1176-
bool assumeVerified) {
1184+
bool assumeVerified, py::object fileObject,
1185+
bool binary) {
11771186
PyOperation &operation = getOperation();
11781187
operation.checkValid();
11791188
if (fileObject.is_none())
@@ -1198,6 +1207,17 @@ void PyOperationBase::print(py::object fileObject, bool binary,
11981207
mlirOpPrintingFlagsDestroy(flags);
11991208
}
12001209

1210+
void PyOperationBase::print(PyAsmState &state, py::object fileObject,
1211+
bool binary) {
1212+
PyOperation &operation = getOperation();
1213+
operation.checkValid();
1214+
if (fileObject.is_none())
1215+
fileObject = py::module::import("sys").attr("stdout");
1216+
PyFileAccumulator accum(fileObject, binary);
1217+
mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
1218+
accum.getUserData());
1219+
}
1220+
12011221
void PyOperationBase::writeBytecode(const py::object &fileObject,
12021222
std::optional<int64_t> bytecodeVersion) {
12031223
PyOperation &operation = getOperation();
@@ -1230,13 +1250,14 @@ py::object PyOperationBase::getAsm(bool binary,
12301250
} else {
12311251
fileObject = py::module::import("io").attr("StringIO")();
12321252
}
1233-
print(fileObject, /*binary=*/binary,
1234-
/*largeElementsLimit=*/largeElementsLimit,
1253+
print(/*largeElementsLimit=*/largeElementsLimit,
12351254
/*enableDebugInfo=*/enableDebugInfo,
12361255
/*prettyDebugInfo=*/prettyDebugInfo,
12371256
/*printGenericOpForm=*/printGenericOpForm,
12381257
/*useLocalScope=*/useLocalScope,
1239-
/*assumeVerified=*/assumeVerified);
1258+
/*assumeVerified=*/assumeVerified,
1259+
/*fileObject=*/fileObject,
1260+
/*binary=*/binary);
12401261

12411262
return fileObject.attr("getvalue")();
12421263
}
@@ -2946,15 +2967,23 @@ void mlir::python::populateIRCore(py::module &m) {
29462967
/*assumeVerified=*/false);
29472968
},
29482969
"Returns the assembly form of the operation.")
2949-
.def("print", &PyOperationBase::print,
2970+
.def("print",
2971+
py::overload_cast<PyAsmState &, pybind11::object, bool>(
2972+
&PyOperationBase::print),
2973+
py::arg("state"), py::arg("file") = py::none(),
2974+
py::arg("binary") = false, kOperationPrintStateDocstring)
2975+
.def("print",
2976+
py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
2977+
bool, py::object, bool>(
2978+
&PyOperationBase::print),
29502979
// Careful: Lots of arguments must match up with print method.
2951-
py::arg("file") = py::none(), py::arg("binary") = false,
29522980
py::arg("large_elements_limit") = py::none(),
29532981
py::arg("enable_debug_info") = false,
29542982
py::arg("pretty_debug_info") = false,
29552983
py::arg("print_generic_op_form") = false,
29562984
py::arg("use_local_scope") = false,
2957-
py::arg("assume_verified") = false, kOperationPrintDocstring)
2985+
py::arg("assume_verified") = false, py::arg("file") = py::none(),
2986+
py::arg("binary") = false, kOperationPrintDocstring)
29582987
.def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"),
29592988
py::arg("desired_version") = py::none(),
29602989
kOperationPrintBytecodeDocstring)

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -550,16 +550,19 @@ class PyModule : public BaseContextObject {
550550
pybind11::handle handle;
551551
};
552552

553+
class PyAsmState;
554+
553555
/// Base class for PyOperation and PyOpView which exposes the primary, user
554556
/// visible methods for manipulating it.
555557
class PyOperationBase {
556558
public:
557559
virtual ~PyOperationBase() = default;
558560
/// Implements the bound 'print' method and helps with others.
559-
void print(pybind11::object fileObject, bool binary,
560-
std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
561+
void print(std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
561562
bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
562-
bool assumeVerified);
563+
bool assumeVerified, py::object fileObject, bool binary);
564+
void print(PyAsmState &state, py::object fileObject, bool binary);
565+
563566
pybind11::object getAsm(bool binary,
564567
std::optional<int64_t> largeElementsLimit,
565568
bool enableDebugInfo, bool prettyDebugInfo,

mlir/test/python/ir/operation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,10 +622,15 @@ def testOperationPrint():
622622
print(bytes_value.__class__)
623623
print(bytes_value)
624624

625-
# Test get_asm local_scope.
625+
# Test print local_scope.
626626
# CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
627627
module.operation.print(enable_debug_info=True, use_local_scope=True)
628628

629+
# Test printing using state.
630+
state = AsmState(module.operation)
631+
# CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32>
632+
module.operation.print(state)
633+
629634
# Test get_asm with options.
630635
# CHECK: value = dense_resource<__elided__> : tensor<4xi32>
631636
# CHECK: "func.return"(%arg0) : (i32) -> () -:4:7

0 commit comments

Comments
 (0)