diff --git a/Include/pyerrors.h b/Include/pyerrors.h --- a/Include/pyerrors.h +++ b/Include/pyerrors.h @@ -312,6 +312,7 @@ /* In sigcheck.c or signalmodule.c */ +PyAPI_FUNC(void) PyErr_AddPendingCheckSignals(void); PyAPI_FUNC(int) PyErr_CheckSignals(void); PyAPI_FUNC(void) PyErr_SetInterrupt(void); diff --git a/Lib/test/test_sys_settrace.py b/Lib/test/test_sys_settrace.py --- a/Lib/test/test_sys_settrace.py +++ b/Lib/test/test_sys_settrace.py @@ -5,6 +5,8 @@ import sys import difflib import gc +import time +import signal # A very basic example. If this fails, we're in deep trouble. def basic(): @@ -480,6 +482,31 @@ finally: sys.settrace(existing) + def test_exception_raised_in_alarm(self): + # Test that an exception raised in an alarm does not remove the trace + # function (issue 20601). + def trace(frame, event, arg): + funcname = frame.f_code.co_name + if event == 'call' and funcname == 'f': + time.sleep(1) + return trace + + def f(): pass + + def handler(*args): + 1/0 + + sys.settrace(trace) + signal.signal(signal.SIGALRM, handler) + signal.alarm(1) + err = None + try: + f() + except ZeroDivisionError as e: + err = e + self.assertIsInstance(err, ZeroDivisionError) + self.assertIs(sys.gettrace(), trace) + # 'Jump' tests: assigning to frame.f_lineno within a trace function # moves the execution position - it's how debuggers implement a Jump diff --git a/Modules/signalmodule.c b/Modules/signalmodule.c --- a/Modules/signalmodule.c +++ b/Modules/signalmodule.c @@ -1312,11 +1312,20 @@ /* Declared in pyerrors.h */ +/* Add a pending call to check for signals (postponed during tracing). */ +void +PyErr_AddPendingCheckSignals(void) +{ + is_tripped = 1; + Py_AddPendingCall(checksignals_witharg, NULL); +} + int PyErr_CheckSignals(void) { int i; PyObject *f; + PyThreadState *tstate; if (!is_tripped) return 0; @@ -1345,7 +1354,13 @@ if (!(f = (PyObject *)PyEval_GetFrame())) f = Py_None; + tstate = PyThreadState_GET(); + for (i = 1; i < NSIG; i++) { + /* Postpone signals while evaluating the trace function. */ + if (tstate->tracing && i != SIGINT) + continue; + if (Handlers[i].tripped) { PyObject *result = NULL; PyObject *arglist = Py_BuildValue("(iO)", i, f); diff --git a/Python/ceval.c b/Python/ceval.c --- a/Python/ceval.c +++ b/Python/ceval.c @@ -3912,6 +3912,7 @@ tstate->use_tracing = ((tstate->c_tracefunc != NULL) || (tstate->c_profilefunc != NULL)); tstate->tracing--; + PyErr_AddPendingCheckSignals(); return result; }