Skip to content

bpo-22393: Fix multiprocessing.Pool hangs if a worker process dies unexpectedly #10441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
14 changes: 14 additions & 0 deletions Doc/library/multiprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2108,6 +2108,12 @@ with the :class:`Pool` class.
.. versionadded:: 3.4
*context*

.. versionchanged:: 3.8
When one of the worker processes terminates abruptly (e.g. the
Out Of Memory Killer of linux kicked in), a :exc:`BrokenProcessPool`
error is now raised. Previously, behavior was undefined and
the :class:`Pool` or its workers would often freeze or deadlock.

.. note::

Worker processes within a :class:`Pool` typically live for the complete
Expand Down Expand Up @@ -2225,6 +2231,14 @@ with the :class:`Pool` class.
:ref:`typecontextmanager`. :meth:`~contextmanager.__enter__` returns the
pool object, and :meth:`~contextmanager.__exit__` calls :meth:`terminate`.

.. exception:: BrokenProcessPool

Derived from :exc:`RuntimeError`, this exception class is raised when
one of the workers of a :class:`Pool` has terminated in a non-clean
fashion (for example, if it was killed from the outside).

.. versionadded:: 3.8


.. class:: AsyncResult

Expand Down
98 changes: 81 additions & 17 deletions Lib/multiprocessing/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,29 @@
RUN = "RUN"
CLOSE = "CLOSE"
TERMINATE = "TERMINATE"
BROKEN = "BROKEN"

#
# Miscellaneous
#

job_counter = itertools.count()


def mapstar(args):
return list(map(*args))


def starmapstar(args):
return list(itertools.starmap(args[0], args[1]))


class BrokenProcessPool(RuntimeError):
"""
Raised when a process in a ProcessPoolExecutor terminated abruptly
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe avoid using ProcessPoolExecutor and future terms, which are objects of the concurrent.futures package and not the multiprocessing package.

while a future was in the running state.
"""

#
# Hack to embed stringification of remote traceback in local traceback
#
Expand Down Expand Up @@ -104,6 +114,7 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None,
if initializer is not None:
initializer(*initargs)

util.debug('worker started')
completed = 0
while maxtasks is None or (maxtasks and completed < maxtasks):
try:
Expand Down Expand Up @@ -189,6 +200,7 @@ def __init__(self, processes=None, initializer=None, initargs=(),
)
self._worker_handler.daemon = True
self._worker_handler._state = RUN
self._worker_state_lock = self._ctx.Lock()
self._worker_handler.start()


Expand Down Expand Up @@ -225,17 +237,31 @@ def __repr__(self):

def _join_exited_workers(self):
"""Cleanup after any worker processes which have exited due to reaching
their specified lifetime. Returns True if any workers were cleaned up.
their specified lifetime.
Returns the number of workers that were cleaned up.
Returns None if the process pool is broken.
"""
cleaned = False
for i in reversed(range(len(self._pool))):
worker = self._pool[i]
if worker.exitcode is not None:
cleaned = 0
broken = False
for i, p in reversed(list(enumerate(self._pool))):
broken = broken or (p.exitcode not in (None, 0))
if p.exitcode is not None:
# worker exited
util.debug('cleaning up worker %d' % i)
worker.join()
cleaned = True
p.join()
cleaned += 1
del self._pool[i]

if broken:
# Stop all workers
util.info('worker handler: process pool is broken, terminating workers...')
for p in self._pool:
if p.exitcode is None:
p.terminate()
for p in self._pool:
p.join()
del self._pool[:]
return None
return cleaned

def _repopulate_pool(self):
Expand All @@ -256,11 +282,21 @@ def _repopulate_pool(self):
util.debug('added worker')

def _maintain_pool(self):
"""Clean up any exited workers and start replacements for them.
"""
if self._join_exited_workers():
"""Clean up any exited workers and start replacements for them."""
need_repopulate = self._join_exited_workers()
if need_repopulate:
self._repopulate_pool()

if need_repopulate is None:
with self._worker_state_lock:
self._worker_handler._state = BROKEN

err = BrokenProcessPool(
'A worker in the pool terminated abruptly.')
# Exhaust MapResult with errors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also applies to ApplyResult right?

for i, cache_ent in list(self._cache.items()):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, is there any reason why we iterate on a list of of self._cache?

cache_ent._set_all((False, err))

def _setup_queues(self):
self._inqueue = self._ctx.SimpleQueue()
self._outqueue = self._ctx.SimpleQueue()
Expand Down Expand Up @@ -419,6 +455,7 @@ def _map_async(self, func, iterable, mapper, chunksize=None, callback=None,
@staticmethod
def _handle_workers(pool):
thread = threading.current_thread()
util.debug('worker handler entering')

# Keep maintaining workers until the cache gets drained, unless the pool
# is terminated.
Expand All @@ -432,6 +469,7 @@ def _handle_workers(pool):
@staticmethod
def _handle_tasks(taskqueue, put, outqueue, pool, cache):
thread = threading.current_thread()
util.debug('task handler entering')

for taskseq, set_length in iter(taskqueue.get, None):
task = None
Expand Down Expand Up @@ -477,6 +515,7 @@ def _handle_tasks(taskqueue, put, outqueue, pool, cache):

@staticmethod
def _handle_results(outqueue, get, cache):
util.debug('result handler entering')
thread = threading.current_thread()

while 1:
Expand Down Expand Up @@ -553,7 +592,10 @@ def close(self):
util.debug('closing pool')
if self._state == RUN:
self._state = CLOSE
self._worker_handler._state = CLOSE
# Avert race condition in broken pools
with self._worker_state_lock:
if self._worker_handler._state != BROKEN:
self._worker_handler._state = CLOSE

def terminate(self):
util.debug('terminating pool')
Expand Down Expand Up @@ -586,13 +628,21 @@ def _help_stuff_finish(inqueue, task_handler, size):
def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool,
worker_handler, task_handler, result_handler, cache):
# this is guaranteed to only be called once
util.debug('finalizing pool')
util.debug('terminate pool entering')
is_broken = BROKEN in (task_handler._state,
worker_handler._state,
result_handler._state)

worker_handler._state = TERMINATE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to use the _worker_state_lock here? And in other places where _worker_handler._state is manipulated?

task_handler._state = TERMINATE

util.debug('helping task handler/workers to finish')
cls._help_stuff_finish(inqueue, task_handler, len(pool))
# Skip _help_finish_stuff if the pool is broken, because
# the broken process may have been holding the inqueue lock.
if not is_broken:
util.debug('helping task handler/workers to finish')
cls._help_stuff_finish(inqueue, task_handler, len(pool))
else:
util.debug('finishing BROKEN process pool')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens here if the task_handler is blocked, but we do not run _help_stuff_finish?


if (not result_handler.is_alive()) and (len(cache) != 0):
raise AssertionError(
Expand All @@ -603,8 +653,8 @@ def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool,

# We must wait for the worker handler to exit before terminating
# workers because we don't want workers to be restarted behind our back.
util.debug('joining worker handler')
if threading.current_thread() is not worker_handler:
util.debug('joining worker handler')
worker_handler.join()

# Terminate workers which haven't already finished.
Expand All @@ -614,12 +664,12 @@ def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool,
if p.exitcode is None:
p.terminate()

util.debug('joining task handler')
if threading.current_thread() is not task_handler:
util.debug('joining task handler')
task_handler.join()

util.debug('joining result handler')
if threading.current_thread() is not result_handler:
util.debug('joining result handler')
result_handler.join()

if pool and hasattr(pool[0], 'terminate'):
Expand All @@ -629,6 +679,7 @@ def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool,
# worker has not yet exited
util.debug('cleaning up worker %d' % p.pid)
p.join()
util.debug('terminate pool finished')

def __enter__(self):
self._check_running()
Expand Down Expand Up @@ -680,6 +731,9 @@ def _set(self, i, obj):
self._event.set()
del self._cache[self._job]

def _set_all(self, obj):
self._set(0, obj)

AsyncResult = ApplyResult # create alias -- see #17805

#
Expand Down Expand Up @@ -723,6 +777,12 @@ def _set(self, i, success_result):
del self._cache[self._job]
self._event.set()

def _set_all(self, obj):
item = 0
while self._number_left > 0:
self._set(item, obj)
item += 1

#
# Class whose instances are returned by `Pool.imap()`
#
Expand Down Expand Up @@ -780,6 +840,10 @@ def _set(self, i, obj):
if self._index == self._length:
del self._cache[self._job]

def _set_all(self, obj):
while self._index != self._length:
self._set(self._index, obj)

def _set_length(self, length):
with self._cond:
self._length = length
Expand Down
108 changes: 108 additions & 0 deletions Lib/test/_test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2584,6 +2584,18 @@ def raising():
def unpickleable_result():
return lambda: 42

def bad_exit_os(value):
if value:
from os import _exit as exit
# from sys import exit
exit(123)

def bad_exit_sys(value):
if value:
from sys import exit
exit(123)


class _TestPoolWorkerErrors(BaseTestCase):
ALLOWED_TYPES = ('processes', )

Expand Down Expand Up @@ -2624,6 +2636,102 @@ def errback(exc):
p.close()
p.join()

def test_external_signal_kills_worker_apply_async(self):
"""mimics that a worker was killed from external signal"""
from multiprocessing.pool import BrokenProcessPool
p = multiprocessing.Pool(2)
res = p.apply_async(time.sleep, (5,))
res = p.apply_async(time.sleep, (2,))
res = p.apply_async(time.sleep, (1,))
# Kill one of the pool workers, after some have entered
# execution (hence, the 0.5s wait)
time.sleep(0.5)
pid = p._pool[0].pid
os.kill(pid, signal.SIGTERM)
with self.assertRaises(BrokenProcessPool):
res.get()
p.close()
p.join()

def test_external_signal_kills_worker_imap_unordered(self):
"""mimics that a worker was killed from external signal"""
from multiprocessing.pool import BrokenProcessPool
p = multiprocessing.Pool(2)
with self.assertRaises(BrokenProcessPool):
res = list(p.imap_unordered(bad_exit_os, [0, 0, 1, 0]))
p.close()
p.join()

def test_external_signal_kills_worker_map_async1(self):
"""mimics that a worker was killed from external signal"""
from multiprocessing.pool import BrokenProcessPool
p = multiprocessing.Pool(2)
res = p.map_async(time.sleep, [5] * 10)
# Kill one of the pool workers, after some have entered
# execution (hence, the 0.5s wait)
time.sleep(0.5)
pid = p._pool[0].pid
os.kill(pid, signal.SIGTERM)
with self.assertRaises(BrokenProcessPool):
res.get()
p.close()
p.join()

def test_external_signal_kills_worker_map_async2(self):
"""mimics that a worker was killed from external signal"""
from multiprocessing.pool import BrokenProcessPool
p = multiprocessing.Pool(2)
res = p.map_async(time.sleep, (2, ))
# Kill one of the pool workers.
pid = p._pool[0].pid
os.kill(pid, signal.SIGTERM)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

with self.assertRaises(BrokenProcessPool):
res.get()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You've only launched a single task, so what if it was scheduled on the other worker? I don't think this test is reliable.

Copy link
Author

@oesteban oesteban Dec 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point (without the patch) you would inevitably get the same behavior: the pool hangs forever. With this patch, if you launch a single task on a broken pool (or the pool will be broken before the result is collected), you'll get the BrokenPoolError, regardless of the worker that was killed. We could keep track of sane results and try to rescue the most, but the original fix didn't look into that and it might be subject for a different PR. Similarly (matter of another PR), we could identify when the pool could be recovered (e.g., the worker died when the pool was idle waiting for tasks).

p.close()
p.join()

def test_map_async_with_broken_pool(self):
"""submit task to a broken pool"""
from multiprocessing.pool import BrokenProcessPool
p = multiprocessing.Pool(2)
pid = p._pool[0].pid
res = p.map_async(time.sleep, (2, ))
# Kill one of the pool workers.
os.kill(pid, signal.SIGTERM)
with self.assertRaises(BrokenProcessPool):
res.get()
p.close()
p.join()

def test_internal_signal_kills_worker_map1(self):
from multiprocessing.pool import BrokenProcessPool
p = multiprocessing.Pool(2)
with self.assertRaises(BrokenProcessPool):
res = p.map(bad_exit_os, [0, 0, 1, 0])
p.close()
p.join()

def test_internal_signal_kills_worker_map2(self):
from multiprocessing.pool import BrokenProcessPool
p = multiprocessing.Pool(2)
with self.assertRaises(BrokenProcessPool):
res = p.map(bad_exit_sys, [0, 0, 1, 0])
p.close()
p.join()

def test_internal_signal_kills_worker_map_async3(self):
from multiprocessing.pool import BrokenProcessPool
p = multiprocessing.Pool(2)
res = p.map_async(time.sleep, [5] * 10)
# Kill one of the pool workers, after some have entered
# execution (hence, the 0.5s wait)
time.sleep(0.5)
p._pool[0].terminate()
with self.assertRaises(BrokenProcessPool):
res.get()
p.close()
p.join()

class _TestPoolWorkerLifetime(BaseTestCase):
ALLOWED_TYPES = ('processes', )

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix ``multiprocessing.Pool`` indefintely hang when a worker process dies
unexpectedly. Patch by Oscar Esteban, based on code from Dan O'Reilly.