Skip to content

[mlir][python][wip] remove liveOperations #92631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlir/include/mlir-c/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module);
/// The returned module is null when the input operation was not a ModuleOp.
MLIR_CAPI_EXPORTED MlirModule mlirModuleFromOperation(MlirOperation op);

MLIR_CAPI_EXPORTED bool mlirModuleEqual(MlirModule mod, MlirModule other);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: module (mod just doesn't save enough). Given this is binary operator, I'd even consider lhs & rhs.


//===----------------------------------------------------------------------===//
// Operation state.
//===----------------------------------------------------------------------===//
Expand Down
138 changes: 22 additions & 116 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,58 +634,6 @@ PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {

size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }

size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }

std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
std::vector<PyOperation *> liveObjects;
for (auto &entry : liveOperations)
liveObjects.push_back(entry.second.second);
return liveObjects;
}

size_t PyMlirContext::clearLiveOperations() {
for (auto &op : liveOperations)
op.second.second->setInvalid();
size_t numInvalidated = liveOperations.size();
liveOperations.clear();
return numInvalidated;
}

void PyMlirContext::clearOperation(MlirOperation op) {
auto it = liveOperations.find(op.ptr);
if (it != liveOperations.end()) {
it->second.second->setInvalid();
liveOperations.erase(it);
}
}

void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
typedef struct {
PyOperation &rootOp;
bool rootSeen;
} callBackData;
callBackData data{op.getOperation(), false};
// Mark all ops below the op that the passmanager will be rooted
// at (but not op itself - note the preorder) as invalid.
MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
void *userData) {
callBackData *data = static_cast<callBackData *>(userData);
if (LLVM_LIKELY(data->rootSeen))
data->rootOp.getOperation().getContext()->clearOperation(op);
else
data->rootSeen = true;
return MlirWalkResult::MlirWalkResultAdvance;
};
mlirOperationWalk(op.getOperation(), invalidatingCallback,
static_cast<void *>(&data), MlirWalkPreOrder);
}
void PyMlirContext::clearOperationsInside(MlirOperation op) {
PyOperationRef opRef = PyOperation::forOperation(getRef(), op);
clearOperationsInside(opRef->getOperation());
}

size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }

pybind11::object PyMlirContext::contextEnter() {
return PyThreadContextEntry::pushContext(*this);
}
Expand Down Expand Up @@ -1055,39 +1003,21 @@ PyLocation &DefaultingPyLocation::resolve() {
PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
: BaseContextObject(std::move(contextRef)), module(module) {}

PyModule::~PyModule() {
py::gil_scoped_acquire acquire;
auto &liveModules = getContext()->liveModules;
assert(liveModules.count(module.ptr) == 1 &&
"destroying module not in live map");
liveModules.erase(module.ptr);
mlirModuleDestroy(module);
}
PyModule::~PyModule() { mlirModuleDestroy(module); }

PyModuleRef PyModule::forModule(MlirModule module) {
MlirContext context = mlirModuleGetContext(module);
PyMlirContextRef contextRef = PyMlirContext::forContext(context);

py::gil_scoped_acquire acquire;
auto &liveModules = contextRef->liveModules;
auto it = liveModules.find(module.ptr);
if (it == liveModules.end()) {
// Create.
PyModule *unownedModule = new PyModule(std::move(contextRef), module);
// Note that the default return value policy on cast is automatic_reference,
// which does not take ownership (delete will not be called).
// Just be explicit.
py::object pyRef =
py::cast(unownedModule, py::return_value_policy::take_ownership);
unownedModule->handle = pyRef;
liveModules[module.ptr] =
std::make_pair(unownedModule->handle, unownedModule);
return PyModuleRef(unownedModule, std::move(pyRef));
}
// Use existing.
PyModule *existing = it->second.second;
py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
return PyModuleRef(existing, std::move(pyRef));
// Create.
PyModule *unownedModule = new PyModule(std::move(contextRef), module);
// Note that the default return value policy on cast is automatic_reference,
// which does not take ownership (delete will not be called).
// Just be explicit.
py::object pyRef =
py::cast(unownedModule, py::return_value_policy::take_ownership);
unownedModule->handle = pyRef;
return PyModuleRef(unownedModule, std::move(pyRef));
}

py::object PyModule::createFromCapsule(py::object capsule) {
Expand All @@ -1112,10 +1042,6 @@ PyOperation::~PyOperation() {
// If the operation has already been invalidated there is nothing to do.
if (!valid)
return;
auto &liveOperations = getContext()->liveOperations;
assert(liveOperations.count(operation.ptr) == 1 &&
"destroying operation not in live map");
liveOperations.erase(operation.ptr);
if (!isAttached()) {
mlirOperationDestroy(operation);
}
Expand All @@ -1124,7 +1050,6 @@ PyOperation::~PyOperation() {
PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
MlirOperation operation,
py::object parentKeepAlive) {
auto &liveOperations = contextRef->liveOperations;
// Create.
PyOperation *unownedOperation =
new PyOperation(std::move(contextRef), operation);
Expand All @@ -1137,34 +1062,20 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
if (parentKeepAlive) {
unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
}
liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
return PyOperationRef(unownedOperation, std::move(pyRef));
}

PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
MlirOperation operation,
py::object parentKeepAlive) {
auto &liveOperations = contextRef->liveOperations;
auto it = liveOperations.find(operation.ptr);
if (it == liveOperations.end()) {
// Create.
return createInstance(std::move(contextRef), operation,
std::move(parentKeepAlive));
}
// Use existing.
PyOperation *existing = it->second.second;
py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
return PyOperationRef(existing, std::move(pyRef));
// Create.
return createInstance(std::move(contextRef), operation,
std::move(parentKeepAlive));
}

PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
MlirOperation operation,
py::object parentKeepAlive) {
auto &liveOperations = contextRef->liveOperations;
assert(liveOperations.count(operation.ptr) == 0 &&
"cannot create detached operation that already exists");
(void)liveOperations;

PyOperationRef created = createInstance(std::move(contextRef), operation,
std::move(parentKeepAlive));
created->attached = false;
Expand Down Expand Up @@ -1530,9 +1441,6 @@ void PyOperation::erase() {
// TODO: Fix memory hazards when erasing a tree of operations for which a deep
// Python reference to a child operation is live. All children should also
// have their `valid` bit set to false.
auto &liveOperations = getContext()->liveOperations;
if (liveOperations.count(operation.ptr))
liveOperations.erase(operation.ptr);
mlirOperationDestroy(operation);
valid = false;
}
Expand Down Expand Up @@ -2274,7 +2182,6 @@ class PyBlockArgumentList
: public Sliceable<PyBlockArgumentList, PyBlockArgument> {
public:
static constexpr const char *pyClassName = "BlockArgumentList";
using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;

PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
intptr_t startIndex = 0, intptr_t length = -1,
Expand Down Expand Up @@ -2598,14 +2505,6 @@ void mlir::python::populateIRCore(py::module &m) {
PyMlirContextRef ref = PyMlirContext::forContext(self.get());
return ref.releaseObject();
})
.def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
.def("_get_live_operation_objects",
&PyMlirContext::getLiveOperationObjects)
.def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
.def("_clear_live_operations_inside",
py::overload_cast<MlirOperation>(
&PyMlirContext::clearOperationsInside))
.def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyMlirContext::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
Expand Down Expand Up @@ -2915,7 +2814,13 @@ void mlir::python::populateIRCore(py::module &m) {
// Defer to the operation's __str__.
return self.attr("operation").attr("__str__")();
},
kOperationStrDunderDocstring);
kOperationStrDunderDocstring)
.def(
"__eq__",
[](PyModule &self, PyModule &other) {
return mlirModuleEqual(self.get(), other.get());
},
"other"_a);

//----------------------------------------------------------------------------
// Mapping of Operation.
Expand All @@ -2927,7 +2832,8 @@ void mlir::python::populateIRCore(py::module &m) {
})
.def("__eq__",
[](PyOperationBase &self, PyOperationBase &other) {
return &self.getOperation() == &other.getOperation();
return mlirOperationEqual(self.getOperation().get(),
other.getOperation().get());
})
.def("__eq__",
[](PyOperationBase &self, py::object other) { return false; })
Expand Down
44 changes: 0 additions & 44 deletions mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,34 +201,6 @@ class PyMlirContext {
/// Gets the count of live context objects. Used for testing.
static size_t getLiveCount();

/// Get a list of Python objects which are still in the live context map.
std::vector<PyOperation *> getLiveOperationObjects();

/// Gets the count of live operations associated with this context.
/// Used for testing.
size_t getLiveOperationCount();

/// Clears the live operations map, returning the number of entries which were
/// invalidated. To be used as a safety mechanism so that API end-users can't
/// corrupt by holding references they shouldn't have accessed in the first
/// place.
size_t clearLiveOperations();

/// Removes an operation from the live operations map and sets it invalid.
/// This is useful for when some non-bindings code destroys the operation and
/// the bindings need to made aware. For example, in the case when pass
/// manager is run.
void clearOperation(MlirOperation op);

/// Clears all operations nested inside the given op using
/// `clearOperation(MlirOperation)`.
void clearOperationsInside(PyOperationBase &op);
void clearOperationsInside(MlirOperation op);

/// Gets the count of live modules associated with this context.
/// Used for testing.
size_t getLiveModuleCount();

/// Enter and exit the context manager.
pybind11::object contextEnter();
void contextExit(const pybind11::object &excType,
Expand All @@ -255,22 +227,6 @@ class PyMlirContext {
using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
static LiveContextMap &getLiveContexts();

// Interns all live modules associated with this context. Modules tracked
// in this map are valid. When a module is invalidated, it is removed
// from this map, and while it still exists as an instance, any
// attempt to access it will raise an error.
using LiveModuleMap =
llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>;
LiveModuleMap liveModules;

// Interns all live operations associated with this context. Operations
// tracked in this map are valid. When an operation is invalidated, it is
// removed from this map, and while it still exists as an instance, any
// attempt to access it will raise an error.
using LiveOperationMap =
llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>;
LiveOperationMap liveOperations;

bool emitErrorDiagnostics = false;

MlirContext context;
Expand Down
8 changes: 2 additions & 6 deletions mlir/lib/Bindings/Python/Pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
"ValueError if the pipeline can't be parsed.")
.def(
"run",
[](PyPassManager &passManager, PyOperationBase &op,
bool invalidateOps) {
if (invalidateOps) {
op.getOperation().getContext()->clearOperationsInside(op);
}
[](PyPassManager &passManager, PyOperationBase &op) {
// Actually run the pass manager.
PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
MlirLogicalResult status = mlirPassManagerRunOnOp(
Expand All @@ -130,7 +126,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
throw MLIRError("Failure while executing pass pipeline",
errors.take());
},
"operation"_a, "invalidate_ops"_a = true,
"operation"_a,
"Run the pass manager on the provided operation, raising an "
"MLIRError on failure.")
.def(
Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Bindings/Python/TransformInterpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ static void populateTransformInterpreterSubmodule(py::module &m) {
// root. This is awkward, but we don't have access to PyMlirContext
// object here otherwise.
py::object obj = py::cast(payloadRoot);
obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot);

MlirLogicalResult result = mlirTransformApplyNamedSequence(
payloadRoot, transformRoot, transformModule, options.options);
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/CAPI/IR/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,10 @@ MlirModule mlirModuleFromOperation(MlirOperation op) {
return wrap(dyn_cast<ModuleOp>(unwrap(op)));
}

bool mlirModuleEqual(MlirModule mod, MlirModule other) {
return unwrap(mod) == unwrap(other);
}

//===----------------------------------------------------------------------===//
// Operation state API.
//===----------------------------------------------------------------------===//
Expand Down
20 changes: 2 additions & 18 deletions mlir/test/python/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,27 +102,16 @@ def testRoundtripBinary():
def testModuleOperation():
ctx = Context()
module = Module.parse(r"""module @successfulParse {}""", ctx)
assert ctx._get_live_module_count() == 1
op1 = module.operation
assert ctx._get_live_operation_count() == 1
live_ops = ctx._get_live_operation_objects()
assert len(live_ops) == 1
assert live_ops[0] is op1
live_ops = None
# CHECK: module @successfulParse
print(op1)

# Ensure that operations are the same on multiple calls.
op2 = module.operation
assert ctx._get_live_operation_count() == 1
assert op1 is op2
assert op1 == op2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does is and == differ here?


# Test live operation clearing.
op1 = module.operation
assert ctx._get_live_operation_count() == 1
num_invalidated = ctx._clear_live_operations()
assert num_invalidated == 1
assert ctx._get_live_operation_count() == 0
op1 = None
gc.collect()
op1 = module.operation
Expand All @@ -136,26 +125,21 @@ def testModuleOperation():
op1 = None
op2 = None
gc.collect()
print("LIVE OPERATIONS:", ctx._get_live_operation_count())
assert ctx._get_live_operation_count() == 0
assert ctx._get_live_module_count() == 0


# CHECK-LABEL: TEST: testModuleCapsule
@run
def testModuleCapsule():
ctx = Context()
module = Module.parse(r"""module @successfulParse {}""", ctx)
assert ctx._get_live_module_count() == 1
# CHECK: "mlir.ir.Module._CAPIPtr"
module_capsule = module._CAPIPtr
print(module_capsule)
module_dup = Module._CAPICreate(module_capsule)
assert module is module_dup
assert module == module_dup
assert module_dup.context is ctx
# Gc and verify destructed.
module = None
module_capsule = None
module_dup = None
gc.collect()
assert ctx._get_live_module_count() == 0
2 changes: 1 addition & 1 deletion mlir/test/python/ir/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,7 @@ def testCapsuleConversions():
m_capsule = m._CAPIPtr
assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
m2 = Operation._CAPICreate(m_capsule)
assert m2 is m
assert m2 == m


# CHECK-LABEL: TEST: testOperationErase
Expand Down
Loading
Loading