Skip to content

[mlir][python] Make the Context/Operation capsule creation methods work as documented. #76010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 68 additions & 10 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ py::object PyMlirContext::createFromCapsule(py::object capsule) {
MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
if (mlirContextIsNull(rawContext))
throw py::error_already_set();
return forContext(rawContext).releaseObject();
return stealExternalContext(rawContext).releaseObject();
}

PyMlirContext *PyMlirContext::createNewContextForInit() {
Expand All @@ -615,18 +615,35 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
auto &liveContexts = getLiveContexts();
auto it = liveContexts.find(context.ptr);
if (it == liveContexts.end()) {
// Create.
PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
py::object pyRef = py::cast(unownedContextWrapper);
assert(pyRef && "cast to py::object failed");
liveContexts[context.ptr] = unownedContextWrapper;
return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
throw std::runtime_error(
"Cannot use a context that is not owned by the Python bindings.");
}

// Use existing.
py::object pyRef = py::cast(it->second);
return PyMlirContextRef(it->second, std::move(pyRef));
}

PyMlirContextRef PyMlirContext::stealExternalContext(MlirContext context) {
py::gil_scoped_acquire acquire;
auto &liveContexts = getLiveContexts();
auto it = liveContexts.find(context.ptr);
if (it != liveContexts.end()) {
throw std::runtime_error(
"Cannot transfer ownership of the context to Python "
"as it is already owned by Python.");
}

PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
// Note that the default return value policy on cast is automatic_reference,
// which does not take ownership (delete will not be called).
Comment on lines +638 to +639
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: is this specific to py::cast and/or a bug in pybind? The docs say about return_value_policy::automatic:

This policy falls back to the policy return_value_policy::take_ownership when the return value is a pointer... This is the default policy for py::class_-wrapped types.

and about return_value_policy::automatic_reference:

As above, but use policy return_value_policy::reference when the return value is a pointer. This is the default conversion policy for function arguments when calling Python functions manually from C++ code (i.e. via handle::operator()) and the casters in pybind11/stl.h.

Are we somehow in either of those two cases here?

Copy link
Contributor Author

@stellaraccident stellaraccident Dec 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's somewhat poorly documented -- the behavior is specific to cast whereas the documentation is talking about the more common cases of function args/returns.

Note that I got that comment from the cast in fromOperation, which gets it right. That path is used heavily and was fixed long ago. This branch of context creation, though, had no upstream uses and was just wrong.

Looking for the reference on how I know this. I remember researching it some time ago.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a couple of links below - explanation by the author and the source.

// Just be explicit.
py::object pyRef =
py::cast(unownedContextWrapper, py::return_value_policy::take_ownership);
assert(pyRef && "cast to py::object failed");
return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
}

PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
static LiveContextMap liveContexts;
return liveContexts;
Expand Down Expand Up @@ -1145,6 +1162,18 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
return PyOperationRef(existing, std::move(pyRef));
}

PyOperationRef PyOperation::stealExternalOperation(PyMlirContextRef contextRef,
MlirOperation operation) {
auto &liveOperations = contextRef->liveOperations;
auto it = liveOperations.find(operation.ptr);
if (it != liveOperations.end()) {
throw std::runtime_error(
"Cannot transfer ownership of the operation to Python "
"as it is already owned by Python.");
}
return createInstance(std::move(contextRef), operation, py::none());
}

PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
MlirOperation operation,
py::object parentKeepAlive) {
Expand Down Expand Up @@ -1316,7 +1345,8 @@ py::object PyOperation::createFromCapsule(py::object capsule) {
if (mlirOperationIsNull(rawOperation))
throw py::error_already_set();
MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
return stealExternalOperation(PyMlirContext::forContext(rawCtxt),
rawOperation)
.releaseObject();
}

Expand Down Expand Up @@ -2548,6 +2578,16 @@ void mlir::python::populateIRCore(py::module &m) {
.def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
.def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
.def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
.def_static("_testing_create_raw_context_capsule",
[]() {
// Creates an MlirContext not known to the Python bindings
// and puts it in a capsule. Used to test interop. Using
// this without passing it back to the capsule creation
// API will leak.
return py::reinterpret_steal<py::object>(
mlirPythonContextToCapsule(
mlirContextCreateWithThreading(false)));
})
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyMlirContext::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
Expand Down Expand Up @@ -2973,8 +3013,7 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("binary") = false, kOperationPrintStateDocstring)
.def("print",
py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
bool, py::object, bool>(
&PyOperationBase::print),
bool, py::object, bool>(&PyOperationBase::print),
// Careful: Lots of arguments must match up with print method.
py::arg("large_elements_limit") = py::none(),
py::arg("enable_debug_info") = false,
Expand Down Expand Up @@ -3046,6 +3085,25 @@ void mlir::python::populateIRCore(py::module &m) {
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyOperation::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
.def_static(
"_testing_create_raw_capsule",
[](std::string sourceStr) {
// Creates a raw context and an operation via parsing the given
// source and returns them in a capsule. Error handling is
// minimal as this is purely intended for testing interop with
// operation creation from capsule functions.
MlirContext context = mlirContextCreateWithThreading(false);
MlirOperation op = mlirOperationCreateParse(
context, toMlirStringRef(sourceStr), toMlirStringRef("temp"));
if (mlirOperationIsNull(op)) {
mlirContextDestroy(context);
throw std::invalid_argument("Failed to parse");
}
return py::make_tuple(py::reinterpret_steal<py::object>(
mlirPythonContextToCapsule(context)),
py::reinterpret_steal<py::object>(
mlirPythonOperationToCapsule(op)));
})
.def_property_readonly("operation", [](py::object self) { return self; })
.def_property_readonly("opview", &PyOperation::createOpView)
.def_property_readonly(
Expand Down
19 changes: 18 additions & 1 deletion mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,19 @@ class PyMlirContext {
static PyMlirContext *createNewContextForInit();

/// Returns a context reference for the singleton PyMlirContext wrapper for
/// the given context.
/// the given context. It is only valid to call this on an MlirContext that
/// is already owned by the Python bindings. Typically this will be because
/// it came in some fashion from createNewContextForInit(). However, it
/// is also possible to explicitly transfer ownership of an existing
/// MlirContext to the Python bindings via stealExternalContext().
static PyMlirContextRef forContext(MlirContext context);

/// Explicitly takes ownership of an MlirContext that must not already be
/// known to the Python bindings. Once done, the life-cycle of the context
/// will be controlled by the Python bindings, and it will be destroyed
/// when the reference count goes to zero.
static PyMlirContextRef stealExternalContext(MlirContext context);

~PyMlirContext();

/// Accesses the underlying MlirContext.
Expand Down Expand Up @@ -606,6 +617,12 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
forOperation(PyMlirContextRef contextRef, MlirOperation operation,
pybind11::object parentKeepAlive = pybind11::object());

/// Explicitly takes ownership of an operation that must not already be known
/// to the Python bindings. Once done, the life-cycle of the operation
/// will be controlled by the Python bindings.
static PyOperationRef stealExternalOperation(PyMlirContextRef contextRef,
MlirOperation operation);

/// Creates a detached operation. The operation must not be associated with
/// any existing live operation.
static PyOperationRef
Expand Down
45 changes: 43 additions & 2 deletions mlir/test/python/ir/context_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,46 @@
c4 = mlir.ir.Context()
c4_capsule = c4._CAPIPtr
assert '"mlir.ir.Context._CAPIPtr"' in repr(c4_capsule)
c5 = mlir.ir.Context._CAPICreate(c4_capsule)
assert c4 is c5
# Because the context is already owned by Python, it cannot be created
# a second time.
try:
c5 = mlir.ir.Context._CAPICreate(c4_capsule)
except RuntimeError:
pass
else:
raise AssertionError(
"Should have gotten a RuntimeError when attempting to "
"re-create an already owned context"
)
c4 = None
c4_capsule = None
gc.collect()
assert mlir.ir.Context._get_live_count() == 0

# Use a private testing method to create an unowned context capsule and
# import it.
c6_capsule = mlir.ir.Context._testing_create_raw_context_capsule()
c6 = mlir.ir.Context._CAPICreate(c6_capsule)
assert mlir.ir.Context._get_live_count() == 1
c6_capsule = None
c6 = None
gc.collect()
assert mlir.ir.Context._get_live_count() == 0

# Also test operation import/export as it is tightly coupled to the context.
(
raw_context_capsule,
raw_operation_capsule,
) = mlir.ir.Operation._testing_create_raw_capsule("builtin.module {}")
assert '"mlir.ir.Operation._CAPIPtr"' in repr(raw_operation_capsule)
# Attempting to import an operation for an unknown context should fail.
try:
mlir.ir.Operation._CAPICreate(raw_operation_capsule)
except RuntimeError:
pass
else:
raise AssertionError("Expected exception for unknown context")

# Try again having imported the context.
c7 = mlir.ir.Context._CAPICreate(raw_context_capsule)
op7 = mlir.ir.Operation._CAPICreate(raw_operation_capsule)
13 changes: 0 additions & 13 deletions mlir/test/python/ir/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,19 +844,6 @@ def testOperationName():
print(op.operation.name)


# CHECK-LABEL: TEST: testCapsuleConversions
@run
def testCapsuleConversions():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
m = Operation.create("custom.op1").operation
m_capsule = m._CAPIPtr
assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
m2 = Operation._CAPICreate(m_capsule)
assert m2 is m


# CHECK-LABEL: TEST: testOperationErase
@run
def testOperationErase():
Expand Down