diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py --- a/Lib/test/test_threading.py +++ b/Lib/test/test_threading.py @@ -445,6 +445,66 @@ class ThreadTests(BaseTestCase): self.assertEqual(out, b'') self.assertEqual(err, b'') + def test_main_thread(self): + main = threading.main_thread() + self.assertEqual(main.name, 'MainThread') + self.assertEqual(main.ident, threading.current_thread().ident) + self.assertEqual(main.ident, threading.get_ident()) + + def f(): + self.assertNotEqual(threading.main_thread().ident, + threading.current_thread().ident) + th = threading.Thread(target=f) + th.start() + th.join() + + @unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()") + @unittest.skipUnless(hasattr(os, 'waitpid'), "test needs os.waitpid()") + def test_main_thread_after_fork(self): + code = """if 1: + import os, threading + + pid = os.fork() + if pid == 0: + main = threading.main_thread() + print(main.name) + print(main.ident == threading.current_thread().ident) + print(main.ident == threading.get_ident()) + else: + os.waitpid(pid, 0) + """ + rc, out, err = assert_python_ok("-c", code) + self.assertEqual(rc, 0) + data = out.decode().replace('\r', '') + self.assertEqual(err, b"") + self.assertEqual(data, "MainThread\nTrue\nTrue\n") + + @unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()") + @unittest.skipUnless(hasattr(os, 'waitpid'), "test needs os.waitpid()") + def test_main_thread_after_fork_from_nonmain_thread(self): + code = """if 1: + import os, threading + + def f(): + pid = os.fork() + if pid == 0: + main = threading.main_thread() + print(main.name) + print(main.ident == threading.current_thread().ident) + print(main.ident == threading.get_ident()) + else: + os.waitpid(pid, 0) + + th = threading.Thread(target=f) + th.start() + th.join() + """ + rc, out, err = assert_python_ok("-c", code) + self.assertEqual(rc, 0) + data = out.decode().replace('\r', '') + self.assertEqual(err, b"") + self.assertEqual(data, "MainThread\nTrue\nTrue\n") + class ThreadJoinOnShutdown(BaseTestCase): diff --git a/Lib/threading.py b/Lib/threading.py --- a/Lib/threading.py +++ b/Lib/threading.py @@ -915,7 +915,12 @@ from _thread import stack_size # and make it available for the interpreter # (Py_Main) as threading._shutdown. -_shutdown = _MainThread()._exitfunc +_main_thread = _MainThread() + +def main_thread(): + return _main_thread + +_shutdown = _main_thread._exitfunc # get thread-local implementation, either from the thread # module, or from the python fallback