diff -r d403eaec64df Lib/ctypes/test/test_arrays.py --- a/Lib/ctypes/test/test_arrays.py Fri Aug 12 18:03:30 2011 +0200 +++ b/Lib/ctypes/test/test_arrays.py Fri Aug 12 15:14:54 2011 -0500 @@ -127,5 +127,57 @@ t2 = my_int * 1 self.assertTrue(t1 is t2) + def test_subclass(self): + class T(Array): + _type_ = c_int + _length_ = 13 + class U(T): + pass + class V(U): + pass + class W(V): + pass + class X(T): + _type_ = c_short + class Y(T): + _length_ = 187 + + for c in [T, U, V, W]: + self.assertEqual(c._type_, c_int) + self.assertEqual(c._length_, 13) + self.assertEqual(c()._type_, c_int) + self.assertEqual(c()._length_, 13) + + self.assertEqual(X._type_, c_short) + self.assertEqual(X._length_, 13) + self.assertEqual(X()._type_, c_short) + self.assertEqual(X()._length_, 13) + + self.assertEqual(Y._type_, c_int) + self.assertEqual(Y._length_, 187) + self.assertEqual(Y()._type_, c_int) + self.assertEqual(Y()._length_, 187) + + def test_bad_subclass(self): + import sys + + with self.assertRaises(AttributeError): + class T(Array): + pass + with self.assertRaises(AttributeError): + class T(Array): + _type_ = c_int + with self.assertRaises(AttributeError): + class T(Array): + _length_ = 13 + with self.assertRaises(OverflowError): + class T(Array): + _type_ = c_int + _length_ = sys.maxsize * 2 + with self.assertRaises(AttributeError): + class T(Array): + _type_ = c_int + _length_ = 1.87 + if __name__ == '__main__': unittest.main() diff -r d403eaec64df Modules/_ctypes/_ctypes.c --- a/Modules/_ctypes/_ctypes.c Fri Aug 12 18:03:30 2011 +0200 +++ b/Modules/_ctypes/_ctypes.c Fri Aug 12 15:14:54 2011 -0500 @@ -1256,49 +1256,57 @@ PyTypeObject *result; StgDictObject *stgdict; StgDictObject *itemdict; - PyObject *proto; - PyObject *typedict; + PyObject *length_attr, *type_attr; long length; int overflow; Py_ssize_t itemsize, itemalign; char buf[32]; - typedict = PyTuple_GetItem(args, 2); - if (!typedict) + /* create the new instance (which is a class, + since we are a metatype!) */ + result = (PyTypeObject *)PyType_Type.tp_new(type, args, kwds); + if (result == NULL) return NULL; - proto = PyDict_GetItemString(typedict, "_length_"); /* Borrowed ref */ - if (!proto || !PyLong_Check(proto)) { + /* Initialize these variables to NULL so that we can simplify error + handling by using Py_XDECREF. */ + stgdict = NULL; + type_attr = NULL; + + length_attr = PyObject_GetAttrString((PyObject *)result, "_length_"); + if (!length_attr || !PyLong_Check(length_attr)) { PyErr_SetString(PyExc_AttributeError, "class must define a '_length_' attribute, " "which must be a positive integer"); - return NULL; - } - length = PyLong_AsLongAndOverflow(proto, &overflow); + Py_XDECREF(length_attr); + goto error; + } + length = PyLong_AsLongAndOverflow(length_attr, &overflow); if (overflow) { PyErr_SetString(PyExc_OverflowError, "The '_length_' attribute is too large"); - return NULL; - } - - proto = PyDict_GetItemString(typedict, "_type_"); /* Borrowed ref */ - if (!proto) { + Py_DECREF(length_attr); + goto error; + } + Py_DECREF(length_attr); + + type_attr = PyObject_GetAttrString((PyObject *)result, "_type_"); + if (!type_attr) { PyErr_SetString(PyExc_AttributeError, "class must define a '_type_' attribute"); - return NULL; + goto error; } stgdict = (StgDictObject *)PyObject_CallObject( (PyObject *)&PyCStgDict_Type, NULL); if (!stgdict) - return NULL; - - itemdict = PyType_stgdict(proto); + goto error; + + itemdict = PyType_stgdict(type_attr); if (!itemdict) { PyErr_SetString(PyExc_TypeError, "_type_ must have storage info"); - Py_DECREF((PyObject *)stgdict); - return NULL; + goto error; } assert(itemdict->format); @@ -1309,16 +1317,12 @@ sprintf(buf, "(%ld)", length); stgdict->format = _ctypes_alloc_format_string(buf, itemdict->format); } - if (stgdict->format == NULL) { - Py_DECREF((PyObject *)stgdict); - return NULL; - } + if (stgdict->format == NULL) + goto error; stgdict->ndim = itemdict->ndim + 1; stgdict->shape = PyMem_Malloc(sizeof(Py_ssize_t *) * stgdict->ndim); - if (stgdict->shape == NULL) { - Py_DECREF((PyObject *)stgdict); - return NULL; - } + if (stgdict->shape == NULL) + goto error; stgdict->shape[0] = length; memmove(&stgdict->shape[1], itemdict->shape, sizeof(Py_ssize_t) * (stgdict->ndim - 1)); @@ -1327,7 +1331,7 @@ if (length * itemsize < 0) { PyErr_SetString(PyExc_OverflowError, "array too large"); - return NULL; + goto error; } itemalign = itemdict->align; @@ -1338,26 +1342,16 @@ stgdict->size = itemsize * length; stgdict->align = itemalign; stgdict->length = length; - Py_INCREF(proto); - stgdict->proto = proto; + stgdict->proto = type_attr; stgdict->paramfunc = &PyCArrayType_paramfunc; /* Arrays are passed as pointers to function calls. */ stgdict->ffi_type_pointer = ffi_type_pointer; - /* create the new instance (which is a class, - since we are a metatype!) */ - result = (PyTypeObject *)PyType_Type.tp_new(type, args, kwds); - if (result == NULL) - return NULL; - /* replace the class dict by our updated spam dict */ - if (-1 == PyDict_Update((PyObject *)stgdict, result->tp_dict)) { - Py_DECREF(result); - Py_DECREF((PyObject *)stgdict); - return NULL; - } + if (-1 == PyDict_Update((PyObject *)stgdict, result->tp_dict)) + goto error; Py_DECREF(result->tp_dict); result->tp_dict = (PyObject *)stgdict; @@ -1366,15 +1360,20 @@ */ if (itemdict->getfunc == _ctypes_get_fielddesc("c")->getfunc) { if (-1 == add_getset(result, CharArray_getsets)) - return NULL; + goto error; #ifdef CTYPES_UNICODE } else if (itemdict->getfunc == _ctypes_get_fielddesc("u")->getfunc) { if (-1 == add_getset(result, WCharArray_getsets)) - return NULL; + goto error; #endif } return (PyObject *)result; +error: + Py_XDECREF((PyObject*)stgdict); + Py_XDECREF(type_attr); + Py_DECREF(result); + return NULL; } PyTypeObject PyCArrayType_Type = {