diff --git a/Lib/test/test_module.py b/Lib/test/test_module.py index 9da3536..a461743 100644 --- a/Lib/test/test_module.py +++ b/Lib/test/test_module.py @@ -235,6 +235,22 @@ a = A(destroyed)""" melon = Descr() self.assertRaises(RuntimeError, getattr, M("mymod"), "melon") + def test___class___assignment(self): + class SubType(ModuleType): + a = 1 + + m = ModuleType("m") + self.assertTrue(m.__class__ is ModuleType) + self.assertFalse(hasattr(m, "a")) + + m.__class__ = SubType + self.assertTrue(m.__class__ is SubType) + self.assertTrue(hasattr(m, "a")) + + m.__class__ = ModuleType + self.assertTrue(m.__class__ is ModuleType) + self.assertFalse(hasattr(m, "a")) + # frozen and namespace module reprs are tested in importlib. diff --git a/Objects/typeobject.c b/Objects/typeobject.c index 941dedb..91ec7cc 100644 --- a/Objects/typeobject.c +++ b/Objects/typeobject.c @@ -3420,17 +3420,18 @@ object_get_class(PyObject *self, void *closure) } static int -equiv_structs(PyTypeObject *a, PyTypeObject *b) +compatible_with_tp_base(PyTypeObject *child) { - return a == b || - (a != NULL && - b != NULL && - a->tp_basicsize == b->tp_basicsize && - a->tp_itemsize == b->tp_itemsize && - a->tp_dictoffset == b->tp_dictoffset && - a->tp_weaklistoffset == b->tp_weaklistoffset && - ((a->tp_flags & Py_TPFLAGS_HAVE_GC) == - (b->tp_flags & Py_TPFLAGS_HAVE_GC))); + PyTypeObject *parent = child->tp_base; + return (parent != NULL && + child->tp_basicsize == parent->tp_basicsize && + child->tp_itemsize == parent->tp_itemsize && + child->tp_dictoffset == parent->tp_dictoffset && + child->tp_weaklistoffset == parent->tp_weaklistoffset && + ((child->tp_flags & Py_TPFLAGS_HAVE_GC) == + (parent->tp_flags & Py_TPFLAGS_HAVE_GC)) && + (child->tp_dealloc == subtype_dealloc || + child->tp_dealloc == parent->tp_dealloc)); } static int @@ -3448,6 +3449,10 @@ same_slots_added(PyTypeObject *a, PyTypeObject *b) size += sizeof(PyObject *); /* Check slots compliance */ + if (!(a->tp_flags & Py_TPFLAGS_HEAPTYPE) || + !(b->tp_flags & Py_TPFLAGS_HEAPTYPE)) { + return 0; + } slots_a = ((PyHeapTypeObject *)a)->ht_slots; slots_b = ((PyHeapTypeObject *)b)->ht_slots; if (slots_a && slots_b) { @@ -3463,8 +3468,7 @@ compatible_for_assignment(PyTypeObject* oldto, PyTypeObject* newto, char* attr) { PyTypeObject *newbase, *oldbase; - if (newto->tp_dealloc != oldto->tp_dealloc || - newto->tp_free != oldto->tp_free) + if (newto->tp_free != oldto->tp_free) { PyErr_Format(PyExc_TypeError, "%s assignment: " @@ -3474,11 +3478,21 @@ compatible_for_assignment(PyTypeObject* oldto, PyTypeObject* newto, char* attr) oldto->tp_name); return 0; } + /* It's tricky to tell if two arbitrary types are sufficiently compatible + as to be interchangeable; e.g., even if they have the same + tp_basicsize, they might have totally different struct fields. It's + much easier to tell if a type and its supertype are compatible; e.g., + if they have the same tp_basicsize, then that means they have identical + fields. So to check whether two arbitrary types are compatible, we + first find the highest supertype that each is compatible with, and then + if those supertypes are compatible then the original types must also be + compatible. + */ newbase = newto; oldbase = oldto; - while (equiv_structs(newbase, newbase->tp_base)) + while (compatible_with_tp_base(newbase)) newbase = newbase->tp_base; - while (equiv_structs(oldbase, oldbase->tp_base)) + while (compatible_with_tp_base(oldbase)) oldbase = oldbase->tp_base; if (newbase != oldbase && (newbase->tp_base != oldbase->tp_base || @@ -3513,17 +3527,12 @@ object_set_class(PyObject *self, PyObject *value, void *closure) return -1; } newto = (PyTypeObject *)value; - if (!(newto->tp_flags & Py_TPFLAGS_HEAPTYPE) || - !(oldto->tp_flags & Py_TPFLAGS_HEAPTYPE)) - { - PyErr_Format(PyExc_TypeError, - "__class__ assignment: only for heap types"); - return -1; - } if (compatible_for_assignment(oldto, newto, "__class__")) { - Py_INCREF(newto); + if (newto->tp_flags & Py_TPFLAGS_HEAPTYPE) + Py_INCREF(newto); Py_TYPE(self) = newto; - Py_DECREF(oldto); + if (oldto->tp_flags & Py_TPFLAGS_HEAPTYPE) + Py_DECREF(oldto); return 0; } else {