# HG changeset patch # Parent f13af83967a0d711cd3fc28be1637e1294e3d223 diff -r f13af83967a0 Lib/atfork.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/Lib/atfork.py Mon Jan 23 20:18:16 2012 +0000 @@ -0,0 +1,106 @@ +''' +Module for registering callbacks to be run before and after fork() +modelled after pthread_atfork(). Unix only. + +Access is also given to a recursive lock object which is held whenever +child processes are created (using os.fork(), subrprocess.Popen() or +multiprocessing.Process()). Among other things, in a multithreaded +program this allows one to reliably control which processes inherit +the file descriptors one creates. For example, on can do + + with atfork.get_fork_lock(): + r, w = os.pipe() + pid = os.fork() + if pid == 0: + os.close(r) + try: + # do something that writes to w + except: + os._exit(1) + else: + os._exit(0) + else: + os.close(w) + + # read from r and expect EOF as soon as forked process exits +''' + +__all__ = ['atfork', 'get_fork_lock'] + + +from posix import fork as _posix_fork +try: + from threading import RLock as _RLock +except ImportError: + from dummy_threading import RLock as _RLock + + +_garbage = [] +_callbacks = [] +_fork_lock = _RLock() + + +def get_fork_lock(): + ''' + Return the recursive lock which is held when fork() runs + + Note that different processes may use different fork lock objects, + one should call this function each time you need to acquire the + fork lock. + ''' + return _fork_lock + + +def atfork(prepare=None, parent=None, child=None): + ''' + Register callbacks to be run before and after fork() + + prepare callbacks are run before calling fork() in LIFO order. + parent callbacks are run by the parent process after calling + fork() in FIFO order. child callbacks are run by the child + process after it forks in FIFO order. + ''' + with _fork_lock: + _callbacks.append((prepare, parent, child)) + + +def _fork(): + ''' + Wrapper for posix.fork() which uses a lock and supports atfork callbacks + ''' + global _fork_lock + _fork_lock.acquire() + try: + parent_callbacks = [] + pid = None + + # call parent callbacks in LIFO order + for (prepare, parent, child) in reversed(_callbacks): + if prepare is not None: + prepare() + parent_callbacks.append(parent) + + pid = _posix_fork() + if pid == 0: + # _fork_lock may now be broken, so we should not let it be + # garbage collected. We just stick the old lock in a + # garbage list and replace it by a new one. + _garbage.append(_fork_lock) + _fork_lock = _RLock() + + # call child callbacks in FIFO order + for (prepare, parent, child) in _callbacks: + if child is not None: + child() + finally: + if pid != 0: + try: + # We call a parent callback for each prepare callback + # which succeeded (but in the opposite order). + for parent in reversed(parent_callbacks): + if parent is not None: + parent() + finally: + _fork_lock.release() + + return pid diff -r f13af83967a0 Lib/multiprocessing/forking.py --- a/Lib/multiprocessing/forking.py Sun Jan 22 21:31:39 2012 +0100 +++ b/Lib/multiprocessing/forking.py Mon Jan 23 20:18:16 2012 +0000 @@ -100,6 +100,7 @@ # if sys.platform != 'win32': + import atfork import select exit = os._exit @@ -119,24 +120,25 @@ sys.stderr.flush() self.returncode = None - r, w = os.pipe() - self.sentinel = r + with atfork.get_fork_lock(): + r, w = os.pipe() + self.sentinel = r - self.pid = os.fork() - if self.pid == 0: - os.close(r) - if 'random' in sys.modules: - import random - random.seed() - code = process_obj._bootstrap() - sys.stdout.flush() - sys.stderr.flush() - os._exit(code) + self.pid = os.fork() + if self.pid == 0: + os.close(r) + if 'random' in sys.modules: + import random + random.seed() + code = process_obj._bootstrap() + sys.stdout.flush() + sys.stderr.flush() + os._exit(code) - # `w` will be closed when the child exits, at which point `r` - # will become ready for reading (using e.g. select()). - os.close(w) - util.Finalize(self, os.close, (r,)) + # `w` will be closed when the child exits, at which point `r` + # will become ready for reading (using e.g. select()). + os.close(w) + util.Finalize(self, os.close, (r,)) def poll(self, flag=os.WNOHANG): if self.returncode is None: diff -r f13af83967a0 Lib/os.py --- a/Lib/os.py Sun Jan 22 21:31:39 2012 +0100 +++ b/Lib/os.py Mon Jan 23 20:18:16 2012 +0000 @@ -829,3 +829,13 @@ raise TypeError("invalid fd type (%s, expected integer)" % type(fd)) import io return io.open(fd, *args, **kwargs) + +# replace fork() by wrapper +if _exists("fork"): + _doc = fork.__doc__ + def fork(): + # for bootstrapping reasons os should not directly import atfork + import atfork + return atfork._fork() + fork.__doc__ = _doc + del _doc diff -r f13af83967a0 Lib/subprocess.py --- a/Lib/subprocess.py Sun Jan 22 21:31:39 2012 +0100 +++ b/Lib/subprocess.py Mon Jan 23 20:18:16 2012 +0000 @@ -394,11 +394,13 @@ wShowWindow = 0 class pywintypes: error = IOError + atfork = None else: import select _has_poll = hasattr(select, 'poll') import _posixsubprocess _create_pipe = _posixsubprocess.cloexec_pipe + import atfork # When select or poll has indicated that the file is writable, # we can write up to _PIPE_BUF bytes without risk of blocking. @@ -708,68 +710,79 @@ self.returncode = None self.universal_newlines = universal_newlines - # Input and output objects. The general principle is like - # this: - # - # Parent Child - # ------ ----- - # p2cwrite ---stdin---> p2cread - # c2pread <--stdout--- c2pwrite - # errread <--stderr--- errwrite - # - # On POSIX, the child objects are file descriptors. On - # Windows, these are Windows file handles. The parent objects - # are file descriptors on both platforms. The parent objects - # are -1 when not using PIPEs. The child objects are -1 - # when not redirecting. + if atfork is not None: + # to avoid accidental inheritance of fds by processes + # forked by other threads we acquire the recursive lock + lock = atfork.get_fork_lock() + lock.acquire() + else: + lock = None + try: + # Input and output objects. The general principle is like + # this: + # + # Parent Child + # ------ ----- + # p2cwrite ---stdin---> p2cread + # c2pread <--stdout--- c2pwrite + # errread <--stderr--- errwrite + # + # On POSIX, the child objects are file descriptors. On + # Windows, these are Windows file handles. The parent objects + # are file descriptors on both platforms. The parent objects + # are -1 when not using PIPEs. The child objects are -1 + # when not redirecting. - (p2cread, p2cwrite, - c2pread, c2pwrite, - errread, errwrite) = self._get_handles(stdin, stdout, stderr) + (p2cread, p2cwrite, + c2pread, c2pwrite, + errread, errwrite) = self._get_handles(stdin, stdout, stderr) - # We wrap OS handles *before* launching the child, otherwise a - # quickly terminating child could make our fds unwrappable - # (see #8458). + # We wrap OS handles *before* launching the child, otherwise a + # quickly terminating child could make our fds unwrappable + # (see #8458). - if mswindows: + if mswindows: + if p2cwrite != -1: + p2cwrite = msvcrt.open_osfhandle(p2cwrite.Detach(), 0) + if c2pread != -1: + c2pread = msvcrt.open_osfhandle(c2pread.Detach(), 0) + if errread != -1: + errread = msvcrt.open_osfhandle(errread.Detach(), 0) + if p2cwrite != -1: - p2cwrite = msvcrt.open_osfhandle(p2cwrite.Detach(), 0) + self.stdin = io.open(p2cwrite, 'wb', bufsize) + if self.universal_newlines: + self.stdin = io.TextIOWrapper(self.stdin, write_through=True) if c2pread != -1: - c2pread = msvcrt.open_osfhandle(c2pread.Detach(), 0) + self.stdout = io.open(c2pread, 'rb', bufsize) + if universal_newlines: + self.stdout = io.TextIOWrapper(self.stdout) if errread != -1: - errread = msvcrt.open_osfhandle(errread.Detach(), 0) + self.stderr = io.open(errread, 'rb', bufsize) + if universal_newlines: + self.stderr = io.TextIOWrapper(self.stderr) - if p2cwrite != -1: - self.stdin = io.open(p2cwrite, 'wb', bufsize) - if self.universal_newlines: - self.stdin = io.TextIOWrapper(self.stdin, write_through=True) - if c2pread != -1: - self.stdout = io.open(c2pread, 'rb', bufsize) - if universal_newlines: - self.stdout = io.TextIOWrapper(self.stdout) - if errread != -1: - self.stderr = io.open(errread, 'rb', bufsize) - if universal_newlines: - self.stderr = io.TextIOWrapper(self.stderr) + try: + self._execute_child(args, executable, preexec_fn, close_fds, + pass_fds, cwd, env, universal_newlines, + startupinfo, creationflags, shell, + p2cread, p2cwrite, + c2pread, c2pwrite, + errread, errwrite, + restore_signals, start_new_session) + except: + # Cleanup if the child failed starting + for f in filter(None, [self.stdin, self.stdout, self.stderr]): + try: + f.close() + except EnvironmentError: + # Ignore EBADF or other errors + pass + raise - try: - self._execute_child(args, executable, preexec_fn, close_fds, - pass_fds, cwd, env, universal_newlines, - startupinfo, creationflags, shell, - p2cread, p2cwrite, - c2pread, c2pwrite, - errread, errwrite, - restore_signals, start_new_session) - except: - # Cleanup if the child failed starting - for f in filter(None, [self.stdin, self.stdout, self.stderr]): - try: - f.close() - except EnvironmentError: - # Ignore EBADF or other errors - pass - raise - + finally: + if lock is not None: + lock.release() def _translate_newlines(self, data, encoding): data = data.replace(b"\r\n", b"\n").replace(b"\r", b"\n") diff -r f13af83967a0 Lib/test/test_atfork.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/Lib/test/test_atfork.py Mon Jan 23 20:18:16 2012 +0000 @@ -0,0 +1,199 @@ +import unittest +import atfork +import os +import time +import threading +import queue +import subprocess +import signal +import errno + +try: + import multiprocessing as mp +except ImportError: + mp = None + +from functools import partial + + +@unittest.skipUnless(hasattr(os, "fork"), "requires os.fork()") +class TestAtfork(unittest.TestCase): + + THREADS = 5 + PERIOD = 10.0 + DELTA = 0.01 + + def _start_using_fork(self, q): + with atfork.get_fork_lock(): + r, w = os.pipe() + pid = os.fork() + if pid == 0: + # child + os.close(r) + buf = str(os.getpid()).encode('ascii') + while buf: + n = os.write(w, buf) + buf = buf[n:] + os.close(w) + time.sleep(self.PERIOD) + os._exit(0) + + # parent + time.sleep(self.DELTA) # sleep to make race easy to hit + os.close(w) + + data = [] + while True: + s = os.read(r, 8) + if not s: + break + data.append(s) + + read_pid = int(b''.join(data).decode('ascii')) + q.put((pid, read_pid)) + + @classmethod + def _start_using_mp_inner(cls, r, w): + r.close() + w.send(os.getpid()) + w.close() + time.sleep(cls.PERIOD) + + def _start_using_mp(self, q): + with atfork.get_fork_lock(): + r, w = mp.Pipe(duplex=False) + p = mp.Process(target=self._start_using_mp_inner, args=(r, w)) + p.daemon = True + p.start() + time.sleep(self.DELTA) # sleep to make race easy to hit + w.close() + + read_pid = r.recv() + q.put((p.pid, read_pid)) + + def _start_using_sp(self, q): + p = subprocess.Popen(['/bin/sleep', str(self.PERIOD)], close_fds=False) + q.put((p.pid, p.pid)) + + def test_lock(self): + # We want to know that holding the fork lock before creating + # pipe fds ensures that no forked process accidently inherits + # a writable fd that was not intended for it. Otherwise there + # will be a long delay before we get EOF when reading from the + # associated pipe. + fork_threads = [] + mp_threads = [] + sp_threads = [] + q = queue.Queue() + + for i in range(self.THREADS): + t = threading.Thread(target=self._start_using_fork, args=(q,)) + fork_threads.append(t) + + for i in range(self.THREADS): + t = threading.Thread(target=self._start_using_sp, args=(q,)) + sp_threads.append(t) + + if mp is not None: + for i in range(self.THREADS): + t = threading.Thread(target=self._start_using_mp, args=(q,)) + mp_threads.append(t) + + num_threads = len(fork_threads) + len(mp_threads) + len(sp_threads) + pids = [] + start = time.time() + + for i in range(self.THREADS): + fork_threads[i].start() + sp_threads[i].start() + if mp: + mp_threads[i].start() + + for i in range(num_threads): + pid, read_pid = q.get() + self.assertEqual(pid, read_pid) + pids.append(pid) + + elapsed = time.time() - start + expected = self.DELTA * (len(fork_threads) + len(mp_threads)) + + # If any of the writables pipe handles gets accidentally + # inherited by a child process due to a race, then `elapsed` + # will be approximately self.PERIOD (i.e. the time that the + # child process lives, preventing EOF on that pipe). + # `expected` is the total amount of that the threads slept + # while holding the lock. + self.assertLess(elapsed, expected + 1.0) + + for pid in pids: + try: + os.kill(pid, signal.SIGTERM) + except OSError as e: + if e.errno != errno.ESRCH: + raise + + def test_atfork(self): + with atfork.get_fork_lock(): + old_callbacks = atfork._callbacks[:] + try: + # we check that callbacks get called in expected order + # (i.e. the order documented for pthread_atfork()) + L = list(range(5)) + reversed_L = list(reversed(L)) + + prepare_list = [] + parent_list = [] + child_list = [] + + for obj in L: + atfork.atfork(prepare=partial(prepare_list.append, obj), + parent=partial(parent_list.append, obj), + child=partial(child_list.append, obj)) + + pid = os.fork() + if pid == 0: + if child_list == L: + status = 0 + else: + status = 1 + os._exit(status) + + (pid, status) = os.waitpid(pid, 0) + self.assertEqual(prepare_list, reversed_L) + self.assertEqual(parent_list, L) + self.assertEqual(status, 0) # implies child_list == L + + # now we check that callbacks get called in expected order + # if one of the prepare callbacks raises an error + M = list(range(100, 105)) + reversed_M = list(reversed(M)) + + prepare_list[:] = [] + parent_list[:] = [] + child_list[:] = [] + + atfork.atfork(prepare=lambda : 1/0) + for obj in M: + atfork.atfork(prepare=partial(prepare_list.append, obj), + parent=partial(parent_list.append, obj), + child=partial(child_list.append, obj)) + + self.assertEqual(len(atfork._callbacks), 11) + + # 6th registered parent callback (of 11) raises an error so: + # * only the last 5 prepare callbacks run (successfully) + # * no child process is forked so no child callbacks will be run + # * only the last 5 registered parent callbacks will be run + + with self.assertRaises(ZeroDivisionError): + if os.fork() == 0: + os._exit(0) + + self.assertEqual(prepare_list, reversed_M) + self.assertEqual(parent_list, M) + finally: + atfork._callbacks = old_callbacks + + +if __name__ == '__main__': + unittest.main()