Skip to content

PYTHON-4663 Fix compatibility with dateutil timezones #1812

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 28, 2024
283 changes: 158 additions & 125 deletions bson/_cbsonmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@ struct module_state {
PyObject* Decimal128;
PyObject* Mapping;
PyObject* DatetimeMS;
PyObject* _min_datetime_ms;
PyObject* _max_datetime_ms;
PyObject* min_datetime;
PyObject* max_datetime;
PyObject* replace_args;
PyObject* replace_kwargs;
PyObject* _type_marker_str;
PyObject* _flags_str;
PyObject* _pattern_str;
Expand All @@ -80,6 +82,8 @@ struct module_state {
PyObject* _from_uuid_str;
PyObject* _as_uuid_str;
PyObject* _from_bid_str;
int64_t min_millis;
int64_t max_millis;
};

#define GETSTATE(m) ((struct module_state*)PyModule_GetState(m))
Expand Down Expand Up @@ -253,7 +257,7 @@ static PyObject* datetime_from_millis(long long millis) {
* 2. Multiply that by 1000: 253402300799000
* 3. Add in microseconds divided by 1000 253402300799999
*
* (Note: BSON doesn't support microsecond accuracy, hence the rounding.)
* (Note: BSON doesn't support microsecond accuracy, hence the truncation.)
*
* To decode we could do:
* 1. Get seconds: timestamp / 1000: 253402300799
Expand Down Expand Up @@ -376,6 +380,118 @@ static int millis_from_datetime_ms(PyObject* dt, long long* out){
return 1;
}

static PyObject* decode_datetime(PyObject* self, long long millis, const codec_options_t* options){
PyObject* naive = NULL;
PyObject* replace = NULL;
PyObject* args = NULL;
PyObject* kwargs = NULL;
PyObject* value = NULL;
struct module_state *state = GETSTATE(self);
if (options->datetime_conversion == DATETIME_MS){
return datetime_ms_from_millis(self, millis);
}

int dt_clamp = options->datetime_conversion == DATETIME_CLAMP;
int dt_auto = options->datetime_conversion == DATETIME_AUTO;

if (dt_clamp || dt_auto){
int64_t min_millis = state->min_millis;
int64_t max_millis = state->max_millis;
int64_t min_millis_offset = 0;
int64_t max_millis_offset = 0;
if (options->tz_aware && options->tzinfo && options->tzinfo != Py_None) {
PyObject* utcoffset = PyObject_CallMethodObjArgs(options->tzinfo, state->_utcoffset_str, state->min_datetime, NULL);
if (utcoffset == NULL) {
return 0;
}
if (utcoffset != Py_None) {
if (!PyDelta_Check(utcoffset)) {
PyObject* BSONError = _error("BSONError");
if (BSONError) {
PyErr_SetString(BSONError, "tzinfo.utcoffset() did not return a datetime.timedelta");
Py_DECREF(BSONError);
}
Py_DECREF(utcoffset);
return 0;
}
min_millis_offset = (PyDateTime_DELTA_GET_DAYS(utcoffset) * 86400 +
PyDateTime_DELTA_GET_SECONDS(utcoffset)) * 1000 +
(PyDateTime_DELTA_GET_MICROSECONDS(utcoffset) / 1000);
}
Py_DECREF(utcoffset);
utcoffset = PyObject_CallMethodObjArgs(options->tzinfo, state->_utcoffset_str, state->max_datetime, NULL);
if (utcoffset == NULL) {
return 0;
}
if (utcoffset != Py_None) {
if (!PyDelta_Check(utcoffset)) {
PyObject* BSONError = _error("BSONError");
if (BSONError) {
PyErr_SetString(BSONError, "tzinfo.utcoffset() did not return a datetime.timedelta");
Py_DECREF(BSONError);
}
Py_DECREF(utcoffset);
return 0;
}
max_millis_offset = (PyDateTime_DELTA_GET_DAYS(utcoffset) * 86400 +
PyDateTime_DELTA_GET_SECONDS(utcoffset)) * 1000 +
(PyDateTime_DELTA_GET_MICROSECONDS(utcoffset) / 1000);
}
Py_DECREF(utcoffset);
}
if (min_millis_offset < 0) {
min_millis -= min_millis_offset;
}

if (max_millis_offset > 0) {
max_millis -= max_millis_offset;
}

if (dt_clamp) {
if (millis < min_millis) {
millis = min_millis;
} else if (millis > max_millis) {
millis = max_millis;
}
// Continues from here to return a datetime.
} else { // dt_auto
if (millis < min_millis || millis > max_millis){
return datetime_ms_from_millis(self, millis);
}
}
}

naive = datetime_from_millis(millis);
if (!naive) {
goto invalid;
}

if (!options->tz_aware) { /* In the naive case, we're done here. */
return naive;
}
replace = PyObject_GetAttr(naive, state->_replace_str);
if (!replace) {
goto invalid;
}
value = PyObject_Call(replace, state->replace_args, state->replace_kwargs);
if (!value) {
goto invalid;
}

/* convert to local time */
if (options->tzinfo != Py_None) {
PyObject* temp = PyObject_CallMethodObjArgs(value, state->_astimezone_str, options->tzinfo, NULL);
Py_DECREF(value);
value = temp;
}
invalid:
Py_XDECREF(naive);
Py_XDECREF(replace);
Py_XDECREF(args);
Py_XDECREF(kwargs);
return value;
}

/* Just make this compatible w/ the old API. */
int buffer_write_bytes(buffer_t buffer, const char* data, int size) {
if (pymongo_buffer_write(buffer, data, size)) {
Expand Down Expand Up @@ -482,6 +598,8 @@ static int _load_python_objects(PyObject* module) {
PyObject* empty_string = NULL;
PyObject* re_compile = NULL;
PyObject* compiled = NULL;
PyObject* min_datetime_ms = NULL;
PyObject* max_datetime_ms = NULL;
struct module_state *state = GETSTATE(module);
if (!state) {
return 1;
Expand Down Expand Up @@ -530,10 +648,34 @@ static int _load_python_objects(PyObject* module) {
_load_object(&state->UUID, "uuid", "UUID") ||
_load_object(&state->Mapping, "collections.abc", "Mapping") ||
_load_object(&state->DatetimeMS, "bson.datetime_ms", "DatetimeMS") ||
_load_object(&state->_min_datetime_ms, "bson.datetime_ms", "_min_datetime_ms") ||
_load_object(&state->_max_datetime_ms, "bson.datetime_ms", "_max_datetime_ms")) {
_load_object(&min_datetime_ms, "bson.datetime_ms", "_MIN_UTC_MS") ||
_load_object(&max_datetime_ms, "bson.datetime_ms", "_MAX_UTC_MS") ||
_load_object(&state->min_datetime, "bson.datetime_ms", "_MIN_UTC") ||
_load_object(&state->max_datetime, "bson.datetime_ms", "_MAX_UTC")) {
return 1;
}

state->min_millis = PyLong_AsLongLong(min_datetime_ms);
state->max_millis = PyLong_AsLongLong(max_datetime_ms);
Py_DECREF(min_datetime_ms);
Py_DECREF(max_datetime_ms);
if ((state->min_millis == -1 || state->max_millis == -1) && PyErr_Occurred()) {
return 1;
}

/* Speed up datetime.replace(tzinfo=utc) call */
state->replace_args = PyTuple_New(0);
if (!state->replace_args) {
return 1;
}
state->replace_kwargs = PyDict_New();
if (!state->replace_kwargs) {
return 1;
}
if (PyDict_SetItem(state->replace_kwargs, state->_tzinfo_str, state->UTC) == -1) {
return 1;
}

/* Reload our REType hack too. */
empty_string = PyBytes_FromString("");
if (empty_string == NULL) {
Expand Down Expand Up @@ -1247,15 +1389,16 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
return 0;
if (utcoffset != Py_None) {
PyObject* result = PyNumber_Subtract(value, utcoffset);
Py_DECREF(utcoffset);
if (!result) {
Py_DECREF(utcoffset);
return 0;
}
millis = millis_from_datetime(result);
Py_DECREF(result);
} else {
millis = millis_from_datetime(value);
}
Py_DECREF(utcoffset);
*(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x09;
return buffer_write_int64(buffer, (int64_t)millis);
} else if (PyObject_TypeCheck(value, state->REType)) {
Expand Down Expand Up @@ -2043,11 +2186,6 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
}
case 9:
{
PyObject* naive;
PyObject* replace;
PyObject* args;
PyObject* kwargs;
PyObject* astimezone;
int64_t millis;
if (max < 8) {
goto invalid;
Expand All @@ -2056,120 +2194,7 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
millis = (int64_t)BSON_UINT64_FROM_LE(millis);
*position += 8;

if (options->datetime_conversion == DATETIME_MS){
value = datetime_ms_from_millis(self, millis);
break;
}

int dt_clamp = options->datetime_conversion == DATETIME_CLAMP;
int dt_auto = options->datetime_conversion == DATETIME_AUTO;


if (dt_clamp || dt_auto){
PyObject *min_millis_fn_res;
PyObject *max_millis_fn_res;
int64_t min_millis;
int64_t max_millis;

if (options->tz_aware){
PyObject* tzinfo = options->tzinfo;
if (tzinfo == Py_None) {
// Default to UTC.
tzinfo = state->UTC;
}
min_millis_fn_res = PyObject_CallFunctionObjArgs(state->_min_datetime_ms, tzinfo, NULL);
max_millis_fn_res = PyObject_CallFunctionObjArgs(state->_max_datetime_ms, tzinfo, NULL);
} else {
min_millis_fn_res = PyObject_CallObject(state->_min_datetime_ms, NULL);
max_millis_fn_res = PyObject_CallObject(state->_max_datetime_ms, NULL);
}

if (!min_millis_fn_res || !max_millis_fn_res){
Py_XDECREF(min_millis_fn_res);
Py_XDECREF(max_millis_fn_res);
goto invalid;
}

min_millis = PyLong_AsLongLong(min_millis_fn_res);
max_millis = PyLong_AsLongLong(max_millis_fn_res);

if ((min_millis == -1 || max_millis == -1) && PyErr_Occurred())
{
// min/max_millis check
goto invalid;
}

if (dt_clamp) {
if (millis < min_millis) {
millis = min_millis;
} else if (millis > max_millis) {
millis = max_millis;
}
// Continues from here to return a datetime.
} else { // dt_auto
if (millis < min_millis || millis > max_millis){
value = datetime_ms_from_millis(self, millis);
break; // Out-of-range so done.
}
}
}

naive = datetime_from_millis(millis);
if (!options->tz_aware) { /* In the naive case, we're done here. */
value = naive;
break;
}

if (!naive) {
goto invalid;
}
replace = PyObject_GetAttr(naive, state->_replace_str);
Py_DECREF(naive);
if (!replace) {
goto invalid;
}
args = PyTuple_New(0);
if (!args) {
Py_DECREF(replace);
goto invalid;
}
kwargs = PyDict_New();
if (!kwargs) {
Py_DECREF(replace);
Py_DECREF(args);
goto invalid;
}
if (PyDict_SetItem(kwargs, state->_tzinfo_str, state->UTC) == -1) {
Py_DECREF(replace);
Py_DECREF(args);
Py_DECREF(kwargs);
goto invalid;
}
value = PyObject_Call(replace, args, kwargs);
if (!value) {
Py_DECREF(replace);
Py_DECREF(args);
Py_DECREF(kwargs);
goto invalid;
}

/* convert to local time */
if (options->tzinfo != Py_None) {
astimezone = PyObject_GetAttr(value, state->_astimezone_str);
Py_DECREF(value);
if (!astimezone) {
Py_DECREF(replace);
Py_DECREF(args);
Py_DECREF(kwargs);
goto invalid;
}
value = PyObject_CallFunctionObjArgs(astimezone, options->tzinfo, NULL);
Py_DECREF(astimezone);
}

Py_DECREF(replace);
Py_DECREF(args);
Py_DECREF(kwargs);
value = decode_datetime(self, millis, options);
break;
}
case 11:
Expand Down Expand Up @@ -3053,6 +3078,10 @@ static int _cbson_traverse(PyObject *m, visitproc visit, void *arg) {
Py_VISIT(state->_from_uuid_str);
Py_VISIT(state->_as_uuid_str);
Py_VISIT(state->_from_bid_str);
Py_VISIT(state->min_datetime);
Py_VISIT(state->max_datetime);
Py_VISIT(state->replace_args);
Py_VISIT(state->replace_kwargs);
return 0;
}

Expand Down Expand Up @@ -3097,6 +3126,10 @@ static int _cbson_clear(PyObject *m) {
Py_CLEAR(state->_from_uuid_str);
Py_CLEAR(state->_as_uuid_str);
Py_CLEAR(state->_from_bid_str);
Py_CLEAR(state->min_datetime);
Py_CLEAR(state->max_datetime);
Py_CLEAR(state->replace_args);
Py_CLEAR(state->replace_kwargs);
return 0;
}

Expand Down
Loading
Loading