diff --git a/Include/pythonrun.h b/Include/pythonrun.h --- a/Include/pythonrun.h +++ b/Include/pythonrun.h @@ -214,6 +214,8 @@ PyAPI_FUNC(void) PyByteArray_Fini(void); PyAPI_FUNC(void) PyFloat_Fini(void); PyAPI_FUNC(void) PyOS_FiniInterrupts(void); PyAPI_FUNC(void) _PyGC_Fini(void); + +PyAPI_DATA(PyThreadState *) _Py_Finalizing; #endif /* Stuff with no proper home (yet) */ diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py --- a/Lib/test/test_threading.py +++ b/Lib/test/test_threading.py @@ -12,6 +12,7 @@ import unittest import weakref import os import subprocess +from test.script_helper import assert_python_ok from test import lock_tests @@ -463,7 +464,6 @@ class ThreadJoinOnShutdown(BaseTestCase) """ self._run_and_join(script) - @unittest.skipUnless(hasattr(os, 'fork'), "needs os.fork()") def test_2_join_in_forked_process(self): # Like the test above, but from a forked interpreter @@ -655,6 +655,49 @@ class ThreadJoinOnShutdown(BaseTestCase) output = "end of worker thread\nend of main thread\n" self.assertScriptHasOutput(script, output) + def test_6_daemon_threads(self): + # Check that a daemon thread cannot crash the interpreter on shutdown + # by manipulating internal structures that are being disposed of in + # the main thread. + script = """if True: + import os + import random + import sys + import time + import threading + + thread_has_run = set() + + def random_io(): + '''Loop for a while sleeping random tiny amounts and doing some I/O.''' + blank = b'x' * 200 + while True: + in_f = open(os.__file__, 'r') + stuff = in_f.read(200) + null_f = open(os.devnull, 'w') + null_f.write(stuff) + time.sleep(random.random() / 1995) + null_f.close() + in_f.close() + thread_has_run.add(threading.current_thread()) + + def main(): + count = 0 + for _ in range(40): + new_thread = threading.Thread(target=random_io) + new_thread.daemon = True + new_thread.start() + count += 1 + while len(thread_has_run) < count: + time.sleep(0.001) + # Trigger process shutdown + sys.exit(0) + + main() + """ + rc, out, err = assert_python_ok('-c', script) + self.assertFalse(err) + class ThreadingExceptionTests(BaseTestCase): # A RuntimeError should be raised if Thread.start() is called diff --git a/Python/ceval.c b/Python/ceval.c --- a/Python/ceval.c +++ b/Python/ceval.c @@ -440,6 +440,13 @@ PyEval_RestoreThread(PyThreadState *tsta if (gil_created()) { int err = errno; take_gil(tstate); + /* _Py_Finalizing is protected by the GIL */ + assert(main_thread); + if (_Py_Finalizing && tstate != _Py_Finalizing) { + drop_gil(tstate); + PyThread_exit_thread(); + assert(0); /* unreachable */ + } errno = err; } #endif diff --git a/Python/pythonrun.c b/Python/pythonrun.c --- a/Python/pythonrun.c +++ b/Python/pythonrun.c @@ -90,6 +90,8 @@ int Py_IgnoreEnvironmentFlag; /* e.g. PY int Py_NoUserSiteDirectory = 0; /* for -s and site.py */ int Py_UnbufferedStdioFlag = 0; /* Unbuffered binary std{in,out,err} */ +PyThreadState *_Py_Finalizing = NULL; + /* PyModule_GetWarningsModule is no longer necessary as of 2.6 since _warnings is builtin. This API should not be used. */ PyObject * @@ -188,6 +190,7 @@ Py_InitializeEx(int install_sigs) if (initialized) return; initialized = 1; + _Py_Finalizing = NULL; #if defined(HAVE_LANGINFO_H) && defined(HAVE_SETLOCALE) /* Set up the LC_CTYPE locale, so we can obtain @@ -388,15 +391,19 @@ Py_Finalize(void) * the threads created via Threading. */ call_py_exitfuncs(); + + /* Get current thread state and interpreter pointer */ + tstate = PyThreadState_GET(); + interp = tstate->interp; + + /* Remaining threads (e.g. daemon threads) will automatically exit + after taking the GIL (in PyEval_RestoreThread()). */ + _Py_Finalizing = tstate; initialized = 0; /* Flush stdout+stderr */ flush_std_files(); - /* Get current thread state and interpreter pointer */ - tstate = PyThreadState_GET(); - interp = tstate->interp; - /* Disable signal handling */ PyOS_FiniInterrupts(); diff --git a/Python/thread_pthread.h b/Python/thread_pthread.h --- a/Python/thread_pthread.h +++ b/Python/thread_pthread.h @@ -250,9 +250,9 @@ void PyThread_exit_thread(void) { dprintf(("PyThread_exit_thread called\n")); - if (!initialized) { + if (!initialized) exit(0); - } + pthread_exit(0); } #ifdef USE_SEMAPHORES