Skip to content

Commit 59f5ad4

Browse files
jorickertAlexisPerry
authored andcommitted
[mlir] Expose skipRegions option for Op printing in the C and Python bindings (llvm#96150)
The MLIR C and Python Bindings expose various methods from `mlir::OpPrintingFlags` . This PR adds a binding for the `skipRegions` method, which allows to skip the printing of Regions when printing Ops. It also exposes this option as parameter in the python `get_asm` and `print` methods
1 parent 7b62f4d commit 59f5ad4

File tree

7 files changed

+51
-13
lines changed

7 files changed

+51
-13
lines changed

mlir/include/mlir-c/IR.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,10 @@ mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags);
450450
MLIR_CAPI_EXPORTED void
451451
mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags);
452452

453+
/// Skip printing regions.
454+
MLIR_CAPI_EXPORTED void
455+
mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags);
456+
453457
//===----------------------------------------------------------------------===//
454458
// Bytecode printing flags API.
455459
//===----------------------------------------------------------------------===//

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ static const char kOperationPrintDocstring[] =
108108
and report failures in a more robust fashion. Set this to True if doing this
109109
in order to avoid running a redundant verification. If the IR is actually
110110
invalid, behavior is undefined.
111+
skip_regions: Whether to skip printing regions. Defaults to False.
111112
)";
112113

113114
static const char kOperationPrintStateDocstring[] =
@@ -1221,7 +1222,7 @@ void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
12211222
bool enableDebugInfo, bool prettyDebugInfo,
12221223
bool printGenericOpForm, bool useLocalScope,
12231224
bool assumeVerified, py::object fileObject,
1224-
bool binary) {
1225+
bool binary, bool skipRegions) {
12251226
PyOperation &operation = getOperation();
12261227
operation.checkValid();
12271228
if (fileObject.is_none())
@@ -1239,6 +1240,8 @@ void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
12391240
mlirOpPrintingFlagsUseLocalScope(flags);
12401241
if (assumeVerified)
12411242
mlirOpPrintingFlagsAssumeVerified(flags);
1243+
if (skipRegions)
1244+
mlirOpPrintingFlagsSkipRegions(flags);
12421245

12431246
PyFileAccumulator accum(fileObject, binary);
12441247
mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
@@ -1314,7 +1317,7 @@ py::object PyOperationBase::getAsm(bool binary,
13141317
std::optional<int64_t> largeElementsLimit,
13151318
bool enableDebugInfo, bool prettyDebugInfo,
13161319
bool printGenericOpForm, bool useLocalScope,
1317-
bool assumeVerified) {
1320+
bool assumeVerified, bool skipRegions) {
13181321
py::object fileObject;
13191322
if (binary) {
13201323
fileObject = py::module::import("io").attr("BytesIO")();
@@ -1328,7 +1331,8 @@ py::object PyOperationBase::getAsm(bool binary,
13281331
/*useLocalScope=*/useLocalScope,
13291332
/*assumeVerified=*/assumeVerified,
13301333
/*fileObject=*/fileObject,
1331-
/*binary=*/binary);
1334+
/*binary=*/binary,
1335+
/*skipRegions=*/skipRegions);
13321336

13331337
return fileObject.attr("getvalue")();
13341338
}
@@ -3043,7 +3047,8 @@ void mlir::python::populateIRCore(py::module &m) {
30433047
/*prettyDebugInfo=*/false,
30443048
/*printGenericOpForm=*/false,
30453049
/*useLocalScope=*/false,
3046-
/*assumeVerified=*/false);
3050+
/*assumeVerified=*/false,
3051+
/*skipRegions=*/false);
30473052
},
30483053
"Returns the assembly form of the operation.")
30493054
.def("print",
@@ -3053,15 +3058,17 @@ void mlir::python::populateIRCore(py::module &m) {
30533058
py::arg("binary") = false, kOperationPrintStateDocstring)
30543059
.def("print",
30553060
py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
3056-
bool, py::object, bool>(&PyOperationBase::print),
3061+
bool, py::object, bool, bool>(
3062+
&PyOperationBase::print),
30573063
// Careful: Lots of arguments must match up with print method.
30583064
py::arg("large_elements_limit") = py::none(),
30593065
py::arg("enable_debug_info") = false,
30603066
py::arg("pretty_debug_info") = false,
30613067
py::arg("print_generic_op_form") = false,
30623068
py::arg("use_local_scope") = false,
30633069
py::arg("assume_verified") = false, py::arg("file") = py::none(),
3064-
py::arg("binary") = false, kOperationPrintDocstring)
3070+
py::arg("binary") = false, py::arg("skip_regions") = false,
3071+
kOperationPrintDocstring)
30653072
.def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"),
30663073
py::arg("desired_version") = py::none(),
30673074
kOperationPrintBytecodeDocstring)
@@ -3073,7 +3080,8 @@ void mlir::python::populateIRCore(py::module &m) {
30733080
py::arg("pretty_debug_info") = false,
30743081
py::arg("print_generic_op_form") = false,
30753082
py::arg("use_local_scope") = false,
3076-
py::arg("assume_verified") = false, kOperationGetAsmDocstring)
3083+
py::arg("assume_verified") = false, py::arg("skip_regions") = false,
3084+
kOperationGetAsmDocstring)
30773085
.def("verify", &PyOperationBase::verify,
30783086
"Verify the operation. Raises MLIRError if verification fails, and "
30793087
"returns true otherwise.")

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -574,14 +574,15 @@ class PyOperationBase {
574574
/// Implements the bound 'print' method and helps with others.
575575
void print(std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
576576
bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
577-
bool assumeVerified, py::object fileObject, bool binary);
577+
bool assumeVerified, py::object fileObject, bool binary,
578+
bool skipRegions);
578579
void print(PyAsmState &state, py::object fileObject, bool binary);
579580

580581
pybind11::object getAsm(bool binary,
581582
std::optional<int64_t> largeElementsLimit,
582583
bool enableDebugInfo, bool prettyDebugInfo,
583584
bool printGenericOpForm, bool useLocalScope,
584-
bool assumeVerified);
585+
bool assumeVerified, bool skipRegions);
585586

586587
// Implement the bound 'writeBytecode' method.
587588
void writeBytecode(const pybind11::object &fileObject,

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,9 @@ void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags) {
219219
unwrap(flags)->assumeVerified();
220220
}
221221

222+
void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags) {
223+
unwrap(flags)->skipRegions();
224+
}
222225
//===----------------------------------------------------------------------===//
223226
// Bytecode printing flags API.
224227
//===----------------------------------------------------------------------===//

mlir/python/mlir/_mlir_libs/_mlir/ir.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ class _OperationBase:
209209
print_generic_op_form: bool = False,
210210
use_local_scope: bool = False,
211211
assume_verified: bool = False,
212+
skip_regions: bool = False,
212213
) -> Union[io.BytesIO, io.StringIO]:
213214
"""
214215
Gets the assembly form of the operation with all options available.
@@ -256,6 +257,7 @@ class _OperationBase:
256257
assume_verified: bool = False,
257258
file: Optional[Any] = None,
258259
binary: bool = False,
260+
skip_regions: bool = False,
259261
) -> None:
260262
"""
261263
Prints the assembly form of the operation to a file like object.
@@ -281,6 +283,7 @@ class _OperationBase:
281283
and report failures in a more robust fashion. Set this to True if doing this
282284
in order to avoid running a redundant verification. If the IR is actually
283285
invalid, behavior is undefined.
286+
skip_regions: Whether to skip printing regions. Defaults to False.
284287
"""
285288
def verify(self) -> bool:
286289
"""

mlir/test/CAPI/ir.c

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -340,9 +340,9 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
340340
// function.
341341
MlirRegion region = mlirOperationGetRegion(operation, 0);
342342
MlirBlock block = mlirRegionGetFirstBlock(region);
343-
operation = mlirBlockGetFirstOperation(block);
344-
region = mlirOperationGetRegion(operation, 0);
345-
MlirOperation parentOperation = operation;
343+
MlirOperation function = mlirBlockGetFirstOperation(block);
344+
region = mlirOperationGetRegion(function, 0);
345+
MlirOperation parentOperation = function;
346346
block = mlirRegionGetFirstBlock(region);
347347
operation = mlirBlockGetFirstOperation(block);
348348
assert(mlirModuleIsNull(mlirModuleFromOperation(operation)));
@@ -490,6 +490,18 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
490490
// CHECK: Op print with all flags: %{{.*}} = "arith.constant"() <{value = 0 : index}> {elts = dense_resource<__elided__> : tensor<4xi32>} : () -> index loc(unknown)
491491
// clang-format on
492492

493+
mlirOpPrintingFlagsDestroy(flags);
494+
flags = mlirOpPrintingFlagsCreate();
495+
mlirOpPrintingFlagsSkipRegions(flags);
496+
fprintf(stderr, "Op print with skip regions flag: ");
497+
mlirOperationPrintWithFlags(function, flags, printToStderr, NULL);
498+
fprintf(stderr, "\n");
499+
// clang-format off
500+
// CHECK: Op print with skip regions flag: func.func @add(%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: memref<?xf32>)
501+
// CHECK-NOT: constant
502+
// CHECK-NOT: return
503+
// clang-format on
504+
493505
fprintf(stderr, "With state: |");
494506
mlirValuePrintAsOperand(value, state, printToStderr, NULL);
495507
// CHECK: With state: |%0|

mlir/test/python/ir/operation.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def testOperationPrint():
631631
# CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32>
632632
module.operation.print(state)
633633

634-
# Test get_asm with options.
634+
# Test print with options.
635635
# CHECK: value = dense_resource<__elided__> : tensor<4xi32>
636636
# CHECK: "func.return"(%arg0) : (i32) -> () -:4:7
637637
module.operation.print(
@@ -642,6 +642,13 @@ def testOperationPrint():
642642
use_local_scope=True,
643643
)
644644

645+
# Test print with skip_regions option
646+
# CHECK: func.func @f1(%arg0: i32) -> i32
647+
# CHECK-NOT: func.return
648+
module.body.operations[0].print(
649+
skip_regions=True,
650+
)
651+
645652

646653
# CHECK-LABEL: TEST: testKnownOpView
647654
@run

0 commit comments

Comments
 (0)