diff --git a/Lib/test/test_descr.py b/Lib/test/test_descr.py --- a/Lib/test/test_descr.py +++ b/Lib/test/test_descr.py @@ -796,6 +796,137 @@ class X(int(), C): pass + def test___init_class__(self): + # PEP 422: Simple class initialisation hook + class C: + @classmethod + def __init_class__(cls): + cls.x = 0 + self.assertEqual(0, C.x) + # inherited: + class D(C): + pass + self.assertEqual(0, D.__dict__['x']) + # overwrite: + class E(C): + x = 1 + self.assertEqual(0, E.__dict__['x']) + # override: + class F(C): + @classmethod + def __init_class__(cls): + cls.y = 1 + self.assertEqual(1, F.y) + self.assertEqual(0, F.x) + self.assertNotIn('x', F.__dict__) + self.assertFalse(hasattr(C, 'y')) + # super: + class G(C): + @classmethod + def __init_class__(cls): + super().__init_class__() + cls.y = 1 + self.assertEqual(1, G.y) + self.assertEqual(0, G.__dict__['x']) + self.assertFalse(hasattr(C, 'y')) + # staticmethod: + class H(C): + @staticmethod + def __init_class__(): + __class__.z = 2 + self.assertEqual(2, H.z) + self.assertFalse(hasattr(C, 'z')) + # __class__: + class I: + @classmethod + def __init_class__(cls): + cls.x = 0 + __class__.y += 1 + y = 0 + self.assertEqual(0, I.x) + self.assertEqual(1, I.y) + class J(I): + pass + self.assertEqual(0, J.__dict__['x']) + self.assertEqual(2, I.y) + class K(J): + @classmethod + def __init_class__(cls): + super().__init_class__() + __class__.z += 1 + z = 0 + self.assertEqual(0, K.__dict__['x']) + self.assertEqual(3, I.y) + self.assertEqual(1, K.z) + self.assertFalse(hasattr(J, 'z')) + # multiple inheritance: + class L: + @classmethod + def __init_class__(cls): + pass + class M(L): + @classmethod + def __init_class__(cls): + super().__init_class__() + cls.x = 0 + self.assertEqual(0, M.x) + class N(L): + @classmethod + def __init_class__(cls): + super().__init_class__() + cls.y = 1 + self.assertEqual(1, N.y) + class O(M, N): + @classmethod + def __init_class__(cls): + super().__init_class__() + cls.z = 2 + self.assertEqual(0, O.__dict__['x']) + self.assertEqual(1, O.__dict__['y']) + self.assertEqual(2, O.__dict__['z']) + # decorators: + def dec1(cls): + cls.x = 1 + return cls + def dec2(cls): + cls.x = 2 + return cls + @dec2 + @dec1 + class P: + @classmethod + def __init_class__(cls): + cls.x = 0 + self.assertEqual(2, P.x) + # block class initialization: + class Meta(type): + def __getattribute__(cls, name): + if name == '__init_class__': + raise AttributeError('__init_class__') + return super().__getattribute__(name) + class Q(C, metaclass=Meta): + pass + self.assertNotIn('x', Q.__dict__) + # other exceptions should be propagated: + class Meta2(type): + def __getattribute__(cls, name): + if name == '__init_class__': + raise KeyError('xxx') + return super().__getattribute__(name) + R = sentinel = object() + with self.assertRaisesRegex(KeyError, 'xxx'): + class R(C, metaclass=Meta2): + pass + self.assertIs(sentinel, R) + # __init_class__ raises an exception: + S = sentinel + with self.assertRaisesRegex(KeyError, 'xxx'): + class S: + @classmethod + def __init_class__(cls): + raise KeyError('xxx') + self.assertIs(sentinel, S) + def test_module_subclasses(self): # Testing Python subclass of module... log = [] diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py --- a/Lib/test/test_types.py +++ b/Lib/test/test_types.py @@ -995,6 +995,73 @@ with self.assertRaises(TypeError): X = types.new_class("X", (int(), C)) + def test___init_class__(self): + # PEP 422: Simple class initialisation hook + def c(ns): + @classmethod + def __init_class__(cls): + cls.x = 0 + ns['__init_class__'] = __init_class__ + C = types.new_class("C", exec_body=c) + self.assertEqual(0, C.x) + # inherited: + D = types.new_class("D", (C,)) + self.assertEqual(0, D.__dict__['x']) + # overwrite: + def e(ns): + ns['x'] = 1 + E = types.new_class("E", (C,), exec_body=e) + self.assertEqual(0, E.__dict__['x']) + # override: + def f(ns): + @classmethod + def __init_class__(cls): + cls.y = 1 + ns['__init_class__'] = __init_class__ + F = types.new_class("F", (C,), exec_body=f) + self.assertEqual(1, F.y) + self.assertEqual(0, F.x) + self.assertNotIn('x', F.__dict__) + self.assertFalse(hasattr(C, 'y')) + # staticmethod: + z = 0 + def h(ns): + @staticmethod + def __init_class__(): + nonlocal z + z = 2 + ns['__init_class__'] = __init_class__ + H = types.new_class("H", (C,), exec_body=h) + self.assertEqual(2, z) + # block class initialization: + class Meta(type): + def __getattribute__(cls, name): + if name == '__init_class__': + raise AttributeError('__init_class__') + return super().__getattribute__(name) + Q = types.new_class("Q", (C,), dict(metaclass=Meta)) + self.assertNotIn('x', Q.__dict__) + # other exceptions should be propagated: + class Meta2(type): + def __getattribute__(cls, name): + if name == '__init_class__': + raise KeyError('xxx') + return super().__getattribute__(name) + R = sentinel = object() + with self.assertRaisesRegex(KeyError, 'xxx'): + R = types.new_class("R", (C,), dict(metaclass=Meta2)) + self.assertIs(sentinel, R) + # __init_class__ raises an exception: + def s(ns): + @classmethod + def __init_class__(cls): + raise KeyError('xxx') + ns['__init_class__'] = __init_class__ + S = sentinel + with self.assertRaisesRegex(KeyError, 'xxx'): + S = types.new_class("S", exec_body=s) + self.assertIs(sentinel, S) + class SimpleNamespaceTests(unittest.TestCase): diff --git a/Lib/types.py b/Lib/types.py --- a/Lib/types.py +++ b/Lib/types.py @@ -49,7 +49,14 @@ meta, ns, kwds = prepare_class(name, bases, kwds) if exec_body is not None: exec_body(ns) - return meta(name, bases, ns, **kwds) + cls = meta(name, bases, ns, **kwds) + try: + initcl = cls.__init_class__ + except AttributeError: + pass + else: + initcl() + return cls def prepare_class(name, bases=(), kwds=None): """Call the __prepare__ method of the appropriate metaclass. diff --git a/Python/bltinmodule.c b/Python/bltinmodule.c --- a/Python/bltinmodule.c +++ b/Python/bltinmodule.c @@ -38,11 +38,13 @@ static PyObject * builtin___build_class__(PyObject *self, PyObject *args, PyObject *kwds) { - PyObject *func, *name, *bases, *mkw, *meta, *winner, *prep, *ns, *cell; + PyObject *func, *name, *bases, *mkw, *meta, *winner, *prep, *ns; + PyObject *cell, *initcl, *res; PyObject *cls = NULL; Py_ssize_t nargs; int isclass; _Py_IDENTIFIER(__prepare__); + _Py_IDENTIFIER(__init_class__); assert(args != NULL); if (!PyTuple_Check(args)) { @@ -163,8 +165,35 @@ cls = PyEval_CallObjectWithKeywords(meta, margs, mkw); Py_DECREF(margs); } - if (cls != NULL && PyCell_Check(cell)) - PyCell_Set(cell, cls); + if (cls != NULL) { + /* initialize the __class__ reference: */ + if (PyCell_Check(cell)) { + PyCell_Set(cell, cls); + } + /* call __init_class__: */ + initcl = _PyObject_GetAttrId(cls, &PyId___init_class__); + if (initcl == NULL) { + if (PyErr_ExceptionMatches(PyExc_AttributeError)) { + PyErr_Clear(); + /* no __init_class__, nothing to do */ + } + else { + /* propagate other exceptions: */ + Py_DECREF(cls); + cls = NULL; + } + } + else { + res = PyObject_CallObject(initcl, NULL); + Py_DECREF(initcl); + if (res == NULL) { + /* __init_class__ raised an exception */ + Py_DECREF(cls); + cls = NULL; + } + Py_XDECREF(res); + } + } Py_DECREF(cell); } Py_DECREF(ns);