@@ -677,29 +677,44 @@ size_t PyMlirContext::getLiveCount() {
677
677
return getLiveContexts ().size ();
678
678
}
679
679
680
- size_t PyMlirContext::getLiveOperationCount () { return liveOperations.size (); }
680
+ size_t PyMlirContext::getLiveOperationCount () {
681
+ nb::ft_lock_guard lock (liveOperationsMutex);
682
+ return liveOperations.size ();
683
+ }
681
684
682
685
std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects () {
683
686
std::vector<PyOperation *> liveObjects;
687
+ nb::ft_lock_guard lock (liveOperationsMutex);
684
688
for (auto &entry : liveOperations)
685
689
liveObjects.push_back (entry.second .second );
686
690
return liveObjects;
687
691
}
688
692
689
693
size_t PyMlirContext::clearLiveOperations () {
690
- for (auto &op : liveOperations)
694
+
695
+ LiveOperationMap operations;
696
+ {
697
+ nb::ft_lock_guard lock (liveOperationsMutex);
698
+ std::swap (operations, liveOperations);
699
+ }
700
+ for (auto &op : operations)
691
701
op.second .second ->setInvalid ();
692
- size_t numInvalidated = liveOperations.size ();
693
- liveOperations.clear ();
702
+ size_t numInvalidated = operations.size ();
694
703
return numInvalidated;
695
704
}
696
705
697
706
void PyMlirContext::clearOperation (MlirOperation op) {
698
- auto it = liveOperations.find (op.ptr );
699
- if (it != liveOperations.end ()) {
700
- it->second .second ->setInvalid ();
707
+ PyOperation *py_op;
708
+ {
709
+ nb::ft_lock_guard lock (liveOperationsMutex);
710
+ auto it = liveOperations.find (op.ptr );
711
+ if (it == liveOperations.end ()) {
712
+ return ;
713
+ }
714
+ py_op = it->second .second ;
701
715
liveOperations.erase (it);
702
716
}
717
+ py_op->setInvalid ();
703
718
}
704
719
705
720
void PyMlirContext::clearOperationsInside (PyOperationBase &op) {
@@ -1183,7 +1198,6 @@ PyOperation::~PyOperation() {
1183
1198
PyOperationRef PyOperation::createInstance (PyMlirContextRef contextRef,
1184
1199
MlirOperation operation,
1185
1200
nb::object parentKeepAlive) {
1186
- auto &liveOperations = contextRef->liveOperations ;
1187
1201
// Create.
1188
1202
PyOperation *unownedOperation =
1189
1203
new PyOperation (std::move (contextRef), operation);
@@ -1195,19 +1209,22 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
1195
1209
if (parentKeepAlive) {
1196
1210
unownedOperation->parentKeepAlive = std::move (parentKeepAlive);
1197
1211
}
1198
- liveOperations[operation.ptr ] = std::make_pair (pyRef, unownedOperation);
1199
1212
return PyOperationRef (unownedOperation, std::move (pyRef));
1200
1213
}
1201
1214
1202
1215
PyOperationRef PyOperation::forOperation (PyMlirContextRef contextRef,
1203
1216
MlirOperation operation,
1204
1217
nb::object parentKeepAlive) {
1218
+ nb::ft_lock_guard lock (contextRef->liveOperationsMutex );
1205
1219
auto &liveOperations = contextRef->liveOperations ;
1206
1220
auto it = liveOperations.find (operation.ptr );
1207
1221
if (it == liveOperations.end ()) {
1208
1222
// Create.
1209
- return createInstance (std::move (contextRef), operation,
1210
- std::move (parentKeepAlive));
1223
+ PyOperationRef result = createInstance (std::move (contextRef), operation,
1224
+ std::move (parentKeepAlive));
1225
+ liveOperations[operation.ptr ] =
1226
+ std::make_pair (result.getObject (), result.get ());
1227
+ return result;
1211
1228
}
1212
1229
// Use existing.
1213
1230
PyOperation *existing = it->second .second ;
@@ -1218,13 +1235,15 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
1218
1235
PyOperationRef PyOperation::createDetached (PyMlirContextRef contextRef,
1219
1236
MlirOperation operation,
1220
1237
nb::object parentKeepAlive) {
1238
+ nb::ft_lock_guard lock (contextRef->liveOperationsMutex );
1221
1239
auto &liveOperations = contextRef->liveOperations ;
1222
1240
assert (liveOperations.count (operation.ptr ) == 0 &&
1223
1241
" cannot create detached operation that already exists" );
1224
1242
(void )liveOperations;
1225
-
1226
1243
PyOperationRef created = createInstance (std::move (contextRef), operation,
1227
1244
std::move (parentKeepAlive));
1245
+ liveOperations[operation.ptr ] =
1246
+ std::make_pair (created.getObject (), created.get ());
1228
1247
created->attached = false ;
1229
1248
return created;
1230
1249
}
0 commit comments