Skip to content

Commit 794623b

Browse files
authored
bpo-28699: fix abnormal behaviour of pools in multiprocessing.pool (GH-693)
an exception raised at the very first of an iterable would cause pools behave abnormally (swallow the exception or hang)
1 parent ec1f5df commit 794623b

File tree

3 files changed

+117
-25
lines changed

3 files changed

+117
-25
lines changed

Lib/multiprocessing/pool.py

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None,
118118
try:
119119
result = (True, func(*args, **kwds))
120120
except Exception as e:
121-
if wrap_exception:
121+
if wrap_exception and func is not _helper_reraises_exception:
122122
e = ExceptionWithTraceback(e, e.__traceback__)
123123
result = (False, e)
124124
try:
@@ -133,6 +133,10 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None,
133133
completed += 1
134134
util.debug('worker exiting after %d tasks' % completed)
135135

136+
def _helper_reraises_exception(ex):
137+
'Pickle-able helper function for use by _guarded_task_generation.'
138+
raise ex
139+
136140
#
137141
# Class representing a process pool
138142
#
@@ -277,6 +281,17 @@ def starmap_async(self, func, iterable, chunksize=None, callback=None,
277281
return self._map_async(func, iterable, starmapstar, chunksize,
278282
callback, error_callback)
279283

284+
def _guarded_task_generation(self, result_job, func, iterable):
285+
'''Provides a generator of tasks for imap and imap_unordered with
286+
appropriate handling for iterables which throw exceptions during
287+
iteration.'''
288+
try:
289+
i = -1
290+
for i, x in enumerate(iterable):
291+
yield (result_job, i, func, (x,), {})
292+
except Exception as e:
293+
yield (result_job, i+1, _helper_reraises_exception, (e,), {})
294+
280295
def imap(self, func, iterable, chunksize=1):
281296
'''
282297
Equivalent of `map()` -- can be MUCH slower than `Pool.map()`.
@@ -285,15 +300,23 @@ def imap(self, func, iterable, chunksize=1):
285300
raise ValueError("Pool not running")
286301
if chunksize == 1:
287302
result = IMapIterator(self._cache)
288-
self._taskqueue.put((((result._job, i, func, (x,), {})
289-
for i, x in enumerate(iterable)), result._set_length))
303+
self._taskqueue.put(
304+
(
305+
self._guarded_task_generation(result._job, func, iterable),
306+
result._set_length
307+
))
290308
return result
291309
else:
292310
assert chunksize > 1
293311
task_batches = Pool._get_tasks(func, iterable, chunksize)
294312
result = IMapIterator(self._cache)
295-
self._taskqueue.put((((result._job, i, mapstar, (x,), {})
296-
for i, x in enumerate(task_batches)), result._set_length))
313+
self._taskqueue.put(
314+
(
315+
self._guarded_task_generation(result._job,
316+
mapstar,
317+
task_batches),
318+
result._set_length
319+
))
297320
return (item for chunk in result for item in chunk)
298321

299322
def imap_unordered(self, func, iterable, chunksize=1):
@@ -304,15 +327,23 @@ def imap_unordered(self, func, iterable, chunksize=1):
304327
raise ValueError("Pool not running")
305328
if chunksize == 1:
306329
result = IMapUnorderedIterator(self._cache)
307-
self._taskqueue.put((((result._job, i, func, (x,), {})
308-
for i, x in enumerate(iterable)), result._set_length))
330+
self._taskqueue.put(
331+
(
332+
self._guarded_task_generation(result._job, func, iterable),
333+
result._set_length
334+
))
309335
return result
310336
else:
311337
assert chunksize > 1
312338
task_batches = Pool._get_tasks(func, iterable, chunksize)
313339
result = IMapUnorderedIterator(self._cache)
314-
self._taskqueue.put((((result._job, i, mapstar, (x,), {})
315-
for i, x in enumerate(task_batches)), result._set_length))
340+
self._taskqueue.put(
341+
(
342+
self._guarded_task_generation(result._job,
343+
mapstar,
344+
task_batches),
345+
result._set_length
346+
))
316347
return (item for chunk in result for item in chunk)
317348

318349
def apply_async(self, func, args=(), kwds={}, callback=None,
@@ -323,7 +354,7 @@ def apply_async(self, func, args=(), kwds={}, callback=None,
323354
if self._state != RUN:
324355
raise ValueError("Pool not running")
325356
result = ApplyResult(self._cache, callback, error_callback)
326-
self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
357+
self._taskqueue.put(([(result._job, 0, func, args, kwds)], None))
327358
return result
328359

329360
def map_async(self, func, iterable, chunksize=None, callback=None,
@@ -354,8 +385,14 @@ def _map_async(self, func, iterable, mapper, chunksize=None, callback=None,
354385
task_batches = Pool._get_tasks(func, iterable, chunksize)
355386
result = MapResult(self._cache, chunksize, len(iterable), callback,
356387
error_callback=error_callback)
357-
self._taskqueue.put((((result._job, i, mapper, (x,), {})
358-
for i, x in enumerate(task_batches)), None))
388+
self._taskqueue.put(
389+
(
390+
self._guarded_task_generation(result._job,
391+
mapper,
392+
task_batches),
393+
None
394+
)
395+
)
359396
return result
360397

361398
@staticmethod
@@ -377,33 +414,27 @@ def _handle_tasks(taskqueue, put, outqueue, pool, cache):
377414

378415
for taskseq, set_length in iter(taskqueue.get, None):
379416
task = None
380-
i = -1
381417
try:
382-
for i, task in enumerate(taskseq):
418+
# iterating taskseq cannot fail
419+
for task in taskseq:
383420
if thread._state:
384421
util.debug('task handler found thread._state != RUN')
385422
break
386423
try:
387424
put(task)
388425
except Exception as e:
389-
job, ind = task[:2]
426+
job, idx = task[:2]
390427
try:
391-
cache[job]._set(ind, (False, e))
428+
cache[job]._set(idx, (False, e))
392429
except KeyError:
393430
pass
394431
else:
395432
if set_length:
396433
util.debug('doing set_length()')
397-
set_length(i+1)
434+
idx = task[1] if task else -1
435+
set_length(idx + 1)
398436
continue
399437
break
400-
except Exception as ex:
401-
job, ind = task[:2] if task else (0, 0)
402-
if job in cache:
403-
cache[job]._set(ind + 1, (False, ex))
404-
if set_length:
405-
util.debug('doing set_length()')
406-
set_length(i+1)
407438
finally:
408439
task = taskseq = job = None
409440
else:

Lib/test/_test_multiprocessing.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1755,6 +1755,8 @@ def __del__(self):
17551755
class SayWhenError(ValueError): pass
17561756

17571757
def exception_throwing_generator(total, when):
1758+
if when == -1:
1759+
raise SayWhenError("Somebody said when")
17581760
for i in range(total):
17591761
if i == when:
17601762
raise SayWhenError("Somebody said when")
@@ -1833,6 +1835,32 @@ def test_map_chunksize(self):
18331835
except multiprocessing.TimeoutError:
18341836
self.fail("pool.map_async with chunksize stalled on null list")
18351837

1838+
def test_map_handle_iterable_exception(self):
1839+
if self.TYPE == 'manager':
1840+
self.skipTest('test not appropriate for {}'.format(self.TYPE))
1841+
1842+
# SayWhenError seen at the very first of the iterable
1843+
with self.assertRaises(SayWhenError):
1844+
self.pool.map(sqr, exception_throwing_generator(1, -1), 1)
1845+
# again, make sure it's reentrant
1846+
with self.assertRaises(SayWhenError):
1847+
self.pool.map(sqr, exception_throwing_generator(1, -1), 1)
1848+
1849+
with self.assertRaises(SayWhenError):
1850+
self.pool.map(sqr, exception_throwing_generator(10, 3), 1)
1851+
1852+
class SpecialIterable:
1853+
def __iter__(self):
1854+
return self
1855+
def __next__(self):
1856+
raise SayWhenError
1857+
def __len__(self):
1858+
return 1
1859+
with self.assertRaises(SayWhenError):
1860+
self.pool.map(sqr, SpecialIterable(), 1)
1861+
with self.assertRaises(SayWhenError):
1862+
self.pool.map(sqr, SpecialIterable(), 1)
1863+
18361864
def test_async(self):
18371865
res = self.pool.apply_async(sqr, (7, TIMEOUT1,))
18381866
get = TimingWrapper(res.get)
@@ -1863,6 +1891,13 @@ def test_imap_handle_iterable_exception(self):
18631891
if self.TYPE == 'manager':
18641892
self.skipTest('test not appropriate for {}'.format(self.TYPE))
18651893

1894+
# SayWhenError seen at the very first of the iterable
1895+
it = self.pool.imap(sqr, exception_throwing_generator(1, -1), 1)
1896+
self.assertRaises(SayWhenError, it.__next__)
1897+
# again, make sure it's reentrant
1898+
it = self.pool.imap(sqr, exception_throwing_generator(1, -1), 1)
1899+
self.assertRaises(SayWhenError, it.__next__)
1900+
18661901
it = self.pool.imap(sqr, exception_throwing_generator(10, 3), 1)
18671902
for i in range(3):
18681903
self.assertEqual(next(it), i*i)
@@ -1889,6 +1924,17 @@ def test_imap_unordered_handle_iterable_exception(self):
18891924
if self.TYPE == 'manager':
18901925
self.skipTest('test not appropriate for {}'.format(self.TYPE))
18911926

1927+
# SayWhenError seen at the very first of the iterable
1928+
it = self.pool.imap_unordered(sqr,
1929+
exception_throwing_generator(1, -1),
1930+
1)
1931+
self.assertRaises(SayWhenError, it.__next__)
1932+
# again, make sure it's reentrant
1933+
it = self.pool.imap_unordered(sqr,
1934+
exception_throwing_generator(1, -1),
1935+
1)
1936+
self.assertRaises(SayWhenError, it.__next__)
1937+
18921938
it = self.pool.imap_unordered(sqr,
18931939
exception_throwing_generator(10, 3),
18941940
1)
@@ -1970,7 +2016,7 @@ def test_traceback(self):
19702016
except Exception as e:
19712017
exc = e
19722018
else:
1973-
raise AssertionError('expected RuntimeError')
2019+
self.fail('expected RuntimeError')
19742020
self.assertIs(type(exc), RuntimeError)
19752021
self.assertEqual(exc.args, (123,))
19762022
cause = exc.__cause__
@@ -1984,6 +2030,17 @@ def test_traceback(self):
19842030
sys.excepthook(*sys.exc_info())
19852031
self.assertIn('raise RuntimeError(123) # some comment',
19862032
f1.getvalue())
2033+
# _helper_reraises_exception should not make the error
2034+
# a remote exception
2035+
with self.Pool(1) as p:
2036+
try:
2037+
p.map(sqr, exception_throwing_generator(1, -1), 1)
2038+
except Exception as e:
2039+
exc = e
2040+
else:
2041+
self.fail('expected SayWhenError')
2042+
self.assertIs(type(exc), SayWhenError)
2043+
self.assertIs(exc.__cause__, None)
19872044

19882045
@classmethod
19892046
def _test_wrapped_exception(cls):

Misc/NEWS

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,10 @@ Extension Modules
291291
Library
292292
-------
293293

294+
- bpo-28699: Fixed a bug in pools in multiprocessing.pool that raising an
295+
exception at the very first of an iterable may swallow the exception or
296+
make the program hang. Patch by Davin Potts and Xiang Zhang.
297+
294298
- bpo-23890: unittest.TestCase.assertRaises() now manually breaks a reference
295299
cycle to not keep objects alive longer than expected.
296300

0 commit comments

Comments
 (0)