Skip to content

[mlir python] Add locking around PyMlirContext::liveOperations. #122720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 13, 2025

Conversation

hawkinsp
Copy link
Contributor

In JAX, I observed a race between two PyOperation destructors from different threads updating the same liveOperations map, despite not intentionally sharing the context between different threads. Since I don't think we can be completely sure when GC happens and on which thread, it seems safest simply to add locking here.

We may also want to explicitly support sharing a context between threads in the future, which would require this change or something similar.

In JAX, I observed a race between two PyOperation destructors from
different threads updating the same `liveOperations` map, despite not
intentionally sharing the context between different threads. Since I
don't think we can be completely sure when GC happens and on which
thread, it seems safest simply to add locking here.

We may also want to explicitly support sharing a context between
threads in the future, which would require this change or something
similar.
@llvmbot llvmbot added the mlir label Jan 13, 2025
@llvmbot
Copy link
Member

llvmbot commented Jan 13, 2025

@llvm/pr-subscribers-mlir

Author: Peter Hawkins (hawkinsp)

Changes

In JAX, I observed a race between two PyOperation destructors from different threads updating the same liveOperations map, despite not intentionally sharing the context between different threads. Since I don't think we can be completely sure when GC happens and on which thread, it seems safest simply to add locking here.

We may also want to explicitly support sharing a context between threads in the future, which would require this change or something similar.


Full diff: https://github.com/llvm/llvm-project/pull/122720.diff

2 Files Affected:

  • (modified) mlir/lib/Bindings/Python/IRCore.cpp (+31-12)
  • (modified) mlir/lib/Bindings/Python/IRModule.h (+3)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 463ebdebb3f3f6..53806ca9f04a49 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -677,29 +677,44 @@ size_t PyMlirContext::getLiveCount() {
   return getLiveContexts().size();
 }
 
-size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
+size_t PyMlirContext::getLiveOperationCount() {
+  nb::ft_lock_guard lock(liveOperationsMutex);
+  return liveOperations.size();
+}
 
 std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
   std::vector<PyOperation *> liveObjects;
+  nb::ft_lock_guard lock(liveOperationsMutex);
   for (auto &entry : liveOperations)
     liveObjects.push_back(entry.second.second);
   return liveObjects;
 }
 
 size_t PyMlirContext::clearLiveOperations() {
-  for (auto &op : liveOperations)
+
+  LiveOperationMap operations;
+  {
+    nb::ft_lock_guard lock(liveOperationsMutex);
+    std::swap(operations, liveOperations);
+  }
+  for (auto &op : operations)
     op.second.second->setInvalid();
-  size_t numInvalidated = liveOperations.size();
-  liveOperations.clear();
+  size_t numInvalidated = operations.size();
   return numInvalidated;
 }
 
 void PyMlirContext::clearOperation(MlirOperation op) {
-  auto it = liveOperations.find(op.ptr);
-  if (it != liveOperations.end()) {
-    it->second.second->setInvalid();
+  PyOperation *py_op;
+  {
+    nb::ft_lock_guard lock(liveOperationsMutex);
+    auto it = liveOperations.find(op.ptr);
+    if (it == liveOperations.end()) {
+      return;
+    }
+    py_op = it->second.second;
     liveOperations.erase(it);
   }
+  py_op->setInvalid();
 }
 
 void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
@@ -1183,7 +1198,6 @@ PyOperation::~PyOperation() {
 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
                                            MlirOperation operation,
                                            nb::object parentKeepAlive) {
-  auto &liveOperations = contextRef->liveOperations;
   // Create.
   PyOperation *unownedOperation =
       new PyOperation(std::move(contextRef), operation);
@@ -1195,19 +1209,22 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
   if (parentKeepAlive) {
     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
   }
-  liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
   return PyOperationRef(unownedOperation, std::move(pyRef));
 }
 
 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
                                          MlirOperation operation,
                                          nb::object parentKeepAlive) {
+  nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
   auto &liveOperations = contextRef->liveOperations;
   auto it = liveOperations.find(operation.ptr);
   if (it == liveOperations.end()) {
     // Create.
-    return createInstance(std::move(contextRef), operation,
-                          std::move(parentKeepAlive));
+    PyOperationRef result = createInstance(std::move(contextRef), operation,
+                                           std::move(parentKeepAlive));
+    liveOperations[operation.ptr] =
+        std::make_pair(result.getObject(), result.get());
+    return result;
   }
   // Use existing.
   PyOperation *existing = it->second.second;
@@ -1218,13 +1235,15 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
                                            MlirOperation operation,
                                            nb::object parentKeepAlive) {
+  nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
   auto &liveOperations = contextRef->liveOperations;
   assert(liveOperations.count(operation.ptr) == 0 &&
          "cannot create detached operation that already exists");
   (void)liveOperations;
-
   PyOperationRef created = createInstance(std::move(contextRef), operation,
                                           std::move(parentKeepAlive));
+  liveOperations[operation.ptr] =
+      std::make_pair(created.getObject(), created.get());
   created->attached = false;
   return created;
 }
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index f5fbb6c61b57e2..d1fb4308dbb77c 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -277,6 +277,9 @@ class PyMlirContext {
   // attempt to access it will raise an error.
   using LiveOperationMap =
       llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>;
+  nanobind::ft_mutex liveOperationsMutex;
+
+  // Guarded by liveOperationsMutex in free-threading mode.
   LiveOperationMap liveOperations;
 
   bool emitErrorDiagnostics = false;

@hawkinsp
Copy link
Contributor Author

@jpienaar

@jpienaar jpienaar merged commit e2c49a4 into llvm:main Jan 13, 2025
8 of 9 checks passed
kazutakahirata pushed a commit to kazutakahirata/llvm-project that referenced this pull request Jan 13, 2025
…#122720)

In JAX, I observed a race between two PyOperation destructors from
different threads updating the same `liveOperations` map, despite not
intentionally sharing the context between different threads. Since I
don't think we can be completely sure when GC happens and on which
thread, it seems safest simply to add locking here.

We may also want to explicitly support sharing a context between threads
in the future, which would require this change or something similar.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants