diff --git a/Lib/test/test_super.py b/Lib/test/test_super.py index b84863f..a7ceded 100644 --- a/Lib/test/test_super.py +++ b/Lib/test/test_super.py @@ -143,6 +143,87 @@ class TestSuper(unittest.TestCase): return __class__ self.assertIs(X.f(), X) + def test___class___new(self): + test_class = None + + class Meta(type): + def __new__(cls, name, bases, namespace): + nonlocal test_class + self = super().__new__(cls, name, bases, namespace) + test_class = self.f() + return self + + class A(metaclass=Meta): + @staticmethod + def f(): + return __class__ + + self.assertIs(test_class, A) + + def test___class___delayed(self): + test_namespace = None + + class Meta(type): + def __new__(cls, name, bases, namespace): + nonlocal test_namespace + test_namespace = namespace + return None + + class A(metaclass=Meta): + @staticmethod + def f(): + return __class__ + + self.assertIs(A, None) + + B = type("B", (), test_namespace) + self.assertIs(B.f(), B) + + def test___class___mro(self): + test_class = None + + class Meta(type): + def mro(self): + # self.f() doesn't work yet... + self.__dict__["f"]() + return super().mro() + + class A(metaclass=Meta): + def f(): + nonlocal test_class + test_class = __class__ + + self.assertIs(test_class, A) + + def test___classcell___deleted(self): + class Meta(type): + def __new__(cls, name, bases, namespace): + del namespace['__classcell__'] + return super().__new__(cls, name, bases, namespace) + + class A(metaclass=Meta): + @staticmethod + def f(): + __class__ + + with self.assertRaises(NameError): + A.f() + + def test___classcell___reset(self): + class Meta(type): + def __new__(cls, name, bases, namespace): + namespace['__classcell__'] = 0 + return super().__new__(cls, name, bases, namespace) + + class A(metaclass=Meta): + @staticmethod + def f(): + __class__ + + with self.assertRaises(NameError): + A.f() + self.assertEqual(A.__classcell__, 0) + def test_obscure_super_errors(self): def f(): super() diff --git a/Objects/typeobject.c b/Objects/typeobject.c index 2675a48..5028304 100644 --- a/Objects/typeobject.c +++ b/Objects/typeobject.c @@ -2285,7 +2285,7 @@ static PyObject * type_new(PyTypeObject *metatype, PyObject *args, PyObject *kwds) { PyObject *name, *bases = NULL, *orig_dict, *dict = NULL; - PyObject *qualname, *slots = NULL, *tmp, *newslots; + PyObject *qualname, *slots = NULL, *tmp, *newslots, *cell; PyTypeObject *type = NULL, *base, *tmptype, *winner; PyHeapTypeObject *et; PyMemberDef *mp; @@ -2293,6 +2293,7 @@ type_new(PyTypeObject *metatype, PyObject *args, PyObject *kwds) int j, may_add_dict, may_add_weak, add_dict, add_weak; _Py_IDENTIFIER(__qualname__); _Py_IDENTIFIER(__slots__); + _Py_IDENTIFIER(__classcell__); assert(args != NULL && PyTuple_Check(args)); assert(kwds == NULL || PyDict_Check(kwds)); @@ -2559,7 +2560,7 @@ type_new(PyTypeObject *metatype, PyObject *args, PyObject *kwds) } et->ht_qualname = qualname ? qualname : et->ht_name; Py_INCREF(et->ht_qualname); - if (qualname != NULL && PyDict_DelItem(dict, PyId___qualname__.object) < 0) + if (qualname != NULL && _PyDict_DelItemId(dict, &PyId___qualname__) < 0) goto error; /* Set tp_doc to a copy of dict['__doc__'], if the latter is there @@ -2685,6 +2686,14 @@ type_new(PyTypeObject *metatype, PyObject *args, PyObject *kwds) else type->tp_free = PyObject_Del; + /* store type in class' cell */ + cell = _PyDict_GetItemId(dict, &PyId___classcell__); + if (cell != NULL && PyCell_Check(cell)) { + PyCell_Set(cell, (PyObject *) type); + _PyDict_DelItemId(dict, &PyId___classcell__); + PyErr_Clear(); + } + /* Initialize the rest */ if (PyType_Ready(type) < 0) goto error; diff --git a/Python/bltinmodule.c b/Python/bltinmodule.c index da9ad55..cee013f 100644 --- a/Python/bltinmodule.c +++ b/Python/bltinmodule.c @@ -54,7 +54,7 @@ _Py_IDENTIFIER(stderr); static PyObject * builtin___build_class__(PyObject *self, PyObject *args, PyObject *kwds) { - PyObject *func, *name, *bases, *mkw, *meta, *winner, *prep, *ns, *cell; + PyObject *func, *name, *bases, *mkw, *meta, *winner, *prep, *ns, *none; PyObject *cls = NULL; Py_ssize_t nargs; int isclass = 0; /* initialize to prevent gcc warning */ @@ -167,15 +167,13 @@ builtin___build_class__(PyObject *self, PyObject *args, PyObject *kwds) Py_DECREF(bases); return NULL; } - cell = PyEval_EvalCodeEx(PyFunction_GET_CODE(func), PyFunction_GET_GLOBALS(func), ns, + none = PyEval_EvalCodeEx(PyFunction_GET_CODE(func), PyFunction_GET_GLOBALS(func), ns, NULL, 0, NULL, 0, NULL, 0, NULL, PyFunction_GET_CLOSURE(func)); - if (cell != NULL) { + if (none != NULL) { PyObject *margs[3] = {name, bases, ns}; cls = _PyObject_FastCallDict(meta, margs, 3, mkw); - if (cls != NULL && PyCell_Check(cell)) - PyCell_Set(cell, cls); - Py_DECREF(cell); + Py_DECREF(none); } Py_DECREF(ns); Py_DECREF(meta); diff --git a/Python/compile.c b/Python/compile.c index 36683d3..4631dbb 100644 --- a/Python/compile.c +++ b/Python/compile.c @@ -1968,7 +1968,7 @@ compiler_class(struct compiler *c, stmt_ty s) return 0; } if (c->u->u_ste->ste_needs_class_closure) { - /* return the (empty) __class__ cell */ + /* store __classcell__ into class namespace */ str = PyUnicode_InternFromString("__class__"); if (str == NULL) { compiler_exit_scope(c); @@ -1981,15 +1981,20 @@ compiler_class(struct compiler *c, stmt_ty s) return 0; } assert(i == 0); - /* Return the cell where to store __class__ */ + ADDOP_I(c, LOAD_CLOSURE, i); + str = PyUnicode_InternFromString("__classcell__"); + if (!str || !compiler_nameop(c, str, Store)) { + Py_XDECREF(str); + compiler_exit_scope(c); + return 0; + } + Py_DECREF(str); } else { + /* This happens when nobody references the cell. */ assert(PyDict_Size(c->u->u_cellvars) == 0); - /* This happens when nobody references the cell. Return None. */ - ADDOP_O(c, LOAD_CONST, Py_None, consts); } - ADDOP_IN_SCOPE(c, RETURN_VALUE); /* create the code object */ co = assemble(c, 1); }