diff -r 220b34cbd711 Lib/test/test_re.py --- a/Lib/test/test_re.py Sun Sep 29 18:59:27 2013 -0700 +++ b/Lib/test/test_re.py Tue Oct 01 22:25:20 2013 +0300 @@ -17,8 +17,26 @@ import unittest +class S(str): + def __getitem__(self, index): + return S(super().__getitem__(index)) + +class B(bytes): + def __getitem__(self, index): + return B(super().__getitem__(index)) + class ReTests(unittest.TestCase): + def assertTypedEqual(self, actual, expect, msg=None): + self.assertEqual(actual, expect, msg) + def recurse(actual, expect): + if isinstance(expect, (tuple, list)): + for x, y in zip(actual, expect): + recurse(x, y) + else: + self.assertIs(type(actual), type(expect), msg) + recurse(actual, expect) + def test_keep_buffer(self): # See bug 14212 b = bytearray(b'x') @@ -53,6 +71,13 @@ return str(int_value + 1) def test_basic_re_sub(self): + self.assertTypedEqual(re.sub('y', 'a', 'xyz'), 'xaz') + self.assertTypedEqual(re.sub('y', S('a'), S('xyz')), 'xaz') + self.assertTypedEqual(re.sub(b'y', b'a', b'xyz'), b'xaz') + self.assertTypedEqual(re.sub(b'y', B(b'a'), B(b'xyz')), b'xaz') + self.assertTypedEqual(re.sub(b'y', bytearray(b'a'), bytearray(b'xyz')), b'xaz') + self.assertTypedEqual(re.sub(b'y', memoryview(b'a'), memoryview(b'xyz')), b'xaz') + self.assertEqual(re.sub("(?i)b+", "x", "bbbb BBBB"), 'x x') self.assertEqual(re.sub(r'\d+', self.bump_num, '08.2 -2 23x99y'), '9.3 -3 24x100y') @@ -210,10 +235,22 @@ self.assertEqual(re.subn("b*", "x", "xyz", 2), ('xxxyz', 2)) def test_re_split(self): - self.assertEqual(re.split(":", ":a:b::c"), ['', 'a', 'b', '', 'c']) - self.assertEqual(re.split(":*", ":a:b::c"), ['', 'a', 'b', 'c']) - self.assertEqual(re.split("(:*)", ":a:b::c"), - ['', ':', 'a', ':', 'b', '::', 'c']) + for string in ":a:b::c", S(":a:b::c"): + self.assertTypedEqual(re.split(":", string), + ['', 'a', 'b', '', 'c']) + self.assertTypedEqual(re.split(":*", string), + ['', 'a', 'b', 'c']) + self.assertTypedEqual(re.split("(:*)", string), + ['', ':', 'a', ':', 'b', '::', 'c']) + for string in (b":a:b::c", B(b":a:b::c"), bytearray(b":a:b::c"), + memoryview(b":a:b::c")): + self.assertTypedEqual(re.split(b":", string), + [b'', b'a', b'b', b'', b'c']) + self.assertTypedEqual(re.split(b":*", string), + [b'', b'a', b'b', b'c']) + self.assertTypedEqual(re.split(b"(:*)", string), + [b'', b':', b'a', b':', b'b', b'::', b'c']) + self.assertEqual(re.split("(?::*)", ":a:b::c"), ['', 'a', 'b', 'c']) self.assertEqual(re.split("(:)*", ":a:b::c"), ['', ':', 'a', ':', 'b', ':', 'c']) @@ -235,22 +272,39 @@ def test_re_findall(self): self.assertEqual(re.findall(":+", "abc"), []) - self.assertEqual(re.findall(":+", "a:b::c:::d"), [":", "::", ":::"]) - self.assertEqual(re.findall("(:+)", "a:b::c:::d"), [":", "::", ":::"]) - self.assertEqual(re.findall("(:)(:*)", "a:b::c:::d"), [(":", ""), - (":", ":"), - (":", "::")]) + for string in "a:b::c:::d", S("a:b::c:::d"): + self.assertTypedEqual(re.findall(":+", string), + [":", "::", ":::"]) + self.assertTypedEqual(re.findall("(:+)", string), + [":", "::", ":::"]) + self.assertTypedEqual(re.findall("(:)(:*)", string), + [(":", ""), (":", ":"), (":", "::")]) + for string in (b"a:b::c:::d", B(b"a:b::c:::d"), bytearray(b"a:b::c:::d"), + memoryview(b"a:b::c:::d")): + self.assertTypedEqual(re.findall(b":+", string), + [b":", b"::", b":::"]) + self.assertTypedEqual(re.findall(b"(:+)", string), + [b":", b"::", b":::"]) + self.assertTypedEqual(re.findall(b"(:)(:*)", string), + [(b":", b""), (b":", b":"), (b":", b"::")]) def test_bug_117612(self): self.assertEqual(re.findall(r"(a|(b))", "aba"), [("a", ""),("b", "b"),("a", "")]) def test_re_match(self): - self.assertEqual(re.match('a', 'a').groups(), ()) - self.assertEqual(re.match('(a)', 'a').groups(), ('a',)) - self.assertEqual(re.match(r'(a)', 'a').group(0), 'a') - self.assertEqual(re.match(r'(a)', 'a').group(1), 'a') - self.assertEqual(re.match(r'(a)', 'a').group(1, 1), ('a', 'a')) + for string in 'a', S('a'): + self.assertEqual(re.match('a', string).groups(), ()) + self.assertEqual(re.match('(a)', string).groups(), ('a',)) + self.assertEqual(re.match(r'(a)', string).group(0), 'a') + self.assertEqual(re.match(r'(a)', string).group(1), 'a') + self.assertEqual(re.match(r'(a)', string).group(1, 1), ('a', 'a')) + for string in b'a', bytearray(b'a'), memoryview(b'a'): + self.assertEqual(re.match(b'a', string).groups(), ()) + self.assertEqual(re.match(b'(a)', string).groups(), (b'a',)) + self.assertEqual(re.match(b'(a)', string).group(0), b'a') + self.assertEqual(re.match(b'(a)', string).group(1), b'a') + self.assertEqual(re.match(b'(a)', string).group(1, 1), (b'a', b'a')) pat = re.compile('((a)|(b))(c)?') self.assertEqual(pat.match('a').groups(), ('a', 'a', None, None)) diff -r 220b34cbd711 Modules/_sre.c --- a/Modules/_sre.c Sun Sep 29 18:59:27 2013 -0700 +++ b/Modules/_sre.c Tue Oct 01 22:25:20 2013 +0300 @@ -1812,6 +1812,24 @@ (((char*)(member) - (char*)(state)->beginning) / (state)->charsize) LOCAL(PyObject*) +getslice(int logical_charsize, const void *ptr, + PyObject* string, Py_ssize_t start, Py_ssize_t end) +{ + if (logical_charsize == 1) { + if (PyBytes_CheckExact(string) && + start == 0 && end == PyBytes_GET_SIZE(string)) { + Py_INCREF(string); + return string; + } + return PyBytes_FromStringAndSize( + (const char *)ptr + start, end - start); + } + else { + return PyUnicode_Substring(string, start, end); + } +} + +LOCAL(PyObject*) state_getslice(SRE_STATE* state, Py_ssize_t index, PyObject* string, int empty) { Py_ssize_t i, j; @@ -1831,7 +1849,7 @@ j = STATE_OFFSET(state, state->mark[index+1]); } - return PySequence_GetSlice(string, i, j); + return getslice(state->logical_charsize, state->beginning, string, i, j); } static void @@ -1993,45 +2011,6 @@ #endif static PyObject* -join_list(PyObject* list, PyObject* string) -{ - /* join list elements */ - - PyObject* joiner; - PyObject* function; - PyObject* args; - PyObject* result; - - joiner = PySequence_GetSlice(string, 0, 0); - if (!joiner) - return NULL; - - if (PyList_GET_SIZE(list) == 0) { - Py_DECREF(list); - return joiner; - } - - function = PyObject_GetAttrString(joiner, "join"); - if (!function) { - Py_DECREF(joiner); - return NULL; - } - args = PyTuple_New(1); - if (!args) { - Py_DECREF(function); - Py_DECREF(joiner); - return NULL; - } - PyTuple_SET_ITEM(args, 0, list); - result = PyObject_CallObject(function, args); - Py_DECREF(args); /* also removes list */ - Py_DECREF(function); - Py_DECREF(joiner); - - return result; -} - -static PyObject* pattern_findall(PatternObject* self, PyObject* args, PyObject* kw) { SRE_STATE state; @@ -2086,7 +2065,8 @@ case 0: b = STATE_OFFSET(&state, state.start); e = STATE_OFFSET(&state, state.ptr); - item = PySequence_GetSlice(string, b, e); + item = getslice(state.logical_charsize, state.beginning, + string, b, e); if (!item) goto error; break; @@ -2216,7 +2196,7 @@ } /* get segment before this match */ - item = PySequence_GetSlice( + item = getslice(state.logical_charsize, state.beginning, string, STATE_OFFSET(&state, last), STATE_OFFSET(&state, state.start) ); @@ -2245,7 +2225,7 @@ } /* get segment following last match (even if empty) */ - item = PySequence_GetSlice( + item = getslice(state.logical_charsize, state.beginning, string, STATE_OFFSET(&state, last), state.endpos ); if (!item) @@ -2271,6 +2251,7 @@ { SRE_STATE state; PyObject* list; + PyObject* joiner; PyObject* item; PyObject* filter; PyObject* args; @@ -2360,7 +2341,8 @@ if (i < b) { /* get segment before this match */ - item = PySequence_GetSlice(string, i, b); + item = getslice(state.logical_charsize, state.beginning, + string, i, b); if (!item) goto error; status = PyList_Append(list, item); @@ -2415,7 +2397,8 @@ /* get segment following last match */ if (i < state.endpos) { - item = PySequence_GetSlice(string, i, state.endpos); + item = getslice(state.logical_charsize, state.beginning, + string, i, state.endpos); if (!item) goto error; status = PyList_Append(list, item); @@ -2429,10 +2412,21 @@ Py_DECREF(filter); /* convert list to single string (also removes list) */ - item = join_list(list, string); - - if (!item) + joiner = getslice(state.logical_charsize, state.beginning, string, 0, 0); + if (!joiner) { + Py_DECREF(list); return NULL; + } + if (PyList_GET_SIZE(list) == 0) { + Py_DECREF(list); + item = joiner; + } + else { + item = PyObject_CallMethod(joiner, "join", "N", list); + Py_DECREF(joiner); + if (!item) + return NULL; + } if (subn) return Py_BuildValue("Nn", item, n); @@ -3189,6 +3183,12 @@ static PyObject* match_getslice_by_index(MatchObject* self, Py_ssize_t index, PyObject* def) { + Py_ssize_t length; + int logical_charsize, charsize; + Py_buffer view; + PyObject *result; + void* ptr; + if (index < 0 || index >= self->groups) { /* raise IndexError if we were given a bad group number */ PyErr_SetString( @@ -3206,9 +3206,14 @@ return def; } - return PySequence_GetSlice( - self->string, self->mark[index], self->mark[index+1] - ); + ptr = getstring(self->string, &length, &logical_charsize, &charsize, &view); + if (ptr == NULL) + return NULL; + result = getslice(logical_charsize, ptr, + self->string, self->mark[index], self->mark[index+1]); + if (logical_charsize == 1 && view.buf != NULL) + PyBuffer_Release(&view); + return result; } static Py_ssize_t