diff -r c1a475587775 Lib/test/test_bytes.py --- a/Lib/test/test_bytes.py Mon Sep 17 08:39:13 2007 -0600 +++ b/Lib/test/test_bytes.py Mon Sep 17 10:32:54 2007 -0600 @@ -8,6 +8,7 @@ import unittest import unittest import test.test_support import test.string_tests +import copy class BytesTest(unittest.TestCase): @@ -754,9 +755,88 @@ class BytesAsStringTest(test.string_test pass +class BytesSubclass(bytes): + pass + +class BytesSubclassTest(unittest.TestCase): + + def test_basic(self): + self.assert_(issubclass(BytesSubclass, bytes)) + self.assert_(isinstance(BytesSubclass(), bytes)) + + a, b = b"abcd", b"efgh" + _a, _b = BytesSubclass(a), BytesSubclass(b) + + # test comparison operators with subclass instances + self.assert_(_a == _a) + self.assert_(_a != _b) + self.assert_(_a < _b) + self.assert_(_a <= _b) + self.assert_(_b >= _a) + self.assert_(_b > _a) + self.assert_(_a is not a) + + # test concat of subclass instances + self.assertEqual(a + b, _a + _b) + self.assertEqual(a + b, a + _b) + self.assertEqual(a + b, _a + b) + + # test repeat + self.assert_(a*5 == _a*5) + + def test_join(self): + # Make sure join returns a NEW object for single item sequences + # involving a subclass. + # Make sure that it is of the appropriate type. + s1 = BytesSubclass(b"abcd") + s2 = b"".join([s1]) + self.assert_(s1 is not s2) + self.assert_(type(s2) is bytes) + + # Test reverse, calling join on subclass + s3 = s1.join([b"abcd"]) + self.assert_(type(s3) is bytes) + + def test_pickle(self): + a = BytesSubclass(b"abcd") + a.x = 10 + a.y = BytesSubclass(b"efgh") + for proto in range(pickle.HIGHEST_PROTOCOL): + b = pickle.loads(pickle.dumps(a, proto)) + self.assertNotEqual(id(a), id(b)) + self.assertEqual(a, b) + self.assertEqual(a.x, b.x) + self.assertEqual(a.y, b.y) + self.assertEqual(type(a), type(b)) + self.assertEqual(type(a.y), type(b.y)) + + def test_copy(self): + a = BytesSubclass(b"abcd") + a.x = 10 + a.y = BytesSubclass(b"efgh") + for copy_method in (copy.copy, copy.deepcopy): + b = copy_method(a) + self.assertNotEqual(id(a), id(b)) + self.assertEqual(a, b) + self.assertEqual(a.x, b.x) + self.assertEqual(a.y, b.y) + self.assertEqual(type(a), type(b)) + self.assertEqual(type(a.y), type(b.y)) + + def test_init_override(self): + class subclass(bytes): + def __init__(self, newarg=1, *args, **kwargs): + bytes.__init__(self, *args, **kwargs) + x = subclass(4, source=b"abcd") + self.assertEqual(x, b"abcd") + x = subclass(newarg=4, source=b"abcd") + self.assertEqual(x, b"abcd") + + def test_main(): test.test_support.run_unittest(BytesTest) test.test_support.run_unittest(BytesAsStringTest) + test.test_support.run_unittest(BytesSubclassTest) if __name__ == "__main__": diff -r c1a475587775 Objects/bytesobject.c --- a/Objects/bytesobject.c Mon Sep 17 08:39:13 2007 -0600 +++ b/Objects/bytesobject.c Mon Sep 17 10:32:54 2007 -0600 @@ -2912,13 +2912,21 @@ static PyObject * static PyObject * bytes_reduce(PyBytesObject *self) { - PyObject *latin1; + PyObject *latin1, *dict; if (self->ob_bytes) latin1 = PyUnicode_DecodeLatin1(self->ob_bytes, Py_Size(self), NULL); else latin1 = PyUnicode_FromString(""); - return Py_BuildValue("(O(Ns))", Py_Type(self), latin1, "latin-1"); + + dict = PyObject_GetAttrString((PyObject *)self, "__dict__"); + if (dict == NULL) { + PyErr_Clear(); + dict = Py_None; + Py_INCREF(dict); + } + + return Py_BuildValue("(O(Ns)N)", Py_Type(self), latin1, "latin-1", dict); } static PySequenceMethods bytes_as_sequence = { @@ -3004,8 +3012,7 @@ PyTypeObject PyBytes_Type = { PyObject_GenericGetAttr, /* tp_getattro */ 0, /* tp_setattro */ &bytes_as_buffer, /* tp_as_buffer */ - /* bytes is 'final' or 'sealed' */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ bytes_doc, /* tp_doc */ 0, /* tp_traverse */ 0, /* tp_clear */