|
1 | 1 | #
|
2 |
| -# Module providing the `SyncManager` class for dealing |
| 2 | +# Module providing manager classes for dealing |
3 | 3 | # with shared objects
|
4 | 4 | #
|
5 | 5 | # multiprocessing/managers.py
|
|
8 | 8 | # Licensed to PSF under a Contributor Agreement.
|
9 | 9 | #
|
10 | 10 |
|
11 |
| -__all__ = [ 'BaseManager', 'SyncManager', 'BaseProxy', 'Token' ] |
| 11 | +__all__ = [ 'BaseManager', 'SyncManager', 'BaseProxy', 'Token', |
| 12 | + 'SharedMemoryManager' ] |
12 | 13 |
|
13 | 14 | #
|
14 | 15 | # Imports
|
|
19 | 20 | import array
|
20 | 21 | import queue
|
21 | 22 | import time
|
| 23 | +from os import getpid |
22 | 24 |
|
23 | 25 | from traceback import format_exc
|
24 | 26 |
|
|
28 | 30 | from . import process
|
29 | 31 | from . import util
|
30 | 32 | from . import get_context
|
| 33 | +try: |
| 34 | + from . import shared_memory |
| 35 | + HAS_SHMEM = True |
| 36 | +except ImportError: |
| 37 | + HAS_SHMEM = False |
31 | 38 |
|
32 | 39 | #
|
33 | 40 | # Register some things for pickling
|
@@ -1200,3 +1207,143 @@ class SyncManager(BaseManager):
|
1200 | 1207 | # types returned by methods of PoolProxy
|
1201 | 1208 | SyncManager.register('Iterator', proxytype=IteratorProxy, create_method=False)
|
1202 | 1209 | SyncManager.register('AsyncResult', create_method=False)
|
| 1210 | + |
| 1211 | +# |
| 1212 | +# Definition of SharedMemoryManager and SharedMemoryServer |
| 1213 | +# |
| 1214 | + |
| 1215 | +if HAS_SHMEM: |
| 1216 | + class _SharedMemoryTracker: |
| 1217 | + "Manages one or more shared memory segments." |
| 1218 | + |
| 1219 | + def __init__(self, name, segment_names=[]): |
| 1220 | + self.shared_memory_context_name = name |
| 1221 | + self.segment_names = segment_names |
| 1222 | + |
| 1223 | + def register_segment(self, segment_name): |
| 1224 | + "Adds the supplied shared memory block name to tracker." |
| 1225 | + util.debug(f"Register segment {segment_name!r} in pid {getpid()}") |
| 1226 | + self.segment_names.append(segment_name) |
| 1227 | + |
| 1228 | + def destroy_segment(self, segment_name): |
| 1229 | + """Calls unlink() on the shared memory block with the supplied name |
| 1230 | + and removes it from the list of blocks being tracked.""" |
| 1231 | + util.debug(f"Destroy segment {segment_name!r} in pid {getpid()}") |
| 1232 | + self.segment_names.remove(segment_name) |
| 1233 | + segment = shared_memory.SharedMemory(segment_name) |
| 1234 | + segment.close() |
| 1235 | + segment.unlink() |
| 1236 | + |
| 1237 | + def unlink(self): |
| 1238 | + "Calls destroy_segment() on all tracked shared memory blocks." |
| 1239 | + for segment_name in self.segment_names[:]: |
| 1240 | + self.destroy_segment(segment_name) |
| 1241 | + |
| 1242 | + def __del__(self): |
| 1243 | + util.debug(f"Call {self.__class__.__name__}.__del__ in {getpid()}") |
| 1244 | + self.unlink() |
| 1245 | + |
| 1246 | + def __getstate__(self): |
| 1247 | + return (self.shared_memory_context_name, self.segment_names) |
| 1248 | + |
| 1249 | + def __setstate__(self, state): |
| 1250 | + self.__init__(*state) |
| 1251 | + |
| 1252 | + |
| 1253 | + class SharedMemoryServer(Server): |
| 1254 | + |
| 1255 | + public = Server.public + \ |
| 1256 | + ['track_segment', 'release_segment', 'list_segments'] |
| 1257 | + |
| 1258 | + def __init__(self, *args, **kwargs): |
| 1259 | + Server.__init__(self, *args, **kwargs) |
| 1260 | + self.shared_memory_context = \ |
| 1261 | + _SharedMemoryTracker(f"shmm_{self.address}_{getpid()}") |
| 1262 | + util.debug(f"SharedMemoryServer started by pid {getpid()}") |
| 1263 | + |
| 1264 | + def create(self, c, typeid, *args, **kwargs): |
| 1265 | + """Create a new distributed-shared object (not backed by a shared |
| 1266 | + memory block) and return its id to be used in a Proxy Object.""" |
| 1267 | + # Unless set up as a shared proxy, don't make shared_memory_context |
| 1268 | + # a standard part of kwargs. This makes things easier for supplying |
| 1269 | + # simple functions. |
| 1270 | + if hasattr(self.registry[typeid][-1], "_shared_memory_proxy"): |
| 1271 | + kwargs['shared_memory_context'] = self.shared_memory_context |
| 1272 | + return Server.create(self, c, typeid, *args, **kwargs) |
| 1273 | + |
| 1274 | + def shutdown(self, c): |
| 1275 | + "Call unlink() on all tracked shared memory, terminate the Server." |
| 1276 | + self.shared_memory_context.unlink() |
| 1277 | + return Server.shutdown(self, c) |
| 1278 | + |
| 1279 | + def track_segment(self, c, segment_name): |
| 1280 | + "Adds the supplied shared memory block name to Server's tracker." |
| 1281 | + self.shared_memory_context.register_segment(segment_name) |
| 1282 | + |
| 1283 | + def release_segment(self, c, segment_name): |
| 1284 | + """Calls unlink() on the shared memory block with the supplied name |
| 1285 | + and removes it from the tracker instance inside the Server.""" |
| 1286 | + self.shared_memory_context.destroy_segment(segment_name) |
| 1287 | + |
| 1288 | + def list_segments(self, c): |
| 1289 | + """Returns a list of names of shared memory blocks that the Server |
| 1290 | + is currently tracking.""" |
| 1291 | + return self.shared_memory_context.segment_names |
| 1292 | + |
| 1293 | + |
| 1294 | + class SharedMemoryManager(BaseManager): |
| 1295 | + """Like SyncManager but uses SharedMemoryServer instead of Server. |
| 1296 | +
|
| 1297 | + It provides methods for creating and returning SharedMemory instances |
| 1298 | + and for creating a list-like object (ShareableList) backed by shared |
| 1299 | + memory. It also provides methods that create and return Proxy Objects |
| 1300 | + that support synchronization across processes (i.e. multi-process-safe |
| 1301 | + locks and semaphores). |
| 1302 | + """ |
| 1303 | + |
| 1304 | + _Server = SharedMemoryServer |
| 1305 | + |
| 1306 | + def __init__(self, *args, **kwargs): |
| 1307 | + BaseManager.__init__(self, *args, **kwargs) |
| 1308 | + util.debug(f"{self.__class__.__name__} created by pid {getpid()}") |
| 1309 | + |
| 1310 | + def __del__(self): |
| 1311 | + util.debug(f"{self.__class__.__name__}.__del__ by pid {getpid()}") |
| 1312 | + pass |
| 1313 | + |
| 1314 | + def get_server(self): |
| 1315 | + 'Better than monkeypatching for now; merge into Server ultimately' |
| 1316 | + if self._state.value != State.INITIAL: |
| 1317 | + if self._state.value == State.STARTED: |
| 1318 | + raise ProcessError("Already started SharedMemoryServer") |
| 1319 | + elif self._state.value == State.SHUTDOWN: |
| 1320 | + raise ProcessError("SharedMemoryManager has shut down") |
| 1321 | + else: |
| 1322 | + raise ProcessError( |
| 1323 | + "Unknown state {!r}".format(self._state.value)) |
| 1324 | + return self._Server(self._registry, self._address, |
| 1325 | + self._authkey, self._serializer) |
| 1326 | + |
| 1327 | + def SharedMemory(self, size): |
| 1328 | + """Returns a new SharedMemory instance with the specified size in |
| 1329 | + bytes, to be tracked by the manager.""" |
| 1330 | + with self._Client(self._address, authkey=self._authkey) as conn: |
| 1331 | + sms = shared_memory.SharedMemory(None, create=True, size=size) |
| 1332 | + try: |
| 1333 | + dispatch(conn, None, 'track_segment', (sms.name,)) |
| 1334 | + except BaseException as e: |
| 1335 | + sms.unlink() |
| 1336 | + raise e |
| 1337 | + return sms |
| 1338 | + |
| 1339 | + def ShareableList(self, sequence): |
| 1340 | + """Returns a new ShareableList instance populated with the values |
| 1341 | + from the input sequence, to be tracked by the manager.""" |
| 1342 | + with self._Client(self._address, authkey=self._authkey) as conn: |
| 1343 | + sl = shared_memory.ShareableList(sequence) |
| 1344 | + try: |
| 1345 | + dispatch(conn, None, 'track_segment', (sl.shm.name,)) |
| 1346 | + except BaseException as e: |
| 1347 | + sl.shm.unlink() |
| 1348 | + raise e |
| 1349 | + return sl |
0 commit comments