diff --git a/Lib/test/test_bytes.py b/Lib/test/test_bytes.py --- a/Lib/test/test_bytes.py +++ b/Lib/test/test_bytes.py @@ -288,8 +288,22 @@ class BaseBytesTest(unittest.TestCase): self.assertEqual(self.type2test(b"").join(lst), b"abc") self.assertEqual(self.type2test(b"").join(tuple(lst)), b"abc") self.assertEqual(self.type2test(b"").join(iter(lst)), b"abc") - self.assertEqual(self.type2test(b".").join([b"ab", b"cd"]), b"ab.cd") - # XXX more... + dot_join = self.type2test(b".:").join + self.assertEqual(dot_join([b"ab", b"cd"]), b"ab.:cd") + self.assertEqual(dot_join([memoryview(b"ab"), b"cd"]), b"ab.:cd") + self.assertEqual(dot_join([b"ab", memoryview(b"cd")]), b"ab.:cd") + self.assertEqual(dot_join([bytearray(b"ab"), b"cd"]), b"ab.:cd") + self.assertEqual(dot_join([b"ab", bytearray(b"cd")]), b"ab.:cd") + # Stress it with many items + seq = [b"abc"] * 1000 + expected = b"abc" + b".:abc" * 999 + self.assertEqual(dot_join(seq), expected) + # Error handling and cleanup when some item in the middle of the + # sequence has the wrong type. + with self.assertRaises(TypeError): + dot_join([bytearray(b"ab"), "cd", b"ef"]) + with self.assertRaises(TypeError): + dot_join([memoryview(b"ab"), "cd", b"ef"]) def test_count(self): b = self.type2test(b'mississippi') diff --git a/Objects/bytearrayobject.c b/Objects/bytearrayobject.c --- a/Objects/bytearrayobject.c +++ b/Objects/bytearrayobject.c @@ -2569,75 +2569,106 @@ Concatenate any number of bytes/bytearra in between each pair, and return the result as a new bytearray."); static PyObject * -bytearray_join(PyByteArrayObject *self, PyObject *it) +bytearray_join(PyObject *self, PyObject *orig) { - PyObject *seq; - Py_ssize_t mysize = Py_SIZE(self); - Py_ssize_t i; - Py_ssize_t n; - PyObject **items; - Py_ssize_t totalsize = 0; - PyObject *result; - char *dest; - - seq = PySequence_Fast(it, "can only join an iterable"); - if (seq == NULL) + char *sep = PyByteArray_AS_STRING(self); + const Py_ssize_t seplen = PyByteArray_GET_SIZE(self); + PyObject *res = NULL; + char *p; + Py_ssize_t seqlen = 0; + Py_ssize_t sz = 0; + Py_ssize_t i, nbufs; + PyObject *seq, *item; + Py_buffer *buffers = NULL; +#define NB_STATIC_BUFFERS 10 + Py_buffer static_buffers[NB_STATIC_BUFFERS]; + + seq = PySequence_Fast(orig, "can only join an iterable"); + if (seq == NULL) { return NULL; - n = PySequence_Fast_GET_SIZE(seq); - items = PySequence_Fast_ITEMS(seq); - - /* Compute the total size, and check that they are all bytes */ - /* XXX Shouldn't we use _getbuffer() on these items instead? */ - for (i = 0; i < n; i++) { - PyObject *obj = items[i]; - if (!PyByteArray_Check(obj) && !PyBytes_Check(obj)) { + } + + seqlen = PySequence_Size(seq); + if (seqlen == 0) { + Py_DECREF(seq); + return PyByteArray_FromStringAndSize("", 0); + } + if (seqlen > NB_STATIC_BUFFERS) { + buffers = PyMem_NEW(Py_buffer, seqlen); + if (buffers == NULL) { + Py_DECREF(seq); + return NULL; + } + } + else { + buffers = static_buffers; + } + + /* There is at least one thing to join. + * Do a pre-pass to figure out the total amount of space we'll + * need (sz), and see whether all arguments are buffer-compatible. + */ + for (i = 0, nbufs = 0; i < seqlen; i++) { + Py_ssize_t itemlen; + item = PySequence_Fast_GET_ITEM(seq, i); + if (_getbuffer(item, &buffers[i]) < 0) { PyErr_Format(PyExc_TypeError, - "can only join an iterable of bytes " - "(item %ld has type '%.100s')", - /* XXX %ld isn't right on Win64 */ - (long)i, Py_TYPE(obj)->tp_name); + "sequence item %zd: expected bytes, bytearray, " + "or an object with a buffer interface, %.80s found", + i, Py_TYPE(item)->tp_name); goto error; } - if (i > 0) - totalsize += mysize; - totalsize += Py_SIZE(obj); - if (totalsize < 0) { - PyErr_NoMemory(); + nbufs = i + 1; /* for error cleanup */ + itemlen = buffers[i].len; + if (itemlen > PY_SSIZE_T_MAX - sz) { + PyErr_SetString(PyExc_OverflowError, + "join() result is too long for bytes"); goto error; } + sz += itemlen; + if (i != 0) { + if (seplen > PY_SSIZE_T_MAX - sz) { + PyErr_SetString(PyExc_OverflowError, + "join() result is too long for bytes"); + goto error; + } + sz += seplen; + } } - /* Allocate the result, and copy the bytes */ - result = PyByteArray_FromStringAndSize(NULL, totalsize); - if (result == NULL) + /* Allocate result space. */ + res = PyByteArray_FromStringAndSize((char *) NULL, sz); + if (res == NULL) goto error; - dest = PyByteArray_AS_STRING(result); - for (i = 0; i < n; i++) { - PyObject *obj = items[i]; - Py_ssize_t size = Py_SIZE(obj); - char *buf; - if (PyByteArray_Check(obj)) - buf = PyByteArray_AS_STRING(obj); - else - buf = PyBytes_AS_STRING(obj); + + /* Catenate everything. */ + p = PyByteArray_AS_STRING(res); + for (i = 0; i < seqlen; ++i) { + Py_ssize_t n; + char *q; if (i) { - memcpy(dest, self->ob_bytes, mysize); - dest += mysize; + Py_MEMCPY(p, sep, seplen); + p += seplen; } - memcpy(dest, buf, size); - dest += size; + n = buffers[i].len; + q = buffers[i].buf; + Py_MEMCPY(p, q, n); + p += n; } - - /* Done */ + goto done; + +error: + res = NULL; +done: Py_DECREF(seq); - return result; - - /* Error handling */ - error: - Py_DECREF(seq); - return NULL; + for (i = 0; i < nbufs; i++) + PyBuffer_Release(&buffers[i]); + if (buffers != static_buffers) + PyMem_FREE(buffers); + return res; } + PyDoc_STRVAR(splitlines__doc__, "B.splitlines([keepends]) -> list of lines\n\ \n\ diff --git a/Objects/bytesobject.c b/Objects/bytesobject.c --- a/Objects/bytesobject.c +++ b/Objects/bytesobject.c @@ -10,9 +10,18 @@ static Py_ssize_t _getbuffer(PyObject *obj, Py_buffer *view) { - PyBufferProcs *buffer = Py_TYPE(obj)->tp_as_buffer; - - if (buffer == NULL || buffer->bf_getbuffer == NULL) + PyBufferProcs *bufferprocs; + if (PyBytes_CheckExact(obj)) { + /* Fast path, e.g. for .join() of many bytes objects */ + Py_INCREF(obj); + view->obj = obj; + view->buf = PyBytes_AS_STRING(obj); + view->len = PyBytes_GET_SIZE(obj); + return view->len; + } + + bufferprocs = Py_TYPE(obj)->tp_as_buffer; + if (bufferprocs == NULL || bufferprocs->bf_getbuffer == NULL) { PyErr_Format(PyExc_TypeError, "Type %.100s doesn't support the buffer API", @@ -20,7 +29,7 @@ static Py_ssize_t return -1; } - if (buffer->bf_getbuffer(obj, view, PyBUF_SIMPLE) < 0) + if (bufferprocs->bf_getbuffer(obj, view, PyBUF_SIMPLE) < 0) return -1; return view->len; } @@ -1114,11 +1123,14 @@ bytes_join(PyObject *self, PyObject *ori PyObject *res = NULL; char *p; Py_ssize_t seqlen = 0; - size_t sz = 0; - Py_ssize_t i; + Py_ssize_t sz = 0; + Py_ssize_t i, nbufs; PyObject *seq, *item; - - seq = PySequence_Fast(orig, ""); + Py_buffer *buffers = NULL; +#define NB_STATIC_BUFFERS 10 + Py_buffer static_buffers[NB_STATIC_BUFFERS]; + + seq = PySequence_Fast(orig, "can only join an iterable"); if (seq == NULL) { return NULL; } @@ -1136,64 +1148,79 @@ bytes_join(PyObject *self, PyObject *ori return item; } } - - /* There are at least two things to join, or else we have a subclass - * of the builtin types in the sequence. - * Do a pre-pass to figure out the total amount of space we'll - * need (sz), and see whether all argument are bytes. - */ - /* XXX Shouldn't we use _getbuffer() on these items instead? */ - for (i = 0; i < seqlen; i++) { - const size_t old_sz = sz; - item = PySequence_Fast_GET_ITEM(seq, i); - if (!PyBytes_Check(item) && !PyByteArray_Check(item)) { - PyErr_Format(PyExc_TypeError, - "sequence item %zd: expected bytes," - " %.80s found", - i, Py_TYPE(item)->tp_name); - Py_DECREF(seq); - return NULL; - } - sz += Py_SIZE(item); - if (i != 0) - sz += seplen; - if (sz < old_sz || sz > PY_SSIZE_T_MAX) { - PyErr_SetString(PyExc_OverflowError, - "join() result is too long for bytes"); + if (seqlen > NB_STATIC_BUFFERS) { + buffers = PyMem_NEW(Py_buffer, seqlen); + if (buffers == NULL) { Py_DECREF(seq); return NULL; } } + else { + buffers = static_buffers; + } + + /* There are at least two things to join, or else we have a subclass + * of the builtin types in the sequence. + * Do a pre-pass to figure out the total amount of space we'll + * need (sz), and see whether all arguments are buffer-compatible. + */ + for (i = 0, nbufs = 0; i < seqlen; i++) { + Py_ssize_t itemlen; + item = PySequence_Fast_GET_ITEM(seq, i); + if (_getbuffer(item, &buffers[i]) < 0) { + PyErr_Format(PyExc_TypeError, + "sequence item %zd: expected bytes, bytearray, " + "or an object with a buffer interface, %.80s found", + i, Py_TYPE(item)->tp_name); + goto error; + } + nbufs = i + 1; /* for error cleanup */ + itemlen = buffers[i].len; + if (itemlen > PY_SSIZE_T_MAX - sz) { + PyErr_SetString(PyExc_OverflowError, + "join() result is too long for bytes"); + goto error; + } + sz += itemlen; + if (i != 0) { + if (seplen > PY_SSIZE_T_MAX - sz) { + PyErr_SetString(PyExc_OverflowError, + "join() result is too long for bytes"); + goto error; + } + sz += seplen; + } + } /* Allocate result space. */ - res = PyBytes_FromStringAndSize((char*)NULL, sz); - if (res == NULL) { - Py_DECREF(seq); - return NULL; - } + res = PyBytes_FromStringAndSize((char *) NULL, sz); + if (res == NULL) + goto error; /* Catenate everything. */ - /* I'm not worried about a PyByteArray item growing because there's - nowhere in this function where we release the GIL. */ p = PyBytes_AS_STRING(res); for (i = 0; i < seqlen; ++i) { - size_t n; + Py_ssize_t n; char *q; if (i) { Py_MEMCPY(p, sep, seplen); p += seplen; } - item = PySequence_Fast_GET_ITEM(seq, i); - n = Py_SIZE(item); - if (PyBytes_Check(item)) - q = PyBytes_AS_STRING(item); - else - q = PyByteArray_AS_STRING(item); + n = buffers[i].len; + q = buffers[i].buf; Py_MEMCPY(p, q, n); p += n; } - + goto done; + +error: + res = NULL; +done: Py_DECREF(seq); + for (i = 0; i < nbufs; i++) + PyBuffer_Release(&buffers[i]); + if (buffers != static_buffers) + PyMem_FREE(buffers); return res; }