Skip to content

bpo-34481: Fix surrogate-handling in strftime #8983

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
49 changes: 37 additions & 12 deletions Lib/test/datetimetester.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ def test_tzname(self):
self.assertEqual('UTC+09:30', timezone(9.5 * HOUR).tzname(None))
self.assertEqual('UTC-00:01', timezone(timedelta(minutes=-1)).tzname(None))
self.assertEqual('XYZ', timezone(-5 * HOUR, 'XYZ').tzname(None))
# bpo-34482: Check that surrogates are handled properly.
self.assertEqual('\ud800', timezone(ZERO, '\ud800').tzname(None))

# Sub-minute offsets:
self.assertEqual('UTC+01:06:40', timezone(timedelta(0, 4000)).tzname(None))
Expand Down Expand Up @@ -1307,6 +1309,12 @@ def test_strftime(self):
except ValueError:
pass

# bpo-34482: Check that surrogates don't cause a crash.
no_op_strs = ['\ud8888', '\ud800', '\\ud800\ud800']
for tstr in no_op_strs:
with self.subTest(f'strftime: {tstr!r}'):
self.assertEqual(t.strftime(tstr), tstr)

#check that this standard extension works
t.strftime("%f")

Expand Down Expand Up @@ -1746,6 +1754,9 @@ def test_isoformat(self):
self.assertEqual(t.isoformat('T'), "0001-02-03T04:05:01.000123")
self.assertEqual(t.isoformat(' '), "0001-02-03 04:05:01.000123")
self.assertEqual(t.isoformat('\x00'), "0001-02-03\x0004:05:01.000123")
# bpo-34482: Check that surrogates are handled properly.
self.assertEqual(t.isoformat('\ud800'),
"0001-02-03\ud80004:05:01.000123")
self.assertEqual(t.isoformat(timespec='hours'), "0001-02-03T04")
self.assertEqual(t.isoformat(timespec='minutes'), "0001-02-03T04:05")
self.assertEqual(t.isoformat(timespec='seconds'), "0001-02-03T04:05:01")
Expand All @@ -1754,6 +1765,8 @@ def test_isoformat(self):
self.assertEqual(t.isoformat(timespec='auto'), "0001-02-03T04:05:01.000123")
self.assertEqual(t.isoformat(sep=' ', timespec='minutes'), "0001-02-03 04:05")
self.assertRaises(ValueError, t.isoformat, timespec='foo')
# bpo-34482: Check that surrogates are handled properly.
self.assertRaises(ValueError, t.isoformat, timespec='\ud800')
# str is ISO format with the separator forced to a blank.
self.assertEqual(str(t), "0001-02-03 04:05:01.000123")

Expand Down Expand Up @@ -2277,13 +2290,21 @@ def test_utcnow(self):
self.assertLessEqual(abs(from_timestamp - from_now), tolerance)

def test_strptime(self):
string = '2004-12-01 13:02:47.197'
format = '%Y-%m-%d %H:%M:%S.%f'
expected = _strptime._strptime_datetime(self.theclass, string, format)
got = self.theclass.strptime(string, format)
self.assertEqual(expected, got)
self.assertIs(type(expected), self.theclass)
self.assertIs(type(got), self.theclass)
inputs = [
('2004-12-01 13:02:47.197', '%Y-%m-%d %H:%M:%S.%f'),
# bpo-34482: Check that surrogates are handled properly.
('2004-12-01\ud80013:02:47.197', '%Y-%m-%d\ud800%H:%M:%S.%f'),
('2004\ud80012-01 13:02:47.197', '%Y\ud800%m-%d %H:%M:%S.%f'),
('2004-12-01 13:02\ud80047.197', '%Y-%m-%d %H:%M\ud800%S.%f'),
]
for string, format in inputs:
with self.subTest(string=string, format=format):
expected = _strptime._strptime_datetime(self.theclass, string,
format)
got = self.theclass.strptime(string, format)
self.assertEqual(expected, got)
self.assertIs(type(expected), self.theclass)
self.assertIs(type(got), self.theclass)

strptime = self.theclass.strptime
self.assertEqual(strptime("+0002", "%z").utcoffset(), 2 * MINUTE)
Expand Down Expand Up @@ -2869,6 +2890,8 @@ def test_isoformat(self):
self.assertEqual(t.isoformat(timespec='microseconds'), "12:34:56.123456")
self.assertEqual(t.isoformat(timespec='auto'), "12:34:56.123456")
self.assertRaises(ValueError, t.isoformat, timespec='monkey')
# bpo-34482: Check that surrogates are handled properly.
self.assertRaises(ValueError, t.isoformat, timespec='\ud800')

t = self.theclass(hour=12, minute=34, second=56, microsecond=999500)
self.assertEqual(t.isoformat(timespec='milliseconds'), "12:34:56.999")
Expand Down Expand Up @@ -2919,6 +2942,13 @@ def test_strftime(self):
# A naive object replaces %z and %Z with empty strings.
self.assertEqual(t.strftime("'%z' '%Z'"), "'' ''")

# bpo-34482: Check that surrogates don't cause a crash.
self.assertEqual(t.strftime('\ud800'), '\ud800')

# bpo-34482: Check that surrogates in tzinfo don't crash
tzinfo = timezone(timedelta(hours=1), '\ud800')
t.replace(tzinfo=tzinfo).strftime("%Z")

def test_format(self):
t = self.theclass(1, 2, 3, 4)
self.assertEqual(t.__format__(''), str(t))
Expand Down Expand Up @@ -3324,11 +3354,6 @@ def tzname(self, dt): return self.tz
self.assertEqual(t.strftime("%H:%M:%S"), "02:03:04")
self.assertRaises(TypeError, t.strftime, "%Z")

# Issue #6697:
if '_Fast' in self.__class__.__name__:
Badtzname.tz = '\ud800'
self.assertRaises(ValueError, t.strftime, "%Z")

def test_hash_edge_cases(self):
# Offsets that overflow a basic time.
t1 = self.theclass(0, 1, 2, 3, tzinfo=FixedOffset(1439, ""))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Add tests for proper handling of non-UTF-8-encodable strings in
:mod:`datetime` classes. Patch by Alexey Izbyshev.
60 changes: 53 additions & 7 deletions Modules/_datetimemodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -1475,8 +1475,23 @@ wrap_strftime(PyObject *object, PyObject *format, PyObject *timetuple,
assert(PyUnicode_Check(format));
/* Convert the input format to a C string and size */
pin = PyUnicode_AsUTF8AndSize(format, &flen);
if (!pin)
return NULL;
PyObject* format_esc = NULL;
if (pin == NULL) {
PyErr_Clear();

format_esc = PyUnicode_AsEncodedString(format, "utf-8", "surrogatepass");
if (format_esc == NULL) {
return NULL;
}
char *pin_tmp;
if (PyBytes_AsStringAndSize(format_esc, &pin_tmp, &flen)) {
Py_DECREF(format_esc);
return NULL;
}

pin = pin_tmp;
}
int contains_surrogates = (format_esc != NULL);

/* Scan the input format, looking for %z/%Z/%f escapes, building
* a new format. Since computing the replacements for those codes
Expand Down Expand Up @@ -1537,15 +1552,33 @@ wrap_strftime(PyObject *object, PyObject *format, PyObject *timetuple,
if (Zreplacement == NULL) {
Zreplacement = make_Zreplacement(object,
tzinfoarg);
if (Zreplacement == NULL)
if (Zreplacement == NULL) {
goto Done;
}
}
assert(Zreplacement != NULL);
assert(PyUnicode_Check(Zreplacement));
ptoappend = PyUnicode_AsUTF8AndSize(Zreplacement,
&ntoappend);
if (ptoappend == NULL)
goto Done;
if (ptoappend == NULL) {
PyErr_Clear();

PyObject *Zreplacement_old = Zreplacement;
Zreplacement = PyUnicode_AsEncodedString(
Zreplacement, "utf-8", "surrogatepass"
);
Py_DECREF(Zreplacement_old);
if (Zreplacement == NULL) {
goto Done;
}

char *p_tmp;
if (PyBytes_AsStringAndSize(Zreplacement, &p_tmp, &ntoappend)) {
goto Done;
}
ptoappend = p_tmp;
contains_surrogates = 1;
}
}
else if (ch == 'f') {
/* format microseconds */
Expand Down Expand Up @@ -1596,7 +1629,13 @@ wrap_strftime(PyObject *object, PyObject *format, PyObject *timetuple,

if (time == NULL)
goto Done;
format = PyUnicode_FromString(PyBytes_AS_STRING(newfmt));
if (!contains_surrogates) {
format = PyUnicode_FromString(PyBytes_AS_STRING(newfmt));
} else {
format = PyUnicode_Decode(PyBytes_AS_STRING(newfmt),
PyBytes_GET_SIZE(newfmt),
"utf-8", "surrogatepass");
}
if (format != NULL) {
result = _PyObject_CallMethodIdObjArgs(time, &PyId_strftime,
format, timetuple, NULL);
Expand All @@ -1605,6 +1644,8 @@ wrap_strftime(PyObject *object, PyObject *format, PyObject *timetuple,
Py_DECREF(time);
}
Done:

Py_XDECREF(format_esc);
Py_XDECREF(freplacement);
Py_XDECREF(zreplacement);
Py_XDECREF(Zreplacement);
Expand Down Expand Up @@ -4898,7 +4939,12 @@ datetime_fromisoformat(PyObject* cls, PyObject *dtstr) {
const char * dt_ptr = PyUnicode_AsUTF8AndSize(dtstr, &len);

if (dt_ptr == NULL) {
goto invalid_string_error;
if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError)) {
// Encoding errors are invalid string errors at this point
goto invalid_string_error;
} else {
goto error;
}
}

const char *p = dt_ptr;
Expand Down
99 changes: 96 additions & 3 deletions Modules/timemodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,81 @@ the C library strftime function.\n"
#define time_strlen strlen
#endif

static PyObject *
_escape_strftime_chars(PyObject* format_arg)
{
// This (and the subsequent _from_escaped_strftime_chars is a special
// case for handling surrogate characters that are failed by
// surrogateescape. It is unlikely that this will be important in any
// real, non-adversarial situation, but there's something to be said for
// correctness on this point.
//
// Specifically, this will allow "pass through" printing for characters
// like \ud800, which would have to be encoded with a non-terminating
// null byte
PyObject * substr = NULL;
PyObject * replstr = NULL;
PyObject * escaped_unicode = NULL;
PyObject * rv = NULL;

substr = PyUnicode_FromString("\\");
if (substr == NULL) {
goto error;
}

replstr = PyUnicode_FromString("\\\\");
if (replstr == NULL) {
goto error;
}

escaped_unicode = PyUnicode_Replace(format_arg, substr, replstr, -1);
if (escaped_unicode == NULL) {
goto error;
}

rv = PyUnicode_AsUnicodeEscapeString(escaped_unicode);

error:
Py_XDECREF(substr);
Py_XDECREF(replstr);
Py_XDECREF(escaped_unicode);

return rv;
}

static PyObject *
_from_escaped_strftime_chars(const char* buf, Py_ssize_t buflen)
{
PyObject * tmp = NULL;
PyObject * substr = NULL;
PyObject * replstr = NULL;
PyObject * rv = NULL;

tmp = PyUnicode_DecodeUnicodeEscape(buf, buflen, "strict");
if (tmp == NULL) {
goto error;
}

substr = PyUnicode_FromString("\\\\");
if (substr == NULL) {
goto error;
}

replstr = PyUnicode_FromString("\\");
if (replstr == NULL) {
goto error;
}

rv = PyUnicode_Replace(tmp, substr, replstr, -1);

error:
Py_XDECREF(tmp);
Py_XDECREF(replstr);
Py_XDECREF(substr);
return rv;

}

static PyObject *
time_strftime(PyObject *self, PyObject *args)
{
Expand Down Expand Up @@ -736,8 +811,21 @@ time_strftime(PyObject *self, PyObject *args)
#else
/* Convert the unicode string to an ascii one */
format = PyUnicode_EncodeLocale(format_arg, "surrogateescape");
if (format == NULL)
return NULL;
int decode_escaped = 0;
if (format == NULL) {
if (!PyErr_ExceptionMatches(PyExc_UnicodeEncodeError)) {
return NULL;
}
PyErr_Clear();

// Surrogate characters with 0x00 bytes - needs backslashreplace
decode_escaped = 1;
format = _escape_strftime_chars(format_arg);
if (format == NULL) {
return NULL;
}
}

fmt = PyBytes_AS_STRING(format);
#endif

Expand Down Expand Up @@ -809,7 +897,12 @@ time_strftime(PyObject *self, PyObject *args)
#ifdef HAVE_WCSFTIME
ret = PyUnicode_FromWideChar(outbuf, buflen);
#else
ret = PyUnicode_DecodeLocaleAndSize(outbuf, buflen, "surrogateescape");
if (decode_escaped) {
ret = _from_escaped_strftime_chars(outbuf, buflen);
}
else {
ret = PyUnicode_DecodeLocaleAndSize(outbuf, buflen, "surrogateescape");
}
#endif
PyMem_Free(outbuf);
break;
Expand Down