Skip to content

Commit 8ce9f16

Browse files
committed
PyUnicode_Join(): Two primary aims:
1. u1.join([u2]) is u2 2. Be more careful about C-level int overflow. Since PySequence_Fast() isn't needed to achieve #1, it's not used -- but the code could sure be simpler if it were.
1 parent 00f8da7 commit 8ce9f16

File tree

1 file changed

+120
-40
lines changed

1 file changed

+120
-40
lines changed

Objects/unicodeobject.c

Lines changed: 120 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3975,49 +3975,110 @@ int fixtitle(PyUnicodeObject *self)
39753975
return 1;
39763976
}
39773977

3978-
PyObject *PyUnicode_Join(PyObject *separator,
3979-
PyObject *seq)
3978+
PyObject *
3979+
PyUnicode_Join(PyObject *separator, PyObject *seq)
39803980
{
3981+
PyObject *internal_separator = NULL;
39813982
Py_UNICODE *sep;
3982-
int seplen;
3983+
size_t seplen;
39833984
PyUnicodeObject *res = NULL;
3984-
int reslen = 0;
3985-
Py_UNICODE *p;
3986-
int sz = 100;
3985+
size_t sz; /* # allocated bytes for string in res */
3986+
size_t reslen; /* # used bytes */
3987+
Py_UNICODE *p; /* pointer to free byte in res's string area */
3988+
PyObject *it; /* iterator */
3989+
PyObject *item;
39873990
int i;
3988-
PyObject *it;
3991+
PyObject *temp;
39893992

39903993
it = PyObject_GetIter(seq);
39913994
if (it == NULL)
39923995
return NULL;
39933996

3997+
item = PyIter_Next(it);
3998+
if (item == NULL) {
3999+
if (PyErr_Occurred())
4000+
goto onError;
4001+
/* empty sequence; return u"" */
4002+
res = _PyUnicode_New(0);
4003+
goto Done;
4004+
}
4005+
4006+
/* If this is the only item, maybe we can get out cheap. */
4007+
res = (PyUnicodeObject *)item;
4008+
item = PyIter_Next(it);
4009+
if (item == NULL) {
4010+
if (PyErr_Occurred())
4011+
goto onError;
4012+
/* There's only one item in the sequence. */
4013+
if (PyUnicode_CheckExact(res)) /* whatever.join([u]) -> u */
4014+
goto Done;
4015+
}
4016+
4017+
/* There are at least two to join (item != NULL), or there's only
4018+
* one but it's not an exact Unicode (item == NULL). res needs
4019+
* conversion to Unicode in either case.
4020+
* Caution: we may need to ensure a copy is made, and that's trickier
4021+
* than it sounds because, e.g., PyUnicode_FromObject() may return
4022+
* a shared object (which must not be mutated).
4023+
*/
4024+
if (! PyUnicode_Check(res) && ! PyString_Check(res)) {
4025+
PyErr_Format(PyExc_TypeError,
4026+
"sequence item 0: expected string or Unicode,"
4027+
" %.80s found",
4028+
res->ob_type->tp_name);
4029+
Py_XDECREF(item);
4030+
goto onError;
4031+
}
4032+
temp = PyUnicode_FromObject((PyObject *)res);
4033+
if (temp == NULL) {
4034+
Py_XDECREF(item);
4035+
goto onError;
4036+
}
4037+
Py_DECREF(res);
4038+
if (item == NULL) {
4039+
/* res was the only item */
4040+
res = (PyUnicodeObject *)temp;
4041+
goto Done;
4042+
}
4043+
/* There are at least two items. As above, temp may be a shared object,
4044+
* so we need to copy it.
4045+
*/
4046+
reslen = PyUnicode_GET_SIZE(temp);
4047+
sz = reslen + 100; /* breathing room */
4048+
if (sz < reslen || sz > INT_MAX) /* overflow -- no breathing room */
4049+
sz = reslen;
4050+
res = _PyUnicode_New(sz);
4051+
if (res == NULL) {
4052+
Py_DECREF(item);
4053+
goto onError;
4054+
}
4055+
p = PyUnicode_AS_UNICODE(res);
4056+
Py_UNICODE_COPY(p, PyUnicode_AS_UNICODE(temp), (int)reslen);
4057+
p += reslen;
4058+
Py_DECREF(temp);
4059+
39944060
if (separator == NULL) {
39954061
Py_UNICODE blank = ' ';
39964062
sep = &blank;
39974063
seplen = 1;
39984064
}
39994065
else {
4000-
separator = PyUnicode_FromObject(separator);
4001-
if (separator == NULL)
4066+
internal_separator = PyUnicode_FromObject(separator);
4067+
if (internal_separator == NULL) {
4068+
Py_DECREF(item);
40024069
goto onError;
4003-
sep = PyUnicode_AS_UNICODE(separator);
4004-
seplen = PyUnicode_GET_SIZE(separator);
4070+
}
4071+
sep = PyUnicode_AS_UNICODE(internal_separator);
4072+
seplen = PyUnicode_GET_SIZE(internal_separator);
40054073
}
40064074

4007-
res = _PyUnicode_New(sz);
4008-
if (res == NULL)
4009-
goto onError;
4010-
p = PyUnicode_AS_UNICODE(res);
4011-
reslen = 0;
4075+
i = 1;
4076+
do {
4077+
size_t itemlen;
4078+
size_t newreslen;
40124079

4013-
for (i = 0; ; ++i) {
4014-
int itemlen;
4015-
PyObject *item = PyIter_Next(it);
4016-
if (item == NULL) {
4017-
if (PyErr_Occurred())
4018-
goto onError;
4019-
break;
4020-
}
4080+
/* Catenate the separator, then item. */
4081+
/* First convert item to Unicode. */
40214082
if (!PyUnicode_Check(item)) {
40224083
PyObject *v;
40234084
if (!PyString_Check(item)) {
@@ -4034,36 +4095,55 @@ PyObject *PyUnicode_Join(PyObject *separator,
40344095
if (item == NULL)
40354096
goto onError;
40364097
}
4098+
/* Make sure we have enough space for the separator and the item. */
40374099
itemlen = PyUnicode_GET_SIZE(item);
4038-
while (reslen + itemlen + seplen >= sz) {
4039-
if (_PyUnicode_Resize(&res, sz*2) < 0) {
4100+
newreslen = reslen + seplen + itemlen;
4101+
if (newreslen < reslen || newreslen > INT_MAX)
4102+
goto Overflow;
4103+
if (newreslen > sz) {
4104+
do {
4105+
size_t oldsize = sz;
4106+
sz += sz;
4107+
if (sz < oldsize || sz > INT_MAX)
4108+
goto Overflow;
4109+
} while (newreslen > sz);
4110+
if (_PyUnicode_Resize(&res, (int)sz) < 0) {
40404111
Py_DECREF(item);
40414112
goto onError;
40424113
}
4043-
sz *= 2;
4044-
p = PyUnicode_AS_UNICODE(res) + reslen;
4045-
}
4046-
if (i > 0) {
4047-
Py_UNICODE_COPY(p, sep, seplen);
4048-
p += seplen;
4049-
reslen += seplen;
4114+
p = PyUnicode_AS_UNICODE(res) + reslen;
40504115
}
4051-
Py_UNICODE_COPY(p, PyUnicode_AS_UNICODE(item), itemlen);
4116+
Py_UNICODE_COPY(p, sep, (int)seplen);
4117+
p += seplen;
4118+
Py_UNICODE_COPY(p, PyUnicode_AS_UNICODE(item), (int)itemlen);
40524119
p += itemlen;
4053-
reslen += itemlen;
40544120
Py_DECREF(item);
4055-
}
4056-
if (_PyUnicode_Resize(&res, reslen) < 0)
4121+
reslen = newreslen;
4122+
4123+
++i;
4124+
item = PyIter_Next(it);
4125+
} while (item != NULL);
4126+
if (PyErr_Occurred())
40574127
goto onError;
40584128

4059-
Py_XDECREF(separator);
4129+
if (_PyUnicode_Resize(&res, (int)reslen) < 0)
4130+
goto onError;
4131+
4132+
Done:
4133+
Py_XDECREF(internal_separator);
40604134
Py_DECREF(it);
40614135
return (PyObject *)res;
40624136

4137+
Overflow:
4138+
PyErr_SetString(PyExc_OverflowError,
4139+
"join() is too long for a Python string");
4140+
Py_DECREF(item);
4141+
/* fall through */
4142+
40634143
onError:
4064-
Py_XDECREF(separator);
4065-
Py_XDECREF(res);
4144+
Py_XDECREF(internal_separator);
40664145
Py_DECREF(it);
4146+
Py_XDECREF(res);
40674147
return NULL;
40684148
}
40694149

0 commit comments

Comments
 (0)