Skip to content

Commit 178df79

Browse files
[mypyc] Rewrite CPyStr_Build using a simplification of _PyUnicode_JoinArray (#10762)
This makes specialized `format()` calls faster. Closes mypyc/mypyc#876.
1 parent 6eafc5e commit 178df79

File tree

5 files changed

+99
-12
lines changed

5 files changed

+99
-12
lines changed

mypyc/irbuild/specialize.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
)
2626
from mypyc.ir.rtypes import (
2727
RType, RTuple, str_rprimitive, list_rprimitive, dict_rprimitive, set_rprimitive,
28-
bool_rprimitive, is_dict_rprimitive, c_int_rprimitive, is_str_rprimitive
28+
bool_rprimitive, is_dict_rprimitive, c_int_rprimitive, is_str_rprimitive,
29+
c_pyssize_t_rprimitive
2930
)
3031
from mypyc.primitives.dict_ops import (
3132
dict_keys_op, dict_values_op, dict_items_op, dict_setdefault_spec_init_op
@@ -378,7 +379,7 @@ def translate_str_format(
378379

379380
# The first parameter is the total size of the following PyObject* merged from
380381
# two lists alternatively.
381-
result_list: List[Value] = [Integer(0, c_int_rprimitive)]
382+
result_list: List[Value] = [Integer(0, c_pyssize_t_rprimitive)]
382383
for a, b in zip(literals, variables):
383384
if a:
384385
result_list.append(builder.load_str(a))
@@ -393,7 +394,7 @@ def translate_str_format(
393394
if not variables and len(result_list) == 2:
394395
return result_list[1]
395396

396-
result_list[0] = Integer(len(result_list) - 1, c_int_rprimitive)
397+
result_list[0] = Integer(len(result_list) - 1, c_pyssize_t_rprimitive)
397398
return builder.call_c(str_build_op, result_list, expr.line)
398399
return None
399400

mypyc/lib-rt/CPy.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ static inline char CPyDict_CheckSize(PyObject *dict, CPyTagged size) {
384384
// Str operations
385385

386386

387-
PyObject *CPyStr_Build(int len, ...);
387+
PyObject *CPyStr_Build(Py_ssize_t len, ...);
388388
PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index);
389389
PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split);
390390
PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr, PyObject *new_substr, CPyTagged max_replace);

mypyc/lib-rt/str_ops.c

Lines changed: 79 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,90 @@ PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index) {
4343
}
4444
}
4545

46-
PyObject *CPyStr_Build(int len, ...) {
47-
int i;
46+
// A simplification of _PyUnicode_JoinArray() from CPython 3.9.6
47+
PyObject *CPyStr_Build(Py_ssize_t len, ...) {
48+
Py_ssize_t i;
4849
va_list args;
50+
51+
// Calculate the total amount of space and check
52+
// whether all components have the same kind.
53+
Py_ssize_t sz = 0;
54+
Py_UCS4 maxchar = 0;
55+
int use_memcpy = 1; // Use memcpy by default
56+
PyObject *last_obj = NULL;
57+
4958
va_start(args, len);
59+
for (i = 0; i < len; i++) {
60+
PyObject *item = va_arg(args, PyObject *);
61+
if (!PyUnicode_Check(item)) {
62+
PyErr_Format(PyExc_TypeError,
63+
"sequence item %zd: expected str instance,"
64+
" %.80s found",
65+
i, Py_TYPE(item)->tp_name);
66+
return NULL;
67+
}
68+
if (PyUnicode_READY(item) == -1)
69+
return NULL;
5070

51-
PyObject *res = PyUnicode_FromObject(va_arg(args, PyObject *));
52-
for (i = 1; i < len; i++) {
53-
PyObject *str = va_arg(args, PyObject *);
54-
PyUnicode_Append(&res, str);
55-
}
71+
size_t add_sz = PyUnicode_GET_LENGTH(item);
72+
Py_UCS4 item_maxchar = PyUnicode_MAX_CHAR_VALUE(item);
73+
maxchar = Py_MAX(maxchar, item_maxchar);
5674

75+
// Using size_t to avoid overflow during arithmetic calculation
76+
if (add_sz > (size_t)(PY_SSIZE_T_MAX - sz)) {
77+
PyErr_SetString(PyExc_OverflowError,
78+
"join() result is too long for a Python string");
79+
return NULL;
80+
}
81+
sz += add_sz;
82+
83+
// If these strings have different kind, we would call
84+
// _PyUnicode_FastCopyCharacters() in the following part.
85+
if (use_memcpy && last_obj != NULL) {
86+
if (PyUnicode_KIND(last_obj) != PyUnicode_KIND(item))
87+
use_memcpy = 0;
88+
}
89+
last_obj = item;
90+
}
5791
va_end(args);
92+
93+
// Construct the string
94+
PyObject *res = PyUnicode_New(sz, maxchar);
95+
if (res == NULL)
96+
return NULL;
97+
98+
if (use_memcpy) {
99+
unsigned char *res_data = PyUnicode_1BYTE_DATA(res);
100+
unsigned int kind = PyUnicode_KIND(res);
101+
102+
va_start(args, len);
103+
for (i = 0; i < len; ++i) {
104+
PyObject *item = va_arg(args, PyObject *);
105+
Py_ssize_t itemlen = PyUnicode_GET_LENGTH(item);
106+
if (itemlen != 0) {
107+
memcpy(res_data, PyUnicode_DATA(item), kind * itemlen);
108+
res_data += kind * itemlen;
109+
}
110+
}
111+
va_end(args);
112+
assert(res_data == PyUnicode_1BYTE_DATA(res) + kind * PyUnicode_GET_LENGTH(res));
113+
} else {
114+
Py_ssize_t res_offset = 0;
115+
116+
va_start(args, len);
117+
for (i = 0; i < len; ++i) {
118+
PyObject *item = va_arg(args, PyObject *);
119+
Py_ssize_t itemlen = PyUnicode_GET_LENGTH(item);
120+
if (itemlen != 0) {
121+
_PyUnicode_FastCopyCharacters(res, res_offset, item, 0, itemlen);
122+
res_offset += itemlen;
123+
}
124+
}
125+
va_end(args);
126+
assert(res_offset == PyUnicode_GET_LENGTH(res));
127+
}
128+
129+
assert(_PyUnicode_CheckConsistency(res, 1));
58130
return res;
59131
}
60132

mypyc/primitives/str_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
)
4646

4747
str_build_op = custom_op(
48-
arg_types=[c_int_rprimitive],
48+
arg_types=[c_pyssize_t_rprimitive],
4949
return_type=str_rprimitive,
5050
c_function_name='CPyStr_Build',
5151
error_kind=ERR_MAGIC,

mypyc/test-data/run-strings.test

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,10 @@ def test_fstring_python_doc() -> None:
271271
from typing import Tuple
272272

273273
def test_format_method_basics() -> None:
274+
x = str()
275+
assert 'x{}'.format(x) == 'x'
276+
assert 'ā{}'.format(x) == 'ā'
277+
assert '😀{}'.format(x) == '😀'
274278
assert ''.format() == ''
275279
assert 'abc'.format() == 'abc'
276280
assert '{}{}'.format(1, 2) == '12'
@@ -342,6 +346,16 @@ def test_format_method_args() -> None:
342346
assert format_kwargs(x=10, y=2, z=1) == 'c10d2'
343347
assert format_kwargs_self(x=10, y=2, z=1) == "{'x': 10, 'y': 2, 'z': 1}"
344348

349+
def test_format_method_different_kind() -> None:
350+
s1 = "Literal['😀']"
351+
assert 'Revealed type is {}'.format(s1) == "Revealed type is Literal['😀']"
352+
s2 = "Revealed type is"
353+
assert "{} Literal['😀']".format(s2) == "Revealed type is Literal['😀']"
354+
s3 = "测试:"
355+
assert "{}{} {}".format(s3, s2, s1) == "测试:Revealed type is Literal['😀']"
356+
assert "Test: {}{}".format(s3, s1) == "Test: 测试:Literal['😀']"
357+
assert "Test: {}{}".format(s3, s2) == "Test: 测试:Revealed type is"
358+
345359
class Point:
346360
def __init__(self, x, y):
347361
self.x, self.y = x, y

0 commit comments

Comments
 (0)