Skip to content

Commit bbc2976

Browse files
[mlir][python] Make the Context/Operation capsule creation methods work as documented. (#76010)
This fixes a longstanding bug in the `Context._CAPICreate` method whereby it was not taking ownership of the PyMlirContext wrapper when casting to a Python object. The result was minimally that all such contexts transferred in that way would leak. In addition, counter to the documentation for the `_CAPICreate` helper (see `mlir-c/Bindings/Python/Interop.h`) and the `forContext` / `forOperation` methods, we were silently upgrading any unknown context/operation pointer to steal-ownership semantics. This is dangerous and was causing some subtle bugs downstream where this facility is getting the most use. This patch corrects the semantics and will only do an ownership transfer for `_CAPICreate`, and it will further require that it is an ownership transfer (if already transferred, it was just silently succeeding). Removing the mis-aligned behavior made it clear where the downstream was doing the wrong thing. It also adds some `_testing_` functions to create unowned context and operation capsules so that this can be fully tested upstream, reworking the tests to verify the behavior. In some torture testing downstream, I was not able to trigger any memory corruption with the newly enforced semantics. When getting it wrong, a regular exception is raised.
1 parent d84c640 commit bbc2976

File tree

4 files changed

+129
-26
lines changed

4 files changed

+129
-26
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,7 @@ py::object PyMlirContext::createFromCapsule(py::object capsule) {
602602
MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
603603
if (mlirContextIsNull(rawContext))
604604
throw py::error_already_set();
605-
return forContext(rawContext).releaseObject();
605+
return stealExternalContext(rawContext).releaseObject();
606606
}
607607

608608
PyMlirContext *PyMlirContext::createNewContextForInit() {
@@ -615,18 +615,35 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
615615
auto &liveContexts = getLiveContexts();
616616
auto it = liveContexts.find(context.ptr);
617617
if (it == liveContexts.end()) {
618-
// Create.
619-
PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
620-
py::object pyRef = py::cast(unownedContextWrapper);
621-
assert(pyRef && "cast to py::object failed");
622-
liveContexts[context.ptr] = unownedContextWrapper;
623-
return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
618+
throw std::runtime_error(
619+
"Cannot use a context that is not owned by the Python bindings.");
624620
}
621+
625622
// Use existing.
626623
py::object pyRef = py::cast(it->second);
627624
return PyMlirContextRef(it->second, std::move(pyRef));
628625
}
629626

627+
PyMlirContextRef PyMlirContext::stealExternalContext(MlirContext context) {
628+
py::gil_scoped_acquire acquire;
629+
auto &liveContexts = getLiveContexts();
630+
auto it = liveContexts.find(context.ptr);
631+
if (it != liveContexts.end()) {
632+
throw std::runtime_error(
633+
"Cannot transfer ownership of the context to Python "
634+
"as it is already owned by Python.");
635+
}
636+
637+
PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
638+
// Note that the default return value policy on cast is automatic_reference,
639+
// which does not take ownership (delete will not be called).
640+
// Just be explicit.
641+
py::object pyRef =
642+
py::cast(unownedContextWrapper, py::return_value_policy::take_ownership);
643+
assert(pyRef && "cast to py::object failed");
644+
return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
645+
}
646+
630647
PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
631648
static LiveContextMap liveContexts;
632649
return liveContexts;
@@ -1145,6 +1162,18 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
11451162
return PyOperationRef(existing, std::move(pyRef));
11461163
}
11471164

1165+
PyOperationRef PyOperation::stealExternalOperation(PyMlirContextRef contextRef,
1166+
MlirOperation operation) {
1167+
auto &liveOperations = contextRef->liveOperations;
1168+
auto it = liveOperations.find(operation.ptr);
1169+
if (it != liveOperations.end()) {
1170+
throw std::runtime_error(
1171+
"Cannot transfer ownership of the operation to Python "
1172+
"as it is already owned by Python.");
1173+
}
1174+
return createInstance(std::move(contextRef), operation, py::none());
1175+
}
1176+
11481177
PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
11491178
MlirOperation operation,
11501179
py::object parentKeepAlive) {
@@ -1316,7 +1345,8 @@ py::object PyOperation::createFromCapsule(py::object capsule) {
13161345
if (mlirOperationIsNull(rawOperation))
13171346
throw py::error_already_set();
13181347
MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1319-
return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1348+
return stealExternalOperation(PyMlirContext::forContext(rawCtxt),
1349+
rawOperation)
13201350
.releaseObject();
13211351
}
13221352

@@ -2548,6 +2578,16 @@ void mlir::python::populateIRCore(py::module &m) {
25482578
.def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
25492579
.def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
25502580
.def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2581+
.def_static("_testing_create_raw_context_capsule",
2582+
[]() {
2583+
// Creates an MlirContext not known to the Python bindings
2584+
// and puts it in a capsule. Used to test interop. Using
2585+
// this without passing it back to the capsule creation
2586+
// API will leak.
2587+
return py::reinterpret_steal<py::object>(
2588+
mlirPythonContextToCapsule(
2589+
mlirContextCreateWithThreading(false)));
2590+
})
25512591
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
25522592
&PyMlirContext::getCapsule)
25532593
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
@@ -2973,8 +3013,7 @@ void mlir::python::populateIRCore(py::module &m) {
29733013
py::arg("binary") = false, kOperationPrintStateDocstring)
29743014
.def("print",
29753015
py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
2976-
bool, py::object, bool>(
2977-
&PyOperationBase::print),
3016+
bool, py::object, bool>(&PyOperationBase::print),
29783017
// Careful: Lots of arguments must match up with print method.
29793018
py::arg("large_elements_limit") = py::none(),
29803019
py::arg("enable_debug_info") = false,
@@ -3046,6 +3085,25 @@ void mlir::python::populateIRCore(py::module &m) {
30463085
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
30473086
&PyOperation::getCapsule)
30483087
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
3088+
.def_static(
3089+
"_testing_create_raw_capsule",
3090+
[](std::string sourceStr) {
3091+
// Creates a raw context and an operation via parsing the given
3092+
// source and returns them in a capsule. Error handling is
3093+
// minimal as this is purely intended for testing interop with
3094+
// operation creation from capsule functions.
3095+
MlirContext context = mlirContextCreateWithThreading(false);
3096+
MlirOperation op = mlirOperationCreateParse(
3097+
context, toMlirStringRef(sourceStr), toMlirStringRef("temp"));
3098+
if (mlirOperationIsNull(op)) {
3099+
mlirContextDestroy(context);
3100+
throw std::invalid_argument("Failed to parse");
3101+
}
3102+
return py::make_tuple(py::reinterpret_steal<py::object>(
3103+
mlirPythonContextToCapsule(context)),
3104+
py::reinterpret_steal<py::object>(
3105+
mlirPythonOperationToCapsule(op)));
3106+
})
30493107
.def_property_readonly("operation", [](py::object self) { return self; })
30503108
.def_property_readonly("opview", &PyOperation::createOpView)
30513109
.def_property_readonly(

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,19 @@ class PyMlirContext {
176176
static PyMlirContext *createNewContextForInit();
177177

178178
/// Returns a context reference for the singleton PyMlirContext wrapper for
179-
/// the given context.
179+
/// the given context. It is only valid to call this on an MlirContext that
180+
/// is already owned by the Python bindings. Typically this will be because
181+
/// it came in some fashion from createNewContextForInit(). However, it
182+
/// is also possible to explicitly transfer ownership of an existing
183+
/// MlirContext to the Python bindings via stealExternalContext().
180184
static PyMlirContextRef forContext(MlirContext context);
185+
186+
/// Explicitly takes ownership of an MlirContext that must not already be
187+
/// known to the Python bindings. Once done, the life-cycle of the context
188+
/// will be controlled by the Python bindings, and it will be destroyed
189+
/// when the reference count goes to zero.
190+
static PyMlirContextRef stealExternalContext(MlirContext context);
191+
181192
~PyMlirContext();
182193

183194
/// Accesses the underlying MlirContext.
@@ -606,6 +617,12 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
606617
forOperation(PyMlirContextRef contextRef, MlirOperation operation,
607618
pybind11::object parentKeepAlive = pybind11::object());
608619

620+
/// Explicitly takes ownership of an operation that must not already be known
621+
/// to the Python bindings. Once done, the life-cycle of the operation
622+
/// will be controlled by the Python bindings.
623+
static PyOperationRef stealExternalOperation(PyMlirContextRef contextRef,
624+
MlirOperation operation);
625+
609626
/// Creates a detached operation. The operation must not be associated with
610627
/// any existing live operation.
611628
static PyOperationRef

mlir/test/python/ir/context_lifecycle.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,46 @@
4545
c4 = mlir.ir.Context()
4646
c4_capsule = c4._CAPIPtr
4747
assert '"mlir.ir.Context._CAPIPtr"' in repr(c4_capsule)
48-
c5 = mlir.ir.Context._CAPICreate(c4_capsule)
49-
assert c4 is c5
48+
# Because the context is already owned by Python, it cannot be created
49+
# a second time.
50+
try:
51+
c5 = mlir.ir.Context._CAPICreate(c4_capsule)
52+
except RuntimeError:
53+
pass
54+
else:
55+
raise AssertionError(
56+
"Should have gotten a RuntimeError when attempting to "
57+
"re-create an already owned context"
58+
)
59+
c4 = None
60+
c4_capsule = None
61+
gc.collect()
62+
assert mlir.ir.Context._get_live_count() == 0
63+
64+
# Use a private testing method to create an unowned context capsule and
65+
# import it.
66+
c6_capsule = mlir.ir.Context._testing_create_raw_context_capsule()
67+
c6 = mlir.ir.Context._CAPICreate(c6_capsule)
68+
assert mlir.ir.Context._get_live_count() == 1
69+
c6_capsule = None
70+
c6 = None
71+
gc.collect()
72+
assert mlir.ir.Context._get_live_count() == 0
73+
74+
# Also test operation import/export as it is tightly coupled to the context.
75+
(
76+
raw_context_capsule,
77+
raw_operation_capsule,
78+
) = mlir.ir.Operation._testing_create_raw_capsule("builtin.module {}")
79+
assert '"mlir.ir.Operation._CAPIPtr"' in repr(raw_operation_capsule)
80+
# Attempting to import an operation for an unknown context should fail.
81+
try:
82+
mlir.ir.Operation._CAPICreate(raw_operation_capsule)
83+
except RuntimeError:
84+
pass
85+
else:
86+
raise AssertionError("Expected exception for unknown context")
87+
88+
# Try again having imported the context.
89+
c7 = mlir.ir.Context._CAPICreate(raw_context_capsule)
90+
op7 = mlir.ir.Operation._CAPICreate(raw_operation_capsule)

mlir/test/python/ir/operation.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -844,19 +844,6 @@ def testOperationName():
844844
print(op.operation.name)
845845

846846

847-
# CHECK-LABEL: TEST: testCapsuleConversions
848-
@run
849-
def testCapsuleConversions():
850-
ctx = Context()
851-
ctx.allow_unregistered_dialects = True
852-
with Location.unknown(ctx):
853-
m = Operation.create("custom.op1").operation
854-
m_capsule = m._CAPIPtr
855-
assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
856-
m2 = Operation._CAPICreate(m_capsule)
857-
assert m2 is m
858-
859-
860847
# CHECK-LABEL: TEST: testOperationErase
861848
@run
862849
def testOperationErase():

0 commit comments

Comments
 (0)