diff --git a/Doc/library/stdtypes.rst b/Doc/library/stdtypes.rst index c72b63e..84b654a 100644 --- a/Doc/library/stdtypes.rst +++ b/Doc/library/stdtypes.rst @@ -1795,6 +1795,11 @@ strings, with the exception of :func:`encode`, :func:`format` and :func:`isidentifier`, which do not make sense with these types. For converting the objects to strings, they have a :func:`decode` method. +The functions :func:`count`, :func:`find`, :func:`index`, +:func:`rfind` and :func:`rindex` have additional semantics compared to +the corresponding string functions: They also accept an integer in +range 0 to 255 (a byte) as their first argument. + Wherever one of these methods needs to interpret the bytes as characters (e.g. the :func:`is...` methods), the ASCII character set is assumed. diff --git a/Lib/test/string_tests.py b/Lib/test/string_tests.py index fca38c3..5a02b7c 100644 --- a/Lib/test/string_tests.py +++ b/Lib/test/string_tests.py @@ -28,6 +28,11 @@ class BaseTest(unittest.TestCase): # Change in subclasses to change the behaviour of fixtesttype() type2test = None + # Whether the "contained items" of the container are integers in + # range(0, 256) (i.e. bytes, bytearray) or strings of length 1 + # (str) + contains_bytes = False + # All tests pass their arguments to the testing methods # as str objects. fixtesttype() can be used to propagate # these arguments to the appropriate type @@ -116,7 +121,11 @@ class BaseTest(unittest.TestCase): self.checkequal(0, '', 'count', 'xx', sys.maxsize, 0) self.checkraises(TypeError, 'hello', 'count') - self.checkraises(TypeError, 'hello', 'count', 42) + + if self.contains_bytes: + self.checkequal(0, 'hello', 'count', 42) + else: + self.checkraises(TypeError, 'hello', 'count', 42) # For a variety of combinations, # verify that str.count() matches an equivalent function @@ -162,7 +171,11 @@ class BaseTest(unittest.TestCase): self.checkequal( 2, 'rrarrrrrrrrra', 'find', 'a', None, 6) self.checkraises(TypeError, 'hello', 'find') - self.checkraises(TypeError, 'hello', 'find', 42) + + if self.contains_bytes: + self.checkequal(-1, 'hello', 'find', 42) + else: + self.checkraises(TypeError, 'hello', 'find', 42) self.checkequal(0, '', 'find', '') self.checkequal(-1, '', 'find', '', 1, 1) @@ -216,7 +229,11 @@ class BaseTest(unittest.TestCase): self.checkequal( 2, 'rrarrrrrrrrra', 'rfind', 'a', None, 6) self.checkraises(TypeError, 'hello', 'rfind') - self.checkraises(TypeError, 'hello', 'rfind', 42) + + if self.contains_bytes: + self.checkequal(-1, 'hello', 'rfind', 42) + else: + self.checkraises(TypeError, 'hello', 'rfind', 42) # For a variety of combinations, # verify that str.rfind() matches __contains__ @@ -263,7 +280,11 @@ class BaseTest(unittest.TestCase): self.checkequal( 2, 'rrarrrrrrrrra', 'index', 'a', None, 6) self.checkraises(TypeError, 'hello', 'index') - self.checkraises(TypeError, 'hello', 'index', 42) + + if self.contains_bytes: + self.checkraises(ValueError, 'hello', 'index', 42) + else: + self.checkraises(TypeError, 'hello', 'index', 42) def test_rindex(self): self.checkequal(12, 'abcdefghiabc', 'rindex', '') @@ -285,7 +306,11 @@ class BaseTest(unittest.TestCase): self.checkequal( 2, 'rrarrrrrrrrra', 'rindex', 'a', None, 6) self.checkraises(TypeError, 'hello', 'rindex') - self.checkraises(TypeError, 'hello', 'rindex', 42) + + if self.contains_bytes: + self.checkraises(ValueError, 'hello', 'rindex', 42) + else: + self.checkraises(TypeError, 'hello', 'rindex', 42) def test_lower(self): self.checkequal('hello', 'HeLLo', 'lower') diff --git a/Lib/test/test_bytes.py b/Lib/test/test_bytes.py index 234b56c..9224a27 100644 --- a/Lib/test/test_bytes.py +++ b/Lib/test/test_bytes.py @@ -293,10 +293,27 @@ class BaseBytesTest(unittest.TestCase): def test_count(self): b = self.type2test(b'mississippi') + i = 105 + p = 112 + w = 119 + self.assertEqual(b.count(b'i'), 4) self.assertEqual(b.count(b'ss'), 2) self.assertEqual(b.count(b'w'), 0) + self.assertEqual(b.count(i), 4) + self.assertEqual(b.count(w), 0) + + self.assertEqual(b.count(b'i', 6), 2) + self.assertEqual(b.count(b'p', 6), 2) + self.assertEqual(b.count(b'i', 1, 3), 1) + self.assertEqual(b.count(b'p', 7, 9), 1) + + self.assertEqual(b.count(i, 6), 2) + self.assertEqual(b.count(p, 6), 2) + self.assertEqual(b.count(i, 1, 3), 1) + self.assertEqual(b.count(p, 7, 9), 1) + def test_startswith(self): b = self.type2test(b'hello') self.assertFalse(self.type2test().startswith(b"anything")) @@ -327,35 +344,81 @@ class BaseBytesTest(unittest.TestCase): def test_find(self): b = self.type2test(b'mississippi') + i = 105 + w = 119 + self.assertEqual(b.find(b'ss'), 2) + self.assertEqual(b.find(b'w'), -1) + self.assertEqual(b.find(b'mississippian'), -1) + + self.assertEqual(b.find(i), 1) + self.assertEqual(b.find(w), -1) + self.assertEqual(b.find(b'ss', 3), 5) self.assertEqual(b.find(b'ss', 1, 7), 2) self.assertEqual(b.find(b'ss', 1, 3), -1) - self.assertEqual(b.find(b'w'), -1) - self.assertEqual(b.find(b'mississippian'), -1) + + self.assertEqual(b.find(i, 6), 7) + self.assertEqual(b.find(i, 1, 3), 1) + self.assertEqual(b.find(w, 1, 3), -1) def test_rfind(self): b = self.type2test(b'mississippi') + i = 105 + w = 119 + self.assertEqual(b.rfind(b'ss'), 5) - self.assertEqual(b.rfind(b'ss', 3), 5) - self.assertEqual(b.rfind(b'ss', 0, 6), 2) self.assertEqual(b.rfind(b'w'), -1) self.assertEqual(b.rfind(b'mississippian'), -1) + self.assertEqual(b.rfind(i), 10) + self.assertEqual(b.rfind(w), -1) + + self.assertEqual(b.rfind(b'ss', 3), 5) + self.assertEqual(b.rfind(b'ss', 0, 6), 2) + + self.assertEqual(b.rfind(i, 1, 3), 1) + self.assertEqual(b.rfind(i, 3, 9), 7) + self.assertEqual(b.rfind(w, 1, 3), -1) + def test_index(self): - b = self.type2test(b'world') - self.assertEqual(b.index(b'w'), 0) - self.assertEqual(b.index(b'orl'), 1) - self.assertRaises(ValueError, b.index, b'worm') - self.assertRaises(ValueError, b.index, b'ldo') + b = self.type2test(b'mississippi') + i = 105 + w = 119 + + self.assertEqual(b.index(b'ss'), 2) + self.assertRaises(ValueError, b.index, b'w') + self.assertRaises(ValueError, b.index, b'mississippian') + + self.assertEqual(b.index(i), 1) + self.assertRaises(ValueError, b.index, w) + + self.assertEqual(b.index(b'ss', 3), 5) + self.assertEqual(b.index(b'ss', 1, 7), 2) + self.assertRaises(ValueError, b.index, b'ss', 1, 3) + + self.assertEqual(b.index(i, 6), 7) + self.assertEqual(b.index(i, 1, 3), 1) + self.assertRaises(ValueError, b.index, w, 1, 3) def test_rindex(self): - # XXX could be more rigorous - b = self.type2test(b'world') - self.assertEqual(b.rindex(b'w'), 0) - self.assertEqual(b.rindex(b'orl'), 1) - self.assertRaises(ValueError, b.rindex, b'worm') - self.assertRaises(ValueError, b.rindex, b'ldo') + b = self.type2test(b'mississippi') + i = 105 + w = 119 + + self.assertEqual(b.rindex(b'ss'), 5) + self.assertRaises(ValueError, b.rindex, b'w') + self.assertRaises(ValueError, b.rindex, b'mississippian') + + self.assertEqual(b.rindex(i), 10) + self.assertRaises(ValueError, b.rindex, w) + + self.assertEqual(b.rindex(b'ss', 3), 5) + self.assertEqual(b.rindex(b'ss', 0, 6), 2) + + self.assertEqual(b.rindex(i, 1, 3), 1) + self.assertEqual(b.rindex(i, 3, 9), 7) + self.assertRaises(ValueError, b.rindex, w, 1, 3) def test_replace(self): b = self.type2test(b'mississippi') @@ -531,6 +594,14 @@ class BaseBytesTest(unittest.TestCase): self.assertEqual(True, b.startswith(h, None, -2)) self.assertEqual(False, b.startswith(x, None, None)) + def test_integer_arguments_out_of_byte_range(self): + b = self.type2test(b'hello') + + for method in (b.count, b.find, b.index, b.rfind, b.rindex): + self.assertRaises(ValueError, method, -1) + self.assertRaises(ValueError, method, 256) + self.assertRaises(ValueError, method, 9999) + def test_find_etc_raise_correct_error_messages(self): # issue 11828 b = self.type2test(b'hello') @@ -1140,9 +1211,11 @@ class FixedStringTest(test.string_tests.BaseTest): class ByteArrayAsStringTest(FixedStringTest): type2test = bytearray + contains_bytes = True class BytesAsStringTest(FixedStringTest): type2test = bytes + contains_bytes = True class SubclassTest(unittest.TestCase): diff --git a/Objects/bytearrayobject.c b/Objects/bytearrayobject.c index fba5758..99822d9 100644 --- a/Objects/bytearrayobject.c +++ b/Objects/bytearrayobject.c @@ -1077,24 +1077,41 @@ Py_LOCAL_INLINE(Py_ssize_t) bytearray_find_internal(PyByteArrayObject *self, PyObject *args, int dir) { PyObject *subobj; + char byte; Py_buffer subbuf; + const char *sub; + Py_ssize_t sub_len; Py_ssize_t start=0, end=PY_SSIZE_T_MAX; Py_ssize_t res; - if (!stringlib_parse_args_finds("find/rfind/index/rindex", - args, &subobj, &start, &end)) - return -2; - if (_getbuffer(subobj, &subbuf) < 0) + if (!stringlib_parse_args_finds_byte("find/rfind/index/rindex", + args, &subobj, &byte, &start, &end)) return -2; + + if (subobj) { + if (_getbuffer(subobj, &subbuf) < 0) + return -2; + + sub = subbuf.buf; + sub_len = subbuf.len; + } + else { + sub = &byte; + sub_len = 1; + } + if (dir > 0) res = stringlib_find_slice( PyByteArray_AS_STRING(self), PyByteArray_GET_SIZE(self), - subbuf.buf, subbuf.len, start, end); + sub, sub_len, start, end); else res = stringlib_rfind_slice( PyByteArray_AS_STRING(self), PyByteArray_GET_SIZE(self), - subbuf.buf, subbuf.len, start, end); - PyBuffer_Release(&subbuf); + sub, sub_len, start, end); + + if (subobj) + PyBuffer_Release(&subbuf); + return res; } @@ -1127,23 +1144,39 @@ static PyObject * bytearray_count(PyByteArrayObject *self, PyObject *args) { PyObject *sub_obj; - const char *str = PyByteArray_AS_STRING(self); + const char *str = PyByteArray_AS_STRING(self), *sub; + Py_ssize_t sub_len; + char byte; Py_ssize_t start = 0, end = PY_SSIZE_T_MAX; + Py_buffer vsub; PyObject *count_obj; - if (!stringlib_parse_args_finds("count", args, &sub_obj, &start, &end)) + if (!stringlib_parse_args_finds_byte("count", args, &sub_obj, &byte, + &start, &end)) return NULL; - if (_getbuffer(sub_obj, &vsub) < 0) - return NULL; + if (sub_obj) { + if (_getbuffer(sub_obj, &vsub) < 0) + return NULL; + + sub = vsub.buf; + sub_len = vsub.len; + } + else { + sub = &byte; + sub_len = 1; + } ADJUST_INDICES(start, end, PyByteArray_GET_SIZE(self)); count_obj = PyLong_FromSsize_t( - stringlib_count(str + start, end - start, vsub.buf, vsub.len, PY_SSIZE_T_MAX) + stringlib_count(str + start, end - start, sub, sub_len, PY_SSIZE_T_MAX) ); - PyBuffer_Release(&vsub); + + if (sub_obj) + PyBuffer_Release(&vsub); + return count_obj; } diff --git a/Objects/bytesobject.c b/Objects/bytesobject.c index ea14be6..decd7ef 100644 --- a/Objects/bytesobject.c +++ b/Objects/bytesobject.c @@ -1237,31 +1237,42 @@ Py_LOCAL_INLINE(Py_ssize_t) bytes_find_internal(PyBytesObject *self, PyObject *args, int dir) { PyObject *subobj; + char byte; + Py_buffer subbuf; const char *sub; Py_ssize_t sub_len; Py_ssize_t start=0, end=PY_SSIZE_T_MAX; + Py_ssize_t res; - if (!stringlib_parse_args_finds("find/rfind/index/rindex", - args, &subobj, &start, &end)) + if (!stringlib_parse_args_finds_byte("find/rfind/index/rindex", + args, &subobj, &byte, &start, &end)) return -2; - if (PyBytes_Check(subobj)) { - sub = PyBytes_AS_STRING(subobj); - sub_len = PyBytes_GET_SIZE(subobj); + if (subobj) { + if (_getbuffer(subobj, &subbuf) < 0) + return -2; + + sub = subbuf.buf; + sub_len = subbuf.len; + } + else { + sub = &byte; + sub_len = 1; } - else if (PyObject_AsCharBuffer(subobj, &sub, &sub_len)) - /* XXX - the "expected a character buffer object" is pretty - confusing for a non-expert. remap to something else ? */ - return -2; if (dir > 0) - return stringlib_find_slice( + res = stringlib_find_slice( PyBytes_AS_STRING(self), PyBytes_GET_SIZE(self), sub, sub_len, start, end); else - return stringlib_rfind_slice( + res = stringlib_rfind_slice( PyBytes_AS_STRING(self), PyBytes_GET_SIZE(self), sub, sub_len, start, end); + + if (subobj) + PyBuffer_Release(&subbuf); + + return res; } @@ -1487,23 +1498,38 @@ bytes_count(PyBytesObject *self, PyObject *args) PyObject *sub_obj; const char *str = PyBytes_AS_STRING(self), *sub; Py_ssize_t sub_len; + char byte; Py_ssize_t start = 0, end = PY_SSIZE_T_MAX; - if (!stringlib_parse_args_finds("count", args, &sub_obj, &start, &end)) + Py_buffer vsub; + PyObject *count_obj; + + if (!stringlib_parse_args_finds_byte("count", args, &sub_obj, &byte, + &start, &end)) return NULL; - if (PyBytes_Check(sub_obj)) { - sub = PyBytes_AS_STRING(sub_obj); - sub_len = PyBytes_GET_SIZE(sub_obj); + if (sub_obj) { + if (_getbuffer(sub_obj, &vsub) < 0) + return NULL; + + sub = vsub.buf; + sub_len = vsub.len; + } + else { + sub = &byte; + sub_len = 1; } - else if (PyObject_AsCharBuffer(sub_obj, &sub, &sub_len)) - return NULL; ADJUST_INDICES(start, end, PyBytes_GET_SIZE(self)); - return PyLong_FromSsize_t( + count_obj = PyLong_FromSsize_t( stringlib_count(str + start, end - start, sub, sub_len, PY_SSIZE_T_MAX) ); + + if (sub_obj) + PyBuffer_Release(&vsub); + + return count_obj; } diff --git a/Objects/stringlib/find.h b/Objects/stringlib/find.h index ce615dc..2c47f0f 100644 --- a/Objects/stringlib/find.h +++ b/Objects/stringlib/find.h @@ -140,6 +140,47 @@ stringlib_parse_args_finds(const char * function_name, PyObject *args, #undef FORMAT_BUFFER_SIZE +/* +Wraps stringlib_parse_args_finds() and additionally checks whether the +first argument is an integer in range(0, 256). + +If this is the case, writes the integer value to the byte parameter +and sets subobj to NULL. Otherwise, sets the first argument to subobj +and doesn't touch byte. The other parameters are similar to those of +stringlib_parse_args_finds(). +*/ + +Py_LOCAL_INLINE(int) +stringlib_parse_args_finds_byte(const char *function_name, PyObject *args, + PyObject **subobj, char *byte, + Py_ssize_t *start, Py_ssize_t *end) +{ + PyObject *tmp_subobj; + Py_ssize_t ival; + + if(!stringlib_parse_args_finds(function_name, args, &tmp_subobj, + start, end)) + return 0; + + ival = PyNumber_AsSsize_t(tmp_subobj, PyExc_ValueError); + if (ival == -1 && PyErr_Occurred()) { + PyErr_Clear(); + *subobj = tmp_subobj; + } + else { + /* The first argument was an integer */ + if(ival < 0 || ival > 255) { + PyErr_SetString(PyExc_ValueError, "byte must be in range(0, 256)"); + return 0; + } + + *subobj = NULL; + *byte = (char)ival; + } + + return 1; +} + #if STRINGLIB_IS_UNICODE /*