diff --git a/Modules/_pickle.c b/Modules/_pickle.c --- a/Modules/_pickle.c +++ b/Modules/_pickle.c @@ -1535,18 +1535,18 @@ memo_put(PicklerObject *self, PyObject * } static PyObject * -getattribute(PyObject *obj, PyObject *name, int allow_qualname) { - PyObject *dotted_path; - Py_ssize_t i; +get_dotted_path(PyObject *obj, PyObject *name, int allow_qualname) { _Py_static_string(PyId_dot, "."); _Py_static_string(PyId_locals, ""); + PyObject *dotted_path; + Py_ssize_t i, n; dotted_path = PyUnicode_Split(name, _PyUnicode_FromId(&PyId_dot), -1); - if (dotted_path == NULL) { + if (dotted_path == NULL) return NULL; - } - assert(Py_SIZE(dotted_path) >= 1); - if (!allow_qualname && Py_SIZE(dotted_path) > 1) { + n = PyList_GET_SIZE(dotted_path); + assert(n >= 1); + if (!allow_qualname && n > 1) { PyErr_Format(PyExc_AttributeError, "Can't get qualified attribute %R on %R;" "use protocols >= 4 to enable support", @@ -1554,10 +1554,8 @@ getattribute(PyObject *obj, PyObject *na Py_DECREF(dotted_path); return NULL; } - Py_INCREF(obj); - for (i = 0; i < Py_SIZE(dotted_path); i++) { + for (i = 0; i < n; i++) { PyObject *subpath = PyList_GET_ITEM(dotted_path, i); - PyObject *tmp; PyObject *result = PyUnicode_RichCompare( subpath, _PyUnicode_FromId(&PyId_locals), Py_EQ); int is_equal = (result == Py_True); @@ -1567,24 +1565,56 @@ getattribute(PyObject *obj, PyObject *na PyErr_Format(PyExc_AttributeError, "Can't get local attribute %R on %R", name, obj); Py_DECREF(dotted_path); - Py_DECREF(obj); return NULL; } - tmp = PyObject_GetAttr(obj, subpath); + } + return dotted_path; +} + +static PyObject * +get_deep_attribute(PyObject *obj, PyObject *names) +{ + Py_ssize_t i, n; + + assert(PyList_CheckExact(names)); + Py_INCREF(obj); + n = PyList_GET_SIZE(names); + for (i = 0; i < n; i++) { + PyObject *name = PyList_GET_ITEM(names, i); + PyObject *tmp; + tmp = PyObject_GetAttr(obj, name); Py_DECREF(obj); - if (tmp == NULL) { - if (PyErr_ExceptionMatches(PyExc_AttributeError)) { - PyErr_Clear(); - PyErr_Format(PyExc_AttributeError, - "Can't get attribute %R on %R", name, obj); - } - Py_DECREF(dotted_path); + if (tmp == NULL) return NULL; - } obj = tmp; } + return obj; +} + +static void +reformat_attribute_error(PyObject *obj, PyObject *name) +{ + if (PyErr_ExceptionMatches(PyExc_AttributeError)) { + PyErr_Clear(); + PyErr_Format(PyExc_AttributeError, + "Can't get attribute %R on %R", name, obj); + } +} + + +static PyObject * +getattribute(PyObject *obj, PyObject *name, int allow_qualname) +{ + PyObject *dotted_path, *attr; + + dotted_path = get_dotted_path(obj, name, allow_qualname); + if (dotted_path == NULL) + return NULL; + attr = get_deep_attribute(obj, dotted_path); Py_DECREF(dotted_path); - return obj; + if (attr == NULL) + reformat_attribute_error(obj, name); + return attr; } static PyObject * @@ -1593,11 +1623,11 @@ whichmodule(PyObject *global, PyObject * PyObject *module_name; PyObject *modules_dict; PyObject *module; - PyObject *obj; - Py_ssize_t i, j; + Py_ssize_t i; _Py_IDENTIFIER(__module__); _Py_IDENTIFIER(modules); _Py_IDENTIFIER(__main__); + PyObject *dotted_path; module_name = _PyObject_GetAttrId(global, &PyId___module__); @@ -1616,43 +1646,50 @@ whichmodule(PyObject *global, PyObject * } assert(module_name == NULL); + /* Fallback on walking sys.modules */ modules_dict = _PySys_GetObjectId(&PyId_modules); if (modules_dict == NULL) { PyErr_SetString(PyExc_RuntimeError, "unable to get sys.modules"); return NULL; } + dotted_path = get_dotted_path(module, global_name, allow_qualname); + if (dotted_path == NULL) + return NULL; + i = 0; - while ((j = PyDict_Next(modules_dict, &i, &module_name, &module))) { - PyObject *result = PyUnicode_RichCompare( - module_name, _PyUnicode_FromId(&PyId___main__), Py_EQ); - int is_equal = (result == Py_True); - assert(PyBool_Check(result)); - Py_DECREF(result); - if (is_equal) + while (PyDict_Next(modules_dict, &i, &module_name, &module)) { + PyObject *candidate; + if (PyUnicode_Check(module_name) && + !PyUnicode_CompareWithASCIIString(module_name, "__main__")) continue; if (module == Py_None) continue; - obj = getattribute(module, global_name, allow_qualname); - if (obj == NULL) { - if (!PyErr_ExceptionMatches(PyExc_AttributeError)) + candidate = get_deep_attribute(module, dotted_path); + if (candidate == NULL) { + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { + Py_DECREF(dotted_path); + Py_DECREF(candidate); return NULL; + } PyErr_Clear(); continue; } - if (obj == global) { - Py_DECREF(obj); + if (candidate == global) { Py_INCREF(module_name); + Py_DECREF(dotted_path); + Py_DECREF(candidate); return module_name; } - Py_DECREF(obj); + Py_DECREF(candidate); } /* If no module is found, use __main__. */ module_name = _PyUnicode_FromId(&PyId___main__); Py_INCREF(module_name); + Py_DECREF(dotted_path); return module_name; }