Skip to content

Commit 2532497

Browse files
Issue #18468: The re.split, re.findall, and re.sub functions and the group()
and groups() methods of match object now always return a string or a bytes object.
1 parent 355dda8 commit 2532497

File tree

3 files changed

+131
-65
lines changed

3 files changed

+131
-65
lines changed

Lib/test/test_re.py

Lines changed: 68 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,26 @@
1717

1818
import unittest
1919

20+
class S(str):
21+
def __getitem__(self, index):
22+
return S(super().__getitem__(index))
23+
24+
class B(bytes):
25+
def __getitem__(self, index):
26+
return B(super().__getitem__(index))
27+
2028
class ReTests(unittest.TestCase):
2129

30+
def assertTypedEqual(self, actual, expect, msg=None):
31+
self.assertEqual(actual, expect, msg)
32+
def recurse(actual, expect):
33+
if isinstance(expect, (tuple, list)):
34+
for x, y in zip(actual, expect):
35+
recurse(x, y)
36+
else:
37+
self.assertIs(type(actual), type(expect), msg)
38+
recurse(actual, expect)
39+
2240
def test_keep_buffer(self):
2341
# See bug 14212
2442
b = bytearray(b'x')
@@ -53,6 +71,13 @@ def bump_num(self, matchobj):
5371
return str(int_value + 1)
5472

5573
def test_basic_re_sub(self):
74+
self.assertTypedEqual(re.sub('y', 'a', 'xyz'), 'xaz')
75+
self.assertTypedEqual(re.sub('y', S('a'), S('xyz')), 'xaz')
76+
self.assertTypedEqual(re.sub(b'y', b'a', b'xyz'), b'xaz')
77+
self.assertTypedEqual(re.sub(b'y', B(b'a'), B(b'xyz')), b'xaz')
78+
self.assertTypedEqual(re.sub(b'y', bytearray(b'a'), bytearray(b'xyz')), b'xaz')
79+
self.assertTypedEqual(re.sub(b'y', memoryview(b'a'), memoryview(b'xyz')), b'xaz')
80+
5681
self.assertEqual(re.sub("(?i)b+", "x", "bbbb BBBB"), 'x x')
5782
self.assertEqual(re.sub(r'\d+', self.bump_num, '08.2 -2 23x99y'),
5883
'9.3 -3 24x100y')
@@ -210,10 +235,22 @@ def test_re_subn(self):
210235
self.assertEqual(re.subn("b*", "x", "xyz", 2), ('xxxyz', 2))
211236

212237
def test_re_split(self):
213-
self.assertEqual(re.split(":", ":a:b::c"), ['', 'a', 'b', '', 'c'])
214-
self.assertEqual(re.split(":*", ":a:b::c"), ['', 'a', 'b', 'c'])
215-
self.assertEqual(re.split("(:*)", ":a:b::c"),
216-
['', ':', 'a', ':', 'b', '::', 'c'])
238+
for string in ":a:b::c", S(":a:b::c"):
239+
self.assertTypedEqual(re.split(":", string),
240+
['', 'a', 'b', '', 'c'])
241+
self.assertTypedEqual(re.split(":*", string),
242+
['', 'a', 'b', 'c'])
243+
self.assertTypedEqual(re.split("(:*)", string),
244+
['', ':', 'a', ':', 'b', '::', 'c'])
245+
for string in (b":a:b::c", B(b":a:b::c"), bytearray(b":a:b::c"),
246+
memoryview(b":a:b::c")):
247+
self.assertTypedEqual(re.split(b":", string),
248+
[b'', b'a', b'b', b'', b'c'])
249+
self.assertTypedEqual(re.split(b":*", string),
250+
[b'', b'a', b'b', b'c'])
251+
self.assertTypedEqual(re.split(b"(:*)", string),
252+
[b'', b':', b'a', b':', b'b', b'::', b'c'])
253+
217254
self.assertEqual(re.split("(?::*)", ":a:b::c"), ['', 'a', 'b', 'c'])
218255
self.assertEqual(re.split("(:)*", ":a:b::c"),
219256
['', ':', 'a', ':', 'b', ':', 'c'])
@@ -235,22 +272,39 @@ def test_qualified_re_split(self):
235272

236273
def test_re_findall(self):
237274
self.assertEqual(re.findall(":+", "abc"), [])
238-
self.assertEqual(re.findall(":+", "a:b::c:::d"), [":", "::", ":::"])
239-
self.assertEqual(re.findall("(:+)", "a:b::c:::d"), [":", "::", ":::"])
240-
self.assertEqual(re.findall("(:)(:*)", "a:b::c:::d"), [(":", ""),
241-
(":", ":"),
242-
(":", "::")])
275+
for string in "a:b::c:::d", S("a:b::c:::d"):
276+
self.assertTypedEqual(re.findall(":+", string),
277+
[":", "::", ":::"])
278+
self.assertTypedEqual(re.findall("(:+)", string),
279+
[":", "::", ":::"])
280+
self.assertTypedEqual(re.findall("(:)(:*)", string),
281+
[(":", ""), (":", ":"), (":", "::")])
282+
for string in (b"a:b::c:::d", B(b"a:b::c:::d"), bytearray(b"a:b::c:::d"),
283+
memoryview(b"a:b::c:::d")):
284+
self.assertTypedEqual(re.findall(b":+", string),
285+
[b":", b"::", b":::"])
286+
self.assertTypedEqual(re.findall(b"(:+)", string),
287+
[b":", b"::", b":::"])
288+
self.assertTypedEqual(re.findall(b"(:)(:*)", string),
289+
[(b":", b""), (b":", b":"), (b":", b"::")])
243290

244291
def test_bug_117612(self):
245292
self.assertEqual(re.findall(r"(a|(b))", "aba"),
246293
[("a", ""),("b", "b"),("a", "")])
247294

248295
def test_re_match(self):
249-
self.assertEqual(re.match('a', 'a').groups(), ())
250-
self.assertEqual(re.match('(a)', 'a').groups(), ('a',))
251-
self.assertEqual(re.match(r'(a)', 'a').group(0), 'a')
252-
self.assertEqual(re.match(r'(a)', 'a').group(1), 'a')
253-
self.assertEqual(re.match(r'(a)', 'a').group(1, 1), ('a', 'a'))
296+
for string in 'a', S('a'):
297+
self.assertEqual(re.match('a', string).groups(), ())
298+
self.assertEqual(re.match('(a)', string).groups(), ('a',))
299+
self.assertEqual(re.match('(a)', string).group(0), 'a')
300+
self.assertEqual(re.match('(a)', string).group(1), 'a')
301+
self.assertEqual(re.match('(a)', string).group(1, 1), ('a', 'a'))
302+
for string in b'a', B(b'a'), bytearray(b'a'), memoryview(b'a'):
303+
self.assertEqual(re.match(b'a', string).groups(), ())
304+
self.assertEqual(re.match(b'(a)', string).groups(), (b'a',))
305+
self.assertEqual(re.match(b'(a)', string).group(0), b'a')
306+
self.assertEqual(re.match(b'(a)', string).group(1), b'a')
307+
self.assertEqual(re.match(b'(a)', string).group(1, 1), (b'a', b'a'))
254308

255309
pat = re.compile('((a)|(b))(c)?')
256310
self.assertEqual(pat.match('a').groups(), ('a', 'a', None, None))

Misc/NEWS

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ Core and Builtins
4242
Library
4343
-------
4444

45+
- Issue #18468: The re.split, re.findall, and re.sub functions and the group()
46+
and groups() methods of match object now always return a string or a bytes
47+
object.
48+
4549
- Issue #18725: The textwrap module now supports truncating multiline text.
4650

4751
- Issue #18776: atexit callbacks now display their full traceback when they

Modules/_sre.c

Lines changed: 59 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1811,6 +1811,24 @@ state_fini(SRE_STATE* state)
18111811
#define STATE_OFFSET(state, member)\
18121812
(((char*)(member) - (char*)(state)->beginning) / (state)->charsize)
18131813

1814+
LOCAL(PyObject*)
1815+
getslice(int logical_charsize, const void *ptr,
1816+
PyObject* string, Py_ssize_t start, Py_ssize_t end)
1817+
{
1818+
if (logical_charsize == 1) {
1819+
if (PyBytes_CheckExact(string) &&
1820+
start == 0 && end == PyBytes_GET_SIZE(string)) {
1821+
Py_INCREF(string);
1822+
return string;
1823+
}
1824+
return PyBytes_FromStringAndSize(
1825+
(const char *)ptr + start, end - start);
1826+
}
1827+
else {
1828+
return PyUnicode_Substring(string, start, end);
1829+
}
1830+
}
1831+
18141832
LOCAL(PyObject*)
18151833
state_getslice(SRE_STATE* state, Py_ssize_t index, PyObject* string, int empty)
18161834
{
@@ -1831,7 +1849,7 @@ state_getslice(SRE_STATE* state, Py_ssize_t index, PyObject* string, int empty)
18311849
j = STATE_OFFSET(state, state->mark[index+1]);
18321850
}
18331851

1834-
return PySequence_GetSlice(string, i, j);
1852+
return getslice(state->logical_charsize, state->beginning, string, i, j);
18351853
}
18361854

18371855
static void
@@ -1992,45 +2010,6 @@ deepcopy(PyObject** object, PyObject* memo)
19922010
}
19932011
#endif
19942012

1995-
static PyObject*
1996-
join_list(PyObject* list, PyObject* string)
1997-
{
1998-
/* join list elements */
1999-
2000-
PyObject* joiner;
2001-
PyObject* function;
2002-
PyObject* args;
2003-
PyObject* result;
2004-
2005-
joiner = PySequence_GetSlice(string, 0, 0);
2006-
if (!joiner)
2007-
return NULL;
2008-
2009-
if (PyList_GET_SIZE(list) == 0) {
2010-
Py_DECREF(list);
2011-
return joiner;
2012-
}
2013-
2014-
function = PyObject_GetAttrString(joiner, "join");
2015-
if (!function) {
2016-
Py_DECREF(joiner);
2017-
return NULL;
2018-
}
2019-
args = PyTuple_New(1);
2020-
if (!args) {
2021-
Py_DECREF(function);
2022-
Py_DECREF(joiner);
2023-
return NULL;
2024-
}
2025-
PyTuple_SET_ITEM(args, 0, list);
2026-
result = PyObject_CallObject(function, args);
2027-
Py_DECREF(args); /* also removes list */
2028-
Py_DECREF(function);
2029-
Py_DECREF(joiner);
2030-
2031-
return result;
2032-
}
2033-
20342013
static PyObject*
20352014
pattern_findall(PatternObject* self, PyObject* args, PyObject* kw)
20362015
{
@@ -2086,7 +2065,8 @@ pattern_findall(PatternObject* self, PyObject* args, PyObject* kw)
20862065
case 0:
20872066
b = STATE_OFFSET(&state, state.start);
20882067
e = STATE_OFFSET(&state, state.ptr);
2089-
item = PySequence_GetSlice(string, b, e);
2068+
item = getslice(state.logical_charsize, state.beginning,
2069+
string, b, e);
20902070
if (!item)
20912071
goto error;
20922072
break;
@@ -2216,7 +2196,7 @@ pattern_split(PatternObject* self, PyObject* args, PyObject* kw)
22162196
}
22172197

22182198
/* get segment before this match */
2219-
item = PySequence_GetSlice(
2199+
item = getslice(state.logical_charsize, state.beginning,
22202200
string, STATE_OFFSET(&state, last),
22212201
STATE_OFFSET(&state, state.start)
22222202
);
@@ -2245,7 +2225,7 @@ pattern_split(PatternObject* self, PyObject* args, PyObject* kw)
22452225
}
22462226

22472227
/* get segment following last match (even if empty) */
2248-
item = PySequence_GetSlice(
2228+
item = getslice(state.logical_charsize, state.beginning,
22492229
string, STATE_OFFSET(&state, last), state.endpos
22502230
);
22512231
if (!item)
@@ -2271,6 +2251,7 @@ pattern_subx(PatternObject* self, PyObject* ptemplate, PyObject* string,
22712251
{
22722252
SRE_STATE state;
22732253
PyObject* list;
2254+
PyObject* joiner;
22742255
PyObject* item;
22752256
PyObject* filter;
22762257
PyObject* args;
@@ -2360,7 +2341,8 @@ pattern_subx(PatternObject* self, PyObject* ptemplate, PyObject* string,
23602341

23612342
if (i < b) {
23622343
/* get segment before this match */
2363-
item = PySequence_GetSlice(string, i, b);
2344+
item = getslice(state.logical_charsize, state.beginning,
2345+
string, i, b);
23642346
if (!item)
23652347
goto error;
23662348
status = PyList_Append(list, item);
@@ -2415,7 +2397,8 @@ pattern_subx(PatternObject* self, PyObject* ptemplate, PyObject* string,
24152397

24162398
/* get segment following last match */
24172399
if (i < state.endpos) {
2418-
item = PySequence_GetSlice(string, i, state.endpos);
2400+
item = getslice(state.logical_charsize, state.beginning,
2401+
string, i, state.endpos);
24192402
if (!item)
24202403
goto error;
24212404
status = PyList_Append(list, item);
@@ -2429,10 +2412,24 @@ pattern_subx(PatternObject* self, PyObject* ptemplate, PyObject* string,
24292412
Py_DECREF(filter);
24302413

24312414
/* convert list to single string (also removes list) */
2432-
item = join_list(list, string);
2433-
2434-
if (!item)
2415+
joiner = getslice(state.logical_charsize, state.beginning, string, 0, 0);
2416+
if (!joiner) {
2417+
Py_DECREF(list);
24352418
return NULL;
2419+
}
2420+
if (PyList_GET_SIZE(list) == 0) {
2421+
Py_DECREF(list);
2422+
item = joiner;
2423+
}
2424+
else {
2425+
if (state.logical_charsize == 1)
2426+
item = _PyBytes_Join(joiner, list);
2427+
else
2428+
item = PyUnicode_Join(joiner, list);
2429+
Py_DECREF(joiner);
2430+
if (!item)
2431+
return NULL;
2432+
}
24362433

24372434
if (subn)
24382435
return Py_BuildValue("Nn", item, n);
@@ -3189,6 +3186,12 @@ match_dealloc(MatchObject* self)
31893186
static PyObject*
31903187
match_getslice_by_index(MatchObject* self, Py_ssize_t index, PyObject* def)
31913188
{
3189+
Py_ssize_t length;
3190+
int logical_charsize, charsize;
3191+
Py_buffer view;
3192+
PyObject *result;
3193+
void* ptr;
3194+
31923195
if (index < 0 || index >= self->groups) {
31933196
/* raise IndexError if we were given a bad group number */
31943197
PyErr_SetString(
@@ -3206,9 +3209,14 @@ match_getslice_by_index(MatchObject* self, Py_ssize_t index, PyObject* def)
32063209
return def;
32073210
}
32083211

3209-
return PySequence_GetSlice(
3210-
self->string, self->mark[index], self->mark[index+1]
3211-
);
3212+
ptr = getstring(self->string, &length, &logical_charsize, &charsize, &view);
3213+
if (ptr == NULL)
3214+
return NULL;
3215+
result = getslice(logical_charsize, ptr,
3216+
self->string, self->mark[index], self->mark[index+1]);
3217+
if (logical_charsize == 1 && view.buf != NULL)
3218+
PyBuffer_Release(&view);
3219+
return result;
32123220
}
32133221

32143222
static Py_ssize_t

0 commit comments

Comments
 (0)