diff --git a/Include/pystate.h b/Include/pystate.h --- a/Include/pystate.h +++ b/Include/pystate.h @@ -40,6 +40,10 @@ typedef struct _is { int tscdump; #endif + /* For the atexit module */ + void (*py_atexit_func)(void *); + void *py_atexit_arg; + } PyInterpreterState; #endif diff --git a/Include/pythonrun.h b/Include/pythonrun.h --- a/Include/pythonrun.h +++ b/Include/pythonrun.h @@ -125,7 +125,7 @@ PyAPI_FUNC(void) PyErr_Display(PyObject * exit functions. */ #ifndef Py_LIMITED_API -PyAPI_FUNC(void) _Py_PyAtExit(void (*func)(void)); +PyAPI_FUNC(void) _Py_PyAtExit(void (*func)(void *), void *); #endif PyAPI_FUNC(int) Py_AtExit(void (*func)(void)); diff --git a/Lib/test/test_atexit.py b/Lib/test/test_atexit.py --- a/Lib/test/test_atexit.py +++ b/Lib/test/test_atexit.py @@ -1,9 +1,16 @@ import sys import unittest import io +import os +import pickle import atexit from test import support +try: + import _testcapi +except ImportError: + _testcapi = None + ### helpers def h1(): print("h1") @@ -121,6 +128,29 @@ class TestCase(unittest.TestCase): atexit._run_exitfuncs() self.assertEqual(l, [5]) + @unittest.skipUnless(_testcapi, "test needs the _testcapi module") + def test_subinterps(self): + code = """if 1: + import sys, pickle, atexit + f = open({:d}, "wb", closefd=False) + def exit_func(): + pickle.dump(id(sys.modules), f) + pickle.dump(id(sys.modules), f) + atexit.register(exit_func) + """ + for i in range(2): + r, w = os.pipe() + with open(r, "rb") as f: + try: + ret = _testcapi.run_in_subinterp(code.format(w)) + finally: + os.close(w) + self.assertEqual(ret, 0) + mod = pickle.load(f) + # The atexit handler was invoked with the right interpreter + self.assertNotEqual(mod, id(sys.modules)) + self.assertEqual(pickle.load(f), mod) + def test_main(): support.run_unittest(TestCase) diff --git a/Modules/atexitmodule.c b/Modules/atexitmodule.c --- a/Modules/atexitmodule.c +++ b/Modules/atexitmodule.c @@ -36,7 +36,7 @@ typedef struct { /* Installed into pythonrun.c's atexit mechanism */ static void -atexit_callfuncs(void) +atexit_callfuncs(void *arg) { PyObject *exc_type = NULL, *exc_value, *exc_tb, *r; atexit_callback *cb; @@ -44,9 +44,9 @@ atexit_callfuncs(void) atexitmodule_state *modstate; int i; - module = PyState_FindModule(&atexitmodule); - if (module == NULL) - return; + module = (PyObject *) arg; + assert(module != NULL); + assert(PyModule_Check(module)); modstate = GET_ATEXIT_STATE(module); if (modstate->ncallbacks == 0) @@ -79,6 +79,7 @@ atexit_callfuncs(void) } atexit_cleanup(module); + Py_DECREF(module); if (exc_type) PyErr_Restore(exc_type, exc_value, exc_tb); @@ -180,7 +181,8 @@ Run all registered exit functions."); static PyObject * atexit_run_exitfuncs(PyObject *self, PyObject *unused) { - atexit_callfuncs(); + Py_INCREF(self); + atexit_callfuncs((void *) self); if (PyErr_Occurred()) return NULL; Py_RETURN_NONE; @@ -296,6 +298,7 @@ PyInit_atexit(void) if (modstate->atexit_callbacks == NULL) return NULL; - _Py_PyAtExit(atexit_callfuncs); + Py_INCREF(m); + _Py_PyAtExit(atexit_callfuncs, (PyObject *) m); return m; } diff --git a/Python/pystate.c b/Python/pystate.c --- a/Python/pystate.c +++ b/Python/pystate.c @@ -90,6 +90,8 @@ PyInterpreterState_New(void) #ifdef WITH_TSC interp->tscdump = 0; #endif + interp->py_atexit_func = NULL; + interp->py_atexit_arg = NULL; HEAD_LOCK(); interp->next = interp_head; diff --git a/Python/pythonrun.c b/Python/pythonrun.c --- a/Python/pythonrun.c +++ b/Python/pythonrun.c @@ -63,7 +63,7 @@ static PyObject *run_pyc_file(FILE *, co PyCompilerFlags *); static void err_input(perrdetail *); static void initsigs(void); -static void call_py_exitfuncs(void); +static void call_py_exitfuncs(PyInterpreterState *); static void wait_for_thread_shutdown(void); static void call_ll_exitfuncs(void); extern void _PyUnicode_Init(void); @@ -397,6 +397,10 @@ Py_Finalize(void) wait_for_thread_shutdown(); + /* Get current thread state and interpreter pointer */ + tstate = PyThreadState_GET(); + interp = tstate->interp; + /* The interpreter is still entirely intact at this point, and the * exit funcs may be relying on that. In particular, if some thread * or exit func is still waiting to do an import, the import machinery @@ -406,11 +410,7 @@ Py_Finalize(void) * threads created thru it, so this also protects pending imports in * the threads created via Threading. */ - call_py_exitfuncs(); - - /* Get current thread state and interpreter pointer */ - tstate = PyThreadState_GET(); - interp = tstate->interp; + call_py_exitfuncs(interp); /* Remaining threads (e.g. daemon threads) will automatically exit after taking the GIL (in PyEval_RestoreThread()). */ @@ -683,6 +683,8 @@ Py_EndInterpreter(PyThreadState *tstate) if (tstate != interp->tstate_head || tstate->next != NULL) Py_FatalError("Py_EndInterpreter: not the last thread"); + call_py_exitfuncs(interp); + PyImport_Cleanup(); PyInterpreterState_Clear(interp); PyThreadState_Swap(NULL); @@ -2167,20 +2169,21 @@ Py_FatalError(const char *msg) #include "pythread.h" #endif -static void (*pyexitfunc)(void) = NULL; /* For the atexit module. */ -void _Py_PyAtExit(void (*func)(void)) +void _Py_PyAtExit(void (*func)(void *), void *arg) { - pyexitfunc = func; + PyInterpreterState *interp = PyThreadState_GET()->interp; + interp->py_atexit_func = func; + interp->py_atexit_arg = arg; } static void -call_py_exitfuncs(void) +call_py_exitfuncs(PyInterpreterState *interp) { - if (pyexitfunc == NULL) + if (interp->py_atexit_func == NULL) return; - (*pyexitfunc)(); + (*interp->py_atexit_func)(interp->py_atexit_arg); PyErr_Clear(); }