Skip to content

Commit a677a17

Browse files
committed
[mlir][py] Enable AsmState overload for operation.
1 parent 8586cd5 commit a677a17

File tree

3 files changed

+15
-3
lines changed

3 files changed

+15
-3
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3479,6 +3479,8 @@ void mlir::python::populateIRCore(py::module &m) {
34793479

34803480
py::class_<PyAsmState>(m, "AsmState", py::module_local())
34813481
.def(py::init<PyValue &, bool>(), py::arg("value"),
3482+
py::arg("use_local_scope") = false)
3483+
.def(py::init<PyOperationBase &, bool>(), py::arg("op"),
34823484
py::arg("use_local_scope") = false);
34833485

34843486
//----------------------------------------------------------------------------

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,16 @@ class PyAsmState {
759759
mlirOpPrintingFlagsUseLocalScope(flags);
760760
state = mlirAsmStateCreateForValue(value, flags);
761761
}
762+
763+
PyAsmState(PyOperationBase &operation, bool useLocalScope) {
764+
flags = mlirOpPrintingFlagsCreate();
765+
// The OpPrintingFlags are not exposed Python side, create locally and
766+
// associate lifetime with the state.
767+
if (useLocalScope)
768+
mlirOpPrintingFlagsUseLocalScope(flags);
769+
state =
770+
mlirAsmStateCreateForOperation(operation.getOperation().get(), flags);
771+
}
762772
~PyAsmState() {
763773
mlirOpPrintingFlagsDestroy(flags);
764774
}

mlir/test/python/ir/value.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ def testValuePrintAsOperand():
165165
# CHECK: Value(%[[VAL2:.*]] = "custom.op2"() : () -> i32)
166166
print(value2)
167167

168-
f = func.FuncOp("test", ([i32, i32], []))
169-
entry_block1 = Block.create_at_start(f.operation.regions[0], [i32, i32])
168+
topFn = func.FuncOp("test", ([i32, i32], []))
169+
entry_block1 = Block.create_at_start(topFn.operation.regions[0], [i32, i32])
170170

171171
with InsertionPoint(entry_block1):
172172
value3 = Operation.create("custom.op3", results=[i32]).results[0]
@@ -201,7 +201,7 @@ def testValuePrintAsOperand():
201201

202202
print("With AsmState")
203203
# CHECK-LABEL: With AsmState
204-
state = AsmState(value3, use_local_scope=True)
204+
state = AsmState(topFn.operation, use_local_scope=True)
205205
# CHECK: %0
206206
print(value3.get_name(state=state))
207207
# CHECK: %1

0 commit comments

Comments
 (0)