# HG changeset patch # Parent fb907ee4fdfa47b49e10c8a2f9990721d4d8823b diff -r fb907ee4fdfa Doc/library/atfork.rst --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/Doc/library/atfork.rst Tue Apr 24 23:34:34 2012 +0100 @@ -0,0 +1,208 @@ +:mod:`atfork` --- Fork handlers +=============================== + +.. module:: atfork + :platform: Unix + :synopsis: Register handlers to be called before and after :func:`os.fork`. + + +The :mod:`atfork` module provides an :func:`atfork` function for +registering callbacks to be called before and after a process forks, +in the style of the POSIX function :manpage:`pthread_atfork(3)`. + +It also provides a recursive lock object which is held while the fork +and the callbacks are executed, allowing the thread safe manipulation +of file descriptors. + + +.. function:: atfork(prepare=None, parent=None, child=None) + + Register callbacks to be run before and after :func:`os.fork` (and + :func:`os.forkpty` if it exists). + + Note that starting processes using :mod:`multiprocessing` will also + trigger these callbacks on Unix, but that starting them using + :mod:`subprocess` will not. (The raw functions :func:`posix.fork` + and :func:`posix.forkpty` will not trigger the callbacks either.) + + The *prepare*, *parent* and *child* arguments should either be + ``None`` or a callback which takes no arguments. ``None`` is + treated as a trivial callback which does nothing. + + Callbacks registered using the *prepare* argument are run before + the process forks, in reverse order of registration. Callbacks + registered using the *parent* argument are run by the parent + process after the fork, in order of registration. Callbacks + registered with the *child* argument are run by the child process + after the fork, in order of registration. + + This means that if *prepare*, *parent* and *child* are all + callables, then doing :: + + atfork.atfork(prepare, parent, child) + + has a very similar effect on :func:`os.fork` as wrapping it with :: + + parent_pid = os.getpid() + prepare() + try: + ... + finally: + parent() if os.getpid() == parent_pid else child() + + Using :func:`atfork.atfork` multiple times has the effect of adding + multiple nested wrappers like this around the original + fork function. + + See :ref:`atfork-exceptions` for what happens if an exception is + raised by a callback or the fork fails. + + The *Fork Lock* -- see :func:`atfork.getlock` -- is held throughout + the time that the callbacks and fork are executed. + + +.. function:: getlock() + + Return a recursive lock object which must be held by a thread while + it forks a child process. This *Fork Lock* is also held by a thread + which starts a child process using the :mod:`subprocess` or + :mod:`multiprocessing` modules. Note, however, that there are some + other ways of starting processes which are not affected by the fork + lock, for instance :func:`os.system`. + + The lock is mainly intended for avoiding races when manipulating + file descriptors in a multithreaded program. See + :ref:`leaking-file-descriptors`. + + In a newly created child process started by :func:`os.fork`, + :func:`os.forkpty` or :class:`multiprocessing.Process` the fork + lock will be free. The child process may use a different fork lock + to its parent process. + + .. note:: + + It is generally best to avoid forking a multithreaded program. + Instead one should to try to use :mod:`subprocess` for starting new + processes. + + +.. _atfork-exceptions: + +Propogation of Exceptions +------------------------- + +The way that exceptions propogate is the same as if we had used nested +:keyword:`try`-:keyword:`finally` wrappers. Suppose we have done :: + + atfork.atfork(prepare1, parent1, child1) + atfork.atfork(prepare2, parent2, child2) + ... + atfork.atfork(prepareN, parentN, childN) + +If we ignore the fork lock then calling :func:`os.fork` does the +equivalent of :: + + pid = None + prepareN() + try: + ... + prepare2() + try: + prepare1() + try: + pid = posix.fork() + finally: + parent1() if pid != 0 else child1() + finally: + parent2() if pid != 0 else child2() + ... + finally: + parentN() if pid != 0 else childN() + +So if multiple exceptions occur then they will be implicitly chained +together using their :attr:`__context__` attributes -- see +:ref:`bltin-exceptions`. + +However, if the fork succeeds and one or more parent callbacks raise +exceptions then the chained traceback is printed to :data:`sys.stderr` +rather than raised. This ensures that the parent process does not +lose the opportunity to get the pid of the forked process. If one or +more child callbacks raises an error then the chained traceback is +printed to :data:`sys.stderr` and the child process exits with a +non-zero exit code. + + +.. _leaking-file-descriptors: + +Leaking File Descriptors +------------------------ + +Suppose we want to fork a child process and have it send data to the +parent process over a pipe. Then we might use a pattern like :: + + r, w = os.pipe() + + pid = os.fork() + if pid == 0: + os.close(r) + try: + child_func(w) + except: + os._exit(1) + else: + os._exit(0) + + os.close(w) + parent_func(r) + +Closing the writable end of the pipe in the parent process is +important since, in a single threaded program, it ensures that no +other process can own a reference to the writable end of the pipe. +So when the child exits or explicitly closes the writable end, no more +data can ever be written to the pipe. This means that reading from +the pipe will not block if it is empty; instead zero bytes are +returned indicating end-of-file. + +However, in a multithreaded program there is the possibility that a +process might be started by another thread after the pipe was created, +but before the parent process closed its copy of the writable end. +This will mean that the file descriptor has been "leaked" to that +second child process. End-of-file will not be signalled at the +readable end of the pipe until the second child process exits. + +To avoid the possibility of leaking the file descriptor one can use +the fork lock as follows:: + + with atfork.getlock(): + r, w = os.pipe() + + pid = os.fork() + if pid == 0: + os.close(r) + try: + child_func(w) + except: + os._exit(1) + else: + os._exit(0) + + os.close(w) + + parent_func(r) + +Using this pattern it does not matter that another thread might start +a process using :func:`os.fork`, :mod:`subprocess` or +:mod:`multiprocessing`. Having said that, processes started using +:class:`subprocess.Popen` with the *closefds* argument set to true +(the default) will not accidentally inherit file descriptors. + +Another potential race is where one wants to set the +:const:`fcntl.FD_CLOEXEC` flag on a newly created file descriptor +before a forked process can inherit it. The race can be fixed by +holding the fork lock while creating and modifing the file descriptor. +For example:: + + with atfork.getlock(): + fd = os.open("somefile", os.O_CREAT | os.O_WRONLY, 0600) + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flags | fcntl.FD_CLOEXEC) diff -r fb907ee4fdfa Doc/library/os.rst --- a/Doc/library/os.rst Thu Apr 05 00:04:20 2012 +0200 +++ b/Doc/library/os.rst Tue Apr 24 23:34:34 2012 +0100 @@ -2574,6 +2574,13 @@ Availability: Unix. + .. versionchanged:: 3.3 + The :mod:`atfork` module allows the registration of callbacks + to be called before and after a process forks. It also exposes + a recursive Fork Lock which is held when :func:`os.fork` forks + the process. The raw fork function is still available as + :func:`posix.fork`. + .. function:: forkpty() @@ -2585,6 +2592,10 @@ Availability: some flavors of Unix. + .. versionchanged:: 3.3 + The :mod:`atfork` module applies to :func:`os.forkpty` in the + same way that it does to :func:`os.fork`. + .. function:: kill(pid, sig) diff -r fb907ee4fdfa Doc/library/someos.rst --- a/Doc/library/someos.rst Thu Apr 05 00:04:20 2012 +0200 +++ b/Doc/library/someos.rst Tue Apr 24 23:34:34 2012 +0100 @@ -19,6 +19,7 @@ mmap.rst readline.rst rlcompleter.rst + atfork.rst dummy_threading.rst _thread.rst _dummy_thread.rst diff -r fb907ee4fdfa Lib/atfork.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/Lib/atfork.py Tue Apr 24 23:34:34 2012 +0100 @@ -0,0 +1,148 @@ +''' +Module for registering callbacks to be run before and after fork(), +modelled after pthread_atfork(). Unix only. + +Access is also given to a recursive lock object which is held whenever +child processes are created (using os.fork(), subprocess.Popen() or +multiprocessing.Process()). Among other things, in a multithreaded +program this allows one to reliably control which processes inherit +the file descriptors one creates. For example, on can do + + with atfork.getlock(): + r, w = os.pipe() + pid = os.fork() + if pid == 0: + os.close(r) + try: + # do something that writes to w + except: + os._exit(1) + else: + os._exit(0) + else: + os.close(w) + + # read from r and expect EOF as soon as the process we just forked exits +''' + +__all__ = ['atfork', 'getlock'] + +import sys +import posix + +try: + from _thread import RLock as _RLock +except ImportError: + class _RLock: + def acquire(self, blocking=True): + return True + def release(self, *args): + pass + __enter__ = acquire + __exit__ = release + + +class AtforkContext: + # intial values for instances + pid = None + exception = None + + # class data + callbacks = [] + lock = _RLock() + _garbage = [] + + def __enter__(self): + self._saved_parents = [] + AtforkContext.lock.acquire() + try: + for prepare, parent, child in reversed(self.callbacks): + if prepare is not None: + prepare() + if parent is not None: + self._saved_parents.append(parent) + except BaseException as e: + self._chain_new_exception(e) + return self + + def __exit__(self, type, value, tb): + assert (self.pid is not None or self.exception is not None + or value is not None) + try: + if value is not None: + self._chain_new_exception(value) # fork failed + if self.pid == 0: + self._exit_child() + else: + self._exit_parent() + except BaseException as e: + self._chain_new_exception(e) # presumably got Ctrl-C + finally: + AtforkContext.lock.release() + return True + + def _exit_parent(self): + for parent in reversed(self._saved_parents): + try: + parent() + except BaseException as e: + self._chain_new_exception(e) + + def _exit_child(self): + # finalizer of old lock might cause segfault -- see Issue 6721 -- + # so prevent garbage collection of old lock and replace by new one + AtforkContext._garbage.append(AtforkContext.lock) + AtforkContext.lock = _RLock() + AtforkContext.lock.acquire() + + for prepare, parent, child in self.callbacks: + if child is not None: + try: + child() + except BaseException as e: + self._chain_new_exception(e) + + def _chain_new_exception(self, e): + e.__context__ = self.exception + self.exception = e + + +def excepthook(e, pid): + if pid == 0: + # child process + try: + print('Exception from atfork child callback ignored:', + file=sys.stderr) + sys.__excepthook__(type(e), e, e.__traceback__) + sys.stderr.close() + finally: + posix._exit(1) + elif pid is None: + # parent process; fork failed or never called + raise e + else: + # parent process; fork succeeded + print('Exception from atfork parent callback ignored:', + file=sys.stderr) + sys.__excepthook__(type(e), e, e.__traceback__) + sys.stderr.flush() + + +def atfork(prepare=None, parent=None, child=None): + ''' + Register callbacks to be run before and after fork() + + prepare callbacks are run before calling fork() in LIFO order. + parent callbacks are run by the parent process after calling + fork() in FIFO order. child callbacks are run by the child + process after it forks in FIFO order. + ''' + with AtforkContext.lock: + AtforkContext.callbacks.append((prepare, parent, child)) + + +def getlock(): + ''' + Return recursive fork lock + ''' + return AtforkContext.lock diff -r fb907ee4fdfa Lib/multiprocessing/forking.py --- a/Lib/multiprocessing/forking.py Thu Apr 05 00:04:20 2012 +0200 +++ b/Lib/multiprocessing/forking.py Tue Apr 24 23:34:34 2012 +0100 @@ -100,6 +100,7 @@ # if sys.platform != 'win32': + import atfork import select exit = os._exit @@ -119,21 +120,23 @@ sys.stderr.flush() self.returncode = None - r, w = os.pipe() - self.sentinel = r + with atfork.getlock(): + r, w = os.pipe() + self.sentinel = r - self.pid = os.fork() - if self.pid == 0: - os.close(r) - if 'random' in sys.modules: - import random - random.seed() - code = process_obj._bootstrap() - os._exit(code) + self.pid = os.fork() + if self.pid == 0: + os.close(r) + if 'random' in sys.modules: + import random + random.seed() + code = process_obj._bootstrap() + os._exit(code) - # `w` will be closed when the child exits, at which point `r` - # will become ready for reading (using e.g. select()). - os.close(w) + # `w` will be closed when the child exits, at which point `r` + # will become ready for reading (using e.g. select()). + os.close(w) + util.Finalize(self, os.close, (r,)) def poll(self, flag=os.WNOHANG): diff -r fb907ee4fdfa Lib/os.py --- a/Lib/os.py Thu Apr 05 00:04:20 2012 +0200 +++ b/Lib/os.py Tue Apr 24 23:34:34 2012 +0100 @@ -1,4 +1,4 @@ -r"""OS routines for Mac, NT, or Posix depending on what system we're on. +"""OS routines for Mac, NT, or Posix depending on what system we're on. This exports: - all functions from posix, nt, os2, or ce, e.g. unlink, stat, etc. @@ -903,3 +903,35 @@ raise TypeError("invalid fd type (%s, expected integer)" % type(fd)) import io return io.open(fd, *args, **kwargs) + +# replace fork() and forkpty() by wrappers which call atfork callbacks +if name == "posix": + if _exists("fork"): + doc = fork.__doc__ + + def fork(): + import atfork, posix + with atfork.AtforkContext() as ctx: + if ctx.exception is None: + ctx.pid = posix.fork() + if ctx.exception is not None: + atfork.excepthook(ctx.exception, ctx.pid) + return ctx.pid + + fork.__doc__ = doc + del doc + + if _exists("forkpty"): + doc = forkpty.__doc__ + + def forkpty(): + import atfork, posix + with atfork.AtforkContext() as ctx: + if ctx.exception is None: + ctx.pid, fd = posix.forkpty() + if ctx.exception is not None: + atfork.excepthook(ctx.exception, ctx.pid) + return ctx.pid, fd + + forkpty.__doc__ = doc + del doc diff -r fb907ee4fdfa Lib/subprocess.py --- a/Lib/subprocess.py Thu Apr 05 00:04:20 2012 +0200 +++ b/Lib/subprocess.py Tue Apr 24 23:34:34 2012 +0100 @@ -394,11 +394,13 @@ wShowWindow = 0 class pywintypes: error = IOError + atfork = None else: import select _has_poll = hasattr(select, 'poll') import _posixsubprocess _create_pipe = _posixsubprocess.cloexec_pipe + import atfork # When select or poll has indicated that the file is writable, # we can write up to _PIPE_BUF bytes without risk of blocking. @@ -708,68 +710,79 @@ self.returncode = None self.universal_newlines = universal_newlines - # Input and output objects. The general principle is like - # this: - # - # Parent Child - # ------ ----- - # p2cwrite ---stdin---> p2cread - # c2pread <--stdout--- c2pwrite - # errread <--stderr--- errwrite - # - # On POSIX, the child objects are file descriptors. On - # Windows, these are Windows file handles. The parent objects - # are file descriptors on both platforms. The parent objects - # are -1 when not using PIPEs. The child objects are -1 - # when not redirecting. + if atfork is not None: + # to avoid accidental inheritance of fds by processes + # forked by other threads we acquire the recursive fork lock + forklock = atfork.getlock() + forklock.acquire() + else: + forklock = None + try: + # Input and output objects. The general principle is like + # this: + # + # Parent Child + # ------ ----- + # p2cwrite ---stdin---> p2cread + # c2pread <--stdout--- c2pwrite + # errread <--stderr--- errwrite + # + # On POSIX, the child objects are file descriptors. On + # Windows, these are Windows file handles. The parent objects + # are file descriptors on both platforms. The parent objects + # are -1 when not using PIPEs. The child objects are -1 + # when not redirecting. - (p2cread, p2cwrite, - c2pread, c2pwrite, - errread, errwrite) = self._get_handles(stdin, stdout, stderr) + (p2cread, p2cwrite, + c2pread, c2pwrite, + errread, errwrite) = self._get_handles(stdin, stdout, stderr) - # We wrap OS handles *before* launching the child, otherwise a - # quickly terminating child could make our fds unwrappable - # (see #8458). + # We wrap OS handles *before* launching the child, otherwise a + # quickly terminating child could make our fds unwrappable + # (see #8458). - if mswindows: + if mswindows: + if p2cwrite != -1: + p2cwrite = msvcrt.open_osfhandle(p2cwrite.Detach(), 0) + if c2pread != -1: + c2pread = msvcrt.open_osfhandle(c2pread.Detach(), 0) + if errread != -1: + errread = msvcrt.open_osfhandle(errread.Detach(), 0) + if p2cwrite != -1: - p2cwrite = msvcrt.open_osfhandle(p2cwrite.Detach(), 0) + self.stdin = io.open(p2cwrite, 'wb', bufsize) + if self.universal_newlines: + self.stdin = io.TextIOWrapper(self.stdin, write_through=True) if c2pread != -1: - c2pread = msvcrt.open_osfhandle(c2pread.Detach(), 0) + self.stdout = io.open(c2pread, 'rb', bufsize) + if universal_newlines: + self.stdout = io.TextIOWrapper(self.stdout) if errread != -1: - errread = msvcrt.open_osfhandle(errread.Detach(), 0) + self.stderr = io.open(errread, 'rb', bufsize) + if universal_newlines: + self.stderr = io.TextIOWrapper(self.stderr) - if p2cwrite != -1: - self.stdin = io.open(p2cwrite, 'wb', bufsize) - if self.universal_newlines: - self.stdin = io.TextIOWrapper(self.stdin, write_through=True) - if c2pread != -1: - self.stdout = io.open(c2pread, 'rb', bufsize) - if universal_newlines: - self.stdout = io.TextIOWrapper(self.stdout) - if errread != -1: - self.stderr = io.open(errread, 'rb', bufsize) - if universal_newlines: - self.stderr = io.TextIOWrapper(self.stderr) + try: + self._execute_child(args, executable, preexec_fn, close_fds, + pass_fds, cwd, env, universal_newlines, + startupinfo, creationflags, shell, + p2cread, p2cwrite, + c2pread, c2pwrite, + errread, errwrite, + restore_signals, start_new_session) + except: + # Cleanup if the child failed starting + for f in filter(None, [self.stdin, self.stdout, self.stderr]): + try: + f.close() + except EnvironmentError: + # Ignore EBADF or other errors + pass + raise - try: - self._execute_child(args, executable, preexec_fn, close_fds, - pass_fds, cwd, env, universal_newlines, - startupinfo, creationflags, shell, - p2cread, p2cwrite, - c2pread, c2pwrite, - errread, errwrite, - restore_signals, start_new_session) - except: - # Cleanup if the child failed starting - for f in filter(None, [self.stdin, self.stdout, self.stderr]): - try: - f.close() - except EnvironmentError: - # Ignore EBADF or other errors - pass - raise - + finally: + if forklock is not None: + forklock.release() def _translate_newlines(self, data, encoding): data = data.replace(b"\r\n", b"\n").replace(b"\r", b"\n") diff -r fb907ee4fdfa Lib/test/test_atfork.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/Lib/test/test_atfork.py Tue Apr 24 23:34:34 2012 +0100 @@ -0,0 +1,374 @@ +import os + +if not hasattr(os, 'fork'): + raise unittest.SkipTest('fork required') + +import re +import sys +import unittest +import time +import signal +import errno +import io +import contextlib +import atfork +import posix + +try: + import multiprocessing as mp +except ImportError: + mp = None + +try: + import threading +except ImportError: + threading = None + +try: + import subprocess +except ImportError: + subprocess = None + + +def writeall(fd, buf): + buf = buf.encode('utf-8') + while buf: + n = os.write(fd, buf) + buf = buf[n:] + +def readall_and_close(fd): + data = [] + while True: + s = os.read(fd, 512) + if not s: + break + data.append(s) + os.close(fd) + return b''.join(data).decode('utf-8') + +def printer(msg): + return lambda : print(msg) + +def THROWER(msg): + def throw(): + raise RuntimeError(msg) + return throw + +def chain_as_list(e): + # unpack implicit exception chain as list, oldest first + L = [] + while e: + L.append(e) + e = e.__context__ + L.reverse() + return L + +@contextlib.contextmanager +def fresh_atfork_context(): + with atfork.getlock(): + old_callbacks = atfork.AtforkContext.callbacks + atfork.AtforkContext.callbacks = [] + yield + atfork.AtforkContext.callbacks = old_callbacks + + +@unittest.skipUnless(hasattr(os, 'fork'), 'requires os.fork()') +class TestAtfork(unittest.TestCase): + + THREADS = 5 + PERIOD = 10.0 + DELTA = 0.01 + + def _start_using_fork(self, q): + with atfork.getlock(): + r, w = os.pipe() + pid = os.fork() + if pid == 0: + # child + os.close(r) + data = str(os.getpid()) + writeall(w, data) + os.close(w) + time.sleep(self.PERIOD) + os._exit(0) + + # parent + time.sleep(self.DELTA) # sleep to make race easy to hit + os.close(w) + + data = readall_and_close(r) + read_pid = int(data) + q.put((pid, read_pid)) + + @classmethod + def _start_using_mp_inner(cls, r, w): + r.close() + w.send(os.getpid()) + w.close() + time.sleep(cls.PERIOD) + + def _start_using_mp(self, q): + with atfork.getlock(): + r, w = mp.Pipe(duplex=False) + p = mp.Process(target=self._start_using_mp_inner, args=(r, w)) + p.daemon = True + p.start() + time.sleep(self.DELTA) # sleep to make race easy to hit + w.close() + + read_pid = r.recv() + q.put((p.pid, read_pid)) + + def _start_using_sp(self, q): + p = subprocess.Popen(['/bin/sleep', str(self.PERIOD)], + close_fds=False) + q.put((p.pid, p.pid)) + + @unittest.skipUnless(threading, 'requires threading') + @unittest.skipUnless(mp, 'requires multiprocessing') + @unittest.skipUnless(mp, 'requires subprocess') + def test_lock(self): + # We want to know that holding the fork lock before creating + # pipe fds ensures that no forked process accidently inherits + # a writable fd that was not intended for it. Otherwise there + # will be a long delay before we get EOF when reading from the + # associated pipe. + import queue + fork_threads = [] + mp_threads = [] + sp_threads = [] + q = queue.Queue() + + for i in range(self.THREADS): + t = threading.Thread(target=self._start_using_fork, args=(q,)) + fork_threads.append(t) + + for i in range(self.THREADS): + t = threading.Thread(target=self._start_using_sp, args=(q,)) + sp_threads.append(t) + + for i in range(self.THREADS): + t = threading.Thread(target=self._start_using_mp, args=(q,)) + mp_threads.append(t) + + num_threads = len(fork_threads) + len(mp_threads) + len(sp_threads) + pids = [] + start = time.time() + + for i in range(self.THREADS): + fork_threads[i].start() + sp_threads[i].start() + mp_threads[i].start() + + for i in range(num_threads): + pid, read_pid = q.get(timeout=20) + self.assertEqual(pid, read_pid) + pids.append(pid) + + elapsed = time.time() - start + expected = self.DELTA * (len(fork_threads) + len(mp_threads)) + + # If any of the writables pipe handles gets accidentally + # inherited by a child process due to a race, then `elapsed` + # will be approximately self.PERIOD (i.e. the time that the + # child process lives, preventing EOF on that pipe). + # `expected` is the total amount of that the threads slept + # while holding the lock. + self.assertLess(elapsed, expected + 1.0) + + for pid in pids: + try: + os.kill(pid, signal.SIGTERM) + except OSError as e: + if e.errno != errno.ESRCH: + raise + + def test_recursive(self): + level = 1 + + def increment_level(): + nonlocal level + level += 1 + + def f(msg): + def inner(): + writeall(w, '%s %s\n' % (level, msg)) + return inner + + with fresh_atfork_context(): + atfork.atfork(None, None, increment_level) + atfork.atfork(f('prepare1'), f('parent1'), f('child1')) + atfork.atfork(f('prepare2'), f('parent2'), f('child2')) + + r, w = os.pipe() + pid1 = os.fork() + if pid1 == 0: + pid2 = os.fork() + if pid2 == 0: + pid3 = os.fork() + if pid3 == 0: + os._exit(0) + os._exit(0) + os._exit(0) + + os.close(w) + + # Note that POSIX guarantees that writes of less than 512 + # bytes to a pipe are atomic. Therefore the messages cannot be + # interleaved. + lines = readall_and_close(r).strip().split('\n') + + # Sort lines by 'level'. We assume sorting is stable, i.e. it + # does not change order of items of equal key. + lines.sort(key=lambda s:s.split()[0]) + + expected = ''' +1 prepare2 +1 prepare1 +1 parent1 +1 parent2 +2 child1 +2 child2 +2 prepare2 +2 prepare1 +2 parent1 +2 parent2 +3 child1 +3 child2 +3 prepare2 +3 prepare1 +3 parent1 +3 parent2 +4 child1 +4 child2 +'''.strip().split('\n') + + self.assertEqual(lines, expected) + + def test_fork_never_attempted(self): + messages = [] + def appender(msg): + return lambda : messages.append(msg) + + with fresh_atfork_context(): + atfork.atfork(THROWER('before 1'), THROWER('after 1'), None) + atfork.atfork(THROWER('before 2'), THROWER('after 2'), None) + atfork.atfork(appender('before 3'), appender('after 3'), None) + atfork.atfork(appender('before 4'), THROWER('after 4'), None) + atfork.atfork(appender('before 5'), THROWER('after 5'), None) + + with self.assertRaises(RuntimeError) as ctx: + if os.fork() == 0: + os._exit(1) + + # check exception chain + errors = [e.args[0] for e in chain_as_list(ctx.exception)] + expected_errors = ['before 2', 'after 4', 'after 5'] + self.assertEqual(errors, expected_errors) + + # check messages + expected_messages = ['before 5', 'before 4', 'before 3', 'after 3'] + self.assertEqual(messages, expected_messages) + + def test_fork_failed(self): + messages = [] + def appender(msg): + return lambda : messages.append(msg) + + def broken_fork(): + raise RuntimeError('fork failed') + + with fresh_atfork_context(): + atfork.atfork(appender('before 1'), THROWER('after 1'), None) + atfork.atfork(appender('before 2'), THROWER('after 2'), None) + atfork.atfork(appender('before 3'), appender('after 3'), None) + atfork.atfork(appender('before 4'), THROWER('after 4'), None) + + # temporarily break posix.fork() + posix_fork = posix.fork + posix.fork = broken_fork + try: + with self.assertRaises(RuntimeError) as ctx: + if os.fork() == 0: + os._exit(1) + finally: + posix.fork = posix_fork + + # check exception chain + errors = [e.args[0] for e in chain_as_list(ctx.exception)] + expected_errors = ['fork failed', 'after 1', 'after 2', 'after 4'] + self.assertEqual(errors, expected_errors) + + # check messages + expected_messages = ['before 4', 'before 3', 'before 2', + 'before 1', 'after 3'] + self.assertEqual(messages, expected_messages) + + def test_failure_after_fork(self): + def flush(): + sys.stdout.flush() + sys.stderr.flush() + + def redirect(): + sys.stdout = io.open(child_stdout_w, mode='w', encoding='utf-8') + sys.stderr = io.open(child_stderr_w, mode='w', encoding='utf-8') + + messages = [] + def appender(msg): + return lambda: messages.append(msg) + + with fresh_atfork_context(): + atfork.atfork(flush, None, redirect) + atfork.atfork(None, appender('parent 1'), THROWER('child 1')) + atfork.atfork(None, THROWER('parent 2'), printer('child 2')) + atfork.atfork(None, appender('parent 3'), THROWER('child 3')) + atfork.atfork(None, THROWER('parent 4'), flush) + + child_stdout_r, child_stdout_w = os.pipe() + child_stderr_r, child_stderr_w = os.pipe() + old_stderr = sys.stderr + sys.stderr = parent_stderr = io.StringIO() + try: + pid = os.fork() + if pid == 0: + print('should not get here') + sys.stdout.close() + sys.stderr.close() + os._exit(0) + finally: + sys.stderr = old_stderr + + os.close(child_stdout_w) + os.close(child_stderr_w) + + parent_stderr_text = parent_stderr.getvalue() + child_stdout_text = readall_and_close(child_stdout_r) + child_stderr_text = readall_and_close(child_stderr_r) + + # check child's stdout + self.assertEqual(child_stdout_text, 'child 2\n') + + # check child's stderr + pat = re.compile( + '^Exception from atfork child callback ignored:\n.*' + + '\nRuntimeError: child 1\n.*' + + '\nRuntimeError: child 3\n$', re.DOTALL) + self.assertRegex(child_stderr_text, pat) + + # check parent's messages + self.assertEqual(messages, ['parent 1', 'parent 3']) + + # check parent's stderr + pat = re.compile( + '^Exception from atfork parent callback ignored:\n.*' + + '\nRuntimeError: parent 2\n.*' + + '\nRuntimeError: parent 4\n$', re.DOTALL) + self.assertRegex(parent_stderr_text, pat) + + # check child's status non-zero + (pid, status) = os.waitpid(pid, 0) + self.assertNotEqual(pid, 0) + + +if __name__ == '__main__': + unittest.main()