Skip to content

Commit 9f8e090

Browse files
authored
bpo-28699: fix abnormal behaviour of pools in multiprocessing.pool (GH-884)
an exception raised at the very first of an iterable would cause pools behave abnormally (swallow the exception or hang)
1 parent 0957f26 commit 9f8e090

File tree

3 files changed

+117
-25
lines changed

3 files changed

+117
-25
lines changed

Lib/multiprocessing/pool.py

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None,
118118
try:
119119
result = (True, func(*args, **kwds))
120120
except Exception as e:
121-
if wrap_exception:
121+
if wrap_exception and func is not _helper_reraises_exception:
122122
e = ExceptionWithTraceback(e, e.__traceback__)
123123
result = (False, e)
124124
try:
@@ -133,6 +133,10 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None,
133133
completed += 1
134134
util.debug('worker exiting after %d tasks' % completed)
135135

136+
def _helper_reraises_exception(ex):
137+
'Pickle-able helper function for use by _guarded_task_generation.'
138+
raise ex
139+
136140
#
137141
# Class representing a process pool
138142
#
@@ -277,6 +281,17 @@ def starmap_async(self, func, iterable, chunksize=None, callback=None,
277281
return self._map_async(func, iterable, starmapstar, chunksize,
278282
callback, error_callback)
279283

284+
def _guarded_task_generation(self, result_job, func, iterable):
285+
'''Provides a generator of tasks for imap and imap_unordered with
286+
appropriate handling for iterables which throw exceptions during
287+
iteration.'''
288+
try:
289+
i = -1
290+
for i, x in enumerate(iterable):
291+
yield (result_job, i, func, (x,), {})
292+
except Exception as e:
293+
yield (result_job, i+1, _helper_reraises_exception, (e,), {})
294+
280295
def imap(self, func, iterable, chunksize=1):
281296
'''
282297
Equivalent of `map()` -- can be MUCH slower than `Pool.map()`.
@@ -285,15 +300,23 @@ def imap(self, func, iterable, chunksize=1):
285300
raise ValueError("Pool not running")
286301
if chunksize == 1:
287302
result = IMapIterator(self._cache)
288-
self._taskqueue.put((((result._job, i, func, (x,), {})
289-
for i, x in enumerate(iterable)), result._set_length))
303+
self._taskqueue.put(
304+
(
305+
self._guarded_task_generation(result._job, func, iterable),
306+
result._set_length
307+
))
290308
return result
291309
else:
292310
assert chunksize > 1
293311
task_batches = Pool._get_tasks(func, iterable, chunksize)
294312
result = IMapIterator(self._cache)
295-
self._taskqueue.put((((result._job, i, mapstar, (x,), {})
296-
for i, x in enumerate(task_batches)), result._set_length))
313+
self._taskqueue.put(
314+
(
315+
self._guarded_task_generation(result._job,
316+
mapstar,
317+
task_batches),
318+
result._set_length
319+
))
297320
return (item for chunk in result for item in chunk)
298321

299322
def imap_unordered(self, func, iterable, chunksize=1):
@@ -304,15 +327,23 @@ def imap_unordered(self, func, iterable, chunksize=1):
304327
raise ValueError("Pool not running")
305328
if chunksize == 1:
306329
result = IMapUnorderedIterator(self._cache)
307-
self._taskqueue.put((((result._job, i, func, (x,), {})
308-
for i, x in enumerate(iterable)), result._set_length))
330+
self._taskqueue.put(
331+
(
332+
self._guarded_task_generation(result._job, func, iterable),
333+
result._set_length
334+
))
309335
return result
310336
else:
311337
assert chunksize > 1
312338
task_batches = Pool._get_tasks(func, iterable, chunksize)
313339
result = IMapUnorderedIterator(self._cache)
314-
self._taskqueue.put((((result._job, i, mapstar, (x,), {})
315-
for i, x in enumerate(task_batches)), result._set_length))
340+
self._taskqueue.put(
341+
(
342+
self._guarded_task_generation(result._job,
343+
mapstar,
344+
task_batches),
345+
result._set_length
346+
))
316347
return (item for chunk in result for item in chunk)
317348

318349
def apply_async(self, func, args=(), kwds={}, callback=None,
@@ -323,7 +354,7 @@ def apply_async(self, func, args=(), kwds={}, callback=None,
323354
if self._state != RUN:
324355
raise ValueError("Pool not running")
325356
result = ApplyResult(self._cache, callback, error_callback)
326-
self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
357+
self._taskqueue.put(([(result._job, 0, func, args, kwds)], None))
327358
return result
328359

329360
def map_async(self, func, iterable, chunksize=None, callback=None,
@@ -354,8 +385,14 @@ def _map_async(self, func, iterable, mapper, chunksize=None, callback=None,
354385
task_batches = Pool._get_tasks(func, iterable, chunksize)
355386
result = MapResult(self._cache, chunksize, len(iterable), callback,
356387
error_callback=error_callback)
357-
self._taskqueue.put((((result._job, i, mapper, (x,), {})
358-
for i, x in enumerate(task_batches)), None))
388+
self._taskqueue.put(
389+
(
390+
self._guarded_task_generation(result._job,
391+
mapper,
392+
task_batches),
393+
None
394+
)
395+
)
359396
return result
360397

361398
@staticmethod
@@ -377,33 +414,27 @@ def _handle_tasks(taskqueue, put, outqueue, pool, cache):
377414

378415
for taskseq, set_length in iter(taskqueue.get, None):
379416
task = None
380-
i = -1
381417
try:
382-
for i, task in enumerate(taskseq):
418+
# iterating taskseq cannot fail
419+
for task in taskseq:
383420
if thread._state:
384421
util.debug('task handler found thread._state != RUN')
385422
break
386423
try:
387424
put(task)
388425
except Exception as e:
389-
job, ind = task[:2]
426+
job, idx = task[:2]
390427
try:
391-
cache[job]._set(ind, (False, e))
428+
cache[job]._set(idx, (False, e))
392429
except KeyError:
393430
pass
394431
else:
395432
if set_length:
396433
util.debug('doing set_length()')
397-
set_length(i+1)
434+
idx = task[1] if task else -1
435+
set_length(idx + 1)
398436
continue
399437
break
400-
except Exception as ex:
401-
job, ind = task[:2] if task else (0, 0)
402-
if job in cache:
403-
cache[job]._set(ind + 1, (False, ex))
404-
if set_length:
405-
util.debug('doing set_length()')
406-
set_length(i+1)
407438
finally:
408439
task = taskseq = job = None
409440
else:

Lib/test/_test_multiprocessing.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1685,6 +1685,8 @@ def __del__(self):
16851685
class SayWhenError(ValueError): pass
16861686

16871687
def exception_throwing_generator(total, when):
1688+
if when == -1:
1689+
raise SayWhenError("Somebody said when")
16881690
for i in range(total):
16891691
if i == when:
16901692
raise SayWhenError("Somebody said when")
@@ -1763,6 +1765,32 @@ def test_map_chunksize(self):
17631765
except multiprocessing.TimeoutError:
17641766
self.fail("pool.map_async with chunksize stalled on null list")
17651767

1768+
def test_map_handle_iterable_exception(self):
1769+
if self.TYPE == 'manager':
1770+
self.skipTest('test not appropriate for {}'.format(self.TYPE))
1771+
1772+
# SayWhenError seen at the very first of the iterable
1773+
with self.assertRaises(SayWhenError):
1774+
self.pool.map(sqr, exception_throwing_generator(1, -1), 1)
1775+
# again, make sure it's reentrant
1776+
with self.assertRaises(SayWhenError):
1777+
self.pool.map(sqr, exception_throwing_generator(1, -1), 1)
1778+
1779+
with self.assertRaises(SayWhenError):
1780+
self.pool.map(sqr, exception_throwing_generator(10, 3), 1)
1781+
1782+
class SpecialIterable:
1783+
def __iter__(self):
1784+
return self
1785+
def __next__(self):
1786+
raise SayWhenError
1787+
def __len__(self):
1788+
return 1
1789+
with self.assertRaises(SayWhenError):
1790+
self.pool.map(sqr, SpecialIterable(), 1)
1791+
with self.assertRaises(SayWhenError):
1792+
self.pool.map(sqr, SpecialIterable(), 1)
1793+
17661794
def test_async(self):
17671795
res = self.pool.apply_async(sqr, (7, TIMEOUT1,))
17681796
get = TimingWrapper(res.get)
@@ -1793,6 +1821,13 @@ def test_imap_handle_iterable_exception(self):
17931821
if self.TYPE == 'manager':
17941822
self.skipTest('test not appropriate for {}'.format(self.TYPE))
17951823

1824+
# SayWhenError seen at the very first of the iterable
1825+
it = self.pool.imap(sqr, exception_throwing_generator(1, -1), 1)
1826+
self.assertRaises(SayWhenError, it.__next__)
1827+
# again, make sure it's reentrant
1828+
it = self.pool.imap(sqr, exception_throwing_generator(1, -1), 1)
1829+
self.assertRaises(SayWhenError, it.__next__)
1830+
17961831
it = self.pool.imap(sqr, exception_throwing_generator(10, 3), 1)
17971832
for i in range(3):
17981833
self.assertEqual(next(it), i*i)
@@ -1819,6 +1854,17 @@ def test_imap_unordered_handle_iterable_exception(self):
18191854
if self.TYPE == 'manager':
18201855
self.skipTest('test not appropriate for {}'.format(self.TYPE))
18211856

1857+
# SayWhenError seen at the very first of the iterable
1858+
it = self.pool.imap_unordered(sqr,
1859+
exception_throwing_generator(1, -1),
1860+
1)
1861+
self.assertRaises(SayWhenError, it.__next__)
1862+
# again, make sure it's reentrant
1863+
it = self.pool.imap_unordered(sqr,
1864+
exception_throwing_generator(1, -1),
1865+
1)
1866+
self.assertRaises(SayWhenError, it.__next__)
1867+
18221868
it = self.pool.imap_unordered(sqr,
18231869
exception_throwing_generator(10, 3),
18241870
1)
@@ -1900,7 +1946,7 @@ def test_traceback(self):
19001946
except Exception as e:
19011947
exc = e
19021948
else:
1903-
raise AssertionError('expected RuntimeError')
1949+
self.fail('expected RuntimeError')
19041950
self.assertIs(type(exc), RuntimeError)
19051951
self.assertEqual(exc.args, (123,))
19061952
cause = exc.__cause__
@@ -1914,6 +1960,17 @@ def test_traceback(self):
19141960
sys.excepthook(*sys.exc_info())
19151961
self.assertIn('raise RuntimeError(123) # some comment',
19161962
f1.getvalue())
1963+
# _helper_reraises_exception should not make the error
1964+
# a remote exception
1965+
with self.Pool(1) as p:
1966+
try:
1967+
p.map(sqr, exception_throwing_generator(1, -1), 1)
1968+
except Exception as e:
1969+
exc = e
1970+
else:
1971+
self.fail('expected SayWhenError')
1972+
self.assertIs(type(exc), SayWhenError)
1973+
self.assertIs(exc.__cause__, None)
19171974

19181975
@classmethod
19191976
def _test_wrapped_exception(cls):

Misc/NEWS

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ Extension Modules
4646
Library
4747
-------
4848

49+
- bpo-28699: Fixed a bug in pools in multiprocessing.pool that raising an
50+
exception at the very first of an iterable may swallow the exception or
51+
make the program hang. Patch by Davin Potts and Xiang Zhang.
52+
4953
- bpo-25803: Avoid incorrect errors raised by Path.mkdir(exist_ok=True)
5054
when the OS gives priority to errors such as EACCES over EEXIST.
5155

0 commit comments

Comments
 (0)