diff -r 8da9d539ecf4 Lib/ctypes/test/test_arrays.py --- a/Lib/ctypes/test/test_arrays.py Sun Apr 17 09:18:04 2011 -0500 +++ b/Lib/ctypes/test/test_arrays.py Sun Apr 17 20:00:02 2011 -0500 @@ -127,5 +127,45 @@ t2 = my_int * 1 self.assertTrue(t1 is t2) + def test_subclass(self): + class T(Array): + _type_ = c_int + _length_ = 13 + class U(T): + pass + class V(T): + pass + class W(T): + _type_ = c_short + class X(T): + _length_ = 187 + + for c in [T, U, V]: + self.assertEqual(c._type_, c_int) + self.assertEqual(c._length_, 13) + self.assertEqual(c()._type_, c_int) + self.assertEqual(c()._length_, 13) + + self.assertEqual(W._type_, c_short) + self.assertEqual(W._length_, 13) + self.assertEqual(W()._type_, c_short) + self.assertEqual(W()._length_, 13) + + self.assertEqual(X._type_, c_int) + self.assertEqual(X._length_, 187) + self.assertEqual(X()._type_, c_int) + self.assertEqual(X()._length_, 187) + + def test_bad_subclass(self): + with self.assertRaises(AttributeError): + class T(Array): + pass + with self.assertRaises(AttributeError): + class T(Array): + _type_ = c_int + with self.assertRaises(AttributeError): + class T(Array): + _length_ = 13 + if __name__ == '__main__': unittest.main() diff -r 8da9d539ecf4 Modules/_ctypes/_ctypes.c --- a/Modules/_ctypes/_ctypes.c Sun Apr 17 09:18:04 2011 -0500 +++ b/Modules/_ctypes/_ctypes.c Sun Apr 17 20:00:02 2011 -0500 @@ -1256,7 +1256,7 @@ PyTypeObject *result; StgDictObject *stgdict; StgDictObject *itemdict; - PyObject *proto; + PyObject *length_attr, *type_attr; PyObject *typedict; long length; int overflow; @@ -1267,38 +1267,50 @@ if (!typedict) return NULL; - proto = PyDict_GetItemString(typedict, "_length_"); /* Borrowed ref */ - if (!proto || !PyLong_Check(proto)) { - PyErr_SetString(PyExc_AttributeError, - "class must define a '_length_' attribute, " - "which must be a positive integer"); + /* create the new instance (which is a class, + since we are a metatype!) */ + result = (PyTypeObject *)PyType_Type.tp_new(type, args, kwds); + if (result == NULL) return NULL; - } - length = PyLong_AsLongAndOverflow(proto, &overflow); + + length_attr = PyDict_GetItemString(typedict, "_length_"); + if (!length_attr || !PyLong_Check(length_attr)) { + length_attr = PyDict_GetItemString(result->tp_base->tp_dict, + "_length_"); + if (!length_attr) { + PyErr_SetString(PyExc_AttributeError, + "class must define a '_length_' attribute, " + "which must be a positive integer"); + goto error_decref_result; + } + } + length = PyLong_AsLongAndOverflow(length_attr, &overflow); if (overflow) { PyErr_SetString(PyExc_OverflowError, "The '_length_' attribute is too large"); - return NULL; - } - - proto = PyDict_GetItemString(typedict, "_type_"); /* Borrowed ref */ - if (!proto) { - PyErr_SetString(PyExc_AttributeError, - "class must define a '_type_' attribute"); - return NULL; + goto error_decref_result; + } + + type_attr = PyDict_GetItemString(typedict, "_type_"); + if (!type_attr) { + type_attr = PyDict_GetItemString(result->tp_base->tp_dict, "_type_"); + if (!type_attr) { + PyErr_SetString(PyExc_AttributeError, + "class must define a '_type_' attribute"); + goto error_decref_result; + } } stgdict = (StgDictObject *)PyObject_CallObject( (PyObject *)&PyCStgDict_Type, NULL); if (!stgdict) - return NULL; - - itemdict = PyType_stgdict(proto); + goto error_decref_result; + + itemdict = PyType_stgdict(type_attr); if (!itemdict) { PyErr_SetString(PyExc_TypeError, "_type_ must have storage info"); - Py_DECREF((PyObject *)stgdict); - return NULL; + goto error_decref_stgdict; } assert(itemdict->format); @@ -1309,16 +1321,12 @@ sprintf(buf, "(%ld)", length); stgdict->format = _ctypes_alloc_format_string(buf, itemdict->format); } - if (stgdict->format == NULL) { - Py_DECREF((PyObject *)stgdict); - return NULL; - } + if (stgdict->format == NULL) + goto error_decref_stgdict; stgdict->ndim = itemdict->ndim + 1; stgdict->shape = PyMem_Malloc(sizeof(Py_ssize_t *) * stgdict->ndim); - if (stgdict->shape == NULL) { - Py_DECREF((PyObject *)stgdict); - return NULL; - } + if (stgdict->shape == NULL) + goto error_decref_stgdict; stgdict->shape[0] = length; memmove(&stgdict->shape[1], itemdict->shape, sizeof(Py_ssize_t) * (stgdict->ndim - 1)); @@ -1327,7 +1335,7 @@ if (length * itemsize < 0) { PyErr_SetString(PyExc_OverflowError, "array too large"); - return NULL; + goto error_decref_stgdict; } itemalign = itemdict->align; @@ -1338,26 +1346,17 @@ stgdict->size = itemsize * length; stgdict->align = itemalign; stgdict->length = length; - Py_INCREF(proto); - stgdict->proto = proto; + Py_INCREF(type_attr); + stgdict->proto = type_attr; stgdict->paramfunc = &PyCArrayType_paramfunc; /* Arrays are passed as pointers to function calls. */ stgdict->ffi_type_pointer = ffi_type_pointer; - /* create the new instance (which is a class, - since we are a metatype!) */ - result = (PyTypeObject *)PyType_Type.tp_new(type, args, kwds); - if (result == NULL) - return NULL; - /* replace the class dict by our updated spam dict */ - if (-1 == PyDict_Update((PyObject *)stgdict, result->tp_dict)) { - Py_DECREF(result); - Py_DECREF((PyObject *)stgdict); - return NULL; - } + if (-1 == PyDict_Update((PyObject *)stgdict, result->tp_dict)) + goto error_decref_stgdict; Py_DECREF(result->tp_dict); result->tp_dict = (PyObject *)stgdict; @@ -1366,15 +1365,20 @@ */ if (itemdict->getfunc == _ctypes_get_fielddesc("c")->getfunc) { if (-1 == add_getset(result, CharArray_getsets)) - return NULL; + goto error_decref_stgdict; #ifdef CTYPES_UNICODE } else if (itemdict->getfunc == _ctypes_get_fielddesc("u")->getfunc) { if (-1 == add_getset(result, WCharArray_getsets)) - return NULL; + goto error_decref_stgdict; #endif } return (PyObject *)result; +error_decref_stgdict: + Py_DECREF((PyObject*)stgdict); +error_decref_result: + Py_DECREF(result); + return NULL; } PyTypeObject PyCArrayType_Type = {