Skip to content

Commit f499c73

Browse files
Make TzInfo picklable (#770)
Co-authored-by: Samuel Colvin <[email protected]>
1 parent 16a05d3 commit f499c73

File tree

5 files changed

+58
-1
lines changed

5 files changed

+58
-1
lines changed

python/pydantic_core/_pydantic_core.pyi

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import datetime
34
import decimal
45
import sys
56
from typing import Any, Callable, Generic, Optional, Type, TypeVar
@@ -43,6 +44,7 @@ __all__ = [
4344
'to_json',
4445
'to_jsonable_python',
4546
'list_all_errors',
47+
'TzInfo',
4648
]
4749
__version__: str
4850
build_profile: str
@@ -323,3 +325,10 @@ def list_all_errors() -> list[ErrorTypeInfo]:
323325
"""
324326
Get information about all built-in errors.
325327
"""
328+
329+
@final
330+
class TzInfo(datetime.tzinfo):
331+
def tzname(self, _dt: datetime.datetime | None) -> str | None: ...
332+
def utcoffset(self, _dt: datetime.datetime | None) -> datetime.timedelta: ...
333+
def dst(self, _dt: datetime.datetime | None) -> datetime.timedelta: ...
334+
def __deepcopy__(self, _memo: dict[Any, Any]) -> 'TzInfo': ...

src/input/datetime.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
use pyo3::intern;
22
use pyo3::prelude::*;
3+
34
use pyo3::types::{PyDate, PyDateTime, PyDelta, PyDeltaAccess, PyDict, PyTime, PyTzInfo};
45
use speedate::MicrosecondsPrecisionOverflowBehavior;
56
use speedate::{Date, DateTime, Duration, ParseError, Time, TimeConfig};
67
use std::borrow::Cow;
8+
79
use strum::EnumMessage;
810

911
use super::Input;
@@ -463,7 +465,7 @@ pub fn float_as_duration<'a>(input: &'a impl Input<'a>, total_seconds: f64) -> V
463465
#[pyclass(module = "pydantic_core._pydantic_core", extends = PyTzInfo)]
464466
#[derive(Clone)]
465467
#[cfg_attr(debug_assertions, derive(Debug))]
466-
struct TzInfo {
468+
pub struct TzInfo {
467469
seconds: i32,
468470
}
469471

@@ -502,4 +504,10 @@ impl TzInfo {
502504
fn __deepcopy__(&self, py: Python, _memo: &PyDict) -> PyResult<Py<Self>> {
503505
Py::new(py, self.clone())
504506
}
507+
508+
pub fn __reduce__(&self, py: Python) -> PyResult<PyObject> {
509+
let args = (self.seconds,);
510+
let cls = Py::new(py, self.clone())?.getattr(py, "__class__")?;
511+
Ok((cls, args).into_py(py))
512+
}
505513
}

src/input/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ mod parse_json;
1010
mod return_enums;
1111
mod shared;
1212

13+
pub use datetime::TzInfo;
1314
pub(crate) use datetime::{
1415
duration_as_pytimedelta, pydate_as_date, pydatetime_as_datetime, pytime_as_time, pytimedelta_as_duration,
1516
EitherDate, EitherDateTime, EitherTime, EitherTimedelta,

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ mod url;
2424
mod validators;
2525

2626
// required for benchmarks
27+
pub use self::input::TzInfo;
2728
pub use self::url::{PyMultiHostUrl, PyUrl};
2829
pub use argument_markers::{ArgsKwargs, PydanticUndefinedType};
2930
pub use build_tools::SchemaError;
@@ -93,6 +94,7 @@ fn _pydantic_core(py: Python, m: &PyModule) -> PyResult<()> {
9394
m.add_class::<PyMultiHostUrl>()?;
9495
m.add_class::<ArgsKwargs>()?;
9596
m.add_class::<SchemaSerializer>()?;
97+
m.add_class::<TzInfo>()?;
9698
m.add_function(wrap_pyfunction!(to_json, m)?)?;
9799
m.add_function(wrap_pyfunction!(to_jsonable_python, m)?)?;
98100
m.add_function(wrap_pyfunction!(list_all_errors, m)?)?;

tests/validators/test_datetime.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import copy
22
import json
3+
import pickle
34
import platform
45
import re
56
from datetime import date, datetime, time, timedelta, timezone, tzinfo
67
from decimal import Decimal
8+
from typing import Dict
79

810
import pytest
911
import pytz
@@ -476,3 +478,38 @@ def test_tz_constraint_too_high():
476478
def test_tz_constraint_wrong():
477479
with pytest.raises(SchemaError, match="Input should be 'aware' or 'naive"):
478480
SchemaValidator(core_schema.datetime_schema(tz_constraint='wrong'))
481+
482+
483+
def test_tz_pickle() -> None:
484+
"""
485+
https://github.com/pydantic/pydantic-core/issues/589
486+
"""
487+
v = SchemaValidator(core_schema.datetime_schema())
488+
original = datetime(2022, 6, 8, 12, 13, 14, tzinfo=timezone(timedelta(hours=-12, minutes=-15)))
489+
validated = v.validate_python('2022-06-08T12:13:14-12:15')
490+
assert validated == original
491+
assert pickle.loads(pickle.dumps(validated)) == validated == original
492+
493+
494+
def test_tz_hash() -> None:
495+
v = SchemaValidator(core_schema.datetime_schema())
496+
lookup: Dict[datetime, str] = {}
497+
for day in range(1, 10):
498+
input_str = f'2022-06-{day:02}T12:13:14-12:15'
499+
validated = v.validate_python(input_str)
500+
lookup[validated] = input_str
501+
502+
assert len(lookup) == 9
503+
assert (
504+
lookup[datetime(2022, 6, 8, 12, 13, 14, tzinfo=timezone(timedelta(hours=-12, minutes=-15)))]
505+
== '2022-06-08T12:13:14-12:15'
506+
)
507+
508+
509+
def test_tz_cmp() -> None:
510+
v = SchemaValidator(core_schema.datetime_schema())
511+
validated1 = v.validate_python('2022-06-08T12:13:14-12:15')
512+
validated2 = v.validate_python('2022-06-08T12:13:14-12:14')
513+
514+
assert validated1 > validated2
515+
assert validated2 < validated1

0 commit comments

Comments
 (0)