Index: Python/pystate.c =================================================================== --- Python/pystate.c (revision 74311) +++ Python/pystate.c (working copy) @@ -192,6 +192,7 @@ tstate->c_tracefunc = NULL; tstate->c_profileobj = NULL; tstate->c_traceobj = NULL; + tstate->main_thread = 0; #ifdef WITH_THREAD _PyGILState_NoteThreadState(tstate); @@ -642,5 +643,3 @@ #endif #endif /* WITH_THREAD */ - - Index: Python/pythonrun.c =================================================================== --- Python/pythonrun.c (revision 74311) +++ Python/pythonrun.c (working copy) @@ -172,6 +172,7 @@ tstate = PyThreadState_New(interp); if (tstate == NULL) Py_FatalError("Py_Initialize: can't make first thread"); + tstate->main_thread = 1; (void) PyThreadState_Swap(tstate); _Py_ReadyTypes(); @@ -360,6 +361,33 @@ } +/* Wait until threading._shutdown completes, provided + the threading module was imported in the first place. + The shutdown routine will wait until all non-daemon + "threading" threads have completed. */ +#include "abstract.h" +void +Py_WaitForThreadShutdown(void) +{ +#ifdef WITH_THREAD + PyObject *result; + PyThreadState *tstate = PyThreadState_GET(); + PyObject *threading = PyMapping_GetItemString(tstate->interp->modules, + "threading"); + if (threading == NULL) { + /* threading not imported */ + PyErr_Clear(); + return; + } + result = PyObject_CallMethod(threading, "_shutdown", ""); + if (result == NULL) + PyErr_WriteUnraisable(threading); + else + Py_DECREF(result); + Py_DECREF(threading); +#endif +} + #ifdef COUNT_ALLOCS extern void dump_counts(FILE*); #endif @@ -1958,4 +1986,3 @@ #ifdef __cplusplus } #endif - Index: Include/pystate.h =================================================================== --- Include/pystate.h (revision 74311) +++ Include/pystate.h (working copy) @@ -65,8 +65,11 @@ This is to prevent the actual trace/profile code from being recorded in the trace/profile. */ int tracing; - int use_tracing; + short use_tracing; + /* Flag for whether this thread is the main thread. */ + short main_thread; + Py_tracefunc c_profilefunc; Py_tracefunc c_tracefunc; PyObject *c_profileobj; Index: Include/pythonrun.h =================================================================== --- Include/pythonrun.h (revision 74311) +++ Include/pythonrun.h (working copy) @@ -27,6 +27,7 @@ PyAPI_FUNC(void) Py_Initialize(void); PyAPI_FUNC(void) Py_InitializeEx(int); +PyAPI_FUNC(void) Py_WaitForThreadShutdown(void); PyAPI_FUNC(void) Py_Finalize(void); PyAPI_FUNC(int) Py_IsInitialized(void); PyAPI_FUNC(PyThreadState *) Py_NewInterpreter(void); Index: Lib/threading.py =================================================================== --- Lib/threading.py (revision 74311) +++ Lib/threading.py (working copy) @@ -225,6 +225,16 @@ else: return True + def _reset_lock(self, lock=None): + """Throw away the old lock and replace it with this one.""" + if lock is None: + lock = Lock() + self.__lock = lock + # Reset these exported bound methods. If we don't do this, these + # bound methods will still refer to the old lock. + self.acquire = lock.acquire + self.release = lock.release + def wait(self, timeout=None): if not self._is_owned(): raise RuntimeError("cannot wait on un-aquired lock") @@ -746,7 +756,10 @@ return False def _exitfunc(self): - self._Thread__stop() + try: + self._Thread__stop() + except: + pass # Swallow errors, since this may not work if we've forked. t = _pickSomeNonDaemonThread() if t: if __debug__: @@ -756,7 +769,10 @@ t = _pickSomeNonDaemonThread() if __debug__: self._note("%s: exiting", self) - self._Thread__delete() + try: + self._Thread__delete() + except: + pass # Swallow errors, since this may not work if we've forked. def _pickSomeNonDaemonThread(): for t in enumerate(): @@ -850,6 +866,7 @@ # fork() only copied the current thread; clear references to others. new_active = {} + current = current_thread() with _active_limbo_lock: for thread in _active.itervalues(): @@ -858,6 +875,10 @@ # its new value since it can have changed. ident = _get_ident() thread._Thread__ident = ident + # Any locks hanging off of the active thread may be in an + # invalid state, so we reset them. + thread._Thread__block._reset_lock() + thread._Thread__started._Event__cond._reset_lock() new_active[ident] = thread else: # All the others are already stopped. Index: Lib/test/test_threading.py =================================================================== --- Lib/test/test_threading.py (revision 74311) +++ Lib/test/test_threading.py (working copy) @@ -1,5 +1,7 @@ # Very rudimentary test of threading module +from __future__ import with_statement + import test.test_support from test.test_support import verbose import random @@ -419,6 +421,172 @@ self._run_and_join(script) +class ThreadAndForkTests(unittest.TestCase): + + def _run_and_check_output(self, script, expected_output): + import subprocess + p = subprocess.Popen([sys.executable, "-c", script], + stdout=subprocess.PIPE) + rc = p.wait() + data = p.stdout.read().decode().replace('\r', '') + self.assertEqual(data, expected_output) + self.assertFalse(rc == 2, "interpreter was blocked") + self.assertTrue(rc == 0, "Unexpected error") + + def test_join_fork_stop_deadlock(self): + # There used to be a possible deadlock when forking from a child + # thread. See http://bugs.python.org/issue6643. + + import os + if not hasattr(os, 'fork'): + return + # Skip platforms with known problems forking from a worker thread. + # See http://bugs.python.org/issue3863. + if sys.platform in ('freebsd4', 'freebsd5', 'freebsd6', 'os2emx'): + raise unittest.SkipTest('due to known OS bugs on ' + sys.platform) + script = """if 1: + import os, sys, time, threading + + finish_fork = False + finish_join = False + + def worker(): + # Wait just a bit before forking so that the original thread + # makes it into my_acquire. + global finish_fork + global finish_join + while not finish_fork: + pass + childpid = os.fork() + finish_join = True + if childpid != 0: + # Parent process just waits for child. + os.waitpid(childpid, 0) + # Child process should just return. + + w = threading.Thread(target=worker) + + # Stub out the private condition variable's lock acquire method. + # There is a race to acquire this between w.join() and w.__stop(), + # which is called when the thread returns. + condition = w._Thread__block + orig_acquire = condition.acquire + call_count = 0 + def my_acquire(): + global call_count + global finish_join + global finish_fork + orig_acquire() + finish_fork = True + if call_count == 0: + while not finish_join: + pass + call_count += 1 + condition.acquire = my_acquire + + w.start() + w.join() + print('end of main') + """ + self._run_and_check_output(script, "end of main\n") + + def test_thread_fork_thread_hang(self): + # Check that a thread that forks and then spawns a daemon thread can + # exit properly. Previously, it would just exit one thread instead of + # shutting down the entire process, so the daemon thread would prevent + # the process from exiting, causing the parent process to hang in + # waitpid(). + + import os + if not hasattr(os, 'fork'): + return + # Skip platforms with known problems forking from a worker thread. + # See http://bugs.python.org/issue3863. + if sys.platform in ('freebsd4', 'freebsd5', 'freebsd6', 'os2emx'): + raise unittest.SkipTest('due to known OS bugs on ' + sys.platform) + script = """if 1: + import os, sys, threading + + def worker(): + childpid = os.fork() + if childpid != 0: + # Parent waits for child. + os.waitpid(childpid, 0) + else: + # Child spawns a daemon thread and then returns + # immediately. + def daemon(): + while True: + pass + d = threading.Thread(target=daemon) + d.daemon = True + d.start() + # Return, do not call sys.exit(0) or d.join(). The process + # should exit without waiting for the daemon thread, but we + # expect that due to a bug relating to os.fork and threads + # it will hang. + + w = threading.Thread(target=worker) + w.start() + w.join() + print('end of main') + """ + self._run_and_check_output(script, "end of main\n") + + def test_thread_fork_atexit(self): + # Check that a thread that forks and then spawns a daemon thread + # properly executes atexit handlers. Previously, Py_Finalize was not + # being called because the main thread in the child process does not + # return through Py_Main(). Now there is a check in + # thread_PyThread_exit_thread() to run these finalizers if the exiting + # thread happens to be the main thread. + + import os + if not hasattr(os, 'fork'): + return + # Skip platforms with known problems forking from a worker thread. + # See http://bugs.python.org/issue3863. + if sys.platform in ('freebsd4', 'freebsd5', 'freebsd6', 'os2emx'): + raise unittest.SkipTest('due to known OS bugs on ' + sys.platform) + script = """if 1: + import atexit, os, sys, threading + + def worker(): + # Setup a pipe between the processes, and register an atexit + # handler to write the pid to the pipe. Note, these are file + # descriptors, not file-like objects. + (reader, writer) = os.pipe() + def write_atexit(): + if writer is None: + return + os.write(writer, str(os.getpid())) + os.close(writer) + atexit.register(write_atexit) + + childpid = os.fork() + if childpid != 0: # Parent + os.close(writer) + # Throw away the writer so that the atexit handler does + # nothing in the parent. + writer = None + os.waitpid(childpid, 0) + output = os.read(reader, 100) + if output == str(childpid): + print "successfully read child pid" + else: + print "got bad output:", output + else: # Child + os.close(reader) + # The child should just return without exiting. We want to + # verify that the atexit handler gets called. + + w = threading.Thread(target=worker) + w.start() + w.join() + """ + self._run_and_check_output(script, "successfully read child pid\n") + + class ThreadingExceptionTests(unittest.TestCase): # A RuntimeError should be raised if Thread.start() is called # multiple times. @@ -461,6 +629,7 @@ test.test_support.run_unittest(ThreadTests, ThreadJoinOnShutdown, ThreadingExceptionTests, + ThreadAndForkTests, ) if __name__ == "__main__": Index: Modules/threadmodule.c =================================================================== --- Modules/threadmodule.c (revision 74311) +++ Modules/threadmodule.c (working copy) @@ -446,9 +446,22 @@ Py_DECREF(boot->args); Py_XDECREF(boot->keyw); PyMem_DEL(boot_raw); - PyThreadState_Clear(tstate); - PyThreadState_DeleteCurrent(); - PyThread_exit_thread(); + /* Shut down the entire interpreter if we're the main thread, otherwise + * just terminate the thread. */ + if (tstate->main_thread) { + Py_WaitForThreadShutdown(); + Py_Finalize(); + /* Note that we must call _exit() directly instead of returning + * in order to terminate any daemon threads that we didn't wait + * for. We also need to call _exit() instead of exit() so that + * we don't double flush any open file descriptors. */ + _exit(0); + } + else { + PyThreadState_Clear(tstate); + PyThreadState_DeleteCurrent(); + PyThread_exit_thread(); + } } static PyObject * Index: Modules/signalmodule.c =================================================================== --- Modules/signalmodule.c (revision 74311) +++ Modules/signalmodule.c (working copy) @@ -922,6 +922,14 @@ PyOS_AfterFork(void) { #ifdef WITH_THREAD + PyThreadState *tstate; + + /* After the fork, there is only one thread in the child and + * that is this thread. Therefore we mark it as the main thread + * so that it will shutdown properly. */ + tstate = PyThreadState_GET(); + tstate->main_thread = 1; + PyEval_ReInitThreads(); main_thread = PyThread_get_thread_ident(); main_pid = getpid(); Index: Modules/main.c =================================================================== --- Modules/main.c (revision 74311) +++ Modules/main.c (working copy) @@ -222,33 +222,6 @@ } -/* Wait until threading._shutdown completes, provided - the threading module was imported in the first place. - The shutdown routine will wait until all non-daemon - "threading" threads have completed. */ -#include "abstract.h" -static void -WaitForThreadShutdown(void) -{ -#ifdef WITH_THREAD - PyObject *result; - PyThreadState *tstate = PyThreadState_GET(); - PyObject *threading = PyMapping_GetItemString(tstate->interp->modules, - "threading"); - if (threading == NULL) { - /* threading not imported */ - PyErr_Clear(); - return; - } - result = PyObject_CallMethod(threading, "_shutdown", ""); - if (result == NULL) - PyErr_WriteUnraisable(threading); - else - Py_DECREF(result); - Py_DECREF(threading); -#endif -} - /* Main program */ int @@ -620,7 +593,7 @@ sts = PyRun_AnyFileFlags(stdin, "", &cf) != 0; } - WaitForThreadShutdown(); + Py_WaitForThreadShutdown(); Py_Finalize(); #ifdef RISCOS @@ -661,4 +634,3 @@ #ifdef __cplusplus } #endif -