diff --git a/Doc/library/threading.rst b/Doc/library/threading.rst --- a/Doc/library/threading.rst +++ b/Doc/library/threading.rst @@ -57,6 +57,12 @@ This module defines the following functi and threads that have not yet been started. +.. function:: main_thread() + + Return the main :class:`Thread` object. The main thread is the + thread that the OS creates to run application. + + .. function:: settrace(func) .. index:: single: trace function diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py --- a/Lib/test/test_threading.py +++ b/Lib/test/test_threading.py @@ -19,6 +19,15 @@ import subprocess from test import lock_tests + +# Between fork() and exec(), only async-safe functions are allowed (issues +# #12316 and #11870), and fork() from a worker thread is known to trigger +# problems with some operating systems (issue #3863): skip problematic tests +# on platforms known to behave badly. +platforms_to_skip = ('freebsd4', 'freebsd5', 'freebsd6', 'netbsd5', + 'hp-ux11') + + # A trivial mutable counter. class Counter(object): def __init__(self): @@ -445,16 +454,72 @@ class ThreadTests(BaseTestCase): self.assertEqual(out, b'') self.assertEqual(err, b'') + def test_main_thread(self): + main = threading.main_thread() + self.assertEqual(main.name, 'MainThread') + self.assertEqual(main.ident, threading.current_thread().ident) + self.assertEqual(main.ident, threading.get_ident()) + + def f(): + self.assertNotEqual(threading.main_thread().ident, + threading.current_thread().ident) + th = threading.Thread(target=f) + th.start() + th.join() + + @unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()") + @unittest.skipUnless(hasattr(os, 'waitpid'), "test needs os.waitpid()") + def test_main_thread_after_fork(self): + code = """if 1: + import os, threading + + pid = os.fork() + if pid == 0: + main = threading.main_thread() + print(main.name) + print(main.ident == threading.current_thread().ident) + print(main.ident == threading.get_ident()) + else: + os.waitpid(pid, 0) + """ + rc, out, err = assert_python_ok("-c", code) + self.assertEqual(rc, 0) + data = out.decode().replace('\r', '') + self.assertEqual(err, b"") + self.assertEqual(data, "MainThread\nTrue\nTrue\n") + + @unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug") + @unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()") + @unittest.skipUnless(hasattr(os, 'waitpid'), "test needs os.waitpid()") + def test_main_thread_after_fork_from_nonmain_thread(self): + code = """if 1: + import os, threading, sys + + def f(): + pid = os.fork() + if pid == 0: + main = threading.main_thread() + print(main.name) + print(main.ident == threading.current_thread().ident) + print(main.ident == threading.get_ident()) + # stdout is fully buffered because not a tty, + # we have to flush before exit. + sys.stdout.flush() + else: + os.waitpid(pid, 0) + + th = threading.Thread(target=f) + th.start() + th.join() + """ + _, out, err = assert_python_ok("-c", code) + data = out.decode().replace('\r', '') + self.assertEqual(err, b"") + self.assertEqual(data, "Thread-1\nTrue\nTrue\n") + class ThreadJoinOnShutdown(BaseTestCase): - # Between fork() and exec(), only async-safe functions are allowed (issues - # #12316 and #11870), and fork() from a worker thread is known to trigger - # problems with some operating systems (issue #3863): skip problematic tests - # on platforms known to behave badly. - platforms_to_skip = ('freebsd4', 'freebsd5', 'freebsd6', 'netbsd5', - 'hp-ux11') - def _run_and_join(self, script): script = """if 1: import sys, os, time, threading diff --git a/Lib/threading.py b/Lib/threading.py --- a/Lib/threading.py +++ b/Lib/threading.py @@ -1,5 +1,6 @@ """Thread module emulating a subset of Java's threading model.""" +import functools import sys as _sys import _thread @@ -840,13 +841,13 @@ class _MainThread(Thread): with _active_limbo_lock: _active[self._ident] = self - def _exitfunc(self): - self._stop() +def _exitfunc(main_thread): + main_thread._stop() + t = _pickSomeNonDaemonThread() + while t: + t.join() t = _pickSomeNonDaemonThread() - while t: - t.join() - t = _pickSomeNonDaemonThread() - self._delete() + main_thread._delete() def _pickSomeNonDaemonThread(): for t in enumerate(): @@ -915,7 +916,12 @@ from _thread import stack_size # and make it available for the interpreter # (Py_Main) as threading._shutdown. -_shutdown = _MainThread()._exitfunc +_main_thread = _MainThread() + +def main_thread(): + return _main_thread + +_shutdown = functools.partial(_exitfunc, _main_thread) # get thread-local implementation, either from the thread # module, or from the python fallback @@ -933,12 +939,14 @@ def _after_fork(): # Reset _active_limbo_lock, in case we forked while the lock was held # by another (non-forked) thread. http://bugs.python.org/issue874900 - global _active_limbo_lock + global _active_limbo_lock, _main_thread, _shutdown _active_limbo_lock = _allocate_lock() # fork() only copied the current thread; clear references to others. new_active = {} current = current_thread() + _main_thread = current + _shutdown = functools.partial(_exitfunc, _main_thread) with _active_limbo_lock: for thread in _active.values(): # Any lock/condition variable may be currently locked or in an