diff --git a/Lib/test/test_bytes.py b/Lib/test/test_bytes.py --- a/Lib/test/test_bytes.py +++ b/Lib/test/test_bytes.py @@ -288,8 +288,22 @@ class BaseBytesTest(unittest.TestCase): self.assertEqual(self.type2test(b"").join(lst), b"abc") self.assertEqual(self.type2test(b"").join(tuple(lst)), b"abc") self.assertEqual(self.type2test(b"").join(iter(lst)), b"abc") - self.assertEqual(self.type2test(b".").join([b"ab", b"cd"]), b"ab.cd") - # XXX more... + dot_join = self.type2test(b".:").join + self.assertEqual(dot_join([b"ab", b"cd"]), b"ab.:cd") + self.assertEqual(dot_join([memoryview(b"ab"), b"cd"]), b"ab.:cd") + self.assertEqual(dot_join([b"ab", memoryview(b"cd")]), b"ab.:cd") + self.assertEqual(dot_join([bytearray(b"ab"), b"cd"]), b"ab.:cd") + self.assertEqual(dot_join([b"ab", bytearray(b"cd")]), b"ab.:cd") + # Stress it with many items + seq = [b"abc"] * 1000 + expected = b"abc" + b".:abc" * 999 + self.assertEqual(dot_join(seq), expected) + # Error handling and cleanup when some item in the middle of the + # sequence has the wrong type. + with self.assertRaises(TypeError): + dot_join([bytearray(b"ab"), "cd", b"ef"]) + with self.assertRaises(TypeError): + dot_join([memoryview(b"ab"), "cd", b"ef"]) def test_count(self): b = self.type2test(b'mississippi') @@ -1249,6 +1263,11 @@ class BytearrayPEP3137Test(unittest.Test self.assertEqual(val, newval) self.assertTrue(val is not newval, expr+' returned val on a mutable object') + sep = self.marshal(b'') + newval = sep.join([val]) + self.assertEqual(val, newval) + self.assertIsNot(val, newval) + class FixedStringTest(test.string_tests.BaseTest): diff --git a/Objects/bytearrayobject.c b/Objects/bytearrayobject.c --- a/Objects/bytearrayobject.c +++ b/Objects/bytearrayobject.c @@ -1032,6 +1032,7 @@ bytearray_dealloc(PyByteArrayObject *sel #define FASTSEARCH fastsearch #define STRINGLIB(F) stringlib_##F #define STRINGLIB_CHAR char +#define STRINGLIB_SIZEOF_CHAR 1 #define STRINGLIB_LEN PyByteArray_GET_SIZE #define STRINGLIB_STR PyByteArray_AS_STRING #define STRINGLIB_NEW PyByteArray_FromStringAndSize @@ -1043,6 +1044,7 @@ bytearray_dealloc(PyByteArrayObject *sel #include "stringlib/fastsearch.h" #include "stringlib/count.h" #include "stringlib/find.h" +#include "stringlib/join.h" #include "stringlib/partition.h" #include "stringlib/split.h" #include "stringlib/ctype.h" @@ -2569,73 +2571,9 @@ Concatenate any number of bytes/bytearra in between each pair, and return the result as a new bytearray."); static PyObject * -bytearray_join(PyByteArrayObject *self, PyObject *it) +bytearray_join(PyObject *self, PyObject *iterable) { - PyObject *seq; - Py_ssize_t mysize = Py_SIZE(self); - Py_ssize_t i; - Py_ssize_t n; - PyObject **items; - Py_ssize_t totalsize = 0; - PyObject *result; - char *dest; - - seq = PySequence_Fast(it, "can only join an iterable"); - if (seq == NULL) - return NULL; - n = PySequence_Fast_GET_SIZE(seq); - items = PySequence_Fast_ITEMS(seq); - - /* Compute the total size, and check that they are all bytes */ - /* XXX Shouldn't we use _getbuffer() on these items instead? */ - for (i = 0; i < n; i++) { - PyObject *obj = items[i]; - if (!PyByteArray_Check(obj) && !PyBytes_Check(obj)) { - PyErr_Format(PyExc_TypeError, - "can only join an iterable of bytes " - "(item %ld has type '%.100s')", - /* XXX %ld isn't right on Win64 */ - (long)i, Py_TYPE(obj)->tp_name); - goto error; - } - if (i > 0) - totalsize += mysize; - totalsize += Py_SIZE(obj); - if (totalsize < 0) { - PyErr_NoMemory(); - goto error; - } - } - - /* Allocate the result, and copy the bytes */ - result = PyByteArray_FromStringAndSize(NULL, totalsize); - if (result == NULL) - goto error; - dest = PyByteArray_AS_STRING(result); - for (i = 0; i < n; i++) { - PyObject *obj = items[i]; - Py_ssize_t size = Py_SIZE(obj); - char *buf; - if (PyByteArray_Check(obj)) - buf = PyByteArray_AS_STRING(obj); - else - buf = PyBytes_AS_STRING(obj); - if (i) { - memcpy(dest, self->ob_bytes, mysize); - dest += mysize; - } - memcpy(dest, buf, size); - dest += size; - } - - /* Done */ - Py_DECREF(seq); - return result; - - /* Error handling */ - error: - Py_DECREF(seq); - return NULL; + return stringlib_bytes_join(self, iterable); } PyDoc_STRVAR(splitlines__doc__, diff --git a/Objects/bytesobject.c b/Objects/bytesobject.c --- a/Objects/bytesobject.c +++ b/Objects/bytesobject.c @@ -10,9 +10,18 @@ static Py_ssize_t _getbuffer(PyObject *obj, Py_buffer *view) { - PyBufferProcs *buffer = Py_TYPE(obj)->tp_as_buffer; + PyBufferProcs *bufferprocs; + if (PyBytes_CheckExact(obj)) { + /* Fast path, e.g. for .join() of many bytes objects */ + Py_INCREF(obj); + view->obj = obj; + view->buf = PyBytes_AS_STRING(obj); + view->len = PyBytes_GET_SIZE(obj); + return view->len; + } - if (buffer == NULL || buffer->bf_getbuffer == NULL) + bufferprocs = Py_TYPE(obj)->tp_as_buffer; + if (bufferprocs == NULL || bufferprocs->bf_getbuffer == NULL) { PyErr_Format(PyExc_TypeError, "Type %.100s doesn't support the buffer API", @@ -20,7 +29,7 @@ static Py_ssize_t return -1; } - if (buffer->bf_getbuffer(obj, view, PyBUF_SIMPLE) < 0) + if (bufferprocs->bf_getbuffer(obj, view, PyBUF_SIMPLE) < 0) return -1; return view->len; } @@ -555,6 +564,7 @@ PyBytes_AsStringAndSize(register PyObjec #include "stringlib/fastsearch.h" #include "stringlib/count.h" #include "stringlib/find.h" +#include "stringlib/join.h" #include "stringlib/partition.h" #include "stringlib/split.h" #include "stringlib/ctype.h" @@ -1107,94 +1117,9 @@ Concatenate any number of bytes objects, Example: b'.'.join([b'ab', b'pq', b'rs']) -> b'ab.pq.rs'."); static PyObject * -bytes_join(PyObject *self, PyObject *orig) +bytes_join(PyObject *self, PyObject *iterable) { - char *sep = PyBytes_AS_STRING(self); - const Py_ssize_t seplen = PyBytes_GET_SIZE(self); - PyObject *res = NULL; - char *p; - Py_ssize_t seqlen = 0; - size_t sz = 0; - Py_ssize_t i; - PyObject *seq, *item; - - seq = PySequence_Fast(orig, ""); - if (seq == NULL) { - return NULL; - } - - seqlen = PySequence_Size(seq); - if (seqlen == 0) { - Py_DECREF(seq); - return PyBytes_FromString(""); - } - if (seqlen == 1) { - item = PySequence_Fast_GET_ITEM(seq, 0); - if (PyBytes_CheckExact(item)) { - Py_INCREF(item); - Py_DECREF(seq); - return item; - } - } - - /* There are at least two things to join, or else we have a subclass - * of the builtin types in the sequence. - * Do a pre-pass to figure out the total amount of space we'll - * need (sz), and see whether all argument are bytes. - */ - /* XXX Shouldn't we use _getbuffer() on these items instead? */ - for (i = 0; i < seqlen; i++) { - const size_t old_sz = sz; - item = PySequence_Fast_GET_ITEM(seq, i); - if (!PyBytes_Check(item) && !PyByteArray_Check(item)) { - PyErr_Format(PyExc_TypeError, - "sequence item %zd: expected bytes," - " %.80s found", - i, Py_TYPE(item)->tp_name); - Py_DECREF(seq); - return NULL; - } - sz += Py_SIZE(item); - if (i != 0) - sz += seplen; - if (sz < old_sz || sz > PY_SSIZE_T_MAX) { - PyErr_SetString(PyExc_OverflowError, - "join() result is too long for bytes"); - Py_DECREF(seq); - return NULL; - } - } - - /* Allocate result space. */ - res = PyBytes_FromStringAndSize((char*)NULL, sz); - if (res == NULL) { - Py_DECREF(seq); - return NULL; - } - - /* Catenate everything. */ - /* I'm not worried about a PyByteArray item growing because there's - nowhere in this function where we release the GIL. */ - p = PyBytes_AS_STRING(res); - for (i = 0; i < seqlen; ++i) { - size_t n; - char *q; - if (i) { - Py_MEMCPY(p, sep, seplen); - p += seplen; - } - item = PySequence_Fast_GET_ITEM(seq, i); - n = Py_SIZE(item); - if (PyBytes_Check(item)) - q = PyBytes_AS_STRING(item); - else - q = PyByteArray_AS_STRING(item); - Py_MEMCPY(p, q, n); - p += n; - } - - Py_DECREF(seq); - return res; + return stringlib_bytes_join(self, iterable); } PyObject * diff --git a/Objects/stringlib/join.h b/Objects/stringlib/join.h new file mode 100644 --- /dev/null +++ b/Objects/stringlib/join.h @@ -0,0 +1,122 @@ +/* stringlib: bytes joining implementation */ + +#if STRINGLIB_SIZEOF_CHAR != 1 +#error join.h only compatible with byte-wise strings +#endif + +Py_LOCAL_INLINE(PyObject *) +STRINGLIB(bytes_join)(PyObject *sep, PyObject *iterable) +{ + char *sepstr = STRINGLIB_STR(sep); + const Py_ssize_t seplen = STRINGLIB_LEN(sep); + PyObject *res = NULL; + char *p; + Py_ssize_t seqlen = 0; + Py_ssize_t sz = 0; + Py_ssize_t i, nbufs; + PyObject *seq, *item; + Py_buffer *buffers = NULL; +#define NB_STATIC_BUFFERS 10 + Py_buffer static_buffers[NB_STATIC_BUFFERS]; + + seq = PySequence_Fast(iterable, "can only join an iterable"); + if (seq == NULL) { + return NULL; + } + + seqlen = PySequence_Fast_GET_SIZE(seq); + if (seqlen == 0) { + Py_DECREF(seq); + return STRINGLIB_NEW(NULL, 0); + } +#ifndef STRINGLIB_MUTABLE + if (seqlen == 1) { + item = PySequence_Fast_GET_ITEM(seq, 0); + if (STRINGLIB_CHECK_EXACT(item)) { + Py_INCREF(item); + Py_DECREF(seq); + return item; + } + } +#endif + if (seqlen > NB_STATIC_BUFFERS) { + buffers = PyMem_NEW(Py_buffer, seqlen); + if (buffers == NULL) { + Py_DECREF(seq); + return NULL; + } + } + else { + buffers = static_buffers; + } + + /* Here is the general case. Do a pre-pass to figure out the total + * amount of space we'll need (sz), and see whether all arguments are + * buffer-compatible. + */ + for (i = 0, nbufs = 0; i < seqlen; i++) { + Py_ssize_t itemlen; + item = PySequence_Fast_GET_ITEM(seq, i); + if (_getbuffer(item, &buffers[i]) < 0) { + PyErr_Format(PyExc_TypeError, + "sequence item %zd: expected bytes, bytearray, " + "or an object with the buffer interface, %.80s found", + i, Py_TYPE(item)->tp_name); + goto error; + } + nbufs = i + 1; /* for error cleanup */ + itemlen = buffers[i].len; + if (itemlen > PY_SSIZE_T_MAX - sz) { + PyErr_SetString(PyExc_OverflowError, + "join() result is too long"); + goto error; + } + sz += itemlen; + if (i != 0) { + if (seplen > PY_SSIZE_T_MAX - sz) { + PyErr_SetString(PyExc_OverflowError, + "join() result is too long"); + goto error; + } + sz += seplen; + } + if (seqlen != PySequence_Fast_GET_SIZE(seq)) { + PyErr_SetString(PyExc_RuntimeError, + "sequence changed size during iteration"); + goto error; + } + } + + /* Allocate result space. */ + res = STRINGLIB_NEW(NULL, sz); + if (res == NULL) + goto error; + + /* Catenate everything. */ + p = STRINGLIB_STR(res); + for (i = 0; i < nbufs; i++) { + Py_ssize_t n; + char *q; + if (i) { + Py_MEMCPY(p, sepstr, seplen); + p += seplen; + } + n = buffers[i].len; + q = buffers[i].buf; + Py_MEMCPY(p, q, n); + p += n; + } + goto done; + +error: + res = NULL; +done: + Py_DECREF(seq); + for (i = 0; i < nbufs; i++) + PyBuffer_Release(&buffers[i]); + if (buffers != static_buffers) + PyMem_FREE(buffers); + return res; +} + +