Skip to content

Commit e2c49a4

Browse files
authored
[mlir python] Add locking around PyMlirContext::liveOperations. (#122720)
In JAX, I observed a race between two PyOperation destructors from different threads updating the same `liveOperations` map, despite not intentionally sharing the context between different threads. Since I don't think we can be completely sure when GC happens and on which thread, it seems safest simply to add locking here. We may also want to explicitly support sharing a context between threads in the future, which would require this change or something similar.
1 parent 3318a72 commit e2c49a4

File tree

2 files changed

+34
-12
lines changed

2 files changed

+34
-12
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -677,29 +677,44 @@ size_t PyMlirContext::getLiveCount() {
677677
return getLiveContexts().size();
678678
}
679679

680-
size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
680+
size_t PyMlirContext::getLiveOperationCount() {
681+
nb::ft_lock_guard lock(liveOperationsMutex);
682+
return liveOperations.size();
683+
}
681684

682685
std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
683686
std::vector<PyOperation *> liveObjects;
687+
nb::ft_lock_guard lock(liveOperationsMutex);
684688
for (auto &entry : liveOperations)
685689
liveObjects.push_back(entry.second.second);
686690
return liveObjects;
687691
}
688692

689693
size_t PyMlirContext::clearLiveOperations() {
690-
for (auto &op : liveOperations)
694+
695+
LiveOperationMap operations;
696+
{
697+
nb::ft_lock_guard lock(liveOperationsMutex);
698+
std::swap(operations, liveOperations);
699+
}
700+
for (auto &op : operations)
691701
op.second.second->setInvalid();
692-
size_t numInvalidated = liveOperations.size();
693-
liveOperations.clear();
702+
size_t numInvalidated = operations.size();
694703
return numInvalidated;
695704
}
696705

697706
void PyMlirContext::clearOperation(MlirOperation op) {
698-
auto it = liveOperations.find(op.ptr);
699-
if (it != liveOperations.end()) {
700-
it->second.second->setInvalid();
707+
PyOperation *py_op;
708+
{
709+
nb::ft_lock_guard lock(liveOperationsMutex);
710+
auto it = liveOperations.find(op.ptr);
711+
if (it == liveOperations.end()) {
712+
return;
713+
}
714+
py_op = it->second.second;
701715
liveOperations.erase(it);
702716
}
717+
py_op->setInvalid();
703718
}
704719

705720
void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
@@ -1183,7 +1198,6 @@ PyOperation::~PyOperation() {
11831198
PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
11841199
MlirOperation operation,
11851200
nb::object parentKeepAlive) {
1186-
auto &liveOperations = contextRef->liveOperations;
11871201
// Create.
11881202
PyOperation *unownedOperation =
11891203
new PyOperation(std::move(contextRef), operation);
@@ -1195,19 +1209,22 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
11951209
if (parentKeepAlive) {
11961210
unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
11971211
}
1198-
liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
11991212
return PyOperationRef(unownedOperation, std::move(pyRef));
12001213
}
12011214

12021215
PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
12031216
MlirOperation operation,
12041217
nb::object parentKeepAlive) {
1218+
nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
12051219
auto &liveOperations = contextRef->liveOperations;
12061220
auto it = liveOperations.find(operation.ptr);
12071221
if (it == liveOperations.end()) {
12081222
// Create.
1209-
return createInstance(std::move(contextRef), operation,
1210-
std::move(parentKeepAlive));
1223+
PyOperationRef result = createInstance(std::move(contextRef), operation,
1224+
std::move(parentKeepAlive));
1225+
liveOperations[operation.ptr] =
1226+
std::make_pair(result.getObject(), result.get());
1227+
return result;
12111228
}
12121229
// Use existing.
12131230
PyOperation *existing = it->second.second;
@@ -1218,13 +1235,15 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
12181235
PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
12191236
MlirOperation operation,
12201237
nb::object parentKeepAlive) {
1238+
nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
12211239
auto &liveOperations = contextRef->liveOperations;
12221240
assert(liveOperations.count(operation.ptr) == 0 &&
12231241
"cannot create detached operation that already exists");
12241242
(void)liveOperations;
1225-
12261243
PyOperationRef created = createInstance(std::move(contextRef), operation,
12271244
std::move(parentKeepAlive));
1245+
liveOperations[operation.ptr] =
1246+
std::make_pair(created.getObject(), created.get());
12281247
created->attached = false;
12291248
return created;
12301249
}

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,9 @@ class PyMlirContext {
277277
// attempt to access it will raise an error.
278278
using LiveOperationMap =
279279
llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>;
280+
nanobind::ft_mutex liveOperationsMutex;
281+
282+
// Guarded by liveOperationsMutex in free-threading mode.
280283
LiveOperationMap liveOperations;
281284

282285
bool emitErrorDiagnostics = false;

0 commit comments

Comments
 (0)