diff --git a/Lib/test/test_super.py b/Lib/test/test_super.py index b84863f..a7ceded 100644 --- a/Lib/test/test_super.py +++ b/Lib/test/test_super.py @@ -143,6 +143,87 @@ class TestSuper(unittest.TestCase): return __class__ self.assertIs(X.f(), X) + def test___class___new(self): + test_class = None + + class Meta(type): + def __new__(cls, name, bases, namespace): + nonlocal test_class + self = super().__new__(cls, name, bases, namespace) + test_class = self.f() + return self + + class A(metaclass=Meta): + @staticmethod + def f(): + return __class__ + + self.assertIs(test_class, A) + + def test___class___delayed(self): + test_namespace = None + + class Meta(type): + def __new__(cls, name, bases, namespace): + nonlocal test_namespace + test_namespace = namespace + return None + + class A(metaclass=Meta): + @staticmethod + def f(): + return __class__ + + self.assertIs(A, None) + + B = type("B", (), test_namespace) + self.assertIs(B.f(), B) + + def test___class___mro(self): + test_class = None + + class Meta(type): + def mro(self): + # self.f() doesn't work yet... + self.__dict__["f"]() + return super().mro() + + class A(metaclass=Meta): + def f(): + nonlocal test_class + test_class = __class__ + + self.assertIs(test_class, A) + + def test___classcell___deleted(self): + class Meta(type): + def __new__(cls, name, bases, namespace): + del namespace['__classcell__'] + return super().__new__(cls, name, bases, namespace) + + class A(metaclass=Meta): + @staticmethod + def f(): + __class__ + + with self.assertRaises(NameError): + A.f() + + def test___classcell___reset(self): + class Meta(type): + def __new__(cls, name, bases, namespace): + namespace['__classcell__'] = 0 + return super().__new__(cls, name, bases, namespace) + + class A(metaclass=Meta): + @staticmethod + def f(): + __class__ + + with self.assertRaises(NameError): + A.f() + self.assertEqual(A.__classcell__, 0) + def test_obscure_super_errors(self): def f(): super() diff --git a/Objects/typeobject.c b/Objects/typeobject.c index 5227f6a..63b1502 100644 --- a/Objects/typeobject.c +++ b/Objects/typeobject.c @@ -2270,7 +2270,7 @@ type_new(PyTypeObject *metatype, PyObject *args, PyObject *kwds) { PyObject *name, *bases = NULL, *orig_dict, *dict = NULL; static char *kwlist[] = {"name", "bases", "dict", 0}; - PyObject *qualname, *slots = NULL, *tmp, *newslots; + PyObject *qualname, *slots = NULL, *tmp, *newslots, *cell; PyTypeObject *type = NULL, *base, *tmptype, *winner; PyHeapTypeObject *et; PyMemberDef *mp; @@ -2278,6 +2278,7 @@ type_new(PyTypeObject *metatype, PyObject *args, PyObject *kwds) int j, may_add_dict, may_add_weak, add_dict, add_weak; _Py_IDENTIFIER(__qualname__); _Py_IDENTIFIER(__slots__); + _Py_IDENTIFIER(__classcell__); assert(args != NULL && PyTuple_Check(args)); assert(kwds == NULL || PyDict_Check(kwds)); @@ -2544,7 +2545,7 @@ type_new(PyTypeObject *metatype, PyObject *args, PyObject *kwds) } et->ht_qualname = qualname ? qualname : et->ht_name; Py_INCREF(et->ht_qualname); - if (qualname != NULL && PyDict_DelItem(dict, PyId___qualname__.object) < 0) + if (qualname != NULL && _PyDict_DelItemId(dict, &PyId___qualname__) < 0) goto error; /* Set tp_doc to a copy of dict['__doc__'], if the latter is there @@ -2656,6 +2657,14 @@ type_new(PyTypeObject *metatype, PyObject *args, PyObject *kwds) else type->tp_free = PyObject_Del; + /* store type in class' cell */ + cell = _PyDict_GetItemId(dict, &PyId___classcell__); + if (cell != NULL && PyCell_Check(cell)) { + PyCell_Set(cell, (PyObject *) type); + _PyDict_DelItemId(dict, &PyId___classcell__); + PyErr_Clear(); + } + /* Initialize the rest */ if (PyType_Ready(type) < 0) goto error; diff --git a/Python/bltinmodule.c b/Python/bltinmodule.c index 220c92d..4635fd0 100644 --- a/Python/bltinmodule.c +++ b/Python/bltinmodule.c @@ -52,7 +52,7 @@ _Py_IDENTIFIER(stderr); static PyObject * builtin___build_class__(PyObject *self, PyObject *args, PyObject *kwds) { - PyObject *func, *name, *bases, *mkw, *meta, *winner, *prep, *ns, *cell; + PyObject *func, *name, *bases, *mkw, *meta, *winner, *prep, *ns, *none; PyObject *cls = NULL; Py_ssize_t nargs; int isclass = 0; /* initialize to prevent gcc warning */ @@ -173,19 +173,17 @@ builtin___build_class__(PyObject *self, PyObject *args, PyObject *kwds) Py_DECREF(bases); return NULL; } - cell = PyEval_EvalCodeEx(PyFunction_GET_CODE(func), PyFunction_GET_GLOBALS(func), ns, + none = PyEval_EvalCodeEx(PyFunction_GET_CODE(func), PyFunction_GET_GLOBALS(func), ns, NULL, 0, NULL, 0, NULL, 0, NULL, PyFunction_GET_CLOSURE(func)); - if (cell != NULL) { + if (none != NULL) { PyObject *margs; margs = PyTuple_Pack(3, name, bases, ns); if (margs != NULL) { cls = PyEval_CallObjectWithKeywords(meta, margs, mkw); Py_DECREF(margs); } - if (cls != NULL && PyCell_Check(cell)) - PyCell_Set(cell, cls); - Py_DECREF(cell); + Py_DECREF(none); } Py_DECREF(ns); Py_DECREF(meta); diff --git a/Python/compile.c b/Python/compile.c index e46676c..5278eed 100644 --- a/Python/compile.c +++ b/Python/compile.c @@ -1883,7 +1883,7 @@ compiler_class(struct compiler *c, stmt_ty s) return 0; } if (c->u->u_ste->ste_needs_class_closure) { - /* return the (empty) __class__ cell */ + /* store __classcell__ into class namespace */ str = PyUnicode_InternFromString("__class__"); if (str == NULL) { compiler_exit_scope(c); @@ -1896,15 +1896,20 @@ compiler_class(struct compiler *c, stmt_ty s) return 0; } assert(i == 0); - /* Return the cell where to store __class__ */ + ADDOP_I(c, LOAD_CLOSURE, i); + str = PyUnicode_InternFromString("__classcell__"); + if (!str || !compiler_nameop(c, str, Store)) { + Py_XDECREF(str); + compiler_exit_scope(c); + return 0; + } + Py_DECREF(str); } else { + /* This happens when nobody references the cell. */ assert(PyDict_Size(c->u->u_cellvars) == 0); - /* This happens when nobody references the cell. Return None. */ - ADDOP_O(c, LOAD_CONST, Py_None, consts); } - ADDOP_IN_SCOPE(c, RETURN_VALUE); /* create the code object */ co = assemble(c, 1); }